pub fn dual_einsum<T: Scalar + HasAlgebra>(
_subscripts: &str,
_operands: &[&DualTensor<Tensor<T>>],
) -> AdResult<DualTensor<Tensor<T>>>where
Tensor<T>: Differentiable,Expand description
Dual einsum (forward-mode JVP propagation).
This is the AD-aware counterpart of einsum for forward-mode.
It propagates tangent vectors through the einsum operation.
§Examples
ⓘ
use chainrules::DualTensor;
use tenferro_einsum::dual_einsum;
use tenferro_tensor::{MemoryOrder, Tensor};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let a = Tensor::<f64>::ones(&[2, 3], mem, col);
let da = Tensor::<f64>::ones(&[2, 3], mem, col);
let b = Tensor::<f64>::ones(&[3, 4], mem, col);
let a_dual = DualTensor::with_tangent(a, da).unwrap();
let b_dual = DualTensor::new(b);
let c_dual = dual_einsum("ij,jk->ik", &[&a_dual, &b_dual]).unwrap();
let _tangent = c_dual.tangent();