pub struct ExtensionRuleSet { /* private fields */ }Expand description
Explicit, owned set of extension AD rules.
This is the rule container used by higher-level AD contexts. Extension AD intentionally has no process-global fallback; callers must pass the rule set that their graph needs.
Implementations§
Source§impl ExtensionRuleSet
impl ExtensionRuleSet
Sourcepub fn new() -> Self
pub fn new() -> Self
Create an empty extension rule set.
§Examples
use tenferro_ops::ExtensionRuleSet;
let rules = ExtensionRuleSet::new();
assert!(!rules.is_rule_registered("example.missing.v1"));Sourcepub fn register_rule(
&mut self,
rule: Arc<dyn ExtensionAdRule>,
) -> Result<(), ExtensionRegistryError>
pub fn register_rule( &mut self, rule: Arc<dyn ExtensionAdRule>, ) -> Result<(), ExtensionRegistryError>
Add one rule to this owned set.
§Examples
use std::sync::Arc;
use tidu::ADRuleResult;
use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
use tenferro_ops::ad::PrimitiveRuleBuilder;
use tenferro_ops::ext_op::{ExtensionAdRule, ExtensionOp};
use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
use tenferro_ops::std_tensor_op::StdTensorOp;
#[derive(Debug)]
struct Rule;
impl ExtensionAdRule for Rule {
fn family_id(&self) -> &'static str { "example.register_rule.v1" }
fn linearize(
&self,
_op: &dyn ExtensionOp,
_builder: &mut dyn PrimitiveRuleBuilder,
_primal_in: &[ValueKey<StdTensorOp>],
_primal_out: &[ValueKey<StdTensorOp>],
tangent_in: &[Option<LocalValueId>],
_ctx: &mut ShapeGuardContext,
) -> ADRuleResult<Vec<Option<LocalValueId>>> {
Ok(tangent_in.to_vec())
}
fn transpose_rule(
&self,
_op: &dyn ExtensionOp,
_builder: &mut dyn PrimitiveRuleBuilder,
cotangent_out: &[Option<LocalValueId>],
_inputs: &[ValueRef<StdTensorOp>],
_mode: &OperationRole,
_ctx: &mut ShapeGuardContext,
) -> ADRuleResult<Vec<Option<LocalValueId>>> {
Ok(cotangent_out.to_vec())
}
}
let mut rules = ExtensionRuleSet::new();
rules.register_rule(Arc::new(Rule)).unwrap();
assert!(rules.lookup_rule("example.register_rule.v1").is_some());Sourcepub fn with_rule(
self,
rule: Arc<dyn ExtensionAdRule>,
) -> Result<Self, ExtensionRegistryError>
pub fn with_rule( self, rule: Arc<dyn ExtensionAdRule>, ) -> Result<Self, ExtensionRegistryError>
Return a new rule set containing rule.
§Examples
use std::sync::Arc;
use tidu::ADRuleResult;
use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
use tenferro_ops::ad::PrimitiveRuleBuilder;
use tenferro_ops::ext_op::{ExtensionAdRule, ExtensionOp};
use tenferro_ops::{ExtensionRuleSet, ShapeGuardContext};
use tenferro_ops::std_tensor_op::StdTensorOp;
#[derive(Debug)]
struct Rule;
impl ExtensionAdRule for Rule {
fn family_id(&self) -> &'static str { "example.with_rule.v1" }
fn linearize(
&self,
_op: &dyn ExtensionOp,
_builder: &mut dyn PrimitiveRuleBuilder,
_primal_in: &[ValueKey<StdTensorOp>],
_primal_out: &[ValueKey<StdTensorOp>],
tangent_in: &[Option<LocalValueId>],
_ctx: &mut ShapeGuardContext,
) -> ADRuleResult<Vec<Option<LocalValueId>>> {
Ok(tangent_in.to_vec())
}
fn transpose_rule(
&self,
_op: &dyn ExtensionOp,
_builder: &mut dyn PrimitiveRuleBuilder,
cotangent_out: &[Option<LocalValueId>],
_inputs: &[ValueRef<StdTensorOp>],
_mode: &OperationRole,
_ctx: &mut ShapeGuardContext,
) -> ADRuleResult<Vec<Option<LocalValueId>>> {
Ok(cotangent_out.to_vec())
}
}
let rules = ExtensionRuleSet::new().with_rule(Arc::new(Rule)).unwrap();
assert!(rules.is_rule_registered("example.with_rule.v1"));Sourcepub fn merge(
&mut self,
other: ExtensionRuleSet,
) -> Result<(), ExtensionRegistryError>
pub fn merge( &mut self, other: ExtensionRuleSet, ) -> Result<(), ExtensionRegistryError>
Merge another owned rule set into this one.
The merge is atomic: if any rule in other is invalid or duplicates an
existing family, self is left unchanged.
§Examples
use tenferro_ops::ExtensionRuleSet;
let mut rules = ExtensionRuleSet::new();
rules.merge(ExtensionRuleSet::new()).unwrap();
assert!(!rules.is_rule_registered("example.missing.v1"));Sourcepub fn lookup_rule(&self, family_id: &str) -> Option<Arc<dyn ExtensionAdRule>>
pub fn lookup_rule(&self, family_id: &str) -> Option<Arc<dyn ExtensionAdRule>>
Look up a rule in this set.
§Examples
use tenferro_ops::ExtensionRuleSet;
let rules = ExtensionRuleSet::new();
assert!(rules.lookup_rule("example.missing.v1").is_none());Sourcepub fn is_rule_registered(&self, family_id: &str) -> bool
pub fn is_rule_registered(&self, family_id: &str) -> bool
Return whether family_id is present in this set.
§Examples
use tenferro_ops::ExtensionRuleSet;
let rules = ExtensionRuleSet::new();
assert!(!rules.is_rule_registered("example.missing.v1"));Trait Implementations§
Source§impl Clone for ExtensionRuleSet
impl Clone for ExtensionRuleSet
Source§fn clone(&self) -> ExtensionRuleSet
fn clone(&self) -> ExtensionRuleSet
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreSource§impl Debug for ExtensionRuleSet
Available on crate feature autodiff only.
impl Debug for ExtensionRuleSet
Available on crate feature
autodiff only.Source§impl Default for ExtensionRuleSet
impl Default for ExtensionRuleSet
Source§fn default() -> ExtensionRuleSet
fn default() -> ExtensionRuleSet
Returns the “default value” for a type. Read more
Auto Trait Implementations§
impl Freeze for ExtensionRuleSet
impl !RefUnwindSafe for ExtensionRuleSet
impl Send for ExtensionRuleSet
impl Sync for ExtensionRuleSet
impl Unpin for ExtensionRuleSet
impl UnsafeUnpin for ExtensionRuleSet
impl !UnwindSafe for ExtensionRuleSet
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more