Skip to main content

tenferro/
einsum.rs

1//! N-ary einsum with configurable contraction strategy.
2//!
3//! This module provides free functions [`einsum`] and [`einsum_with`]. They
4//! build a lazy computation graph; call `.eval(&mut engine)` on the result to
5//! trigger execution.
6//!
7//! # Quick start
8//!
9//! ```ignore
10//! use tenferro::einsum::einsum;
11//! use tenferro::engine::Engine;
12//! use tenferro::traced::TracedTensor;
13//!
14//! let mut engine = Engine::new(CpuBackend::new());
15//! let a = TracedTensor::from_tensor_concrete_shape(tensor_a);
16//! let b = TracedTensor::from_tensor_concrete_shape(tensor_b);
17//!
18//! // Matrix multiply
19//! let c = einsum(&mut engine, &[&a, &b], "ij,jk->ik");
20//! let result = c.eval(&mut engine);
21//! ```
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use computegraph::fragment::FragmentBuilder;
27use computegraph::types::ValRef;
28use omeco::ScoreFunction;
29use tenferro_einsum::builder::build_einsum_fragment;
30use tenferro_einsum::{ContractionOptimizerOptions, ContractionTree, NestedEinsum, Subscripts};
31use tenferro_ops::std_tensor_op::StdTensorOp;
32use tenferro_tensor::TensorBackend;
33
34use super::checkpoint::CheckpointNode;
35use super::engine::Engine;
36use super::error::{Error, Result};
37use super::sym_dim::SymDim;
38use super::traced::{concrete_shape, next_traced_id, try_concrete_shape, TracedTensor};
39
40/// Controls how the contraction path is determined for N-ary einsum.
41///
42/// # Variants
43///
44/// ## `Auto` -- Automatic optimization (default: FLOPS-first)
45///
46/// Uses omeco's TreeSA optimizer. The default scoring prioritizes
47/// time complexity (FLOPS). Customize via `ContractionOptimizerOptions`.
48///
49/// ```ignore
50/// use omeco::ScoreFunction;
51/// use tenferro_einsum::ContractionOptimizerOptions;
52/// use tenferro::einsum::{einsum_with, EinsumOptimize};
53///
54/// // Default: FLOPS-first (minimize computation time)
55/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
56///     EinsumOptimize::default());
57///
58/// // Space-optimized (minimize peak intermediate tensor size)
59/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
60///     EinsumOptimize::Auto(ContractionOptimizerOptions {
61///         score: ScoreFunction::space_optimized(20.0),
62///         ..Default::default()
63///     }));
64///
65/// // Balanced (FLOPS + space, omeco default)
66/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
67///     EinsumOptimize::Auto(ContractionOptimizerOptions {
68///         score: ScoreFunction::default(),
69///         ..Default::default()
70///     }));
71///
72/// // Custom: space-heavy with FLOPS tiebreaker
73/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
74///     EinsumOptimize::Auto(ContractionOptimizerOptions {
75///         score: ScoreFunction::new(
76///             0.1,   // tc_weight (FLOPS, low priority)
77///             1.0,   // sc_weight (space, high priority)
78///             0.0,   // rw_weight (read-write, ignored)
79///             15.0,  // sc_target (no penalty below 2^15 elements)
80///         ),
81///         ..Default::default()
82///     }));
83///
84/// // Full TreeSA: multiple trials with annealing
85/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
86///     EinsumOptimize::Auto(ContractionOptimizerOptions {
87///         score: ScoreFunction::time_optimized(),
88///         ntrials: 10,
89///         niters: 50,
90///         betas: vec![0.01, 0.1, 1.0, 10.0],
91///         ..Default::default()
92///     }));
93/// ```
94///
95/// ## `False` -- No optimization
96///
97/// Contracts operands left-to-right in the order given.
98/// Useful for debugging or when the input order is already optimal.
99///
100/// ```ignore
101/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
102///     EinsumOptimize::False);
103/// ```
104///
105/// ## `Nested` -- Parenthesized notation
106///
107/// Specifies contraction order using a pre-parsed [`NestedEinsum`] tree.
108/// Most human-readable way to control order.
109///
110/// ```ignore
111/// use tenferro_einsum::NestedEinsum;
112///
113/// // "Contract A*B first, then result with C"
114/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
115///     EinsumOptimize::Nested(NestedEinsum::parse("(ij,jk),kl->il").unwrap()));
116///
117/// // "Contract B*C first, then A with result"
118/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
119///     EinsumOptimize::Nested(NestedEinsum::parse("ij,(jk,kl)->il").unwrap()));
120/// ```
121///
122/// ## `Path` -- JAX-compatible explicit path
123///
124/// Each pair specifies positions in a shrinking operand list.
125/// After each step, the two contracted operands are removed and
126/// the result is appended to the end.
127///
128/// Compatible with `jax.numpy.einsum(optimize=path)` and
129/// `opt_einsum.contract_path` output.
130///
131/// ```ignore
132/// // 3 operands: A(0), B(1), C(2)
133/// // Step 1: contract positions 1,2 (B,C) -> T. List: [A, T]
134/// // Step 2: contract positions 0,1 (A,T) -> result
135/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
136///     EinsumOptimize::Path(vec![(1, 2), (0, 1)]));
137///
138/// // Step 1: contract positions 0,1 (A,B) -> T. List: [C, T]
139/// // Step 2: contract positions 0,1 (C,T) -> result
140/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
141///     EinsumOptimize::Path(vec![(0, 1), (0, 1)]));
142/// ```
143///
144/// ## `Tree` -- Pre-computed ContractionTree
145///
146/// Pass a tree obtained from `ContractionTree::optimize` or other
147/// optimization tools. Skips all path computation.
148///
149/// ```ignore
150/// use tenferro_einsum::{ContractionTree, Subscripts};
151///
152/// let subs = Subscripts::parse("ij,jk,kl->il").unwrap();
153/// let shapes = [&[2, 3][..], &[3, 4], &[4, 5]];
154/// let tree = ContractionTree::optimize(&subs, &shapes).unwrap();
155/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
156///     EinsumOptimize::Tree(tree));
157/// ```
158pub enum EinsumOptimize {
159    /// Automatic optimization via omeco TreeSA.
160    Auto(ContractionOptimizerOptions),
161    /// No optimization -- contract left-to-right.
162    False,
163    /// Parenthesized notation specifying contraction order.
164    Nested(NestedEinsum),
165    /// JAX-compatible position-based contraction path.
166    Path(Vec<(usize, usize)>),
167    /// Pre-computed contraction tree.
168    Tree(ContractionTree),
169}
170
171impl Default for EinsumOptimize {
172    /// Default: FLOPS-first automatic optimization.
173    ///
174    /// Uses `ScoreFunction::time_optimized()`:
175    /// - `tc_weight = 1.0` (minimize FLOPS)
176    /// - `sc_weight = 0.0` (ignore space)
177    fn default() -> Self {
178        EinsumOptimize::Auto(ContractionOptimizerOptions {
179            score: ScoreFunction::time_optimized(),
180            ..Default::default()
181        })
182    }
183}
184
185/// N-ary einsum with default FLOPS-first optimization.
186///
187/// Builds a lazy computation graph. Call `.eval(&mut engine)` on the
188/// result to trigger execution.
189///
190/// # Examples
191///
192/// ```ignore
193/// use tenferro::einsum::einsum;
194/// use tenferro::engine::Engine;
195/// use tenferro::traced::TracedTensor;
196///
197/// // Matrix multiply
198/// let c = einsum(&mut engine, &[&a, &b], "ij,jk->ik");
199///
200/// // 3-tensor chain multiply
201/// let d = einsum(&mut engine, &[&a, &b, &c], "ij,jk,kl->il");
202///
203/// // Inner product
204/// let s = einsum(&mut engine, &[&x, &y], "i,i->");
205///
206/// // Row sum (unary)
207/// let r = einsum(&mut engine, &[&a], "ij->i");
208///
209/// // Hadamard product
210/// let h = einsum(&mut engine, &[&a, &b], "ij,ij->ij");
211///
212/// // Outer product
213/// let o = einsum(&mut engine, &[&x, &y], "i,j->ij");
214/// ```
215pub fn einsum<B: TensorBackend>(
216    engine: &mut Engine<B>,
217    inputs: &[&TracedTensor],
218    subscripts: &str,
219) -> Result<TracedTensor> {
220    einsum_with(engine, inputs, subscripts, EinsumOptimize::default())
221}
222
223/// N-ary einsum with explicit contraction strategy.
224///
225/// See [`EinsumOptimize`] for all available strategies and examples.
226///
227/// # Examples
228///
229/// ```ignore
230/// use tenferro::einsum::{einsum_with, EinsumOptimize};
231/// use tenferro::engine::Engine;
232/// use tenferro::traced::TracedTensor;
233///
234/// // Left-to-right, no optimizer
235/// let c = einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
236///     EinsumOptimize::False);
237///
238/// // JAX-compatible explicit path
239/// let c = einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
240///     EinsumOptimize::Path(vec![(1, 2), (0, 1)]));
241/// ```
242pub fn einsum_with<B: TensorBackend>(
243    engine: &mut Engine<B>,
244    inputs: &[&TracedTensor],
245    subscripts: &str,
246    optimize: EinsumOptimize,
247) -> Result<TracedTensor> {
248    if inputs.is_empty() {
249        return Err(Error::ContractionError(
250            "einsum requires at least one input tensor".into(),
251        ));
252    }
253
254    let subs =
255        Subscripts::parse(subscripts).map_err(|e| Error::InvalidSubscripts(format!("{e}")))?;
256    if subs.inputs.len() != inputs.len() {
257        return Err(Error::ContractionError(format!(
258            "einsum subscripts expect {} inputs, got {}",
259            subs.inputs.len(),
260            inputs.len()
261        )));
262    }
263    if inputs
264        .iter()
265        .any(|tensor| try_concrete_shape(tensor).is_none())
266    {
267        return Ok(build_symbolic_nary_einsum(inputs, subscripts, &subs));
268    }
269    let shapes: Vec<Vec<usize>> = inputs.iter().map(|t| concrete_shape(t)).collect();
270    let shape_refs: Vec<&[usize]> = shapes.iter().map(|s| s.as_slice()).collect();
271
272    match optimize {
273        // Reuse TreeSA results for repeated calls with the same equation and input shapes.
274        EinsumOptimize::Auto(opts) => {
275            let cache_key = (subscripts.to_string(), shapes.clone());
276            let tree = if let Some(cached) = engine.einsum_cache.get(&cache_key) {
277                cached.clone()
278            } else {
279                let tree = Arc::new(resolve_strategy(
280                    EinsumOptimize::Auto(opts),
281                    &subs,
282                    &shape_refs,
283                )?);
284                engine.einsum_cache.put(cache_key, tree.clone());
285                tree
286            };
287            Ok(build_traced_from_tree(
288                inputs,
289                &subs,
290                tree.as_ref(),
291                &shapes,
292            ))
293        }
294        optimize => {
295            let tree = resolve_strategy(optimize, &subs, &shape_refs)?;
296            Ok(build_traced_from_tree(inputs, &subs, &tree, &shapes))
297        }
298    }
299}
300
301fn build_symbolic_nary_einsum(
302    inputs: &[&TracedTensor],
303    subscripts: &str,
304    parsed: &Subscripts,
305) -> TracedTensor {
306    let mut builder = FragmentBuilder::new();
307    let mut input_vals = Vec::with_capacity(inputs.len());
308    let mut merged = HashMap::new();
309    let mut extra_roots = Vec::new();
310
311    for input in inputs {
312        builder.add_parent(input.fragment.clone());
313        input_vals.push(ValRef::External(
314            input.fragment.vals()[input.val].key.clone(),
315        ));
316        merged.extend(
317            input
318                .inputs_map
319                .iter()
320                .map(|(key, value)| (key.clone(), value.clone())),
321        );
322        extra_roots.extend(input.extra_roots.iter().cloned());
323    }
324
325    let outputs = builder.add_op(
326        StdTensorOp::NaryEinsum {
327            subscripts: subscripts.to_string(),
328            n_inputs: inputs.len(),
329        },
330        input_vals,
331        computegraph::types::OpMode::Primal,
332    );
333    builder.set_outputs(outputs.clone());
334
335    TracedTensor {
336        id: next_traced_id(),
337        rank: parsed.output.len(),
338        dtype: inputs[0].dtype,
339        fragment: Arc::new(builder.build()),
340        val: outputs[0],
341        data: None,
342        shape_hint: None,
343        inputs_map: Arc::new(merged),
344        extra_roots,
345        checkpoint_chain: None,
346    }
347}
348
349/// Resolve an [`EinsumOptimize`] strategy to a [`ContractionTree`].
350fn resolve_strategy(
351    optimize: EinsumOptimize,
352    subs: &Subscripts,
353    shapes: &[&[usize]],
354) -> Result<ContractionTree> {
355    match optimize {
356        EinsumOptimize::Auto(opts) => ContractionTree::optimize_with_options(subs, shapes, &opts)
357            .map_err(|e| Error::ContractionError(format!("{e}"))),
358        EinsumOptimize::False => {
359            let n = subs.inputs.len();
360            if n <= 1 {
361                ContractionTree::from_pairs(subs, shapes, &[])
362                    .map_err(|e| Error::ContractionError(format!("{e}")))
363            } else {
364                let jax_path: Vec<(usize, usize)> = (0..n - 1).map(|_| (0, 1)).collect();
365                let v1_pairs = jax_path_to_v1_pairs(&jax_path, n);
366                ContractionTree::from_pairs(subs, shapes, &v1_pairs)
367                    .map_err(|e| Error::ContractionError(format!("{e}")))
368            }
369        }
370        EinsumOptimize::Nested(nested) => {
371            let n = subs.inputs.len();
372            let v1_pairs = nested_to_v1_pairs(&nested, n);
373            ContractionTree::from_pairs(subs, shapes, &v1_pairs)
374                .map_err(|e| Error::ContractionError(format!("{e}")))
375        }
376        EinsumOptimize::Path(jax_path) => {
377            let n = subs.inputs.len();
378            let v1_pairs = jax_path_to_v1_pairs(&jax_path, n);
379            ContractionTree::from_pairs(subs, shapes, &v1_pairs)
380                .map_err(|e| Error::ContractionError(format!("{e}")))
381        }
382        EinsumOptimize::Tree(tree) => Ok(tree),
383    }
384}
385
386/// Convert JAX-style position-based path to v1 fixed-ID pairs.
387///
388/// JAX format: each pair `(i, j)` refers to positions in a shrinking list.
389/// After contraction, the two operands are removed (higher index first)
390/// and the result is appended at the end.
391///
392/// v1 format: inputs are `0..n`, intermediate at step `k` has ID `n + k`.
393fn jax_path_to_v1_pairs(jax_path: &[(usize, usize)], n_inputs: usize) -> Vec<(usize, usize)> {
394    // Track which original/intermediate IDs are at each position
395    let mut positions: Vec<usize> = (0..n_inputs).collect();
396    let mut v1_pairs = Vec::new();
397
398    for (step, &(pos_a, pos_b)) in jax_path.iter().enumerate() {
399        let (lo, hi) = if pos_a < pos_b {
400            (pos_a, pos_b)
401        } else {
402            (pos_b, pos_a)
403        };
404        let id_a = positions[lo];
405        let id_b = positions[hi];
406        v1_pairs.push((id_a, id_b));
407
408        // Remove higher index first, then lower
409        positions.remove(hi);
410        positions.remove(lo);
411        // Append new intermediate ID
412        positions.push(n_inputs + step);
413    }
414
415    v1_pairs
416}
417
418/// Convert a [`NestedEinsum`] tree into v1 fixed-ID pairs.
419///
420/// Walks the tree bottom-up. Each `Leaf(i)` maps to original input `i`.
421/// Each binary `Node` emits a pair `(left_id, right_id)` and is assigned
422/// the next intermediate ID (`n_inputs + step`).
423fn nested_to_v1_pairs(nested: &NestedEinsum, n_inputs: usize) -> Vec<(usize, usize)> {
424    let mut pairs = Vec::new();
425    let mut next_id = n_inputs;
426    walk_nested(nested, &mut pairs, &mut next_id);
427    pairs
428}
429
430/// Recursive walk of `NestedEinsum` that emits v1-style pairs.
431///
432/// Returns the operand ID for this sub-expression (either a leaf input index
433/// or an intermediate ID).
434fn walk_nested(
435    nested: &NestedEinsum,
436    pairs: &mut Vec<(usize, usize)>,
437    next_id: &mut usize,
438) -> usize {
439    match nested {
440        NestedEinsum::Leaf(idx) => *idx,
441        NestedEinsum::Node { children, .. } => {
442            // For binary nodes (the normal case), contract the two children.
443            // For N-ary nodes (N > 2), contract left-to-right.
444            assert!(
445                !children.is_empty(),
446                "NestedEinsum::Node must have at least one child"
447            );
448            let mut result_id = walk_nested(&children[0], pairs, next_id);
449            for child in &children[1..] {
450                let child_id = walk_nested(child, pairs, next_id);
451                pairs.push((result_id, child_id));
452                result_id = *next_id;
453                *next_id += 1;
454            }
455            result_id
456        }
457    }
458}
459
460/// Build a [`TracedTensor`] from a contraction tree and inputs.
461fn build_traced_from_tree(
462    inputs: &[&TracedTensor],
463    subscripts: &Subscripts,
464    tree: &ContractionTree,
465    shapes: &[Vec<usize>],
466) -> TracedTensor {
467    let out_shape = compute_einsum_output_shape(subscripts, shapes);
468
469    let mut builder = FragmentBuilder::new();
470
471    // Add parents and create ValRef for each input
472    let mut input_vals = Vec::new();
473    for input in inputs {
474        builder.add_parent(input.fragment.clone());
475        let val_ref = ValRef::External(input.fragment.vals()[input.val].key.clone());
476        input_vals.push(val_ref);
477    }
478
479    let result_ref = build_einsum_fragment(&mut builder, tree, &input_vals, shapes);
480
481    match result_ref {
482        ValRef::Local(result_local) => {
483            builder.set_outputs(vec![result_local]);
484            let fragment = Arc::new(builder.build());
485
486            let mut merged = HashMap::new();
487            let mut extra_roots = Vec::new();
488            for input in inputs {
489                merged.extend(input.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
490                extra_roots.extend(input.extra_roots.iter().cloned());
491            }
492
493            let merged_chain = inputs.iter().fold(None, |acc, input| {
494                CheckpointNode::merge_chains(acc, input.checkpoint_chain.clone())
495            });
496
497            TracedTensor {
498                id: next_traced_id(),
499                rank: out_shape.len(),
500                dtype: inputs[0].dtype,
501                fragment,
502                val: result_local,
503                data: None,
504                shape_hint: Some(out_shape.into_iter().map(SymDim::from).collect()),
505                inputs_map: Arc::new(merged),
506                extra_roots,
507                checkpoint_chain: merged_chain,
508            }
509        }
510        ValRef::External(_) => {
511            // Identity pass-through: the einsum doesn't add any ops.
512            // Find which input was returned and clone its TracedTensor.
513            for (i, iv) in input_vals.iter().enumerate() {
514                if *iv == result_ref {
515                    return TracedTensor {
516                        id: next_traced_id(),
517                        rank: out_shape.len(),
518                        dtype: inputs[i].dtype,
519                        fragment: inputs[i].fragment.clone(),
520                        val: inputs[i].val,
521                        data: inputs[i].data.clone(),
522                        shape_hint: Some(out_shape.into_iter().map(SymDim::from).collect()),
523                        inputs_map: inputs[i].inputs_map.clone(),
524                        extra_roots: inputs[i].extra_roots.clone(),
525                        checkpoint_chain: inputs[i].checkpoint_chain.clone(),
526                    };
527                }
528            }
529            panic!("build_einsum_fragment returned unrecognized external ref");
530        }
531    }
532}
533
534/// Compute the output shape from einsum subscripts and input shapes.
535fn compute_einsum_output_shape(subscripts: &Subscripts, shapes: &[Vec<usize>]) -> Vec<usize> {
536    let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
537    let size_dict = tenferro_einsum::build_size_dict(subscripts, &shape_refs, None)
538        .unwrap_or_else(|err| panic!("einsum shape computation failed: {err}"));
539    tenferro_einsum::compute_output_shape(&subscripts.output, &size_dict)
540        .unwrap_or_else(|err| panic!("einsum output shape computation failed: {err}"))
541}