1use std::hash::Hasher;
2
3use omeco::ScoreFunction;
4
5use crate::{
6 ContractionOptimizerOptions, ContractionTree, Error, NestedEinsum, Result, Subscripts,
7};
8
9#[derive(Debug)]
43pub enum EinsumOptimize {
44 Auto(ContractionOptimizerOptions),
46 False,
48 Nested(NestedEinsum),
50 Path(Vec<(usize, usize)>),
57 Tree(ContractionTree),
67}
68
69impl Default for EinsumOptimize {
70 fn default() -> Self {
72 Self::Auto(default_auto_options())
73 }
74}
75
76#[derive(Clone, Debug)]
77pub(crate) enum EinsumPlanSpec {
78 Auto(ContractionOptimizerOptions),
79 LeftToRight,
80 Path(Vec<(usize, usize)>),
81 FixedPairs(Vec<(usize, usize)>),
82}
83
84#[must_use]
86pub(crate) fn default_auto_options() -> ContractionOptimizerOptions {
87 ContractionOptimizerOptions {
88 score: ScoreFunction::time_optimized(),
89 ..Default::default()
90 }
91}
92
93pub(crate) fn plan_spec_from_optimize(
94 optimize: EinsumOptimize,
95 subscripts: &Subscripts,
96) -> Result<EinsumPlanSpec> {
97 match optimize {
98 EinsumOptimize::Auto(options) => {
99 options.validate()?;
100 Ok(EinsumPlanSpec::Auto(options))
101 }
102 EinsumOptimize::False => Ok(EinsumPlanSpec::LeftToRight),
103 EinsumOptimize::Nested(nested) => {
104 let pairs = nested_to_v1_pairs(&nested, subscripts.inputs.len())?;
105 validate_fixed_pairs(&pairs, subscripts.inputs.len())?;
106 Ok(EinsumPlanSpec::FixedPairs(pairs))
107 }
108 EinsumOptimize::Path(path) => {
109 let _ = jax_path_to_v1_pairs(&path, subscripts.inputs.len())?;
110 Ok(EinsumPlanSpec::Path(path))
111 }
112 EinsumOptimize::Tree(_) => Err(Error::InvalidArgument(
113 "precomputed contraction tree requires concrete input shapes; use Path or parenthesized notation for symbolic traced einsum"
114 .into(),
115 )),
116 }
117}
118
119pub(crate) fn resolve_einsum_strategy_with_spec(
120 optimize: EinsumOptimize,
121 subscripts: &Subscripts,
122 shapes: &[&[usize]],
123) -> Result<(EinsumPlanSpec, ContractionTree)> {
124 match optimize {
125 EinsumOptimize::Tree(tree) => {
126 let pairs = tree_pairs(&tree);
127 let spec = EinsumPlanSpec::FixedPairs(pairs);
128 let tree = resolve_plan_spec(&spec, subscripts, shapes)?;
129 Ok((spec, tree))
130 }
131 optimize => {
132 let spec = plan_spec_from_optimize(optimize, subscripts)?;
133 let tree = resolve_plan_spec(&spec, subscripts, shapes)?;
134 Ok((spec, tree))
135 }
136 }
137}
138
139pub(crate) fn resolve_plan_spec(
140 spec: &EinsumPlanSpec,
141 subscripts: &Subscripts,
142 shapes: &[&[usize]],
143) -> Result<ContractionTree> {
144 match spec {
145 EinsumPlanSpec::Auto(options) => {
146 ContractionTree::optimize_with_options(subscripts, shapes, options)
147 }
148 EinsumPlanSpec::LeftToRight => {
149 let n = subscripts.inputs.len();
150 if n <= 1 {
151 ContractionTree::from_pairs(subscripts, shapes, &[])
152 } else {
153 let path: Vec<(usize, usize)> = (0..n - 1).map(|_| (0, 1)).collect();
154 let pairs = jax_path_to_v1_pairs(&path, n)?;
155 ContractionTree::from_pairs(subscripts, shapes, &pairs)
156 }
157 }
158 EinsumPlanSpec::Path(path) => {
159 let pairs = jax_path_to_v1_pairs(path, subscripts.inputs.len())?;
160 ContractionTree::from_pairs(subscripts, shapes, &pairs)
161 }
162 EinsumPlanSpec::FixedPairs(pairs) => ContractionTree::from_pairs(subscripts, shapes, pairs),
163 }
164}
165
166pub(crate) fn hash_einsum_plan_spec(spec: &EinsumPlanSpec, state: &mut dyn Hasher) {
167 match spec {
168 EinsumPlanSpec::Auto(options) => {
169 state.write_u8(0);
170 hash_optimizer_options(options, state);
171 }
172 EinsumPlanSpec::LeftToRight => state.write_u8(1),
173 EinsumPlanSpec::Path(path) => {
174 state.write_u8(2);
175 hash_pairs(path, state);
176 }
177 EinsumPlanSpec::FixedPairs(pairs) => {
178 state.write_u8(3);
179 hash_pairs(pairs, state);
180 }
181 }
182}
183
184pub(crate) fn plan_specs_equal(lhs: &EinsumPlanSpec, rhs: &EinsumPlanSpec) -> bool {
185 match (lhs, rhs) {
186 (EinsumPlanSpec::Auto(lhs), EinsumPlanSpec::Auto(rhs)) => {
187 optimizer_options_equal_by_bits(lhs, rhs)
188 }
189 (EinsumPlanSpec::LeftToRight, EinsumPlanSpec::LeftToRight) => true,
190 (EinsumPlanSpec::Path(lhs), EinsumPlanSpec::Path(rhs)) => lhs == rhs,
191 (EinsumPlanSpec::FixedPairs(lhs), EinsumPlanSpec::FixedPairs(rhs)) => lhs == rhs,
192 _ => false,
193 }
194}
195
196fn tree_pairs(tree: &ContractionTree) -> Vec<(usize, usize)> {
197 (0..tree.step_count())
198 .filter_map(|step| tree.step_pair(step))
199 .collect()
200}
201
202fn validate_fixed_pairs(pairs: &[(usize, usize)], input_count: usize) -> Result<()> {
203 let required_steps = input_count.saturating_sub(1);
204 if pairs.len() != required_steps {
205 return Err(Error::InvalidArgument(format!(
206 "explicit contraction path for {input_count} operands must have {required_steps} steps, got {}",
207 pairs.len()
208 )));
209 }
210
211 let mut live = vec![false; input_count + pairs.len()];
212 for slot in live.iter_mut().take(input_count) {
213 *slot = true;
214 }
215
216 for (step_idx, &(left, right)) in pairs.iter().enumerate() {
217 let next_idx = input_count + step_idx;
218 if left == right {
219 return Err(Error::InvalidArgument(format!(
220 "pair ({left}, {right}) must reference two distinct live operands"
221 )));
222 }
223 if left >= next_idx || right >= next_idx {
224 return Err(Error::InvalidArgument(format!(
225 "pair ({left}, {right}) references non-existent operand"
226 )));
227 }
228 if !live[left] || !live[right] {
229 return Err(Error::InvalidArgument(format!(
230 "pair ({left}, {right}) references an operand or intermediate that is no longer live"
231 )));
232 }
233
234 live[left] = false;
235 live[right] = false;
236 live[next_idx] = true;
237 }
238
239 let live_count = live.iter().filter(|&&is_live| is_live).count();
240 if live_count != 1 {
241 return Err(Error::InvalidArgument(format!(
242 "explicit contraction path must leave exactly one live result, got {live_count}"
243 )));
244 }
245
246 Ok(())
247}
248
249fn hash_pairs(pairs: &[(usize, usize)], state: &mut dyn Hasher) {
250 state.write_usize(pairs.len());
251 for &(left, right) in pairs {
252 state.write_usize(left);
253 state.write_usize(right);
254 }
255}
256
257fn hash_optimizer_options(options: &ContractionOptimizerOptions, state: &mut dyn Hasher) {
258 state.write_usize(options.ntrials);
259 state.write_usize(options.niters);
260 state.write_usize(options.betas.len());
261 for value in &options.betas {
262 state.write_u64(value.to_bits());
263 }
264 state.write_u64(options.score.tc_weight.to_bits());
265 state.write_u64(options.score.sc_weight.to_bits());
266 state.write_u64(options.score.rw_weight.to_bits());
267 state.write_u64(options.score.sc_target.to_bits());
268}
269
270fn optimizer_options_equal_by_bits(
271 lhs: &ContractionOptimizerOptions,
272 rhs: &ContractionOptimizerOptions,
273) -> bool {
274 lhs.ntrials == rhs.ntrials
275 && lhs.niters == rhs.niters
276 && f64_slices_equal_by_bits(&lhs.betas, &rhs.betas)
277 && score_functions_equal_by_bits(&lhs.score, &rhs.score)
278}
279
280pub(crate) fn jax_path_to_v1_pairs(
292 jax_path: &[(usize, usize)],
293 input_count: usize,
294) -> Result<Vec<(usize, usize)>> {
295 let required_steps = input_count.saturating_sub(1);
296 if jax_path.len() != required_steps {
297 return Err(Error::InvalidArgument(format!(
298 "explicit contraction path for {input_count} operands must have {required_steps} steps, got {}",
299 jax_path.len()
300 )));
301 }
302
303 let mut positions: Vec<usize> = (0..input_count).collect();
304 let mut v1_pairs = Vec::with_capacity(jax_path.len());
305
306 for (step, &(pos_a, pos_b)) in jax_path.iter().enumerate() {
307 if pos_a == pos_b {
308 return Err(Error::InvalidArgument(format!(
309 "path step {step} references the same operand position twice: {pos_a}"
310 )));
311 }
312 let current_len = positions.len();
313 if pos_a >= current_len || pos_b >= current_len {
314 return Err(Error::InvalidArgument(format!(
315 "path step {step} references operand positions ({pos_a}, {pos_b}) with only {current_len} live operands"
316 )));
317 }
318
319 let (lo, hi) = if pos_a < pos_b {
320 (pos_a, pos_b)
321 } else {
322 (pos_b, pos_a)
323 };
324 let id_a = positions[lo];
325 let id_b = positions[hi];
326 v1_pairs.push((id_a, id_b));
327
328 positions.remove(hi);
329 positions.remove(lo);
330 positions.push(input_count + step);
331 }
332
333 Ok(v1_pairs)
334}
335
336pub(crate) fn nested_to_v1_pairs(
343 nested: &NestedEinsum,
344 input_count: usize,
345) -> Result<Vec<(usize, usize)>> {
346 let mut pairs = Vec::with_capacity(input_count.saturating_sub(1));
347 let mut next_id = input_count;
348 let root_id = walk_nested(nested, input_count, &mut pairs, &mut next_id)?;
349 if input_count == 0 || root_id >= next_id {
350 return Err(Error::InvalidArgument(
351 "nested einsum did not produce a valid root operand".into(),
352 ));
353 }
354 Ok(pairs)
355}
356
357fn walk_nested(
358 nested: &NestedEinsum,
359 input_count: usize,
360 pairs: &mut Vec<(usize, usize)>,
361 next_id: &mut usize,
362) -> Result<usize> {
363 match nested {
364 NestedEinsum::Leaf(idx) => {
365 if *idx >= input_count {
366 return Err(Error::InvalidArgument(format!(
367 "nested einsum leaf {idx} is outside 0..{input_count}"
368 )));
369 }
370 Ok(*idx)
371 }
372 NestedEinsum::Node { children, .. } => {
373 let Some(first) = children.first() else {
374 return Err(Error::InvalidArgument(
375 "nested einsum node must have at least one child".into(),
376 ));
377 };
378 let mut result_id = walk_nested(first, input_count, pairs, next_id)?;
379 for child in &children[1..] {
380 let child_id = walk_nested(child, input_count, pairs, next_id)?;
381 pairs.push((result_id, child_id));
382 result_id = *next_id;
383 *next_id += 1;
384 }
385 Ok(result_id)
386 }
387 }
388}
389
390fn f64_slices_equal_by_bits(lhs: &[f64], rhs: &[f64]) -> bool {
391 lhs.len() == rhs.len()
392 && lhs
393 .iter()
394 .zip(rhs)
395 .all(|(lhs, rhs)| lhs.to_bits() == rhs.to_bits())
396}
397
398fn score_functions_equal_by_bits(lhs: &ScoreFunction, rhs: &ScoreFunction) -> bool {
399 lhs.tc_weight.to_bits() == rhs.tc_weight.to_bits()
400 && lhs.sc_weight.to_bits() == rhs.sc_weight.to_bits()
401 && lhs.rw_weight.to_bits() == rhs.rw_weight.to_bits()
402 && lhs.sc_target.to_bits() == rhs.sc_target.to_bits()
403}
404
405#[cfg(test)]
406mod tests;