tenferro/einsum.rs
1//! N-ary einsum with configurable contraction strategy.
2//!
3//! This module provides free functions [`einsum`] and [`einsum_with`]. They
4//! build a lazy computation graph; call `.eval(&mut engine)` on the result to
5//! trigger execution.
6//!
7//! # Quick start
8//!
9//! ```ignore
10//! use tenferro::einsum::einsum;
11//! use tenferro::engine::Engine;
12//! use tenferro::traced::TracedTensor;
13//!
14//! let mut engine = Engine::new(CpuBackend::new());
15//! let a = TracedTensor::from_tensor_concrete_shape(tensor_a);
16//! let b = TracedTensor::from_tensor_concrete_shape(tensor_b);
17//!
18//! // Matrix multiply
19//! let c = einsum(&mut engine, &[&a, &b], "ij,jk->ik");
20//! let result = c.eval(&mut engine);
21//! ```
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use computegraph::fragment::FragmentBuilder;
27use computegraph::types::ValRef;
28use omeco::ScoreFunction;
29use tenferro_einsum::builder::build_einsum_fragment;
30use tenferro_einsum::{ContractionOptimizerOptions, ContractionTree, NestedEinsum, Subscripts};
31use tenferro_ops::std_tensor_op::StdTensorOp;
32use tenferro_tensor::TensorBackend;
33
34use super::checkpoint::CheckpointNode;
35use super::engine::Engine;
36use super::error::{Error, Result};
37use super::sym_dim::SymDim;
38use super::traced::{concrete_shape, next_traced_id, try_concrete_shape, TracedTensor};
39
40/// Controls how the contraction path is determined for N-ary einsum.
41///
42/// # Variants
43///
44/// ## `Auto` -- Automatic optimization (default: FLOPS-first)
45///
46/// Uses omeco's TreeSA optimizer. The default scoring prioritizes
47/// time complexity (FLOPS). Customize via `ContractionOptimizerOptions`.
48///
49/// ```ignore
50/// use omeco::ScoreFunction;
51/// use tenferro_einsum::ContractionOptimizerOptions;
52/// use tenferro::einsum::{einsum_with, EinsumOptimize};
53///
54/// // Default: FLOPS-first (minimize computation time)
55/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
56/// EinsumOptimize::default());
57///
58/// // Space-optimized (minimize peak intermediate tensor size)
59/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
60/// EinsumOptimize::Auto(ContractionOptimizerOptions {
61/// score: ScoreFunction::space_optimized(20.0),
62/// ..Default::default()
63/// }));
64///
65/// // Balanced (FLOPS + space, omeco default)
66/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
67/// EinsumOptimize::Auto(ContractionOptimizerOptions {
68/// score: ScoreFunction::default(),
69/// ..Default::default()
70/// }));
71///
72/// // Custom: space-heavy with FLOPS tiebreaker
73/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
74/// EinsumOptimize::Auto(ContractionOptimizerOptions {
75/// score: ScoreFunction::new(
76/// 0.1, // tc_weight (FLOPS, low priority)
77/// 1.0, // sc_weight (space, high priority)
78/// 0.0, // rw_weight (read-write, ignored)
79/// 15.0, // sc_target (no penalty below 2^15 elements)
80/// ),
81/// ..Default::default()
82/// }));
83///
84/// // Full TreeSA: multiple trials with annealing
85/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
86/// EinsumOptimize::Auto(ContractionOptimizerOptions {
87/// score: ScoreFunction::time_optimized(),
88/// ntrials: 10,
89/// niters: 50,
90/// betas: vec![0.01, 0.1, 1.0, 10.0],
91/// ..Default::default()
92/// }));
93/// ```
94///
95/// ## `False` -- No optimization
96///
97/// Contracts operands left-to-right in the order given.
98/// Useful for debugging or when the input order is already optimal.
99///
100/// ```ignore
101/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
102/// EinsumOptimize::False);
103/// ```
104///
105/// ## `Nested` -- Parenthesized notation
106///
107/// Specifies contraction order using a pre-parsed [`NestedEinsum`] tree.
108/// Most human-readable way to control order.
109///
110/// ```ignore
111/// use tenferro_einsum::NestedEinsum;
112///
113/// // "Contract A*B first, then result with C"
114/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
115/// EinsumOptimize::Nested(NestedEinsum::parse("(ij,jk),kl->il").unwrap()));
116///
117/// // "Contract B*C first, then A with result"
118/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
119/// EinsumOptimize::Nested(NestedEinsum::parse("ij,(jk,kl)->il").unwrap()));
120/// ```
121///
122/// ## `Path` -- JAX-compatible explicit path
123///
124/// Each pair specifies positions in a shrinking operand list.
125/// After each step, the two contracted operands are removed and
126/// the result is appended to the end.
127///
128/// Compatible with `jax.numpy.einsum(optimize=path)` and
129/// `opt_einsum.contract_path` output.
130///
131/// ```ignore
132/// // 3 operands: A(0), B(1), C(2)
133/// // Step 1: contract positions 1,2 (B,C) -> T. List: [A, T]
134/// // Step 2: contract positions 0,1 (A,T) -> result
135/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
136/// EinsumOptimize::Path(vec![(1, 2), (0, 1)]));
137///
138/// // Step 1: contract positions 0,1 (A,B) -> T. List: [C, T]
139/// // Step 2: contract positions 0,1 (C,T) -> result
140/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
141/// EinsumOptimize::Path(vec![(0, 1), (0, 1)]));
142/// ```
143///
144/// ## `Tree` -- Pre-computed ContractionTree
145///
146/// Pass a tree obtained from `ContractionTree::optimize` or other
147/// optimization tools. Skips all path computation.
148///
149/// ```ignore
150/// use tenferro_einsum::{ContractionTree, Subscripts};
151///
152/// let subs = Subscripts::parse("ij,jk,kl->il").unwrap();
153/// let shapes = [&[2, 3][..], &[3, 4], &[4, 5]];
154/// let tree = ContractionTree::optimize(&subs, &shapes).unwrap();
155/// einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
156/// EinsumOptimize::Tree(tree));
157/// ```
158pub enum EinsumOptimize {
159 /// Automatic optimization via omeco TreeSA.
160 Auto(ContractionOptimizerOptions),
161 /// No optimization -- contract left-to-right.
162 False,
163 /// Parenthesized notation specifying contraction order.
164 Nested(NestedEinsum),
165 /// JAX-compatible position-based contraction path.
166 Path(Vec<(usize, usize)>),
167 /// Pre-computed contraction tree.
168 Tree(ContractionTree),
169}
170
171impl Default for EinsumOptimize {
172 /// Default: FLOPS-first automatic optimization.
173 ///
174 /// Uses `ScoreFunction::time_optimized()`:
175 /// - `tc_weight = 1.0` (minimize FLOPS)
176 /// - `sc_weight = 0.0` (ignore space)
177 fn default() -> Self {
178 EinsumOptimize::Auto(ContractionOptimizerOptions {
179 score: ScoreFunction::time_optimized(),
180 ..Default::default()
181 })
182 }
183}
184
185/// N-ary einsum with default FLOPS-first optimization.
186///
187/// Builds a lazy computation graph. Call `.eval(&mut engine)` on the
188/// result to trigger execution.
189///
190/// # Examples
191///
192/// ```ignore
193/// use tenferro::einsum::einsum;
194/// use tenferro::engine::Engine;
195/// use tenferro::traced::TracedTensor;
196///
197/// // Matrix multiply
198/// let c = einsum(&mut engine, &[&a, &b], "ij,jk->ik");
199///
200/// // 3-tensor chain multiply
201/// let d = einsum(&mut engine, &[&a, &b, &c], "ij,jk,kl->il");
202///
203/// // Inner product
204/// let s = einsum(&mut engine, &[&x, &y], "i,i->");
205///
206/// // Row sum (unary)
207/// let r = einsum(&mut engine, &[&a], "ij->i");
208///
209/// // Hadamard product
210/// let h = einsum(&mut engine, &[&a, &b], "ij,ij->ij");
211///
212/// // Outer product
213/// let o = einsum(&mut engine, &[&x, &y], "i,j->ij");
214/// ```
215pub fn einsum<B: TensorBackend>(
216 engine: &mut Engine<B>,
217 inputs: &[&TracedTensor],
218 subscripts: &str,
219) -> Result<TracedTensor> {
220 einsum_with(engine, inputs, subscripts, EinsumOptimize::default())
221}
222
223/// N-ary einsum with explicit contraction strategy.
224///
225/// See [`EinsumOptimize`] for all available strategies and examples.
226///
227/// # Examples
228///
229/// ```ignore
230/// use tenferro::einsum::{einsum_with, EinsumOptimize};
231/// use tenferro::engine::Engine;
232/// use tenferro::traced::TracedTensor;
233///
234/// // Left-to-right, no optimizer
235/// let c = einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
236/// EinsumOptimize::False);
237///
238/// // JAX-compatible explicit path
239/// let c = einsum_with(&mut engine, &[&a, &b, &c], "ij,jk,kl->il",
240/// EinsumOptimize::Path(vec![(1, 2), (0, 1)]));
241/// ```
242pub fn einsum_with<B: TensorBackend>(
243 engine: &mut Engine<B>,
244 inputs: &[&TracedTensor],
245 subscripts: &str,
246 optimize: EinsumOptimize,
247) -> Result<TracedTensor> {
248 if inputs.is_empty() {
249 return Err(Error::ContractionError(
250 "einsum requires at least one input tensor".into(),
251 ));
252 }
253
254 let subs =
255 Subscripts::parse(subscripts).map_err(|e| Error::InvalidSubscripts(format!("{e}")))?;
256 if subs.inputs.len() != inputs.len() {
257 return Err(Error::ContractionError(format!(
258 "einsum subscripts expect {} inputs, got {}",
259 subs.inputs.len(),
260 inputs.len()
261 )));
262 }
263 if inputs
264 .iter()
265 .any(|tensor| try_concrete_shape(tensor).is_none())
266 {
267 return Ok(build_symbolic_nary_einsum(inputs, subscripts, &subs));
268 }
269 let shapes: Vec<Vec<usize>> = inputs.iter().map(|t| concrete_shape(t)).collect();
270 let shape_refs: Vec<&[usize]> = shapes.iter().map(|s| s.as_slice()).collect();
271
272 match optimize {
273 // Reuse TreeSA results for repeated calls with the same equation and input shapes.
274 EinsumOptimize::Auto(opts) => {
275 let cache_key = (subscripts.to_string(), shapes.clone());
276 let tree = if let Some(cached) = engine.einsum_cache.get(&cache_key) {
277 cached.clone()
278 } else {
279 let tree = Arc::new(resolve_strategy(
280 EinsumOptimize::Auto(opts),
281 &subs,
282 &shape_refs,
283 )?);
284 engine.einsum_cache.put(cache_key, tree.clone());
285 tree
286 };
287 Ok(build_traced_from_tree(
288 inputs,
289 &subs,
290 tree.as_ref(),
291 &shapes,
292 ))
293 }
294 optimize => {
295 let tree = resolve_strategy(optimize, &subs, &shape_refs)?;
296 Ok(build_traced_from_tree(inputs, &subs, &tree, &shapes))
297 }
298 }
299}
300
301fn build_symbolic_nary_einsum(
302 inputs: &[&TracedTensor],
303 subscripts: &str,
304 parsed: &Subscripts,
305) -> TracedTensor {
306 let mut builder = FragmentBuilder::new();
307 let mut input_vals = Vec::with_capacity(inputs.len());
308 let mut merged = HashMap::new();
309 let mut extra_roots = Vec::new();
310
311 for input in inputs {
312 builder.add_parent(input.fragment.clone());
313 input_vals.push(ValRef::External(
314 input.fragment.vals()[input.val].key.clone(),
315 ));
316 merged.extend(
317 input
318 .inputs_map
319 .iter()
320 .map(|(key, value)| (key.clone(), value.clone())),
321 );
322 extra_roots.extend(input.extra_roots.iter().cloned());
323 }
324
325 let outputs = builder.add_op(
326 StdTensorOp::NaryEinsum {
327 subscripts: subscripts.to_string(),
328 n_inputs: inputs.len(),
329 },
330 input_vals,
331 computegraph::types::OpMode::Primal,
332 );
333 builder.set_outputs(outputs.clone());
334
335 TracedTensor {
336 id: next_traced_id(),
337 rank: parsed.output.len(),
338 dtype: inputs[0].dtype,
339 fragment: Arc::new(builder.build()),
340 val: outputs[0],
341 data: None,
342 shape_hint: None,
343 inputs_map: Arc::new(merged),
344 extra_roots,
345 checkpoint_chain: None,
346 }
347}
348
349/// Resolve an [`EinsumOptimize`] strategy to a [`ContractionTree`].
350fn resolve_strategy(
351 optimize: EinsumOptimize,
352 subs: &Subscripts,
353 shapes: &[&[usize]],
354) -> Result<ContractionTree> {
355 match optimize {
356 EinsumOptimize::Auto(opts) => ContractionTree::optimize_with_options(subs, shapes, &opts)
357 .map_err(|e| Error::ContractionError(format!("{e}"))),
358 EinsumOptimize::False => {
359 let n = subs.inputs.len();
360 if n <= 1 {
361 ContractionTree::from_pairs(subs, shapes, &[])
362 .map_err(|e| Error::ContractionError(format!("{e}")))
363 } else {
364 let jax_path: Vec<(usize, usize)> = (0..n - 1).map(|_| (0, 1)).collect();
365 let v1_pairs = jax_path_to_v1_pairs(&jax_path, n);
366 ContractionTree::from_pairs(subs, shapes, &v1_pairs)
367 .map_err(|e| Error::ContractionError(format!("{e}")))
368 }
369 }
370 EinsumOptimize::Nested(nested) => {
371 let n = subs.inputs.len();
372 let v1_pairs = nested_to_v1_pairs(&nested, n);
373 ContractionTree::from_pairs(subs, shapes, &v1_pairs)
374 .map_err(|e| Error::ContractionError(format!("{e}")))
375 }
376 EinsumOptimize::Path(jax_path) => {
377 let n = subs.inputs.len();
378 let v1_pairs = jax_path_to_v1_pairs(&jax_path, n);
379 ContractionTree::from_pairs(subs, shapes, &v1_pairs)
380 .map_err(|e| Error::ContractionError(format!("{e}")))
381 }
382 EinsumOptimize::Tree(tree) => Ok(tree),
383 }
384}
385
386/// Convert JAX-style position-based path to v1 fixed-ID pairs.
387///
388/// JAX format: each pair `(i, j)` refers to positions in a shrinking list.
389/// After contraction, the two operands are removed (higher index first)
390/// and the result is appended at the end.
391///
392/// v1 format: inputs are `0..n`, intermediate at step `k` has ID `n + k`.
393fn jax_path_to_v1_pairs(jax_path: &[(usize, usize)], n_inputs: usize) -> Vec<(usize, usize)> {
394 // Track which original/intermediate IDs are at each position
395 let mut positions: Vec<usize> = (0..n_inputs).collect();
396 let mut v1_pairs = Vec::new();
397
398 for (step, &(pos_a, pos_b)) in jax_path.iter().enumerate() {
399 let (lo, hi) = if pos_a < pos_b {
400 (pos_a, pos_b)
401 } else {
402 (pos_b, pos_a)
403 };
404 let id_a = positions[lo];
405 let id_b = positions[hi];
406 v1_pairs.push((id_a, id_b));
407
408 // Remove higher index first, then lower
409 positions.remove(hi);
410 positions.remove(lo);
411 // Append new intermediate ID
412 positions.push(n_inputs + step);
413 }
414
415 v1_pairs
416}
417
418/// Convert a [`NestedEinsum`] tree into v1 fixed-ID pairs.
419///
420/// Walks the tree bottom-up. Each `Leaf(i)` maps to original input `i`.
421/// Each binary `Node` emits a pair `(left_id, right_id)` and is assigned
422/// the next intermediate ID (`n_inputs + step`).
423fn nested_to_v1_pairs(nested: &NestedEinsum, n_inputs: usize) -> Vec<(usize, usize)> {
424 let mut pairs = Vec::new();
425 let mut next_id = n_inputs;
426 walk_nested(nested, &mut pairs, &mut next_id);
427 pairs
428}
429
430/// Recursive walk of `NestedEinsum` that emits v1-style pairs.
431///
432/// Returns the operand ID for this sub-expression (either a leaf input index
433/// or an intermediate ID).
434fn walk_nested(
435 nested: &NestedEinsum,
436 pairs: &mut Vec<(usize, usize)>,
437 next_id: &mut usize,
438) -> usize {
439 match nested {
440 NestedEinsum::Leaf(idx) => *idx,
441 NestedEinsum::Node { children, .. } => {
442 // For binary nodes (the normal case), contract the two children.
443 // For N-ary nodes (N > 2), contract left-to-right.
444 assert!(
445 !children.is_empty(),
446 "NestedEinsum::Node must have at least one child"
447 );
448 let mut result_id = walk_nested(&children[0], pairs, next_id);
449 for child in &children[1..] {
450 let child_id = walk_nested(child, pairs, next_id);
451 pairs.push((result_id, child_id));
452 result_id = *next_id;
453 *next_id += 1;
454 }
455 result_id
456 }
457 }
458}
459
460/// Build a [`TracedTensor`] from a contraction tree and inputs.
461fn build_traced_from_tree(
462 inputs: &[&TracedTensor],
463 subscripts: &Subscripts,
464 tree: &ContractionTree,
465 shapes: &[Vec<usize>],
466) -> TracedTensor {
467 let out_shape = compute_einsum_output_shape(subscripts, shapes);
468
469 let mut builder = FragmentBuilder::new();
470
471 // Add parents and create ValRef for each input
472 let mut input_vals = Vec::new();
473 for input in inputs {
474 builder.add_parent(input.fragment.clone());
475 let val_ref = ValRef::External(input.fragment.vals()[input.val].key.clone());
476 input_vals.push(val_ref);
477 }
478
479 let result_ref = build_einsum_fragment(&mut builder, tree, &input_vals, shapes);
480
481 match result_ref {
482 ValRef::Local(result_local) => {
483 builder.set_outputs(vec![result_local]);
484 let fragment = Arc::new(builder.build());
485
486 let mut merged = HashMap::new();
487 let mut extra_roots = Vec::new();
488 for input in inputs {
489 merged.extend(input.inputs_map.iter().map(|(k, v)| (k.clone(), v.clone())));
490 extra_roots.extend(input.extra_roots.iter().cloned());
491 }
492
493 let merged_chain = inputs.iter().fold(None, |acc, input| {
494 CheckpointNode::merge_chains(acc, input.checkpoint_chain.clone())
495 });
496
497 TracedTensor {
498 id: next_traced_id(),
499 rank: out_shape.len(),
500 dtype: inputs[0].dtype,
501 fragment,
502 val: result_local,
503 data: None,
504 shape_hint: Some(out_shape.into_iter().map(SymDim::from).collect()),
505 inputs_map: Arc::new(merged),
506 extra_roots,
507 checkpoint_chain: merged_chain,
508 }
509 }
510 ValRef::External(_) => {
511 // Identity pass-through: the einsum doesn't add any ops.
512 // Find which input was returned and clone its TracedTensor.
513 for (i, iv) in input_vals.iter().enumerate() {
514 if *iv == result_ref {
515 return TracedTensor {
516 id: next_traced_id(),
517 rank: out_shape.len(),
518 dtype: inputs[i].dtype,
519 fragment: inputs[i].fragment.clone(),
520 val: inputs[i].val,
521 data: inputs[i].data.clone(),
522 shape_hint: Some(out_shape.into_iter().map(SymDim::from).collect()),
523 inputs_map: inputs[i].inputs_map.clone(),
524 extra_roots: inputs[i].extra_roots.clone(),
525 checkpoint_chain: inputs[i].checkpoint_chain.clone(),
526 };
527 }
528 }
529 panic!("build_einsum_fragment returned unrecognized external ref");
530 }
531 }
532}
533
534/// Compute the output shape from einsum subscripts and input shapes.
535fn compute_einsum_output_shape(subscripts: &Subscripts, shapes: &[Vec<usize>]) -> Vec<usize> {
536 let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect();
537 let size_dict = tenferro_einsum::build_size_dict(subscripts, &shape_refs, None)
538 .unwrap_or_else(|err| panic!("einsum shape computation failed: {err}"));
539 tenferro_einsum::compute_output_shape(&subscripts.output, &size_dict)
540 .unwrap_or_else(|err| panic!("einsum output shape computation failed: {err}"))
541}