Skip to main content

TracedTensorAdExt

Trait TracedTensorAdExt 

Source
pub trait TracedTensorAdExt {
    // Required methods
    fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor>;
    fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>>;
    fn checkpoint<B: TensorBackend>(
        &mut self,
        compiler: &mut GraphCompiler,
        executor: &mut GraphExecutor<B>,
    ) -> Result<()>;
    fn jvp(
        &self,
        wrt: &TracedTensor,
        tangent: &TracedTensor,
    ) -> Result<TracedTensor>;
    fn jvp_optional(
        &self,
        wrt: &TracedTensor,
        tangent: &TracedTensor,
    ) -> Result<Option<TracedTensor>>;
    fn vjp(
        &self,
        wrt: &TracedTensor,
        cotangent: &TracedTensor,
    ) -> Result<TracedTensor>;
    fn vjp_optional(
        &self,
        wrt: &TracedTensor,
        cotangent: &TracedTensor,
    ) -> Result<Option<TracedTensor>>;
}
Expand description

Automatic differentiation helpers for TracedTensor.

§Examples

use tenferro_ad::TracedTensorAdExt;
use tenferro_runtime::TracedTensor;

let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let loss = x.scale_real(2.0);
let maybe_dx = loss.grad_optional(&x).unwrap();
assert!(maybe_dx.is_some());

Required Methods§

Source

fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor>

Gradient of a scalar output with respect to a traced input.

For complex scalar outputs, tenferro returns the Hermitian-adjoint cotangent. To compare seed-1 scalar gradients with JAX’s public grad values, use the complex conjugate of this result. See https://tensor4all.org/tenferro-rs/guides/complex-ad.html.

§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};

fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
    let mut compiler = GraphCompiler::new();
    let program = compiler.compile(tensor).unwrap();
    let mut executor = GraphExecutor::new(CpuBackend::new());
    executor.run(&program).unwrap()
}

let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let loss = (&x * &x).unwrap();
let dx = loss.grad(&x).unwrap();

assert_eq!(eval(&dx).as_slice::<f64>().unwrap(), &[6.0]);
Source

fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>>

Like grad, but returns None when wrt is inactive.

§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_runtime::TracedTensor;

let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
let loss = (&y * &y).unwrap();

assert!(loss.grad_optional(&x).unwrap().is_none());
Source

fn checkpoint<B: TensorBackend>( &mut self, compiler: &mut GraphCompiler, executor: &mut GraphExecutor<B>, ) -> Result<()>

Evaluate this tensor and replace its graph with a concrete leaf while preserving the previous graph for AD replay.

§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};

let mut compiler = GraphCompiler::new();
let mut executor = GraphExecutor::new(CpuBackend::new());
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let mut y = (&x * &x).unwrap();

y.checkpoint(&mut compiler, &mut executor).unwrap();

let value = y.attached_data().unwrap();
assert_eq!(value.as_slice::<f64>().unwrap(), &[9.0]);
Source

fn jvp( &self, wrt: &TracedTensor, tangent: &TracedTensor, ) -> Result<TracedTensor>

Forward-mode Jacobian-vector product.

§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};

fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
    let mut compiler = GraphCompiler::new();
    let program = compiler.compile(tensor).unwrap();
    let mut executor = GraphExecutor::new(CpuBackend::new());
    executor.run(&program).unwrap()
}

let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let tangent = TracedTensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap();
let y = (&x * &x).unwrap();
let dy = y.jvp(&x, &tangent).unwrap();

assert_eq!(eval(&dy).as_slice::<f64>().unwrap(), &[12.0]);
Source

fn jvp_optional( &self, wrt: &TracedTensor, tangent: &TracedTensor, ) -> Result<Option<TracedTensor>>

Like jvp, but returns None when wrt is inactive.

§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_runtime::TracedTensor;

let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
let tangent = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
let loss = (&y * &y).unwrap();

assert!(loss.jvp_optional(&x, &tangent).unwrap().is_none());
Source

fn vjp( &self, wrt: &TracedTensor, cotangent: &TracedTensor, ) -> Result<TracedTensor>

Reverse-mode vector-Jacobian product.

Complex cotangents use tenferro’s Hermitian real-inner-product convention. Non-real complex cotangent seeds therefore need an explicit seed-convention comparison when matching JAX. See https://tensor4all.org/tenferro-rs/guides/complex-ad.html.

§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};

fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
    let mut compiler = GraphCompiler::new();
    let program = compiler.compile(tensor).unwrap();
    let mut executor = GraphExecutor::new(CpuBackend::new());
    executor.run(&program).unwrap()
}

let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let cotangent = TracedTensor::from_vec_col_major(vec![], vec![0.5_f64]).unwrap();
let y = (&x * &x).unwrap();
let dx = y.vjp(&x, &cotangent).unwrap();

assert_eq!(eval(&dx).as_slice::<f64>().unwrap(), &[3.0]);
Source

fn vjp_optional( &self, wrt: &TracedTensor, cotangent: &TracedTensor, ) -> Result<Option<TracedTensor>>

Like vjp, but returns None when wrt is inactive.

§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_runtime::TracedTensor;

let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
let cotangent = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
let loss = (&y * &y).unwrap();

assert!(loss.vjp_optional(&x, &cotangent).unwrap().is_none());

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl TracedTensorAdExt for TracedTensor

Source§

fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor>

Source§

fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>>

Source§

fn checkpoint<B: TensorBackend>( &mut self, compiler: &mut GraphCompiler, executor: &mut GraphExecutor<B>, ) -> Result<()>

Source§

fn jvp( &self, wrt: &TracedTensor, tangent: &TracedTensor, ) -> Result<TracedTensor>

Source§

fn jvp_optional( &self, wrt: &TracedTensor, tangent: &TracedTensor, ) -> Result<Option<TracedTensor>>

Source§

fn vjp( &self, wrt: &TracedTensor, cotangent: &TracedTensor, ) -> Result<TracedTensor>

Source§

fn vjp_optional( &self, wrt: &TracedTensor, cotangent: &TracedTensor, ) -> Result<Option<TracedTensor>>

Implementors§