pub trait ExtensionAdRule:
Debug
+ Send
+ Sync
+ 'static {
// Required methods
fn family_id(&self) -> &'static str;
fn linearize(
&self,
op: &dyn ExtensionOp,
builder: &mut dyn PrimitiveRuleBuilder,
primal_in: &[ValueKey<StdTensorOp>],
primal_out: &[ValueKey<StdTensorOp>],
tangent_in: &[Option<LocalValueId>],
ctx: &mut ShapeGuardContext,
) -> ADRuleResult<Vec<Option<LocalValueId>>>;
fn transpose_rule(
&self,
op: &dyn ExtensionOp,
builder: &mut dyn PrimitiveRuleBuilder,
cotangent_out: &[Option<LocalValueId>],
inputs: &[ValueRef<StdTensorOp>],
mode: &OperationRole,
ctx: &mut ShapeGuardContext,
) -> ADRuleResult<Vec<Option<LocalValueId>>>;
}Expand description
AD rule provider for an extension family.
Rules are registered independently from the primal operation so an
out-of-tree crate can provide forward execution without AD, or gate AD
support behind an optional feature. Rule methods receive the concrete
ExtensionOp payload as a trait object; implementations should downcast
through ExtensionOp::as_any when they need payload-specific parameters.
Required Methods§
Sourcefn linearize(
&self,
op: &dyn ExtensionOp,
builder: &mut dyn PrimitiveRuleBuilder,
primal_in: &[ValueKey<StdTensorOp>],
primal_out: &[ValueKey<StdTensorOp>],
tangent_in: &[Option<LocalValueId>],
ctx: &mut ShapeGuardContext,
) -> ADRuleResult<Vec<Option<LocalValueId>>>
fn linearize( &self, op: &dyn ExtensionOp, builder: &mut dyn PrimitiveRuleBuilder, primal_in: &[ValueKey<StdTensorOp>], primal_out: &[ValueKey<StdTensorOp>], tangent_in: &[Option<LocalValueId>], ctx: &mut ShapeGuardContext, ) -> ADRuleResult<Vec<Option<LocalValueId>>>
Emit the linear (JVP) rule.
Sourcefn transpose_rule(
&self,
op: &dyn ExtensionOp,
builder: &mut dyn PrimitiveRuleBuilder,
cotangent_out: &[Option<LocalValueId>],
inputs: &[ValueRef<StdTensorOp>],
mode: &OperationRole,
ctx: &mut ShapeGuardContext,
) -> ADRuleResult<Vec<Option<LocalValueId>>>
fn transpose_rule( &self, op: &dyn ExtensionOp, builder: &mut dyn PrimitiveRuleBuilder, cotangent_out: &[Option<LocalValueId>], inputs: &[ValueRef<StdTensorOp>], mode: &OperationRole, ctx: &mut ShapeGuardContext, ) -> ADRuleResult<Vec<Option<LocalValueId>>>
Emit the transpose (VJP) rule.