Skip to main content

tenferro_einsum/planning/
tree.rs

1use std::collections::{HashMap, HashSet};
2
3use omeco::{
4    CodeOptimizer, EinCode as OmecoEinCode, Initializer, NestedEinsum, ScoreFunction, TreeSA,
5};
6use tenferro_device::{Error, Result};
7
8use crate::planning::plan::{compile_step_plans, StepPlan};
9use crate::syntax::subscripts::Subscripts;
10use crate::util::{build_size_dict, compute_output_shape, contraction_cost, intermediate_subs};
11
12/// A single step in the contraction sequence.
13pub(crate) struct ContractionStep {
14    pub(crate) left: usize,
15    pub(crate) right: usize,
16}
17
18/// Public options for automatic contraction-path optimization.
19///
20/// The default planner uses TreeSA with a greedy initializer and zero annealing
21/// iterations. This keeps the public API on a single optimizer family while
22/// making the default behavior effectively "greedy-only".
23#[derive(Debug, Clone)]
24pub struct ContractionOptimizerOptions {
25    /// Inverse-temperature schedule for TreeSA.
26    pub betas: Vec<f64>,
27    /// Number of independent TreeSA trials.
28    pub ntrials: usize,
29    /// Annealing iterations per temperature level.
30    pub niters: usize,
31    /// Score function used by TreeSA.
32    pub score: ScoreFunction,
33}
34
35impl Default for ContractionOptimizerOptions {
36    fn default() -> Self {
37        Self {
38            betas: Vec::new(),
39            ntrials: 1,
40            niters: 0,
41            score: ScoreFunction::default(),
42        }
43    }
44}
45
46impl ContractionOptimizerOptions {
47    fn to_treesa(&self) -> TreeSA {
48        TreeSA::new(
49            self.betas.clone(),
50            self.ntrials,
51            self.niters,
52            Initializer::Greedy,
53            self.score.clone(),
54        )
55    }
56
57    fn validate(&self) -> Result<()> {
58        if self.ntrials == 0 {
59            return Err(Error::InvalidArgument(
60                "contraction optimizer ntrials must be at least 1".into(),
61            ));
62        }
63        Ok(())
64    }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub(crate) struct ChainAttachment {
69    pub(crate) prev_on_left: bool,
70    pub(crate) operand: usize,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub(crate) struct LinearChainPlan {
75    pub(crate) first_pair: (usize, usize),
76    pub(crate) attachments: Vec<ChainAttachment>,
77}
78
79/// Contraction tree determining pairwise contraction order for N-ary einsum.
80///
81/// When contracting more than two tensors, the order in which pairwise
82/// contractions are performed significantly affects performance.
83/// `ContractionTree` encodes this order as a binary tree.
84///
85/// # Optimization
86///
87/// Use [`ContractionTree::optimize`] for automatic cost-based optimization
88/// (e.g., greedy algorithm based on tensor sizes), or
89/// [`ContractionTree::from_pairs`] for manual specification.
90pub struct ContractionTree {
91    /// Original subscripts.
92    pub(crate) subscripts: Subscripts,
93    /// Steps in the contraction (empty for single-tensor case).
94    pub(crate) steps: Vec<ContractionStep>,
95    /// Label → dimension size mapping.
96    pub(crate) size_dict: HashMap<u32, usize>,
97    /// Subscripts for each operand (0..n_inputs from input, then intermediates).
98    pub(crate) operand_subs: Vec<Vec<u32>>,
99    /// Pre-computed output shapes for each intermediate step (indexed by step_idx).
100    pub(crate) step_output_shapes: Vec<Vec<usize>>,
101    /// Pre-compiled step plans (cached to avoid recomputation per execute call).
102    pub(crate) step_plans: Vec<StepPlan>,
103}
104
105impl ContractionTree {
106    /// Automatically compute an optimized contraction order.
107    ///
108    /// Uses a cost-based heuristic (greedy algorithm) to determine
109    /// the pairwise contraction sequence that minimizes total operation count.
110    ///
111    /// # Arguments
112    ///
113    /// * `subscripts` — Einsum subscripts for all tensors
114    /// * `shapes` — Shape of each input tensor
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if subscripts and shapes are inconsistent.
119    pub fn optimize(subscripts: &Subscripts, shapes: &[&[usize]]) -> Result<Self> {
120        Self::optimize_with_options(subscripts, shapes, &ContractionOptimizerOptions::default())
121    }
122
123    /// Automatically compute an optimized contraction order with explicit
124    /// planner options.
125    ///
126    /// This routes automatic planning through TreeSA using the provided
127    /// configuration. The default options correspond to a greedy-initialized
128    /// TreeSA with zero annealing iterations.
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if subscripts, shapes, or planner options are invalid.
133    pub fn optimize_with_options(
134        subscripts: &Subscripts,
135        shapes: &[&[usize]],
136        options: &ContractionOptimizerOptions,
137    ) -> Result<Self> {
138        let n_inputs = subscripts.inputs.len();
139        if n_inputs <= 1 {
140            return Self::from_pairs(subscripts, shapes, &[]);
141        }
142
143        options.validate()?;
144        let size_dict = build_size_dict(subscripts, shapes, None)?;
145        let pairs =
146            if let Some(omeco_pairs) = optimize_omeco_pairs(subscripts, &size_dict, options)? {
147                omeco_pairs
148            } else {
149                optimize_self_greedy_pairs(subscripts, &size_dict)
150            };
151        Self::from_pairs(subscripts, shapes, &pairs)
152    }
153
154    /// Manually build a contraction tree from a pairwise contraction sequence.
155    ///
156    /// Each pair `(i, j)` specifies which two tensors (or intermediate results)
157    /// to contract next. Intermediate results are assigned indices starting
158    /// from the number of input tensors.
159    ///
160    /// # Arguments
161    ///
162    /// * `subscripts` — Einsum subscripts for all tensors
163    /// * `shapes` — Shape of each input tensor
164    /// * `pairs` — Ordered list of pairwise contractions
165    ///
166    /// # Examples
167    ///
168    /// ```ignore
169    /// // Three tensors: A[ij] B[jk] C[kl] -> D[il]
170    /// // Contract B and C first, then A with the result:
171    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
172    /// let shapes = [&[3, 4][..], &[4, 5], &[5, 6]];
173    /// let tree = ContractionTree::from_pairs(
174    ///     &subs,
175    ///     &shapes,
176    ///     &[(1, 2), (0, 3)],  // B*C -> T(index=3), then A*T -> D
177    /// ).unwrap();
178    /// ```
179    ///
180    /// # Errors
181    ///
182    /// Returns an error if the pairs do not form a valid contraction sequence.
183    pub fn from_pairs(
184        subscripts: &Subscripts,
185        shapes: &[&[usize]],
186        pairs: &[(usize, usize)],
187    ) -> Result<Self> {
188        let n_inputs = subscripts.inputs.len();
189        let required_steps = n_inputs.saturating_sub(1);
190        if pairs.len() != required_steps {
191            return Err(Error::InvalidArgument(format!(
192                "explicit contraction path for {n_inputs} operands must have {required_steps} steps, got {}",
193                pairs.len()
194            )));
195        }
196        let size_dict = build_size_dict(subscripts, shapes, None)?;
197
198        let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
199        let mut live = vec![false; n_inputs + pairs.len()];
200        for slot in live.iter_mut().take(n_inputs) {
201            *slot = true;
202        }
203        let mut steps = Vec::new();
204
205        for (step_idx, &(left, right)) in pairs.iter().enumerate() {
206            let next_idx = n_inputs + step_idx;
207            if left == right {
208                return Err(Error::InvalidArgument(format!(
209                    "pair ({left}, {right}) must reference two distinct live operands"
210                )));
211            }
212            if left >= next_idx || right >= next_idx {
213                return Err(Error::InvalidArgument(format!(
214                    "pair ({left}, {right}) references non-existent operand"
215                )));
216            }
217            if !live[left] || !live[right] {
218                return Err(Error::InvalidArgument(format!(
219                    "pair ({left}, {right}) references an operand or intermediate that is no longer live"
220                )));
221            }
222
223            // Labels needed by other live operands + final output
224            let mut needed: HashSet<u32> = subscripts.output.iter().copied().collect();
225            for (idx, subs) in operand_subs.iter().enumerate() {
226                if idx != left && idx != right && live[idx] {
227                    needed.extend(subs.iter().copied());
228                }
229            }
230
231            let new_subs = intermediate_subs(&operand_subs[left], &operand_subs[right], &needed);
232            operand_subs.push(new_subs);
233            live[left] = false;
234            live[right] = false;
235            live[next_idx] = true;
236            steps.push(ContractionStep { left, right });
237        }
238
239        let live_count = live.iter().filter(|&&is_live| is_live).count();
240        if live_count != 1 {
241            return Err(Error::InvalidArgument(format!(
242                "explicit contraction path must leave exactly one live result, got {live_count}"
243            )));
244        }
245
246        // Pre-compute output shapes for each intermediate step.
247        let step_output_shapes: Vec<Vec<usize>> = (0..steps.len())
248            .map(|step_idx| {
249                let result_idx = n_inputs + step_idx;
250                compute_output_shape(&operand_subs[result_idx], &size_dict)
251            })
252            .collect::<Result<Vec<_>>>()?;
253
254        let mut tree = Self {
255            subscripts: subscripts.clone(),
256            steps,
257            size_dict,
258            operand_subs,
259            step_output_shapes,
260            step_plans: Vec::new(),
261        };
262        tree.step_plans = compile_step_plans(&tree).map_err(Error::InvalidArgument)?;
263        Ok(tree)
264    }
265
266    /// Return the number of pairwise contraction steps in this tree.
267    ///
268    /// # Examples
269    ///
270    /// ```ignore
271    /// use tenferro_einsum::{ContractionTree, Subscripts};
272    ///
273    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
274    /// let tree = ContractionTree::from_pairs(
275    ///     &subs,
276    ///     &[&[2, 2], &[2, 2], &[2, 2]],
277    ///     &[(1, 2), (0, 3)],
278    /// )
279    /// .unwrap();
280    /// assert_eq!(tree.step_count(), 2);
281    /// ```
282    #[must_use]
283    pub fn step_count(&self) -> usize {
284        self.steps.len()
285    }
286
287    /// Return the operand indices for a pairwise contraction step.
288    ///
289    /// The returned indices refer to the original inputs (`0..n_inputs`) and
290    /// then to intermediates (`n_inputs..`) produced by earlier steps.
291    ///
292    /// # Examples
293    ///
294    /// ```ignore
295    /// use tenferro_einsum::{ContractionTree, Subscripts};
296    ///
297    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
298    /// let tree = ContractionTree::from_pairs(
299    ///     &subs,
300    ///     &[&[2, 2], &[2, 2], &[2, 2]],
301    ///     &[(1, 2), (0, 3)],
302    /// )
303    /// .unwrap();
304    /// assert_eq!(tree.step_pair(0), Some((1, 2)));
305    /// ```
306    #[must_use]
307    pub fn step_pair(&self, step_idx: usize) -> Option<(usize, usize)> {
308        self.steps.get(step_idx).map(|step| (step.left, step.right))
309    }
310
311    /// Return the `(lhs, rhs, output)` subscripts for a pairwise step.
312    ///
313    /// The output subscripts are the intermediate labels preserved after the
314    /// contraction, or the final output labels on the last step.
315    ///
316    /// # Examples
317    ///
318    /// ```ignore
319    /// use tenferro_einsum::{ContractionTree, Subscripts};
320    ///
321    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
322    /// let tree = ContractionTree::from_pairs(
323    ///     &subs,
324    ///     &[&[2, 2], &[2, 2], &[2, 2]],
325    ///     &[(1, 2), (0, 3)],
326    /// )
327    /// .unwrap();
328    /// let (lhs, rhs, out) = tree.step_subscripts(0).unwrap();
329    /// assert_eq!(lhs, &[1, 2]);
330    /// assert_eq!(rhs, &[2, 3]);
331    /// assert_eq!(out, &[1, 3]);
332    /// ```
333    #[must_use]
334    pub fn step_subscripts(&self, step_idx: usize) -> Option<(&[u32], &[u32], &[u32])> {
335        let n_inputs = self.subscripts.inputs.len();
336        let step = self.steps.get(step_idx)?;
337        let result_idx = n_inputs + step_idx;
338        Some((
339            &self.operand_subs[step.left],
340            &self.operand_subs[step.right],
341            &self.operand_subs[result_idx],
342        ))
343    }
344
345    pub(crate) fn linear_chain_plan(&self) -> Option<LinearChainPlan> {
346        if self.steps.is_empty() {
347            return Some(LinearChainPlan {
348                first_pair: (0, 0),
349                attachments: Vec::new(),
350            });
351        }
352
353        let n_inputs = self.subscripts.inputs.len();
354        let first = self.steps.first()?;
355        if first.left >= n_inputs || first.right >= n_inputs {
356            return None;
357        }
358
359        let mut seen_inputs = vec![false; n_inputs];
360        seen_inputs[first.left] = true;
361        seen_inputs[first.right] = true;
362        let mut attachments = Vec::with_capacity(self.steps.len().saturating_sub(1));
363        let mut prev_result_idx = n_inputs;
364
365        for (step_idx, step) in self.steps.iter().enumerate().skip(1) {
366            let (prev_on_left, operand) = if step.left == prev_result_idx && step.right < n_inputs {
367                (true, step.right)
368            } else if step.right == prev_result_idx && step.left < n_inputs {
369                (false, step.left)
370            } else {
371                return None;
372            };
373
374            if seen_inputs[operand] {
375                return None;
376            }
377            seen_inputs[operand] = true;
378            attachments.push(ChainAttachment {
379                prev_on_left,
380                operand,
381            });
382            prev_result_idx = n_inputs + step_idx;
383        }
384
385        Some(LinearChainPlan {
386            first_pair: (first.left, first.right),
387            attachments,
388        })
389    }
390}
391
392fn optimize_omeco_pairs(
393    subscripts: &Subscripts,
394    size_dict: &HashMap<u32, usize>,
395    options: &ContractionOptimizerOptions,
396) -> Result<Option<Vec<(usize, usize)>>> {
397    let code = OmecoEinCode::new(subscripts.inputs.clone(), subscripts.output.clone());
398    let optimizer = options.to_treesa();
399    let Some(nested) = optimizer.optimize(&code, size_dict) else {
400        return Ok(None);
401    };
402
403    let mut next_operand = subscripts.inputs.len();
404    let mut pairs = Vec::with_capacity(subscripts.inputs.len().saturating_sub(1));
405    nested_to_pairs(&nested, &mut next_operand, &mut pairs)?;
406    Ok(Some(pairs))
407}
408
409fn nested_to_pairs(
410    nested: &NestedEinsum<u32>,
411    next_operand: &mut usize,
412    pairs: &mut Vec<(usize, usize)>,
413) -> Result<usize> {
414    match nested {
415        NestedEinsum::Leaf { tensor_index } => Ok(*tensor_index),
416        NestedEinsum::Node { args, .. } => {
417            if args.len() != 2 {
418                return Err(Error::InvalidArgument(format!(
419                    "omeco returned non-binary contraction node with {} children",
420                    args.len()
421                )));
422            }
423            let left = nested_to_pairs(&args[0], next_operand, pairs)?;
424            let right = nested_to_pairs(&args[1], next_operand, pairs)?;
425            pairs.push((left, right));
426            let result_idx = *next_operand;
427            *next_operand += 1;
428            Ok(result_idx)
429        }
430    }
431}
432
433fn optimize_self_greedy_pairs(
434    subscripts: &Subscripts,
435    size_dict: &HashMap<u32, usize>,
436) -> Vec<(usize, usize)> {
437    let n_inputs = subscripts.inputs.len();
438    let mut available: Vec<usize> = (0..n_inputs).collect();
439    let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
440    let mut pairs: Vec<(usize, usize)> = Vec::new();
441
442    while available.len() > 1 {
443        let mut best_i = 0;
444        let mut best_j = 1;
445        let mut best_cost = usize::MAX;
446
447        for i in 0..available.len() {
448            for j in (i + 1)..available.len() {
449                let li = available[i];
450                let lj = available[j];
451                let mut needed = HashSet::new();
452                needed.extend(subscripts.output.iter().copied());
453                for &idx in &available {
454                    if idx != li && idx != lj {
455                        needed.extend(operand_subs[idx].iter().copied());
456                    }
457                }
458                let cost =
459                    contraction_cost(&operand_subs[li], &operand_subs[lj], &needed, size_dict);
460                if cost < best_cost {
461                    best_cost = cost;
462                    best_i = i;
463                    best_j = j;
464                }
465            }
466        }
467
468        let left = available[best_i];
469        let right = available[best_j];
470        pairs.push((left, right));
471
472        let mut needed = HashSet::new();
473        needed.extend(subscripts.output.iter().copied());
474        for &idx in &available {
475            if idx != left && idx != right {
476                needed.extend(operand_subs[idx].iter().copied());
477            }
478        }
479        let new_subs = intermediate_subs(&operand_subs[left], &operand_subs[right], &needed);
480        let new_idx = operand_subs.len();
481        operand_subs.push(new_subs);
482        available.remove(best_j);
483        available.remove(best_i);
484        available.push(new_idx);
485    }
486
487    pairs
488}
489
490#[cfg(test)]
491mod tests;