einsum_hvp

Function einsum_hvp 

Source
pub fn einsum_hvp<T: Scalar + HasAlgebra>(
    _subscripts: &str,
    _primals: &[&Tensor<T>],
    _tangents: &[Option<&Tensor<T>>],
    _cotangent: &Tensor<T>,
    _cotangent_tangent: &Tensor<T>,
) -> Result<Vec<(Tensor<T>, Tensor<T>)>>
Expand description

Local HVP rule for einsum without building a global tape.

Computes the forward-over-reverse Hessian-vector product for an einsum operation. Given primals, their tangents (direction v), an output cotangent ḡ, and its tangent dḡ, returns (gradient, hvp) pairs for each input operand.

For C = einsum(subscripts, [A, B]):

  • gradient: standard pullback (e.g., ḡ_A = einsum(ḡ_C, B))
  • hvp: tangent of pullback (e.g., dḡ_A = einsum(dḡ_C, B) + einsum(ḡ_C, dB))

This API is intended for language interop and manual AD.

§Examples

use tenferro_einsum::einsum_hvp;
use tenferro_tensor::{MemoryOrder, Tensor};
use tenferro_device::LogicalMemorySpace;

let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let a = Tensor::<f64>::ones(&[2, 3], mem, col);
let b = Tensor::<f64>::ones(&[3, 4], mem, col);
let da = Tensor::<f64>::ones(&[2, 3], mem, col);

let grad_c = Tensor::<f64>::ones(&[2, 4], mem, col);
let dgrad_c = Tensor::<f64>::ones(&[2, 4], mem, col);

let results = einsum_hvp(
    "ij,jk->ik",
    &[&a, &b],
    &[Some(&da), None],
    &grad_c,
    &dgrad_c,
).unwrap();
assert_eq!(results.len(), 2);
let (_grad_a, _hvp_a) = &results[0];
let (_grad_b, _hvp_b) = &results[1];