Skip to main content

tenferro_einsum/planning/
tree.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::mem::{size_of, size_of_val};
4
5use omeco::{
6    CodeOptimizer, EinCode as OmecoEinCode, Initializer, NestedEinsum, ScoreFunction, TreeSA,
7};
8
9use crate::cache::{saturating_sum, vec_of_vec_retained_bytes, vec_retained_bytes};
10use crate::planning::plan::{compile_step_plans, DiagPlan, GemmPlan, ReducePlan, StepPlan};
11use crate::syntax::subscripts::Subscripts;
12use crate::util::{build_size_dict, intermediate_subs};
13use crate::{Error, Result};
14
15/// A single step in the contraction sequence.
16pub(crate) struct ContractionStep {
17    pub(crate) left: usize,
18    pub(crate) right: usize,
19}
20
21/// Public options for automatic contraction-path optimization.
22///
23/// The default planner uses TreeSA with a greedy initializer and zero annealing
24/// iterations. This keeps the public API on a single optimizer family while
25/// making the default behavior effectively "greedy-only".
26#[derive(Debug, Clone)]
27pub struct ContractionOptimizerOptions {
28    /// Inverse-temperature schedule for TreeSA.
29    pub betas: Vec<f64>,
30    /// Number of independent TreeSA trials.
31    pub ntrials: usize,
32    /// Annealing iterations per temperature level.
33    pub niters: usize,
34    /// Score function used by TreeSA.
35    pub score: ScoreFunction,
36}
37
38impl Default for ContractionOptimizerOptions {
39    fn default() -> Self {
40        Self {
41            betas: Vec::new(),
42            ntrials: 1,
43            niters: 0,
44            score: ScoreFunction::default(),
45        }
46    }
47}
48
49impl ContractionOptimizerOptions {
50    fn to_treesa(&self) -> TreeSA {
51        TreeSA::new(
52            self.betas.clone(),
53            self.ntrials,
54            self.niters,
55            Initializer::Greedy,
56            self.score.clone(),
57        )
58    }
59
60    pub(crate) fn validate(&self) -> Result<()> {
61        if self.ntrials == 0 {
62            return Err(Error::InvalidArgument(
63                "contraction optimizer ntrials must be at least 1".into(),
64            ));
65        }
66        if self.betas.iter().any(|value| value.is_nan()) {
67            return Err(Error::InvalidArgument(
68                "contraction optimizer betas must not contain NaN".into(),
69            ));
70        }
71        if self.score.tc_weight.is_nan()
72            || self.score.sc_weight.is_nan()
73            || self.score.rw_weight.is_nan()
74            || self.score.sc_target.is_nan()
75        {
76            return Err(Error::InvalidArgument(
77                "contraction optimizer score fields must not contain NaN".into(),
78            ));
79        }
80        Ok(())
81    }
82}
83
84/// Contraction tree determining pairwise contraction order for N-ary einsum.
85///
86/// When contracting more than two tensors, the order in which pairwise
87/// contractions are performed significantly affects performance.
88/// `ContractionTree` encodes this order as a binary tree.
89///
90/// # Optimization
91///
92/// Use [`ContractionTree::optimize`] for automatic cost-based optimization
93/// (e.g., greedy algorithm based on tensor sizes), or
94/// [`ContractionTree::from_pairs`] for manual specification.
95pub struct ContractionTree {
96    /// Original subscripts.
97    pub(crate) subscripts: Subscripts,
98    /// Steps in the contraction (empty for single-tensor case).
99    pub(crate) steps: Vec<ContractionStep>,
100    /// Label → dimension size mapping.
101    pub(crate) size_dict: HashMap<u32, usize>,
102    /// Subscripts for each operand (0..input_count from input, then intermediates).
103    pub(crate) operand_subs: Vec<Vec<u32>>,
104    /// Pre-compiled step plans (cached to avoid recomputation per execute call).
105    pub(crate) step_plans: Vec<StepPlan>,
106}
107
108impl fmt::Debug for ContractionTree {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        f.debug_struct("ContractionTree")
111            .field("input_count", &self.subscripts.inputs.len())
112            .field("output_rank", &self.subscripts.output.len())
113            .field("steps_len", &self.steps.len())
114            .field("size_dict_len", &self.size_dict.len())
115            .field("operand_subs_len", &self.operand_subs.len())
116            .field("step_plans_len", &self.step_plans.len())
117            .finish_non_exhaustive()
118    }
119}
120
121impl ContractionTree {
122    /// Automatically compute an optimized contraction order.
123    ///
124    /// Uses a cost-based heuristic (greedy algorithm) to determine
125    /// the pairwise contraction sequence that minimizes total operation count.
126    ///
127    /// # Arguments
128    ///
129    /// * `subscripts` — Einsum subscripts for all tensors
130    /// * `shapes` — Shape of each input tensor
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if subscripts and shapes are inconsistent.
135    pub fn optimize(subscripts: &Subscripts, shapes: &[&[usize]]) -> Result<Self> {
136        Self::optimize_with_options(subscripts, shapes, &ContractionOptimizerOptions::default())
137    }
138
139    /// Automatically compute an optimized contraction order with explicit
140    /// planner options.
141    ///
142    /// This routes automatic planning through TreeSA using the provided
143    /// configuration. The default options correspond to a greedy-initialized
144    /// TreeSA with zero annealing iterations.
145    ///
146    /// # Errors
147    ///
148    /// Returns an error if subscripts, shapes, or planner options are invalid.
149    pub fn optimize_with_options(
150        subscripts: &Subscripts,
151        shapes: &[&[usize]],
152        options: &ContractionOptimizerOptions,
153    ) -> Result<Self> {
154        options.validate()?;
155        let input_count = subscripts.inputs.len();
156        if input_count <= 1 {
157            return Self::from_pairs(subscripts, shapes, &[]);
158        }
159
160        let size_dict = build_size_dict(subscripts, shapes, None)?;
161        let pairs =
162            if let Some(omeco_pairs) = optimize_omeco_pairs(subscripts, &size_dict, options)? {
163                omeco_pairs
164            } else {
165                optimize_self_greedy_pairs(subscripts, &size_dict)?
166            };
167        Self::from_pairs(subscripts, shapes, &pairs)
168    }
169
170    /// Manually build a contraction tree from a pairwise contraction sequence.
171    ///
172    /// Each pair `(i, j)` specifies which two tensors (or intermediate results)
173    /// to contract next. Intermediate results are assigned indices starting
174    /// from the number of input tensors.
175    ///
176    /// # Arguments
177    ///
178    /// * `subscripts` — Einsum subscripts for all tensors
179    /// * `shapes` — Shape of each input tensor
180    /// * `pairs` — Ordered list of pairwise contractions
181    ///
182    /// # Examples
183    ///
184    /// ```rust
185    /// use tenferro_einsum::{ContractionTree, Subscripts};
186    ///
187    /// // Three tensors: A[ij] B[jk] C[kl] -> D[il]
188    /// // Contract B and C first, then A with the result:
189    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
190    /// let shapes = [&[3, 4][..], &[4, 5], &[5, 6]];
191    /// let tree = ContractionTree::from_pairs(
192    ///     &subs,
193    ///     &shapes,
194    ///     &[(1, 2), (0, 3)],  // B*C -> T(index=3), then A*T -> D
195    /// ).unwrap();
196    /// ```
197    ///
198    /// # Errors
199    ///
200    /// Returns an error if the pairs do not form a valid contraction sequence.
201    pub fn from_pairs(
202        subscripts: &Subscripts,
203        shapes: &[&[usize]],
204        pairs: &[(usize, usize)],
205    ) -> Result<Self> {
206        let input_count = subscripts.inputs.len();
207        let required_steps = input_count.saturating_sub(1);
208        if pairs.len() != required_steps {
209            return Err(Error::InvalidArgument(format!(
210                "explicit contraction path for {input_count} operands must have {required_steps} steps, got {}",
211                pairs.len()
212            )));
213        }
214        let size_dict = build_size_dict(subscripts, shapes, None)?;
215
216        let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
217        let mut live = vec![false; input_count + pairs.len()];
218        for slot in live.iter_mut().take(input_count) {
219            *slot = true;
220        }
221        let mut steps = Vec::new();
222
223        for (step_idx, &(left, right)) in pairs.iter().enumerate() {
224            let next_idx = input_count + step_idx;
225            if left == right {
226                return Err(Error::InvalidArgument(format!(
227                    "pair ({left}, {right}) must reference two distinct live operands"
228                )));
229            }
230            if left >= next_idx || right >= next_idx {
231                return Err(Error::InvalidArgument(format!(
232                    "pair ({left}, {right}) references non-existent operand"
233                )));
234            }
235            if !live[left] || !live[right] {
236                return Err(Error::InvalidArgument(format!(
237                    "pair ({left}, {right}) references an operand or intermediate that is no longer live"
238                )));
239            }
240
241            // Labels needed by other live operands + final output
242            let mut needed: HashSet<u32> = subscripts.output.iter().copied().collect();
243            for (idx, subs) in operand_subs.iter().enumerate() {
244                if idx != left && idx != right && live[idx] {
245                    needed.extend(subs.iter().copied());
246                }
247            }
248
249            let new_subs = intermediate_subs(&operand_subs[left], &operand_subs[right], &needed);
250            operand_subs.push(new_subs);
251            live[left] = false;
252            live[right] = false;
253            live[next_idx] = true;
254            steps.push(ContractionStep { left, right });
255        }
256
257        let live_count = live.iter().filter(|&&is_live| is_live).count();
258        if live_count != 1 {
259            return Err(Error::InvalidArgument(format!(
260                "explicit contraction path must leave exactly one live result, got {live_count}"
261            )));
262        }
263
264        let mut tree = Self {
265            subscripts: subscripts.clone(),
266            steps,
267            size_dict,
268            operand_subs,
269            step_plans: Vec::new(),
270        };
271        tree.step_plans = compile_step_plans(&tree)?;
272        Ok(tree)
273    }
274
275    /// Return the number of pairwise contraction steps in this tree.
276    ///
277    /// # Examples
278    ///
279    /// ```rust
280    /// use tenferro_einsum::{ContractionTree, Subscripts};
281    ///
282    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
283    /// let tree = ContractionTree::from_pairs(
284    ///     &subs,
285    ///     &[&[2, 2], &[2, 2], &[2, 2]],
286    ///     &[(1, 2), (0, 3)],
287    /// )
288    /// .unwrap();
289    /// assert_eq!(tree.step_count(), 2);
290    /// ```
291    #[must_use]
292    pub fn step_count(&self) -> usize {
293        self.steps.len()
294    }
295
296    /// Return the operand indices for a pairwise contraction step.
297    ///
298    /// The returned indices refer to the original inputs (`0..input_count`) and
299    /// then to intermediates (`input_count..`) produced by earlier steps.
300    ///
301    /// # Examples
302    ///
303    /// ```rust
304    /// use tenferro_einsum::{ContractionTree, Subscripts};
305    ///
306    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
307    /// let tree = ContractionTree::from_pairs(
308    ///     &subs,
309    ///     &[&[2, 2], &[2, 2], &[2, 2]],
310    ///     &[(1, 2), (0, 3)],
311    /// )
312    /// .unwrap();
313    /// assert_eq!(tree.step_pair(0), Some((1, 2)));
314    /// ```
315    #[must_use]
316    pub fn step_pair(&self, step_idx: usize) -> Option<(usize, usize)> {
317        self.steps.get(step_idx).map(|step| (step.left, step.right))
318    }
319
320    /// Return the `(lhs, rhs, output)` subscripts for a pairwise step.
321    ///
322    /// The output subscripts are the intermediate labels preserved after the
323    /// contraction, or the final output labels on the last step.
324    ///
325    /// # Examples
326    ///
327    /// ```rust
328    /// use tenferro_einsum::{ContractionTree, Subscripts};
329    ///
330    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
331    /// let tree = ContractionTree::from_pairs(
332    ///     &subs,
333    ///     &[&[2, 2], &[2, 2], &[2, 2]],
334    ///     &[(1, 2), (0, 3)],
335    /// )
336    /// .unwrap();
337    /// let (lhs, rhs, out) = tree.step_subscripts(0).unwrap();
338    /// assert_eq!(lhs, &[1, 2]);
339    /// assert_eq!(rhs, &[2, 3]);
340    /// assert_eq!(out, &[1, 3]);
341    /// ```
342    #[must_use]
343    pub fn step_subscripts(&self, step_idx: usize) -> Option<(&[u32], &[u32], &[u32])> {
344        let input_count = self.subscripts.inputs.len();
345        let step = self.steps.get(step_idx)?;
346        let result_idx = input_count + step_idx;
347        let output_subs = if step_idx + 1 == self.steps.len() {
348            &self.subscripts.output
349        } else {
350            &self.operand_subs[result_idx]
351        };
352        Some((
353            &self.operand_subs[step.left],
354            &self.operand_subs[step.right],
355            output_subs,
356        ))
357    }
358
359    /// Return the precomputed lowering plan for one pairwise contraction step.
360    ///
361    /// # Examples
362    ///
363    /// ```rust
364    /// use tenferro_einsum::{ContractionTree, Subscripts};
365    ///
366    /// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
367    /// let tree = ContractionTree::from_pairs(&subs, &[&[2, 3], &[3, 4]], &[(0, 1)]).unwrap();
368    ///
369    /// assert_eq!(tree.step_plan(0).unwrap().gemm().m(), 2);
370    /// ```
371    #[must_use]
372    pub fn step_plan(&self, step_idx: usize) -> Option<crate::lowering::PairwiseStepPlan<'_>> {
373        self.step_plans
374            .get(step_idx)
375            .map(crate::lowering::PairwiseStepPlan::new)
376    }
377
378    #[doc(hidden)]
379    #[must_use]
380    pub(crate) fn retained_bytes_for_cache_stats(&self) -> usize {
381        saturating_sum([
382            size_of::<Self>(),
383            subscripts_retained_bytes(&self.subscripts),
384            self.steps
385                .capacity()
386                .saturating_mul(size_of::<ContractionStep>()),
387            self.size_dict
388                .capacity()
389                .saturating_mul(size_of::<u32>().saturating_add(size_of::<usize>())),
390            vec_of_vec_retained_bytes(&self.operand_subs),
391            self.step_plans
392                .capacity()
393                .saturating_mul(size_of::<StepPlan>()),
394            saturating_sum(self.step_plans.iter().map(step_plan_retained_bytes)),
395        ])
396    }
397}
398
399fn subscripts_retained_bytes(subscripts: &Subscripts) -> usize {
400    saturating_sum([
401        vec_of_vec_retained_bytes(&subscripts.inputs),
402        vec_retained_bytes(&subscripts.output),
403    ])
404}
405
406fn reduce_plan_retained_bytes(plan: &ReducePlan) -> usize {
407    saturating_sum([
408        vec_retained_bytes(&plan.original_subs),
409        vec_retained_bytes(&plan.kept_subs),
410        vec_retained_bytes(&plan.out_shape),
411    ])
412}
413
414fn diag_plan_retained_bytes(plan: &DiagPlan) -> usize {
415    saturating_sum([
416        vec_retained_bytes(&plan.stages),
417        saturating_sum(plan.stages.iter().map(|stage| {
418            saturating_sum([
419                vec_retained_bytes(&stage.axis_pairs),
420                vec_retained_bytes(&stage.result_subs),
421            ])
422        })),
423        vec_retained_bytes(&plan.result_subs),
424    ])
425}
426
427fn gemm_plan_retained_bytes(plan: &GemmPlan) -> usize {
428    saturating_sum([
429        plan.reduce_a.as_ref().map_or(0, reduce_plan_retained_bytes),
430        plan.reduce_b.as_ref().map_or(0, reduce_plan_retained_bytes),
431        vec_retained_bytes(&plan.subs_a),
432        vec_retained_bytes(&plan.subs_b),
433        vec_retained_bytes(&plan.lo_modes),
434        vec_retained_bytes(&plan.ro_modes),
435        vec_retained_bytes(&plan.sum_modes),
436        vec_retained_bytes(&plan.lo_sizes),
437        vec_retained_bytes(&plan.ro_sizes),
438        vec_retained_bytes(&plan.sum_sizes),
439        vec_retained_bytes(&plan.batch_sizes),
440        vec_retained_bytes(&plan.target_a),
441        vec_retained_bytes(&plan.target_b),
442        vec_retained_bytes(&plan.c_gemm_shape),
443        vec_retained_bytes(&plan.expanded_shape),
444        vec_retained_bytes(&plan.canonical_modes),
445        vec_retained_bytes(&plan.a_gemm_shape),
446        vec_retained_bytes(&plan.b_gemm_shape),
447    ])
448}
449
450fn step_plan_retained_bytes(plan: &StepPlan) -> usize {
451    saturating_sum([
452        plan.diag_a.as_ref().map_or(0, diag_plan_retained_bytes),
453        plan.diag_b.as_ref().map_or(0, diag_plan_retained_bytes),
454        plan.strict_binary.as_ref().map_or(0, size_of_val),
455        gemm_plan_retained_bytes(&plan.gemm),
456    ])
457}
458
459fn optimize_omeco_pairs(
460    subscripts: &Subscripts,
461    size_dict: &HashMap<u32, usize>,
462    options: &ContractionOptimizerOptions,
463) -> Result<Option<Vec<(usize, usize)>>> {
464    let code = OmecoEinCode::new(subscripts.inputs.clone(), subscripts.output.clone());
465    let optimizer = options.to_treesa();
466    let Some(nested) = optimizer.optimize(&code, size_dict) else {
467        return Ok(None);
468    };
469
470    let mut next_operand = subscripts.inputs.len();
471    let mut pairs = Vec::with_capacity(subscripts.inputs.len().saturating_sub(1));
472    nested_to_pairs(&nested, &mut next_operand, &mut pairs)?;
473    Ok(Some(pairs))
474}
475
476fn nested_to_pairs(
477    nested: &NestedEinsum<u32>,
478    next_operand: &mut usize,
479    pairs: &mut Vec<(usize, usize)>,
480) -> Result<usize> {
481    match nested {
482        NestedEinsum::Leaf { tensor_index } => Ok(*tensor_index),
483        NestedEinsum::Node { args, .. } => {
484            if args.len() != 2 {
485                return Err(Error::InvalidArgument(format!(
486                    "omeco returned non-binary contraction node with {} children",
487                    args.len()
488                )));
489            }
490            let left = nested_to_pairs(&args[0], next_operand, pairs)?;
491            let right = nested_to_pairs(&args[1], next_operand, pairs)?;
492            pairs.push((left, right));
493            let result_idx = *next_operand;
494            *next_operand += 1;
495            Ok(result_idx)
496        }
497    }
498}
499
500fn build_operand_label_sets(operand_subs: &[Vec<u32>]) -> Vec<HashSet<u32>> {
501    operand_subs
502        .iter()
503        .map(|subs| subs.iter().copied().collect())
504        .collect()
505}
506
507fn build_needed_label_counts(
508    output_subs: &[u32],
509    available: &[usize],
510    operand_label_sets: &[HashSet<u32>],
511) -> HashMap<u32, usize> {
512    let mut counts = HashMap::new();
513    for &label in output_subs {
514        counts.entry(label).or_insert(1);
515    }
516    for &idx in available {
517        add_labels_to_counts(&mut counts, &operand_label_sets[idx]);
518    }
519    counts
520}
521
522fn add_labels_to_counts(counts: &mut HashMap<u32, usize>, labels: &HashSet<u32>) {
523    for &label in labels {
524        *counts.entry(label).or_insert(0) += 1;
525    }
526}
527
528fn remove_labels_from_counts(counts: &mut HashMap<u32, usize>, labels: &HashSet<u32>) {
529    for &label in labels {
530        match counts.get(&label).copied() {
531            Some(1) => {
532                counts.remove(&label);
533            }
534            Some(count) => {
535                counts.insert(label, count - 1);
536            }
537            None => {}
538        }
539    }
540}
541
542fn candidate_label_is_needed(
543    label: u32,
544    left: usize,
545    right: usize,
546    operand_label_sets: &[HashSet<u32>],
547    needed_label_counts: &HashMap<u32, usize>,
548) -> bool {
549    let mut selected_count = 0;
550    if operand_label_sets[left].contains(&label) {
551        selected_count += 1;
552    }
553    if operand_label_sets[right].contains(&label) {
554        selected_count += 1;
555    }
556    needed_label_counts.get(&label).copied().unwrap_or(0) > selected_count
557}
558
559fn collect_candidate_intermediate_subs(
560    subs_left: &[u32],
561    subs_right: &[u32],
562    left: usize,
563    right: usize,
564    operand_label_sets: &[HashSet<u32>],
565    needed_label_counts: &HashMap<u32, usize>,
566    output: &mut Vec<u32>,
567) {
568    output.clear();
569    for &label in subs_left.iter().chain(subs_right.iter()) {
570        if candidate_label_is_needed(label, left, right, operand_label_sets, needed_label_counts)
571            && !output.contains(&label)
572        {
573            output.push(label);
574        }
575    }
576}
577
578#[derive(Clone, Copy)]
579struct CandidateCostContext<'a> {
580    operand_label_sets: &'a [HashSet<u32>],
581    needed_label_counts: &'a HashMap<u32, usize>,
582    size_dict: &'a HashMap<u32, usize>,
583}
584
585fn candidate_contraction_cost(
586    subs_left: &[u32],
587    subs_right: &[u32],
588    left: usize,
589    right: usize,
590    context: CandidateCostContext<'_>,
591    candidate_subs: &mut Vec<u32>,
592) -> Result<usize> {
593    collect_candidate_intermediate_subs(
594        subs_left,
595        subs_right,
596        left,
597        right,
598        context.operand_label_sets,
599        context.needed_label_counts,
600        candidate_subs,
601    );
602    let mut cost = 1usize;
603    for &label in candidate_subs.iter() {
604        let size = context.size_dict.get(&label).copied().ok_or_else(|| {
605            Error::InvalidArgument(format!(
606                "unknown size for label {label} in contraction cost"
607            ))
608        })?;
609        cost = cost.saturating_mul(size);
610    }
611    Ok(cost.max(1))
612}
613
614fn optimize_self_greedy_pairs(
615    subscripts: &Subscripts,
616    size_dict: &HashMap<u32, usize>,
617) -> Result<Vec<(usize, usize)>> {
618    let input_count = subscripts.inputs.len();
619    let mut available: Vec<usize> = (0..input_count).collect();
620    let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
621    let mut operand_label_sets = build_operand_label_sets(&operand_subs);
622    let mut needed_label_counts =
623        build_needed_label_counts(&subscripts.output, &available, &operand_label_sets);
624    let mut candidate_subs = Vec::new();
625    let mut pairs: Vec<(usize, usize)> = Vec::new();
626
627    while available.len() > 1 {
628        let mut best_i = 0;
629        let mut best_j = 1;
630        let mut best_cost = usize::MAX;
631
632        for i in 0..available.len() {
633            for j in (i + 1)..available.len() {
634                let li = available[i];
635                let lj = available[j];
636                let cost = candidate_contraction_cost(
637                    &operand_subs[li],
638                    &operand_subs[lj],
639                    li,
640                    lj,
641                    CandidateCostContext {
642                        operand_label_sets: &operand_label_sets,
643                        needed_label_counts: &needed_label_counts,
644                        size_dict,
645                    },
646                    &mut candidate_subs,
647                )?;
648                if cost < best_cost {
649                    best_cost = cost;
650                    best_i = i;
651                    best_j = j;
652                }
653            }
654        }
655
656        let left = available[best_i];
657        let right = available[best_j];
658        pairs.push((left, right));
659
660        let mut new_subs = Vec::new();
661        collect_candidate_intermediate_subs(
662            &operand_subs[left],
663            &operand_subs[right],
664            left,
665            right,
666            &operand_label_sets,
667            &needed_label_counts,
668            &mut new_subs,
669        );
670        let new_idx = operand_subs.len();
671        let new_label_set: HashSet<u32> = new_subs.iter().copied().collect();
672        remove_labels_from_counts(&mut needed_label_counts, &operand_label_sets[left]);
673        remove_labels_from_counts(&mut needed_label_counts, &operand_label_sets[right]);
674        add_labels_to_counts(&mut needed_label_counts, &new_label_set);
675        operand_subs.push(new_subs);
676        operand_label_sets.push(new_label_set);
677        available.remove(best_j);
678        available.remove(best_i);
679        available.push(new_idx);
680    }
681
682    Ok(pairs)
683}
684
685#[cfg(test)]
686mod tests;