OpRule

Trait OpRule 

Source
pub trait OpRule<V: Differentiable> {
    // Required methods
    fn eval(&self, inputs: &[&V]) -> Result<V>;
    fn rrule(
        &self,
        inputs: &[&V],
        out: &V,
        cotangent: &V::Primal,
    ) -> AdResult<Vec<V::Primal>>;
    fn frule(
        &self,
        inputs: &[&V],
        tangents: &[Option<&V::Primal>],
    ) -> AdResult<V::Primal>;
    fn hvp(
        &self,
        inputs: &[&V],
        cotangent: &V::Primal,
        cotangent_tangent: Option<&V::Primal>,
        input_tangents: &[Option<&V::Primal>],
    ) -> AdResult<Vec<(V::Primal, V::Primal)>>;
}
Expand description

Operation-level AD rules (rrule, frule, hvp).

§Examples

use ad_tensors_rs::{AdResult, AdValue, Differentiable, OpRule, Result};

struct IdentityRule;

impl OpRule<AdValue<f64>> for IdentityRule {
    fn eval(&self, inputs: &[&AdValue<f64>]) -> Result<AdValue<f64>> {
        Ok((*inputs[0]).clone())
    }

    fn rrule(
        &self,
        _inputs: &[&AdValue<f64>],
        _out: &AdValue<f64>,
        cotangent: &f64,
    ) -> AdResult<Vec<f64>> {
        Ok(vec![*cotangent])
    }

    fn frule(
        &self,
        _inputs: &[&AdValue<f64>],
        tangents: &[Option<&f64>],
    ) -> AdResult<f64> {
        Ok(tangents[0].copied().unwrap_or(0.0))
    }

    fn hvp(
        &self,
        _inputs: &[&AdValue<f64>],
        cotangent: &f64,
        cotangent_tangent: Option<&f64>,
        _input_tangents: &[Option<&f64>],
    ) -> AdResult<Vec<(f64, f64)>> {
        Ok(vec![(*cotangent, cotangent_tangent.copied().unwrap_or(0.0))])
    }
}

let x = AdValue::primal(2.0_f64);
let rule = IdentityRule;
let y = rule.eval(&[&x]).unwrap();
assert_eq!(y.primal_ref(), &2.0);

Required Methods§

Source

fn eval(&self, inputs: &[&V]) -> Result<V>

Compute primal output.

Source

fn rrule( &self, inputs: &[&V], out: &V, cotangent: &V::Primal, ) -> AdResult<Vec<V::Primal>>

Reverse-mode pullback.

Source

fn frule( &self, inputs: &[&V], tangents: &[Option<&V::Primal>], ) -> AdResult<V::Primal>

Forward-mode pushforward.

Source

fn hvp( &self, inputs: &[&V], cotangent: &V::Primal, cotangent_tangent: Option<&V::Primal>, input_tangents: &[Option<&V::Primal>], ) -> AdResult<Vec<(V::Primal, V::Primal)>>

Hessian-vector product.

Implementors§