Skip to main content

tenferro_ops/ad/
context.rs

1//! AD context for guard-based shape resolution and value metadata queries.
2//!
3//! During AD graph construction, linalg rules such as SVD, QR, and LU need
4//! concrete matrix dimensions to choose between structurally different
5//! subgraphs. `ShapeGuardContext` records those dimension comparisons as guards
6//! so cached AD graphs can later be invalidated when the observed shape
7//! relationship changes.
8
9use std::cmp::Ordering;
10use std::collections::HashMap;
11#[cfg(feature = "autodiff")]
12use std::sync::Arc;
13use std::sync::{Mutex, OnceLock};
14
15use computegraph::graph::Graph;
16use computegraph::types::{ValueKey, ValueRef};
17use tenferro_tensor::DType;
18
19use crate::dim_expr::{DimExpr, DimExprEvalError};
20#[cfg(feature = "autodiff")]
21use crate::ext_op::{ExtensionAdRule, ExtensionRuleSet};
22use crate::shape_extent::ShapeExtent;
23use crate::std_tensor_op::StdTensorOp;
24use crate::sym_dim::SymDim;
25
26type MetadataMap = HashMap<ValueKey<StdTensorOp>, TensorMeta>;
27
28type GlobalMetadataMap = HashMap<ValueKey<StdTensorOp>, GlobalMetadataEntry>;
29
30#[derive(Clone, Debug)]
31struct GlobalMetadataEntry {
32    meta: TensorMeta,
33    scoped_refs: usize,
34}
35
36/// Error returned when the process-global AD metadata registry is unavailable.
37#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
38pub enum MetadataRegistryError {
39    /// A previous panic poisoned the global metadata mutex.
40    #[error("AD global metadata registry lock poisoned")]
41    LockPoisoned,
42}
43
44/// Error returned when shape-guard metadata cannot be resolved.
45#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
46pub enum ShapeGuardError {
47    /// A local graph value was queried before a graph was attached.
48    #[error("cannot resolve local value {local_id} without an attached graph")]
49    LocalWithoutAttachedGraph {
50        /// Graph-local value id.
51        local_id: usize,
52    },
53    /// A local graph value id is outside the attached graph's value table.
54    #[error("local value {local_id} is out of bounds for the attached graph")]
55    LocalOutOfBounds {
56        /// Graph-local value id.
57        local_id: usize,
58    },
59    /// No metadata was registered for the resolved value key.
60    #[error("missing TensorMeta for {key:?}")]
61    MissingMetadata {
62        /// Resolved value key.
63        key: ValueKey<StdTensorOp>,
64    },
65    /// Metadata exists, but at least one axis is only bounded or unknown.
66    #[error("TensorMeta for {key:?} does not have an exact shape; query extents instead")]
67    NonExactShape {
68        /// Resolved value key.
69        key: ValueKey<StdTensorOp>,
70    },
71}
72
73/// Result type used by shape-guard metadata queries.
74pub type ShapeGuardResult<T> = Result<T, ShapeGuardError>;
75
76#[cfg(feature = "autodiff")]
77impl From<ShapeGuardError> for tidu::ADRuleError {
78    fn from(err: ShapeGuardError) -> Self {
79        tidu::ADRuleError::invalid_input(
80            "tenferro.shape_guard",
81            tidu::ADRuleKind::Jvp,
82            err.to_string(),
83        )
84    }
85}
86
87/// Global metadata registry.
88///
89/// Stored as `Mutex<MetadataMap>` directly: writes insert in place (O(1)),
90/// and reads lock briefly for targeted key lookups. `ShapeGuardContext::metadata_of`
91/// reaches into the registry lazily via [`lookup_global_metadata`] and caches the
92/// result into the context's local map.
93///
94/// Earlier designs either cloned the whole map up-front into each AD
95/// `ShapeGuardContext` or kept the map in an `Arc` and cloned on every write.
96/// Both variants were quadratic across the monotonically growing registry and
97/// dominated oracle_replay runtime.
98static GLOBAL_METADATA: OnceLock<Mutex<GlobalMetadataMap>> = OnceLock::new();
99
100fn global_metadata_registry() -> &'static Mutex<GlobalMetadataMap> {
101    GLOBAL_METADATA.get_or_init(|| Mutex::new(HashMap::new()))
102}
103
104/// Lifetime token for graph-scoped global metadata.
105///
106/// Dropping the last frontend owner of a traced graph drops this scope and
107/// releases the metadata keys that were registered for that graph graph.
108#[doc(hidden)]
109#[derive(Debug)]
110pub struct GlobalMetadataScope {
111    keys: Vec<ValueKey<StdTensorOp>>,
112}
113
114impl Drop for GlobalMetadataScope {
115    fn drop(&mut self) {
116        release_scoped_global_metadata(&self.keys);
117    }
118}
119
120/// Per-value tensor metadata used by AD rules.
121///
122/// Shape information is stored as per-axis [`ShapeExtent`] values. Callers must
123/// explicitly choose whether they need an exact shape or only a known bound.
124///
125/// # Examples
126///
127/// ```
128/// use tenferro_ops::{SymDim, TensorMeta};
129/// use tenferro_tensor::DType;
130///
131/// let meta = TensorMeta::exact(DType::F64, vec![SymDim::from(2usize), SymDim::from(3usize)]);
132/// assert_eq!(meta.rank(), 2);
133/// ```
134#[derive(Clone, Debug, PartialEq, Eq)]
135pub struct TensorMeta {
136    /// Element dtype of the tensor value.
137    pub dtype: DType,
138    /// Per-axis shape guarantees.
139    pub extents: Vec<ShapeExtent<SymDim>>,
140}
141
142impl TensorMeta {
143    /// Construct metadata whose every axis is exact.
144    ///
145    /// # Examples
146    ///
147    /// ```
148    /// use tenferro_ops::{SymDim, TensorMeta};
149    /// use tenferro_tensor::DType;
150    ///
151    /// let meta = TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]);
152    /// assert_eq!(meta.exact_shape(), Some(vec![SymDim::from(4usize)]));
153    /// ```
154    pub fn exact(dtype: DType, shape: Vec<SymDim>) -> Self {
155        let extents = shape.iter().cloned().map(ShapeExtent::exact).collect();
156        Self { dtype, extents }
157    }
158
159    /// Construct metadata from per-axis extents.
160    ///
161    /// # Examples
162    ///
163    /// ```
164    /// use tenferro_ops::{ShapeExtent, SymDim, TensorMeta};
165    /// use tenferro_tensor::DType;
166    ///
167    /// let meta = TensorMeta::with_extents(
168    ///     DType::F64,
169    ///     vec![ShapeExtent::upper_bound(SymDim::from(8usize))],
170    /// );
171    /// assert_eq!(meta.exact_shape(), None);
172    /// ```
173    pub fn with_extents(dtype: DType, extents: Vec<ShapeExtent<SymDim>>) -> Self {
174        Self { dtype, extents }
175    }
176
177    /// Return the tensor rank known by this metadata record.
178    pub fn rank(&self) -> usize {
179        self.extents.len()
180    }
181
182    /// Return the per-axis shape guarantees.
183    ///
184    /// # Examples
185    ///
186    /// ```
187    /// use tenferro_ops::{SymDim, TensorMeta};
188    /// use tenferro_tensor::DType;
189    ///
190    /// let meta = TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]);
191    /// assert_eq!(meta.extents().len(), 1);
192    /// ```
193    pub fn extents(&self) -> &[ShapeExtent<SymDim>] {
194        &self.extents
195    }
196
197    /// Return the shape only when every axis is exact.
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// use tenferro_ops::{ShapeExtent, SymDim, TensorMeta};
203    /// use tenferro_tensor::DType;
204    ///
205    /// let meta = TensorMeta::with_extents(
206    ///     DType::F64,
207    ///     vec![ShapeExtent::upper_bound(SymDim::from(8usize))],
208    /// );
209    /// assert_eq!(meta.exact_shape(), None);
210    /// ```
211    pub fn exact_shape(&self) -> Option<Vec<SymDim>> {
212        self.extents
213            .iter()
214            .map(|extent| extent.as_exact().cloned())
215            .collect()
216    }
217
218    /// Return one known bound per axis when every axis has a bound.
219    ///
220    /// This is intentionally separate from [`TensorMeta::exact_shape`]: a bound
221    /// is not proof of the runtime size.
222    pub fn bound_shape(&self) -> Option<Vec<SymDim>> {
223        self.extents
224            .iter()
225            .map(|extent| extent.bound_expr().cloned())
226            .collect()
227    }
228}
229
230/// A recorded dimension comparison made during AD graph construction.
231///
232/// # Examples
233///
234/// ```
235/// use std::cmp::Ordering;
236/// use tenferro_ops::ShapeGuard;
237///
238/// let guard = ShapeGuard {
239///     dim_a: 5,
240///     dim_b: 3,
241///     ordering: Ordering::Greater,
242/// };
243///
244/// assert_eq!(guard.ordering, Ordering::Greater);
245/// ```
246#[derive(Clone, Debug, PartialEq, Eq)]
247pub struct ShapeGuard {
248    /// First dimension value, such as `m`.
249    pub dim_a: usize,
250    /// Second dimension value, such as `n`.
251    pub dim_b: usize,
252    /// The observed ordering `dim_a.cmp(&dim_b)`.
253    pub ordering: Ordering,
254}
255
256/// AD context providing dimension resolution, guard recording, and value metadata.
257///
258/// # Examples
259///
260/// ```
261/// use tenferro_ops::ShapeGuardContext;
262///
263/// let ctx = ShapeGuardContext::default();
264/// assert!(ctx.guards().is_empty());
265/// ```
266#[derive(Clone, Debug, Default)]
267pub struct ShapeGuardContext {
268    guards: Vec<ShapeGuard>,
269    metadata: MetadataMap,
270    use_global_registry: bool,
271    local_keys: Option<Vec<ValueKey<StdTensorOp>>>,
272    #[cfg(feature = "autodiff")]
273    extension_rules: Option<ExtensionRuleSet>,
274}
275
276impl ShapeGuardContext {
277    /// Create a context backed by the global metadata registry.
278    ///
279    /// Instead of cloning the entire global registry up-front (which used
280    /// to be O(N) per AD pass and quadratic across oracle_replay), the
281    /// context keeps a flag and lazily fetches entries from the shared
282    /// [`lookup_global_metadata`] on first miss, caching into its local
283    /// `metadata` map for subsequent reads within the same pass.
284    ///
285    /// # Examples
286    ///
287    /// ```
288    /// let ctx = tenferro_ops::ShapeGuardContext::with_global_metadata();
289    /// assert!(ctx.guards().is_empty());
290    /// ```
291    pub fn with_global_metadata() -> Self {
292        Self {
293            use_global_registry: true,
294            ..Self::default()
295        }
296    }
297
298    #[doc(hidden)]
299    /// Keep global-registry lookup enabled after a pass boundary.
300    ///
301    /// This is intentionally a no-op for cached entries: global metadata is
302    /// already read lazily on cache misses, and clearing the local cache would
303    /// also discard metadata inserted directly into this context.
304    pub fn refresh_global_metadata(&mut self) {
305        self.use_global_registry = true;
306    }
307
308    /// Use an explicit extension AD rule set for this context.
309    ///
310    /// Extension AD lookup is context-owned: a context without an attached rule
311    /// set has no extension AD rules.
312    ///
313    /// # Examples
314    ///
315    /// ```
316    /// use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
317    ///
318    /// let _ctx = ShapeGuardContext::default().with_extension_rules(ExtensionRuleSet::new());
319    /// ```
320    #[cfg(feature = "autodiff")]
321    pub fn with_extension_rules(mut self, rules: ExtensionRuleSet) -> Self {
322        self.extension_rules = Some(rules);
323        self
324    }
325
326    /// Look up an extension AD rule using this context's ownership policy.
327    ///
328    /// Contexts without an explicit rule set have no extension AD rules.
329    #[doc(hidden)]
330    #[cfg(feature = "autodiff")]
331    pub(crate) fn extension_rule_for(&self, family_id: &str) -> Option<Arc<dyn ExtensionAdRule>> {
332        self.extension_rules
333            .as_ref()
334            .and_then(|rules| rules.lookup_rule(family_id))
335    }
336
337    /// Returns the guards recorded so far.
338    ///
339    /// # Examples
340    ///
341    /// ```
342    /// use tenferro_ops::ShapeGuardContext;
343    ///
344    /// let ctx = ShapeGuardContext::default();
345    /// assert_eq!(ctx.guards(), &[]);
346    /// ```
347    pub fn guards(&self) -> &[ShapeGuard] {
348        &self.guards
349    }
350
351    /// Clears all recorded guards.
352    ///
353    /// # Examples
354    ///
355    /// ```
356    /// use tenferro_ops::ShapeGuardContext;
357    ///
358    /// let mut ctx = ShapeGuardContext::default();
359    /// ctx.clear_guards();
360    /// assert!(ctx.guards().is_empty());
361    /// ```
362    pub fn clear_guards(&mut self) {
363        self.guards.clear();
364    }
365
366    /// Return the shape metadata for a value reference.
367    ///
368    /// # Examples
369    ///
370    /// ```
371    /// use computegraph::types::{ValueKey, ValueRef};
372    /// use tenferro_ops::input_key::TensorInputKey;
373    /// use tenferro_ops::std_tensor_op::StdTensorOp;
374    /// use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
375    /// use tenferro_tensor::DType;
376    ///
377    /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
378    /// let value = ValueRef::External(key.clone());
379    /// let mut ctx = ShapeGuardContext::default();
380    /// ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
381    ///
382    /// let shape = ctx.shape_of(&value).unwrap();
383    /// assert_eq!(shape, &[SymDim::from(4usize)]);
384    /// ```
385    pub fn shape_of(&mut self, val: &ValueRef<StdTensorOp>) -> ShapeGuardResult<Vec<SymDim>> {
386        let key = self.resolve_key(val)?.clone();
387        self.ensure_metadata_loaded(&key);
388        let meta = self
389            .metadata
390            .get(&key)
391            .ok_or_else(|| ShapeGuardError::MissingMetadata { key: key.clone() })?;
392        meta.exact_shape()
393            .ok_or(ShapeGuardError::NonExactShape { key })
394    }
395
396    /// Return the rank for a value reference without requiring exact extents.
397    ///
398    /// Use this when an AD rule only needs axis count or needs to build
399    /// runtime-shape references. Calling [`ShapeGuardContext::shape_of`] in those
400    /// cases would reject valid values such as `DynamicTruncate` outputs whose
401    /// runtime extent is known only as an upper bound.
402    ///
403    /// # Examples
404    ///
405    /// ```
406    /// use computegraph::types::{ValueKey, ValueRef};
407    /// use tenferro_ops::input_key::TensorInputKey;
408    /// use tenferro_ops::std_tensor_op::StdTensorOp;
409    /// use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
410    /// use tenferro_tensor::DType;
411    ///
412    /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
413    /// let value = ValueRef::External(key.clone());
414    /// let mut ctx = ShapeGuardContext::default();
415    /// ctx.insert_metadata(
416    ///     key,
417    ///     TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
418    /// );
419    ///
420    /// assert_eq!(ctx.rank_of(&value).unwrap(), 1);
421    /// ```
422    pub fn rank_of(&mut self, val: &ValueRef<StdTensorOp>) -> ShapeGuardResult<usize> {
423        self.metadata_of(val).map(TensorMeta::rank)
424    }
425
426    /// Return per-axis shape guarantees for a value reference.
427    ///
428    /// # Examples
429    ///
430    /// ```
431    /// use computegraph::types::{ValueKey, ValueRef};
432    /// use tenferro_ops::input_key::TensorInputKey;
433    /// use tenferro_ops::std_tensor_op::StdTensorOp;
434    /// use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
435    /// use tenferro_tensor::DType;
436    ///
437    /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
438    /// let value = ValueRef::External(key.clone());
439    /// let mut ctx = ShapeGuardContext::default();
440    /// ctx.insert_metadata(
441    ///     key,
442    ///     TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
443    /// );
444    ///
445    /// let extents = ctx.extents_of(&value).unwrap();
446    /// assert_eq!(extents[0], ShapeExtent::upper_bound(SymDim::from(8usize)));
447    /// ```
448    pub fn extents_of(
449        &mut self,
450        val: &ValueRef<StdTensorOp>,
451    ) -> ShapeGuardResult<&[ShapeExtent<SymDim>]> {
452        self.metadata_of(val).map(TensorMeta::extents)
453    }
454
455    /// Return the exact shape for a value reference, if all axes are exact.
456    ///
457    /// # Examples
458    ///
459    /// ```
460    /// use computegraph::types::{ValueKey, ValueRef};
461    /// use tenferro_ops::input_key::TensorInputKey;
462    /// use tenferro_ops::std_tensor_op::StdTensorOp;
463    /// use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
464    /// use tenferro_tensor::DType;
465    ///
466    /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
467    /// let value = ValueRef::External(key.clone());
468    /// let mut ctx = ShapeGuardContext::default();
469    /// ctx.insert_metadata(
470    ///     key,
471    ///     TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
472    /// );
473    ///
474    /// let maybe_shape = ctx.exact_shape_of(&value).unwrap();
475    /// assert_eq!(maybe_shape, None);
476    /// ```
477    pub fn exact_shape_of(
478        &mut self,
479        val: &ValueRef<StdTensorOp>,
480    ) -> ShapeGuardResult<Option<Vec<SymDim>>> {
481        self.metadata_of(val).map(TensorMeta::exact_shape)
482    }
483
484    #[doc(hidden)]
485    pub fn shape_if_available(&mut self, val: &ValueRef<StdTensorOp>) -> Option<Vec<SymDim>> {
486        self.metadata_if_available(val)
487            .and_then(TensorMeta::exact_shape)
488    }
489
490    /// Return the dtype metadata for a value reference.
491    ///
492    /// # Examples
493    ///
494    /// ```
495    /// use computegraph::types::{ValueKey, ValueRef};
496    /// use tenferro_ops::input_key::TensorInputKey;
497    /// use tenferro_ops::std_tensor_op::StdTensorOp;
498    /// use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
499    /// use tenferro_tensor::DType;
500    ///
501    /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
502    /// let value = ValueRef::External(key.clone());
503    /// let mut ctx = ShapeGuardContext::default();
504    /// ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
505    ///
506    /// let dtype = ctx.dtype_of(&value).unwrap();
507    /// assert_eq!(dtype, DType::F64);
508    /// ```
509    pub fn dtype_of(&mut self, val: &ValueRef<StdTensorOp>) -> ShapeGuardResult<DType> {
510        self.metadata_of(val).map(|meta| meta.dtype)
511    }
512
513    /// Return the complete metadata record for a value reference.
514    ///
515    /// # Examples
516    ///
517    /// ```
518    /// use computegraph::types::{ValueKey, ValueRef};
519    /// use tenferro_ops::input_key::TensorInputKey;
520    /// use tenferro_ops::std_tensor_op::StdTensorOp;
521    /// use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
522    /// use tenferro_tensor::DType;
523    ///
524    /// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
525    /// let value = ValueRef::External(key.clone());
526    /// let mut ctx = ShapeGuardContext::default();
527    /// ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
528    ///
529    /// let meta = ctx.metadata_of(&value).unwrap();
530    /// assert_eq!(meta.dtype, DType::F64);
531    /// ```
532    pub fn metadata_of(&mut self, val: &ValueRef<StdTensorOp>) -> ShapeGuardResult<&TensorMeta> {
533        let key = self.resolve_key(val)?.clone();
534        self.ensure_metadata_loaded(&key);
535        self.metadata
536            .get(&key)
537            .ok_or(ShapeGuardError::MissingMetadata { key })
538    }
539
540    #[doc(hidden)]
541    pub fn metadata_if_available(&mut self, val: &ValueRef<StdTensorOp>) -> Option<&TensorMeta> {
542        let key = self.resolve_key_if_available(val)?.clone();
543        self.ensure_metadata_loaded(&key);
544        self.metadata.get(&key)
545    }
546
547    #[doc(hidden)]
548    pub fn attach_graph(&mut self, graph: &Graph<StdTensorOp>) {
549        self.local_keys = Some(graph.values().iter().map(|node| node.key.clone()).collect());
550    }
551
552    #[doc(hidden)]
553    pub fn insert_metadata(&mut self, key: ValueKey<StdTensorOp>, meta: TensorMeta) {
554        self.metadata.insert(key, meta);
555    }
556
557    #[doc(hidden)]
558    pub fn extend_metadata<I>(&mut self, entries: I)
559    where
560        I: IntoIterator<Item = (ValueKey<StdTensorOp>, TensorMeta)>,
561    {
562        self.metadata.extend(entries);
563    }
564
565    fn resolve_key_if_available<'a>(
566        &'a self,
567        val: &'a ValueRef<StdTensorOp>,
568    ) -> Option<&'a ValueKey<StdTensorOp>> {
569        match val {
570            ValueRef::External(key) => Some(key),
571            ValueRef::Local(local_id) => self
572                .local_keys
573                .as_ref()
574                .and_then(|keys| keys.get(*local_id)),
575        }
576    }
577
578    fn resolve_key<'a>(
579        &'a self,
580        val: &'a ValueRef<StdTensorOp>,
581    ) -> ShapeGuardResult<&'a ValueKey<StdTensorOp>> {
582        match val {
583            ValueRef::External(key) => Ok(key),
584            ValueRef::Local(local_id) if self.local_keys.is_none() => {
585                Err(ShapeGuardError::LocalWithoutAttachedGraph {
586                    local_id: *local_id,
587                })
588            }
589            ValueRef::Local(local_id) => self
590                .local_keys
591                .as_ref()
592                .and_then(|keys| keys.get(*local_id))
593                .ok_or(ShapeGuardError::LocalOutOfBounds {
594                    local_id: *local_id,
595                }),
596        }
597    }
598
599    fn ensure_metadata_loaded(&mut self, key: &ValueKey<StdTensorOp>) {
600        if !self.metadata.contains_key(key) && self.use_global_registry {
601            if let Ok(Some(meta)) = lookup_global_metadata(key) {
602                self.metadata.insert(key.clone(), meta);
603            }
604        }
605    }
606}
607
608/// Look up a single metadata entry from the global registry.
609///
610/// Locks the registry briefly for a single `HashMap::get` + clone.
611///
612/// # Examples
613///
614/// ```
615/// use computegraph::types::ValueKey;
616/// use tenferro_ops::ad::context::lookup_global_metadata;
617/// use tenferro_ops::input_key::TensorInputKey;
618/// use tenferro_ops::std_tensor_op::StdTensorOp;
619///
620/// let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 99 });
621/// let meta = lookup_global_metadata(&key).unwrap();
622/// assert!(meta.is_none());
623/// ```
624pub fn lookup_global_metadata(
625    key: &ValueKey<StdTensorOp>,
626) -> Result<Option<TensorMeta>, MetadataRegistryError> {
627    let guard = global_metadata_registry()
628        .lock()
629        .map_err(|_| MetadataRegistryError::LockPoisoned)?;
630    Ok(guard.get(key).map(|entry| entry.meta.clone()))
631}
632
633#[doc(hidden)]
634pub fn register_scoped_global_metadata_batch<I>(
635    entries: I,
636) -> Result<GlobalMetadataScope, MetadataRegistryError>
637where
638    I: IntoIterator<Item = (ValueKey<StdTensorOp>, TensorMeta)>,
639{
640    let mut guard = global_metadata_registry()
641        .lock()
642        .map_err(|_| MetadataRegistryError::LockPoisoned)?;
643    let mut keys = Vec::new();
644    for (key, meta) in entries {
645        let entry = guard.entry(key.clone()).or_insert(GlobalMetadataEntry {
646            meta: meta.clone(),
647            scoped_refs: 0,
648        });
649        entry.meta = meta;
650        entry.scoped_refs += 1;
651        keys.push(key);
652    }
653    Ok(GlobalMetadataScope { keys })
654}
655
656fn release_scoped_global_metadata(keys: &[ValueKey<StdTensorOp>]) {
657    let Ok(mut guard) = global_metadata_registry().lock() else {
658        // Drop cannot return an error. Failing closed here avoids reading or
659        // mutating data from a poisoned registry at the cost of leaking entries
660        // until process exit.
661        return;
662    };
663    for key in keys {
664        let should_remove = if let Some(entry) = guard.get_mut(key) {
665            entry.scoped_refs = entry.scoped_refs.saturating_sub(1);
666            entry.scoped_refs == 0
667        } else {
668            false
669        };
670        if should_remove {
671            guard.remove(key);
672        }
673    }
674}
675
676/// Resolve a [`DimExpr`] to a concrete `usize`.
677#[doc(hidden)]
678pub fn resolve_dim(dim: &DimExpr) -> Result<usize, DimExprEvalError> {
679    dim.eval(&[])
680}
681
682/// Resolve matrix dimensions and record their ordering as a guard.
683#[doc(hidden)]
684pub fn resolve_and_guard(
685    m: &DimExpr,
686    n: &DimExpr,
687    ctx: &mut ShapeGuardContext,
688) -> Result<(usize, usize), DimExprEvalError> {
689    let m_size = resolve_dim(m)?;
690    let n_size = resolve_dim(n)?;
691    ctx.guards.push(ShapeGuard {
692        dim_a: m_size,
693        dim_b: n_size,
694        ordering: m_size.cmp(&n_size),
695    });
696    Ok((m_size, n_size))
697}