tenferro_burn/backward.rs
1//! Backward-mode (autodiff) implementation of [`TensorNetworkOps`] for the
2//! [`Autodiff<B, C>`] backend.
3//!
4//! This module registers einsum as a differentiable operation in Burn's
5//! autodiff graph. When a backward pass is triggered, the VJP (vector–
6//! Jacobian product) will be computed via tenferro's `einsum_rrule`.
7//!
8//! # Variable-arity operations and Burn's `Backward<B, N>`
9//!
10//! Burn's [`Backward`] trait uses a const generic `N` that fixes the number
11//! of parent tensors at compile time (e.g., `Backward<B, 2>` for binary
12//! ops). Einsum accepts a variable number of inputs known only at run time.
13//!
14//! ## Chosen approach: macro-generated impls for N = 1..8
15//!
16//! We generate `Backward<B, N>` implementations for each arity via a macro.
17//! Every impl delegates to the same tenferro `einsum_rrule` internally —
18//! only the `[NodeRef; N]` / `[Option<NodeRef>; N]` array sizes differ.
19//! At the call site, `tn_einsum` dispatches to the matching arity based on
20//! `inputs.len()`.
21//!
22//! ```ignore
23//! macro_rules! impl_einsum_backward {
24//! ($n:literal) => {
25//! impl<B: TensorNetworkOps> Backward<B, $n> for EinsumBackward {
26//! type State = (String, Vec<Shape> /*, checkpointed ids */);
27//! fn backward(self, ops: Ops<Self::State, $n>, grads: &mut Gradients, ..) {
28//! // All N parents are real inputs — no dummy padding.
29//! // Convert Burn grads → tenferro, call einsum_rrule,
30//! // convert back, register per-input gradients.
31//! }
32//! }
33//! };
34//! }
35//! impl_einsum_backward!(1);
36//! impl_einsum_backward!(2);
37//! // ... up to 8
38//! ```
39//!
40//! This keeps tenferro's N-ary einsum intact (no forced binary
41//! decomposition) while fitting cleanly into Burn's compile-time–sized
42//! backward infrastructure. 8 inputs covers all practical tensor network
43//! contractions; the limit can be raised trivially if needed.
44//!
45//! For now the function body is `todo!()`, deferring implementation until
46//! the AD infrastructure in tenferro (chainrules / rrule) is complete.
47
48use burn::backend::autodiff::checkpoint::strategy::CheckpointStrategy;
49use burn::backend::Autodiff;
50use burn::tensor::ops::FloatTensor;
51
52use crate::TensorNetworkOps;
53
54impl<B, C> TensorNetworkOps for Autodiff<B, C>
55where
56 B: TensorNetworkOps,
57 C: CheckpointStrategy,
58{
59 /// Perform an einsum contraction, recording the operation on the autodiff
60 /// tape so that gradients can be computed during the backward pass.
61 ///
62 /// # Future implementation plan
63 ///
64 /// The backward pass will invoke tenferro's `einsum_rrule` to obtain the
65 /// VJP for each input tensor. The contraction tree used in the forward
66 /// pass will be cached (or re-derived) for the backward pass so that
67 /// each partial derivative is itself an optimised einsum.
68 fn tn_einsum(_subscripts: &str, _inputs: Vec<FloatTensor<Self>>) -> FloatTensor<Self> {
69 todo!()
70 }
71}