pub trait ReverseRule<V>: Send + Syncwhere
V: Differentiable,{
// Required methods
fn pullback(
&self,
cotangent: &<V as Differentiable>::Tangent,
) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent)>, AutodiffError>;
fn inputs(&self) -> Vec<NodeId>;
// Provided methods
fn forward_tangents<'t>(
&self,
input_tangents: &dyn Fn(NodeId) -> Option<&'t <V as Differentiable>::Tangent>,
) -> Result<Option<<V as Differentiable>::Tangent>, AutodiffError>
where <V as Differentiable>::Tangent: 't { ... }
fn pullback_with_tangents<'t>(
&self,
cotangent: &<V as Differentiable>::Tangent,
cotangent_tangent: &<V as Differentiable>::Tangent,
input_tangents: &dyn Fn(NodeId) -> Option<&'t <V as Differentiable>::Tangent>,
) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent, <V as Differentiable>::Tangent)>, AutodiffError>
where <V as Differentiable>::Tangent: 't { ... }
}Expand description
Reverse-mode AD rule interface (rrule).
Implemented by operation-specific nodes (einsum, reduce, permute, …).
Named after Julia’s ChainRules.jl convention: rrule returns a pullback.
The type parameter V is the differentiable value type (e.g., Tensor<f64>).
Implementors must be Send + Sync because rule objects may be stored on an
AD tape that is shared across threads.
§Examples
Custom reverse rule for scalar multiplication output = a * b:
use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};
struct ScalarMulRule {
a: f64,
b: f64,
a_node: NodeId,
b_node: NodeId,
}
impl ReverseRule<f64> for ScalarMulRule {
fn pullback(&self, cotangent: &f64) -> AdResult<Vec<(NodeId, f64)>> {
// d(a*b)/da = b, d(a*b)/db = a
let da = cotangent * self.b;
let db = cotangent * self.a;
Ok(vec![(self.a_node, da), (self.b_node, db)])
}
fn inputs(&self) -> Vec<NodeId> {
vec![self.a_node, self.b_node]
}
}
// Verify: for a=3, b=5, cotangent=1 → da=5, db=3
let rule = ScalarMulRule {
a: 3.0, b: 5.0,
a_node: NodeId::new(0), b_node: NodeId::new(1),
};
let grads = rule.pullback(&1.0).unwrap();
assert_eq!(grads[0], (NodeId::new(0), 5.0)); // da = cotangent * b
assert_eq!(grads[1], (NodeId::new(1), 3.0)); // db = cotangent * aRequired Methods§
Sourcefn pullback(
&self,
cotangent: &<V as Differentiable>::Tangent,
) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent)>, AutodiffError>
fn pullback( &self, cotangent: &<V as Differentiable>::Tangent, ) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent)>, AutodiffError>
Computes input cotangents from an output cotangent (pullback).
Provided Methods§
Sourcefn forward_tangents<'t>(
&self,
input_tangents: &dyn Fn(NodeId) -> Option<&'t <V as Differentiable>::Tangent>,
) -> Result<Option<<V as Differentiable>::Tangent>, AutodiffError>where
<V as Differentiable>::Tangent: 't,
fn forward_tangents<'t>(
&self,
input_tangents: &dyn Fn(NodeId) -> Option<&'t <V as Differentiable>::Tangent>,
) -> Result<Option<<V as Differentiable>::Tangent>, AutodiffError>where
<V as Differentiable>::Tangent: 't,
Computes the forward tangent of this operation’s output.
Given a closure that returns the tangent for each input node
(or None if the input has no tangent), returns the output tangent.
The default implementation returns AutodiffError::HvpNotSupported.
Operations that support deferred HVP override this method.
Sourcefn pullback_with_tangents<'t>(
&self,
cotangent: &<V as Differentiable>::Tangent,
cotangent_tangent: &<V as Differentiable>::Tangent,
input_tangents: &dyn Fn(NodeId) -> Option<&'t <V as Differentiable>::Tangent>,
) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent, <V as Differentiable>::Tangent)>, AutodiffError>where
<V as Differentiable>::Tangent: 't,
fn pullback_with_tangents<'t>(
&self,
cotangent: &<V as Differentiable>::Tangent,
cotangent_tangent: &<V as Differentiable>::Tangent,
input_tangents: &dyn Fn(NodeId) -> Option<&'t <V as Differentiable>::Tangent>,
) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent, <V as Differentiable>::Tangent)>, AutodiffError>where
<V as Differentiable>::Tangent: 't,
Computes pullback with tangent propagation for HVP.
Given an output cotangent, its tangent, and a closure providing input
tangents by node ID, returns
(node_id, input_cotangent, input_cotangent_tangent) triples.
The input_tangents closure provides access to forward-propagated
tangents for each input node, enabling deferred tangent injection
without storing tangents in the rule struct.
The default implementation returns AutodiffError::HvpNotSupported.
Operations that support forward-over-reverse HVP override this method.
§Examples
// Called internally by hvp(); users rarely call this directly.
let results = rule.pullback_with_tangents(
&cotangent, &cotangent_tangent, &|node| tangents_vec[node.index()].as_ref(),
)?;
for (node_id, grad, grad_tangent) in results {
// grad: standard cotangent for this input
// grad_tangent: cotangent tangent for HVP
}