pub trait ReverseRule<V: Differentiable> {
// Required methods
fn pullback(
&self,
cotangent: &V::Tangent,
) -> AdResult<Vec<(NodeId, V::Tangent)>>;
fn inputs(&self) -> Vec<NodeId>;
// Provided method
fn pullback_with_tangents(
&self,
cotangent: &V::Tangent,
cotangent_tangent: &V::Tangent,
) -> AdResult<Vec<(NodeId, V::Tangent, V::Tangent)>> { ... }
}Expand description
Reverse-mode AD rule interface (rrule).
Implemented by operation-specific nodes (einsum, reduce, permute, …).
Named after Julia’s ChainRules.jl convention: rrule returns a pullback.
The type parameter V is the differentiable value type (e.g., Tensor<f64>).
§Examples
ⓘ
use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};
struct MyRule;
impl<V: Differentiable> ReverseRule<V> for MyRule {
fn pullback(&self, cotangent: &V::Tangent)
-> AdResult<Vec<(NodeId, V::Tangent)>> {
todo!()
}
fn inputs(&self) -> Vec<NodeId> { vec![] }
}Required Methods§
Provided Methods§
Sourcefn pullback_with_tangents(
&self,
cotangent: &V::Tangent,
cotangent_tangent: &V::Tangent,
) -> AdResult<Vec<(NodeId, V::Tangent, V::Tangent)>>
fn pullback_with_tangents( &self, cotangent: &V::Tangent, cotangent_tangent: &V::Tangent, ) -> AdResult<Vec<(NodeId, V::Tangent, V::Tangent)>>
Computes pullback with tangent propagation for HVP.
Given an output cotangent and its tangent, returns
(node_id, input_cotangent, input_cotangent_tangent) triples.
The default implementation returns AutodiffError::HvpNotSupported.
Operations that support forward-over-reverse HVP override this method.
§Examples
ⓘ
// Called internally by hvp(); users rarely call this directly.
let results = rule.pullback_with_tangents(&cotangent, &cotangent_tangent)?;
for (node_id, grad, grad_tangent) in results {
// grad: standard cotangent for this input
// grad_tangent: cotangent tangent for HVP
}