1use burn::backend::autodiff::checkpoint::{base::Checkpointer, strategy::CheckpointStrategy};
8use burn::backend::autodiff::grads::Gradients;
9use burn::backend::autodiff::ops::{Backward, Ops, OpsKind};
10use burn::backend::Autodiff;
11use burn::tensor::ops::FloatTensor;
12use burn::tensor::TensorMetadata;
13
14use tenferro_algebra::Standard;
15use tenferro_einsum::{ContractionTree, NestedEinsum, Subscripts};
16use tenferro_prims::{CpuBackend, CpuContext};
17
18use crate::{panic_on_error, Error, Result, TensorNetworkOps};
19
20#[derive(Clone, Debug)]
21struct EinsumState<T> {
22 subscripts: String,
23 inputs: Vec<T>,
24}
25
26fn labels_to_notation(labels: &[u32]) -> Result<String> {
27 labels
28 .iter()
29 .map(|&label| {
30 char::from_u32(label).ok_or_else(|| {
31 Error::InvalidArgument(format!(
32 "tenferro-ext-burn received a non-Unicode einsum label {label}"
33 ))
34 })
35 })
36 .collect()
37}
38
39fn subscripts_to_notation(subscripts: &Subscripts) -> Result<String> {
40 let inputs = subscripts
41 .inputs
42 .iter()
43 .map(|labels| labels_to_notation(labels))
44 .collect::<Result<Vec<_>>>()?
45 .join(",");
46 Ok(format!(
47 "{inputs}->{}",
48 labels_to_notation(&subscripts.output)?
49 ))
50}
51
52fn binary_step_notation(lhs: &[u32], rhs: &[u32], output: &[u32]) -> Result<String> {
53 Ok(format!(
54 "{},{}->{}",
55 labels_to_notation(lhs)?,
56 labels_to_notation(rhs)?,
57 labels_to_notation(output)?
58 ))
59}
60
61fn require_next<T>(iter: &mut impl Iterator<Item = T>, message: &'static str) -> Result<T> {
62 iter.next().ok_or(Error::InternalInvariant(message))
63}
64
65fn try_rrule_grads<B: burn::tensor::backend::Backend<FloatElem = f64>>(
66 subscripts: &str,
67 inputs: &[FloatTensor<B>],
68 cotangent: FloatTensor<B>,
69) -> Result<Vec<FloatTensor<B>>> {
70 let device = B::float_device(&cotangent);
71 let tenferro_inputs: Vec<_> = inputs
72 .iter()
73 .cloned()
74 .map(crate::convert::try_burn_to_tenferro::<B>)
75 .collect::<Result<_>>()?;
76 let input_refs: Vec<_> = tenferro_inputs.iter().collect();
77 let tenferro_cotangent = crate::convert::try_burn_to_tenferro::<B>(cotangent)?;
78 let mut ctx = CpuContext::new(1);
79 let grads = tenferro_einsum::einsum_rrule::<Standard<f64>, CpuBackend>(
80 &mut ctx,
81 subscripts,
82 &input_refs,
83 &tenferro_cotangent,
84 )
85 .map_err(|err| Error::InvalidArgument(err.to_string()))?;
86
87 grads
88 .into_iter()
89 .map(|grad| crate::convert::try_tenferro_to_burn::<B>(grad, &device))
90 .collect()
91}
92
93fn unary_einsum<B, C>(
94 subscripts: &str,
95 input: FloatTensor<Autodiff<B, C>>,
96) -> FloatTensor<Autodiff<B, C>>
97where
98 B: TensorNetworkOps,
99 C: CheckpointStrategy,
100{
101 #[derive(Debug)]
102 struct UnaryEinsum;
103
104 impl<B: burn::tensor::backend::Backend<FloatElem = f64>> Backward<B, 1> for UnaryEinsum {
105 type State = EinsumState<B::FloatTensorPrimitive>;
106
107 fn backward(
108 self,
109 ops: Ops<Self::State, 1>,
110 grads: &mut Gradients,
111 _checkpointer: &mut Checkpointer,
112 ) {
113 let mut grad_iter = panic_on_error(try_rrule_grads::<B>(
114 &ops.state.subscripts,
115 &ops.state.inputs,
116 grads.consume::<B>(&ops.node),
117 ))
118 .into_iter();
119
120 if let Some(node) = ops.parents[0].clone() {
121 let grad = panic_on_error(require_next(
122 &mut grad_iter,
123 "unary einsum rrule must return exactly one gradient",
124 ));
125 grads.register::<B>(node.id, grad);
126 }
127 }
128 }
129
130 let state = EinsumState {
131 subscripts: subscripts.to_owned(),
132 inputs: vec![input.primitive.clone()],
133 };
134
135 match UnaryEinsum
136 .prepare::<C>([input.node.clone()])
137 .compute_bound()
138 .stateful()
139 {
140 OpsKind::Tracked(prep) => {
141 prep.finish(state, B::tn_einsum(subscripts, vec![input.primitive]))
142 }
143 OpsKind::UnTracked(prep) => prep.finish(B::tn_einsum(subscripts, vec![input.primitive])),
144 }
145}
146
147fn binary_einsum<B, C>(
148 subscripts: &str,
149 lhs: FloatTensor<Autodiff<B, C>>,
150 rhs: FloatTensor<Autodiff<B, C>>,
151) -> FloatTensor<Autodiff<B, C>>
152where
153 B: TensorNetworkOps,
154 C: CheckpointStrategy,
155{
156 #[derive(Debug)]
157 struct BinaryEinsum;
158
159 impl<B: burn::tensor::backend::Backend<FloatElem = f64>> Backward<B, 2> for BinaryEinsum {
160 type State = EinsumState<B::FloatTensorPrimitive>;
161
162 fn backward(
163 self,
164 ops: Ops<Self::State, 2>,
165 grads: &mut Gradients,
166 _checkpointer: &mut Checkpointer,
167 ) {
168 let mut grad_iter = panic_on_error(try_rrule_grads::<B>(
169 &ops.state.subscripts,
170 &ops.state.inputs,
171 grads.consume::<B>(&ops.node),
172 ))
173 .into_iter();
174
175 if let Some(node) = ops.parents[0].clone() {
176 let grad = panic_on_error(require_next(
177 &mut grad_iter,
178 "binary einsum rrule must return a gradient for lhs",
179 ));
180 grads.register::<B>(node.id, grad);
181 }
182
183 if let Some(node) = ops.parents[1].clone() {
184 let grad = panic_on_error(require_next(
185 &mut grad_iter,
186 "binary einsum rrule must return a gradient for rhs",
187 ));
188 grads.register::<B>(node.id, grad);
189 }
190 }
191 }
192
193 let state = EinsumState {
194 subscripts: subscripts.to_owned(),
195 inputs: vec![lhs.primitive.clone(), rhs.primitive.clone()],
196 };
197
198 match BinaryEinsum
199 .prepare::<C>([lhs.node.clone(), rhs.node.clone()])
200 .compute_bound()
201 .stateful()
202 {
203 OpsKind::Tracked(prep) => prep.finish(
204 state,
205 B::tn_einsum(subscripts, vec![lhs.primitive, rhs.primitive]),
206 ),
207 OpsKind::UnTracked(prep) => {
208 prep.finish(B::tn_einsum(subscripts, vec![lhs.primitive, rhs.primitive]))
209 }
210 }
211}
212
213fn try_execute_einsum_tree<B, C>(
214 subscripts: &Subscripts,
215 inputs: Vec<FloatTensor<Autodiff<B, C>>>,
216) -> Result<FloatTensor<Autodiff<B, C>>>
217where
218 B: TensorNetworkOps,
219 C: CheckpointStrategy,
220{
221 match inputs.len() {
222 0 => Err(Error::InvalidArgument(
223 "tenferro-ext-burn autodiff einsum requires at least one input tensor".into(),
224 )),
225 1 => Ok(unary_einsum::<B, C>(
226 &subscripts_to_notation(subscripts)?,
227 inputs.into_iter().next().ok_or(Error::InternalInvariant(
228 "unary einsum dispatch lost its only input",
229 ))?,
230 )),
231 2 => {
232 let mut iter = inputs.into_iter();
233 let lhs = iter.next().ok_or(Error::InternalInvariant(
234 "binary einsum dispatch lost its lhs input",
235 ))?;
236 let rhs = iter.next().ok_or(Error::InternalInvariant(
237 "binary einsum dispatch lost its rhs input",
238 ))?;
239 Ok(binary_einsum::<B, C>(
240 &subscripts_to_notation(subscripts)?,
241 lhs,
242 rhs,
243 ))
244 }
245 n_inputs => {
246 let shapes: Vec<Vec<usize>> = inputs.iter().map(|input| input.shape().dims).collect();
247 let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
248 let tree = ContractionTree::optimize(subscripts, &shape_refs).map_err(|err| {
249 Error::InvalidArgument(format!(
250 "tenferro-ext-burn autodiff einsum could not optimize the pairwise contraction path: {err}"
251 ))
252 })?;
253 let mut slots: Vec<Option<FloatTensor<Autodiff<B, C>>>> =
254 inputs.into_iter().map(Some).collect();
255 slots.resize(n_inputs + tree.step_count(), None);
256
257 for step_idx in 0..tree.step_count() {
258 let (left, right) = tree.step_pair(step_idx).ok_or(Error::InternalInvariant(
259 "contraction tree is missing a recorded step",
260 ))?;
261 let (lhs_subs, rhs_subs, out_subs) =
262 tree.step_subscripts(step_idx)
263 .ok_or(Error::InternalInvariant(
264 "contraction tree is missing step subscripts",
265 ))?;
266 let lhs = slots[left].take().ok_or(Error::InternalInvariant(
267 "contraction tree referenced a consumed lhs operand",
268 ))?;
269 let rhs = slots[right].take().ok_or(Error::InternalInvariant(
270 "contraction tree referenced a consumed rhs operand",
271 ))?;
272 let step_notation = binary_step_notation(lhs_subs, rhs_subs, out_subs)?;
273 let result = binary_einsum::<B, C>(&step_notation, lhs, rhs);
274 slots[n_inputs + step_idx] = Some(result);
275 }
276
277 slots
278 .into_iter()
279 .rev()
280 .flatten()
281 .next()
282 .ok_or(Error::InternalInvariant(
283 "contraction tree did not leave a final result",
284 ))
285 }
286 }
287}
288
289fn try_execute_nested_einsum<B, C>(
290 nested: &NestedEinsum,
291 inputs: &[FloatTensor<Autodiff<B, C>>],
292) -> Result<FloatTensor<Autodiff<B, C>>>
293where
294 B: TensorNetworkOps,
295 C: CheckpointStrategy,
296{
297 match nested {
298 NestedEinsum::Leaf(index) => inputs.get(*index).cloned().ok_or(Error::InternalInvariant(
299 "nested einsum referenced a missing input tensor",
300 )),
301 NestedEinsum::Node {
302 subscripts,
303 children,
304 } => {
305 let child_results = children
306 .iter()
307 .map(|child| try_execute_nested_einsum::<B, C>(child, inputs))
308 .collect::<Result<Vec<_>>>()?;
309 try_execute_einsum_tree::<B, C>(subscripts, child_results)
310 }
311 }
312}
313
314impl<B, C> TensorNetworkOps for Autodiff<B, C>
315where
316 B: TensorNetworkOps,
317 C: CheckpointStrategy,
318{
319 fn tn_einsum(subscripts: &str, inputs: Vec<FloatTensor<Self>>) -> FloatTensor<Self> {
320 let nested = panic_on_error(NestedEinsum::parse(subscripts).map_err(|err| {
321 Error::InvalidArgument(format!(
322 "tenferro-ext-burn autodiff einsum received invalid subscripts or mismatched parentheses: {err}"
323 ))
324 }));
325 panic_on_error(try_execute_nested_einsum::<B, C>(&nested, &inputs))
326 }
327}