1use std::collections::{HashMap, HashSet};
40
41use computegraph::fragment::FragmentBuilder;
42use computegraph::types::{OpMode, ValRef};
43use computegraph::GraphOp;
44
45use tenferro_ops::dim_expr::DimExpr;
46use tenferro_ops::semiring_ops::SemiringOps;
47use tenferro_tensor::DotGeneralConfig;
48
49use crate::planning::tree::ContractionTree;
50
51#[derive(Clone, Debug)]
52struct LabeledVal<Op: GraphOp> {
53 val: ValRef<Op>,
54 labels: Vec<u32>,
55 shape: Vec<usize>,
56}
57
58fn label_size_map(labels: &[u32], shape: &[usize]) -> Vec<(u32, usize)> {
59 labels.iter().copied().zip(shape.iter().copied()).collect()
60}
61
62fn reduce_val<Op: GraphOp + SemiringOps>(
63 builder: &mut FragmentBuilder<Op>,
64 lv: &LabeledVal<Op>,
65 reduce_labels: &HashSet<u32>,
66) -> LabeledVal<Op> {
67 if reduce_labels.is_empty() {
68 return lv.clone();
69 }
70 let reduce_axes: Vec<usize> = lv
71 .labels
72 .iter()
73 .enumerate()
74 .filter(|(_, l)| reduce_labels.contains(l))
75 .map(|(i, _)| i)
76 .collect();
77 if reduce_axes.is_empty() {
78 return lv.clone();
79 }
80 let reduce_set: HashSet<usize> = reduce_axes.iter().copied().collect();
81 let new_labels: Vec<u32> = lv
82 .labels
83 .iter()
84 .enumerate()
85 .filter(|(i, _)| !reduce_set.contains(i))
86 .map(|(_, &l)| l)
87 .collect();
88 let new_shape: Vec<usize> = lv
89 .shape
90 .iter()
91 .enumerate()
92 .filter(|(i, _)| !reduce_set.contains(i))
93 .map(|(_, &s)| s)
94 .collect();
95 let outputs = builder.add_op(
96 Op::reduce_sum(reduce_axes, DimExpr::input_shape(0, lv.shape.len())),
97 vec![lv.val.clone()],
98 OpMode::Primal,
99 );
100 LabeledVal {
101 val: ValRef::Local(outputs[0]),
102 labels: new_labels,
103 shape: new_shape,
104 }
105}
106
107fn embed_repeated<Op: GraphOp + SemiringOps>(
111 builder: &mut FragmentBuilder<Op>,
112 lv: &LabeledVal<Op>,
113 output_labels: &[u32],
114) -> LabeledVal<Op> {
115 let mut result = lv.clone();
117 for &label in output_labels {
118 let current_count = result.labels.iter().filter(|&&l| l == label).count();
119 let output_count = output_labels.iter().filter(|&&l| l == label).count();
120 if output_count > current_count {
121 let axis_a = result
124 .labels
125 .iter()
126 .position(|&l| l == label)
127 .expect("label must exist in current tensor for embedding");
128 let axis_b = axis_a + 1;
130 let n = result.shape[axis_a];
131 let outputs = builder.add_op(
132 Op::embed_diag(axis_a, axis_b),
133 vec![result.val.clone()],
134 OpMode::Primal,
135 );
136 let mut new_labels = result.labels.clone();
137 new_labels.insert(axis_b, label);
138 let mut new_shape = result.shape.clone();
139 new_shape.insert(axis_b, n);
140 result = LabeledVal {
141 val: ValRef::Local(outputs[0]),
142 labels: new_labels,
143 shape: new_shape,
144 };
145 return embed_repeated(builder, &result, output_labels);
147 }
148 }
149 result
150}
151
152fn diagonalize_repeated<Op: GraphOp + SemiringOps>(
153 builder: &mut FragmentBuilder<Op>,
154 lv: &LabeledVal<Op>,
155) -> LabeledVal<Op> {
156 let mut seen: HashMap<u32, usize> = HashMap::new();
157 for (i, &label) in lv.labels.iter().enumerate() {
158 if let Some(&first) = seen.get(&label) {
159 let outputs = builder.add_op(
161 Op::extract_diag(first, i),
162 vec![lv.val.clone()],
163 OpMode::Primal,
164 );
165 let mut new_labels = lv.labels.clone();
166 new_labels.remove(i);
167 let mut new_shape = lv.shape.clone();
168 new_shape.remove(i);
169 let result = LabeledVal {
170 val: ValRef::Local(outputs[0]),
171 labels: new_labels,
172 shape: new_shape,
173 };
174 return diagonalize_repeated(builder, &result);
176 }
177 seen.insert(label, i);
178 }
179 lv.clone()
180}
181
182fn binary_contract<Op: GraphOp + SemiringOps>(
183 builder: &mut FragmentBuilder<Op>,
184 lhs: &LabeledVal<Op>,
185 rhs: &LabeledVal<Op>,
186 survive_labels: &[u32],
187 reorder_result: bool,
188) -> LabeledVal<Op> {
189 let survive_set: HashSet<u32> = survive_labels.iter().copied().collect();
190 let rhs_label_set: HashSet<u32> = rhs.labels.iter().copied().collect();
191 let lhs_label_set: HashSet<u32> = lhs.labels.iter().copied().collect();
192
193 let lhs_reduce: HashSet<u32> = lhs
195 .labels
196 .iter()
197 .filter(|l| !rhs_label_set.contains(l) && !survive_set.contains(l))
198 .copied()
199 .collect();
200 let rhs_reduce: HashSet<u32> = rhs
201 .labels
202 .iter()
203 .filter(|l| !lhs_label_set.contains(l) && !survive_set.contains(l))
204 .copied()
205 .collect();
206
207 let lhs = reduce_val(builder, lhs, &lhs_reduce);
208 let rhs = reduce_val(builder, rhs, &rhs_reduce);
209
210 let lhs_label_set: HashSet<u32> = lhs.labels.iter().copied().collect();
211 let rhs_label_set: HashSet<u32> = rhs.labels.iter().copied().collect();
212
213 let mut batch_labels = Vec::new();
215 let mut contracting_labels = Vec::new();
216 let mut lhs_free_labels = Vec::new();
217 let mut rhs_free_labels = Vec::new();
218
219 for &l in &lhs.labels {
221 if rhs_label_set.contains(&l) {
222 if survive_set.contains(&l) {
223 if !batch_labels.contains(&l) {
224 batch_labels.push(l);
225 }
226 } else if !contracting_labels.contains(&l) {
227 contracting_labels.push(l);
228 }
229 } else if !lhs_free_labels.contains(&l) {
230 lhs_free_labels.push(l);
231 }
232 }
233
234 for &l in &rhs.labels {
235 if !lhs_label_set.contains(&l) && !rhs_free_labels.contains(&l) {
236 rhs_free_labels.push(l);
237 }
238 }
239
240 let lhs_sizes: Vec<(u32, usize)> = label_size_map(&lhs.labels, &lhs.shape);
242 let rhs_sizes: Vec<(u32, usize)> = label_size_map(&rhs.labels, &rhs.shape);
243
244 let label_to_size = |l: u32| -> usize {
245 for &(label, size) in &lhs_sizes {
246 if label == l {
247 return size;
248 }
249 }
250 for &(label, size) in &rhs_sizes {
251 if label == l {
252 return size;
253 }
254 }
255 panic!("label {} not found in any operand", l);
256 };
257
258 let result = if !contracting_labels.is_empty() {
259 let lhs_contracting_dims: Vec<usize> = contracting_labels
261 .iter()
262 .map(|l| lhs.labels.iter().position(|x| x == l).unwrap())
263 .collect();
264 let rhs_contracting_dims: Vec<usize> = contracting_labels
265 .iter()
266 .map(|l| rhs.labels.iter().position(|x| x == l).unwrap())
267 .collect();
268 let lhs_batch_dims: Vec<usize> = batch_labels
269 .iter()
270 .map(|l| lhs.labels.iter().position(|x| x == l).unwrap())
271 .collect();
272 let rhs_batch_dims: Vec<usize> = batch_labels
273 .iter()
274 .map(|l| rhs.labels.iter().position(|x| x == l).unwrap())
275 .collect();
276
277 let config = DotGeneralConfig {
278 lhs_contracting_dims,
279 rhs_contracting_dims,
280 lhs_batch_dims,
281 rhs_batch_dims,
282 lhs_rank: lhs.shape.len(),
283 rhs_rank: rhs.shape.len(),
284 };
285
286 let result_labels: Vec<u32> = lhs_free_labels
288 .iter()
289 .chain(rhs_free_labels.iter())
290 .chain(batch_labels.iter())
291 .copied()
292 .collect();
293 let result_shape: Vec<usize> = result_labels.iter().map(|&l| label_to_size(l)).collect();
294
295 let outputs = builder.add_op(
296 Op::dot_general(config),
297 vec![lhs.val.clone(), rhs.val.clone()],
298 OpMode::Primal,
299 );
300
301 LabeledVal {
302 val: ValRef::Local(outputs[0]),
303 labels: result_labels,
304 shape: result_shape,
305 }
306 } else {
307 outer_product(
309 builder,
310 &lhs,
311 &rhs,
312 &batch_labels,
313 &lhs_free_labels,
314 &rhs_free_labels,
315 &label_to_size,
316 )
317 };
318
319 if !reorder_result {
320 return result;
321 }
322
323 let current_labels = &result.labels;
325 if current_labels.is_empty() {
326 return result;
327 }
328
329 let result_label_set: HashSet<u32> = current_labels.iter().copied().collect();
331 let target_labels: Vec<u32> = survive_labels
332 .iter()
333 .filter(|l| result_label_set.contains(l))
334 .copied()
335 .collect();
336
337 if current_labels.len() == target_labels.len() && *current_labels == target_labels {
338 return result;
339 }
340
341 let perm: Vec<usize> = target_labels
343 .iter()
344 .map(|l| current_labels.iter().position(|x| x == l).unwrap())
345 .collect();
346
347 if perm.iter().enumerate().all(|(i, &p)| i == p) {
348 return result;
349 }
350
351 let new_shape: Vec<usize> = perm.iter().map(|&p| result.shape[p]).collect();
352 let outputs = builder.add_op(
353 Op::transpose_op(perm),
354 vec![result.val.clone()],
355 OpMode::Primal,
356 );
357
358 LabeledVal {
359 val: ValRef::Local(outputs[0]),
360 labels: target_labels,
361 shape: new_shape,
362 }
363}
364
365fn outer_product<Op: GraphOp + SemiringOps>(
366 builder: &mut FragmentBuilder<Op>,
367 lhs: &LabeledVal<Op>,
368 rhs: &LabeledVal<Op>,
369 batch_labels: &[u32],
370 lhs_free_labels: &[u32],
371 rhs_free_labels: &[u32],
372 label_to_size: &dyn Fn(u32) -> usize,
373) -> LabeledVal<Op> {
374 let combined_labels: Vec<u32> = lhs_free_labels
375 .iter()
376 .chain(rhs_free_labels.iter())
377 .chain(batch_labels.iter())
378 .copied()
379 .collect();
380 let combined_shape: Vec<usize> = combined_labels.iter().map(|&l| label_to_size(l)).collect();
381
382 if lhs.labels == rhs.labels {
383 let outputs = builder.add_op(
385 Op::mul_op(),
386 vec![lhs.val.clone(), rhs.val.clone()],
387 OpMode::Primal,
388 );
389 return LabeledVal {
390 val: ValRef::Local(outputs[0]),
391 labels: lhs.labels.clone(),
392 shape: lhs.shape.clone(),
393 };
394 }
395
396 let lhs_dims: Vec<usize> = lhs
398 .labels
399 .iter()
400 .map(|l| combined_labels.iter().position(|x| x == l).unwrap())
401 .collect();
402 let rhs_dims: Vec<usize> = rhs
403 .labels
404 .iter()
405 .map(|l| combined_labels.iter().position(|x| x == l).unwrap())
406 .collect();
407
408 let lhs_bc = builder.add_op(
409 Op::broadcast_in_dim(DimExpr::from_concrete(&combined_shape), lhs_dims),
410 vec![lhs.val.clone()],
411 OpMode::Primal,
412 );
413 let rhs_bc = builder.add_op(
414 Op::broadcast_in_dim(DimExpr::from_concrete(&combined_shape), rhs_dims),
415 vec![rhs.val.clone()],
416 OpMode::Primal,
417 );
418 let outputs = builder.add_op(
419 Op::mul_op(),
420 vec![ValRef::Local(lhs_bc[0]), ValRef::Local(rhs_bc[0])],
421 OpMode::Primal,
422 );
423 LabeledVal {
424 val: ValRef::Local(outputs[0]),
425 labels: combined_labels,
426 shape: combined_shape,
427 }
428}
429
430pub fn build_einsum_fragment<Op: GraphOp + SemiringOps>(
431 builder: &mut FragmentBuilder<Op>,
432 tree: &ContractionTree,
433 input_vals: &[ValRef<Op>],
434 input_shapes: &[Vec<usize>],
435) -> ValRef<Op> {
436 let subscripts = &tree.subscripts;
437 let n_inputs = subscripts.inputs.len();
438 assert_eq!(
439 n_inputs,
440 input_vals.len(),
441 "number of subscripts inputs must match number of input values"
442 );
443 assert_eq!(
444 input_vals.len(),
445 input_shapes.len(),
446 "number of input values must match number of input shapes"
447 );
448
449 let output_labels = &subscripts.output;
450
451 let mut labeled: Vec<LabeledVal<Op>> = input_vals
452 .iter()
453 .zip(subscripts.inputs.iter())
454 .zip(input_shapes.iter())
455 .map(|((val, labels), shape)| {
456 assert_eq!(
457 labels.len(),
458 shape.len(),
459 "labels length must match shape rank"
460 );
461 LabeledVal {
462 val: val.clone(),
463 labels: labels.clone(),
464 shape: shape.clone(),
465 }
466 })
467 .collect();
468
469 for lv in &mut labeled {
471 *lv = diagonalize_repeated(builder, lv);
472 }
473
474 if n_inputs == 1 || tree.step_count() == 0 {
475 let lv = &labeled[0];
477 let output_set: HashSet<u32> = output_labels.iter().copied().collect();
478 let reduce_labels: HashSet<u32> = lv
479 .labels
480 .iter()
481 .filter(|l| !output_set.contains(l))
482 .copied()
483 .collect();
484 let result = reduce_val(builder, lv, &reduce_labels);
485
486 let result = embed_repeated(builder, &result, output_labels);
488
489 if result.labels == *output_labels {
491 return result.val;
492 }
493 let perm: Vec<usize> = output_labels
494 .iter()
495 .map(|l| result.labels.iter().position(|x| x == l).unwrap())
496 .collect();
497 if perm.iter().enumerate().all(|(i, &p)| i == p) {
498 return result.val;
499 }
500 let outputs = builder.add_op(Op::transpose_op(perm), vec![result.val], OpMode::Primal);
501 return ValRef::Local(outputs[0]);
502 }
503
504 for step_idx in 0..tree.step_count() {
507 let (left, right) = tree.step_pair(step_idx).unwrap();
508 let (_, _, step_out_labels) = tree.step_subscripts(step_idx).unwrap();
511 let is_last = step_idx + 1 == tree.step_count();
512 let result = binary_contract(
513 builder,
514 &labeled[left],
515 &labeled[right],
516 step_out_labels,
517 is_last,
518 );
519 labeled.push(result);
521 }
522
523 let final_idx = n_inputs + tree.step_count() - 1;
525 let result = &labeled[final_idx];
526
527 let output_set: HashSet<u32> = output_labels.iter().copied().collect();
529 let extra_labels: HashSet<u32> = result
530 .labels
531 .iter()
532 .filter(|l| !output_set.contains(l))
533 .copied()
534 .collect();
535 let result = reduce_val(builder, result, &extra_labels);
536
537 if result.labels == *output_labels {
539 return result.val;
540 }
541
542 if result.labels.is_empty() && output_labels.is_empty() {
543 return result.val;
544 }
545
546 let perm: Vec<usize> = output_labels
547 .iter()
548 .map(|l| result.labels.iter().position(|x| x == l).unwrap())
549 .collect();
550 if perm.iter().enumerate().all(|(i, &p)| i == p) {
551 return result.val;
552 }
553 let outputs = builder.add_op(
554 Op::transpose_op(perm),
555 vec![result.val.clone()],
556 OpMode::Primal,
557 );
558 ValRef::Local(outputs[0])
559}