einsum_rrule

Function einsum_rrule 

Source
pub fn einsum_rrule<T: Scalar + HasAlgebra>(
    _subscripts: &str,
    _operands: &[&Tensor<T>],
    _cotangent: &Tensor<T>,
) -> Result<Vec<Tensor<T>>>
Expand description

Reverse-mode rule (rrule) for einsum without building a global tape.

Computes the pullback (vector-Jacobian product) for an einsum operation. Returns one gradient tensor per input operand.

Named after Julia’s ChainRules.jl convention. This API is intended for language interop and manual AD.

§Examples

use tenferro_einsum::einsum_rrule;
use tenferro_tensor::{MemoryOrder, Tensor};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let a = Tensor::<f64>::ones(&[2, 3], mem, col);
let b = Tensor::<f64>::ones(&[3, 4], mem, col);
let grad_c = Tensor::<f64>::ones(&[2, 4], mem, col);

let grads = einsum_rrule("ij,jk->ik", &[&a, &b], &grad_c).unwrap();
assert_eq!(grads.len(), 2);