strided_opteinsum/
expr.rs

1use std::collections::{BTreeMap, HashMap};
2
3use num_complex::Complex64;
4#[cfg(test)]
5use num_traits::Zero;
6use strided_einsum2::{einsum2_into, einsum2_into_owned};
7use strided_kernel::copy_scale;
8use strided_view::{StridedArray, StridedViewMut};
9
10use crate::operand::{EinsumOperand, EinsumScalar, StridedData};
11use crate::parse::{EinsumCode, EinsumNode};
12use crate::single_tensor::single_tensor_einsum;
13
14// ---------------------------------------------------------------------------
15// Buffer pool for intermediate reuse
16// ---------------------------------------------------------------------------
17
18/// Reusable buffer pool for intermediate tensor allocations.
19///
20/// Tracks freed `Vec<f64>` and `Vec<Complex64>` buffers indexed by length.
21/// When a contraction step completes, its input buffers are returned to the
22/// pool. Subsequent steps can reuse these buffers instead of allocating fresh.
23///
24/// Uses `BTreeMap` so `pool_acquire` can find the **smallest buffer ≥ requested
25/// size** (best-fit), improving reuse when intermediate sizes vary slightly.
26///
27/// # Usage
28///
29/// Create a pool and pass it to [`EinsumCode::evaluate_with_pool`] to reuse
30/// buffers across multiple einsum calls:
31///
32/// ```ignore
33/// let mut pool = BufferPool::new();
34/// code.evaluate_with_pool(operands1, None, Some(&mut pool))?;
35/// code.evaluate_with_pool(operands2, None, Some(&mut pool))?;
36/// ```
37///
38/// Pass `None` to let each call allocate and free independently (no pooling).
39pub struct BufferPool {
40    f64_pool: BTreeMap<usize, Vec<Vec<f64>>>,
41    c64_pool: BTreeMap<usize, Vec<Vec<Complex64>>>,
42}
43
44impl BufferPool {
45    /// Create an empty buffer pool.
46    pub fn new() -> Self {
47        Self {
48            f64_pool: BTreeMap::new(),
49            c64_pool: BTreeMap::new(),
50        }
51    }
52}
53
54// ---------------------------------------------------------------------------
55// PoolOps — private trait for type-dispatched buffer pool access
56// ---------------------------------------------------------------------------
57
58/// Private sealed trait for buffer pool acquire/release, implemented for f64
59/// and Complex64. Enables generic `eval_pair_alloc` without exposing
60/// `BufferPool` in the public API.
61trait PoolOps: EinsumScalar {
62    /// Acquire a col-major buffer from the pool (or allocate fresh).
63    ///
64    /// # Safety contract
65    /// The returned array may contain uninitialized data. Callers must write
66    /// every element before reading (e.g. via `einsum2_into` with `beta=0`).
67    fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<Self>;
68
69    /// Release an owned buffer back to the pool for reuse.
70    /// Views are silently dropped (nothing to recycle).
71    fn pool_release(pool: &mut BufferPool, data: StridedData<'_, Self>);
72}
73
74/// Take the best-fit buffer (smallest capacity >= `total`) from a BTreeMap pool.
75fn take_best_fit<T>(pool: &mut BTreeMap<usize, Vec<Vec<T>>>, total: usize) -> Option<Vec<T>> {
76    // Find the smallest key >= total using BTreeMap range search.
77    let key = *pool.range(total..).next()?.0;
78    let vecs = pool.get_mut(&key)?;
79    let buf = vecs.pop();
80    if vecs.is_empty() {
81        pool.remove(&key);
82    }
83    buf
84}
85
86impl PoolOps for f64 {
87    fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<f64> {
88        let total: usize = dims.iter().product();
89        // SAFETY: einsum2_into with beta=0 writes every output element before reading.
90        match take_best_fit(&mut pool.f64_pool, total) {
91            Some(buf) => unsafe { StridedArray::col_major_from_buffer_uninit(buf, dims) },
92            None => unsafe { StridedArray::col_major_uninit(dims) },
93        }
94    }
95
96    fn pool_release(pool: &mut BufferPool, data: StridedData<'_, f64>) {
97        if let StridedData::Owned(arr) = data {
98            let buf = arr.into_data();
99            pool.f64_pool.entry(buf.len()).or_default().push(buf);
100        }
101    }
102}
103
104impl PoolOps for Complex64 {
105    fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<Complex64> {
106        let total: usize = dims.iter().product();
107        // SAFETY: einsum2_into with beta=0 writes every output element before reading.
108        match take_best_fit(&mut pool.c64_pool, total) {
109            Some(buf) => unsafe { StridedArray::col_major_from_buffer_uninit(buf, dims) },
110            None => unsafe { StridedArray::col_major_uninit(dims) },
111        }
112    }
113
114    fn pool_release(pool: &mut BufferPool, data: StridedData<'_, Complex64>) {
115        if let StridedData::Owned(arr) = data {
116            let buf = arr.into_data();
117            pool.c64_pool.entry(buf.len()).or_default().push(buf);
118        }
119    }
120}
121
122// ---------------------------------------------------------------------------
123// Helper: collect all index chars from a subtree, preserving first-seen order
124// ---------------------------------------------------------------------------
125
126fn collect_all_ids(node: &EinsumNode) -> Vec<char> {
127    let mut result = Vec::new();
128    collect_all_ids_inner(node, &mut result);
129    result
130}
131
132fn collect_all_ids_inner(node: &EinsumNode, result: &mut Vec<char>) {
133    match node {
134        EinsumNode::Leaf { ids, .. } => {
135            for &id in ids {
136                if !result.contains(&id) {
137                    result.push(id);
138                }
139            }
140        }
141        EinsumNode::Contract { args } => {
142            for arg in args {
143                collect_all_ids_inner(arg, result);
144            }
145        }
146    }
147}
148
149// ---------------------------------------------------------------------------
150// Helper: compute which output indices a Contract node should keep
151// ---------------------------------------------------------------------------
152
153/// For a Contract node with `args`, decide which index chars to keep.
154///
155/// An index is kept if it appears in `needed_ids` (what the parent/caller
156/// needs from this node) AND it is actually present in at least one child
157/// subtree.
158fn compute_contract_output_ids(args: &[EinsumNode], needed_ids: &[char]) -> Vec<char> {
159    if args.len() == 2 {
160        let left_ids = collect_all_ids(&args[0]);
161        let right_ids = collect_all_ids(&args[1]);
162        return compute_binary_output_ids(&left_ids, &right_ids, needed_ids);
163    }
164
165    // Walk args in order and collect ids preserving first-seen order
166    let mut all_ids_ordered = Vec::new();
167    for arg in args {
168        for id in collect_all_ids(arg) {
169            if !all_ids_ordered.contains(&id) {
170                all_ids_ordered.push(id);
171            }
172        }
173    }
174
175    // Keep only ids that the parent needs
176    all_ids_ordered
177        .into_iter()
178        .filter(|id| needed_ids.contains(id))
179        .collect()
180}
181
182/// Compute the set of ids that a child of a Contract node needs to provide.
183///
184/// A child needs to keep an index if:
185///   - It is in the Contract's own `output_ids` (parent needs it), OR
186///   - It is shared with at least one sibling subtree (contraction index).
187///
188/// The child is only responsible for indices in its own subtree, but this
189/// function returns the full needed set; the child will naturally intersect
190/// with its own ids.
191fn compute_child_needed_ids(
192    output_ids: &[char],
193    child_idx: usize,
194    args: &[EinsumNode],
195) -> Vec<char> {
196    let mut needed: Vec<char> = output_ids.to_vec();
197
198    // Add indices shared between this child and any sibling
199    let child_ids = collect_all_ids(&args[child_idx]);
200    for (j, arg) in args.iter().enumerate() {
201        if j == child_idx {
202            continue;
203        }
204        let sibling_ids = collect_all_ids(arg);
205        for &id in &child_ids {
206            if sibling_ids.contains(&id) && !needed.contains(&id) {
207                needed.push(id);
208            }
209        }
210    }
211
212    needed
213}
214
215// ---------------------------------------------------------------------------
216// Pairwise contraction (borrows operands)
217// ---------------------------------------------------------------------------
218
219fn out_dims_from_map(
220    dim_map: &HashMap<char, usize>,
221    output_ids: &[char],
222    size_dict: &HashMap<char, usize>,
223) -> crate::Result<Vec<usize>> {
224    let mut out_dims = Vec::with_capacity(output_ids.len());
225    for &id in output_ids {
226        if let Some(&dim) = dim_map.get(&id) {
227            out_dims.push(dim);
228        } else if let Some(&dim) = size_dict.get(&id) {
229            out_dims.push(dim);
230        } else {
231            return Err(crate::EinsumError::OrphanOutputAxis(id.to_string()));
232        }
233    }
234    Ok(out_dims)
235}
236
237/// Look up output dimensions directly from left/right ids without HashMap allocation.
238fn out_dims_from_ids(
239    left_ids: &[char],
240    left_dims: &[usize],
241    right_ids: &[char],
242    right_dims: &[usize],
243    output_ids: &[char],
244    size_dict: &HashMap<char, usize>,
245) -> crate::Result<Vec<usize>> {
246    let mut out_dims = Vec::with_capacity(output_ids.len());
247    for &id in output_ids {
248        if let Some(pos) = left_ids.iter().position(|&c| c == id) {
249            out_dims.push(left_dims[pos]);
250        } else if let Some(pos) = right_ids.iter().position(|&c| c == id) {
251            out_dims.push(right_dims[pos]);
252        } else if let Some(&dim) = size_dict.get(&id) {
253            out_dims.push(dim);
254        } else {
255            return Err(crate::EinsumError::OrphanOutputAxis(id.to_string()));
256        }
257    }
258    Ok(out_dims)
259}
260
261/// Compute binary contraction output id order.
262///
263/// Uses canonical `[lo, ro, batch]` order:
264/// - lo: ids only in left and needed
265/// - ro: ids only in right and needed
266/// - batch: ids in both and needed
267fn compute_binary_output_ids(
268    left_ids: &[char],
269    right_ids: &[char],
270    needed_ids: &[char],
271) -> Vec<char> {
272    let mut out = Vec::new();
273    for &id in left_ids {
274        if needed_ids.contains(&id) && !right_ids.contains(&id) && !out.contains(&id) {
275            out.push(id);
276        }
277    }
278    for &id in right_ids {
279        if needed_ids.contains(&id) && !left_ids.contains(&id) && !out.contains(&id) {
280            out.push(id);
281        }
282    }
283    for &id in left_ids {
284        if needed_ids.contains(&id) && right_ids.contains(&id) && !out.contains(&id) {
285            out.push(id);
286        }
287    }
288    out
289}
290
291/// Generic inner function for pairwise contraction with buffer pool.
292///
293/// Acquires an output buffer, runs `einsum2_into`, and releases input buffers
294/// back to the pool.
295fn eval_pair_alloc<T: PoolOps>(
296    ld: StridedData<'_, T>,
297    left_ids: &[char],
298    rd: StridedData<'_, T>,
299    right_ids: &[char],
300    output_ids: &[char],
301    pool: &mut BufferPool,
302    size_dict: &HashMap<char, usize>,
303) -> crate::Result<EinsumOperand<'static>> {
304    let out_dims = out_dims_from_ids(
305        left_ids,
306        ld.dims(),
307        right_ids,
308        rd.dims(),
309        output_ids,
310        size_dict,
311    )?;
312    let mut c_arr = T::pool_acquire(pool, &out_dims);
313    match (ld, rd) {
314        // Preserve ownership so strided-einsum2 can use prepare_input_owned
315        // and avoid extra materialization in prepare_input_view.
316        (StridedData::Owned(a), StridedData::Owned(b)) => {
317            einsum2_into_owned(
318                c_arr.view_mut(),
319                a,
320                b,
321                output_ids,
322                left_ids,
323                right_ids,
324                T::one(),
325                T::zero(),
326                false,
327                false,
328            )?;
329        }
330        (ld, rd) => {
331            let a_view = ld.as_view();
332            let b_view = rd.as_view();
333            einsum2_into(
334                c_arr.view_mut(),
335                &a_view,
336                &b_view,
337                output_ids,
338                left_ids,
339                right_ids,
340                T::one(),
341                T::zero(),
342            )?;
343            T::pool_release(pool, ld);
344            T::pool_release(pool, rd);
345        }
346    }
347    Ok(T::wrap_array(c_arr))
348}
349
350/// Contract two operands, consuming them by value. Promotes to c64 if types are mixed.
351///
352/// Uses view-based `einsum2_into` so that owned input buffers can be released
353/// back to the `pool` for reuse by subsequent contraction steps.
354fn eval_pair(
355    left: EinsumOperand<'_>,
356    left_ids: &[char],
357    right: EinsumOperand<'_>,
358    right_ids: &[char],
359    output_ids: &[char],
360    pool: &mut BufferPool,
361    size_dict: &HashMap<char, usize>,
362) -> crate::Result<EinsumOperand<'static>> {
363    match (left, right) {
364        (EinsumOperand::F64(ld), EinsumOperand::F64(rd)) => {
365            eval_pair_alloc(ld, left_ids, rd, right_ids, output_ids, pool, size_dict)
366        }
367        (EinsumOperand::C64(ld), EinsumOperand::C64(rd)) => {
368            eval_pair_alloc(ld, left_ids, rd, right_ids, output_ids, pool, size_dict)
369        }
370        (left, right) => {
371            // Mixed types: promote both to c64 by consuming, then recurse (hits C64/C64 branch)
372            let left_c64 = left.to_c64_owned();
373            let right_c64 = right.to_c64_owned();
374            eval_pair(
375                left_c64, left_ids, right_c64, right_ids, output_ids, pool, size_dict,
376            )
377        }
378    }
379}
380
381// ---------------------------------------------------------------------------
382// Pairwise contraction into user-provided output (zero-copy for final step)
383// ---------------------------------------------------------------------------
384
385/// Contract two operands directly into a user-provided output buffer.
386///
387/// Unlike `eval_pair`, this writes into `output` with alpha/beta scaling
388/// instead of allocating a fresh array. Used for the final contraction in
389/// `evaluate_into`.
390fn eval_pair_into<T: EinsumScalar>(
391    left: EinsumOperand<'_>,
392    left_ids: &[char],
393    right: EinsumOperand<'_>,
394    right_ids: &[char],
395    output: StridedViewMut<T>,
396    output_ids: &[char],
397    alpha: T,
398    beta: T,
399) -> crate::Result<()> {
400    let left_data = T::extract_data(left)?;
401    let right_data = T::extract_data(right)?;
402
403    match (left_data, right_data) {
404        (StridedData::Owned(a), StridedData::Owned(b)) => {
405            einsum2_into_owned(
406                output, a, b, output_ids, left_ids, right_ids, alpha, beta, false, false,
407            )?;
408        }
409        (StridedData::Owned(a), StridedData::View(b)) => {
410            einsum2_into(
411                output,
412                &a.view(),
413                &b,
414                output_ids,
415                left_ids,
416                right_ids,
417                alpha,
418                beta,
419            )?;
420        }
421        (StridedData::View(a), StridedData::Owned(b)) => {
422            einsum2_into(
423                output,
424                &a,
425                &b.view(),
426                output_ids,
427                left_ids,
428                right_ids,
429                alpha,
430                beta,
431            )?;
432        }
433        (StridedData::View(a), StridedData::View(b)) => {
434            einsum2_into(output, &a, &b, output_ids, left_ids, right_ids, alpha, beta)?;
435        }
436    }
437    Ok(())
438}
439
440// ---------------------------------------------------------------------------
441// Accumulate helper for single-tensor results
442// ---------------------------------------------------------------------------
443
444/// Write `output = alpha * result + beta * output`.
445///
446/// `result` must already have the same shape as `output`.
447fn accumulate_into<T: EinsumScalar>(
448    output: &mut StridedViewMut<T>,
449    result: &StridedArray<T>,
450    alpha: T,
451    beta: T,
452) -> crate::Result<()> {
453    let result_view = result.view();
454    if beta == T::zero() {
455        if alpha == T::one() {
456            strided_kernel::copy_into(output, &result_view)?;
457        } else {
458            copy_scale(output, &result_view, alpha)?;
459        }
460    } else {
461        // General case: output = alpha * result + beta * output
462        // axpy does: output += alpha * result, so we need to scale output by beta first.
463        // We use a temporary to avoid aliasing issues.
464        let dims = output.dims().to_vec();
465        let mut temp = StridedArray::<T>::col_major(&dims);
466        strided_kernel::copy_into(&mut temp.view_mut(), &result_view)?;
467        // temp now holds result data in col-major layout
468        // output = beta * output + alpha * temp
469        // Using zip_map2_into would need output as both src and dest.
470        // Instead: copy_scale output into a second temp, then zip_map2_into.
471        // But simpler: use axpy which reads+writes dest.
472        // axpy(dest, src, alpha) does: dest[i] = dest[i] + alpha * src[i]
473        // So: first scale output by beta, then axpy with alpha.
474        // "scale output by beta" = copy_scale into temp2, copy back. Or just
475        // use a different approach: compute full result in temp, copy to output.
476        //
477        // Simplest correct approach for this rare path:
478        let mut output_copy = StridedArray::<T>::col_major(&dims);
479        strided_kernel::copy_into(&mut output_copy.view_mut(), &output.as_view())?;
480        strided_kernel::zip_map2_into(output, &temp.view(), &output_copy.view(), |r, o| {
481            alpha * r + beta * o
482        })?;
483    }
484    Ok(())
485}
486
487// ---------------------------------------------------------------------------
488// Single-tensor dispatch (borrows operand)
489// ---------------------------------------------------------------------------
490
491/// Generic inner function for single-tensor einsum.
492fn eval_single_typed<T: EinsumScalar>(
493    data: &StridedData<'_, T>,
494    input_ids: &[char],
495    output_ids: &[char],
496    size_dict: &HashMap<char, usize>,
497) -> crate::Result<EinsumOperand<'static>> {
498    let view = data.as_view();
499    let result = single_tensor_einsum(&view, input_ids, output_ids, Some(size_dict))?;
500    Ok(T::wrap_array(result))
501}
502
503fn eval_single(
504    operand: &EinsumOperand<'_>,
505    input_ids: &[char],
506    output_ids: &[char],
507    size_dict: &HashMap<char, usize>,
508) -> crate::Result<EinsumOperand<'static>> {
509    match operand {
510        EinsumOperand::F64(data) => eval_single_typed(data, input_ids, output_ids, size_dict),
511        EinsumOperand::C64(data) => eval_single_typed(data, input_ids, output_ids, size_dict),
512    }
513}
514
515// ---------------------------------------------------------------------------
516// Permutation helpers
517// ---------------------------------------------------------------------------
518
519/// Check if the transformation from input_ids to output_ids is a pure
520/// permutation (same set of chars, same length, no repeated indices).
521fn is_permutation_only(input_ids: &[char], output_ids: &[char]) -> bool {
522    if input_ids.len() != output_ids.len() {
523        return false;
524    }
525    // Check no repeated indices in input (linear scan)
526    for (i, &id) in input_ids.iter().enumerate() {
527        if input_ids[..i].contains(&id) {
528            return false; // repeated index = trace, not permutation
529        }
530    }
531    // Check all output ids appear in input
532    for &id in output_ids {
533        if !input_ids.contains(&id) {
534            return false;
535        }
536    }
537    true
538}
539
540/// Compute the permutation that maps input_ids ordering to output_ids ordering.
541fn compute_permutation(input_ids: &[char], output_ids: &[char]) -> Vec<usize> {
542    output_ids
543        .iter()
544        .map(|oid| input_ids.iter().position(|iid| iid == oid).unwrap())
545        .collect()
546}
547
548// ---------------------------------------------------------------------------
549// Execute omeco NestedEinsum tree
550// ---------------------------------------------------------------------------
551
552/// Execute an omeco-optimized contraction tree by contracting pairs according
553/// to the tree structure.
554///
555/// `children` is a Vec of Option-wrapped (operand, ids) pairs. The omeco
556/// `NestedEinsum::Leaf` variant references children by index; we `.take()`
557/// each entry to move ownership out exactly once.
558fn execute_nested<'a>(
559    nested: &omeco::NestedEinsum<char>,
560    children: &mut Vec<Option<(EinsumOperand<'a>, Vec<char>)>>,
561    pool: &mut BufferPool,
562    size_dict: &HashMap<char, usize>,
563) -> crate::Result<(EinsumOperand<'a>, Vec<char>)> {
564    match nested {
565        omeco::NestedEinsum::Leaf { tensor_index } => {
566            let slot = children.get_mut(*tensor_index).ok_or_else(|| {
567                crate::EinsumError::Internal(format!(
568                    "optimizer referenced child index {} out of bounds",
569                    tensor_index
570                ))
571            })?;
572            let (op, ids) = slot.take().ok_or_else(|| {
573                crate::EinsumError::Internal(format!(
574                    "child operand {} was already consumed",
575                    tensor_index
576                ))
577            })?;
578            Ok((op, ids))
579        }
580        omeco::NestedEinsum::Node { args, eins } => {
581            if args.len() != 2 {
582                return Err(crate::EinsumError::Internal(format!(
583                    "optimizer produced non-binary node with {} children",
584                    args.len()
585                )));
586            }
587            let (left, left_ids) = execute_nested(&args[0], children, pool, size_dict)?;
588            let (right, right_ids) = execute_nested(&args[1], children, pool, size_dict)?;
589            let output_ids: Vec<char> = eins.iy.clone();
590            let result = eval_pair(
591                left,
592                &left_ids,
593                right,
594                &right_ids,
595                &output_ids,
596                pool,
597                size_dict,
598            )?;
599            Ok((result, output_ids))
600        }
601    }
602}
603
604/// Execute an omeco-optimized contraction tree, writing the root contraction
605/// directly into a user-provided output buffer.
606///
607/// Inner (non-root) contractions use normal `execute_nested` / `eval_pair`.
608/// Only the root `Node`'s contraction is written directly into `output`.
609fn execute_nested_into<'a, T: EinsumScalar>(
610    nested: &omeco::NestedEinsum<char>,
611    children: &mut Vec<Option<(EinsumOperand<'a>, Vec<char>)>>,
612    output: StridedViewMut<T>,
613    output_ids: &[char],
614    alpha: T,
615    beta: T,
616    pool: &mut BufferPool,
617    size_dict: &HashMap<char, usize>,
618) -> crate::Result<()> {
619    match nested {
620        omeco::NestedEinsum::Node { args, eins: _ } => {
621            if args.len() != 2 {
622                return Err(crate::EinsumError::Internal(format!(
623                    "optimizer produced non-binary node with {} children",
624                    args.len()
625                )));
626            }
627            // Evaluate children normally (they allocate temporaries)
628            let (left, left_ids) = execute_nested(&args[0], children, pool, size_dict)?;
629            let (right, right_ids) = execute_nested(&args[1], children, pool, size_dict)?;
630            // Root contraction writes directly into user's output
631            eval_pair_into(
632                left, &left_ids, right, &right_ids, output, output_ids, alpha, beta,
633            )
634        }
635        omeco::NestedEinsum::Leaf { tensor_index } => {
636            // Root is a single leaf — extract and accumulate into output
637            let slot = children.get_mut(*tensor_index).ok_or_else(|| {
638                crate::EinsumError::Internal(format!(
639                    "optimizer referenced child index {} out of bounds",
640                    tensor_index
641                ))
642            })?;
643            let (op, op_ids) = slot.take().ok_or_else(|| {
644                crate::EinsumError::Internal(format!(
645                    "child operand {} was already consumed",
646                    tensor_index
647                ))
648            })?;
649            let data = T::extract_data(op)?;
650            let arr = data.into_array();
651            // Permute if needed
652            if op_ids != output_ids {
653                let perm = compute_permutation(&op_ids, output_ids);
654                let permuted = arr.permuted(&perm)?;
655                accumulate_into(&mut { output }, &permuted, alpha, beta)?;
656            } else {
657                accumulate_into(&mut { output }, &arr, alpha, beta)?;
658            }
659            Ok(())
660        }
661    }
662}
663
664// ---------------------------------------------------------------------------
665// Recursive evaluation
666// ---------------------------------------------------------------------------
667
668/// Recursively evaluate an `EinsumNode`, returning the result operand and
669/// the index chars labelling its axes.
670///
671/// Leaf nodes return borrowed views directly (no copy). Contract nodes
672/// always produce freshly allocated results (`'static` coerced to `'a`).
673///
674/// `needed_ids` tells this node which indices the caller needs in the result.
675/// For the root call this is the final output indices of the einsum.
676fn eval_node<'a>(
677    node: &EinsumNode,
678    operands: &mut Vec<Option<EinsumOperand<'a>>>,
679    needed_ids: &[char],
680    pool: &mut BufferPool,
681    size_dict: &HashMap<char, usize>,
682) -> crate::Result<(EinsumOperand<'a>, Vec<char>)> {
683    match node {
684        EinsumNode::Leaf { ids, tensor_index } => {
685            let found = operands.len();
686            let slot = operands.get_mut(*tensor_index).ok_or_else(|| {
687                crate::EinsumError::OperandCountMismatch {
688                    expected: tensor_index + 1,
689                    found,
690                }
691            })?;
692            let op = slot.take().ok_or_else(|| {
693                crate::EinsumError::Internal(format!(
694                    "operand {} was already consumed",
695                    tensor_index
696                ))
697            })?;
698            // Return borrowed view directly — no to_owned_static() copy.
699            Ok((op, ids.clone()))
700        }
701        EinsumNode::Contract { args } => {
702            // Determine which indices this Contract node should output.
703            let node_output_ids = compute_contract_output_ids(args, needed_ids);
704
705            match args.len() {
706                0 => unreachable!("empty Contract node"),
707                1 => {
708                    // Single-tensor operation.
709                    let child_needed = compute_child_needed_ids(&node_output_ids, 0, args);
710                    let (child_op, child_ids) =
711                        eval_node(&args[0], operands, &child_needed, pool, size_dict)?;
712
713                    // Identity passthrough: no allocation needed.
714                    if child_ids == node_output_ids {
715                        return Ok((child_op, node_output_ids));
716                    }
717
718                    // Permutation-only passthrough: metadata reorder, no data copy.
719                    if is_permutation_only(&child_ids, &node_output_ids) {
720                        let perm = compute_permutation(&child_ids, &node_output_ids);
721                        return Ok((child_op.permuted(&perm)?, node_output_ids));
722                    }
723
724                    // General case: trace, reduction, repeat, duplicate, etc.
725                    let result = eval_single(&child_op, &child_ids, &node_output_ids, size_dict)?;
726                    Ok((result, node_output_ids))
727                }
728                2 => {
729                    // Binary contraction.
730                    let left_needed = compute_child_needed_ids(&node_output_ids, 0, args);
731                    let right_needed = compute_child_needed_ids(&node_output_ids, 1, args);
732                    let (left, left_ids) =
733                        eval_node(&args[0], operands, &left_needed, pool, size_dict)?;
734                    let (right, right_ids) =
735                        eval_node(&args[1], operands, &right_needed, pool, size_dict)?;
736                    let result = eval_pair(
737                        left,
738                        &left_ids,
739                        right,
740                        &right_ids,
741                        &node_output_ids,
742                        pool,
743                        size_dict,
744                    )?;
745                    Ok((result, node_output_ids))
746                }
747                _ => {
748                    // 3+ children: use omeco greedy optimizer to find
749                    // an efficient pairwise contraction order.
750
751                    // 1. Evaluate all children to get their operands and ids
752                    let mut children: Vec<Option<(EinsumOperand<'a>, Vec<char>)>> = Vec::new();
753                    for (i, arg) in args.iter().enumerate() {
754                        let child_needed = compute_child_needed_ids(&node_output_ids, i, args);
755                        let (op, ids) = eval_node(arg, operands, &child_needed, pool, size_dict)?;
756                        children.push(Some((op, ids)));
757                    }
758
759                    // 2. Build dimension sizes map from evaluated operands
760                    let mut dim_sizes: HashMap<char, usize> = HashMap::new();
761                    for child_opt in &children {
762                        if let Some((op, ids)) = child_opt {
763                            for (j, &id) in ids.iter().enumerate() {
764                                dim_sizes.insert(id, op.dims()[j]);
765                            }
766                        }
767                    }
768
769                    // 3. Build omeco EinCode from child ids and node output ids
770                    let input_ids: Vec<Vec<char>> = children
771                        .iter()
772                        .map(|c| c.as_ref().unwrap().1.clone())
773                        .collect();
774                    let code = omeco::EinCode::new(input_ids, node_output_ids.clone());
775
776                    // 4. Optimize using omeco greedy method
777                    let optimizer = omeco::GreedyMethod::default();
778                    let nested = omeco::CodeOptimizer::optimize(&optimizer, &code, &dim_sizes)
779                        .ok_or_else(|| {
780                            crate::EinsumError::Internal(
781                                "optimizer failed to produce a plan".into(),
782                            )
783                        })?;
784
785                    // 5. Execute the nested contraction tree
786                    let (result, result_ids) =
787                        execute_nested(&nested, &mut children, pool, size_dict)?;
788                    Ok((result, result_ids))
789                }
790            }
791        }
792    }
793}
794
795// ---------------------------------------------------------------------------
796// Public API
797// ---------------------------------------------------------------------------
798
799impl EinsumCode {
800    /// Evaluate the einsum contraction tree with the given operands.
801    ///
802    /// Borrowed view operands are propagated through the tree without copying.
803    /// The result lifetime matches the input operand lifetime: if all inputs
804    /// are owned (`'static`), the result is also `'static`.
805    ///
806    /// Pass `size_dict` to specify sizes for output indices not present in any
807    /// input (generative outputs like `"->ii"` or `"i->ij"`).
808    /// Evaluate the einsum contraction tree with the given operands.
809    ///
810    /// Equivalent to `evaluate_with_pool(operands, size_dict, None)`.
811    pub fn evaluate<'a>(
812        &self,
813        operands: Vec<EinsumOperand<'a>>,
814        size_dict: Option<&HashMap<char, usize>>,
815    ) -> crate::Result<EinsumOperand<'a>> {
816        self.evaluate_with_pool(operands, size_dict, None)
817    }
818
819    /// Evaluate the einsum contraction tree, optionally reusing a buffer pool.
820    ///
821    /// Pass `Some(&mut pool)` to reuse intermediate buffers across calls.
822    /// Pass `None` to use a fresh temporary pool (buffers freed on return).
823    pub fn evaluate_with_pool<'a>(
824        &self,
825        operands: Vec<EinsumOperand<'a>>,
826        size_dict: Option<&HashMap<char, usize>>,
827        pool: Option<&mut BufferPool>,
828    ) -> crate::Result<EinsumOperand<'a>> {
829        let expected = leaf_count(&self.root);
830        if operands.len() != expected {
831            return Err(crate::EinsumError::OperandCountMismatch {
832                expected,
833                found: operands.len(),
834            });
835        }
836
837        let mut ops: Vec<Option<EinsumOperand<'a>>> = operands.into_iter().map(Some).collect();
838        let mut temp_pool;
839        let pool = match pool {
840            Some(p) => p,
841            None => {
842                temp_pool = BufferPool::new();
843                &mut temp_pool
844            }
845        };
846
847        // Build unified size_dict: operand-inferred sizes + user-provided overrides
848        let mut unified = build_dim_map(&self.root, &ops);
849        if let Some(sd) = size_dict {
850            merge_size_dict(&mut unified, sd)?;
851        }
852
853        let (result, result_ids) =
854            eval_node(&self.root, &mut ops, &self.output_ids, pool, &unified)?;
855
856        // If the result ids already match the desired output, we're done.
857        if result_ids == self.output_ids {
858            return Ok(result);
859        }
860
861        // Permutation-only: reorder metadata, no data copy.
862        if is_permutation_only(&result_ids, &self.output_ids) {
863            let perm = compute_permutation(&result_ids, &self.output_ids);
864            return Ok(result.permuted(&perm)?);
865        }
866
867        // General fallback: reduce/trace/repeat/duplicate to match the final output_ids.
868        let adjusted = eval_single(&result, &result_ids, &self.output_ids, &unified)?;
869        Ok(adjusted)
870    }
871}
872
873fn leaf_count(node: &EinsumNode) -> usize {
874    match node {
875        EinsumNode::Leaf { .. } => 1,
876        EinsumNode::Contract { args } => args.iter().map(leaf_count).sum(),
877    }
878}
879
880/// Build a dimension map from operands and the parsed tree.
881///
882/// Maps each index char to its dimension size by walking the tree's Leaf nodes
883/// and matching them to the corresponding operands.
884fn build_dim_map(
885    node: &EinsumNode,
886    operands: &[Option<EinsumOperand<'_>>],
887) -> HashMap<char, usize> {
888    let mut dim_map = HashMap::new();
889    build_dim_map_inner(node, operands, &mut dim_map);
890    dim_map
891}
892
893/// Merge user-provided size_dict into the unified map.
894///
895/// Returns an error if a label appears in both maps with different sizes.
896fn merge_size_dict(
897    unified: &mut HashMap<char, usize>,
898    user: &HashMap<char, usize>,
899) -> crate::Result<()> {
900    for (&label, &size) in user {
901        if let Some(&existing) = unified.get(&label) {
902            if existing != size {
903                return Err(crate::EinsumError::DimensionMismatch {
904                    axis: label.to_string(),
905                    dim_a: existing,
906                    dim_b: size,
907                });
908            }
909        } else {
910            unified.insert(label, size);
911        }
912    }
913    Ok(())
914}
915
916fn build_dim_map_inner(
917    node: &EinsumNode,
918    operands: &[Option<EinsumOperand<'_>>],
919    dim_map: &mut HashMap<char, usize>,
920) {
921    match node {
922        EinsumNode::Leaf { ids, tensor_index } => {
923            if let Some(Some(op)) = operands.get(*tensor_index) {
924                for (i, &id) in ids.iter().enumerate() {
925                    dim_map.insert(id, op.dims()[i]);
926                }
927            }
928        }
929        EinsumNode::Contract { args } => {
930            for arg in args {
931                build_dim_map_inner(arg, operands, dim_map);
932            }
933        }
934    }
935}
936
937impl EinsumCode {
938    /// Evaluate the einsum contraction tree, writing the result directly into
939    /// a user-provided output buffer with alpha/beta scaling.
940    ///
941    /// `output = alpha * einsum(operands) + beta * output`
942    ///
943    /// The output element type `T` must match the computation: use `f64` when
944    /// all operands are real, `Complex64` when any operand is complex.
945    /// If `T = f64` but any operand is complex, returns `TypeMismatch` error.
946    /// If `T = Complex64`, real operands are promoted automatically.
947    ///
948    /// Pass `size_dict` to specify sizes for output indices not present in any
949    /// input (generative outputs like `"->ii"` or `"i->ij"`).
950    pub fn evaluate_into<T: EinsumScalar>(
951        &self,
952        operands: Vec<EinsumOperand<'_>>,
953        output: StridedViewMut<T>,
954        alpha: T,
955        beta: T,
956        size_dict: Option<&HashMap<char, usize>>,
957    ) -> crate::Result<()> {
958        self.evaluate_into_with_pool(operands, output, alpha, beta, size_dict, None)
959    }
960
961    /// Evaluate the einsum contraction tree, writing the result into a
962    /// pre-allocated output buffer with alpha/beta scaling and optional
963    /// buffer pool reuse.
964    ///
965    /// `output = alpha * einsum(operands) + beta * output`
966    ///
967    /// Pass `Some(&mut pool)` to reuse intermediate buffers across calls.
968    /// Pass `None` to use a fresh temporary pool (buffers freed on return).
969    pub fn evaluate_into_with_pool<T: EinsumScalar>(
970        &self,
971        operands: Vec<EinsumOperand<'_>>,
972        mut output: StridedViewMut<T>,
973        alpha: T,
974        beta: T,
975        size_dict: Option<&HashMap<char, usize>>,
976        pool: Option<&mut BufferPool>,
977    ) -> crate::Result<()> {
978        let expected = leaf_count(&self.root);
979        if operands.len() != expected {
980            return Err(crate::EinsumError::OperandCountMismatch {
981                expected,
982                found: operands.len(),
983            });
984        }
985
986        // Validate output type compatibility
987        let mut ops: Vec<Option<EinsumOperand<'_>>> = operands.into_iter().map(Some).collect();
988        T::validate_operands(&ops)?;
989
990        // Build unified size_dict: operand-inferred sizes + user-provided overrides
991        let mut unified = build_dim_map(&self.root, &ops);
992        if let Some(sd) = size_dict {
993            merge_size_dict(&mut unified, sd)?;
994        }
995
996        // Compute expected output shape
997        let expected_dims = out_dims_from_map(&unified, &self.output_ids, &unified)?;
998        if output.dims() != expected_dims.as_slice() {
999            return Err(crate::EinsumError::OutputShapeMismatch {
1000                expected: expected_dims,
1001                got: output.dims().to_vec(),
1002            });
1003        }
1004
1005        let mut temp_pool;
1006        let pool = match pool {
1007            Some(p) => p,
1008            None => {
1009                temp_pool = BufferPool::new();
1010                &mut temp_pool
1011            }
1012        };
1013
1014        match &self.root {
1015            EinsumNode::Leaf { ids, tensor_index } => {
1016                // Single operand: extract, permute/trace, accumulate
1017                let op = ops[*tensor_index].take().ok_or_else(|| {
1018                    crate::EinsumError::Internal("operand already consumed".into())
1019                })?;
1020                let single_result = eval_single(&op, ids, &self.output_ids, &unified)?;
1021                let data = T::extract_data(single_result)?;
1022                accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
1023            }
1024            EinsumNode::Contract { args } => match args.len() {
1025                0 => unreachable!("empty Contract node"),
1026                1 => {
1027                    // Single child: evaluate, then accumulate
1028                    let child_needed = compute_child_needed_ids(&self.output_ids, 0, args);
1029                    let (child_op, child_ids) =
1030                        eval_node(&args[0], &mut ops, &child_needed, pool, &unified)?;
1031
1032                    if child_ids == self.output_ids {
1033                        // Identity: just accumulate
1034                        let data = T::extract_data(child_op)?;
1035                        accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
1036                    } else if is_permutation_only(&child_ids, &self.output_ids) {
1037                        // Permutation: permute the data, then accumulate
1038                        let perm = compute_permutation(&child_ids, &self.output_ids);
1039                        let data = T::extract_data(child_op)?;
1040                        let arr = data.into_array();
1041                        let permuted = arr.permuted(&perm)?;
1042                        accumulate_into(&mut output, &permuted, alpha, beta)?;
1043                    } else {
1044                        // General: trace/reduction/repeat/duplicate
1045                        let result =
1046                            eval_single(&child_op, &child_ids, &self.output_ids, &unified)?;
1047                        let data = T::extract_data(result)?;
1048                        accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
1049                    }
1050                }
1051                2 => {
1052                    // Binary contraction: write directly into output
1053                    let left_needed = compute_child_needed_ids(&self.output_ids, 0, args);
1054                    let right_needed = compute_child_needed_ids(&self.output_ids, 1, args);
1055                    let (left, left_ids) =
1056                        eval_node(&args[0], &mut ops, &left_needed, pool, &unified)?;
1057                    let (right, right_ids) =
1058                        eval_node(&args[1], &mut ops, &right_needed, pool, &unified)?;
1059                    eval_pair_into(
1060                        left,
1061                        &left_ids,
1062                        right,
1063                        &right_ids,
1064                        output,
1065                        &self.output_ids,
1066                        alpha,
1067                        beta,
1068                    )?;
1069                }
1070                _ => {
1071                    // 3+ children: use omeco, final contraction into output
1072                    let node_output_ids = compute_contract_output_ids(args, &self.output_ids);
1073
1074                    let mut children: Vec<Option<(EinsumOperand<'_>, Vec<char>)>> = Vec::new();
1075                    for (i, arg) in args.iter().enumerate() {
1076                        let child_needed = compute_child_needed_ids(&node_output_ids, i, args);
1077                        let (op, ids) = eval_node(arg, &mut ops, &child_needed, pool, &unified)?;
1078                        children.push(Some((op, ids)));
1079                    }
1080
1081                    let mut dim_sizes: HashMap<char, usize> = HashMap::new();
1082                    for child_opt in &children {
1083                        if let Some((op, ids)) = child_opt {
1084                            for (j, &id) in ids.iter().enumerate() {
1085                                dim_sizes.insert(id, op.dims()[j]);
1086                            }
1087                        }
1088                    }
1089
1090                    let input_ids: Vec<Vec<char>> = children
1091                        .iter()
1092                        .map(|c| c.as_ref().unwrap().1.clone())
1093                        .collect();
1094                    let code = omeco::EinCode::new(input_ids, self.output_ids.clone());
1095
1096                    let optimizer = omeco::GreedyMethod::default();
1097                    let nested = omeco::CodeOptimizer::optimize(&optimizer, &code, &dim_sizes)
1098                        .ok_or_else(|| {
1099                            crate::EinsumError::Internal(
1100                                "optimizer failed to produce a plan".into(),
1101                            )
1102                        })?;
1103
1104                    execute_nested_into(
1105                        &nested,
1106                        &mut children,
1107                        output,
1108                        &self.output_ids,
1109                        alpha,
1110                        beta,
1111                        pool,
1112                        &unified,
1113                    )?;
1114                }
1115            },
1116        }
1117
1118        Ok(())
1119    }
1120}
1121
1122// ---------------------------------------------------------------------------
1123// Tests
1124// ---------------------------------------------------------------------------
1125
1126#[cfg(test)]
1127mod tests {
1128    use super::*;
1129    use crate::parse::parse_einsum;
1130    use approx::assert_abs_diff_eq;
1131    use strided_view::{row_major_strides, StridedArray};
1132
1133    fn make_f64(dims: &[usize], data: Vec<f64>) -> EinsumOperand<'static> {
1134        let strides = row_major_strides(dims);
1135        StridedArray::from_parts(data, dims, &strides, 0)
1136            .unwrap()
1137            .into()
1138    }
1139
1140    #[test]
1141    fn test_binary_output_ids_canonical_lo_ro_batch_order() {
1142        let out = compute_binary_output_ids(&['b', 'a', 'x'], &['x', 'c', 'a'], &['b', 'c', 'a']);
1143        assert_eq!(out, vec!['b', 'c', 'a']);
1144    }
1145
1146    #[test]
1147    fn test_matmul() {
1148        let code = parse_einsum("ij,jk->ik").unwrap();
1149        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1150        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1151        let result = code.evaluate(vec![a, b], None).unwrap();
1152        match result {
1153            EinsumOperand::F64(data) => {
1154                let arr = data.as_array();
1155                assert_eq!(arr.dims(), &[2, 2]);
1156                assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0);
1157                assert_abs_diff_eq!(arr.get(&[0, 1]), 22.0);
1158                assert_abs_diff_eq!(arr.get(&[1, 0]), 43.0);
1159                assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0);
1160            }
1161            _ => panic!("expected F64"),
1162        }
1163    }
1164
1165    #[test]
1166    fn test_nested_three_tensor() {
1167        let code = parse_einsum("(ij,jk),kl->il").unwrap();
1168        // A = [[1,0],[0,1]] (identity), B = [[1,2],[3,4]], C = [[5,6],[7,8]]
1169        let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1170        let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1171        let c = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1172        // A*B = B, B*C = [[19,22],[43,50]]
1173        let result = code.evaluate(vec![a, b, c], None).unwrap();
1174        match result {
1175            EinsumOperand::F64(data) => {
1176                let arr = data.as_array();
1177                assert_eq!(arr.dims(), &[2, 2]);
1178                assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0);
1179                assert_abs_diff_eq!(arr.get(&[0, 1]), 22.0);
1180                assert_abs_diff_eq!(arr.get(&[1, 0]), 43.0);
1181                assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0);
1182            }
1183            _ => panic!("expected F64"),
1184        }
1185    }
1186
1187    #[test]
1188    fn test_outer_product() {
1189        let code = parse_einsum("i,j->ij").unwrap();
1190        let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1191        let b = make_f64(&[2], vec![10.0, 20.0]);
1192        let result = code.evaluate(vec![a, b], None).unwrap();
1193        match result {
1194            EinsumOperand::F64(data) => {
1195                let arr = data.as_array();
1196                assert_eq!(arr.dims(), &[3, 2]);
1197                assert_abs_diff_eq!(arr.get(&[0, 0]), 10.0);
1198                assert_abs_diff_eq!(arr.get(&[2, 1]), 60.0);
1199            }
1200            _ => panic!("expected F64"),
1201        }
1202    }
1203
1204    #[test]
1205    fn test_dot_product() {
1206        let code = parse_einsum("i,i->").unwrap();
1207        let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1208        let b = make_f64(&[3], vec![4.0, 5.0, 6.0]);
1209        let result = code.evaluate(vec![a, b], None).unwrap();
1210        match result {
1211            EinsumOperand::F64(data) => {
1212                // 1*4 + 2*5 + 3*6 = 32
1213                assert_abs_diff_eq!(data.as_array().data()[0], 32.0);
1214            }
1215            _ => panic!("expected F64"),
1216        }
1217    }
1218
1219    #[test]
1220    fn test_single_tensor_permute() {
1221        let code = parse_einsum("ij->ji").unwrap();
1222        let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1223        let result = code.evaluate(vec![a], None).unwrap();
1224        match result {
1225            EinsumOperand::F64(data) => {
1226                let arr = data.as_array();
1227                assert_eq!(arr.dims(), &[3, 2]);
1228                assert_abs_diff_eq!(arr.get(&[0, 0]), 1.0);
1229                assert_abs_diff_eq!(arr.get(&[0, 1]), 4.0);
1230            }
1231            _ => panic!("expected F64"),
1232        }
1233    }
1234
1235    #[test]
1236    fn test_single_tensor_trace() {
1237        let code = parse_einsum("ii->").unwrap();
1238        let a = make_f64(&[3, 3], vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1239        let result = code.evaluate(vec![a], None).unwrap();
1240        match result {
1241            EinsumOperand::F64(data) => {
1242                assert_abs_diff_eq!(data.as_array().data()[0], 6.0);
1243            }
1244            _ => panic!("expected F64"),
1245        }
1246    }
1247
1248    #[test]
1249    fn test_three_tensor_flat_omeco() {
1250        // ij,jk,kl->il -- flat 3-tensor, should use omeco
1251        let code = parse_einsum("ij,jk,kl->il").unwrap();
1252        let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1253        let b = make_f64(&[3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
1254        // c = identity; AB = [[4,2],[10,5]], AB*I = [[4,2],[10,5]]
1255        let c = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1256        let result = code.evaluate(vec![a, b, c], None).unwrap();
1257        match result {
1258            EinsumOperand::F64(data) => {
1259                let arr = data.as_array();
1260                assert_eq!(arr.dims(), &[2, 2]);
1261                assert_abs_diff_eq!(arr.get(&[0, 0]), 4.0, epsilon = 1e-10);
1262                assert_abs_diff_eq!(arr.get(&[0, 1]), 2.0, epsilon = 1e-10);
1263                assert_abs_diff_eq!(arr.get(&[1, 0]), 10.0, epsilon = 1e-10);
1264                assert_abs_diff_eq!(arr.get(&[1, 1]), 5.0, epsilon = 1e-10);
1265            }
1266            _ => panic!("expected F64"),
1267        }
1268    }
1269
1270    #[test]
1271    fn test_four_tensor_flat_omeco() {
1272        // ij,jk,kl,lm->im -- 4-tensor chain
1273        let code = parse_einsum("ij,jk,kl,lm->im").unwrap();
1274        let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]); // identity
1275        let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1276        let c = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]); // identity
1277        let d = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1278        // I*B = B, B*I = B, B*D = [[19,22],[43,50]]
1279        let result = code.evaluate(vec![a, b, c, d], None).unwrap();
1280        match result {
1281            EinsumOperand::F64(data) => {
1282                let arr = data.as_array();
1283                assert_eq!(arr.dims(), &[2, 2]);
1284                assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0, epsilon = 1e-10);
1285                assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0, epsilon = 1e-10);
1286            }
1287            _ => panic!("expected F64"),
1288        }
1289    }
1290
1291    #[test]
1292    fn test_orphan_output_axis_returns_error() {
1293        let code = parse_einsum("ij,jk->iz").unwrap();
1294        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1295        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1296        let err = code.evaluate(vec![a, b], None).unwrap_err();
1297        assert!(matches!(err, crate::EinsumError::OrphanOutputAxis(ref s) if s == "z"));
1298    }
1299
1300    #[test]
1301    fn test_operand_count_mismatch_too_few() {
1302        let code = parse_einsum("ij,jk->ik").unwrap();
1303        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1304        let err = code.evaluate(vec![a], None).unwrap_err();
1305        assert!(matches!(
1306            err,
1307            crate::EinsumError::OperandCountMismatch {
1308                expected: 2,
1309                found: 1
1310            }
1311        ));
1312    }
1313
1314    #[test]
1315    fn test_operand_count_mismatch_too_many() {
1316        let code = parse_einsum("ij->ji").unwrap();
1317        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1318        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1319        let err = code.evaluate(vec![a, b], None).unwrap_err();
1320        assert!(matches!(
1321            err,
1322            crate::EinsumError::OperandCountMismatch {
1323                expected: 1,
1324                found: 2
1325            }
1326        ));
1327    }
1328
1329    // -----------------------------------------------------------------------
1330    // evaluate_into tests
1331    // -----------------------------------------------------------------------
1332
1333    #[test]
1334    fn test_into_matmul() {
1335        let code = parse_einsum("ij,jk->ik").unwrap();
1336        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1337        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1338        let mut c = StridedArray::<f64>::col_major(&[2, 2]);
1339        code.evaluate_into(vec![a, b], c.view_mut(), 1.0, 0.0, None)
1340            .unwrap();
1341        assert_abs_diff_eq!(c.get(&[0, 0]), 19.0);
1342        assert_abs_diff_eq!(c.get(&[0, 1]), 22.0);
1343        assert_abs_diff_eq!(c.get(&[1, 0]), 43.0);
1344        assert_abs_diff_eq!(c.get(&[1, 1]), 50.0);
1345    }
1346
1347    #[test]
1348    fn test_into_matmul_alpha_beta() {
1349        // C = 2 * A*B + 3 * C_old
1350        let code = parse_einsum("ij,jk->ik").unwrap();
1351        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1352        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1353        // A*B = [[19,22],[43,50]]
1354        // C_old = [[1,1],[1,1]]
1355        // result = 2*[[19,22],[43,50]] + 3*[[1,1],[1,1]] = [[41,47],[89,103]]
1356        let mut c = StridedArray::<f64>::col_major(&[2, 2]);
1357        for v in c.data_mut().iter_mut() {
1358            *v = 1.0;
1359        }
1360        code.evaluate_into(vec![a, b], c.view_mut(), 2.0, 3.0, None)
1361            .unwrap();
1362        assert_abs_diff_eq!(c.get(&[0, 0]), 41.0, epsilon = 1e-10);
1363        assert_abs_diff_eq!(c.get(&[0, 1]), 47.0, epsilon = 1e-10);
1364        assert_abs_diff_eq!(c.get(&[1, 0]), 89.0, epsilon = 1e-10);
1365        assert_abs_diff_eq!(c.get(&[1, 1]), 103.0, epsilon = 1e-10);
1366    }
1367
1368    #[test]
1369    fn test_into_single_tensor_permute() {
1370        let code = parse_einsum("ij->ji").unwrap();
1371        let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1372        let mut c = StridedArray::<f64>::col_major(&[3, 2]);
1373        code.evaluate_into(vec![a], c.view_mut(), 1.0, 0.0, None)
1374            .unwrap();
1375        assert_eq!(c.dims(), &[3, 2]);
1376        assert_abs_diff_eq!(c.get(&[0, 0]), 1.0);
1377        assert_abs_diff_eq!(c.get(&[0, 1]), 4.0);
1378        assert_abs_diff_eq!(c.get(&[1, 0]), 2.0);
1379        assert_abs_diff_eq!(c.get(&[2, 1]), 6.0);
1380    }
1381
1382    #[test]
1383    fn test_into_single_tensor_trace() {
1384        let code = parse_einsum("ii->").unwrap();
1385        let a = make_f64(&[3, 3], vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1386        let mut c = StridedArray::<f64>::col_major(&[]);
1387        code.evaluate_into(vec![a], c.view_mut(), 1.0, 0.0, None)
1388            .unwrap();
1389        assert_abs_diff_eq!(c.data()[0], 6.0);
1390    }
1391
1392    #[test]
1393    fn test_into_three_tensor_omeco() {
1394        let code = parse_einsum("ij,jk,kl->il").unwrap();
1395        let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1396        let b = make_f64(&[3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
1397        let c_op = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1398        let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1399        code.evaluate_into(vec![a, b, c_op], out.view_mut(), 1.0, 0.0, None)
1400            .unwrap();
1401        assert_abs_diff_eq!(out.get(&[0, 0]), 4.0, epsilon = 1e-10);
1402        assert_abs_diff_eq!(out.get(&[0, 1]), 2.0, epsilon = 1e-10);
1403        assert_abs_diff_eq!(out.get(&[1, 0]), 10.0, epsilon = 1e-10);
1404        assert_abs_diff_eq!(out.get(&[1, 1]), 5.0, epsilon = 1e-10);
1405    }
1406
1407    #[test]
1408    fn test_into_nested() {
1409        let code = parse_einsum("(ij,jk),kl->il").unwrap();
1410        let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1411        let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1412        let c_op = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1413        let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1414        code.evaluate_into(vec![a, b, c_op], out.view_mut(), 1.0, 0.0, None)
1415            .unwrap();
1416        assert_abs_diff_eq!(out.get(&[0, 0]), 19.0, epsilon = 1e-10);
1417        assert_abs_diff_eq!(out.get(&[0, 1]), 22.0, epsilon = 1e-10);
1418        assert_abs_diff_eq!(out.get(&[1, 0]), 43.0, epsilon = 1e-10);
1419        assert_abs_diff_eq!(out.get(&[1, 1]), 50.0, epsilon = 1e-10);
1420    }
1421
1422    #[test]
1423    fn test_into_dot_product() {
1424        let code = parse_einsum("i,i->").unwrap();
1425        let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1426        let b = make_f64(&[3], vec![4.0, 5.0, 6.0]);
1427        let mut c = StridedArray::<f64>::col_major(&[]);
1428        code.evaluate_into(vec![a, b], c.view_mut(), 1.0, 0.0, None)
1429            .unwrap();
1430        assert_abs_diff_eq!(c.data()[0], 32.0);
1431    }
1432
1433    #[test]
1434    fn test_into_type_mismatch_f64_output_c64_input() {
1435        let code = parse_einsum("ij->ji").unwrap();
1436        let c64_data = vec![
1437            Complex64::new(1.0, 0.0),
1438            Complex64::new(2.0, 0.0),
1439            Complex64::new(3.0, 0.0),
1440            Complex64::new(4.0, 0.0),
1441        ];
1442        let strides = row_major_strides(&[2, 2]);
1443        let arr = StridedArray::from_parts(c64_data, &[2, 2], &strides, 0).unwrap();
1444        let op = EinsumOperand::C64(StridedData::Owned(arr));
1445        let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1446        let err = code
1447            .evaluate_into(vec![op], out.view_mut(), 1.0, 0.0, None)
1448            .unwrap_err();
1449        assert!(matches!(err, crate::EinsumError::TypeMismatch { .. }));
1450    }
1451
1452    #[test]
1453    fn test_into_shape_mismatch() {
1454        let code = parse_einsum("ij,jk->ik").unwrap();
1455        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1456        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1457        let mut out = StridedArray::<f64>::col_major(&[3, 3]); // wrong shape
1458        let err = code
1459            .evaluate_into(vec![a, b], out.view_mut(), 1.0, 0.0, None)
1460            .unwrap_err();
1461        assert!(matches!(
1462            err,
1463            crate::EinsumError::OutputShapeMismatch { .. }
1464        ));
1465    }
1466
1467    #[test]
1468    fn test_into_c64_output() {
1469        let code = parse_einsum("ij,jk->ik").unwrap();
1470        let c64 = |r| Complex64::new(r, 0.0);
1471        let a_data = vec![c64(1.0), c64(2.0), c64(3.0), c64(4.0)];
1472        let b_data = vec![c64(5.0), c64(6.0), c64(7.0), c64(8.0)];
1473        let strides = row_major_strides(&[2, 2]);
1474        let a = EinsumOperand::C64(StridedData::Owned(
1475            StridedArray::from_parts(a_data, &[2, 2], &strides, 0).unwrap(),
1476        ));
1477        let b = EinsumOperand::C64(StridedData::Owned(
1478            StridedArray::from_parts(b_data, &[2, 2], &strides, 0).unwrap(),
1479        ));
1480        let mut out = StridedArray::<Complex64>::col_major(&[2, 2]);
1481        code.evaluate_into(
1482            vec![a, b],
1483            out.view_mut(),
1484            c64(1.0),
1485            Complex64::zero(),
1486            None,
1487        )
1488        .unwrap();
1489        assert_abs_diff_eq!(out.get(&[0, 0]).re, 19.0);
1490        assert_abs_diff_eq!(out.get(&[1, 1]).re, 50.0);
1491    }
1492
1493    #[test]
1494    fn test_into_mixed_types_c64_output() {
1495        // f64 + c64 operands -> c64 output (f64 gets promoted)
1496        let code = parse_einsum("ij,jk->ik").unwrap();
1497        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1498        let c64 = |r| Complex64::new(r, 0.0);
1499        let b_data = vec![c64(5.0), c64(6.0), c64(7.0), c64(8.0)];
1500        let strides = row_major_strides(&[2, 2]);
1501        let b = EinsumOperand::C64(StridedData::Owned(
1502            StridedArray::from_parts(b_data, &[2, 2], &strides, 0).unwrap(),
1503        ));
1504        let mut out = StridedArray::<Complex64>::col_major(&[2, 2]);
1505        code.evaluate_into(
1506            vec![a, b],
1507            out.view_mut(),
1508            c64(1.0),
1509            Complex64::zero(),
1510            None,
1511        )
1512        .unwrap();
1513        assert_abs_diff_eq!(out.get(&[0, 0]).re, 19.0);
1514        assert_abs_diff_eq!(out.get(&[1, 1]).re, 50.0);
1515    }
1516}