Skip to main content

tenferro_einsum/
builder.rs

1// TODO: Remaining einsum optimizations
2//
3// The current v2 einsum lowering is correct and already removes some
4// intermediate permutations by keeping N-ary intermediates in canonical dot
5// order. The following optimizations from v1 / the spec are still partial or
6// not yet implemented:
7//
8// Compiler passes (spec: optimizer-passes.md):
9//   - TransposeFolding: partially absorb Transpose into DotGeneral
10//     dimension_numbers when the free/contract/batch axis order remains
11//     compatible with the lowering.
12//     v1 equivalent: lazy permutation (dispatch.rs:446-454).
13//     Impact: eliminates physical copies for supported permuted GEMM inputs.
14//   - DotDimensionSorter: sort contracting dims to avoid transposes.
15//     v1 equivalent: implicit (modes already ordered).
16//   - DotDecomposer: canonicalize DotGeneral to [batch, M, K] × [batch, K, N].
17//     v1 equivalent: fusability check + partial materialization.
18//     Impact: maps arbitrary DotGeneral to BatchedGemm without extra copies.
19//   - ReductionSimplification: hoist independent ReduceSum before DotGeneral.
20//     v1 equivalent: pre-reduction (dispatch.rs:121-139).
21//
22// Einsum-level optimizations:
23//   - Diagonal embedding ("i->ii"): requires Scatter op (not yet implemented).
24//   - Hyper-edge einsum ("ik,k,kj->ij"): 3+ tensors sharing an index.
25//     Currently decomposed into binary steps; v1 handled this with a
26//     specialized dispatch path.
27//   - Binary diagonal ("ii,jk->ijk"): v1 diagonal plan in dispatch.rs.
28//     Currently works via ExtractDiag + standard contraction, but v1 had
29//     fused paths for better performance.
30//
31// Execution-level optimizations:
32//   - Stride-aware engine: v1 inspects strides at dispatch time and uses
33//     BLAS trans flags for transposed inputs. v2 engine does physical copies.
34//   - Buffer pooling: v1 reuses buffers via Arc refcount + pool.
35//     v2 has last_use liveness analysis but no pool.
36//   - SemiringFastPath: optional fused patterns (contract, elementwise_mul/add).
37//     Trait exists but no implementation.
38
39use std::collections::{HashMap, HashSet};
40
41use computegraph::fragment::FragmentBuilder;
42use computegraph::types::{OpMode, ValRef};
43use computegraph::GraphOp;
44
45use tenferro_ops::dim_expr::DimExpr;
46use tenferro_ops::semiring_ops::SemiringOps;
47use tenferro_tensor::DotGeneralConfig;
48
49use crate::planning::tree::ContractionTree;
50
51#[derive(Clone, Debug)]
52struct LabeledVal<Op: GraphOp> {
53    val: ValRef<Op>,
54    labels: Vec<u32>,
55    shape: Vec<usize>,
56}
57
58fn label_size_map(labels: &[u32], shape: &[usize]) -> Vec<(u32, usize)> {
59    labels.iter().copied().zip(shape.iter().copied()).collect()
60}
61
62fn reduce_val<Op: GraphOp + SemiringOps>(
63    builder: &mut FragmentBuilder<Op>,
64    lv: &LabeledVal<Op>,
65    reduce_labels: &HashSet<u32>,
66) -> LabeledVal<Op> {
67    if reduce_labels.is_empty() {
68        return lv.clone();
69    }
70    let reduce_axes: Vec<usize> = lv
71        .labels
72        .iter()
73        .enumerate()
74        .filter(|(_, l)| reduce_labels.contains(l))
75        .map(|(i, _)| i)
76        .collect();
77    if reduce_axes.is_empty() {
78        return lv.clone();
79    }
80    let reduce_set: HashSet<usize> = reduce_axes.iter().copied().collect();
81    let new_labels: Vec<u32> = lv
82        .labels
83        .iter()
84        .enumerate()
85        .filter(|(i, _)| !reduce_set.contains(i))
86        .map(|(_, &l)| l)
87        .collect();
88    let new_shape: Vec<usize> = lv
89        .shape
90        .iter()
91        .enumerate()
92        .filter(|(i, _)| !reduce_set.contains(i))
93        .map(|(_, &s)| s)
94        .collect();
95    let outputs = builder.add_op(
96        Op::reduce_sum(reduce_axes, DimExpr::input_shape(0, lv.shape.len())),
97        vec![lv.val.clone()],
98        OpMode::Primal,
99    );
100    LabeledVal {
101        val: ValRef::Local(outputs[0]),
102        labels: new_labels,
103        shape: new_shape,
104    }
105}
106
107/// Embed diagonal axes when the output requires higher multiplicity of a label
108/// than the current tensor has. For example, "i->ii" needs to embed axis 0
109/// into a new axis 1 of the same size.
110fn embed_repeated<Op: GraphOp + SemiringOps>(
111    builder: &mut FragmentBuilder<Op>,
112    lv: &LabeledVal<Op>,
113    output_labels: &[u32],
114) -> LabeledVal<Op> {
115    // Count how many times each label appears in output vs current labels.
116    let mut result = lv.clone();
117    for &label in output_labels {
118        let current_count = result.labels.iter().filter(|&&l| l == label).count();
119        let output_count = output_labels.iter().filter(|&&l| l == label).count();
120        if output_count > current_count {
121            // Need to embed: find the existing axis with this label and
122            // insert a duplicate axis after it.
123            let axis_a = result
124                .labels
125                .iter()
126                .position(|&l| l == label)
127                .expect("label must exist in current tensor for embedding");
128            // Insert the new axis right after axis_a.
129            let axis_b = axis_a + 1;
130            let n = result.shape[axis_a];
131            let outputs = builder.add_op(
132                Op::embed_diag(axis_a, axis_b),
133                vec![result.val.clone()],
134                OpMode::Primal,
135            );
136            let mut new_labels = result.labels.clone();
137            new_labels.insert(axis_b, label);
138            let mut new_shape = result.shape.clone();
139            new_shape.insert(axis_b, n);
140            result = LabeledVal {
141                val: ValRef::Local(outputs[0]),
142                labels: new_labels,
143                shape: new_shape,
144            };
145            // Recurse to handle cases like "i->iii" (multiple embeddings).
146            return embed_repeated(builder, &result, output_labels);
147        }
148    }
149    result
150}
151
152fn diagonalize_repeated<Op: GraphOp + SemiringOps>(
153    builder: &mut FragmentBuilder<Op>,
154    lv: &LabeledVal<Op>,
155) -> LabeledVal<Op> {
156    let mut seen: HashMap<u32, usize> = HashMap::new();
157    for (i, &label) in lv.labels.iter().enumerate() {
158        if let Some(&first) = seen.get(&label) {
159            // Found repeated label at axes `first` and `i`
160            let outputs = builder.add_op(
161                Op::extract_diag(first, i),
162                vec![lv.val.clone()],
163                OpMode::Primal,
164            );
165            let mut new_labels = lv.labels.clone();
166            new_labels.remove(i);
167            let mut new_shape = lv.shape.clone();
168            new_shape.remove(i);
169            let result = LabeledVal {
170                val: ValRef::Local(outputs[0]),
171                labels: new_labels,
172                shape: new_shape,
173            };
174            // Recurse in case there are more repeated labels
175            return diagonalize_repeated(builder, &result);
176        }
177        seen.insert(label, i);
178    }
179    lv.clone()
180}
181
182fn binary_contract<Op: GraphOp + SemiringOps>(
183    builder: &mut FragmentBuilder<Op>,
184    lhs: &LabeledVal<Op>,
185    rhs: &LabeledVal<Op>,
186    survive_labels: &[u32],
187    reorder_result: bool,
188) -> LabeledVal<Op> {
189    let survive_set: HashSet<u32> = survive_labels.iter().copied().collect();
190    let rhs_label_set: HashSet<u32> = rhs.labels.iter().copied().collect();
191    let lhs_label_set: HashSet<u32> = lhs.labels.iter().copied().collect();
192
193    // Pre-reduce: labels in lhs only, not in rhs and not in output
194    let lhs_reduce: HashSet<u32> = lhs
195        .labels
196        .iter()
197        .filter(|l| !rhs_label_set.contains(l) && !survive_set.contains(l))
198        .copied()
199        .collect();
200    let rhs_reduce: HashSet<u32> = rhs
201        .labels
202        .iter()
203        .filter(|l| !lhs_label_set.contains(l) && !survive_set.contains(l))
204        .copied()
205        .collect();
206
207    let lhs = reduce_val(builder, lhs, &lhs_reduce);
208    let rhs = reduce_val(builder, rhs, &rhs_reduce);
209
210    let lhs_label_set: HashSet<u32> = lhs.labels.iter().copied().collect();
211    let rhs_label_set: HashSet<u32> = rhs.labels.iter().copied().collect();
212
213    // Classify labels
214    let mut batch_labels = Vec::new();
215    let mut contracting_labels = Vec::new();
216    let mut lhs_free_labels = Vec::new();
217    let mut rhs_free_labels = Vec::new();
218
219    // Preserve order from lhs for batch and contracting
220    for &l in &lhs.labels {
221        if rhs_label_set.contains(&l) {
222            if survive_set.contains(&l) {
223                if !batch_labels.contains(&l) {
224                    batch_labels.push(l);
225                }
226            } else if !contracting_labels.contains(&l) {
227                contracting_labels.push(l);
228            }
229        } else if !lhs_free_labels.contains(&l) {
230            lhs_free_labels.push(l);
231        }
232    }
233
234    for &l in &rhs.labels {
235        if !lhs_label_set.contains(&l) && !rhs_free_labels.contains(&l) {
236            rhs_free_labels.push(l);
237        }
238    }
239
240    // Build label->size map
241    let lhs_sizes: Vec<(u32, usize)> = label_size_map(&lhs.labels, &lhs.shape);
242    let rhs_sizes: Vec<(u32, usize)> = label_size_map(&rhs.labels, &rhs.shape);
243
244    let label_to_size = |l: u32| -> usize {
245        for &(label, size) in &lhs_sizes {
246            if label == l {
247                return size;
248            }
249        }
250        for &(label, size) in &rhs_sizes {
251            if label == l {
252                return size;
253            }
254        }
255        panic!("label {} not found in any operand", l);
256    };
257
258    let result = if !contracting_labels.is_empty() {
259        // Use DotGeneral
260        let lhs_contracting_dims: Vec<usize> = contracting_labels
261            .iter()
262            .map(|l| lhs.labels.iter().position(|x| x == l).unwrap())
263            .collect();
264        let rhs_contracting_dims: Vec<usize> = contracting_labels
265            .iter()
266            .map(|l| rhs.labels.iter().position(|x| x == l).unwrap())
267            .collect();
268        let lhs_batch_dims: Vec<usize> = batch_labels
269            .iter()
270            .map(|l| lhs.labels.iter().position(|x| x == l).unwrap())
271            .collect();
272        let rhs_batch_dims: Vec<usize> = batch_labels
273            .iter()
274            .map(|l| rhs.labels.iter().position(|x| x == l).unwrap())
275            .collect();
276
277        let config = DotGeneralConfig {
278            lhs_contracting_dims,
279            rhs_contracting_dims,
280            lhs_batch_dims,
281            rhs_batch_dims,
282            lhs_rank: lhs.shape.len(),
283            rhs_rank: rhs.shape.len(),
284        };
285
286        // DotGeneral output order: lhs_free + rhs_free + batch (col-major batch trailing)
287        let result_labels: Vec<u32> = lhs_free_labels
288            .iter()
289            .chain(rhs_free_labels.iter())
290            .chain(batch_labels.iter())
291            .copied()
292            .collect();
293        let result_shape: Vec<usize> = result_labels.iter().map(|&l| label_to_size(l)).collect();
294
295        let outputs = builder.add_op(
296            Op::dot_general(config),
297            vec![lhs.val.clone(), rhs.val.clone()],
298            OpMode::Primal,
299        );
300
301        LabeledVal {
302            val: ValRef::Local(outputs[0]),
303            labels: result_labels,
304            shape: result_shape,
305        }
306    } else {
307        // No contracting dims -> element-wise multiply with broadcasting
308        outer_product(
309            builder,
310            &lhs,
311            &rhs,
312            &batch_labels,
313            &lhs_free_labels,
314            &rhs_free_labels,
315            &label_to_size,
316        )
317    };
318
319    if !reorder_result {
320        return result;
321    }
322
323    // Reorder to match the caller-visible order if needed.
324    let current_labels = &result.labels;
325    if current_labels.is_empty() {
326        return result;
327    }
328
329    // Filter survivor labels to those present in result (to handle final reduction later)
330    let result_label_set: HashSet<u32> = current_labels.iter().copied().collect();
331    let target_labels: Vec<u32> = survive_labels
332        .iter()
333        .filter(|l| result_label_set.contains(l))
334        .copied()
335        .collect();
336
337    if current_labels.len() == target_labels.len() && *current_labels == target_labels {
338        return result;
339    }
340
341    // Build permutation
342    let perm: Vec<usize> = target_labels
343        .iter()
344        .map(|l| current_labels.iter().position(|x| x == l).unwrap())
345        .collect();
346
347    if perm.iter().enumerate().all(|(i, &p)| i == p) {
348        return result;
349    }
350
351    let new_shape: Vec<usize> = perm.iter().map(|&p| result.shape[p]).collect();
352    let outputs = builder.add_op(
353        Op::transpose_op(perm),
354        vec![result.val.clone()],
355        OpMode::Primal,
356    );
357
358    LabeledVal {
359        val: ValRef::Local(outputs[0]),
360        labels: target_labels,
361        shape: new_shape,
362    }
363}
364
365fn outer_product<Op: GraphOp + SemiringOps>(
366    builder: &mut FragmentBuilder<Op>,
367    lhs: &LabeledVal<Op>,
368    rhs: &LabeledVal<Op>,
369    batch_labels: &[u32],
370    lhs_free_labels: &[u32],
371    rhs_free_labels: &[u32],
372    label_to_size: &dyn Fn(u32) -> usize,
373) -> LabeledVal<Op> {
374    let combined_labels: Vec<u32> = lhs_free_labels
375        .iter()
376        .chain(rhs_free_labels.iter())
377        .chain(batch_labels.iter())
378        .copied()
379        .collect();
380    let combined_shape: Vec<usize> = combined_labels.iter().map(|&l| label_to_size(l)).collect();
381
382    if lhs.labels == rhs.labels {
383        // Same labels: just Mul
384        let outputs = builder.add_op(
385            Op::mul_op(),
386            vec![lhs.val.clone(), rhs.val.clone()],
387            OpMode::Primal,
388        );
389        return LabeledVal {
390            val: ValRef::Local(outputs[0]),
391            labels: lhs.labels.clone(),
392            shape: lhs.shape.clone(),
393        };
394    }
395
396    // Broadcast both to combined shape, then Mul
397    let lhs_dims: Vec<usize> = lhs
398        .labels
399        .iter()
400        .map(|l| combined_labels.iter().position(|x| x == l).unwrap())
401        .collect();
402    let rhs_dims: Vec<usize> = rhs
403        .labels
404        .iter()
405        .map(|l| combined_labels.iter().position(|x| x == l).unwrap())
406        .collect();
407
408    let lhs_bc = builder.add_op(
409        Op::broadcast_in_dim(DimExpr::from_concrete(&combined_shape), lhs_dims),
410        vec![lhs.val.clone()],
411        OpMode::Primal,
412    );
413    let rhs_bc = builder.add_op(
414        Op::broadcast_in_dim(DimExpr::from_concrete(&combined_shape), rhs_dims),
415        vec![rhs.val.clone()],
416        OpMode::Primal,
417    );
418    let outputs = builder.add_op(
419        Op::mul_op(),
420        vec![ValRef::Local(lhs_bc[0]), ValRef::Local(rhs_bc[0])],
421        OpMode::Primal,
422    );
423    LabeledVal {
424        val: ValRef::Local(outputs[0]),
425        labels: combined_labels,
426        shape: combined_shape,
427    }
428}
429
430pub fn build_einsum_fragment<Op: GraphOp + SemiringOps>(
431    builder: &mut FragmentBuilder<Op>,
432    tree: &ContractionTree,
433    input_vals: &[ValRef<Op>],
434    input_shapes: &[Vec<usize>],
435) -> ValRef<Op> {
436    let subscripts = &tree.subscripts;
437    let n_inputs = subscripts.inputs.len();
438    assert_eq!(
439        n_inputs,
440        input_vals.len(),
441        "number of subscripts inputs must match number of input values"
442    );
443    assert_eq!(
444        input_vals.len(),
445        input_shapes.len(),
446        "number of input values must match number of input shapes"
447    );
448
449    let output_labels = &subscripts.output;
450
451    let mut labeled: Vec<LabeledVal<Op>> = input_vals
452        .iter()
453        .zip(subscripts.inputs.iter())
454        .zip(input_shapes.iter())
455        .map(|((val, labels), shape)| {
456            assert_eq!(
457                labels.len(),
458                shape.len(),
459                "labels length must match shape rank"
460            );
461            LabeledVal {
462                val: val.clone(),
463                labels: labels.clone(),
464                shape: shape.clone(),
465            }
466        })
467        .collect();
468
469    // Diagonalize repeated indices in each input
470    for lv in &mut labeled {
471        *lv = diagonalize_repeated(builder, lv);
472    }
473
474    if n_inputs == 1 || tree.step_count() == 0 {
475        // Unary: reduce, embed, and reorder
476        let lv = &labeled[0];
477        let output_set: HashSet<u32> = output_labels.iter().copied().collect();
478        let reduce_labels: HashSet<u32> = lv
479            .labels
480            .iter()
481            .filter(|l| !output_set.contains(l))
482            .copied()
483            .collect();
484        let result = reduce_val(builder, lv, &reduce_labels);
485
486        // Embed diagonal axes if output needs higher multiplicity
487        let result = embed_repeated(builder, &result, output_labels);
488
489        // Reorder if needed
490        if result.labels == *output_labels {
491            return result.val;
492        }
493        let perm: Vec<usize> = output_labels
494            .iter()
495            .map(|l| result.labels.iter().position(|x| x == l).unwrap())
496            .collect();
497        if perm.iter().enumerate().all(|(i, &p)| i == p) {
498            return result.val;
499        }
500        let outputs = builder.add_op(Op::transpose_op(perm), vec![result.val], OpMode::Primal);
501        return ValRef::Local(outputs[0]);
502    }
503
504    // N >= 2: use contraction tree from v1
505    // Operand indices: 0..n_inputs are originals, n_inputs+step_idx are intermediates
506    for step_idx in 0..tree.step_count() {
507        let (left, right) = tree.step_pair(step_idx).unwrap();
508        // Use the step's intermediate output subscripts so that labels needed
509        // by later contractions are preserved (not pre-reduced away).
510        let (_, _, step_out_labels) = tree.step_subscripts(step_idx).unwrap();
511        let is_last = step_idx + 1 == tree.step_count();
512        let result = binary_contract(
513            builder,
514            &labeled[left],
515            &labeled[right],
516            step_out_labels,
517            is_last,
518        );
519        // Push intermediate as new entry in labeled
520        labeled.push(result);
521    }
522
523    // The final result is the last intermediate: labeled[n_inputs + step_count - 1]
524    let final_idx = n_inputs + tree.step_count() - 1;
525    let result = &labeled[final_idx];
526
527    // Final reduction if result has labels not in output
528    let output_set: HashSet<u32> = output_labels.iter().copied().collect();
529    let extra_labels: HashSet<u32> = result
530        .labels
531        .iter()
532        .filter(|l| !output_set.contains(l))
533        .copied()
534        .collect();
535    let result = reduce_val(builder, result, &extra_labels);
536
537    // Final reorder if needed
538    if result.labels == *output_labels {
539        return result.val;
540    }
541
542    if result.labels.is_empty() && output_labels.is_empty() {
543        return result.val;
544    }
545
546    let perm: Vec<usize> = output_labels
547        .iter()
548        .map(|l| result.labels.iter().position(|x| x == l).unwrap())
549        .collect();
550    if perm.iter().enumerate().all(|(i, &p)| i == p) {
551        return result.val;
552    }
553    let outputs = builder.add_op(
554        Op::transpose_op(perm),
555        vec![result.val.clone()],
556        OpMode::Primal,
557    );
558    ValRef::Local(outputs[0])
559}