pub fn einsum_hvp<T: Scalar + HasAlgebra>(
_subscripts: &str,
_primals: &[&Tensor<T>],
_tangents: &[Option<&Tensor<T>>],
_cotangent: &Tensor<T>,
_cotangent_tangent: &Tensor<T>,
) -> Result<Vec<(Tensor<T>, Tensor<T>)>>Expand description
Local HVP rule for einsum without building a global tape.
Computes the forward-over-reverse Hessian-vector product for an einsum
operation. Given primals, their tangents (direction v), an output
cotangent ḡ, and its tangent dḡ, returns (gradient, hvp) pairs
for each input operand.
For C = einsum(subscripts, [A, B]):
- gradient: standard pullback (e.g., ḡ_A = einsum(ḡ_C, B))
- hvp: tangent of pullback (e.g., dḡ_A = einsum(dḡ_C, B) + einsum(ḡ_C, dB))
This API is intended for language interop and manual AD.
§Examples
ⓘ
use tenferro_einsum::einsum_hvp;
use tenferro_tensor::{MemoryOrder, Tensor};
use tenferro_device::LogicalMemorySpace;
let col = MemoryOrder::ColumnMajor;
let mem = LogicalMemorySpace::MainMemory;
let a = Tensor::<f64>::ones(&[2, 3], mem, col);
let b = Tensor::<f64>::ones(&[3, 4], mem, col);
let da = Tensor::<f64>::ones(&[2, 3], mem, col);
let grad_c = Tensor::<f64>::ones(&[2, 4], mem, col);
let dgrad_c = Tensor::<f64>::ones(&[2, 4], mem, col);
let results = einsum_hvp(
"ij,jk->ik",
&[&a, &b],
&[Some(&da), None],
&grad_c,
&dgrad_c,
).unwrap();
assert_eq!(results.len(), 2);
let (_grad_a, _hvp_a) = &results[0];
let (_grad_b, _hvp_b) = &results[1];