pub struct ShapeGuardContext { /* private fields */ }Expand description
AD context providing dimension resolution, guard recording, and value metadata.
§Examples
use tenferro_ops::ShapeGuardContext;
let ctx = ShapeGuardContext::default();
assert!(ctx.guards().is_empty());Implementations§
Source§impl ShapeGuardContext
impl ShapeGuardContext
Sourcepub fn with_global_metadata() -> Self
pub fn with_global_metadata() -> Self
Create a context backed by the global metadata registry.
Instead of cloning the entire global registry up-front (which used
to be O(N) per AD pass and quadratic across oracle_replay), the
context keeps a flag and lazily fetches entries from the shared
lookup_global_metadata on first miss, caching into its local
metadata map for subsequent reads within the same pass.
§Examples
let ctx = tenferro_ops::ShapeGuardContext::with_global_metadata();
assert!(ctx.guards().is_empty());Sourcepub fn with_extension_rules(self, rules: ExtensionRuleSet) -> Self
pub fn with_extension_rules(self, rules: ExtensionRuleSet) -> Self
Use an explicit extension AD rule set for this context.
Extension AD lookup is context-owned: a context without an attached rule set has no extension AD rules.
§Examples
use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
let _ctx = ShapeGuardContext::default().with_extension_rules(ExtensionRuleSet::new());Sourcepub fn guards(&self) -> &[ShapeGuard]
pub fn guards(&self) -> &[ShapeGuard]
Returns the guards recorded so far.
§Examples
use tenferro_ops::ShapeGuardContext;
let ctx = ShapeGuardContext::default();
assert_eq!(ctx.guards(), &[]);Sourcepub fn clear_guards(&mut self)
pub fn clear_guards(&mut self)
Clears all recorded guards.
§Examples
use tenferro_ops::ShapeGuardContext;
let mut ctx = ShapeGuardContext::default();
ctx.clear_guards();
assert!(ctx.guards().is_empty());Sourcepub fn shape_of(
&mut self,
val: &ValueRef<StdTensorOp>,
) -> ShapeGuardResult<Vec<SymDim>>
pub fn shape_of( &mut self, val: &ValueRef<StdTensorOp>, ) -> ShapeGuardResult<Vec<SymDim>>
Return the shape metadata for a value reference.
§Examples
use computegraph::types::{ValueKey, ValueRef};
use tenferro_ops::input_key::TensorInputKey;
use tenferro_ops::std_tensor_op::StdTensorOp;
use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
use tenferro_tensor::DType;
let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
let value = ValueRef::External(key.clone());
let mut ctx = ShapeGuardContext::default();
ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
let shape = ctx.shape_of(&value).unwrap();
assert_eq!(shape, &[SymDim::from(4usize)]);Sourcepub fn rank_of(
&mut self,
val: &ValueRef<StdTensorOp>,
) -> ShapeGuardResult<usize>
pub fn rank_of( &mut self, val: &ValueRef<StdTensorOp>, ) -> ShapeGuardResult<usize>
Return the rank for a value reference without requiring exact extents.
Use this when an AD rule only needs axis count or needs to build
runtime-shape references. Calling ShapeGuardContext::shape_of in those
cases would reject valid values such as DynamicTruncate outputs whose
runtime extent is known only as an upper bound.
§Examples
use computegraph::types::{ValueKey, ValueRef};
use tenferro_ops::input_key::TensorInputKey;
use tenferro_ops::std_tensor_op::StdTensorOp;
use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
use tenferro_tensor::DType;
let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
let value = ValueRef::External(key.clone());
let mut ctx = ShapeGuardContext::default();
ctx.insert_metadata(
key,
TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
);
assert_eq!(ctx.rank_of(&value).unwrap(), 1);Sourcepub fn extents_of(
&mut self,
val: &ValueRef<StdTensorOp>,
) -> ShapeGuardResult<&[ShapeExtent<SymDim>]>
pub fn extents_of( &mut self, val: &ValueRef<StdTensorOp>, ) -> ShapeGuardResult<&[ShapeExtent<SymDim>]>
Return per-axis shape guarantees for a value reference.
§Examples
use computegraph::types::{ValueKey, ValueRef};
use tenferro_ops::input_key::TensorInputKey;
use tenferro_ops::std_tensor_op::StdTensorOp;
use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
use tenferro_tensor::DType;
let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
let value = ValueRef::External(key.clone());
let mut ctx = ShapeGuardContext::default();
ctx.insert_metadata(
key,
TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
);
let extents = ctx.extents_of(&value).unwrap();
assert_eq!(extents[0], ShapeExtent::upper_bound(SymDim::from(8usize)));Sourcepub fn exact_shape_of(
&mut self,
val: &ValueRef<StdTensorOp>,
) -> ShapeGuardResult<Option<Vec<SymDim>>>
pub fn exact_shape_of( &mut self, val: &ValueRef<StdTensorOp>, ) -> ShapeGuardResult<Option<Vec<SymDim>>>
Return the exact shape for a value reference, if all axes are exact.
§Examples
use computegraph::types::{ValueKey, ValueRef};
use tenferro_ops::input_key::TensorInputKey;
use tenferro_ops::std_tensor_op::StdTensorOp;
use tenferro_ops::{ShapeExtent, ShapeGuardContext, SymDim, TensorMeta};
use tenferro_tensor::DType;
let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
let value = ValueRef::External(key.clone());
let mut ctx = ShapeGuardContext::default();
ctx.insert_metadata(
key,
TensorMeta::with_extents(DType::F64, vec![ShapeExtent::upper_bound(SymDim::from(8usize))]),
);
let maybe_shape = ctx.exact_shape_of(&value).unwrap();
assert_eq!(maybe_shape, None);Sourcepub fn dtype_of(
&mut self,
val: &ValueRef<StdTensorOp>,
) -> ShapeGuardResult<DType>
pub fn dtype_of( &mut self, val: &ValueRef<StdTensorOp>, ) -> ShapeGuardResult<DType>
Return the dtype metadata for a value reference.
§Examples
use computegraph::types::{ValueKey, ValueRef};
use tenferro_ops::input_key::TensorInputKey;
use tenferro_ops::std_tensor_op::StdTensorOp;
use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
use tenferro_tensor::DType;
let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
let value = ValueRef::External(key.clone());
let mut ctx = ShapeGuardContext::default();
ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
let dtype = ctx.dtype_of(&value).unwrap();
assert_eq!(dtype, DType::F64);Sourcepub fn metadata_of(
&mut self,
val: &ValueRef<StdTensorOp>,
) -> ShapeGuardResult<&TensorMeta>
pub fn metadata_of( &mut self, val: &ValueRef<StdTensorOp>, ) -> ShapeGuardResult<&TensorMeta>
Return the complete metadata record for a value reference.
§Examples
use computegraph::types::{ValueKey, ValueRef};
use tenferro_ops::input_key::TensorInputKey;
use tenferro_ops::std_tensor_op::StdTensorOp;
use tenferro_ops::{ShapeGuardContext, SymDim, TensorMeta};
use tenferro_tensor::DType;
let key = ValueKey::<StdTensorOp>::Input(TensorInputKey::User { id: 1 });
let value = ValueRef::External(key.clone());
let mut ctx = ShapeGuardContext::default();
ctx.insert_metadata(key, TensorMeta::exact(DType::F64, vec![SymDim::from(4usize)]));
let meta = ctx.metadata_of(&value).unwrap();
assert_eq!(meta.dtype, DType::F64);Trait Implementations§
Source§impl Clone for ShapeGuardContext
impl Clone for ShapeGuardContext
Source§fn clone(&self) -> ShapeGuardContext
fn clone(&self) -> ShapeGuardContext
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read more