tenferro_ext_burn/
backward.rs

1//! Backward-mode support for [`TensorNetworkOps`] on Burn's autodiff backend.
2//!
3//! N-ary einsum calls are lowered to a sequence of unary/binary autodiff
4//! nodes following tenferro's contraction tree, so the public Burn surface
5//! matches the forward N-ary einsum contract.
6
7use burn::backend::autodiff::checkpoint::{base::Checkpointer, strategy::CheckpointStrategy};
8use burn::backend::autodiff::grads::Gradients;
9use burn::backend::autodiff::ops::{Backward, Ops, OpsKind};
10use burn::backend::Autodiff;
11use burn::tensor::ops::FloatTensor;
12use burn::tensor::TensorMetadata;
13
14use tenferro_algebra::Standard;
15use tenferro_einsum::{ContractionTree, NestedEinsum, Subscripts};
16use tenferro_prims::{CpuBackend, CpuContext};
17
18use crate::{panic_on_error, Error, Result, TensorNetworkOps};
19
20#[derive(Clone, Debug)]
21struct EinsumState<T> {
22    subscripts: String,
23    inputs: Vec<T>,
24}
25
26fn labels_to_notation(labels: &[u32]) -> Result<String> {
27    labels
28        .iter()
29        .map(|&label| {
30            char::from_u32(label).ok_or_else(|| {
31                Error::InvalidArgument(format!(
32                    "tenferro-ext-burn received a non-Unicode einsum label {label}"
33                ))
34            })
35        })
36        .collect()
37}
38
39fn subscripts_to_notation(subscripts: &Subscripts) -> Result<String> {
40    let inputs = subscripts
41        .inputs
42        .iter()
43        .map(|labels| labels_to_notation(labels))
44        .collect::<Result<Vec<_>>>()?
45        .join(",");
46    Ok(format!(
47        "{inputs}->{}",
48        labels_to_notation(&subscripts.output)?
49    ))
50}
51
52fn binary_step_notation(lhs: &[u32], rhs: &[u32], output: &[u32]) -> Result<String> {
53    Ok(format!(
54        "{},{}->{}",
55        labels_to_notation(lhs)?,
56        labels_to_notation(rhs)?,
57        labels_to_notation(output)?
58    ))
59}
60
61fn require_next<T>(iter: &mut impl Iterator<Item = T>, message: &'static str) -> Result<T> {
62    iter.next().ok_or(Error::InternalInvariant(message))
63}
64
65fn try_rrule_grads<B: burn::tensor::backend::Backend<FloatElem = f64>>(
66    subscripts: &str,
67    inputs: &[FloatTensor<B>],
68    cotangent: FloatTensor<B>,
69) -> Result<Vec<FloatTensor<B>>> {
70    let device = B::float_device(&cotangent);
71    let tenferro_inputs: Vec<_> = inputs
72        .iter()
73        .cloned()
74        .map(crate::convert::try_burn_to_tenferro::<B>)
75        .collect::<Result<_>>()?;
76    let input_refs: Vec<_> = tenferro_inputs.iter().collect();
77    let tenferro_cotangent = crate::convert::try_burn_to_tenferro::<B>(cotangent)?;
78    let mut ctx = CpuContext::new(1);
79    let grads = tenferro_einsum::einsum_rrule::<Standard<f64>, CpuBackend>(
80        &mut ctx,
81        subscripts,
82        &input_refs,
83        &tenferro_cotangent,
84    )
85    .map_err(|err| Error::InvalidArgument(err.to_string()))?;
86
87    grads
88        .into_iter()
89        .map(|grad| crate::convert::try_tenferro_to_burn::<B>(grad, &device))
90        .collect()
91}
92
93fn unary_einsum<B, C>(
94    subscripts: &str,
95    input: FloatTensor<Autodiff<B, C>>,
96) -> FloatTensor<Autodiff<B, C>>
97where
98    B: TensorNetworkOps,
99    C: CheckpointStrategy,
100{
101    #[derive(Debug)]
102    struct UnaryEinsum;
103
104    impl<B: burn::tensor::backend::Backend<FloatElem = f64>> Backward<B, 1> for UnaryEinsum {
105        type State = EinsumState<B::FloatTensorPrimitive>;
106
107        fn backward(
108            self,
109            ops: Ops<Self::State, 1>,
110            grads: &mut Gradients,
111            _checkpointer: &mut Checkpointer,
112        ) {
113            let mut grad_iter = panic_on_error(try_rrule_grads::<B>(
114                &ops.state.subscripts,
115                &ops.state.inputs,
116                grads.consume::<B>(&ops.node),
117            ))
118            .into_iter();
119
120            if let Some(node) = ops.parents[0].clone() {
121                let grad = panic_on_error(require_next(
122                    &mut grad_iter,
123                    "unary einsum rrule must return exactly one gradient",
124                ));
125                grads.register::<B>(node.id, grad);
126            }
127        }
128    }
129
130    let state = EinsumState {
131        subscripts: subscripts.to_owned(),
132        inputs: vec![input.primitive.clone()],
133    };
134
135    match UnaryEinsum
136        .prepare::<C>([input.node.clone()])
137        .compute_bound()
138        .stateful()
139    {
140        OpsKind::Tracked(prep) => {
141            prep.finish(state, B::tn_einsum(subscripts, vec![input.primitive]))
142        }
143        OpsKind::UnTracked(prep) => prep.finish(B::tn_einsum(subscripts, vec![input.primitive])),
144    }
145}
146
147fn binary_einsum<B, C>(
148    subscripts: &str,
149    lhs: FloatTensor<Autodiff<B, C>>,
150    rhs: FloatTensor<Autodiff<B, C>>,
151) -> FloatTensor<Autodiff<B, C>>
152where
153    B: TensorNetworkOps,
154    C: CheckpointStrategy,
155{
156    #[derive(Debug)]
157    struct BinaryEinsum;
158
159    impl<B: burn::tensor::backend::Backend<FloatElem = f64>> Backward<B, 2> for BinaryEinsum {
160        type State = EinsumState<B::FloatTensorPrimitive>;
161
162        fn backward(
163            self,
164            ops: Ops<Self::State, 2>,
165            grads: &mut Gradients,
166            _checkpointer: &mut Checkpointer,
167        ) {
168            let mut grad_iter = panic_on_error(try_rrule_grads::<B>(
169                &ops.state.subscripts,
170                &ops.state.inputs,
171                grads.consume::<B>(&ops.node),
172            ))
173            .into_iter();
174
175            if let Some(node) = ops.parents[0].clone() {
176                let grad = panic_on_error(require_next(
177                    &mut grad_iter,
178                    "binary einsum rrule must return a gradient for lhs",
179                ));
180                grads.register::<B>(node.id, grad);
181            }
182
183            if let Some(node) = ops.parents[1].clone() {
184                let grad = panic_on_error(require_next(
185                    &mut grad_iter,
186                    "binary einsum rrule must return a gradient for rhs",
187                ));
188                grads.register::<B>(node.id, grad);
189            }
190        }
191    }
192
193    let state = EinsumState {
194        subscripts: subscripts.to_owned(),
195        inputs: vec![lhs.primitive.clone(), rhs.primitive.clone()],
196    };
197
198    match BinaryEinsum
199        .prepare::<C>([lhs.node.clone(), rhs.node.clone()])
200        .compute_bound()
201        .stateful()
202    {
203        OpsKind::Tracked(prep) => prep.finish(
204            state,
205            B::tn_einsum(subscripts, vec![lhs.primitive, rhs.primitive]),
206        ),
207        OpsKind::UnTracked(prep) => {
208            prep.finish(B::tn_einsum(subscripts, vec![lhs.primitive, rhs.primitive]))
209        }
210    }
211}
212
213fn try_execute_einsum_tree<B, C>(
214    subscripts: &Subscripts,
215    inputs: Vec<FloatTensor<Autodiff<B, C>>>,
216) -> Result<FloatTensor<Autodiff<B, C>>>
217where
218    B: TensorNetworkOps,
219    C: CheckpointStrategy,
220{
221    match inputs.len() {
222        0 => Err(Error::InvalidArgument(
223            "tenferro-ext-burn autodiff einsum requires at least one input tensor".into(),
224        )),
225        1 => Ok(unary_einsum::<B, C>(
226            &subscripts_to_notation(subscripts)?,
227            inputs.into_iter().next().ok_or(Error::InternalInvariant(
228                "unary einsum dispatch lost its only input",
229            ))?,
230        )),
231        2 => {
232            let mut iter = inputs.into_iter();
233            let lhs = iter.next().ok_or(Error::InternalInvariant(
234                "binary einsum dispatch lost its lhs input",
235            ))?;
236            let rhs = iter.next().ok_or(Error::InternalInvariant(
237                "binary einsum dispatch lost its rhs input",
238            ))?;
239            Ok(binary_einsum::<B, C>(
240                &subscripts_to_notation(subscripts)?,
241                lhs,
242                rhs,
243            ))
244        }
245        n_inputs => {
246            let shapes: Vec<Vec<usize>> = inputs.iter().map(|input| input.shape().dims).collect();
247            let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
248            let tree = ContractionTree::optimize(subscripts, &shape_refs).map_err(|err| {
249                Error::InvalidArgument(format!(
250                    "tenferro-ext-burn autodiff einsum could not optimize the pairwise contraction path: {err}"
251                ))
252            })?;
253            let mut slots: Vec<Option<FloatTensor<Autodiff<B, C>>>> =
254                inputs.into_iter().map(Some).collect();
255            slots.resize(n_inputs + tree.step_count(), None);
256
257            for step_idx in 0..tree.step_count() {
258                let (left, right) = tree.step_pair(step_idx).ok_or(Error::InternalInvariant(
259                    "contraction tree is missing a recorded step",
260                ))?;
261                let (lhs_subs, rhs_subs, out_subs) =
262                    tree.step_subscripts(step_idx)
263                        .ok_or(Error::InternalInvariant(
264                            "contraction tree is missing step subscripts",
265                        ))?;
266                let lhs = slots[left].take().ok_or(Error::InternalInvariant(
267                    "contraction tree referenced a consumed lhs operand",
268                ))?;
269                let rhs = slots[right].take().ok_or(Error::InternalInvariant(
270                    "contraction tree referenced a consumed rhs operand",
271                ))?;
272                let step_notation = binary_step_notation(lhs_subs, rhs_subs, out_subs)?;
273                let result = binary_einsum::<B, C>(&step_notation, lhs, rhs);
274                slots[n_inputs + step_idx] = Some(result);
275            }
276
277            slots
278                .into_iter()
279                .rev()
280                .flatten()
281                .next()
282                .ok_or(Error::InternalInvariant(
283                    "contraction tree did not leave a final result",
284                ))
285        }
286    }
287}
288
289fn try_execute_nested_einsum<B, C>(
290    nested: &NestedEinsum,
291    inputs: &[FloatTensor<Autodiff<B, C>>],
292) -> Result<FloatTensor<Autodiff<B, C>>>
293where
294    B: TensorNetworkOps,
295    C: CheckpointStrategy,
296{
297    match nested {
298        NestedEinsum::Leaf(index) => inputs.get(*index).cloned().ok_or(Error::InternalInvariant(
299            "nested einsum referenced a missing input tensor",
300        )),
301        NestedEinsum::Node {
302            subscripts,
303            children,
304        } => {
305            let child_results = children
306                .iter()
307                .map(|child| try_execute_nested_einsum::<B, C>(child, inputs))
308                .collect::<Result<Vec<_>>>()?;
309            try_execute_einsum_tree::<B, C>(subscripts, child_results)
310        }
311    }
312}
313
314impl<B, C> TensorNetworkOps for Autodiff<B, C>
315where
316    B: TensorNetworkOps,
317    C: CheckpointStrategy,
318{
319    fn tn_einsum(subscripts: &str, inputs: Vec<FloatTensor<Self>>) -> FloatTensor<Self> {
320        let nested = panic_on_error(NestedEinsum::parse(subscripts).map_err(|err| {
321            Error::InvalidArgument(format!(
322                "tenferro-ext-burn autodiff einsum received invalid subscripts or mismatched parentheses: {err}"
323            ))
324        }));
325        panic_on_error(try_execute_nested_einsum::<B, C>(&nested, &inputs))
326    }
327}