pub trait OpRule<V: Differentiable> {
// Required methods
fn eval(&self, inputs: &[&V]) -> Result<V>;
fn rrule(
&self,
inputs: &[&V],
out: &V,
cotangent: &V::Primal,
) -> AdResult<Vec<V::Primal>>;
fn frule(
&self,
inputs: &[&V],
tangents: &[Option<&V::Primal>],
) -> AdResult<V::Primal>;
fn hvp(
&self,
inputs: &[&V],
cotangent: &V::Primal,
cotangent_tangent: Option<&V::Primal>,
input_tangents: &[Option<&V::Primal>],
) -> AdResult<Vec<(V::Primal, V::Primal)>>;
}Expand description
Operation-level AD rules (rrule, frule, hvp).
§Examples
use ad_tensors_rs::{AdResult, AdValue, Differentiable, OpRule, Result};
struct IdentityRule;
impl OpRule<AdValue<f64>> for IdentityRule {
fn eval(&self, inputs: &[&AdValue<f64>]) -> Result<AdValue<f64>> {
Ok((*inputs[0]).clone())
}
fn rrule(
&self,
_inputs: &[&AdValue<f64>],
_out: &AdValue<f64>,
cotangent: &f64,
) -> AdResult<Vec<f64>> {
Ok(vec![*cotangent])
}
fn frule(
&self,
_inputs: &[&AdValue<f64>],
tangents: &[Option<&f64>],
) -> AdResult<f64> {
Ok(tangents[0].copied().unwrap_or(0.0))
}
fn hvp(
&self,
_inputs: &[&AdValue<f64>],
cotangent: &f64,
cotangent_tangent: Option<&f64>,
_input_tangents: &[Option<&f64>],
) -> AdResult<Vec<(f64, f64)>> {
Ok(vec![(*cotangent, cotangent_tangent.copied().unwrap_or(0.0))])
}
}
let x = AdValue::primal(2.0_f64);
let rule = IdentityRule;
let y = rule.eval(&[&x]).unwrap();
assert_eq!(y.primal_ref(), &2.0);Required Methods§
Sourcefn rrule(
&self,
inputs: &[&V],
out: &V,
cotangent: &V::Primal,
) -> AdResult<Vec<V::Primal>>
fn rrule( &self, inputs: &[&V], out: &V, cotangent: &V::Primal, ) -> AdResult<Vec<V::Primal>>
Reverse-mode pullback.