pub trait TracedTensorAdExt {
// Required methods
fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor>;
fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>>;
fn checkpoint<B: TensorBackend>(
&mut self,
compiler: &mut GraphCompiler,
executor: &mut GraphExecutor<B>,
) -> Result<()>;
fn jvp(
&self,
wrt: &TracedTensor,
tangent: &TracedTensor,
) -> Result<TracedTensor>;
fn jvp_optional(
&self,
wrt: &TracedTensor,
tangent: &TracedTensor,
) -> Result<Option<TracedTensor>>;
fn vjp(
&self,
wrt: &TracedTensor,
cotangent: &TracedTensor,
) -> Result<TracedTensor>;
fn vjp_optional(
&self,
wrt: &TracedTensor,
cotangent: &TracedTensor,
) -> Result<Option<TracedTensor>>;
}Expand description
Automatic differentiation helpers for TracedTensor.
§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_runtime::TracedTensor;
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let loss = x.scale_real(2.0);
let maybe_dx = loss.grad_optional(&x).unwrap();
assert!(maybe_dx.is_some());Required Methods§
Sourcefn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor>
fn grad(&self, wrt: &TracedTensor) -> Result<TracedTensor>
Gradient of a scalar output with respect to a traced input.
For complex scalar outputs, tenferro returns the Hermitian-adjoint
cotangent. To compare seed-1 scalar gradients with JAX’s public
grad values, use the complex conjugate of this result. See
https://tensor4all.org/tenferro-rs/guides/complex-ad.html.
§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
let mut compiler = GraphCompiler::new();
let program = compiler.compile(tensor).unwrap();
let mut executor = GraphExecutor::new(CpuBackend::new());
executor.run(&program).unwrap()
}
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let loss = (&x * &x).unwrap();
let dx = loss.grad(&x).unwrap();
assert_eq!(eval(&dx).as_slice::<f64>().unwrap(), &[6.0]);Sourcefn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>>
fn grad_optional(&self, wrt: &TracedTensor) -> Result<Option<TracedTensor>>
Like grad, but returns None when wrt is inactive.
§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_runtime::TracedTensor;
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
let loss = (&y * &y).unwrap();
assert!(loss.grad_optional(&x).unwrap().is_none());Sourcefn checkpoint<B: TensorBackend>(
&mut self,
compiler: &mut GraphCompiler,
executor: &mut GraphExecutor<B>,
) -> Result<()>
fn checkpoint<B: TensorBackend>( &mut self, compiler: &mut GraphCompiler, executor: &mut GraphExecutor<B>, ) -> Result<()>
Evaluate this tensor and replace its graph with a concrete leaf while preserving the previous graph for AD replay.
§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
let mut compiler = GraphCompiler::new();
let mut executor = GraphExecutor::new(CpuBackend::new());
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let mut y = (&x * &x).unwrap();
y.checkpoint(&mut compiler, &mut executor).unwrap();
let value = y.attached_data().unwrap();
assert_eq!(value.as_slice::<f64>().unwrap(), &[9.0]);Sourcefn jvp(
&self,
wrt: &TracedTensor,
tangent: &TracedTensor,
) -> Result<TracedTensor>
fn jvp( &self, wrt: &TracedTensor, tangent: &TracedTensor, ) -> Result<TracedTensor>
Forward-mode Jacobian-vector product.
§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
let mut compiler = GraphCompiler::new();
let program = compiler.compile(tensor).unwrap();
let mut executor = GraphExecutor::new(CpuBackend::new());
executor.run(&program).unwrap()
}
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let tangent = TracedTensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap();
let y = (&x * &x).unwrap();
let dy = y.jvp(&x, &tangent).unwrap();
assert_eq!(eval(&dy).as_slice::<f64>().unwrap(), &[12.0]);Sourcefn jvp_optional(
&self,
wrt: &TracedTensor,
tangent: &TracedTensor,
) -> Result<Option<TracedTensor>>
fn jvp_optional( &self, wrt: &TracedTensor, tangent: &TracedTensor, ) -> Result<Option<TracedTensor>>
Like jvp, but returns None when wrt is inactive.
§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_runtime::TracedTensor;
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
let tangent = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
let loss = (&y * &y).unwrap();
assert!(loss.jvp_optional(&x, &tangent).unwrap().is_none());Sourcefn vjp(
&self,
wrt: &TracedTensor,
cotangent: &TracedTensor,
) -> Result<TracedTensor>
fn vjp( &self, wrt: &TracedTensor, cotangent: &TracedTensor, ) -> Result<TracedTensor>
Reverse-mode vector-Jacobian product.
Complex cotangents use tenferro’s Hermitian real-inner-product convention. Non-real complex cotangent seeds therefore need an explicit seed-convention comparison when matching JAX. See https://tensor4all.org/tenferro-rs/guides/complex-ad.html.
§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_cpu::CpuBackend;
use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
fn eval(tensor: &TracedTensor) -> tenferro_runtime::Tensor {
let mut compiler = GraphCompiler::new();
let program = compiler.compile(tensor).unwrap();
let mut executor = GraphExecutor::new(CpuBackend::new());
executor.run(&program).unwrap()
}
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let cotangent = TracedTensor::from_vec_col_major(vec![], vec![0.5_f64]).unwrap();
let y = (&x * &x).unwrap();
let dx = y.vjp(&x, &cotangent).unwrap();
assert_eq!(eval(&dx).as_slice::<f64>().unwrap(), &[3.0]);Sourcefn vjp_optional(
&self,
wrt: &TracedTensor,
cotangent: &TracedTensor,
) -> Result<Option<TracedTensor>>
fn vjp_optional( &self, wrt: &TracedTensor, cotangent: &TracedTensor, ) -> Result<Option<TracedTensor>>
Like vjp, but returns None when wrt is inactive.
§Examples
use tenferro_ad::TracedTensorAdExt;
use tenferro_runtime::TracedTensor;
let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
let y = TracedTensor::from_vec_col_major(vec![], vec![4.0_f64]).unwrap();
let cotangent = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
let loss = (&y * &y).unwrap();
assert!(loss.vjp_optional(&x, &cotangent).unwrap().is_none());Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.