tenferro_einsum/planning/
tree.rs1use std::collections::{HashMap, HashSet};
2
3use omeco::{
4 CodeOptimizer, EinCode as OmecoEinCode, Initializer, NestedEinsum, ScoreFunction, TreeSA,
5};
6use tenferro_device::{Error, Result};
7
8use crate::execution::util::{
9 build_size_dict, compute_output_shape, contraction_cost, intermediate_subs,
10};
11use crate::planning::plan::{compile_step_plans, StepPlan};
12use crate::syntax::subscripts::Subscripts;
13
14pub(crate) struct ContractionStep {
16 pub(crate) left: usize,
17 pub(crate) right: usize,
18}
19
20#[derive(Debug, Clone)]
26pub struct ContractionOptimizerOptions {
27 pub betas: Vec<f64>,
29 pub ntrials: usize,
31 pub niters: usize,
33 pub score: ScoreFunction,
35}
36
37impl Default for ContractionOptimizerOptions {
38 fn default() -> Self {
39 Self {
40 betas: Vec::new(),
41 ntrials: 1,
42 niters: 0,
43 score: ScoreFunction::default(),
44 }
45 }
46}
47
48impl ContractionOptimizerOptions {
49 fn to_treesa(&self) -> TreeSA {
50 TreeSA::new(
51 self.betas.clone(),
52 self.ntrials,
53 self.niters,
54 Initializer::Greedy,
55 self.score.clone(),
56 )
57 }
58
59 fn validate(&self) -> Result<()> {
60 if self.ntrials == 0 {
61 return Err(Error::InvalidArgument(
62 "contraction optimizer ntrials must be at least 1".into(),
63 ));
64 }
65 Ok(())
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub(crate) struct ChainAttachment {
71 pub(crate) prev_on_left: bool,
72 pub(crate) operand: usize,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub(crate) struct LinearChainPlan {
77 pub(crate) first_pair: (usize, usize),
78 pub(crate) attachments: Vec<ChainAttachment>,
79}
80
81pub struct ContractionTree {
93 pub(crate) subscripts: Subscripts,
95 pub(crate) steps: Vec<ContractionStep>,
97 pub(crate) size_dict: HashMap<u32, usize>,
99 pub(crate) operand_subs: Vec<Vec<u32>>,
101 pub(crate) step_output_shapes: Vec<Vec<usize>>,
103 pub(crate) step_plans: Vec<StepPlan>,
105}
106
107impl ContractionTree {
108 pub fn optimize(subscripts: &Subscripts, shapes: &[&[usize]]) -> Result<Self> {
122 Self::optimize_with_options(subscripts, shapes, &ContractionOptimizerOptions::default())
123 }
124
125 pub fn optimize_with_options(
136 subscripts: &Subscripts,
137 shapes: &[&[usize]],
138 options: &ContractionOptimizerOptions,
139 ) -> Result<Self> {
140 let n_inputs = subscripts.inputs.len();
141 if n_inputs <= 1 {
142 return Self::from_pairs(subscripts, shapes, &[]);
143 }
144
145 options.validate()?;
146 let size_dict = build_size_dict(subscripts, shapes, None)?;
147 let pairs =
148 if let Some(omeco_pairs) = optimize_omeco_pairs(subscripts, &size_dict, options)? {
149 omeco_pairs
150 } else {
151 optimize_self_greedy_pairs(subscripts, &size_dict)
152 };
153 Self::from_pairs(subscripts, shapes, &pairs)
154 }
155
156 pub fn from_pairs(
186 subscripts: &Subscripts,
187 shapes: &[&[usize]],
188 pairs: &[(usize, usize)],
189 ) -> Result<Self> {
190 let n_inputs = subscripts.inputs.len();
191 let required_steps = n_inputs.saturating_sub(1);
192 if pairs.len() != required_steps {
193 return Err(Error::InvalidArgument(format!(
194 "explicit contraction path for {n_inputs} operands must have {required_steps} steps, got {}",
195 pairs.len()
196 )));
197 }
198 let size_dict = build_size_dict(subscripts, shapes, None)?;
199
200 let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
201 let mut live = vec![false; n_inputs + pairs.len()];
202 for slot in live.iter_mut().take(n_inputs) {
203 *slot = true;
204 }
205 let mut steps = Vec::new();
206
207 for (step_idx, &(left, right)) in pairs.iter().enumerate() {
208 let next_idx = n_inputs + step_idx;
209 if left == right {
210 return Err(Error::InvalidArgument(format!(
211 "pair ({left}, {right}) must reference two distinct live operands"
212 )));
213 }
214 if left >= next_idx || right >= next_idx {
215 return Err(Error::InvalidArgument(format!(
216 "pair ({left}, {right}) references non-existent operand"
217 )));
218 }
219 if !live[left] || !live[right] {
220 return Err(Error::InvalidArgument(format!(
221 "pair ({left}, {right}) references an operand or intermediate that is no longer live"
222 )));
223 }
224
225 let mut needed: HashSet<u32> = subscripts.output.iter().copied().collect();
227 for (idx, subs) in operand_subs.iter().enumerate() {
228 if idx != left && idx != right && live[idx] {
229 needed.extend(subs.iter().copied());
230 }
231 }
232
233 let new_subs = intermediate_subs(&operand_subs[left], &operand_subs[right], &needed);
234 operand_subs.push(new_subs);
235 live[left] = false;
236 live[right] = false;
237 live[next_idx] = true;
238 steps.push(ContractionStep { left, right });
239 }
240
241 let live_count = live.iter().filter(|&&is_live| is_live).count();
242 if live_count != 1 {
243 return Err(Error::InvalidArgument(format!(
244 "explicit contraction path must leave exactly one live result, got {live_count}"
245 )));
246 }
247
248 let step_output_shapes: Vec<Vec<usize>> = (0..steps.len())
250 .map(|step_idx| {
251 let result_idx = n_inputs + step_idx;
252 compute_output_shape(&operand_subs[result_idx], &size_dict)
253 })
254 .collect::<Result<Vec<_>>>()?;
255
256 let mut tree = Self {
257 subscripts: subscripts.clone(),
258 steps,
259 size_dict,
260 operand_subs,
261 step_output_shapes,
262 step_plans: Vec::new(),
263 };
264 tree.step_plans = compile_step_plans(&tree).map_err(Error::InvalidArgument)?;
265 Ok(tree)
266 }
267
268 #[must_use]
285 pub fn step_count(&self) -> usize {
286 self.steps.len()
287 }
288
289 #[must_use]
309 pub fn step_pair(&self, step_idx: usize) -> Option<(usize, usize)> {
310 self.steps.get(step_idx).map(|step| (step.left, step.right))
311 }
312
313 #[must_use]
336 pub fn step_subscripts(&self, step_idx: usize) -> Option<(&[u32], &[u32], &[u32])> {
337 let n_inputs = self.subscripts.inputs.len();
338 let step = self.steps.get(step_idx)?;
339 let result_idx = n_inputs + step_idx;
340 Some((
341 &self.operand_subs[step.left],
342 &self.operand_subs[step.right],
343 &self.operand_subs[result_idx],
344 ))
345 }
346
347 pub(crate) fn linear_chain_plan(&self) -> Option<LinearChainPlan> {
348 if self.steps.is_empty() {
349 return Some(LinearChainPlan {
350 first_pair: (0, 0),
351 attachments: Vec::new(),
352 });
353 }
354
355 let n_inputs = self.subscripts.inputs.len();
356 let first = self.steps.first()?;
357 if first.left >= n_inputs || first.right >= n_inputs {
358 return None;
359 }
360
361 let mut seen_inputs = vec![false; n_inputs];
362 seen_inputs[first.left] = true;
363 seen_inputs[first.right] = true;
364 let mut attachments = Vec::with_capacity(self.steps.len().saturating_sub(1));
365 let mut prev_result_idx = n_inputs;
366
367 for (step_idx, step) in self.steps.iter().enumerate().skip(1) {
368 let (prev_on_left, operand) = if step.left == prev_result_idx && step.right < n_inputs {
369 (true, step.right)
370 } else if step.right == prev_result_idx && step.left < n_inputs {
371 (false, step.left)
372 } else {
373 return None;
374 };
375
376 if seen_inputs[operand] {
377 return None;
378 }
379 seen_inputs[operand] = true;
380 attachments.push(ChainAttachment {
381 prev_on_left,
382 operand,
383 });
384 prev_result_idx = n_inputs + step_idx;
385 }
386
387 Some(LinearChainPlan {
388 first_pair: (first.left, first.right),
389 attachments,
390 })
391 }
392}
393
394fn optimize_omeco_pairs(
395 subscripts: &Subscripts,
396 size_dict: &HashMap<u32, usize>,
397 options: &ContractionOptimizerOptions,
398) -> Result<Option<Vec<(usize, usize)>>> {
399 let code = OmecoEinCode::new(subscripts.inputs.clone(), subscripts.output.clone());
400 let optimizer = options.to_treesa();
401 let Some(nested) = optimizer.optimize(&code, size_dict) else {
402 return Ok(None);
403 };
404
405 let mut next_operand = subscripts.inputs.len();
406 let mut pairs = Vec::with_capacity(subscripts.inputs.len().saturating_sub(1));
407 nested_to_pairs(&nested, &mut next_operand, &mut pairs)?;
408 Ok(Some(pairs))
409}
410
411fn nested_to_pairs(
412 nested: &NestedEinsum<u32>,
413 next_operand: &mut usize,
414 pairs: &mut Vec<(usize, usize)>,
415) -> Result<usize> {
416 match nested {
417 NestedEinsum::Leaf { tensor_index } => Ok(*tensor_index),
418 NestedEinsum::Node { args, .. } => {
419 if args.len() != 2 {
420 return Err(Error::InvalidArgument(format!(
421 "omeco returned non-binary contraction node with {} children",
422 args.len()
423 )));
424 }
425 let left = nested_to_pairs(&args[0], next_operand, pairs)?;
426 let right = nested_to_pairs(&args[1], next_operand, pairs)?;
427 pairs.push((left, right));
428 let result_idx = *next_operand;
429 *next_operand += 1;
430 Ok(result_idx)
431 }
432 }
433}
434
435fn optimize_self_greedy_pairs(
436 subscripts: &Subscripts,
437 size_dict: &HashMap<u32, usize>,
438) -> Vec<(usize, usize)> {
439 let n_inputs = subscripts.inputs.len();
440 let mut available: Vec<usize> = (0..n_inputs).collect();
441 let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
442 let mut pairs: Vec<(usize, usize)> = Vec::new();
443
444 while available.len() > 1 {
445 let mut best_i = 0;
446 let mut best_j = 1;
447 let mut best_cost = usize::MAX;
448
449 for i in 0..available.len() {
450 for j in (i + 1)..available.len() {
451 let li = available[i];
452 let lj = available[j];
453 let mut needed = HashSet::new();
454 needed.extend(subscripts.output.iter().copied());
455 for &idx in &available {
456 if idx != li && idx != lj {
457 needed.extend(operand_subs[idx].iter().copied());
458 }
459 }
460 let cost =
461 contraction_cost(&operand_subs[li], &operand_subs[lj], &needed, size_dict);
462 if cost < best_cost {
463 best_cost = cost;
464 best_i = i;
465 best_j = j;
466 }
467 }
468 }
469
470 let left = available[best_i];
471 let right = available[best_j];
472 pairs.push((left, right));
473
474 let mut needed = HashSet::new();
475 needed.extend(subscripts.output.iter().copied());
476 for &idx in &available {
477 if idx != left && idx != right {
478 needed.extend(operand_subs[idx].iter().copied());
479 }
480 }
481 let new_subs = intermediate_subs(&operand_subs[left], &operand_subs[right], &needed);
482 let new_idx = operand_subs.len();
483 operand_subs.push(new_subs);
484 available.remove(best_j);
485 available.remove(best_i);
486 available.push(new_idx);
487 }
488
489 pairs
490}
491
492#[cfg(test)]
493mod tests;