Skip to main content

tenferro_einsum/
optimize.rs

1use std::hash::Hasher;
2
3use omeco::ScoreFunction;
4
5use crate::{
6    ContractionOptimizerOptions, ContractionTree, Error, NestedEinsum, Result, Subscripts,
7};
8
9/// Controls how the contraction path is determined for N-ary einsum.
10///
11/// The traced API resolves this enum into a shape-independent plan
12/// specification and stores that specification in the einsum extension
13/// payload. Concrete traced inputs can resolve a [`ContractionTree`] while the
14/// op is built; symbolic traced inputs carry the plan specification until the
15/// extension runtime sees concrete execution shapes.
16///
17/// Planner options and explicit paths are part of the extension payload
18/// identity. Two otherwise identical traced einsum ops with different
19/// optimizer options, explicit paths, or fixed plan identities are not treated
20/// as the same extension op, and their compile/runtime plan cache entries are
21/// kept separate.
22///
23/// # Examples
24///
25/// ```
26/// use tenferro_einsum::{EinsumOptimize, GraphCompilerEinsumExt};
27/// use tenferro_runtime::{GraphCompiler, TracedTensor};
28///
29/// let lhs = TracedTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]).unwrap();
30/// let rhs = TracedTensor::from_vec_col_major(vec![3, 2], vec![1.0_f64; 6]).unwrap();
31///
32/// let mut compiler = GraphCompiler::new();
33/// let out = compiler.einsum_with(
34///     &[&lhs, &rhs],
35///     "ij,jk->ik",
36///     EinsumOptimize::False,
37/// )
38/// .unwrap();
39///
40/// assert_eq!(out.try_concrete_shape(), Some(vec![2, 2]));
41/// ```
42#[derive(Debug)]
43pub enum EinsumOptimize {
44    /// Automatic optimization via omeco TreeSA.
45    Auto(ContractionOptimizerOptions),
46    /// No optimization: contract operands left-to-right.
47    False,
48    /// Parenthesized notation specifying contraction order.
49    Nested(NestedEinsum),
50    /// JAX-compatible position-based contraction path.
51    ///
52    /// Each pair references positions in a shrinking operand list. After each
53    /// contraction, the two operands are removed and the result is appended.
54    /// Because this representation is independent of concrete dimension
55    /// values, it can be used with symbolic traced inputs.
56    Path(Vec<(usize, usize)>),
57    /// Pre-computed contraction tree.
58    ///
59    /// A tree contains concrete shape-dependent planning results. It is
60    /// accepted when shapes are concrete, then converted into fixed contraction
61    /// pairs for the extension payload. A binary tree with the single pair
62    /// `(0, 1)` or `(1, 0)` may bypass the extension path and lower directly to
63    /// `dot_general`, including with symbolic traced inputs. Use
64    /// [`EinsumOptimize::Path`] instead when building a symbolic traced graph
65    /// for N-ary contraction.
66    Tree(ContractionTree),
67}
68
69impl Default for EinsumOptimize {
70    /// Default: time-optimized automatic planning.
71    fn default() -> Self {
72        Self::Auto(default_auto_options())
73    }
74}
75
76#[derive(Clone, Debug)]
77pub(crate) enum EinsumPlanSpec {
78    Auto(ContractionOptimizerOptions),
79    LeftToRight,
80    Path(Vec<(usize, usize)>),
81    FixedPairs(Vec<(usize, usize)>),
82}
83
84/// Return the default automatic optimizer options.
85#[must_use]
86pub(crate) fn default_auto_options() -> ContractionOptimizerOptions {
87    ContractionOptimizerOptions {
88        score: ScoreFunction::time_optimized(),
89        ..Default::default()
90    }
91}
92
93pub(crate) fn plan_spec_from_optimize(
94    optimize: EinsumOptimize,
95    subscripts: &Subscripts,
96) -> Result<EinsumPlanSpec> {
97    match optimize {
98        EinsumOptimize::Auto(options) => {
99            options.validate()?;
100            Ok(EinsumPlanSpec::Auto(options))
101        }
102        EinsumOptimize::False => Ok(EinsumPlanSpec::LeftToRight),
103        EinsumOptimize::Nested(nested) => {
104            let pairs = nested_to_v1_pairs(&nested, subscripts.inputs.len())?;
105            validate_fixed_pairs(&pairs, subscripts.inputs.len())?;
106            Ok(EinsumPlanSpec::FixedPairs(pairs))
107        }
108        EinsumOptimize::Path(path) => {
109            let _ = jax_path_to_v1_pairs(&path, subscripts.inputs.len())?;
110            Ok(EinsumPlanSpec::Path(path))
111        }
112        EinsumOptimize::Tree(_) => Err(Error::InvalidArgument(
113            "precomputed contraction tree requires concrete input shapes; use Path or parenthesized notation for symbolic traced einsum"
114                .into(),
115        )),
116    }
117}
118
119pub(crate) fn resolve_einsum_strategy_with_spec(
120    optimize: EinsumOptimize,
121    subscripts: &Subscripts,
122    shapes: &[&[usize]],
123) -> Result<(EinsumPlanSpec, ContractionTree)> {
124    match optimize {
125        EinsumOptimize::Tree(tree) => {
126            let pairs = tree_pairs(&tree);
127            let spec = EinsumPlanSpec::FixedPairs(pairs);
128            let tree = resolve_plan_spec(&spec, subscripts, shapes)?;
129            Ok((spec, tree))
130        }
131        optimize => {
132            let spec = plan_spec_from_optimize(optimize, subscripts)?;
133            let tree = resolve_plan_spec(&spec, subscripts, shapes)?;
134            Ok((spec, tree))
135        }
136    }
137}
138
139pub(crate) fn resolve_plan_spec(
140    spec: &EinsumPlanSpec,
141    subscripts: &Subscripts,
142    shapes: &[&[usize]],
143) -> Result<ContractionTree> {
144    match spec {
145        EinsumPlanSpec::Auto(options) => {
146            ContractionTree::optimize_with_options(subscripts, shapes, options)
147        }
148        EinsumPlanSpec::LeftToRight => {
149            let n = subscripts.inputs.len();
150            if n <= 1 {
151                ContractionTree::from_pairs(subscripts, shapes, &[])
152            } else {
153                let path: Vec<(usize, usize)> = (0..n - 1).map(|_| (0, 1)).collect();
154                let pairs = jax_path_to_v1_pairs(&path, n)?;
155                ContractionTree::from_pairs(subscripts, shapes, &pairs)
156            }
157        }
158        EinsumPlanSpec::Path(path) => {
159            let pairs = jax_path_to_v1_pairs(path, subscripts.inputs.len())?;
160            ContractionTree::from_pairs(subscripts, shapes, &pairs)
161        }
162        EinsumPlanSpec::FixedPairs(pairs) => ContractionTree::from_pairs(subscripts, shapes, pairs),
163    }
164}
165
166pub(crate) fn hash_einsum_plan_spec(spec: &EinsumPlanSpec, state: &mut dyn Hasher) {
167    match spec {
168        EinsumPlanSpec::Auto(options) => {
169            state.write_u8(0);
170            hash_optimizer_options(options, state);
171        }
172        EinsumPlanSpec::LeftToRight => state.write_u8(1),
173        EinsumPlanSpec::Path(path) => {
174            state.write_u8(2);
175            hash_pairs(path, state);
176        }
177        EinsumPlanSpec::FixedPairs(pairs) => {
178            state.write_u8(3);
179            hash_pairs(pairs, state);
180        }
181    }
182}
183
184pub(crate) fn plan_specs_equal(lhs: &EinsumPlanSpec, rhs: &EinsumPlanSpec) -> bool {
185    match (lhs, rhs) {
186        (EinsumPlanSpec::Auto(lhs), EinsumPlanSpec::Auto(rhs)) => {
187            optimizer_options_equal_by_bits(lhs, rhs)
188        }
189        (EinsumPlanSpec::LeftToRight, EinsumPlanSpec::LeftToRight) => true,
190        (EinsumPlanSpec::Path(lhs), EinsumPlanSpec::Path(rhs)) => lhs == rhs,
191        (EinsumPlanSpec::FixedPairs(lhs), EinsumPlanSpec::FixedPairs(rhs)) => lhs == rhs,
192        _ => false,
193    }
194}
195
196fn tree_pairs(tree: &ContractionTree) -> Vec<(usize, usize)> {
197    (0..tree.step_count())
198        .filter_map(|step| tree.step_pair(step))
199        .collect()
200}
201
202fn validate_fixed_pairs(pairs: &[(usize, usize)], input_count: usize) -> Result<()> {
203    let required_steps = input_count.saturating_sub(1);
204    if pairs.len() != required_steps {
205        return Err(Error::InvalidArgument(format!(
206            "explicit contraction path for {input_count} operands must have {required_steps} steps, got {}",
207            pairs.len()
208        )));
209    }
210
211    let mut live = vec![false; input_count + pairs.len()];
212    for slot in live.iter_mut().take(input_count) {
213        *slot = true;
214    }
215
216    for (step_idx, &(left, right)) in pairs.iter().enumerate() {
217        let next_idx = input_count + step_idx;
218        if left == right {
219            return Err(Error::InvalidArgument(format!(
220                "pair ({left}, {right}) must reference two distinct live operands"
221            )));
222        }
223        if left >= next_idx || right >= next_idx {
224            return Err(Error::InvalidArgument(format!(
225                "pair ({left}, {right}) references non-existent operand"
226            )));
227        }
228        if !live[left] || !live[right] {
229            return Err(Error::InvalidArgument(format!(
230                "pair ({left}, {right}) references an operand or intermediate that is no longer live"
231            )));
232        }
233
234        live[left] = false;
235        live[right] = false;
236        live[next_idx] = true;
237    }
238
239    let live_count = live.iter().filter(|&&is_live| is_live).count();
240    if live_count != 1 {
241        return Err(Error::InvalidArgument(format!(
242            "explicit contraction path must leave exactly one live result, got {live_count}"
243        )));
244    }
245
246    Ok(())
247}
248
249fn hash_pairs(pairs: &[(usize, usize)], state: &mut dyn Hasher) {
250    state.write_usize(pairs.len());
251    for &(left, right) in pairs {
252        state.write_usize(left);
253        state.write_usize(right);
254    }
255}
256
257fn hash_optimizer_options(options: &ContractionOptimizerOptions, state: &mut dyn Hasher) {
258    state.write_usize(options.ntrials);
259    state.write_usize(options.niters);
260    state.write_usize(options.betas.len());
261    for value in &options.betas {
262        state.write_u64(value.to_bits());
263    }
264    state.write_u64(options.score.tc_weight.to_bits());
265    state.write_u64(options.score.sc_weight.to_bits());
266    state.write_u64(options.score.rw_weight.to_bits());
267    state.write_u64(options.score.sc_target.to_bits());
268}
269
270fn optimizer_options_equal_by_bits(
271    lhs: &ContractionOptimizerOptions,
272    rhs: &ContractionOptimizerOptions,
273) -> bool {
274    lhs.ntrials == rhs.ntrials
275        && lhs.niters == rhs.niters
276        && f64_slices_equal_by_bits(&lhs.betas, &rhs.betas)
277        && score_functions_equal_by_bits(&lhs.score, &rhs.score)
278}
279
280/// Convert JAX-style position-based path to fixed-ID pairs.
281///
282/// JAX format: each pair `(i, j)` refers to positions in a shrinking list.
283/// After contraction, the two operands are removed and the result is appended.
284/// Fixed-ID format keeps original operands at `0..input_count` and gives
285/// intermediate at step `k` the ID `input_count + k`.
286///
287/// # Errors
288///
289/// Returns an error if a path step references the same position twice or a
290/// position outside the current shrinking list.
291pub(crate) fn jax_path_to_v1_pairs(
292    jax_path: &[(usize, usize)],
293    input_count: usize,
294) -> Result<Vec<(usize, usize)>> {
295    let required_steps = input_count.saturating_sub(1);
296    if jax_path.len() != required_steps {
297        return Err(Error::InvalidArgument(format!(
298            "explicit contraction path for {input_count} operands must have {required_steps} steps, got {}",
299            jax_path.len()
300        )));
301    }
302
303    let mut positions: Vec<usize> = (0..input_count).collect();
304    let mut v1_pairs = Vec::with_capacity(jax_path.len());
305
306    for (step, &(pos_a, pos_b)) in jax_path.iter().enumerate() {
307        if pos_a == pos_b {
308            return Err(Error::InvalidArgument(format!(
309                "path step {step} references the same operand position twice: {pos_a}"
310            )));
311        }
312        let current_len = positions.len();
313        if pos_a >= current_len || pos_b >= current_len {
314            return Err(Error::InvalidArgument(format!(
315                "path step {step} references operand positions ({pos_a}, {pos_b}) with only {current_len} live operands"
316            )));
317        }
318
319        let (lo, hi) = if pos_a < pos_b {
320            (pos_a, pos_b)
321        } else {
322            (pos_b, pos_a)
323        };
324        let id_a = positions[lo];
325        let id_b = positions[hi];
326        v1_pairs.push((id_a, id_b));
327
328        positions.remove(hi);
329        positions.remove(lo);
330        positions.push(input_count + step);
331    }
332
333    Ok(v1_pairs)
334}
335
336/// Convert a [`NestedEinsum`] tree into fixed-ID pairs.
337///
338/// # Errors
339///
340/// Returns an error if a leaf references an input outside `0..input_count` or if
341/// a node has no children.
342pub(crate) fn nested_to_v1_pairs(
343    nested: &NestedEinsum,
344    input_count: usize,
345) -> Result<Vec<(usize, usize)>> {
346    let mut pairs = Vec::with_capacity(input_count.saturating_sub(1));
347    let mut next_id = input_count;
348    let root_id = walk_nested(nested, input_count, &mut pairs, &mut next_id)?;
349    if input_count == 0 || root_id >= next_id {
350        return Err(Error::InvalidArgument(
351            "nested einsum did not produce a valid root operand".into(),
352        ));
353    }
354    Ok(pairs)
355}
356
357fn walk_nested(
358    nested: &NestedEinsum,
359    input_count: usize,
360    pairs: &mut Vec<(usize, usize)>,
361    next_id: &mut usize,
362) -> Result<usize> {
363    match nested {
364        NestedEinsum::Leaf(idx) => {
365            if *idx >= input_count {
366                return Err(Error::InvalidArgument(format!(
367                    "nested einsum leaf {idx} is outside 0..{input_count}"
368                )));
369            }
370            Ok(*idx)
371        }
372        NestedEinsum::Node { children, .. } => {
373            let Some(first) = children.first() else {
374                return Err(Error::InvalidArgument(
375                    "nested einsum node must have at least one child".into(),
376                ));
377            };
378            let mut result_id = walk_nested(first, input_count, pairs, next_id)?;
379            for child in &children[1..] {
380                let child_id = walk_nested(child, input_count, pairs, next_id)?;
381                pairs.push((result_id, child_id));
382                result_id = *next_id;
383                *next_id += 1;
384            }
385            Ok(result_id)
386        }
387    }
388}
389
390fn f64_slices_equal_by_bits(lhs: &[f64], rhs: &[f64]) -> bool {
391    lhs.len() == rhs.len()
392        && lhs
393            .iter()
394            .zip(rhs)
395            .all(|(lhs, rhs)| lhs.to_bits() == rhs.to_bits())
396}
397
398fn score_functions_equal_by_bits(lhs: &ScoreFunction, rhs: &ScoreFunction) -> bool {
399    lhs.tc_weight.to_bits() == rhs.tc_weight.to_bits()
400        && lhs.sc_weight.to_bits() == rhs.sc_weight.to_bits()
401        && lhs.rw_weight.to_bits() == rhs.rw_weight.to_bits()
402        && lhs.sc_target.to_bits() == rhs.sc_target.to_bits()
403}
404
405#[cfg(test)]
406mod tests;