ForwardRule

Trait ForwardRule 

Source
pub trait ForwardRule<V: Differentiable>: Send + Sync {
    // Required method
    fn pushforward(
        &self,
        tangents: &[Option<&V::Tangent>],
    ) -> AdResult<V::Tangent>;
}
Expand description

Forward-mode AD rule interface (frule).

Named after Julia’s ChainRules.jl convention: frule computes pushforward.

The type parameter V is the differentiable value type (e.g., Tensor<f64>).

§Examples

Custom forward rule for scalar multiplication output = a * b:

use chainrules_core::{ForwardRule, Differentiable, AdResult};

struct ScalarMulFrule {
    a: f64,
    b: f64,
}

impl ForwardRule<f64> for ScalarMulFrule {
    fn pushforward(&self, tangents: &[Option<&f64>]) -> AdResult<f64> {
        // d(a*b) = da*b + a*db
        let da = tangents.get(0).and_then(|t| *t).copied().unwrap_or(0.0);
        let db = tangents.get(1).and_then(|t| *t).copied().unwrap_or(0.0);
        Ok(da * self.b + self.a * db)
    }
}

// Verify: for a=3, b=5, da=1, db=0 → d(a*b) = 1*5 + 3*0 = 5
let rule = ScalarMulFrule { a: 3.0, b: 5.0 };
let result = rule.pushforward(&[Some(&1.0), Some(&0.0)]).unwrap();
assert_eq!(result, 5.0);

// Both tangents active: da=1, db=1 → d(a*b) = 1*5 + 3*1 = 8
let result = rule.pushforward(&[Some(&1.0), Some(&1.0)]).unwrap();
assert_eq!(result, 8.0);

Required Methods§

Source

fn pushforward(&self, tangents: &[Option<&V::Tangent>]) -> AdResult<V::Tangent>

Computes output tangent from input tangents (pushforward).

Implementors§