tenferro_burn/
backward.rs

1//! Backward-mode (autodiff) implementation of [`TensorNetworkOps`] for the
2//! [`Autodiff<B, C>`] backend.
3//!
4//! This module registers einsum as a differentiable operation in Burn's
5//! autodiff graph.  When a backward pass is triggered, the VJP (vector–
6//! Jacobian product) will be computed via tenferro's `einsum_rrule`.
7//!
8//! # Variable-arity operations and Burn's `Backward<B, N>`
9//!
10//! Burn's [`Backward`] trait uses a const generic `N` that fixes the number
11//! of parent tensors at compile time (e.g., `Backward<B, 2>` for binary
12//! ops).  Einsum accepts a variable number of inputs known only at run time.
13//!
14//! ## Chosen approach: macro-generated impls for N = 1..8
15//!
16//! We generate `Backward<B, N>` implementations for each arity via a macro.
17//! Every impl delegates to the same tenferro `einsum_rrule` internally —
18//! only the `[NodeRef; N]` / `[Option<NodeRef>; N]` array sizes differ.
19//! At the call site, `tn_einsum` dispatches to the matching arity based on
20//! `inputs.len()`.
21//!
22//! ```ignore
23//! macro_rules! impl_einsum_backward {
24//!     ($n:literal) => {
25//!         impl<B: TensorNetworkOps> Backward<B, $n> for EinsumBackward {
26//!             type State = (String, Vec<Shape> /*, checkpointed ids */);
27//!             fn backward(self, ops: Ops<Self::State, $n>, grads: &mut Gradients, ..) {
28//!                 // All N parents are real inputs — no dummy padding.
29//!                 // Convert Burn grads → tenferro, call einsum_rrule,
30//!                 // convert back, register per-input gradients.
31//!             }
32//!         }
33//!     };
34//! }
35//! impl_einsum_backward!(1);
36//! impl_einsum_backward!(2);
37//! // ... up to 8
38//! ```
39//!
40//! This keeps tenferro's N-ary einsum intact (no forced binary
41//! decomposition) while fitting cleanly into Burn's compile-time–sized
42//! backward infrastructure.  8 inputs covers all practical tensor network
43//! contractions; the limit can be raised trivially if needed.
44//!
45//! For now the function body is `todo!()`, deferring implementation until
46//! the AD infrastructure in tenferro (chainrules / rrule) is complete.
47
48use burn::backend::autodiff::checkpoint::strategy::CheckpointStrategy;
49use burn::backend::Autodiff;
50use burn::tensor::ops::FloatTensor;
51
52use crate::TensorNetworkOps;
53
54impl<B, C> TensorNetworkOps for Autodiff<B, C>
55where
56    B: TensorNetworkOps,
57    C: CheckpointStrategy,
58{
59    /// Perform an einsum contraction, recording the operation on the autodiff
60    /// tape so that gradients can be computed during the backward pass.
61    ///
62    /// # Future implementation plan
63    ///
64    /// The backward pass will invoke tenferro's `einsum_rrule` to obtain the
65    /// VJP for each input tensor.  The contraction tree used in the forward
66    /// pass will be cached (or re-derived) for the backward pass so that
67    /// each partial derivative is itself an optimised einsum.
68    fn tn_einsum(_subscripts: &str, _inputs: Vec<FloatTensor<Self>>) -> FloatTensor<Self> {
69        todo!()
70    }
71}