ReverseRule

Trait ReverseRule 

pub trait ReverseRule<V>
where V: Differentiable,
{ // Required methods fn pullback( &self, cotangent: &<V as Differentiable>::Tangent, ) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent)>, AutodiffError>; fn inputs(&self) -> Vec<NodeId>; // Provided method fn pullback_with_tangents( &self, cotangent: &<V as Differentiable>::Tangent, cotangent_tangent: &<V as Differentiable>::Tangent, ) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent, <V as Differentiable>::Tangent)>, AutodiffError> { ... } }
Expand description

Reverse-mode AD rule interface (rrule).

Implemented by operation-specific nodes (einsum, reduce, permute, …). Named after Julia’s ChainRules.jl convention: rrule returns a pullback.

The type parameter V is the differentiable value type (e.g., Tensor<f64>).

§Examples

use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};

struct MyRule;
impl<V: Differentiable> ReverseRule<V> for MyRule {
    fn pullback(&self, cotangent: &V::Tangent)
        -> AdResult<Vec<(NodeId, V::Tangent)>> {
        todo!()
    }
    fn inputs(&self) -> Vec<NodeId> { vec![] }
}

Required Methods§

fn pullback( &self, cotangent: &<V as Differentiable>::Tangent, ) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent)>, AutodiffError>

Computes input cotangents from an output cotangent (pullback).

fn inputs(&self) -> Vec<NodeId>

Returns input node IDs this rule depends on.

Provided Methods§

fn pullback_with_tangents( &self, cotangent: &<V as Differentiable>::Tangent, cotangent_tangent: &<V as Differentiable>::Tangent, ) -> Result<Vec<(NodeId, <V as Differentiable>::Tangent, <V as Differentiable>::Tangent)>, AutodiffError>

Computes pullback with tangent propagation for HVP.

Given an output cotangent and its tangent, returns (node_id, input_cotangent, input_cotangent_tangent) triples.

The default implementation returns AutodiffError::HvpNotSupported. Operations that support forward-over-reverse HVP override this method.

§Examples
// Called internally by hvp(); users rarely call this directly.
let results = rule.pullback_with_tangents(&cotangent, &cotangent_tangent)?;
for (node_id, grad, grad_tangent) in results {
    // grad: standard cotangent for this input
    // grad_tangent: cotangent tangent for HVP
}

Implementors§