Trait ForwardRule
pub trait ForwardRule<V>: Send + Syncwhere
V: Differentiable,{
// Required method
fn pushforward(
&self,
tangents: &[Option<&<V as Differentiable>::Tangent>],
) -> Result<<V as Differentiable>::Tangent, AutodiffError>;
}Expand description
Forward-mode AD rule interface (frule).
Named after Julia’s ChainRules.jl convention: frule computes pushforward.
The type parameter V is the differentiable value type (e.g., Tensor<f64>).
§Examples
Custom forward rule for scalar multiplication output = a * b:
use chainrules_core::{ForwardRule, Differentiable, AdResult};
struct ScalarMulFrule {
a: f64,
b: f64,
}
impl ForwardRule<f64> for ScalarMulFrule {
fn pushforward(&self, tangents: &[Option<&f64>]) -> AdResult<f64> {
// d(a*b) = da*b + a*db
let da = tangents.get(0).and_then(|t| *t).copied().unwrap_or(0.0);
let db = tangents.get(1).and_then(|t| *t).copied().unwrap_or(0.0);
Ok(da * self.b + self.a * db)
}
}
// Verify: for a=3, b=5, da=1, db=0 → d(a*b) = 1*5 + 3*0 = 5
let rule = ScalarMulFrule { a: 3.0, b: 5.0 };
let result = rule.pushforward(&[Some(&1.0), Some(&0.0)]).unwrap();
assert_eq!(result, 5.0);
// Both tangents active: da=1, db=1 → d(a*b) = 1*5 + 3*1 = 8
let result = rule.pushforward(&[Some(&1.0), Some(&1.0)]).unwrap();
assert_eq!(result, 8.0);Required Methods§
fn pushforward(
&self,
tangents: &[Option<&<V as Differentiable>::Tangent>],
) -> Result<<V as Differentiable>::Tangent, AutodiffError>
fn pushforward( &self, tangents: &[Option<&<V as Differentiable>::Tangent>], ) -> Result<<V as Differentiable>::Tangent, AutodiffError>
Computes output tangent from input tangents (pushforward).