strided_opteinsum/
expr.rs

1use std::collections::HashMap;
2
3use num_complex::Complex64;
4#[cfg(test)]
5use num_traits::Zero;
6use strided_einsum2::{einsum2_into, einsum2_into_owned};
7use strided_kernel::copy_scale;
8use strided_view::{StridedArray, StridedViewMut};
9
10use crate::operand::{EinsumOperand, EinsumScalar, StridedData};
11use crate::parse::{EinsumCode, EinsumNode};
12use crate::single_tensor::single_tensor_einsum;
13
14// ---------------------------------------------------------------------------
15// Buffer pool for intermediate reuse
16// ---------------------------------------------------------------------------
17
18/// Reusable buffer pool for intermediate tensor allocations.
19///
20/// Tracks freed `Vec<f64>` and `Vec<Complex64>` buffers indexed by length.
21/// When a contraction step completes, its input buffers are returned to the
22/// pool. Subsequent steps can reuse these buffers instead of allocating fresh.
23struct BufferPool {
24    f64_pool: HashMap<usize, Vec<Vec<f64>>>,
25    c64_pool: HashMap<usize, Vec<Vec<Complex64>>>,
26}
27
28impl BufferPool {
29    fn new() -> Self {
30        Self {
31            f64_pool: HashMap::new(),
32            c64_pool: HashMap::new(),
33        }
34    }
35}
36
37// ---------------------------------------------------------------------------
38// PoolOps — private trait for type-dispatched buffer pool access
39// ---------------------------------------------------------------------------
40
41/// Private sealed trait for buffer pool acquire/release, implemented for f64
42/// and Complex64. Enables generic `eval_pair_alloc` without exposing
43/// `BufferPool` in the public API.
44trait PoolOps: EinsumScalar {
45    /// Acquire a col-major buffer from the pool (or allocate fresh).
46    ///
47    /// # Safety contract
48    /// The returned array may contain uninitialized data. Callers must write
49    /// every element before reading (e.g. via `einsum2_into` with `beta=0`).
50    fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<Self>;
51
52    /// Release an owned buffer back to the pool for reuse.
53    /// Views are silently dropped (nothing to recycle).
54    fn pool_release(pool: &mut BufferPool, data: StridedData<'_, Self>);
55}
56
57impl PoolOps for f64 {
58    fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<f64> {
59        let total: usize = dims.iter().product();
60        // SAFETY: einsum2_into with beta=0 writes every output element before reading.
61        match pool.f64_pool.get_mut(&total).and_then(|v| v.pop()) {
62            Some(buf) => unsafe { StridedArray::col_major_from_buffer_uninit(buf, dims) },
63            None => unsafe { StridedArray::col_major_uninit(dims) },
64        }
65    }
66
67    fn pool_release(pool: &mut BufferPool, data: StridedData<'_, f64>) {
68        if let StridedData::Owned(arr) = data {
69            let buf = arr.into_data();
70            pool.f64_pool.entry(buf.len()).or_default().push(buf);
71        }
72    }
73}
74
75impl PoolOps for Complex64 {
76    fn pool_acquire(pool: &mut BufferPool, dims: &[usize]) -> StridedArray<Complex64> {
77        let total: usize = dims.iter().product();
78        // SAFETY: einsum2_into with beta=0 writes every output element before reading.
79        match pool.c64_pool.get_mut(&total).and_then(|v| v.pop()) {
80            Some(buf) => unsafe { StridedArray::col_major_from_buffer_uninit(buf, dims) },
81            None => unsafe { StridedArray::col_major_uninit(dims) },
82        }
83    }
84
85    fn pool_release(pool: &mut BufferPool, data: StridedData<'_, Complex64>) {
86        if let StridedData::Owned(arr) = data {
87            let buf = arr.into_data();
88            pool.c64_pool.entry(buf.len()).or_default().push(buf);
89        }
90    }
91}
92
93// ---------------------------------------------------------------------------
94// Helper: collect all index chars from a subtree, preserving first-seen order
95// ---------------------------------------------------------------------------
96
97fn collect_all_ids(node: &EinsumNode) -> Vec<char> {
98    let mut result = Vec::new();
99    collect_all_ids_inner(node, &mut result);
100    result
101}
102
103fn collect_all_ids_inner(node: &EinsumNode, result: &mut Vec<char>) {
104    match node {
105        EinsumNode::Leaf { ids, .. } => {
106            for &id in ids {
107                if !result.contains(&id) {
108                    result.push(id);
109                }
110            }
111        }
112        EinsumNode::Contract { args } => {
113            for arg in args {
114                collect_all_ids_inner(arg, result);
115            }
116        }
117    }
118}
119
120// ---------------------------------------------------------------------------
121// Helper: compute which output indices a Contract node should keep
122// ---------------------------------------------------------------------------
123
124/// For a Contract node with `args`, decide which index chars to keep.
125///
126/// An index is kept if it appears in `needed_ids` (what the parent/caller
127/// needs from this node) AND it is actually present in at least one child
128/// subtree.
129fn compute_contract_output_ids(args: &[EinsumNode], needed_ids: &[char]) -> Vec<char> {
130    // Walk args in order and collect ids preserving first-seen order
131    let mut all_ids_ordered = Vec::new();
132    for arg in args {
133        for id in collect_all_ids(arg) {
134            if !all_ids_ordered.contains(&id) {
135                all_ids_ordered.push(id);
136            }
137        }
138    }
139
140    // Keep only ids that the parent needs
141    all_ids_ordered
142        .into_iter()
143        .filter(|id| needed_ids.contains(id))
144        .collect()
145}
146
147/// Compute the set of ids that a child of a Contract node needs to provide.
148///
149/// A child needs to keep an index if:
150///   - It is in the Contract's own `output_ids` (parent needs it), OR
151///   - It is shared with at least one sibling subtree (contraction index).
152///
153/// The child is only responsible for indices in its own subtree, but this
154/// function returns the full needed set; the child will naturally intersect
155/// with its own ids.
156fn compute_child_needed_ids(
157    output_ids: &[char],
158    child_idx: usize,
159    args: &[EinsumNode],
160) -> Vec<char> {
161    let mut needed: Vec<char> = output_ids.to_vec();
162
163    // Add indices shared between this child and any sibling
164    let child_ids = collect_all_ids(&args[child_idx]);
165    for (j, arg) in args.iter().enumerate() {
166        if j == child_idx {
167            continue;
168        }
169        let sibling_ids = collect_all_ids(arg);
170        for &id in &child_ids {
171            if sibling_ids.contains(&id) && !needed.contains(&id) {
172                needed.push(id);
173            }
174        }
175    }
176
177    needed
178}
179
180// ---------------------------------------------------------------------------
181// Pairwise contraction (borrows operands)
182// ---------------------------------------------------------------------------
183
184fn out_dims_from_map(
185    dim_map: &HashMap<char, usize>,
186    output_ids: &[char],
187    size_dict: &HashMap<char, usize>,
188) -> crate::Result<Vec<usize>> {
189    let mut out_dims = Vec::with_capacity(output_ids.len());
190    for &id in output_ids {
191        if let Some(&dim) = dim_map.get(&id) {
192            out_dims.push(dim);
193        } else if let Some(&dim) = size_dict.get(&id) {
194            out_dims.push(dim);
195        } else {
196            return Err(crate::EinsumError::OrphanOutputAxis(id.to_string()));
197        }
198    }
199    Ok(out_dims)
200}
201
202/// Look up output dimensions directly from left/right ids without HashMap allocation.
203fn out_dims_from_ids(
204    left_ids: &[char],
205    left_dims: &[usize],
206    right_ids: &[char],
207    right_dims: &[usize],
208    output_ids: &[char],
209    size_dict: &HashMap<char, usize>,
210) -> crate::Result<Vec<usize>> {
211    let mut out_dims = Vec::with_capacity(output_ids.len());
212    for &id in output_ids {
213        if let Some(pos) = left_ids.iter().position(|&c| c == id) {
214            out_dims.push(left_dims[pos]);
215        } else if let Some(pos) = right_ids.iter().position(|&c| c == id) {
216            out_dims.push(right_dims[pos]);
217        } else if let Some(&dim) = size_dict.get(&id) {
218            out_dims.push(dim);
219        } else {
220            return Err(crate::EinsumError::OrphanOutputAxis(id.to_string()));
221        }
222    }
223    Ok(out_dims)
224}
225
226/// Generic inner function for pairwise contraction with buffer pool.
227///
228/// Acquires an output buffer, runs `einsum2_into`, and releases input buffers
229/// back to the pool.
230fn eval_pair_alloc<T: PoolOps>(
231    ld: StridedData<'_, T>,
232    left_ids: &[char],
233    rd: StridedData<'_, T>,
234    right_ids: &[char],
235    output_ids: &[char],
236    pool: &mut BufferPool,
237    size_dict: &HashMap<char, usize>,
238) -> crate::Result<EinsumOperand<'static>> {
239    let out_dims = out_dims_from_ids(
240        left_ids,
241        ld.dims(),
242        right_ids,
243        rd.dims(),
244        output_ids,
245        size_dict,
246    )?;
247    let mut c_arr = T::pool_acquire(pool, &out_dims);
248    {
249        let a_view = ld.as_view();
250        let b_view = rd.as_view();
251        einsum2_into(
252            c_arr.view_mut(),
253            &a_view,
254            &b_view,
255            output_ids,
256            left_ids,
257            right_ids,
258            T::one(),
259            T::zero(),
260        )?;
261    }
262    T::pool_release(pool, ld);
263    T::pool_release(pool, rd);
264    Ok(T::wrap_array(c_arr))
265}
266
267/// Contract two operands, consuming them by value. Promotes to c64 if types are mixed.
268///
269/// Uses view-based `einsum2_into` so that owned input buffers can be released
270/// back to the `pool` for reuse by subsequent contraction steps.
271fn eval_pair(
272    left: EinsumOperand<'_>,
273    left_ids: &[char],
274    right: EinsumOperand<'_>,
275    right_ids: &[char],
276    output_ids: &[char],
277    pool: &mut BufferPool,
278    size_dict: &HashMap<char, usize>,
279) -> crate::Result<EinsumOperand<'static>> {
280    match (left, right) {
281        (EinsumOperand::F64(ld), EinsumOperand::F64(rd)) => {
282            eval_pair_alloc(ld, left_ids, rd, right_ids, output_ids, pool, size_dict)
283        }
284        (EinsumOperand::C64(ld), EinsumOperand::C64(rd)) => {
285            eval_pair_alloc(ld, left_ids, rd, right_ids, output_ids, pool, size_dict)
286        }
287        (left, right) => {
288            // Mixed types: promote both to c64 by consuming, then recurse (hits C64/C64 branch)
289            let left_c64 = left.to_c64_owned();
290            let right_c64 = right.to_c64_owned();
291            eval_pair(
292                left_c64, left_ids, right_c64, right_ids, output_ids, pool, size_dict,
293            )
294        }
295    }
296}
297
298// ---------------------------------------------------------------------------
299// Pairwise contraction into user-provided output (zero-copy for final step)
300// ---------------------------------------------------------------------------
301
302/// Contract two operands directly into a user-provided output buffer.
303///
304/// Unlike `eval_pair`, this writes into `output` with alpha/beta scaling
305/// instead of allocating a fresh array. Used for the final contraction in
306/// `evaluate_into`.
307fn eval_pair_into<T: EinsumScalar>(
308    left: EinsumOperand<'_>,
309    left_ids: &[char],
310    right: EinsumOperand<'_>,
311    right_ids: &[char],
312    output: StridedViewMut<T>,
313    output_ids: &[char],
314    alpha: T,
315    beta: T,
316) -> crate::Result<()> {
317    let left_data = T::extract_data(left)?;
318    let right_data = T::extract_data(right)?;
319
320    match (left_data, right_data) {
321        (StridedData::Owned(a), StridedData::Owned(b)) => {
322            einsum2_into_owned(
323                output, a, b, output_ids, left_ids, right_ids, alpha, beta, false, false,
324            )?;
325        }
326        (StridedData::Owned(a), StridedData::View(b)) => {
327            einsum2_into(
328                output,
329                &a.view(),
330                &b,
331                output_ids,
332                left_ids,
333                right_ids,
334                alpha,
335                beta,
336            )?;
337        }
338        (StridedData::View(a), StridedData::Owned(b)) => {
339            einsum2_into(
340                output,
341                &a,
342                &b.view(),
343                output_ids,
344                left_ids,
345                right_ids,
346                alpha,
347                beta,
348            )?;
349        }
350        (StridedData::View(a), StridedData::View(b)) => {
351            einsum2_into(output, &a, &b, output_ids, left_ids, right_ids, alpha, beta)?;
352        }
353    }
354    Ok(())
355}
356
357// ---------------------------------------------------------------------------
358// Accumulate helper for single-tensor results
359// ---------------------------------------------------------------------------
360
361/// Write `output = alpha * result + beta * output`.
362///
363/// `result` must already have the same shape as `output`.
364fn accumulate_into<T: EinsumScalar>(
365    output: &mut StridedViewMut<T>,
366    result: &StridedArray<T>,
367    alpha: T,
368    beta: T,
369) -> crate::Result<()> {
370    let result_view = result.view();
371    if beta == T::zero() {
372        if alpha == T::one() {
373            strided_kernel::copy_into(output, &result_view)?;
374        } else {
375            copy_scale(output, &result_view, alpha)?;
376        }
377    } else {
378        // General case: output = alpha * result + beta * output
379        // axpy does: output += alpha * result, so we need to scale output by beta first.
380        // We use a temporary to avoid aliasing issues.
381        let dims = output.dims().to_vec();
382        let mut temp = StridedArray::<T>::col_major(&dims);
383        strided_kernel::copy_into(&mut temp.view_mut(), &result_view)?;
384        // temp now holds result data in col-major layout
385        // output = beta * output + alpha * temp
386        // Using zip_map2_into would need output as both src and dest.
387        // Instead: copy_scale output into a second temp, then zip_map2_into.
388        // But simpler: use axpy which reads+writes dest.
389        // axpy(dest, src, alpha) does: dest[i] = dest[i] + alpha * src[i]
390        // So: first scale output by beta, then axpy with alpha.
391        // "scale output by beta" = copy_scale into temp2, copy back. Or just
392        // use a different approach: compute full result in temp, copy to output.
393        //
394        // Simplest correct approach for this rare path:
395        let mut output_copy = StridedArray::<T>::col_major(&dims);
396        strided_kernel::copy_into(&mut output_copy.view_mut(), &output.as_view())?;
397        strided_kernel::zip_map2_into(output, &temp.view(), &output_copy.view(), |r, o| {
398            alpha * r + beta * o
399        })?;
400    }
401    Ok(())
402}
403
404// ---------------------------------------------------------------------------
405// Single-tensor dispatch (borrows operand)
406// ---------------------------------------------------------------------------
407
408/// Generic inner function for single-tensor einsum.
409fn eval_single_typed<T: EinsumScalar>(
410    data: &StridedData<'_, T>,
411    input_ids: &[char],
412    output_ids: &[char],
413    size_dict: &HashMap<char, usize>,
414) -> crate::Result<EinsumOperand<'static>> {
415    let view = data.as_view();
416    let result = single_tensor_einsum(&view, input_ids, output_ids, Some(size_dict))?;
417    Ok(T::wrap_array(result))
418}
419
420fn eval_single(
421    operand: &EinsumOperand<'_>,
422    input_ids: &[char],
423    output_ids: &[char],
424    size_dict: &HashMap<char, usize>,
425) -> crate::Result<EinsumOperand<'static>> {
426    match operand {
427        EinsumOperand::F64(data) => eval_single_typed(data, input_ids, output_ids, size_dict),
428        EinsumOperand::C64(data) => eval_single_typed(data, input_ids, output_ids, size_dict),
429    }
430}
431
432// ---------------------------------------------------------------------------
433// Permutation helpers
434// ---------------------------------------------------------------------------
435
436/// Check if the transformation from input_ids to output_ids is a pure
437/// permutation (same set of chars, same length, no repeated indices).
438fn is_permutation_only(input_ids: &[char], output_ids: &[char]) -> bool {
439    if input_ids.len() != output_ids.len() {
440        return false;
441    }
442    // Check no repeated indices in input (linear scan)
443    for (i, &id) in input_ids.iter().enumerate() {
444        if input_ids[..i].contains(&id) {
445            return false; // repeated index = trace, not permutation
446        }
447    }
448    // Check all output ids appear in input
449    for &id in output_ids {
450        if !input_ids.contains(&id) {
451            return false;
452        }
453    }
454    true
455}
456
457/// Compute the permutation that maps input_ids ordering to output_ids ordering.
458fn compute_permutation(input_ids: &[char], output_ids: &[char]) -> Vec<usize> {
459    output_ids
460        .iter()
461        .map(|oid| input_ids.iter().position(|iid| iid == oid).unwrap())
462        .collect()
463}
464
465// ---------------------------------------------------------------------------
466// Execute omeco NestedEinsum tree
467// ---------------------------------------------------------------------------
468
469/// Execute an omeco-optimized contraction tree by contracting pairs according
470/// to the tree structure.
471///
472/// `children` is a Vec of Option-wrapped (operand, ids) pairs. The omeco
473/// `NestedEinsum::Leaf` variant references children by index; we `.take()`
474/// each entry to move ownership out exactly once.
475fn execute_nested<'a>(
476    nested: &omeco::NestedEinsum<char>,
477    children: &mut Vec<Option<(EinsumOperand<'a>, Vec<char>)>>,
478    pool: &mut BufferPool,
479    size_dict: &HashMap<char, usize>,
480) -> crate::Result<(EinsumOperand<'a>, Vec<char>)> {
481    match nested {
482        omeco::NestedEinsum::Leaf { tensor_index } => {
483            let slot = children.get_mut(*tensor_index).ok_or_else(|| {
484                crate::EinsumError::Internal(format!(
485                    "optimizer referenced child index {} out of bounds",
486                    tensor_index
487                ))
488            })?;
489            let (op, ids) = slot.take().ok_or_else(|| {
490                crate::EinsumError::Internal(format!(
491                    "child operand {} was already consumed",
492                    tensor_index
493                ))
494            })?;
495            Ok((op, ids))
496        }
497        omeco::NestedEinsum::Node { args, eins } => {
498            if args.len() != 2 {
499                return Err(crate::EinsumError::Internal(format!(
500                    "optimizer produced non-binary node with {} children",
501                    args.len()
502                )));
503            }
504            let (left, left_ids) = execute_nested(&args[0], children, pool, size_dict)?;
505            let (right, right_ids) = execute_nested(&args[1], children, pool, size_dict)?;
506            let output_ids: Vec<char> = eins.iy.clone();
507            let result = eval_pair(
508                left,
509                &left_ids,
510                right,
511                &right_ids,
512                &output_ids,
513                pool,
514                size_dict,
515            )?;
516            Ok((result, output_ids))
517        }
518    }
519}
520
521/// Execute an omeco-optimized contraction tree, writing the root contraction
522/// directly into a user-provided output buffer.
523///
524/// Inner (non-root) contractions use normal `execute_nested` / `eval_pair`.
525/// Only the root `Node`'s contraction is written directly into `output`.
526fn execute_nested_into<'a, T: EinsumScalar>(
527    nested: &omeco::NestedEinsum<char>,
528    children: &mut Vec<Option<(EinsumOperand<'a>, Vec<char>)>>,
529    output: StridedViewMut<T>,
530    output_ids: &[char],
531    alpha: T,
532    beta: T,
533    pool: &mut BufferPool,
534    size_dict: &HashMap<char, usize>,
535) -> crate::Result<()> {
536    match nested {
537        omeco::NestedEinsum::Node { args, eins: _ } => {
538            if args.len() != 2 {
539                return Err(crate::EinsumError::Internal(format!(
540                    "optimizer produced non-binary node with {} children",
541                    args.len()
542                )));
543            }
544            // Evaluate children normally (they allocate temporaries)
545            let (left, left_ids) = execute_nested(&args[0], children, pool, size_dict)?;
546            let (right, right_ids) = execute_nested(&args[1], children, pool, size_dict)?;
547            // Root contraction writes directly into user's output
548            eval_pair_into(
549                left, &left_ids, right, &right_ids, output, output_ids, alpha, beta,
550            )
551        }
552        omeco::NestedEinsum::Leaf { tensor_index } => {
553            // Root is a single leaf — extract and accumulate into output
554            let slot = children.get_mut(*tensor_index).ok_or_else(|| {
555                crate::EinsumError::Internal(format!(
556                    "optimizer referenced child index {} out of bounds",
557                    tensor_index
558                ))
559            })?;
560            let (op, op_ids) = slot.take().ok_or_else(|| {
561                crate::EinsumError::Internal(format!(
562                    "child operand {} was already consumed",
563                    tensor_index
564                ))
565            })?;
566            let data = T::extract_data(op)?;
567            let arr = data.into_array();
568            // Permute if needed
569            if op_ids != output_ids {
570                let perm = compute_permutation(&op_ids, output_ids);
571                let permuted = arr.permuted(&perm)?;
572                accumulate_into(&mut { output }, &permuted, alpha, beta)?;
573            } else {
574                accumulate_into(&mut { output }, &arr, alpha, beta)?;
575            }
576            Ok(())
577        }
578    }
579}
580
581// ---------------------------------------------------------------------------
582// Recursive evaluation
583// ---------------------------------------------------------------------------
584
585/// Recursively evaluate an `EinsumNode`, returning the result operand and
586/// the index chars labelling its axes.
587///
588/// Leaf nodes return borrowed views directly (no copy). Contract nodes
589/// always produce freshly allocated results (`'static` coerced to `'a`).
590///
591/// `needed_ids` tells this node which indices the caller needs in the result.
592/// For the root call this is the final output indices of the einsum.
593fn eval_node<'a>(
594    node: &EinsumNode,
595    operands: &mut Vec<Option<EinsumOperand<'a>>>,
596    needed_ids: &[char],
597    pool: &mut BufferPool,
598    size_dict: &HashMap<char, usize>,
599) -> crate::Result<(EinsumOperand<'a>, Vec<char>)> {
600    match node {
601        EinsumNode::Leaf { ids, tensor_index } => {
602            let found = operands.len();
603            let slot = operands.get_mut(*tensor_index).ok_or_else(|| {
604                crate::EinsumError::OperandCountMismatch {
605                    expected: tensor_index + 1,
606                    found,
607                }
608            })?;
609            let op = slot.take().ok_or_else(|| {
610                crate::EinsumError::Internal(format!(
611                    "operand {} was already consumed",
612                    tensor_index
613                ))
614            })?;
615            // Return borrowed view directly — no to_owned_static() copy.
616            Ok((op, ids.clone()))
617        }
618        EinsumNode::Contract { args } => {
619            // Determine which indices this Contract node should output.
620            let node_output_ids = compute_contract_output_ids(args, needed_ids);
621
622            match args.len() {
623                0 => unreachable!("empty Contract node"),
624                1 => {
625                    // Single-tensor operation.
626                    let child_needed = compute_child_needed_ids(&node_output_ids, 0, args);
627                    let (child_op, child_ids) =
628                        eval_node(&args[0], operands, &child_needed, pool, size_dict)?;
629
630                    // Identity passthrough: no allocation needed.
631                    if child_ids == node_output_ids {
632                        return Ok((child_op, node_output_ids));
633                    }
634
635                    // Permutation-only passthrough: metadata reorder, no data copy.
636                    if is_permutation_only(&child_ids, &node_output_ids) {
637                        let perm = compute_permutation(&child_ids, &node_output_ids);
638                        return Ok((child_op.permuted(&perm)?, node_output_ids));
639                    }
640
641                    // General case: trace, reduction, repeat, duplicate, etc.
642                    let result = eval_single(&child_op, &child_ids, &node_output_ids, size_dict)?;
643                    Ok((result, node_output_ids))
644                }
645                2 => {
646                    // Binary contraction.
647                    let left_needed = compute_child_needed_ids(&node_output_ids, 0, args);
648                    let right_needed = compute_child_needed_ids(&node_output_ids, 1, args);
649                    let (left, left_ids) =
650                        eval_node(&args[0], operands, &left_needed, pool, size_dict)?;
651                    let (right, right_ids) =
652                        eval_node(&args[1], operands, &right_needed, pool, size_dict)?;
653                    let result = eval_pair(
654                        left,
655                        &left_ids,
656                        right,
657                        &right_ids,
658                        &node_output_ids,
659                        pool,
660                        size_dict,
661                    )?;
662                    Ok((result, node_output_ids))
663                }
664                _ => {
665                    // 3+ children: use omeco greedy optimizer to find
666                    // an efficient pairwise contraction order.
667
668                    // 1. Evaluate all children to get their operands and ids
669                    let mut children: Vec<Option<(EinsumOperand<'a>, Vec<char>)>> = Vec::new();
670                    for (i, arg) in args.iter().enumerate() {
671                        let child_needed = compute_child_needed_ids(&node_output_ids, i, args);
672                        let (op, ids) = eval_node(arg, operands, &child_needed, pool, size_dict)?;
673                        children.push(Some((op, ids)));
674                    }
675
676                    // 2. Build dimension sizes map from evaluated operands
677                    let mut dim_sizes: HashMap<char, usize> = HashMap::new();
678                    for child_opt in &children {
679                        if let Some((op, ids)) = child_opt {
680                            for (j, &id) in ids.iter().enumerate() {
681                                dim_sizes.insert(id, op.dims()[j]);
682                            }
683                        }
684                    }
685
686                    // 3. Build omeco EinCode from child ids and node output ids
687                    let input_ids: Vec<Vec<char>> = children
688                        .iter()
689                        .map(|c| c.as_ref().unwrap().1.clone())
690                        .collect();
691                    let code = omeco::EinCode::new(input_ids, node_output_ids.clone());
692
693                    // 4. Optimize using omeco greedy method
694                    let optimizer = omeco::GreedyMethod::default();
695                    let nested = omeco::CodeOptimizer::optimize(&optimizer, &code, &dim_sizes)
696                        .ok_or_else(|| {
697                            crate::EinsumError::Internal(
698                                "optimizer failed to produce a plan".into(),
699                            )
700                        })?;
701
702                    // 5. Execute the nested contraction tree
703                    let (result, result_ids) =
704                        execute_nested(&nested, &mut children, pool, size_dict)?;
705                    Ok((result, result_ids))
706                }
707            }
708        }
709    }
710}
711
712// ---------------------------------------------------------------------------
713// Public API
714// ---------------------------------------------------------------------------
715
716impl EinsumCode {
717    /// Evaluate the einsum contraction tree with the given operands.
718    ///
719    /// Borrowed view operands are propagated through the tree without copying.
720    /// The result lifetime matches the input operand lifetime: if all inputs
721    /// are owned (`'static`), the result is also `'static`.
722    ///
723    /// Pass `size_dict` to specify sizes for output indices not present in any
724    /// input (generative outputs like `"->ii"` or `"i->ij"`).
725    pub fn evaluate<'a>(
726        &self,
727        operands: Vec<EinsumOperand<'a>>,
728        size_dict: Option<&HashMap<char, usize>>,
729    ) -> crate::Result<EinsumOperand<'a>> {
730        let expected = leaf_count(&self.root);
731        if operands.len() != expected {
732            return Err(crate::EinsumError::OperandCountMismatch {
733                expected,
734                found: operands.len(),
735            });
736        }
737
738        let mut ops: Vec<Option<EinsumOperand<'a>>> = operands.into_iter().map(Some).collect();
739        let mut pool = BufferPool::new();
740
741        // Build unified size_dict: operand-inferred sizes + user-provided overrides
742        let mut unified = build_dim_map(&self.root, &ops);
743        if let Some(sd) = size_dict {
744            merge_size_dict(&mut unified, sd)?;
745        }
746
747        let (result, result_ids) =
748            eval_node(&self.root, &mut ops, &self.output_ids, &mut pool, &unified)?;
749
750        // If the result ids already match the desired output, we're done.
751        if result_ids == self.output_ids {
752            return Ok(result);
753        }
754
755        // Permutation-only: reorder metadata, no data copy.
756        if is_permutation_only(&result_ids, &self.output_ids) {
757            let perm = compute_permutation(&result_ids, &self.output_ids);
758            return Ok(result.permuted(&perm)?);
759        }
760
761        // General fallback: reduce/trace/repeat/duplicate to match the final output_ids.
762        let adjusted = eval_single(&result, &result_ids, &self.output_ids, &unified)?;
763        Ok(adjusted)
764    }
765}
766
767fn leaf_count(node: &EinsumNode) -> usize {
768    match node {
769        EinsumNode::Leaf { .. } => 1,
770        EinsumNode::Contract { args } => args.iter().map(leaf_count).sum(),
771    }
772}
773
774/// Build a dimension map from operands and the parsed tree.
775///
776/// Maps each index char to its dimension size by walking the tree's Leaf nodes
777/// and matching them to the corresponding operands.
778fn build_dim_map(
779    node: &EinsumNode,
780    operands: &[Option<EinsumOperand<'_>>],
781) -> HashMap<char, usize> {
782    let mut dim_map = HashMap::new();
783    build_dim_map_inner(node, operands, &mut dim_map);
784    dim_map
785}
786
787/// Merge user-provided size_dict into the unified map.
788///
789/// Returns an error if a label appears in both maps with different sizes.
790fn merge_size_dict(
791    unified: &mut HashMap<char, usize>,
792    user: &HashMap<char, usize>,
793) -> crate::Result<()> {
794    for (&label, &size) in user {
795        if let Some(&existing) = unified.get(&label) {
796            if existing != size {
797                return Err(crate::EinsumError::DimensionMismatch {
798                    axis: label.to_string(),
799                    dim_a: existing,
800                    dim_b: size,
801                });
802            }
803        } else {
804            unified.insert(label, size);
805        }
806    }
807    Ok(())
808}
809
810fn build_dim_map_inner(
811    node: &EinsumNode,
812    operands: &[Option<EinsumOperand<'_>>],
813    dim_map: &mut HashMap<char, usize>,
814) {
815    match node {
816        EinsumNode::Leaf { ids, tensor_index } => {
817            if let Some(Some(op)) = operands.get(*tensor_index) {
818                for (i, &id) in ids.iter().enumerate() {
819                    dim_map.insert(id, op.dims()[i]);
820                }
821            }
822        }
823        EinsumNode::Contract { args } => {
824            for arg in args {
825                build_dim_map_inner(arg, operands, dim_map);
826            }
827        }
828    }
829}
830
831impl EinsumCode {
832    /// Evaluate the einsum contraction tree, writing the result directly into
833    /// a user-provided output buffer with alpha/beta scaling.
834    ///
835    /// `output = alpha * einsum(operands) + beta * output`
836    ///
837    /// The output element type `T` must match the computation: use `f64` when
838    /// all operands are real, `Complex64` when any operand is complex.
839    /// If `T = f64` but any operand is complex, returns `TypeMismatch` error.
840    /// If `T = Complex64`, real operands are promoted automatically.
841    ///
842    /// Pass `size_dict` to specify sizes for output indices not present in any
843    /// input (generative outputs like `"->ii"` or `"i->ij"`).
844    pub fn evaluate_into<T: EinsumScalar>(
845        &self,
846        operands: Vec<EinsumOperand<'_>>,
847        mut output: StridedViewMut<T>,
848        alpha: T,
849        beta: T,
850        size_dict: Option<&HashMap<char, usize>>,
851    ) -> crate::Result<()> {
852        let expected = leaf_count(&self.root);
853        if operands.len() != expected {
854            return Err(crate::EinsumError::OperandCountMismatch {
855                expected,
856                found: operands.len(),
857            });
858        }
859
860        // Validate output type compatibility
861        let mut ops: Vec<Option<EinsumOperand<'_>>> = operands.into_iter().map(Some).collect();
862        T::validate_operands(&ops)?;
863
864        // Build unified size_dict: operand-inferred sizes + user-provided overrides
865        let mut unified = build_dim_map(&self.root, &ops);
866        if let Some(sd) = size_dict {
867            merge_size_dict(&mut unified, sd)?;
868        }
869
870        // Compute expected output shape
871        let expected_dims = out_dims_from_map(&unified, &self.output_ids, &unified)?;
872        if output.dims() != expected_dims.as_slice() {
873            return Err(crate::EinsumError::OutputShapeMismatch {
874                expected: expected_dims,
875                got: output.dims().to_vec(),
876            });
877        }
878
879        let mut pool = BufferPool::new();
880
881        match &self.root {
882            EinsumNode::Leaf { ids, tensor_index } => {
883                // Single operand: extract, permute/trace, accumulate
884                let op = ops[*tensor_index].take().ok_or_else(|| {
885                    crate::EinsumError::Internal("operand already consumed".into())
886                })?;
887                let single_result = eval_single(&op, ids, &self.output_ids, &unified)?;
888                let data = T::extract_data(single_result)?;
889                accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
890            }
891            EinsumNode::Contract { args } => match args.len() {
892                0 => unreachable!("empty Contract node"),
893                1 => {
894                    // Single child: evaluate, then accumulate
895                    let child_needed = compute_child_needed_ids(&self.output_ids, 0, args);
896                    let (child_op, child_ids) =
897                        eval_node(&args[0], &mut ops, &child_needed, &mut pool, &unified)?;
898
899                    if child_ids == self.output_ids {
900                        // Identity: just accumulate
901                        let data = T::extract_data(child_op)?;
902                        accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
903                    } else if is_permutation_only(&child_ids, &self.output_ids) {
904                        // Permutation: permute the data, then accumulate
905                        let perm = compute_permutation(&child_ids, &self.output_ids);
906                        let data = T::extract_data(child_op)?;
907                        let arr = data.into_array();
908                        let permuted = arr.permuted(&perm)?;
909                        accumulate_into(&mut output, &permuted, alpha, beta)?;
910                    } else {
911                        // General: trace/reduction/repeat/duplicate
912                        let result =
913                            eval_single(&child_op, &child_ids, &self.output_ids, &unified)?;
914                        let data = T::extract_data(result)?;
915                        accumulate_into(&mut output, &data.into_array(), alpha, beta)?;
916                    }
917                }
918                2 => {
919                    // Binary contraction: write directly into output
920                    let left_needed = compute_child_needed_ids(&self.output_ids, 0, args);
921                    let right_needed = compute_child_needed_ids(&self.output_ids, 1, args);
922                    let (left, left_ids) =
923                        eval_node(&args[0], &mut ops, &left_needed, &mut pool, &unified)?;
924                    let (right, right_ids) =
925                        eval_node(&args[1], &mut ops, &right_needed, &mut pool, &unified)?;
926                    eval_pair_into(
927                        left,
928                        &left_ids,
929                        right,
930                        &right_ids,
931                        output,
932                        &self.output_ids,
933                        alpha,
934                        beta,
935                    )?;
936                }
937                _ => {
938                    // 3+ children: use omeco, final contraction into output
939                    let node_output_ids = compute_contract_output_ids(args, &self.output_ids);
940
941                    let mut children: Vec<Option<(EinsumOperand<'_>, Vec<char>)>> = Vec::new();
942                    for (i, arg) in args.iter().enumerate() {
943                        let child_needed = compute_child_needed_ids(&node_output_ids, i, args);
944                        let (op, ids) =
945                            eval_node(arg, &mut ops, &child_needed, &mut pool, &unified)?;
946                        children.push(Some((op, ids)));
947                    }
948
949                    let mut dim_sizes: HashMap<char, usize> = HashMap::new();
950                    for child_opt in &children {
951                        if let Some((op, ids)) = child_opt {
952                            for (j, &id) in ids.iter().enumerate() {
953                                dim_sizes.insert(id, op.dims()[j]);
954                            }
955                        }
956                    }
957
958                    let input_ids: Vec<Vec<char>> = children
959                        .iter()
960                        .map(|c| c.as_ref().unwrap().1.clone())
961                        .collect();
962                    let code = omeco::EinCode::new(input_ids, self.output_ids.clone());
963
964                    let optimizer = omeco::GreedyMethod::default();
965                    let nested = omeco::CodeOptimizer::optimize(&optimizer, &code, &dim_sizes)
966                        .ok_or_else(|| {
967                            crate::EinsumError::Internal(
968                                "optimizer failed to produce a plan".into(),
969                            )
970                        })?;
971
972                    execute_nested_into(
973                        &nested,
974                        &mut children,
975                        output,
976                        &self.output_ids,
977                        alpha,
978                        beta,
979                        &mut pool,
980                        &unified,
981                    )?;
982                }
983            },
984        }
985
986        Ok(())
987    }
988}
989
990// ---------------------------------------------------------------------------
991// Tests
992// ---------------------------------------------------------------------------
993
994#[cfg(test)]
995mod tests {
996    use super::*;
997    use crate::parse::parse_einsum;
998    use approx::assert_abs_diff_eq;
999    use strided_view::{row_major_strides, StridedArray};
1000
1001    fn make_f64(dims: &[usize], data: Vec<f64>) -> EinsumOperand<'static> {
1002        let strides = row_major_strides(dims);
1003        StridedArray::from_parts(data, dims, &strides, 0)
1004            .unwrap()
1005            .into()
1006    }
1007
1008    #[test]
1009    fn test_matmul() {
1010        let code = parse_einsum("ij,jk->ik").unwrap();
1011        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1012        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1013        let result = code.evaluate(vec![a, b], None).unwrap();
1014        match result {
1015            EinsumOperand::F64(data) => {
1016                let arr = data.as_array();
1017                assert_eq!(arr.dims(), &[2, 2]);
1018                assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0);
1019                assert_abs_diff_eq!(arr.get(&[0, 1]), 22.0);
1020                assert_abs_diff_eq!(arr.get(&[1, 0]), 43.0);
1021                assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0);
1022            }
1023            _ => panic!("expected F64"),
1024        }
1025    }
1026
1027    #[test]
1028    fn test_nested_three_tensor() {
1029        let code = parse_einsum("(ij,jk),kl->il").unwrap();
1030        // A = [[1,0],[0,1]] (identity), B = [[1,2],[3,4]], C = [[5,6],[7,8]]
1031        let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1032        let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1033        let c = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1034        // A*B = B, B*C = [[19,22],[43,50]]
1035        let result = code.evaluate(vec![a, b, c], None).unwrap();
1036        match result {
1037            EinsumOperand::F64(data) => {
1038                let arr = data.as_array();
1039                assert_eq!(arr.dims(), &[2, 2]);
1040                assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0);
1041                assert_abs_diff_eq!(arr.get(&[0, 1]), 22.0);
1042                assert_abs_diff_eq!(arr.get(&[1, 0]), 43.0);
1043                assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0);
1044            }
1045            _ => panic!("expected F64"),
1046        }
1047    }
1048
1049    #[test]
1050    fn test_outer_product() {
1051        let code = parse_einsum("i,j->ij").unwrap();
1052        let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1053        let b = make_f64(&[2], vec![10.0, 20.0]);
1054        let result = code.evaluate(vec![a, b], None).unwrap();
1055        match result {
1056            EinsumOperand::F64(data) => {
1057                let arr = data.as_array();
1058                assert_eq!(arr.dims(), &[3, 2]);
1059                assert_abs_diff_eq!(arr.get(&[0, 0]), 10.0);
1060                assert_abs_diff_eq!(arr.get(&[2, 1]), 60.0);
1061            }
1062            _ => panic!("expected F64"),
1063        }
1064    }
1065
1066    #[test]
1067    fn test_dot_product() {
1068        let code = parse_einsum("i,i->").unwrap();
1069        let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1070        let b = make_f64(&[3], vec![4.0, 5.0, 6.0]);
1071        let result = code.evaluate(vec![a, b], None).unwrap();
1072        match result {
1073            EinsumOperand::F64(data) => {
1074                // 1*4 + 2*5 + 3*6 = 32
1075                assert_abs_diff_eq!(data.as_array().data()[0], 32.0);
1076            }
1077            _ => panic!("expected F64"),
1078        }
1079    }
1080
1081    #[test]
1082    fn test_single_tensor_permute() {
1083        let code = parse_einsum("ij->ji").unwrap();
1084        let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1085        let result = code.evaluate(vec![a], None).unwrap();
1086        match result {
1087            EinsumOperand::F64(data) => {
1088                let arr = data.as_array();
1089                assert_eq!(arr.dims(), &[3, 2]);
1090                assert_abs_diff_eq!(arr.get(&[0, 0]), 1.0);
1091                assert_abs_diff_eq!(arr.get(&[0, 1]), 4.0);
1092            }
1093            _ => panic!("expected F64"),
1094        }
1095    }
1096
1097    #[test]
1098    fn test_single_tensor_trace() {
1099        let code = parse_einsum("ii->").unwrap();
1100        let a = make_f64(&[3, 3], vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1101        let result = code.evaluate(vec![a], None).unwrap();
1102        match result {
1103            EinsumOperand::F64(data) => {
1104                assert_abs_diff_eq!(data.as_array().data()[0], 6.0);
1105            }
1106            _ => panic!("expected F64"),
1107        }
1108    }
1109
1110    #[test]
1111    fn test_three_tensor_flat_omeco() {
1112        // ij,jk,kl->il -- flat 3-tensor, should use omeco
1113        let code = parse_einsum("ij,jk,kl->il").unwrap();
1114        let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1115        let b = make_f64(&[3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
1116        // c = identity; AB = [[4,2],[10,5]], AB*I = [[4,2],[10,5]]
1117        let c = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1118        let result = code.evaluate(vec![a, b, c], None).unwrap();
1119        match result {
1120            EinsumOperand::F64(data) => {
1121                let arr = data.as_array();
1122                assert_eq!(arr.dims(), &[2, 2]);
1123                assert_abs_diff_eq!(arr.get(&[0, 0]), 4.0, epsilon = 1e-10);
1124                assert_abs_diff_eq!(arr.get(&[0, 1]), 2.0, epsilon = 1e-10);
1125                assert_abs_diff_eq!(arr.get(&[1, 0]), 10.0, epsilon = 1e-10);
1126                assert_abs_diff_eq!(arr.get(&[1, 1]), 5.0, epsilon = 1e-10);
1127            }
1128            _ => panic!("expected F64"),
1129        }
1130    }
1131
1132    #[test]
1133    fn test_four_tensor_flat_omeco() {
1134        // ij,jk,kl,lm->im -- 4-tensor chain
1135        let code = parse_einsum("ij,jk,kl,lm->im").unwrap();
1136        let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]); // identity
1137        let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1138        let c = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]); // identity
1139        let d = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1140        // I*B = B, B*I = B, B*D = [[19,22],[43,50]]
1141        let result = code.evaluate(vec![a, b, c, d], None).unwrap();
1142        match result {
1143            EinsumOperand::F64(data) => {
1144                let arr = data.as_array();
1145                assert_eq!(arr.dims(), &[2, 2]);
1146                assert_abs_diff_eq!(arr.get(&[0, 0]), 19.0, epsilon = 1e-10);
1147                assert_abs_diff_eq!(arr.get(&[1, 1]), 50.0, epsilon = 1e-10);
1148            }
1149            _ => panic!("expected F64"),
1150        }
1151    }
1152
1153    #[test]
1154    fn test_orphan_output_axis_returns_error() {
1155        let code = parse_einsum("ij,jk->iz").unwrap();
1156        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1157        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1158        let err = code.evaluate(vec![a, b], None).unwrap_err();
1159        assert!(matches!(err, crate::EinsumError::OrphanOutputAxis(ref s) if s == "z"));
1160    }
1161
1162    #[test]
1163    fn test_operand_count_mismatch_too_few() {
1164        let code = parse_einsum("ij,jk->ik").unwrap();
1165        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1166        let err = code.evaluate(vec![a], None).unwrap_err();
1167        assert!(matches!(
1168            err,
1169            crate::EinsumError::OperandCountMismatch {
1170                expected: 2,
1171                found: 1
1172            }
1173        ));
1174    }
1175
1176    #[test]
1177    fn test_operand_count_mismatch_too_many() {
1178        let code = parse_einsum("ij->ji").unwrap();
1179        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1180        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1181        let err = code.evaluate(vec![a, b], None).unwrap_err();
1182        assert!(matches!(
1183            err,
1184            crate::EinsumError::OperandCountMismatch {
1185                expected: 1,
1186                found: 2
1187            }
1188        ));
1189    }
1190
1191    // -----------------------------------------------------------------------
1192    // evaluate_into tests
1193    // -----------------------------------------------------------------------
1194
1195    #[test]
1196    fn test_into_matmul() {
1197        let code = parse_einsum("ij,jk->ik").unwrap();
1198        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1199        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1200        let mut c = StridedArray::<f64>::col_major(&[2, 2]);
1201        code.evaluate_into(vec![a, b], c.view_mut(), 1.0, 0.0, None)
1202            .unwrap();
1203        assert_abs_diff_eq!(c.get(&[0, 0]), 19.0);
1204        assert_abs_diff_eq!(c.get(&[0, 1]), 22.0);
1205        assert_abs_diff_eq!(c.get(&[1, 0]), 43.0);
1206        assert_abs_diff_eq!(c.get(&[1, 1]), 50.0);
1207    }
1208
1209    #[test]
1210    fn test_into_matmul_alpha_beta() {
1211        // C = 2 * A*B + 3 * C_old
1212        let code = parse_einsum("ij,jk->ik").unwrap();
1213        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1214        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1215        // A*B = [[19,22],[43,50]]
1216        // C_old = [[1,1],[1,1]]
1217        // result = 2*[[19,22],[43,50]] + 3*[[1,1],[1,1]] = [[41,47],[89,103]]
1218        let mut c = StridedArray::<f64>::col_major(&[2, 2]);
1219        for v in c.data_mut().iter_mut() {
1220            *v = 1.0;
1221        }
1222        code.evaluate_into(vec![a, b], c.view_mut(), 2.0, 3.0, None)
1223            .unwrap();
1224        assert_abs_diff_eq!(c.get(&[0, 0]), 41.0, epsilon = 1e-10);
1225        assert_abs_diff_eq!(c.get(&[0, 1]), 47.0, epsilon = 1e-10);
1226        assert_abs_diff_eq!(c.get(&[1, 0]), 89.0, epsilon = 1e-10);
1227        assert_abs_diff_eq!(c.get(&[1, 1]), 103.0, epsilon = 1e-10);
1228    }
1229
1230    #[test]
1231    fn test_into_single_tensor_permute() {
1232        let code = parse_einsum("ij->ji").unwrap();
1233        let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1234        let mut c = StridedArray::<f64>::col_major(&[3, 2]);
1235        code.evaluate_into(vec![a], c.view_mut(), 1.0, 0.0, None)
1236            .unwrap();
1237        assert_eq!(c.dims(), &[3, 2]);
1238        assert_abs_diff_eq!(c.get(&[0, 0]), 1.0);
1239        assert_abs_diff_eq!(c.get(&[0, 1]), 4.0);
1240        assert_abs_diff_eq!(c.get(&[1, 0]), 2.0);
1241        assert_abs_diff_eq!(c.get(&[2, 1]), 6.0);
1242    }
1243
1244    #[test]
1245    fn test_into_single_tensor_trace() {
1246        let code = parse_einsum("ii->").unwrap();
1247        let a = make_f64(&[3, 3], vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
1248        let mut c = StridedArray::<f64>::col_major(&[]);
1249        code.evaluate_into(vec![a], c.view_mut(), 1.0, 0.0, None)
1250            .unwrap();
1251        assert_abs_diff_eq!(c.data()[0], 6.0);
1252    }
1253
1254    #[test]
1255    fn test_into_three_tensor_omeco() {
1256        let code = parse_einsum("ij,jk,kl->il").unwrap();
1257        let a = make_f64(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1258        let b = make_f64(&[3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
1259        let c_op = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1260        let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1261        code.evaluate_into(vec![a, b, c_op], out.view_mut(), 1.0, 0.0, None)
1262            .unwrap();
1263        assert_abs_diff_eq!(out.get(&[0, 0]), 4.0, epsilon = 1e-10);
1264        assert_abs_diff_eq!(out.get(&[0, 1]), 2.0, epsilon = 1e-10);
1265        assert_abs_diff_eq!(out.get(&[1, 0]), 10.0, epsilon = 1e-10);
1266        assert_abs_diff_eq!(out.get(&[1, 1]), 5.0, epsilon = 1e-10);
1267    }
1268
1269    #[test]
1270    fn test_into_nested() {
1271        let code = parse_einsum("(ij,jk),kl->il").unwrap();
1272        let a = make_f64(&[2, 2], vec![1.0, 0.0, 0.0, 1.0]);
1273        let b = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1274        let c_op = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1275        let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1276        code.evaluate_into(vec![a, b, c_op], out.view_mut(), 1.0, 0.0, None)
1277            .unwrap();
1278        assert_abs_diff_eq!(out.get(&[0, 0]), 19.0, epsilon = 1e-10);
1279        assert_abs_diff_eq!(out.get(&[0, 1]), 22.0, epsilon = 1e-10);
1280        assert_abs_diff_eq!(out.get(&[1, 0]), 43.0, epsilon = 1e-10);
1281        assert_abs_diff_eq!(out.get(&[1, 1]), 50.0, epsilon = 1e-10);
1282    }
1283
1284    #[test]
1285    fn test_into_dot_product() {
1286        let code = parse_einsum("i,i->").unwrap();
1287        let a = make_f64(&[3], vec![1.0, 2.0, 3.0]);
1288        let b = make_f64(&[3], vec![4.0, 5.0, 6.0]);
1289        let mut c = StridedArray::<f64>::col_major(&[]);
1290        code.evaluate_into(vec![a, b], c.view_mut(), 1.0, 0.0, None)
1291            .unwrap();
1292        assert_abs_diff_eq!(c.data()[0], 32.0);
1293    }
1294
1295    #[test]
1296    fn test_into_type_mismatch_f64_output_c64_input() {
1297        let code = parse_einsum("ij->ji").unwrap();
1298        let c64_data = vec![
1299            Complex64::new(1.0, 0.0),
1300            Complex64::new(2.0, 0.0),
1301            Complex64::new(3.0, 0.0),
1302            Complex64::new(4.0, 0.0),
1303        ];
1304        let strides = row_major_strides(&[2, 2]);
1305        let arr = StridedArray::from_parts(c64_data, &[2, 2], &strides, 0).unwrap();
1306        let op = EinsumOperand::C64(StridedData::Owned(arr));
1307        let mut out = StridedArray::<f64>::col_major(&[2, 2]);
1308        let err = code
1309            .evaluate_into(vec![op], out.view_mut(), 1.0, 0.0, None)
1310            .unwrap_err();
1311        assert!(matches!(err, crate::EinsumError::TypeMismatch { .. }));
1312    }
1313
1314    #[test]
1315    fn test_into_shape_mismatch() {
1316        let code = parse_einsum("ij,jk->ik").unwrap();
1317        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1318        let b = make_f64(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
1319        let mut out = StridedArray::<f64>::col_major(&[3, 3]); // wrong shape
1320        let err = code
1321            .evaluate_into(vec![a, b], out.view_mut(), 1.0, 0.0, None)
1322            .unwrap_err();
1323        assert!(matches!(
1324            err,
1325            crate::EinsumError::OutputShapeMismatch { .. }
1326        ));
1327    }
1328
1329    #[test]
1330    fn test_into_c64_output() {
1331        let code = parse_einsum("ij,jk->ik").unwrap();
1332        let c64 = |r| Complex64::new(r, 0.0);
1333        let a_data = vec![c64(1.0), c64(2.0), c64(3.0), c64(4.0)];
1334        let b_data = vec![c64(5.0), c64(6.0), c64(7.0), c64(8.0)];
1335        let strides = row_major_strides(&[2, 2]);
1336        let a = EinsumOperand::C64(StridedData::Owned(
1337            StridedArray::from_parts(a_data, &[2, 2], &strides, 0).unwrap(),
1338        ));
1339        let b = EinsumOperand::C64(StridedData::Owned(
1340            StridedArray::from_parts(b_data, &[2, 2], &strides, 0).unwrap(),
1341        ));
1342        let mut out = StridedArray::<Complex64>::col_major(&[2, 2]);
1343        code.evaluate_into(
1344            vec![a, b],
1345            out.view_mut(),
1346            c64(1.0),
1347            Complex64::zero(),
1348            None,
1349        )
1350        .unwrap();
1351        assert_abs_diff_eq!(out.get(&[0, 0]).re, 19.0);
1352        assert_abs_diff_eq!(out.get(&[1, 1]).re, 50.0);
1353    }
1354
1355    #[test]
1356    fn test_into_mixed_types_c64_output() {
1357        // f64 + c64 operands -> c64 output (f64 gets promoted)
1358        let code = parse_einsum("ij,jk->ik").unwrap();
1359        let a = make_f64(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1360        let c64 = |r| Complex64::new(r, 0.0);
1361        let b_data = vec![c64(5.0), c64(6.0), c64(7.0), c64(8.0)];
1362        let strides = row_major_strides(&[2, 2]);
1363        let b = EinsumOperand::C64(StridedData::Owned(
1364            StridedArray::from_parts(b_data, &[2, 2], &strides, 0).unwrap(),
1365        ));
1366        let mut out = StridedArray::<Complex64>::col_major(&[2, 2]);
1367        code.evaluate_into(
1368            vec![a, b],
1369            out.view_mut(),
1370            c64(1.0),
1371            Complex64::zero(),
1372            None,
1373        )
1374        .unwrap();
1375        assert_abs_diff_eq!(out.get(&[0, 0]).re, 19.0);
1376        assert_abs_diff_eq!(out.get(&[1, 1]).re, 50.0);
1377    }
1378}