1use std::collections::{HashMap, HashSet};
2
3use tenferro_tensor::{DotGeneralConfig, Error, Result, Tensor, TensorBackend, TensorExec};
4
5use crate::{ContractionTree, Subscripts};
6
7const EAGER_EINSUM_OP: &str = "eager_einsum";
8
9enum TensorValue<'a> {
10 Borrowed(&'a Tensor),
11 Owned(Tensor),
12}
13
14impl TensorValue<'_> {
15 fn as_tensor(&self) -> &Tensor {
16 match self {
17 Self::Borrowed(tensor) => tensor,
18 Self::Owned(tensor) => tensor,
19 }
20 }
21
22 fn into_tensor(self) -> Tensor {
23 match self {
24 Self::Borrowed(tensor) => tensor.clone(),
25 Self::Owned(tensor) => tensor,
26 }
27 }
28}
29
30struct LabeledTensor<'a> {
31 tensor: TensorValue<'a>,
32 labels: Vec<u32>,
33}
34
35impl LabeledTensor<'_> {
36 fn tensor(&self) -> &Tensor {
37 self.tensor.as_tensor()
38 }
39
40 fn shape(&self) -> &[usize] {
41 self.tensor().shape()
42 }
43}
44
45fn eager_invalid_config(message: impl Into<String>) -> Error {
46 Error::InvalidConfig {
47 op: EAGER_EINSUM_OP,
48 message: message.into(),
49 }
50}
51
52fn take_labeled<'a>(
53 labeled: &mut [Option<LabeledTensor<'a>>],
54 index: usize,
55 role: &'static str,
56) -> Result<LabeledTensor<'a>> {
57 labeled
58 .get_mut(index)
59 .ok_or_else(|| eager_invalid_config(format!("missing {role} operand at index {index}")))?
60 .take()
61 .ok_or_else(|| eager_invalid_config(format!("missing {role} operand at index {index}")))
62}
63
64fn find_label_axis(labels: &[u32], label: u32) -> Result<usize> {
65 labels
66 .iter()
67 .position(|candidate| *candidate == label)
68 .ok_or_else(|| eager_invalid_config(format!("label {label} missing from tensor labels")))
69}
70
71fn label_size(label: u32, operands: &[&LabeledTensor<'_>]) -> Result<usize> {
72 for operand in operands {
73 if let Some(axis) = operand
74 .labels
75 .iter()
76 .position(|candidate| *candidate == label)
77 {
78 return Ok(operand.shape()[axis]);
79 }
80 }
81 Err(eager_invalid_config(format!(
82 "label {label} missing from eager einsum operands"
83 )))
84}
85
86fn reduce_tensor<'a>(
87 exec: &mut dyn TensorExec,
88 operand: LabeledTensor<'a>,
89 reduce_labels: &HashSet<u32>,
90) -> Result<LabeledTensor<'a>> {
91 if reduce_labels.is_empty() {
92 return Ok(operand);
93 }
94
95 let reduce_axes: Vec<usize> = operand
96 .labels
97 .iter()
98 .enumerate()
99 .filter(|(_, label)| reduce_labels.contains(label))
100 .map(|(axis, _)| axis)
101 .collect();
102 if reduce_axes.is_empty() {
103 return Ok(operand);
104 }
105
106 let reduce_set: HashSet<usize> = reduce_axes.iter().copied().collect();
107 let labels: Vec<u32> = operand
108 .labels
109 .iter()
110 .enumerate()
111 .filter(|(axis, _)| !reduce_set.contains(axis))
112 .map(|(_, label)| *label)
113 .collect();
114 let tensor = exec.reduce_sum(operand.tensor(), &reduce_axes)?;
115 Ok(LabeledTensor {
116 tensor: TensorValue::Owned(tensor),
117 labels,
118 })
119}
120
121fn diagonalize_repeated<'a>(
122 exec: &mut dyn TensorExec,
123 mut operand: LabeledTensor<'a>,
124) -> Result<LabeledTensor<'a>> {
125 loop {
126 let mut seen = HashMap::new();
127 let mut repeated_pair = None;
128 for (axis, label) in operand.labels.iter().copied().enumerate() {
129 if let Some(first_axis) = seen.insert(label, axis) {
130 repeated_pair = Some((first_axis, axis));
131 break;
132 }
133 }
134
135 let Some((axis_a, axis_b)) = repeated_pair else {
136 return Ok(operand);
137 };
138
139 let tensor = exec.extract_diagonal(operand.tensor(), axis_a, axis_b)?;
140 let mut labels = operand.labels;
141 labels.remove(axis_b);
142 operand = LabeledTensor {
143 tensor: TensorValue::Owned(tensor),
144 labels,
145 };
146 }
147}
148
149fn embed_repeated<'a>(
150 exec: &mut dyn TensorExec,
151 mut operand: LabeledTensor<'a>,
152 output_labels: &[u32],
153) -> Result<LabeledTensor<'a>> {
154 loop {
155 let mut embedded = false;
156 for &label in output_labels {
157 let current_count = operand
158 .labels
159 .iter()
160 .filter(|candidate| **candidate == label)
161 .count();
162 let output_count = output_labels
163 .iter()
164 .filter(|candidate| **candidate == label)
165 .count();
166 if output_count > current_count {
167 let axis_a = find_label_axis(&operand.labels, label)?;
168 let axis_b = axis_a + 1;
169 let tensor = exec.embed_diagonal(operand.tensor(), axis_a, axis_b)?;
170 let mut labels = operand.labels;
171 labels.insert(axis_b, label);
172 operand = LabeledTensor {
173 tensor: TensorValue::Owned(tensor),
174 labels,
175 };
176 embedded = true;
177 break;
178 }
179 }
180
181 if !embedded {
182 return Ok(operand);
183 }
184 }
185}
186
187fn transpose_to_labels<'a>(
188 exec: &mut dyn TensorExec,
189 operand: LabeledTensor<'a>,
190 target_labels: &[u32],
191) -> Result<LabeledTensor<'a>> {
192 if operand.labels == target_labels {
193 return Ok(operand);
194 }
195
196 let perm: Vec<usize> = target_labels
197 .iter()
198 .map(|label| find_label_axis(&operand.labels, *label))
199 .collect::<Result<_>>()?;
200 if perm
201 .iter()
202 .enumerate()
203 .all(|(axis, target)| axis == *target)
204 {
205 return Ok(operand);
206 }
207
208 let tensor = exec.transpose(operand.tensor(), &perm)?;
209 Ok(LabeledTensor {
210 tensor: TensorValue::Owned(tensor),
211 labels: target_labels.to_vec(),
212 })
213}
214
215fn outer_product<'a>(
216 exec: &mut dyn TensorExec,
217 lhs: LabeledTensor<'a>,
218 rhs: LabeledTensor<'a>,
219 batch_labels: &[u32],
220 lhs_free_labels: &[u32],
221 rhs_free_labels: &[u32],
222) -> Result<LabeledTensor<'a>> {
223 if lhs.labels == rhs.labels {
224 let tensor = exec.mul(lhs.tensor(), rhs.tensor())?;
225 return Ok(LabeledTensor {
226 tensor: TensorValue::Owned(tensor),
227 labels: lhs.labels,
228 });
229 }
230
231 let combined_labels: Vec<u32> = lhs_free_labels
232 .iter()
233 .chain(rhs_free_labels.iter())
234 .chain(batch_labels.iter())
235 .copied()
236 .collect();
237 let combined_shape: Vec<usize> = combined_labels
238 .iter()
239 .map(|label| label_size(*label, &[&lhs, &rhs]))
240 .collect::<Result<_>>()?;
241 let lhs_dims: Vec<usize> = lhs
242 .labels
243 .iter()
244 .map(|label| find_label_axis(&combined_labels, *label))
245 .collect::<Result<_>>()?;
246 let rhs_dims: Vec<usize> = rhs
247 .labels
248 .iter()
249 .map(|label| find_label_axis(&combined_labels, *label))
250 .collect::<Result<_>>()?;
251
252 let lhs_tensor = exec.broadcast_in_dim(lhs.tensor(), &combined_shape, &lhs_dims)?;
253 let rhs_tensor = exec.broadcast_in_dim(rhs.tensor(), &combined_shape, &rhs_dims)?;
254 let tensor = exec.mul(&lhs_tensor, &rhs_tensor)?;
255 Ok(LabeledTensor {
256 tensor: TensorValue::Owned(tensor),
257 labels: combined_labels,
258 })
259}
260
261fn binary_contract<'a>(
262 exec: &mut dyn TensorExec,
263 lhs: LabeledTensor<'a>,
264 rhs: LabeledTensor<'a>,
265 survive_labels: &[u32],
266 reorder_result: bool,
267) -> Result<LabeledTensor<'a>> {
268 let survive_set: HashSet<u32> = survive_labels.iter().copied().collect();
269 let rhs_label_set: HashSet<u32> = rhs.labels.iter().copied().collect();
270 let lhs_label_set: HashSet<u32> = lhs.labels.iter().copied().collect();
271
272 let lhs_reduce: HashSet<u32> = lhs
273 .labels
274 .iter()
275 .filter(|label| !rhs_label_set.contains(label) && !survive_set.contains(label))
276 .copied()
277 .collect();
278 let rhs_reduce: HashSet<u32> = rhs
279 .labels
280 .iter()
281 .filter(|label| !lhs_label_set.contains(label) && !survive_set.contains(label))
282 .copied()
283 .collect();
284
285 let lhs = reduce_tensor(exec, lhs, &lhs_reduce)?;
286 let rhs = reduce_tensor(exec, rhs, &rhs_reduce)?;
287
288 let lhs_label_set: HashSet<u32> = lhs.labels.iter().copied().collect();
289 let rhs_label_set: HashSet<u32> = rhs.labels.iter().copied().collect();
290
291 let mut batch_labels = Vec::new();
292 let mut contracting_labels = Vec::new();
293 let mut lhs_free_labels = Vec::new();
294 let mut rhs_free_labels = Vec::new();
295
296 for &label in &lhs.labels {
297 if rhs_label_set.contains(&label) {
298 if survive_set.contains(&label) {
299 if !batch_labels.contains(&label) {
300 batch_labels.push(label);
301 }
302 } else if !contracting_labels.contains(&label) {
303 contracting_labels.push(label);
304 }
305 } else if !lhs_free_labels.contains(&label) {
306 lhs_free_labels.push(label);
307 }
308 }
309
310 for &label in &rhs.labels {
311 if !lhs_label_set.contains(&label) && !rhs_free_labels.contains(&label) {
312 rhs_free_labels.push(label);
313 }
314 }
315
316 let result = if contracting_labels.is_empty() {
317 outer_product(
318 exec,
319 lhs,
320 rhs,
321 &batch_labels,
322 &lhs_free_labels,
323 &rhs_free_labels,
324 )?
325 } else {
326 let lhs_contracting_dims: Vec<usize> = contracting_labels
327 .iter()
328 .map(|label| find_label_axis(&lhs.labels, *label))
329 .collect::<Result<_>>()?;
330 let rhs_contracting_dims: Vec<usize> = contracting_labels
331 .iter()
332 .map(|label| find_label_axis(&rhs.labels, *label))
333 .collect::<Result<_>>()?;
334 let lhs_batch_dims: Vec<usize> = batch_labels
335 .iter()
336 .map(|label| find_label_axis(&lhs.labels, *label))
337 .collect::<Result<_>>()?;
338 let rhs_batch_dims: Vec<usize> = batch_labels
339 .iter()
340 .map(|label| find_label_axis(&rhs.labels, *label))
341 .collect::<Result<_>>()?;
342 let labels: Vec<u32> = lhs_free_labels
343 .iter()
344 .chain(rhs_free_labels.iter())
345 .chain(batch_labels.iter())
346 .copied()
347 .collect();
348 let config = DotGeneralConfig {
349 lhs_contracting_dims,
350 rhs_contracting_dims,
351 lhs_batch_dims,
352 rhs_batch_dims,
353 lhs_rank: lhs.labels.len(),
354 rhs_rank: rhs.labels.len(),
355 };
356 let tensor = exec.dot_general(lhs.tensor(), rhs.tensor(), &config)?;
357 LabeledTensor {
358 tensor: TensorValue::Owned(tensor),
359 labels,
360 }
361 };
362
363 if !reorder_result {
364 return Ok(result);
365 }
366
367 let result_label_set: HashSet<u32> = result.labels.iter().copied().collect();
368 let target_labels: Vec<u32> = survive_labels
369 .iter()
370 .filter(|label| result_label_set.contains(label))
371 .copied()
372 .collect();
373 transpose_to_labels(exec, result, &target_labels)
374}
375
376fn eager_einsum_exec(
377 exec: &mut dyn TensorExec,
378 inputs: &[&Tensor],
379 tree: &ContractionTree,
380) -> Result<Tensor> {
381 let subscripts = &tree.subscripts;
382 let n_inputs = subscripts.inputs.len();
383 let output_labels = &subscripts.output;
384
385 let mut labeled: Vec<Option<LabeledTensor<'_>>> = inputs
386 .iter()
387 .zip(subscripts.inputs.iter())
388 .map(|(tensor, labels)| {
389 Some(LabeledTensor {
390 tensor: TensorValue::Borrowed(tensor),
391 labels: labels.clone(),
392 })
393 })
394 .collect();
395
396 for index in 0..labeled.len() {
397 let operand = take_labeled(&mut labeled, index, "input")?;
398 labeled[index] = Some(diagonalize_repeated(exec, operand)?);
399 }
400
401 if n_inputs == 1 || tree.step_count() == 0 {
402 let operand = take_labeled(&mut labeled, 0, "input")?;
403 let output_set: HashSet<u32> = output_labels.iter().copied().collect();
404 let reduce_labels: HashSet<u32> = operand
405 .labels
406 .iter()
407 .filter(|label| !output_set.contains(label))
408 .copied()
409 .collect();
410 let reduced = reduce_tensor(exec, operand, &reduce_labels)?;
411 let embedded = embed_repeated(exec, reduced, output_labels)?;
412 let reordered = transpose_to_labels(exec, embedded, output_labels)?;
413 return Ok(reordered.tensor.into_tensor());
414 }
415
416 for step_idx in 0..tree.step_count() {
417 let (left, right) = tree.step_pair(step_idx).ok_or_else(|| {
418 eager_invalid_config(format!("missing contraction pair for step {step_idx}"))
419 })?;
420 let (_, _, step_output_labels) = tree.step_subscripts(step_idx).ok_or_else(|| {
421 eager_invalid_config(format!(
422 "missing contraction subscripts for step {step_idx}"
423 ))
424 })?;
425 let lhs = take_labeled(&mut labeled, left, "lhs")?;
426 let rhs = take_labeled(&mut labeled, right, "rhs")?;
427 let result = binary_contract(
428 exec,
429 lhs,
430 rhs,
431 step_output_labels,
432 step_idx + 1 == tree.step_count(),
433 )?;
434 labeled.push(Some(result));
435 }
436
437 let final_index = n_inputs + tree.step_count() - 1;
438 let result = take_labeled(&mut labeled, final_index, "result")?;
439 let output_set: HashSet<u32> = output_labels.iter().copied().collect();
440 let extra_labels: HashSet<u32> = result
441 .labels
442 .iter()
443 .filter(|label| !output_set.contains(label))
444 .copied()
445 .collect();
446 let reduced = reduce_tensor(exec, result, &extra_labels)?;
447 let reordered = transpose_to_labels(exec, reduced, output_labels)?;
448 Ok(reordered.tensor.into_tensor())
449}
450
451pub fn eager_einsum(
472 ctx: &mut impl TensorBackend,
473 inputs: &[&Tensor],
474 subscripts: &str,
475) -> Result<Tensor> {
476 if inputs.is_empty() {
477 return Err(eager_invalid_config(
478 "eager einsum requires at least one input tensor",
479 ));
480 }
481
482 let subs = Subscripts::parse(subscripts)
483 .map_err(|err| eager_invalid_config(format!("invalid subscripts: {err}")))?;
484 if subs.inputs.len() != inputs.len() {
485 return Err(eager_invalid_config(format!(
486 "eager einsum subscripts expect {} inputs, got {}",
487 subs.inputs.len(),
488 inputs.len()
489 )));
490 }
491
492 let shapes: Vec<&[usize]> = inputs.iter().map(|tensor| tensor.shape()).collect();
493 let tree = ContractionTree::optimize(&subs, &shapes).map_err(|err| {
494 eager_invalid_config(format!("failed to optimize contraction tree: {err}"))
495 })?;
496 ctx.with_exec_session(|exec| eager_einsum_exec(exec, inputs, &tree))
497}