Expand description
Backward-mode (autodiff) implementation of TensorNetworkOps for the
[Autodiff<B, C>] backend.
This module registers einsum as a differentiable operation in Burn’s
autodiff graph. When a backward pass is triggered, the VJP (vector–
Jacobian product) will be computed via tenferro’s einsum_rrule.
§Variable-arity operations and Burn’s Backward<B, N>
Burn’s [Backward] trait uses a const generic N that fixes the number
of parent tensors at compile time (e.g., Backward<B, 2> for binary
ops). Einsum accepts a variable number of inputs known only at run time.
§Chosen approach: macro-generated impls for N = 1..8
We generate Backward<B, N> implementations for each arity via a macro.
Every impl delegates to the same tenferro einsum_rrule internally —
only the [NodeRef; N] / [Option<NodeRef>; N] array sizes differ.
At the call site, tn_einsum dispatches to the matching arity based on
inputs.len().
macro_rules! impl_einsum_backward {
($n:literal) => {
impl<B: TensorNetworkOps> Backward<B, $n> for EinsumBackward {
type State = (String, Vec<Shape> /*, checkpointed ids */);
fn backward(self, ops: Ops<Self::State, $n>, grads: &mut Gradients, ..) {
// All N parents are real inputs — no dummy padding.
// Convert Burn grads → tenferro, call einsum_rrule,
// convert back, register per-input gradients.
}
}
};
}
impl_einsum_backward!(1);
impl_einsum_backward!(2);
// ... up to 8This keeps tenferro’s N-ary einsum intact (no forced binary decomposition) while fitting cleanly into Burn’s compile-time–sized backward infrastructure. 8 inputs covers all practical tensor network contractions; the limit can be raised trivially if needed.
For now the function body is todo!(), deferring implementation until
the AD infrastructure in tenferro (chainrules / rrule) is complete.