ReverseRule

Trait ReverseRule 

Source
pub trait ReverseRule<V: Differentiable>: Send + Sync {
    // Required methods
    fn pullback(
        &self,
        cotangent: &V::Tangent,
    ) -> AdResult<Vec<PullbackEntry<V>>>;
    fn inputs(&self) -> Vec<NodeId>;

    // Provided methods
    fn forward_tangents<'t>(
        &self,
        input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>,
    ) -> AdResult<Option<V::Tangent>>
       where V::Tangent: 't { ... }
    fn pullback_with_tangents<'t>(
        &self,
        cotangent: &V::Tangent,
        cotangent_tangent: &V::Tangent,
        input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>,
    ) -> AdResult<Vec<PullbackWithTangentsEntry<V>>>
       where V::Tangent: 't { ... }
}
Expand description

Reverse-mode AD rule interface (rrule).

Implemented by operation-specific nodes (einsum, reduce, permute, …). Named after Julia’s ChainRules.jl convention: rrule returns a pullback.

The type parameter V is the differentiable value type (e.g., Tensor<f64>).

Implementors must be Send + Sync because rule objects may be stored on an AD tape that is shared across threads.

§Examples

Custom reverse rule for scalar multiplication output = a * b:

use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};

struct ScalarMulRule {
    a: f64,
    b: f64,
    a_node: NodeId,
    b_node: NodeId,
}

impl ReverseRule<f64> for ScalarMulRule {
    fn pullback(&self, cotangent: &f64) -> AdResult<Vec<(NodeId, f64)>> {
        // d(a*b)/da = b, d(a*b)/db = a
        let da = cotangent * self.b;
        let db = cotangent * self.a;
        Ok(vec![(self.a_node, da), (self.b_node, db)])
    }

    fn inputs(&self) -> Vec<NodeId> {
        vec![self.a_node, self.b_node]
    }
}

// Verify: for a=3, b=5, cotangent=1 → da=5, db=3
let rule = ScalarMulRule {
    a: 3.0, b: 5.0,
    a_node: NodeId::new(0), b_node: NodeId::new(1),
};
let grads = rule.pullback(&1.0).unwrap();
assert_eq!(grads[0], (NodeId::new(0), 5.0)); // da = cotangent * b
assert_eq!(grads[1], (NodeId::new(1), 3.0)); // db = cotangent * a

Required Methods§

Source

fn pullback(&self, cotangent: &V::Tangent) -> AdResult<Vec<PullbackEntry<V>>>

Computes input cotangents from an output cotangent (pullback).

Source

fn inputs(&self) -> Vec<NodeId>

Returns input node IDs this rule depends on.

Provided Methods§

Source

fn forward_tangents<'t>( &self, input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>, ) -> AdResult<Option<V::Tangent>>
where V::Tangent: 't,

Computes the forward tangent of this operation’s output.

Given a closure that returns the tangent for each input node (or None if the input has no tangent), returns the output tangent.

The default implementation returns AutodiffError::HvpNotSupported. Operations that support deferred HVP override this method.

Source

fn pullback_with_tangents<'t>( &self, cotangent: &V::Tangent, cotangent_tangent: &V::Tangent, input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>, ) -> AdResult<Vec<PullbackWithTangentsEntry<V>>>
where V::Tangent: 't,

Computes pullback with tangent propagation for HVP.

Given an output cotangent, its tangent, and a closure providing input tangents by node ID, returns (node_id, input_cotangent, input_cotangent_tangent) triples.

The input_tangents closure provides access to forward-propagated tangents for each input node, enabling deferred tangent injection without storing tangents in the rule struct.

The default implementation returns AutodiffError::HvpNotSupported. Operations that support forward-over-reverse HVP override this method.

§Examples
// Called internally by hvp(); users rarely call this directly.
let results = rule.pullback_with_tangents(
    &cotangent, &cotangent_tangent, &|node| tangents_vec[node.index()].as_ref(),
)?;
for (node_id, grad, grad_tangent) in results {
    // grad: standard cotangent for this input
    // grad_tangent: cotangent tangent for HVP
}

Implementors§