tracked_einsum

Function tracked_einsum 

Source
pub fn tracked_einsum<T: Scalar + HasAlgebra>(
    _subscripts: &str,
    _operands: &[&TrackedTensor<Tensor<T>>],
) -> AdResult<TrackedTensor<Tensor<T>>>
where Tensor<T>: Differentiable,
Expand description

Tracked einsum (reverse-mode AD).

This is the AD-aware counterpart of einsum. It records the operation on the reverse-mode tape so that [chainrules::Tape::pullback] can compute gradients through it.

§Examples

use chainrules::Tape;
use tenferro_einsum::tracked_einsum;
use tenferro_tensor::{MemoryOrder, Tensor};
use tenferro_device::LogicalMemorySpace;

let tape = Tape::<Tensor<f64>>::new();
let a = tape.leaf(Tensor::ones(
    &[2, 3],
    LogicalMemorySpace::MainMemory,
    MemoryOrder::ColumnMajor,
));
let b = tape.leaf(Tensor::ones(
    &[3, 4],
    LogicalMemorySpace::MainMemory,
    MemoryOrder::ColumnMajor,
));
let c = tracked_einsum("ij,jk->ik", &[&a, &b]).unwrap();
let loss = tracked_einsum("ij,ij->", &[&c, &c]).unwrap();
let grads = tape.pullback(&loss).unwrap();
let _ga = grads.get(a.node_id().unwrap()).unwrap();