pub fn tracked_einsum<T: Scalar + HasAlgebra>(
_subscripts: &str,
_operands: &[&TrackedTensor<Tensor<T>>],
) -> AdResult<TrackedTensor<Tensor<T>>>where
Tensor<T>: Differentiable,Expand description
Tracked einsum (reverse-mode AD).
This is the AD-aware counterpart of einsum. It records the operation
on the reverse-mode tape so that [chainrules::Tape::pullback] can
compute gradients through it.
§Examples
ⓘ
use chainrules::Tape;
use tenferro_einsum::tracked_einsum;
use tenferro_tensor::{MemoryOrder, Tensor};
use tenferro_device::LogicalMemorySpace;
let tape = Tape::<Tensor<f64>>::new();
let a = tape.leaf(Tensor::ones(
&[2, 3],
LogicalMemorySpace::MainMemory,
MemoryOrder::ColumnMajor,
));
let b = tape.leaf(Tensor::ones(
&[3, 4],
LogicalMemorySpace::MainMemory,
MemoryOrder::ColumnMajor,
));
let c = tracked_einsum("ij,jk->ik", &[&a, &b]).unwrap();
let loss = tracked_einsum("ij,ij->", &[&c, &c]).unwrap();
let grads = tape.pullback(&loss).unwrap();
let _ga = grads.get(a.node_id().unwrap()).unwrap();