Skip to main content

ExtensionAdRule

Trait ExtensionAdRule 

Source
pub trait ExtensionAdRule:
    Debug
    + Send
    + Sync
    + 'static {
    // Required methods
    fn family_id(&self) -> &'static str;
    fn linearize(
        &self,
        op: &dyn ExtensionOp,
        builder: &mut dyn PrimitiveRuleBuilder,
        primal_in: &[ValueKey<StdTensorOp>],
        primal_out: &[ValueKey<StdTensorOp>],
        tangent_in: &[Option<LocalValueId>],
        ctx: &mut ShapeGuardContext,
    ) -> ADRuleResult<Vec<Option<LocalValueId>>>;
    fn transpose_rule(
        &self,
        op: &dyn ExtensionOp,
        builder: &mut dyn PrimitiveRuleBuilder,
        cotangent_out: &[Option<LocalValueId>],
        inputs: &[ValueRef<StdTensorOp>],
        mode: &OperationRole,
        ctx: &mut ShapeGuardContext,
    ) -> ADRuleResult<Vec<Option<LocalValueId>>>;
}
Expand description

AD rule provider for an extension family.

Rules are registered independently from the primal operation so an out-of-tree crate can provide forward execution without AD, or gate AD support behind an optional feature. Rule methods receive the concrete ExtensionOp payload as a trait object; implementations should downcast through ExtensionOp::as_any when they need payload-specific parameters.

Required Methods§

Source

fn family_id(&self) -> &'static str

The extension family this rule handles.

Source

fn linearize( &self, op: &dyn ExtensionOp, builder: &mut dyn PrimitiveRuleBuilder, primal_in: &[ValueKey<StdTensorOp>], primal_out: &[ValueKey<StdTensorOp>], tangent_in: &[Option<LocalValueId>], ctx: &mut ShapeGuardContext, ) -> ADRuleResult<Vec<Option<LocalValueId>>>

Emit the linear (JVP) rule.

Source

fn transpose_rule( &self, op: &dyn ExtensionOp, builder: &mut dyn PrimitiveRuleBuilder, cotangent_out: &[Option<LocalValueId>], inputs: &[ValueRef<StdTensorOp>], mode: &OperationRole, ctx: &mut ShapeGuardContext, ) -> ADRuleResult<Vec<Option<LocalValueId>>>

Emit the transpose (VJP) rule.

Implementors§