tenferro_ops/ad/context.rs
1//! AD context for guard-based shape resolution and value metadata queries.
2//!
3//! During AD graph construction, linalg rules such as SVD, QR, and LU need
4//! concrete matrix dimensions to choose between structurally different
5//! subgraphs. `ShapeGuardContext` records those dimension comparisons as guards
6//! so cached AD graphs can later be invalidated when the observed shape
7//! relationship changes.
8
9use std::cmp::Ordering;
10use std::collections::HashMap;
11#[cfg(feature = "autodiff")]
12use std::sync::Arc;
13use std::sync::{Mutex, OnceLock};
14
15use computegraph::graph::Graph;
16use computegraph::types::{ValueKey, ValueRef};
17use tenferro_tensor::DType;
18
19use crate::dim_expr::{DimExpr, DimExprEvalError};
20#[cfg(feature = "autodiff")]
21use crate::ext_op::{ExtensionAdRule, ExtensionRuleSet};
22use crate::shape_extent::ShapeExtent;
23use crate::std_tensor_op::StdTensorOp;
24use crate::sym_dim::SymDim;
25
26type MetadataMap = HashMap<ValueKey<StdTensorOp>, TensorMeta>;
27
28type GlobalMetadataMap = HashMap<ValueKey<StdTensorOp>, GlobalMetadataEntry>;
29
30#[derive(Clone, Debug)]
31struct GlobalMetadataEntry {
32 meta: TensorMeta,
33 scoped_refs: usize,
34}
35
36/// Error returned when the process-global AD metadata registry is unavailable.
37#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
38pub enum MetadataRegistryError {
39 /// A previous panic poisoned the global metadata mutex.
40 #[error("AD global metadata registry lock poisoned")]
41 LockPoisoned,
42}
43
44/// Error returned when shape-guard metadata cannot be resolved.
45#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
46pub enum ShapeGuardError {
47 /// A local graph value was queried before a graph was attached.
48 #[error("cannot resolve local value {local_id} without an attached graph")]
49 LocalWithoutAttachedGraph {
50 /// Graph-local value id.
51 local_id: usize,
52 },
53 /// A local graph value id is outside the attached graph's value table.
54 #[error("local value {local_id} is out of bounds for the attached graph")]
55 LocalOutOfBounds {
56 /// Graph-local value id.
57 local_id: usize,
58 },
59 /// No metadata was registered for the resolved value key.
60 #[error("missing TensorMeta for {key:?}")]
61 MissingMetadata {
62 /// Resolved value key.
63 key: ValueKey<StdTensorOp>,
64 },
65 /// Metadata exists, but at least one axis is only bounded or unknown.
66 #[error("TensorMeta for {key:?} does not have an exact shape; query extents instead")]
67 NonExactShape {
68 /// Resolved value key.
69 key: ValueKey<StdTensorOp>,
70 },
71}
72
73/// Result type used by shape-guard metadata queries.
74pub type ShapeGuardResult<T> = Result<T, ShapeGuardError>;
75
76#[cfg(feature = "autodiff")]
77impl From<ShapeGuardError> for tidu::ADRuleError {
78 fn from(err: ShapeGuardError) -> Self {
79 tidu::ADRuleError::invalid_input(
80 "tenferro.shape_guard",
81 tidu::ADRuleKind::Jvp,
82 err.to_string(),
83 )
84 }
85}
86
87/// Global metadata registry.
88///
89/// Stored as `Mutex<MetadataMap>` directly: writes insert in place (O(1)),
90/// and reads lock briefly for targeted key lookups. `ShapeGuardContext::metadata_of`
91/// reaches into the registry lazily via [`lookup_global_metadata`] and caches the
92/// result into the context's local map.
93///
94/// Earlier designs either cloned the whole map up-front into each AD
95/// `ShapeGuardContext` or kept the map in an `Arc` and cloned on every write.
96/// Both variants were quadratic across the monotonically growing registry and
97/// dominated oracle_replay runtime.
98static GLOBAL_METADATA: OnceLock<Mutex<GlobalMetadataMap>> = OnceLock::new();
99
100fn global_metadata_registry() -> &'static Mutex<GlobalMetadataMap> {
101 GLOBAL_METADATA.get_or_init(|| Mutex::new(HashMap::new()))
102}
103
104/// Lifetime token for graph-scoped global metadata.
105///
106/// Dropping the last frontend owner of a traced graph drops this scope and
107/// releases the metadata keys that were registered for that graph graph.
108#[doc(hidden)]
109#[derive(Debug)]
110pub struct GlobalMetadataScope {
111 keys: Vec<ValueKey<StdTensorOp>>,
112}
113
114impl Drop for GlobalMetadataScope {
115 fn drop(&mut self) {
116 release_scoped_global_metadata(&self.keys);
117 }
118}
119
120/// Per-value tensor metadata used by AD rules.
121///
122/// Shape information is stored as per-axis [`ShapeExtent`] values. Callers must
123/// explicitly choose whether they need an exact shape or only a known bound.
124///
125/// # Examples
126///
127/// ```
128/// use tenferro_ops::{SymDim, TensorMeta};
129/// use tenferro_tensor::DType;
130///
131/// let meta = TensorMeta::exact(DType::F64, vec![SymDim::from(2usize), SymDim::from(3usize)]);
132/// assert_eq!(meta.rank(), 2);
133/// ```
134#[derive(Clone, Debug, PartialEq, Eq)]
135pub struct TensorMeta {
136 /// Element dtype of the tensor value.
137 pub dtype: DType,
138 /// Per-axis shape guarantees.
139 pub extents: Vec<ShapeExtent<SymDim>>,
140}
141
142impl TensorMeta {
143 /// Construct metadata whose every axis is exact.
144 ///
145 /// # Examples
146 ///
147 /// ```
148 /// use tenferro_ops::{SymDim, TensorMeta};
149 /// use tenferro_tensor::DType;
150 ///
151 /// let meta = TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]);
152 /// assert_eq!(meta.exact_shape(), Some(vec![SymDim::from(4usize)]));
153 /// ```
154 pub fn exact(dtype: DType, shape: Vec<SymDim>) -> Self {
155 let extents = shape.iter().cloned().map(ShapeExtent::exact).collect();
156 Self { dtype, extents }
157 }
158
159 /// Construct metadata from per-axis extents.
160 ///
161 /// # Examples
162 ///
163 /// ```
164 /// use tenferro_ops::{ShapeExtent, SymDim, TensorMeta};
165 /// use tenferro_tensor::DType;
166 ///
167 /// let meta = TensorMeta::with_extents(
168 /// DType::F64,
169 /// vec![ShapeExtent::upper_bound(SymDim::from(8usize))],
170 /// );
171 /// assert_eq!(meta.exact_shape(), None);
172 /// ```
173 pub fn with_extents(dtype: DType, extents: Vec<ShapeExtent<SymDim>>) -> Self {
174 Self { dtype, extents }
175 }
176
177 /// Return the tensor rank known by this metadata record.
178 pub fn rank(&self) -> usize {
179 self.extents.len()
180 }
181
182 /// Return the per-axis shape guarantees.
183 ///
184 /// # Examples
185 ///
186 /// ```
187 /// use tenferro_ops::{SymDim, TensorMeta};
188 /// use tenferro_tensor::DType;
189 ///
190 /// let meta = TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]);
191 /// assert_eq!(meta.extents().len(), 1);
192 /// ```
193 pub fn extents(&self) -> &[ShapeExtent<SymDim>] {
194 &self.extents
195 }
196
197 /// Return the shape only when every axis is exact.
198 ///
199 /// # Examples
200 ///
201 /// ```
202 /// use tenferro_ops::{ShapeExtent, SymDim, TensorMeta};
203 /// use tenferro_tensor::DType;
204 ///
205 /// let meta = TensorMeta::with_extents(
206 /// DType::F64,
207 /// vec![ShapeExtent::upper_bound(SymDim::from(8usize))],
208 /// );
209 /// assert_eq!(meta.exact_shape(), None);
210 /// ```
211 pub fn exact_shape(&self) -> Option<Vec<SymDim>> {
212 self.extents
213 .iter()
214 .map(|extent| extent.as_exact().cloned())
215 .collect()
216 }
217
218 /// Return one known bound per axis when every axis has a bound.
219 ///
220 /// This is intentionally separate from [`TensorMeta::exact_shape`]: a bound
221 /// is not proof of the runtime size.
222 pub fn bound_shape(&self) -> Option<Vec<SymDim>> {
223 self.extents
224 .iter()
225 .map(|extent| extent.bound_expr().cloned())
226 .collect()
227 }
228}
229
230/// A recorded dimension comparison made during AD graph construction.
231///
232/// # Examples
233///
234/// ```
235/// use std::cmp::Ordering;
236/// use tenferro_ops::ShapeGuard;
237///
238/// let guard = ShapeGuard {
239/// dim_a: 5,
240/// dim_b: 3,
241/// ordering: Ordering::Greater,
242/// };
243///
244/// assert_eq!(guard.ordering, Ordering::Greater);
245/// ```
246#[derive(Clone, Debug, PartialEq, Eq)]
247pub struct ShapeGuard {
248 /// First dimension value, such as `m`.
249 pub dim_a: usize,
250 /// Second dimension value, such as `n`.
251 pub dim_b: usize,
252 /// The observed ordering `dim_a.cmp(&dim_b)`.
253 pub ordering: Ordering,
254}
255
256/// AD context providing dimension resolution, guard recording, and value metadata.
257///
258/// # Examples
259///
260/// ```
261/// use tenferro_ops::ShapeGuardContext;
262///
263/// let ctx = ShapeGuardContext::default();
264/// assert!(ctx.guards().is_empty());
265/// ```
266#[derive(Clone, Debug, Default)]
267pub struct ShapeGuardContext {
268 guards: Vec<ShapeGuard>,
269 metadata: MetadataMap,
270 use_global_registry: bool,
271 local_keys: Option<Vec<ValueKey<StdTensorOp>>>,
272 #[cfg(feature = "autodiff")]
273 extension_rules: Option<ExtensionRuleSet>,
274}
275
276impl ShapeGuardContext {
277 /// Create a context backed by the global metadata registry.
278 ///
279 /// Instead of cloning the entire global registry up-front (which used
280 /// to be O(N) per AD pass and quadratic across oracle_replay), the
281 /// context keeps a flag and lazily fetches entries from the shared
282 /// [`lookup_global_metadata`] on first miss, caching into its local
283 /// `metadata` map for subsequent reads within the same pass.
284 ///
285 /// # Examples
286 ///
287 /// ```
288 /// let ctx = tenferro_ops::ShapeGuardContext::with_global_metadata();
289 /// assert!(ctx.guards().is_empty());
290 /// ```
291 pub fn with_global_metadata() -> Self {
292 Self {
293 use_global_registry: true,
294 ..Self::default()
295 }
296 }
297
298 #[doc(hidden)]
299 /// Keep global-registry lookup enabled after a pass boundary.
300 ///
301 /// This is intentionally a no-op for cached entries: global metadata is
302 /// already read lazily on cache misses, and clearing the local cache would
303 /// also discard metadata inserted directly into this context.
304 pub fn refresh_global_metadata(&mut self) {
305 self.use_global_registry = true;
306 }
307
308 /// Use an explicit extension AD rule set for this context.
309 ///
310 /// Extension AD lookup is context-owned: a context without an attached rule
311 /// set has no extension AD rules.
312 ///
313 /// # Examples
314 ///
315 /// ```
316 /// use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
317 ///
318 /// let _ctx = ShapeGuardContext::default().with_extension_rules(ExtensionRuleSet::new());
319 /// ```
320 #[cfg(feature = "autodiff")]
321 pub fn with_extension_rules(mut self, rules: ExtensionRuleSet) -> Self {
322 self.extension_rules = Some(rules);
323 self
324 }
325
326 /// Look up an extension AD rule using this context's ownership policy.
327 ///
328 /// Contexts without an explicit rule set have no extension AD rules.
329 #[doc(hidden)]
330 #[cfg(feature = "autodiff")]
331 pub(crate) fn extension_rule_for(&self, family_id: &str) -> Option<Arc<dyn ExtensionAdRule>> {
332 self.extension_rules
333 .as_ref()
334 .and_then(|rules| rules.lookup_rule(family_id))
335 }
336
337 /// Returns the guards recorded so far.
338 ///
339 /// # Examples
340 ///
341 /// ```
342 /// use tenferro_ops::ShapeGuardContext;
343 ///
344 /// let ctx = ShapeGuardContext::default();
345 /// assert_eq!(ctx.guards(), &[]);
346 /// ```
347 pub fn guards(&self) -> &[ShapeGuard] {
348 &self.guards
349 }
350
351 /// Clears all recorded guards.
352 ///
353 /// # Examples
354 ///
355 /// ```
356 /// use tenferro_ops::ShapeGuardContext;
357 ///
358 /// let mut ctx = ShapeGuardContext::default();
359 /// ctx.clear_guards();
360 /// assert!(ctx.guards().is_empty());
361 /// ```
362 pub fn clear_guards(&mut self) {
363 self.guards.clear();
364 }
365
366 /// Return the shape metadata for a value reference.
367 ///
368 /// # Examples
369 ///
370 /// ```
371 /// use computegraph::types::{ValueKey, ValueRef};
372 /// use tenferro_ops::input_key::TensorInputKey;
373 /// use tenferro_ops::std_tensor_op::StdTensorOp;
374 /// use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
375 /// use tenferro_tensor::DType;
376 ///
377 /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
378 /// let value = ValueRef::External(key.clone());
379 /// let mut ctx = ShapeGuardContext::default();
380 /// ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
381 ///
382 /// let shape = ctx.shape_of(&value).unwrap();
383 /// assert_eq!(shape, &[SymDim::from(4usize)]);
384 /// ```
385 pub fn shape_of(&mut self, val: &ValueRef<StdTensorOp>) -> ShapeGuardResult<Vec<SymDim>> {
386 let key = self.resolve_key(val)?.clone();
387 self.ensure_metadata_loaded(&key);
388 let meta = self
389 .metadata
390 .get(&key)
391 .ok_or_else(|| ShapeGuardError::MissingMetadata { key: key.clone() })?;
392 meta.exact_shape()
393 .ok_or(ShapeGuardError::NonExactShape { key })
394 }
395
396 /// Return the rank for a value reference without requiring exact extents.
397 ///
398 /// Use this when an AD rule only needs axis count or needs to build
399 /// runtime-shape references. Calling [`ShapeGuardContext::shape_of`] in those
400 /// cases would reject valid values such as `DynamicTruncate` outputs whose
401 /// runtime extent is known only as an upper bound.
402 ///
403 /// # Examples
404 ///
405 /// ```
406 /// use computegraph::types::{ValueKey, ValueRef};
407 /// use tenferro_ops::input_key::TensorInputKey;
408 /// use tenferro_ops::std_tensor_op::StdTensorOp;
409 /// use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
410 /// use tenferro_tensor::DType;
411 ///
412 /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
413 /// let value = ValueRef::External(key.clone());
414 /// let mut ctx = ShapeGuardContext::default();
415 /// ctx.insert_metadata(
416 /// key,
417 /// TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
418 /// );
419 ///
420 /// assert_eq!(ctx.rank_of(&value).unwrap(), 1);
421 /// ```
422 pub fn rank_of(&mut self, val: &ValueRef<StdTensorOp>) -> ShapeGuardResult<usize> {
423 self.metadata_of(val).map(TensorMeta::rank)
424 }
425
426 /// Return per-axis shape guarantees for a value reference.
427 ///
428 /// # Examples
429 ///
430 /// ```
431 /// use computegraph::types::{ValueKey, ValueRef};
432 /// use tenferro_ops::input_key::TensorInputKey;
433 /// use tenferro_ops::std_tensor_op::StdTensorOp;
434 /// use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
435 /// use tenferro_tensor::DType;
436 ///
437 /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
438 /// let value = ValueRef::External(key.clone());
439 /// let mut ctx = ShapeGuardContext::default();
440 /// ctx.insert_metadata(
441 /// key,
442 /// TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
443 /// );
444 ///
445 /// let extents = ctx.extents_of(&value).unwrap();
446 /// assert_eq!(extents[0], ShapeExtent::upper_bound(SymDim::from(8usize)));
447 /// ```
448 pub fn extents_of(
449 &mut self,
450 val: &ValueRef<StdTensorOp>,
451 ) -> ShapeGuardResult<&[ShapeExtent<SymDim>]> {
452 self.metadata_of(val).map(TensorMeta::extents)
453 }
454
455 /// Return the exact shape for a value reference, if all axes are exact.
456 ///
457 /// # Examples
458 ///
459 /// ```
460 /// use computegraph::types::{ValueKey, ValueRef};
461 /// use tenferro_ops::input_key::TensorInputKey;
462 /// use tenferro_ops::std_tensor_op::StdTensorOp;
463 /// use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
464 /// use tenferro_tensor::DType;
465 ///
466 /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
467 /// let value = ValueRef::External(key.clone());
468 /// let mut ctx = ShapeGuardContext::default();
469 /// ctx.insert_metadata(
470 /// key,
471 /// TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
472 /// );
473 ///
474 /// let maybe_shape = ctx.exact_shape_of(&value).unwrap();
475 /// assert_eq!(maybe_shape, None);
476 /// ```
477 pub fn exact_shape_of(
478 &mut self,
479 val: &ValueRef<StdTensorOp>,
480 ) -> ShapeGuardResult<Option<Vec<SymDim>>> {
481 self.metadata_of(val).map(TensorMeta::exact_shape)
482 }
483
484 #[doc(hidden)]
485 pub fn shape_if_available(&mut self, val: &ValueRef<StdTensorOp>) -> Option<Vec<SymDim>> {
486 self.metadata_if_available(val)
487 .and_then(TensorMeta::exact_shape)
488 }
489
490 /// Return the dtype metadata for a value reference.
491 ///
492 /// # Examples
493 ///
494 /// ```
495 /// use computegraph::types::{ValueKey, ValueRef};
496 /// use tenferro_ops::input_key::TensorInputKey;
497 /// use tenferro_ops::std_tensor_op::StdTensorOp;
498 /// use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
499 /// use tenferro_tensor::DType;
500 ///
501 /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
502 /// let value = ValueRef::External(key.clone());
503 /// let mut ctx = ShapeGuardContext::default();
504 /// ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
505 ///
506 /// let dtype = ctx.dtype_of(&value).unwrap();
507 /// assert_eq!(dtype, DType::F64);
508 /// ```
509 pub fn dtype_of(&mut self, val: &ValueRef<StdTensorOp>) -> ShapeGuardResult<DType> {
510 self.metadata_of(val).map(|meta| meta.dtype)
511 }
512
513 /// Return the complete metadata record for a value reference.
514 ///
515 /// # Examples
516 ///
517 /// ```
518 /// use computegraph::types::{ValueKey, ValueRef};
519 /// use tenferro_ops::input_key::TensorInputKey;
520 /// use tenferro_ops::std_tensor_op::StdTensorOp;
521 /// use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
522 /// use tenferro_tensor::DType;
523 ///
524 /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
525 /// let value = ValueRef::External(key.clone());
526 /// let mut ctx = ShapeGuardContext::default();
527 /// ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
528 ///
529 /// let meta = ctx.metadata_of(&value).unwrap();
530 /// assert_eq!(meta.dtype, DType::F64);
531 /// ```
532 pub fn metadata_of(&mut self, val: &ValueRef<StdTensorOp>) -> ShapeGuardResult<&TensorMeta> {
533 let key = self.resolve_key(val)?.clone();
534 self.ensure_metadata_loaded(&key);
535 self.metadata
536 .get(&key)
537 .ok_or(ShapeGuardError::MissingMetadata { key })
538 }
539
540 #[doc(hidden)]
541 pub fn metadata_if_available(&mut self, val: &ValueRef<StdTensorOp>) -> Option<&TensorMeta> {
542 let key = self.resolve_key_if_available(val)?.clone();
543 self.ensure_metadata_loaded(&key);
544 self.metadata.get(&key)
545 }
546
547 #[doc(hidden)]
548 pub fn attach_graph(&mut self, graph: &Graph<StdTensorOp>) {
549 self.local_keys = Some(graph.values().iter().map(|node| node.key.clone()).collect());
550 }
551
552 #[doc(hidden)]
553 pub fn insert_metadata(&mut self, key: ValueKey<StdTensorOp>, meta: TensorMeta) {
554 self.metadata.insert(key, meta);
555 }
556
557 #[doc(hidden)]
558 pub fn extend_metadata<I>(&mut self, entries: I)
559 where
560 I: IntoIterator<Item = (ValueKey<StdTensorOp>, TensorMeta)>,
561 {
562 self.metadata.extend(entries);
563 }
564
565 fn resolve_key_if_available<'a>(
566 &'a self,
567 val: &'a ValueRef<StdTensorOp>,
568 ) -> Option<&'a ValueKey<StdTensorOp>> {
569 match val {
570 ValueRef::External(key) => Some(key),
571 ValueRef::Local(local_id) => self
572 .local_keys
573 .as_ref()
574 .and_then(|keys| keys.get(*local_id)),
575 }
576 }
577
578 fn resolve_key<'a>(
579 &'a self,
580 val: &'a ValueRef<StdTensorOp>,
581 ) -> ShapeGuardResult<&'a ValueKey<StdTensorOp>> {
582 match val {
583 ValueRef::External(key) => Ok(key),
584 ValueRef::Local(local_id) if self.local_keys.is_none() => {
585 Err(ShapeGuardError::LocalWithoutAttachedGraph {
586 local_id: *local_id,
587 })
588 }
589 ValueRef::Local(local_id) => self
590 .local_keys
591 .as_ref()
592 .and_then(|keys| keys.get(*local_id))
593 .ok_or(ShapeGuardError::LocalOutOfBounds {
594 local_id: *local_id,
595 }),
596 }
597 }
598
599 fn ensure_metadata_loaded(&mut self, key: &ValueKey<StdTensorOp>) {
600 if !self.metadata.contains_key(key) && self.use_global_registry {
601 if let Ok(Some(meta)) = lookup_global_metadata(key) {
602 self.metadata.insert(key.clone(), meta);
603 }
604 }
605 }
606}
607
608/// Look up a single metadata entry from the global registry.
609///
610/// Locks the registry briefly for a single `HashMap::get` + clone.
611///
612/// # Examples
613///
614/// ```
615/// use computegraph::types::ValueKey;
616/// use tenferro_ops::ad::context::lookup_global_metadata;
617/// use tenferro_ops::input_key::TensorInputKey;
618/// use tenferro_ops::std_tensor_op::StdTensorOp;
619///
620/// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 99 });
621/// let meta = lookup_global_metadata(&key).unwrap();
622/// assert!(meta.is_none());
623/// ```
624pub fn lookup_global_metadata(
625 key: &ValueKey<StdTensorOp>,
626) -> Result<Option<TensorMeta>, MetadataRegistryError> {
627 let guard = global_metadata_registry()
628 .lock()
629 .map_err(|_| MetadataRegistryError::LockPoisoned)?;
630 Ok(guard.get(key).map(|entry| entry.meta.clone()))
631}
632
633#[doc(hidden)]
634pub fn register_scoped_global_metadata_batch<I>(
635 entries: I,
636) -> Result<GlobalMetadataScope, MetadataRegistryError>
637where
638 I: IntoIterator<Item = (ValueKey<StdTensorOp>, TensorMeta)>,
639{
640 let mut guard = global_metadata_registry()
641 .lock()
642 .map_err(|_| MetadataRegistryError::LockPoisoned)?;
643 let mut keys = Vec::new();
644 for (key, meta) in entries {
645 let entry = guard.entry(key.clone()).or_insert(GlobalMetadataEntry {
646 meta: meta.clone(),
647 scoped_refs: 0,
648 });
649 entry.meta = meta;
650 entry.scoped_refs += 1;
651 keys.push(key);
652 }
653 Ok(GlobalMetadataScope { keys })
654}
655
656fn release_scoped_global_metadata(keys: &[ValueKey<StdTensorOp>]) {
657 let Ok(mut guard) = global_metadata_registry().lock() else {
658 // Drop cannot return an error. Failing closed here avoids reading or
659 // mutating data from a poisoned registry at the cost of leaking entries
660 // until process exit.
661 return;
662 };
663 for key in keys {
664 let should_remove = if let Some(entry) = guard.get_mut(key) {
665 entry.scoped_refs = entry.scoped_refs.saturating_sub(1);
666 entry.scoped_refs == 0
667 } else {
668 false
669 };
670 if should_remove {
671 guard.remove(key);
672 }
673 }
674}
675
676/// Resolve a [`DimExpr`] to a concrete `usize`.
677#[doc(hidden)]
678pub fn resolve_dim(dim: &DimExpr) -> Result<usize, DimExprEvalError> {
679 dim.eval(&[])
680}
681
682/// Resolve matrix dimensions and record their ordering as a guard.
683#[doc(hidden)]
684pub fn resolve_and_guard(
685 m: &DimExpr,
686 n: &DimExpr,
687 ctx: &mut ShapeGuardContext,
688) -> Result<(usize, usize), DimExprEvalError> {
689 let m_size = resolve_dim(m)?;
690 let n_size = resolve_dim(n)?;
691 ctx.guards.push(ShapeGuard {
692 dim_a: m_size,
693 dim_b: n_size,
694 ordering: m_size.cmp(&n_size),
695 });
696 Ok((m_size, n_size))
697}