dual_einsum

Function dual_einsum 

Source
pub fn dual_einsum<T: Scalar + HasAlgebra>(
    _subscripts: &str,
    _operands: &[&DualTensor<Tensor<T>>],
) -> AdResult<DualTensor<Tensor<T>>>
where Tensor<T>: Differentiable,
Expand description

Dual einsum (forward-mode JVP propagation).

This is the AD-aware counterpart of einsum for forward-mode. It propagates tangent vectors through the einsum operation.

§Examples

use chainrules::DualTensor;
use tenferro_einsum::dual_einsum;
use tenferro_tensor::{MemoryOrder, Tensor};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let a = Tensor::<f64>::ones(&[2, 3], mem, col);
let da = Tensor::<f64>::ones(&[2, 3], mem, col);
let b = Tensor::<f64>::ones(&[3, 4], mem, col);

let a_dual = DualTensor::with_tangent(a, da).unwrap();
let b_dual = DualTensor::new(b);
let c_dual = dual_einsum("ij,jk->ik", &[&a_dual, &b_dual]).unwrap();
let _tangent = c_dual.tangent();