tenferro_einsum/planning/
tree.rs1use std::collections::{HashMap, HashSet};
2
3use omeco::{
4 CodeOptimizer, EinCode as OmecoEinCode, Initializer, NestedEinsum, ScoreFunction, TreeSA,
5};
6use tenferro_device::{Error, Result};
7
8use crate::planning::plan::{compile_step_plans, StepPlan};
9use crate::syntax::subscripts::Subscripts;
10use crate::util::{build_size_dict, compute_output_shape, contraction_cost, intermediate_subs};
11
12pub(crate) struct ContractionStep {
14 pub(crate) left: usize,
15 pub(crate) right: usize,
16}
17
18#[derive(Debug, Clone)]
24pub struct ContractionOptimizerOptions {
25 pub betas: Vec<f64>,
27 pub ntrials: usize,
29 pub niters: usize,
31 pub score: ScoreFunction,
33}
34
35impl Default for ContractionOptimizerOptions {
36 fn default() -> Self {
37 Self {
38 betas: Vec::new(),
39 ntrials: 1,
40 niters: 0,
41 score: ScoreFunction::default(),
42 }
43 }
44}
45
46impl ContractionOptimizerOptions {
47 fn to_treesa(&self) -> TreeSA {
48 TreeSA::new(
49 self.betas.clone(),
50 self.ntrials,
51 self.niters,
52 Initializer::Greedy,
53 self.score.clone(),
54 )
55 }
56
57 fn validate(&self) -> Result<()> {
58 if self.ntrials == 0 {
59 return Err(Error::InvalidArgument(
60 "contraction optimizer ntrials must be at least 1".into(),
61 ));
62 }
63 Ok(())
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub(crate) struct ChainAttachment {
69 pub(crate) prev_on_left: bool,
70 pub(crate) operand: usize,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub(crate) struct LinearChainPlan {
75 pub(crate) first_pair: (usize, usize),
76 pub(crate) attachments: Vec<ChainAttachment>,
77}
78
79pub struct ContractionTree {
91 pub(crate) subscripts: Subscripts,
93 pub(crate) steps: Vec<ContractionStep>,
95 pub(crate) size_dict: HashMap<u32, usize>,
97 pub(crate) operand_subs: Vec<Vec<u32>>,
99 pub(crate) step_output_shapes: Vec<Vec<usize>>,
101 pub(crate) step_plans: Vec<StepPlan>,
103}
104
105impl ContractionTree {
106 pub fn optimize(subscripts: &Subscripts, shapes: &[&[usize]]) -> Result<Self> {
120 Self::optimize_with_options(subscripts, shapes, &ContractionOptimizerOptions::default())
121 }
122
123 pub fn optimize_with_options(
134 subscripts: &Subscripts,
135 shapes: &[&[usize]],
136 options: &ContractionOptimizerOptions,
137 ) -> Result<Self> {
138 let n_inputs = subscripts.inputs.len();
139 if n_inputs <= 1 {
140 return Self::from_pairs(subscripts, shapes, &[]);
141 }
142
143 options.validate()?;
144 let size_dict = build_size_dict(subscripts, shapes, None)?;
145 let pairs =
146 if let Some(omeco_pairs) = optimize_omeco_pairs(subscripts, &size_dict, options)? {
147 omeco_pairs
148 } else {
149 optimize_self_greedy_pairs(subscripts, &size_dict)
150 };
151 Self::from_pairs(subscripts, shapes, &pairs)
152 }
153
154 pub fn from_pairs(
184 subscripts: &Subscripts,
185 shapes: &[&[usize]],
186 pairs: &[(usize, usize)],
187 ) -> Result<Self> {
188 let n_inputs = subscripts.inputs.len();
189 let required_steps = n_inputs.saturating_sub(1);
190 if pairs.len() != required_steps {
191 return Err(Error::InvalidArgument(format!(
192 "explicit contraction path for {n_inputs} operands must have {required_steps} steps, got {}",
193 pairs.len()
194 )));
195 }
196 let size_dict = build_size_dict(subscripts, shapes, None)?;
197
198 let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
199 let mut live = vec![false; n_inputs + pairs.len()];
200 for slot in live.iter_mut().take(n_inputs) {
201 *slot = true;
202 }
203 let mut steps = Vec::new();
204
205 for (step_idx, &(left, right)) in pairs.iter().enumerate() {
206 let next_idx = n_inputs + step_idx;
207 if left == right {
208 return Err(Error::InvalidArgument(format!(
209 "pair ({left}, {right}) must reference two distinct live operands"
210 )));
211 }
212 if left >= next_idx || right >= next_idx {
213 return Err(Error::InvalidArgument(format!(
214 "pair ({left}, {right}) references non-existent operand"
215 )));
216 }
217 if !live[left] || !live[right] {
218 return Err(Error::InvalidArgument(format!(
219 "pair ({left}, {right}) references an operand or intermediate that is no longer live"
220 )));
221 }
222
223 let mut needed: HashSet<u32> = subscripts.output.iter().copied().collect();
225 for (idx, subs) in operand_subs.iter().enumerate() {
226 if idx != left && idx != right && live[idx] {
227 needed.extend(subs.iter().copied());
228 }
229 }
230
231 let new_subs = intermediate_subs(&operand_subs[left], &operand_subs[right], &needed);
232 operand_subs.push(new_subs);
233 live[left] = false;
234 live[right] = false;
235 live[next_idx] = true;
236 steps.push(ContractionStep { left, right });
237 }
238
239 let live_count = live.iter().filter(|&&is_live| is_live).count();
240 if live_count != 1 {
241 return Err(Error::InvalidArgument(format!(
242 "explicit contraction path must leave exactly one live result, got {live_count}"
243 )));
244 }
245
246 let step_output_shapes: Vec<Vec<usize>> = (0..steps.len())
248 .map(|step_idx| {
249 let result_idx = n_inputs + step_idx;
250 compute_output_shape(&operand_subs[result_idx], &size_dict)
251 })
252 .collect::<Result<Vec<_>>>()?;
253
254 let mut tree = Self {
255 subscripts: subscripts.clone(),
256 steps,
257 size_dict,
258 operand_subs,
259 step_output_shapes,
260 step_plans: Vec::new(),
261 };
262 tree.step_plans = compile_step_plans(&tree).map_err(Error::InvalidArgument)?;
263 Ok(tree)
264 }
265
266 #[must_use]
283 pub fn step_count(&self) -> usize {
284 self.steps.len()
285 }
286
287 #[must_use]
307 pub fn step_pair(&self, step_idx: usize) -> Option<(usize, usize)> {
308 self.steps.get(step_idx).map(|step| (step.left, step.right))
309 }
310
311 #[must_use]
334 pub fn step_subscripts(&self, step_idx: usize) -> Option<(&[u32], &[u32], &[u32])> {
335 let n_inputs = self.subscripts.inputs.len();
336 let step = self.steps.get(step_idx)?;
337 let result_idx = n_inputs + step_idx;
338 Some((
339 &self.operand_subs[step.left],
340 &self.operand_subs[step.right],
341 &self.operand_subs[result_idx],
342 ))
343 }
344
345 pub(crate) fn linear_chain_plan(&self) -> Option<LinearChainPlan> {
346 if self.steps.is_empty() {
347 return Some(LinearChainPlan {
348 first_pair: (0, 0),
349 attachments: Vec::new(),
350 });
351 }
352
353 let n_inputs = self.subscripts.inputs.len();
354 let first = self.steps.first()?;
355 if first.left >= n_inputs || first.right >= n_inputs {
356 return None;
357 }
358
359 let mut seen_inputs = vec![false; n_inputs];
360 seen_inputs[first.left] = true;
361 seen_inputs[first.right] = true;
362 let mut attachments = Vec::with_capacity(self.steps.len().saturating_sub(1));
363 let mut prev_result_idx = n_inputs;
364
365 for (step_idx, step) in self.steps.iter().enumerate().skip(1) {
366 let (prev_on_left, operand) = if step.left == prev_result_idx && step.right < n_inputs {
367 (true, step.right)
368 } else if step.right == prev_result_idx && step.left < n_inputs {
369 (false, step.left)
370 } else {
371 return None;
372 };
373
374 if seen_inputs[operand] {
375 return None;
376 }
377 seen_inputs[operand] = true;
378 attachments.push(ChainAttachment {
379 prev_on_left,
380 operand,
381 });
382 prev_result_idx = n_inputs + step_idx;
383 }
384
385 Some(LinearChainPlan {
386 first_pair: (first.left, first.right),
387 attachments,
388 })
389 }
390}
391
392fn optimize_omeco_pairs(
393 subscripts: &Subscripts,
394 size_dict: &HashMap<u32, usize>,
395 options: &ContractionOptimizerOptions,
396) -> Result<Option<Vec<(usize, usize)>>> {
397 let code = OmecoEinCode::new(subscripts.inputs.clone(), subscripts.output.clone());
398 let optimizer = options.to_treesa();
399 let Some(nested) = optimizer.optimize(&code, size_dict) else {
400 return Ok(None);
401 };
402
403 let mut next_operand = subscripts.inputs.len();
404 let mut pairs = Vec::with_capacity(subscripts.inputs.len().saturating_sub(1));
405 nested_to_pairs(&nested, &mut next_operand, &mut pairs)?;
406 Ok(Some(pairs))
407}
408
409fn nested_to_pairs(
410 nested: &NestedEinsum<u32>,
411 next_operand: &mut usize,
412 pairs: &mut Vec<(usize, usize)>,
413) -> Result<usize> {
414 match nested {
415 NestedEinsum::Leaf { tensor_index } => Ok(*tensor_index),
416 NestedEinsum::Node { args, .. } => {
417 if args.len() != 2 {
418 return Err(Error::InvalidArgument(format!(
419 "omeco returned non-binary contraction node with {} children",
420 args.len()
421 )));
422 }
423 let left = nested_to_pairs(&args[0], next_operand, pairs)?;
424 let right = nested_to_pairs(&args[1], next_operand, pairs)?;
425 pairs.push((left, right));
426 let result_idx = *next_operand;
427 *next_operand += 1;
428 Ok(result_idx)
429 }
430 }
431}
432
433fn optimize_self_greedy_pairs(
434 subscripts: &Subscripts,
435 size_dict: &HashMap<u32, usize>,
436) -> Vec<(usize, usize)> {
437 let n_inputs = subscripts.inputs.len();
438 let mut available: Vec<usize> = (0..n_inputs).collect();
439 let mut operand_subs: Vec<Vec<u32>> = subscripts.inputs.clone();
440 let mut pairs: Vec<(usize, usize)> = Vec::new();
441
442 while available.len() > 1 {
443 let mut best_i = 0;
444 let mut best_j = 1;
445 let mut best_cost = usize::MAX;
446
447 for i in 0..available.len() {
448 for j in (i + 1)..available.len() {
449 let li = available[i];
450 let lj = available[j];
451 let mut needed = HashSet::new();
452 needed.extend(subscripts.output.iter().copied());
453 for &idx in &available {
454 if idx != li && idx != lj {
455 needed.extend(operand_subs[idx].iter().copied());
456 }
457 }
458 let cost =
459 contraction_cost(&operand_subs[li], &operand_subs[lj], &needed, size_dict);
460 if cost < best_cost {
461 best_cost = cost;
462 best_i = i;
463 best_j = j;
464 }
465 }
466 }
467
468 let left = available[best_i];
469 let right = available[best_j];
470 pairs.push((left, right));
471
472 let mut needed = HashSet::new();
473 needed.extend(subscripts.output.iter().copied());
474 for &idx in &available {
475 if idx != left && idx != right {
476 needed.extend(operand_subs[idx].iter().copied());
477 }
478 }
479 let new_subs = intermediate_subs(&operand_subs[left], &operand_subs[right], &needed);
480 let new_idx = operand_subs.len();
481 operand_subs.push(new_subs);
482 available.remove(best_j);
483 available.remove(best_i);
484 available.push(new_idx);
485 }
486
487 pairs
488}
489
490#[cfg(test)]
491mod tests;