1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::mem::{size_of, size_of_val};
4
5use omeco::{
6 CodeOptimizer, EinCode as OmecoEinCode, Initializer, NestedEinsum, ScoreFunction, TreeSA,
7};
8
9use crate::cache::{saturating_sum, vec_of_vec_retained_bytes, vec_retained_bytes};
10use crate::planning::plan::{compile_step_plans, DiagPlan, GemmPlan, ReducePlan, StepPlan};
11use crate::syntax::subscripts::Subscripts;
12use crate::util::{build_size_dict, intermediate_subs};
13use crate::{Error, Result};
14
15pub(crate) struct ContractionStep {
17 pub(crate) left: usize,
18 pub(crate) right: usize,
19}
20
21#[derive(Debug, Clone)]
27pub struct ContractionOptimizerOptions {
28 pub betas: Vec<f64>,
30 pub ntrials: usize,
32 pub niters: usize,
34 pub score: ScoreFunction,
36}
37
38impl Default for ContractionOptimizerOptions {
39 fn default() -> Self {
40 Self {
41 betas: Vec::new(),
42 ntrials: 1,
43 niters: 0,
44 score: ScoreFunction::default(),
45 }
46 }
47}
48
49impl ContractionOptimizerOptions {
50 fn to_treesa(&self) -> TreeSA {
51 TreeSA::new(
52 self.betas.clone(),
53 self.ntrials,
54 self.niters,
55 Initializer::Greedy,
56 self.score.clone(),
57 )
58 }
59
60 pub(crate) fn validate(&self) -> Result<()> {
61 if self.ntrials == 0 {
62 return Err(Error::InvalidArgument(
63 "contraction optimizer ntrials must be at least 1".into(),
64 ));
65 }
66 if self.betas.iter().any(|value| value.is_nan()) {
67 return Err(Error::InvalidArgument(
68 "contraction optimizer betas must not contain NaN".into(),
69 ));
70 }
71 if self.score.tc_weight.is_nan()
72 || self.score.sc_weight.is_nan()
73 || self.score.rw_weight.is_nan()
74 || self.score.sc_target.is_nan()
75 {
76 return Err(Error::InvalidArgument(
77 "contraction optimizer score fields must not contain NaN".into(),
78 ));
79 }
80 Ok(())
81 }
82}
83
84pub struct ContractionTree {
96 pub(crate) subscripts: Subscripts,
98 pub(crate) steps: Vec<ContractionStep>,
100 pub(crate) size_dict: HashMap<u32, usize>,
102 pub(crate) operand_subs: Vec<Vec<u32>>,
104 pub(crate) step_plans: Vec<StepPlan>,
106}
107
108impl fmt::Debug for ContractionTree {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 f.debug_struct("ContractionTree")
111 .field("input_count", &self.subscripts.inputs.len())
112 .field("output_rank", &self.subscripts.output.len())
113 .field("steps_len", &self.steps.len())
114 .field("size_dict_len", &self.size_dict.len())
115 .field("operand_subs_len", &self.operand_subs.len())
116 .field("step_plans_len", &self.step_plans.len())
117 .finish_non_exhaustive()
118 }
119}
120
121impl ContractionTree {
122 pub fn optimize(subscripts: &Subscripts, shapes: &[&[usize]]) -> Result<Self> {
136 Self::optimize_with_options(subscripts, shapes, &ContractionOptimizerOptions::default())
137 }
138
139 pub fn optimize_with_options(
150 subscripts: &Subscripts,
151 shapes: &[&[usize]],
152 options: &ContractionOptimizerOptions,
153 ) -> Result<Self> {
154 options.validate()?;
155 let input_count = subscripts.inputs.len();
156 if input_count <= 1 {
157 return Self::from_pairs(subscripts, shapes, &[]);
158 }
159
160 let size_dict = build_size_dict(subscripts, shapes, None)?;
161 let pairs =
162 if let Some(omeco_pairs) = optimize_omeco_pairs(subscripts, &size_dict, options)? {
163 omeco_pairs
164 } else {
165 optimize_self_greedy_pairs(subscripts, &size_dict)?
166 };
167 Self::from_pairs(subscripts, shapes, &pairs)
168 }
169
170 pub fn from_pairs(
202 subscripts: &Subscripts,
203 shapes: &[&[usize]],
204 pairs: &[(usize, usize)],
205 ) -> Result<Self> {
206 let input_count = subscripts.inputs.len();
207 let required_steps = input_count.saturating_sub(1);
208 if pairs.len() != required_steps {
209 return Err(Error::InvalidArgument(format!(
210 "explicit contraction path for {input_count} operands must have {required_steps} steps, got {}",
211 pairs.len()
212 )));
213 }
214 let size_dict = build_size_dict(subscripts, shapes, None)?;
215
216 let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
217 let mut live = vec![false; input_count + pairs.len()];
218 for slot in live.iter_mut().take(input_count) {
219 *slot = true;
220 }
221 let mut steps = Vec::new();
222
223 for (step_idx, &(left, right)) in pairs.iter().enumerate() {
224 let next_idx = input_count + step_idx;
225 if left == right {
226 return Err(Error::InvalidArgument(format!(
227 "pair ({left}, {right}) must reference two distinct live operands"
228 )));
229 }
230 if left >= next_idx || right >= next_idx {
231 return Err(Error::InvalidArgument(format!(
232 "pair ({left}, {right}) references non-existent operand"
233 )));
234 }
235 if !live[left] || !live[right] {
236 return Err(Error::InvalidArgument(format!(
237 "pair ({left}, {right}) references an operand or intermediate that is no longer live"
238 )));
239 }
240
241 let mut needed: HashSet<u32> = subscripts.output.iter().copied().collect();
243 for (idx, subs) in operand_subs.iter().enumerate() {
244 if idx != left && idx != right && live[idx] {
245 needed.extend(subs.iter().copied());
246 }
247 }
248
249 let new_subs = intermediate_subs(&operand_subs[left], &operand_subs[right], &needed);
250 operand_subs.push(new_subs);
251 live[left] = false;
252 live[right] = false;
253 live[next_idx] = true;
254 steps.push(ContractionStep { left, right });
255 }
256
257 let live_count = live.iter().filter(|&&is_live| is_live).count();
258 if live_count != 1 {
259 return Err(Error::InvalidArgument(format!(
260 "explicit contraction path must leave exactly one live result, got {live_count}"
261 )));
262 }
263
264 let mut tree = Self {
265 subscripts: subscripts.clone(),
266 steps,
267 size_dict,
268 operand_subs,
269 step_plans: Vec::new(),
270 };
271 tree.step_plans = compile_step_plans(&tree)?;
272 Ok(tree)
273 }
274
275 #[must_use]
292 pub fn step_count(&self) -> usize {
293 self.steps.len()
294 }
295
296 #[must_use]
316 pub fn step_pair(&self, step_idx: usize) -> Option<(usize, usize)> {
317 self.steps.get(step_idx).map(|step| (step.left, step.right))
318 }
319
320 #[must_use]
343 pub fn step_subscripts(&self, step_idx: usize) -> Option<(&[u32], &[u32], &[u32])> {
344 let input_count = self.subscripts.inputs.len();
345 let step = self.steps.get(step_idx)?;
346 let result_idx = input_count + step_idx;
347 let output_subs = if step_idx + 1 == self.steps.len() {
348 &self.subscripts.output
349 } else {
350 &self.operand_subs[result_idx]
351 };
352 Some((
353 &self.operand_subs[step.left],
354 &self.operand_subs[step.right],
355 output_subs,
356 ))
357 }
358
359 #[must_use]
372 pub fn step_plan(&self, step_idx: usize) -> Option<crate::lowering::PairwiseStepPlan<'_>> {
373 self.step_plans
374 .get(step_idx)
375 .map(crate::lowering::PairwiseStepPlan::new)
376 }
377
378 #[doc(hidden)]
379 #[must_use]
380 pub(crate) fn retained_bytes_for_cache_stats(&self) -> usize {
381 saturating_sum([
382 size_of::<Self>(),
383 subscripts_retained_bytes(&self.subscripts),
384 self.steps
385 .capacity()
386 .saturating_mul(size_of::<ContractionStep>()),
387 self.size_dict
388 .capacity()
389 .saturating_mul(size_of::<u32>().saturating_add(size_of::<usize>())),
390 vec_of_vec_retained_bytes(&self.operand_subs),
391 self.step_plans
392 .capacity()
393 .saturating_mul(size_of::<StepPlan>()),
394 saturating_sum(self.step_plans.iter().map(step_plan_retained_bytes)),
395 ])
396 }
397}
398
399fn subscripts_retained_bytes(subscripts: &Subscripts) -> usize {
400 saturating_sum([
401 vec_of_vec_retained_bytes(&subscripts.inputs),
402 vec_retained_bytes(&subscripts.output),
403 ])
404}
405
406fn reduce_plan_retained_bytes(plan: &ReducePlan) -> usize {
407 saturating_sum([
408 vec_retained_bytes(&plan.original_subs),
409 vec_retained_bytes(&plan.kept_subs),
410 vec_retained_bytes(&plan.out_shape),
411 ])
412}
413
414fn diag_plan_retained_bytes(plan: &DiagPlan) -> usize {
415 saturating_sum([
416 vec_retained_bytes(&plan.stages),
417 saturating_sum(plan.stages.iter().map(|stage| {
418 saturating_sum([
419 vec_retained_bytes(&stage.axis_pairs),
420 vec_retained_bytes(&stage.result_subs),
421 ])
422 })),
423 vec_retained_bytes(&plan.result_subs),
424 ])
425}
426
427fn gemm_plan_retained_bytes(plan: &GemmPlan) -> usize {
428 saturating_sum([
429 plan.reduce_a.as_ref().map_or(0, reduce_plan_retained_bytes),
430 plan.reduce_b.as_ref().map_or(0, reduce_plan_retained_bytes),
431 vec_retained_bytes(&plan.subs_a),
432 vec_retained_bytes(&plan.subs_b),
433 vec_retained_bytes(&plan.lo_modes),
434 vec_retained_bytes(&plan.ro_modes),
435 vec_retained_bytes(&plan.sum_modes),
436 vec_retained_bytes(&plan.lo_sizes),
437 vec_retained_bytes(&plan.ro_sizes),
438 vec_retained_bytes(&plan.sum_sizes),
439 vec_retained_bytes(&plan.batch_sizes),
440 vec_retained_bytes(&plan.target_a),
441 vec_retained_bytes(&plan.target_b),
442 vec_retained_bytes(&plan.c_gemm_shape),
443 vec_retained_bytes(&plan.expanded_shape),
444 vec_retained_bytes(&plan.canonical_modes),
445 vec_retained_bytes(&plan.a_gemm_shape),
446 vec_retained_bytes(&plan.b_gemm_shape),
447 ])
448}
449
450fn step_plan_retained_bytes(plan: &StepPlan) -> usize {
451 saturating_sum([
452 plan.diag_a.as_ref().map_or(0, diag_plan_retained_bytes),
453 plan.diag_b.as_ref().map_or(0, diag_plan_retained_bytes),
454 plan.strict_binary.as_ref().map_or(0, size_of_val),
455 gemm_plan_retained_bytes(&plan.gemm),
456 ])
457}
458
459fn optimize_omeco_pairs(
460 subscripts: &Subscripts,
461 size_dict: &HashMap<u32, usize>,
462 options: &ContractionOptimizerOptions,
463) -> Result<Option<Vec<(usize, usize)>>> {
464 let code = OmecoEinCode::new(subscripts.inputs.clone(), subscripts.output.clone());
465 let optimizer = options.to_treesa();
466 let Some(nested) = optimizer.optimize(&code, size_dict) else {
467 return Ok(None);
468 };
469
470 let mut next_operand = subscripts.inputs.len();
471 let mut pairs = Vec::with_capacity(subscripts.inputs.len().saturating_sub(1));
472 nested_to_pairs(&nested, &mut next_operand, &mut pairs)?;
473 Ok(Some(pairs))
474}
475
476fn nested_to_pairs(
477 nested: &NestedEinsum<u32>,
478 next_operand: &mut usize,
479 pairs: &mut Vec<(usize, usize)>,
480) -> Result<usize> {
481 match nested {
482 NestedEinsum::Leaf { tensor_index } => Ok(*tensor_index),
483 NestedEinsum::Node { args, .. } => {
484 if args.len() != 2 {
485 return Err(Error::InvalidArgument(format!(
486 "omeco returned non-binary contraction node with {} children",
487 args.len()
488 )));
489 }
490 let left = nested_to_pairs(&args[0], next_operand, pairs)?;
491 let right = nested_to_pairs(&args[1], next_operand, pairs)?;
492 pairs.push((left, right));
493 let result_idx = *next_operand;
494 *next_operand += 1;
495 Ok(result_idx)
496 }
497 }
498}
499
500fn build_operand_label_sets(operand_subs: &[Vec<u32>]) -> Vec<HashSet<u32>> {
501 operand_subs
502 .iter()
503 .map(|subs| subs.iter().copied().collect())
504 .collect()
505}
506
507fn build_needed_label_counts(
508 output_subs: &[u32],
509 available: &[usize],
510 operand_label_sets: &[HashSet<u32>],
511) -> HashMap<u32, usize> {
512 let mut counts = HashMap::new();
513 for &label in output_subs {
514 counts.entry(label).or_insert(1);
515 }
516 for &idx in available {
517 add_labels_to_counts(&mut counts, &operand_label_sets[idx]);
518 }
519 counts
520}
521
522fn add_labels_to_counts(counts: &mut HashMap<u32, usize>, labels: &HashSet<u32>) {
523 for &label in labels {
524 *counts.entry(label).or_insert(0) += 1;
525 }
526}
527
528fn remove_labels_from_counts(counts: &mut HashMap<u32, usize>, labels: &HashSet<u32>) {
529 for &label in labels {
530 match counts.get(&label).copied() {
531 Some(1) => {
532 counts.remove(&label);
533 }
534 Some(count) => {
535 counts.insert(label, count - 1);
536 }
537 None => {}
538 }
539 }
540}
541
542fn candidate_label_is_needed(
543 label: u32,
544 left: usize,
545 right: usize,
546 operand_label_sets: &[HashSet<u32>],
547 needed_label_counts: &HashMap<u32, usize>,
548) -> bool {
549 let mut selected_count = 0;
550 if operand_label_sets[left].contains(&label) {
551 selected_count += 1;
552 }
553 if operand_label_sets[right].contains(&label) {
554 selected_count += 1;
555 }
556 needed_label_counts.get(&label).copied().unwrap_or(0) > selected_count
557}
558
559fn collect_candidate_intermediate_subs(
560 subs_left: &[u32],
561 subs_right: &[u32],
562 left: usize,
563 right: usize,
564 operand_label_sets: &[HashSet<u32>],
565 needed_label_counts: &HashMap<u32, usize>,
566 output: &mut Vec<u32>,
567) {
568 output.clear();
569 for &label in subs_left.iter().chain(subs_right.iter()) {
570 if candidate_label_is_needed(label, left, right, operand_label_sets, needed_label_counts)
571 && !output.contains(&label)
572 {
573 output.push(label);
574 }
575 }
576}
577
578#[derive(Clone, Copy)]
579struct CandidateCostContext<'a> {
580 operand_label_sets: &'a [HashSet<u32>],
581 needed_label_counts: &'a HashMap<u32, usize>,
582 size_dict: &'a HashMap<u32, usize>,
583}
584
585fn candidate_contraction_cost(
586 subs_left: &[u32],
587 subs_right: &[u32],
588 left: usize,
589 right: usize,
590 context: CandidateCostContext<'_>,
591 candidate_subs: &mut Vec<u32>,
592) -> Result<usize> {
593 collect_candidate_intermediate_subs(
594 subs_left,
595 subs_right,
596 left,
597 right,
598 context.operand_label_sets,
599 context.needed_label_counts,
600 candidate_subs,
601 );
602 let mut cost = 1usize;
603 for &label in candidate_subs.iter() {
604 let size = context.size_dict.get(&label).copied().ok_or_else(|| {
605 Error::InvalidArgument(format!(
606 "unknown size for label {label} in contraction cost"
607 ))
608 })?;
609 cost = cost.saturating_mul(size);
610 }
611 Ok(cost.max(1))
612}
613
614fn optimize_self_greedy_pairs(
615 subscripts: &Subscripts,
616 size_dict: &HashMap<u32, usize>,
617) -> Result<Vec<(usize, usize)>> {
618 let input_count = subscripts.inputs.len();
619 let mut available: Vec<usize> = (0..input_count).collect();
620 let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
621 let mut operand_label_sets = build_operand_label_sets(&operand_subs);
622 let mut needed_label_counts =
623 build_needed_label_counts(&subscripts.output, &available, &operand_label_sets);
624 let mut candidate_subs = Vec::new();
625 let mut pairs: Vec<(usize, usize)> = Vec::new();
626
627 while available.len() > 1 {
628 let mut best_i = 0;
629 let mut best_j = 1;
630 let mut best_cost = usize::MAX;
631
632 for i in 0..available.len() {
633 for j in (i + 1)..available.len() {
634 let li = available[i];
635 let lj = available[j];
636 let cost = candidate_contraction_cost(
637 &operand_subs[li],
638 &operand_subs[lj],
639 li,
640 lj,
641 CandidateCostContext {
642 operand_label_sets: &operand_label_sets,
643 needed_label_counts: &needed_label_counts,
644 size_dict,
645 },
646 &mut candidate_subs,
647 )?;
648 if cost < best_cost {
649 best_cost = cost;
650 best_i = i;
651 best_j = j;
652 }
653 }
654 }
655
656 let left = available[best_i];
657 let right = available[best_j];
658 pairs.push((left, right));
659
660 let mut new_subs = Vec::new();
661 collect_candidate_intermediate_subs(
662 &operand_subs[left],
663 &operand_subs[right],
664 left,
665 right,
666 &operand_label_sets,
667 &needed_label_counts,
668 &mut new_subs,
669 );
670 let new_idx = operand_subs.len();
671 let new_label_set: HashSet<u32> = new_subs.iter().copied().collect();
672 remove_labels_from_counts(&mut needed_label_counts, &operand_label_sets[left]);
673 remove_labels_from_counts(&mut needed_label_counts, &operand_label_sets[right]);
674 add_labels_to_counts(&mut needed_label_counts, &new_label_set);
675 operand_subs.push(new_subs);
676 operand_label_sets.push(new_label_set);
677 available.remove(best_j);
678 available.remove(best_i);
679 available.push(new_idx);
680 }
681
682 Ok(pairs)
683}
684
685#[cfg(test)]
686mod tests;