strided_einsum2/
lib.rs

1//! Binary Einstein summation on strided views.
2//!
3//! Provides `einsum2_into` for computing binary tensor contractions with
4//! accumulation semantics: `C = alpha * A * B + beta * C`.
5//!
6//! # Example
7//!
8//! ```
9//! use strided_view::StridedArray;
10//! use strided_einsum2::einsum2_into;
11//!
12//! // Matrix multiply: C_ik = A_ij * B_jk
13//! let a = StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
14//! let b = StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
15//! let mut c = StridedArray::<f64>::row_major(&[2, 2]);
16//!
17//! einsum2_into(
18//!     c.view_mut(), &a.view(), &b.view(),
19//!     &['i', 'k'], &['i', 'j'], &['j', 'k'],
20//!     1.0, 0.0,
21//! ).unwrap();
22//! ```
23
24#[cfg(all(feature = "faer", feature = "blas"))]
25compile_error!("Features `faer` and `blas` are mutually exclusive. Use one or the other.");
26
27#[cfg(all(feature = "faer", feature = "blas-inject"))]
28compile_error!("Features `faer` and `blas-inject` are mutually exclusive.");
29
30#[cfg(all(feature = "blas", feature = "blas-inject"))]
31compile_error!("Features `blas` and `blas-inject` are mutually exclusive.");
32
33#[cfg(all(feature = "blas-inject", not(feature = "blas")))]
34extern crate cblas_inject as cblas_sys;
35#[cfg(all(feature = "blas", not(feature = "blas-inject")))]
36extern crate cblas_sys;
37
38#[cfg(all(
39    not(feature = "faer"),
40    any(
41        all(feature = "blas", not(feature = "blas-inject")),
42        all(feature = "blas-inject", not(feature = "blas"))
43    )
44))]
45pub mod bgemm_blas;
46
47#[cfg(feature = "faer")]
48/// Batched GEMM backend using the [`faer`] library.
49pub mod bgemm_faer;
50/// Batched GEMM fallback using explicit loops.
51pub mod bgemm_naive;
52/// GEMM-ready operand types and preparation functions for contiguous data.
53pub mod contiguous;
54/// Contraction planning: axis classification and permutation computation.
55pub mod plan;
56/// Trace-axis reduction (summing axes that appear only in one operand).
57pub mod trace;
58/// Shared helpers (permutation inversion, multi-index iteration, dimension fusion).
59pub mod util;
60
61/// Backend abstraction for batched GEMM dispatch.
62pub mod backend;
63
64use std::any::TypeId;
65use std::fmt::Debug;
66use std::hash::Hash;
67
68use strided_kernel::zip_map2_into;
69#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
70use strided_view::StridedArray;
71use strided_view::{Adjoint, Conj, ElementOp, ElementOpApply, StridedView, StridedViewMut};
72
73pub use strided_traits::ScalarBase;
74
75pub use backend::Backend;
76pub use plan::Einsum2Plan;
77
78/// Trait alias for axis label types.
79pub trait AxisId: Clone + Eq + Hash + Debug {}
80impl<T: Clone + Eq + Hash + Debug> AxisId for T {}
81
82// ScalarBase is re-exported from strided_traits (see above).
83// It no longer requires ElementOpApply, enabling custom scalar types
84// (e.g., tropical semiring) to work with Identity-only views.
85
86/// Trait alias for element types supported by einsum operations.
87///
88/// When the `faer` feature is enabled, this additionally requires `faer::ComplexField`
89/// so that the faer GEMM backend can be used.
90#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
91pub trait Scalar: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
92
93#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
94impl<T> Scalar for T where T: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
95
96/// Trait alias for element types (with `blas` or `blas-inject` feature).
97///
98/// Includes `BlasGemm` so that all `Scalar` types can be dispatched to CBLAS.
99#[cfg(all(
100    not(feature = "faer"),
101    any(
102        all(feature = "blas", not(feature = "blas-inject")),
103        all(feature = "blas-inject", not(feature = "blas"))
104    )
105))]
106pub trait Scalar: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
107
108#[cfg(all(
109    not(feature = "faer"),
110    any(
111        all(feature = "blas", not(feature = "blas-inject")),
112        all(feature = "blas-inject", not(feature = "blas"))
113    )
114))]
115impl<T> Scalar for T where T: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
116
117/// Trait alias for element types (without `faer` or BLAS features).
118#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
119pub trait Scalar: ScalarBase + ElementOpApply {}
120
121#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
122impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
123
124/// Placeholder trait definition for invalid mutually-exclusive feature combinations.
125///
126/// The crate emits `compile_error!` above for these combinations. This trait only
127/// avoids cascading type-resolution errors so users see the intended diagnostics.
128#[cfg(any(
129    all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
130    all(feature = "blas", feature = "blas-inject")
131))]
132pub trait Scalar: ScalarBase + ElementOpApply {}
133
134#[cfg(any(
135    all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
136    all(feature = "blas", feature = "blas-inject")
137))]
138impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
139
140/// Errors specific to einsum operations.
141#[derive(Debug, thiserror::Error)]
142pub enum EinsumError {
143    #[error("duplicate axis label: {0}")]
144    DuplicateAxis(String),
145    #[error("output axis {0} not found in any input")]
146    OrphanOutputAxis(String),
147    #[error("dimension mismatch for axis {axis:?}: {dim_a} vs {dim_b}")]
148    DimensionMismatch {
149        axis: String,
150        dim_a: usize,
151        dim_b: usize,
152    },
153    #[error(transparent)]
154    Strided(#[from] strided_view::StridedError),
155}
156
157/// Convenience alias for `Result<T, EinsumError>`.
158pub type Result<T> = std::result::Result<T, EinsumError>;
159
160/// Returns `true` if the given `ElementOp` type represents conjugation.
161///
162/// - `Identity` / `Transpose` → `false` (no per-element conjugation)
163/// - `Conj` / `Adjoint` → `true` (per-element conjugation needed)
164///
165/// For scalar types, `Transpose::apply(x) = x` (identity) and the dimension
166/// swap is already reflected in the view's strides/dims.  Similarly,
167/// `Adjoint::apply(x) = x.conj()` with the dimension swap in the view.
168fn op_is_conj<Op: 'static>() -> bool {
169    TypeId::of::<Op>() == TypeId::of::<Conj>() || TypeId::of::<Op>() == TypeId::of::<Adjoint>()
170}
171
172/// Binary einsum contraction: `C = alpha * contract(A, B) + beta * C`.
173///
174/// `ic`, `ia`, `ib` are axis labels for C, A, B respectively.
175/// Axes are classified as:
176/// - **batch**: in A, B, and C
177/// - **lo** (left-output): in A and C, not B
178/// - **ro** (right-output): in B and C, not A
179/// - **sum** (contraction): in A and B, not C
180///
181/// Trace axes (only in A or only in B) are detected and reduced lazily.
182pub fn einsum2_into<T: Scalar, OpA, OpB, ID: AxisId>(
183    c: StridedViewMut<T>,
184    a: &StridedView<T, OpA>,
185    b: &StridedView<T, OpB>,
186    ic: &[ID],
187    ia: &[ID],
188    ib: &[ID],
189    alpha: T,
190    beta: T,
191) -> Result<()>
192where
193    OpA: ElementOp<T> + 'static,
194    OpB: ElementOp<T> + 'static,
195{
196    // 1. Build plan
197    let plan = Einsum2Plan::new(ia, ib, ic)?;
198
199    // 2. Validate dimension consistency across operands
200    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
201
202    // 3. Reduce trace axes if present; determine conjugation flags.
203    //    When trace reduction occurs, Op is already applied during the reduction,
204    //    so conj flag is false. Otherwise, we strip the Op and pass a conj flag
205    //    to the GEMM kernel (avoiding materialization).
206    let left_trace = trace::find_trace_indices(ia, ib, ic);
207    let (a_buf, conj_a) = if !left_trace.is_empty() {
208        (Some(trace::reduce_trace_axes(a, &left_trace)?), false)
209    } else {
210        (None, op_is_conj::<OpA>())
211    };
212
213    let a_view: StridedView<T> = match a_buf.as_ref() {
214        Some(buf) => buf.view(),
215        None => StridedView::new(a.data(), a.dims(), a.strides(), a.offset())
216            .expect("strip_op_view: metadata already validated"),
217    };
218
219    let right_trace = trace::find_trace_indices(ib, ia, ic);
220    let (b_buf, conj_b) = if !right_trace.is_empty() {
221        (Some(trace::reduce_trace_axes(b, &right_trace)?), false)
222    } else {
223        (None, op_is_conj::<OpB>())
224    };
225
226    let b_view: StridedView<T> = match b_buf.as_ref() {
227        Some(buf) => buf.view(),
228        None => StridedView::new(b.data(), b.dims(), b.strides(), b.offset())
229            .expect("strip_op_view: metadata already validated"),
230    };
231
232    // 4. Dispatch to GEMM
233    #[cfg(any(
234        all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
235        all(
236            not(feature = "faer"),
237            any(
238                all(feature = "blas", not(feature = "blas-inject")),
239                all(feature = "blas-inject", not(feature = "blas"))
240            )
241        )
242    ))]
243    {
244        let conj_fn = make_conj_fn::<T>();
245        einsum2_dispatch::<T, backend::ActiveBackend, _>(
246            c, &a_view, &b_view, &plan, alpha, beta, conj_a, conj_b, conj_fn,
247        )?;
248    }
249
250    #[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
251    {
252        let a_perm = a_view.permute(&plan.left_perm)?;
253        let b_perm = b_view.permute(&plan.right_perm)?;
254        let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
255
256        if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
257            let mul_fn = move |a_val: T, b_val: T| -> T {
258                let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
259                let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
260                alpha * a_c * b_c
261            };
262            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
263            return Ok(());
264        }
265
266        bgemm_naive::bgemm_strided_into(
267            &mut c_perm,
268            &a_perm,
269            &b_perm,
270            plan.batch.len(),
271            plan.lo.len(),
272            plan.ro.len(),
273            plan.sum.len(),
274            alpha,
275            beta,
276            conj_a,
277            conj_b,
278        )?;
279    }
280
281    Ok(())
282}
283
284/// Binary einsum for custom scalar types using naive GEMM.
285///
286/// Like [`einsum2_into`] but works with any `ScalarBase` type (no `ElementOpApply` required).
287/// Uses closures `map_a` and `map_b` for per-element transformation instead of
288/// conjugation flags. Always dispatches to the naive GEMM kernel.
289///
290/// Views must use `Identity` element operations (the default).
291pub fn einsum2_naive_into<T, ID, MapA, MapB>(
292    c: StridedViewMut<T>,
293    a: &StridedView<T>,
294    b: &StridedView<T>,
295    ic: &[ID],
296    ia: &[ID],
297    ib: &[ID],
298    alpha: T,
299    beta: T,
300    map_a: MapA,
301    map_b: MapB,
302) -> Result<()>
303where
304    T: ScalarBase,
305    ID: AxisId,
306    MapA: Fn(T) -> T + strided_kernel::MaybeSync,
307    MapB: Fn(T) -> T + strided_kernel::MaybeSync,
308{
309    let plan = Einsum2Plan::new(ia, ib, ic)?;
310    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
311
312    // Reduce trace axes if present.
313    // When trace reduction occurs, map is applied via map_into before reduction,
314    // so we use identity map for GEMM. Otherwise, pass through the original map.
315    let left_trace = trace::find_trace_indices(ia, ib, ic);
316    let (a_buf, use_map_a) = if !left_trace.is_empty() {
317        let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(a.dims()) };
318        strided_kernel::map_into(&mut mapped.view_mut(), a, &map_a)?;
319        let reduced = trace::reduce_trace_axes(&mapped.view(), &left_trace)?;
320        (Some(reduced), false)
321    } else {
322        (None, true)
323    };
324    let a_view: StridedView<T> = match a_buf.as_ref() {
325        Some(buf) => buf.view(),
326        None => a.clone(),
327    };
328
329    let right_trace = trace::find_trace_indices(ib, ia, ic);
330    let (b_buf, use_map_b) = if !right_trace.is_empty() {
331        let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(b.dims()) };
332        strided_kernel::map_into(&mut mapped.view_mut(), b, &map_b)?;
333        let reduced = trace::reduce_trace_axes(&mapped.view(), &right_trace)?;
334        (Some(reduced), false)
335    } else {
336        (None, true)
337    };
338    let b_view: StridedView<T> = match b_buf.as_ref() {
339        Some(buf) => buf.view(),
340        None => b.clone(),
341    };
342
343    let a_perm = a_view.permute(&plan.left_perm)?;
344    let b_perm = b_view.permute(&plan.right_perm)?;
345    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
346
347    // Element-wise fast path
348    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
349        let mul_fn = move |a_val: T, b_val: T| -> T {
350            let a_c = if use_map_a { map_a(a_val) } else { a_val };
351            let b_c = if use_map_b { map_b(b_val) } else { b_val };
352            alpha * a_c * b_c
353        };
354        zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
355        return Ok(());
356    }
357
358    let final_map_a: Box<dyn Fn(T) -> T> = if use_map_a {
359        Box::new(map_a)
360    } else {
361        Box::new(|x| x)
362    };
363    let final_map_b: Box<dyn Fn(T) -> T> = if use_map_b {
364        Box::new(map_b)
365    } else {
366        Box::new(|x| x)
367    };
368
369    bgemm_naive::bgemm_strided_into_with_map(
370        &mut c_perm,
371        &a_perm,
372        &b_perm,
373        plan.batch.len(),
374        plan.lo.len(),
375        plan.ro.len(),
376        plan.sum.len(),
377        alpha,
378        beta,
379        final_map_a,
380        final_map_b,
381    )?;
382
383    Ok(())
384}
385
386/// Binary einsum with a pluggable GEMM backend.
387///
388/// Like [`einsum2_into`] but works with any `ScalarBase` type and dispatches
389/// to the caller-provided GEMM backend `B`. Views must use `Identity` element
390/// operations (the default).
391///
392/// External crates can implement [`Backend`] for custom scalar types
393/// (e.g., tropical semiring) and pass the backend here.
394pub fn einsum2_with_backend_into<T, B, ID>(
395    c: StridedViewMut<T>,
396    a: &StridedView<T>,
397    b: &StridedView<T>,
398    ic: &[ID],
399    ia: &[ID],
400    ib: &[ID],
401    alpha: T,
402    beta: T,
403) -> Result<()>
404where
405    T: ScalarBase,
406    B: Backend<T>,
407    ID: AxisId,
408{
409    let plan = Einsum2Plan::new(ia, ib, ic)?;
410    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
411
412    // Trace reduction (plain sum, Identity views)
413    let left_trace = trace::find_trace_indices(ia, ib, ic);
414    let a_buf = if !left_trace.is_empty() {
415        Some(trace::reduce_trace_axes(a, &left_trace)?)
416    } else {
417        None
418    };
419    let a_view: StridedView<T> = match a_buf.as_ref() {
420        Some(buf) => buf.view(),
421        None => a.clone(),
422    };
423
424    let right_trace = trace::find_trace_indices(ib, ia, ic);
425    let b_buf = if !right_trace.is_empty() {
426        Some(trace::reduce_trace_axes(b, &right_trace)?)
427    } else {
428        None
429    };
430    let b_view: StridedView<T> = match b_buf.as_ref() {
431        Some(buf) => buf.view(),
432        None => b.clone(),
433    };
434
435    // No conjugation for generic backend path
436    einsum2_dispatch::<T, B, _>(c, &a_view, &b_view, &plan, alpha, beta, false, false, None)
437}
438
439/// Build the conj materialization function pointer for the active backend.
440///
441/// When the backend requires conj to be materialized into data (e.g. CBLAS),
442/// returns `Some(conj_apply)`. Otherwise returns `None` (backend handles conj
443/// via transpose flags or similar).
444#[cfg(any(
445    all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
446    all(
447        not(feature = "faer"),
448        any(
449            all(feature = "blas", not(feature = "blas-inject")),
450            all(feature = "blas-inject", not(feature = "blas"))
451        )
452    )
453))]
454fn make_conj_fn<T: Scalar>() -> Option<fn(T) -> T> {
455    if <backend::ActiveBackend as Backend<T>>::MATERIALIZES_CONJ {
456        Some(|x| Conj::apply(x))
457    } else {
458        None
459    }
460}
461
462/// Internal GEMM dispatch, generic over backend.
463///
464/// Called after trace reduction with plain Identity views. Handles:
465/// 1. Permutation to canonical order
466/// 2. Element-wise fast path (if applicable)
467/// 3. Contiguous preparation via `prepare_input_view`
468/// 4. GEMM via `B::bgemm_contiguous_into`
469/// 5. Finalize (copy-back if needed)
470///
471/// `conj_fn` is the materialization function for backends that need conj
472/// applied to data before GEMM. Pass `None` when conj_a/conj_b are both false
473/// or when the backend handles conj natively (via flags).
474fn einsum2_dispatch<T, B, ID>(
475    c: StridedViewMut<T>,
476    a: &StridedView<T>,
477    b: &StridedView<T>,
478    plan: &Einsum2Plan<ID>,
479    alpha: T,
480    beta: T,
481    conj_a: bool,
482    conj_b: bool,
483    conj_fn: Option<fn(T) -> T>,
484) -> Result<()>
485where
486    T: ScalarBase,
487    B: Backend<T>,
488    ID: AxisId,
489{
490    // 1. Permute to canonical order
491    let a_perm = a.permute(&plan.left_perm)?;
492    let b_perm = b.permute(&plan.right_perm)?;
493    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
494
495    // 2. Fast path: element-wise (all batch, no contraction)
496    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
497        if !conj_a && !conj_b && alpha == T::one() {
498            zip_map2_into(&mut c_perm, &a_perm, &b_perm, |a_val, b_val| a_val * b_val)?;
499        } else if !conj_a && !conj_b {
500            let mul_fn = move |a_val: T, b_val: T| -> T { alpha * a_val * b_val };
501            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
502        } else {
503            let conj_fn = conj_fn.unwrap_or(|x| x);
504            let mul_fn = move |a_val: T, b_val: T| -> T {
505                let a_c = if conj_a { conj_fn(a_val) } else { a_val };
506                let b_c = if conj_b { conj_fn(b_val) } else { b_val };
507                alpha * a_c * b_c
508            };
509            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
510        }
511        return Ok(());
512    }
513
514    // 3. Prepare contiguous operands
515    let n_lo = plan.lo.len();
516    let n_ro = plan.ro.len();
517    let n_sum = plan.sum.len();
518    let use_pool = true;
519    let materialize = if B::MATERIALIZES_CONJ { conj_fn } else { None };
520
521    let a_op = contiguous::prepare_input_view(
522        &a_perm,
523        n_lo,
524        n_sum,
525        conj_a,
526        B::REQUIRES_UNIT_STRIDE,
527        use_pool,
528        materialize,
529    )?;
530    let b_op = contiguous::prepare_input_view(
531        &b_perm,
532        n_sum,
533        n_ro,
534        conj_b,
535        B::REQUIRES_UNIT_STRIDE,
536        use_pool,
537        materialize,
538    )?;
539    let mut c_op = contiguous::prepare_output_view(
540        &mut c_perm,
541        n_lo,
542        n_ro,
543        beta,
544        B::REQUIRES_UNIT_STRIDE,
545        use_pool,
546    )?;
547
548    // Compute fused dimension sizes
549    let lo_dims = &a_perm.dims()[..n_lo];
550    let sum_dims = &a_perm.dims()[n_lo..n_lo + n_sum];
551    let batch_dims = &a_perm.dims()[n_lo + n_sum..];
552    let ro_dims = &b_perm.dims()[n_sum..n_sum + n_ro];
553    let m: usize = lo_dims.iter().product::<usize>().max(1);
554    let k: usize = sum_dims.iter().product::<usize>().max(1);
555    let n: usize = ro_dims.iter().product::<usize>().max(1);
556
557    // 4. GEMM — dispatched through trait
558    B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta)?;
559
560    // 5. Finalize
561    c_op.finalize_into(&mut c_perm)?;
562
563    Ok(())
564}
565
566/// Binary einsum accepting owned inputs for zero-copy optimization.
567///
568/// Same semantics as [`einsum2_into`] but accepts owned `StridedArray` inputs.
569/// When inputs have non-contiguous strides after permutation, ownership
570/// transfer avoids allocating separate buffers. For contiguous inputs,
571/// the behavior is identical.
572///
573/// `conj_a` and `conj_b` indicate whether to conjugate elements of A/B.
574#[cfg(any(
575    all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
576    all(
577        not(feature = "faer"),
578        any(
579            all(feature = "blas", not(feature = "blas-inject")),
580            all(feature = "blas-inject", not(feature = "blas"))
581        )
582    )
583))]
584pub fn einsum2_into_owned<T: Scalar, ID: AxisId>(
585    c: StridedViewMut<T>,
586    a: StridedArray<T>,
587    b: StridedArray<T>,
588    ic: &[ID],
589    ia: &[ID],
590    ib: &[ID],
591    alpha: T,
592    beta: T,
593    conj_a: bool,
594    conj_b: bool,
595) -> Result<()>
596where
597    backend::ActiveBackend: Backend<T>,
598{
599    // 1. Build plan
600    let plan = Einsum2Plan::new(ia, ib, ic)?;
601
602    // 2. Validate dimensions
603    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
604
605    // 3. Trace reduction: reduce trace axes if present.
606    //    When trace reduction occurs, conjugation is applied during reduction,
607    //    so the conj flag becomes false. Otherwise keep the caller's flag.
608    let left_trace = trace::find_trace_indices(ia, ib, ic);
609    let (a_for_gemm, conj_a_final) = if !left_trace.is_empty() {
610        (trace::reduce_trace_axes(&a.view(), &left_trace)?, false)
611    } else {
612        (a, conj_a)
613    };
614
615    let right_trace = trace::find_trace_indices(ib, ia, ic);
616    let (b_for_gemm, conj_b_final) = if !right_trace.is_empty() {
617        (trace::reduce_trace_axes(&b.view(), &right_trace)?, false)
618    } else {
619        (b, conj_b)
620    };
621
622    // 4. Permute to canonical order (metadata-only on owned arrays)
623    let a_perm = a_for_gemm.permuted(&plan.left_perm)?;
624    let b_perm = b_for_gemm.permuted(&plan.right_perm)?;
625    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
626
627    let n_lo = plan.lo.len();
628    let n_ro = plan.ro.len();
629    let n_sum = plan.sum.len();
630
631    // 5. Fast path: element-wise (all batch, no contraction)
632    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
633        let mul_fn = move |a_val: T, b_val: T| -> T {
634            let a_c = if conj_a_final {
635                Conj::apply(a_val)
636            } else {
637                a_val
638            };
639            let b_c = if conj_b_final {
640                Conj::apply(b_val)
641            } else {
642                b_val
643            };
644            alpha * a_c * b_c
645        };
646        zip_map2_into(&mut c_perm, &a_perm.view(), &b_perm.view(), mul_fn)?;
647        return Ok(());
648    }
649
650    // 6. Extract dimension sizes BEFORE consuming arrays via prepare_input_owned
651    let a_dims_perm = a_perm.dims().to_vec();
652    let b_dims_perm = b_perm.dims().to_vec();
653
654    let lo_dims = &a_dims_perm[..n_lo];
655    let sum_dims = &a_dims_perm[n_lo..n_lo + n_sum];
656    let batch_dims = a_dims_perm[n_lo + n_sum..].to_vec();
657    let ro_dims = &b_dims_perm[n_sum..n_sum + n_ro];
658    let m: usize = lo_dims.iter().product::<usize>().max(1);
659    let k: usize = sum_dims.iter().product::<usize>().max(1);
660    let n: usize = ro_dims.iter().product::<usize>().max(1);
661
662    // 7. Prepare contiguous operands (owned path -- avoids extra copies)
663    let conj_fn = make_conj_fn::<T>();
664    let materialize = if <backend::ActiveBackend as Backend<T>>::MATERIALIZES_CONJ {
665        conj_fn
666    } else {
667        None
668    };
669    let use_pool = true;
670    let unit_stride = <backend::ActiveBackend as Backend<T>>::REQUIRES_UNIT_STRIDE;
671    let a_op = contiguous::prepare_input_owned(
672        a_perm,
673        n_lo,
674        n_sum,
675        conj_a_final,
676        unit_stride,
677        use_pool,
678        materialize,
679    )?;
680    let b_op = contiguous::prepare_input_owned(
681        b_perm,
682        n_sum,
683        n_ro,
684        conj_b_final,
685        unit_stride,
686        use_pool,
687        materialize,
688    )?;
689    let mut c_op =
690        contiguous::prepare_output_view(&mut c_perm, n_lo, n_ro, beta, unit_stride, use_pool)?;
691
692    // 8. GEMM — dispatched through trait
693    backend::ActiveBackend::bgemm_contiguous_into(
694        &mut c_op,
695        &a_op,
696        &b_op,
697        &batch_dims,
698        m,
699        n,
700        k,
701        alpha,
702        beta,
703    )?;
704
705    // 9. Finalize
706    c_op.finalize_into(&mut c_perm)?;
707
708    Ok(())
709}
710
711/// Validate that dimensions match across operands for each axis group.
712fn validate_dimensions<ID: AxisId>(
713    plan: &Einsum2Plan<ID>,
714    a_dims: &[usize],
715    b_dims: &[usize],
716    c_dims: &[usize],
717    ia: &[ID],
718    ib: &[ID],
719    ic: &[ID],
720) -> Result<()> {
721    let find_dim = |labels: &[ID], dims: &[usize], id: &ID| -> usize {
722        labels
723            .iter()
724            .position(|x| x == id)
725            .map(|i| dims[i])
726            .unwrap()
727    };
728
729    // Batch: must match in A, B, and C
730    for id in &plan.batch {
731        let da = find_dim(ia, a_dims, id);
732        let db = find_dim(ib, b_dims, id);
733        let dc = find_dim(ic, c_dims, id);
734        if da != db || da != dc {
735            return Err(EinsumError::DimensionMismatch {
736                axis: format!("{:?}", id),
737                dim_a: da,
738                dim_b: db,
739            });
740        }
741    }
742
743    // Sum: must match in A and B
744    for id in &plan.sum {
745        let da = find_dim(ia, a_dims, id);
746        let db = find_dim(ib, b_dims, id);
747        if da != db {
748            return Err(EinsumError::DimensionMismatch {
749                axis: format!("{:?}", id),
750                dim_a: da,
751                dim_b: db,
752            });
753        }
754    }
755
756    // LO: must match in A and C
757    for id in &plan.lo {
758        let da = find_dim(ia, a_dims, id);
759        let dc = find_dim(ic, c_dims, id);
760        if da != dc {
761            return Err(EinsumError::DimensionMismatch {
762                axis: format!("{:?}", id),
763                dim_a: da,
764                dim_b: dc,
765            });
766        }
767    }
768
769    // RO: must match in B and C
770    for id in &plan.ro {
771        let db = find_dim(ib, b_dims, id);
772        let dc = find_dim(ic, c_dims, id);
773        if db != dc {
774            return Err(EinsumError::DimensionMismatch {
775                axis: format!("{:?}", id),
776                dim_a: db,
777                dim_b: dc,
778            });
779        }
780    }
781
782    Ok(())
783}
784
785#[cfg(test)]
786mod tests {
787    use super::*;
788    use strided_view::StridedArray;
789
790    #[test]
791    fn test_matmul_ij_jk_ik() {
792        // C_ik = A_ij * B_jk
793        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
794            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
795        });
796        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
797            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
798        });
799        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
800
801        einsum2_into(
802            c.view_mut(),
803            &a.view(),
804            &b.view(),
805            &['i', 'k'],
806            &['i', 'j'],
807            &['j', 'k'],
808            1.0,
809            0.0,
810        )
811        .unwrap();
812
813        assert_eq!(c.get(&[0, 0]), 19.0);
814        assert_eq!(c.get(&[0, 1]), 22.0);
815        assert_eq!(c.get(&[1, 0]), 43.0);
816        assert_eq!(c.get(&[1, 1]), 50.0);
817    }
818
819    #[test]
820    fn test_matmul_rect() {
821        // A: 2x3, B: 3x4, C: 2x4
822        let a =
823            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
824        let b =
825            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
826        let mut c = StridedArray::<f64>::row_major(&[2, 4]);
827
828        einsum2_into(
829            c.view_mut(),
830            &a.view(),
831            &b.view(),
832            &['i', 'k'],
833            &['i', 'j'],
834            &['j', 'k'],
835            1.0,
836            0.0,
837        )
838        .unwrap();
839
840        // A = [[1,2,3],[4,5,6]], B = [[1,2,3,4],[5,6,7,8],[9,10,11,12]]
841        assert_eq!(c.get(&[0, 0]), 38.0);
842        assert_eq!(c.get(&[1, 3]), 128.0);
843    }
844
845    #[test]
846    fn test_batched_matmul() {
847        // C_bik = A_bij * B_bjk
848        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
849            (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
850        });
851        let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
852            (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
853        });
854        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
855
856        einsum2_into(
857            c.view_mut(),
858            &a.view(),
859            &b.view(),
860            &['b', 'i', 'k'],
861            &['b', 'i', 'j'],
862            &['b', 'j', 'k'],
863            1.0,
864            0.0,
865        )
866        .unwrap();
867
868        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
869        // C0[0,0] = 1*1+2*3+3*5 = 22
870        assert_eq!(c.get(&[0, 0, 0]), 22.0);
871    }
872
873    #[test]
874    fn test_batched_matmul_col_major_output() {
875        // C_bik = A_bij * B_bjk with col-major output (same layout as opteinsum)
876        let a_data = vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
877        let b_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
878        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
879        let b = StridedArray::<f64>::from_parts(b_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
880        let mut c = StridedArray::<f64>::col_major(&[2, 2, 2]);
881
882        einsum2_into(
883            c.view_mut(),
884            &a.view(),
885            &b.view(),
886            &['b', 'i', 'k'],
887            &['b', 'i', 'j'],
888            &['b', 'j', 'k'],
889            1.0,
890            0.0,
891        )
892        .unwrap();
893
894        // batch 0: I * [[1,2],[3,4]] = [[1,2],[3,4]]
895        assert_eq!(c.get(&[0, 0, 0]), 1.0);
896        assert_eq!(c.get(&[0, 0, 1]), 2.0);
897        assert_eq!(c.get(&[0, 1, 0]), 3.0);
898        assert_eq!(c.get(&[0, 1, 1]), 4.0);
899        // batch 1: 2I * [[5,6],[7,8]] = [[10,12],[14,16]]
900        assert_eq!(c.get(&[1, 0, 0]), 10.0);
901        assert_eq!(c.get(&[1, 1, 1]), 16.0);
902    }
903
904    #[test]
905    fn test_outer_product() {
906        // C_ij = A_i * B_j
907        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
908        let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
909        let mut c = StridedArray::<f64>::row_major(&[3, 4]);
910
911        einsum2_into(
912            c.view_mut(),
913            &a.view(),
914            &b.view(),
915            &['i', 'j'],
916            &['i'],
917            &['j'],
918            1.0,
919            0.0,
920        )
921        .unwrap();
922
923        assert_eq!(c.get(&[0, 0]), 1.0);
924        assert_eq!(c.get(&[2, 3]), 12.0);
925    }
926
927    #[test]
928    fn test_dot_product() {
929        // C = A_i * B_i (scalar output)
930        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
931        let b = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
932        let mut c = StridedArray::<f64>::row_major(&[]);
933
934        einsum2_into(
935            c.view_mut(),
936            &a.view(),
937            &b.view(),
938            &[] as &[char],
939            &['i'],
940            &['i'],
941            1.0,
942            0.0,
943        )
944        .unwrap();
945
946        // 1*1 + 2*2 + 3*3 = 14
947        assert_eq!(c.get(&[]), 14.0);
948    }
949
950    #[test]
951    fn test_alpha_beta() {
952        // C = 2*A*B + 3*C_old
953        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
954            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
955        });
956        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
957            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
958        });
959        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
960            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
961        });
962
963        einsum2_into(
964            c.view_mut(),
965            &a.view(),
966            &b.view(),
967            &['i', 'k'],
968            &['i', 'j'],
969            &['j', 'k'],
970            2.0,
971            3.0,
972        )
973        .unwrap();
974
975        // C = 2*I*B + 3*C_old
976        assert_eq!(c.get(&[0, 0]), 32.0); // 2*1 + 3*10
977        assert_eq!(c.get(&[1, 1]), 128.0); // 2*4 + 3*40
978    }
979
980    #[test]
981    fn test_transposed_output() {
982        // C_ki = A_ij * B_jk (output transposed)
983        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
984            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
985        });
986        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
987            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
988        });
989        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
990
991        einsum2_into(
992            c.view_mut(),
993            &a.view(),
994            &b.view(),
995            &['k', 'i'], // C indexed as (k, i) instead of (i, k)
996            &['i', 'j'],
997            &['j', 'k'],
998            1.0,
999            0.0,
1000        )
1001        .unwrap();
1002
1003        // Normal matmul result: C_ik = [[19,22],[43,50]]
1004        // But C is indexed as (k, i), so C[k, i] = (A*B)[i, k]
1005        assert_eq!(c.get(&[0, 0]), 19.0); // C[k=0, i=0]
1006        assert_eq!(c.get(&[0, 1]), 43.0); // C[k=0, i=1]
1007        assert_eq!(c.get(&[1, 0]), 22.0); // C[k=1, i=0]
1008        assert_eq!(c.get(&[1, 1]), 50.0); // C[k=1, i=1]
1009    }
1010
1011    #[test]
1012    fn test_left_trace() {
1013        // C_k = sum_j (sum_i A_ij) * B_jk
1014        // left_trace=[i], sum=[j], ro=[k]
1015        let a =
1016            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1017        // A = [[1,2,3],[4,5,6]]
1018        // sum over i: [5, 7, 9]
1019        let b =
1020            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1021        // B = [[1,2],[3,4],[5,6]]
1022        let mut c = StridedArray::<f64>::row_major(&[2]);
1023
1024        einsum2_into(
1025            c.view_mut(),
1026            &a.view(),
1027            &b.view(),
1028            &['k'],
1029            &['i', 'j'],
1030            &['j', 'k'],
1031            1.0,
1032            0.0,
1033        )
1034        .unwrap();
1035
1036        // C_k = sum_j [5,7,9][j] * B[j,k]
1037        // C[0] = 5*1 + 7*3 + 9*5 = 5 + 21 + 45 = 71
1038        // C[1] = 5*2 + 7*4 + 9*6 = 10 + 28 + 54 = 92
1039        assert_eq!(c.get(&[0]), 71.0);
1040        assert_eq!(c.get(&[1]), 92.0);
1041    }
1042
1043    #[test]
1044    fn test_u32_labels() {
1045        // Same as matmul but with u32 labels
1046        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1047            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1048        });
1049        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1050            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1051        });
1052        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1053
1054        einsum2_into(
1055            c.view_mut(),
1056            &a.view(),
1057            &b.view(),
1058            &[0u32, 2],
1059            &[0u32, 1],
1060            &[1u32, 2],
1061            1.0,
1062            0.0,
1063        )
1064        .unwrap();
1065
1066        assert_eq!(c.get(&[0, 0]), 19.0);
1067        assert_eq!(c.get(&[1, 1]), 50.0);
1068    }
1069
1070    #[test]
1071    fn test_complex_matmul() {
1072        use num_complex::Complex64;
1073        let i = Complex64::i();
1074
1075        // A = [[1+i, 2], [3, 4-i]]
1076        let a_vals = [
1077            [1.0 + i, Complex64::new(2.0, 0.0)],
1078            [Complex64::new(3.0, 0.0), 4.0 - i],
1079        ];
1080        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1081
1082        // B = [[1, i], [0, 1]]
1083        let b_vals = [
1084            [Complex64::new(1.0, 0.0), i],
1085            [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
1086        ];
1087        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1088
1089        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1090
1091        einsum2_into(
1092            c.view_mut(),
1093            &a.view(),
1094            &b.view(),
1095            &['i', 'k'],
1096            &['i', 'j'],
1097            &['j', 'k'],
1098            Complex64::new(1.0, 0.0),
1099            Complex64::new(0.0, 0.0),
1100        )
1101        .unwrap();
1102
1103        // C = A * B
1104        // C[0,0] = (1+i)*1 + 2*0 = 1+i
1105        // C[0,1] = (1+i)*i + 2*1 = i+i²+2 = i-1+2 = 1+i
1106        // C[1,0] = 3*1 + (4-i)*0 = 3
1107        // C[1,1] = 3*i + (4-i)*1 = 3i+4-i = 4+2i
1108        assert_eq!(c.get(&[0, 0]), 1.0 + i);
1109        assert_eq!(c.get(&[0, 1]), 1.0 + i);
1110        assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1111        assert_eq!(c.get(&[1, 1]), 4.0 + 2.0 * i);
1112    }
1113
1114    #[test]
1115    fn test_complex_matmul_with_conj() {
1116        use num_complex::Complex64;
1117        let i = Complex64::i();
1118
1119        // A = [[1+i, 2i], [3, 4-i]]
1120        let a_vals = [[1.0 + i, 2.0 * i], [Complex64::new(3.0, 0.0), 4.0 - i]];
1121        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1122
1123        // B = identity
1124        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| {
1125            if idx[0] == idx[1] {
1126                Complex64::new(1.0, 0.0)
1127            } else {
1128                Complex64::new(0.0, 0.0)
1129            }
1130        });
1131
1132        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1133
1134        // C = conj(A) * B = conj(A)
1135        let a_conj = a.view().conj();
1136        einsum2_into(
1137            c.view_mut(),
1138            &a_conj,
1139            &b.view(),
1140            &['i', 'k'],
1141            &['i', 'j'],
1142            &['j', 'k'],
1143            Complex64::new(1.0, 0.0),
1144            Complex64::new(0.0, 0.0),
1145        )
1146        .unwrap();
1147
1148        // conj(A) = [[1-i, -2i], [3, 4+i]]
1149        assert_eq!(c.get(&[0, 0]), 1.0 - i);
1150        assert_eq!(c.get(&[0, 1]), -2.0 * i);
1151        assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1152        assert_eq!(c.get(&[1, 1]), 4.0 + i);
1153    }
1154
1155    #[test]
1156    fn test_complex_matmul_with_conj_both() {
1157        use num_complex::Complex64;
1158        let i = Complex64::i();
1159
1160        // A = [[1+i, 0], [0, 2-i]]
1161        let a_vals = [
1162            [1.0 + i, Complex64::new(0.0, 0.0)],
1163            [Complex64::new(0.0, 0.0), 2.0 - i],
1164        ];
1165        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1166
1167        // B = [[1, i], [0, 1+i]]
1168        let b_vals = [
1169            [Complex64::new(1.0, 0.0), i],
1170            [Complex64::new(0.0, 0.0), 1.0 + i],
1171        ];
1172        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1173
1174        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1175
1176        // C = conj(A) * conj(B)
1177        let a_conj = a.view().conj();
1178        let b_conj = b.view().conj();
1179        einsum2_into(
1180            c.view_mut(),
1181            &a_conj,
1182            &b_conj,
1183            &['i', 'k'],
1184            &['i', 'j'],
1185            &['j', 'k'],
1186            Complex64::new(1.0, 0.0),
1187            Complex64::new(0.0, 0.0),
1188        )
1189        .unwrap();
1190
1191        // conj(A) = [[1-i, 0], [0, 2+i]]
1192        // conj(B) = [[1, -i], [0, 1-i]]
1193        // C = conj(A) * conj(B)
1194        // C[0,0] = (1-i)*1 + 0*0 = 1-i
1195        // C[0,1] = (1-i)*(-i) + 0*(1-i) = -i+i² = -i-1 = -(1+i)
1196        // C[1,0] = 0*1 + (2+i)*0 = 0
1197        // C[1,1] = 0*(-i) + (2+i)*(1-i) = 2-2i+i-i² = 2-i+1 = 3-i
1198        assert_eq!(c.get(&[0, 0]), 1.0 - i);
1199        assert_eq!(c.get(&[0, 1]), -(1.0 + i));
1200        assert_eq!(c.get(&[1, 0]), Complex64::new(0.0, 0.0));
1201        assert_eq!(c.get(&[1, 1]), 3.0 - i);
1202    }
1203
1204    #[test]
1205    fn test_elementwise_hadamard() {
1206        // C_ijk = A_ijk * B_ijk — all batch, no contraction
1207        let a = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1208            (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64
1209        });
1210        let b = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1211            (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64 * 0.1
1212        });
1213        let mut c = StridedArray::<f64>::row_major(&[3, 4, 5]);
1214
1215        einsum2_into(
1216            c.view_mut(),
1217            &a.view(),
1218            &b.view(),
1219            &['i', 'j', 'k'],
1220            &['i', 'j', 'k'],
1221            &['i', 'j', 'k'],
1222            1.0,
1223            0.0,
1224        )
1225        .unwrap();
1226
1227        // Spot check: C[0,0,0] = 1 * 0.1 = 0.1
1228        assert!((c.get(&[0, 0, 0]) - 0.1).abs() < 1e-12);
1229        // C[2,3,4] = 60 * 6.0 = 360
1230        assert!((c.get(&[2, 3, 4]) - 360.0).abs() < 1e-10);
1231    }
1232
1233    #[test]
1234    fn test_elementwise_hadamard_with_alpha() {
1235        let a =
1236            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1237        let b =
1238            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1239        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1240
1241        einsum2_into(
1242            c.view_mut(),
1243            &a.view(),
1244            &b.view(),
1245            &['i', 'j'],
1246            &['i', 'j'],
1247            &['i', 'j'],
1248            2.0,
1249            0.0,
1250        )
1251        .unwrap();
1252
1253        // C[0,0] = 2.0 * 1 * 1 = 2.0
1254        assert_eq!(c.get(&[0, 0]), 2.0);
1255        // C[1,2] = 2.0 * 6 * 6 = 72.0
1256        assert_eq!(c.get(&[1, 2]), 72.0);
1257    }
1258
1259    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1260    #[test]
1261    fn test_einsum2_owned_matmul() {
1262        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1263            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1264        });
1265        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1266            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1267        });
1268        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1269
1270        einsum2_into_owned(
1271            c.view_mut(),
1272            a,
1273            b,
1274            &['i', 'k'],
1275            &['i', 'j'],
1276            &['j', 'k'],
1277            1.0,
1278            0.0,
1279            false,
1280            false,
1281        )
1282        .unwrap();
1283
1284        assert_eq!(c.get(&[0, 0]), 19.0);
1285        assert_eq!(c.get(&[0, 1]), 22.0);
1286        assert_eq!(c.get(&[1, 0]), 43.0);
1287        assert_eq!(c.get(&[1, 1]), 50.0);
1288    }
1289
1290    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1291    #[test]
1292    fn test_einsum2_owned_batched() {
1293        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
1294            (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
1295        });
1296        let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
1297            (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
1298        });
1299        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
1300
1301        einsum2_into_owned(
1302            c.view_mut(),
1303            a,
1304            b,
1305            &['b', 'i', 'k'],
1306            &['b', 'i', 'j'],
1307            &['b', 'j', 'k'],
1308            1.0,
1309            0.0,
1310            false,
1311            false,
1312        )
1313        .unwrap();
1314
1315        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
1316        // C0[0,0] = 1*1+2*3+3*5 = 22
1317        assert_eq!(c.get(&[0, 0, 0]), 22.0);
1318    }
1319
1320    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1321    #[test]
1322    fn test_einsum2_owned_alpha_beta() {
1323        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1324            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]]
1325        });
1326        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1327            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1328        });
1329        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1330            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
1331        });
1332
1333        einsum2_into_owned(
1334            c.view_mut(),
1335            a,
1336            b,
1337            &['i', 'k'],
1338            &['i', 'j'],
1339            &['j', 'k'],
1340            2.0,
1341            3.0,
1342            false,
1343            false,
1344        )
1345        .unwrap();
1346
1347        // C = 2*I*B + 3*C_old
1348        assert_eq!(c.get(&[0, 0]), 32.0); // 2*1 + 3*10
1349        assert_eq!(c.get(&[1, 1]), 128.0); // 2*4 + 3*40
1350    }
1351
1352    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1353    #[test]
1354    fn test_einsum2_owned_elementwise() {
1355        // All batch, no contraction -- element-wise fast path
1356        let a =
1357            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1358        let b =
1359            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1360        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1361
1362        einsum2_into_owned(
1363            c.view_mut(),
1364            a,
1365            b,
1366            &['i', 'j'],
1367            &['i', 'j'],
1368            &['i', 'j'],
1369            2.0,
1370            0.0,
1371            false,
1372            false,
1373        )
1374        .unwrap();
1375
1376        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1377        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1378    }
1379
1380    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1381    #[test]
1382    fn test_einsum2_owned_left_trace() {
1383        // C_k = sum_j (sum_i A_ij) * B_jk
1384        // left_trace=[i], sum=[j], ro=[k]
1385        let a =
1386            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1387        // A = [[1,2,3],[4,5,6]]
1388        // sum over i: [5, 7, 9]
1389        let b =
1390            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1391        // B = [[1,2],[3,4],[5,6]]
1392        let mut c = StridedArray::<f64>::row_major(&[2]);
1393
1394        einsum2_into_owned(
1395            c.view_mut(),
1396            a,
1397            b,
1398            &['k'],
1399            &['i', 'j'],
1400            &['j', 'k'],
1401            1.0,
1402            0.0,
1403            false,
1404            false,
1405        )
1406        .unwrap();
1407
1408        // C[0] = 5*1 + 7*3 + 9*5 = 5 + 21 + 45 = 71
1409        // C[1] = 5*2 + 7*4 + 9*6 = 10 + 28 + 54 = 92
1410        assert_eq!(c.get(&[0]), 71.0);
1411        assert_eq!(c.get(&[1]), 92.0);
1412    }
1413
1414    #[test]
1415    fn test_einsum2_naive_matmul() {
1416        // einsum2_naive_into with identity maps = regular matmul
1417        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1418            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1419        });
1420        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1421            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1422        });
1423        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1424
1425        einsum2_naive_into(
1426            c.view_mut(),
1427            &a.view(),
1428            &b.view(),
1429            &['i', 'k'],
1430            &['i', 'j'],
1431            &['j', 'k'],
1432            1.0,
1433            0.0,
1434            |x| x,
1435            |x| x,
1436        )
1437        .unwrap();
1438
1439        assert_eq!(c.get(&[0, 0]), 19.0);
1440        assert_eq!(c.get(&[0, 1]), 22.0);
1441        assert_eq!(c.get(&[1, 0]), 43.0);
1442        assert_eq!(c.get(&[1, 1]), 50.0);
1443    }
1444
1445    #[test]
1446    fn test_einsum2_naive_elementwise() {
1447        // Element-wise (all batch) with identity maps
1448        let a =
1449            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1450        let b =
1451            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1452        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1453
1454        einsum2_naive_into(
1455            c.view_mut(),
1456            &a.view(),
1457            &b.view(),
1458            &['i', 'j'],
1459            &['i', 'j'],
1460            &['i', 'j'],
1461            2.0,
1462            0.0,
1463            |x| x,
1464            |x| x,
1465        )
1466        .unwrap();
1467
1468        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1469        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1470    }
1471
1472    #[test]
1473    fn test_einsum2_naive_custom_type() {
1474        // Custom scalar type that does NOT implement ElementOpApply.
1475        // This demonstrates that einsum2_naive_into works with custom types.
1476        use num_traits::{One, Zero};
1477
1478        #[derive(Debug, Clone, Copy, PartialEq)]
1479        struct MyVal(f64);
1480
1481        impl Default for MyVal {
1482            fn default() -> Self {
1483                MyVal(0.0)
1484            }
1485        }
1486
1487        impl std::ops::Add for MyVal {
1488            type Output = Self;
1489            fn add(self, rhs: Self) -> Self {
1490                MyVal(self.0 + rhs.0)
1491            }
1492        }
1493
1494        impl std::ops::Mul for MyVal {
1495            type Output = Self;
1496            fn mul(self, rhs: Self) -> Self {
1497                MyVal(self.0 * rhs.0)
1498            }
1499        }
1500
1501        impl Zero for MyVal {
1502            fn zero() -> Self {
1503                MyVal(0.0)
1504            }
1505            fn is_zero(&self) -> bool {
1506                self.0 == 0.0
1507            }
1508        }
1509
1510        impl One for MyVal {
1511            fn one() -> Self {
1512                MyVal(1.0)
1513            }
1514        }
1515
1516        // C_ik = A_ij * B_jk (matrix multiply with custom type)
1517        let a = StridedArray::from_parts(
1518            vec![MyVal(1.0), MyVal(2.0), MyVal(3.0), MyVal(4.0)],
1519            &[2, 2],
1520            &[2, 1],
1521            0,
1522        )
1523        .unwrap();
1524        let b = StridedArray::from_parts(
1525            vec![MyVal(5.0), MyVal(6.0), MyVal(7.0), MyVal(8.0)],
1526            &[2, 2],
1527            &[2, 1],
1528            0,
1529        )
1530        .unwrap();
1531        let mut c = StridedArray::<MyVal>::col_major(&[2, 2]);
1532
1533        einsum2_naive_into(
1534            c.view_mut(),
1535            &a.view(),
1536            &b.view(),
1537            &['i', 'k'],
1538            &['i', 'j'],
1539            &['j', 'k'],
1540            MyVal(1.0),
1541            MyVal(0.0),
1542            |x| x,
1543            |x| x,
1544        )
1545        .unwrap();
1546
1547        // C = [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1548        //   = [[19, 22], [43, 50]]
1549        assert_eq!(c.get(&[0, 0]), MyVal(19.0));
1550        assert_eq!(c.get(&[0, 1]), MyVal(22.0));
1551        assert_eq!(c.get(&[1, 0]), MyVal(43.0));
1552        assert_eq!(c.get(&[1, 1]), MyVal(50.0));
1553    }
1554
1555    #[test]
1556    fn test_einsum2_naive_left_trace() {
1557        // C_k = sum_j (sum_i A_ij) * B_jk with map_a = identity
1558        let a =
1559            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1560        let b =
1561            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1562        let mut c = StridedArray::<f64>::row_major(&[2]);
1563
1564        einsum2_naive_into(
1565            c.view_mut(),
1566            &a.view(),
1567            &b.view(),
1568            &['k'],
1569            &['i', 'j'],
1570            &['j', 'k'],
1571            1.0,
1572            0.0,
1573            |x| x,
1574            |x| x,
1575        )
1576        .unwrap();
1577
1578        assert_eq!(c.get(&[0]), 71.0);
1579        assert_eq!(c.get(&[1]), 92.0);
1580    }
1581
1582    /// Test backend that delegates to naive GEMM loops.
1583    struct TestNaiveBackend;
1584
1585    impl Backend<f64> for TestNaiveBackend {
1586        const MATERIALIZES_CONJ: bool = false;
1587        const REQUIRES_UNIT_STRIDE: bool = false;
1588
1589        fn bgemm_contiguous_into(
1590            c: &mut contiguous::ContiguousOperandMut<f64>,
1591            a: &contiguous::ContiguousOperand<f64>,
1592            b: &contiguous::ContiguousOperand<f64>,
1593            batch_dims: &[usize],
1594            m: usize,
1595            n: usize,
1596            k: usize,
1597            alpha: f64,
1598            beta: f64,
1599        ) -> strided_view::Result<()> {
1600            // Delegate to the contiguous naive GEMM loop
1601            let a_ptr = a.ptr();
1602            let b_ptr = b.ptr();
1603            let c_ptr = c.ptr();
1604            let a_rs = a.row_stride();
1605            let a_cs = a.col_stride();
1606            let b_rs = b.row_stride();
1607            let b_cs = b.col_stride();
1608            let c_rs = c.row_stride();
1609            let c_cs = c.col_stride();
1610
1611            let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1612            while batch_idx.next().is_some() {
1613                let a_base = batch_idx.offset(a.batch_strides());
1614                let b_base = batch_idx.offset(b.batch_strides());
1615                let c_base = batch_idx.offset(c.batch_strides());
1616
1617                for i in 0..m {
1618                    for j in 0..n {
1619                        let mut acc = 0.0f64;
1620                        for l in 0..k {
1621                            let a_val = unsafe {
1622                                *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1623                            };
1624                            let b_val = unsafe {
1625                                *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1626                            };
1627                            acc += a_val * b_val;
1628                        }
1629                        unsafe {
1630                            let c_elem =
1631                                c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1632                            if beta == 0.0 {
1633                                *c_elem = alpha * acc;
1634                            } else {
1635                                *c_elem = alpha * acc + beta * (*c_elem);
1636                            }
1637                        }
1638                    }
1639                }
1640            }
1641            Ok(())
1642        }
1643    }
1644
1645    #[test]
1646    fn test_einsum2_with_backend_matmul() {
1647        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1648            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1649        });
1650        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1651            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1652        });
1653        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1654
1655        einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1656            c.view_mut(),
1657            &a.view(),
1658            &b.view(),
1659            &['i', 'k'],
1660            &['i', 'j'],
1661            &['j', 'k'],
1662            1.0,
1663            0.0,
1664        )
1665        .unwrap();
1666
1667        assert_eq!(c.get(&[0, 0]), 19.0);
1668        assert_eq!(c.get(&[0, 1]), 22.0);
1669        assert_eq!(c.get(&[1, 0]), 43.0);
1670        assert_eq!(c.get(&[1, 1]), 50.0);
1671    }
1672
1673    #[test]
1674    fn test_einsum2_with_backend_elementwise() {
1675        let a =
1676            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1677        let b =
1678            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1679        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1680
1681        einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1682            c.view_mut(),
1683            &a.view(),
1684            &b.view(),
1685            &['i', 'j'],
1686            &['i', 'j'],
1687            &['i', 'j'],
1688            2.0,
1689            0.0,
1690        )
1691        .unwrap();
1692
1693        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1694        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1695    }
1696
1697    #[test]
1698    fn test_einsum2_with_backend_custom_type() {
1699        use num_traits::{One, Zero};
1700
1701        #[derive(Debug, Clone, Copy, PartialEq)]
1702        struct Tropical(f64);
1703
1704        impl Default for Tropical {
1705            fn default() -> Self {
1706                Tropical(0.0)
1707            }
1708        }
1709
1710        impl std::ops::Add for Tropical {
1711            type Output = Self;
1712            fn add(self, rhs: Self) -> Self {
1713                Tropical(self.0 + rhs.0)
1714            }
1715        }
1716
1717        impl std::ops::Mul for Tropical {
1718            type Output = Self;
1719            fn mul(self, rhs: Self) -> Self {
1720                Tropical(self.0 * rhs.0)
1721            }
1722        }
1723
1724        impl Zero for Tropical {
1725            fn zero() -> Self {
1726                Tropical(0.0)
1727            }
1728            fn is_zero(&self) -> bool {
1729                self.0 == 0.0
1730            }
1731        }
1732
1733        impl One for Tropical {
1734            fn one() -> Self {
1735                Tropical(1.0)
1736            }
1737        }
1738
1739        struct TropicalBackend;
1740
1741        impl Backend<Tropical> for TropicalBackend {
1742            const MATERIALIZES_CONJ: bool = false;
1743            const REQUIRES_UNIT_STRIDE: bool = false;
1744
1745            fn bgemm_contiguous_into(
1746                c: &mut contiguous::ContiguousOperandMut<Tropical>,
1747                a: &contiguous::ContiguousOperand<Tropical>,
1748                b: &contiguous::ContiguousOperand<Tropical>,
1749                batch_dims: &[usize],
1750                m: usize,
1751                n: usize,
1752                k: usize,
1753                alpha: Tropical,
1754                beta: Tropical,
1755            ) -> strided_view::Result<()> {
1756                // Simple naive GEMM for Tropical type
1757                let a_ptr = a.ptr();
1758                let b_ptr = b.ptr();
1759                let c_ptr = c.ptr();
1760                let a_rs = a.row_stride();
1761                let a_cs = a.col_stride();
1762                let b_rs = b.row_stride();
1763                let b_cs = b.col_stride();
1764                let c_rs = c.row_stride();
1765                let c_cs = c.col_stride();
1766
1767                let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1768                while batch_idx.next().is_some() {
1769                    let a_base = batch_idx.offset(a.batch_strides());
1770                    let b_base = batch_idx.offset(b.batch_strides());
1771                    let c_base = batch_idx.offset(c.batch_strides());
1772
1773                    for i in 0..m {
1774                        for j in 0..n {
1775                            let mut acc = Tropical::zero();
1776                            for l in 0..k {
1777                                let a_val = unsafe {
1778                                    *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1779                                };
1780                                let b_val = unsafe {
1781                                    *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1782                                };
1783                                acc = acc + a_val * b_val;
1784                            }
1785                            unsafe {
1786                                let c_elem =
1787                                    c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1788                                if beta == Tropical::zero() {
1789                                    *c_elem = alpha * acc;
1790                                } else {
1791                                    *c_elem = alpha * acc + beta * (*c_elem);
1792                                }
1793                            }
1794                        }
1795                    }
1796                }
1797                Ok(())
1798            }
1799        }
1800
1801        let a = StridedArray::from_parts(
1802            vec![Tropical(1.0), Tropical(2.0), Tropical(3.0), Tropical(4.0)],
1803            &[2, 2],
1804            &[2, 1],
1805            0,
1806        )
1807        .unwrap();
1808        let b = StridedArray::from_parts(
1809            vec![Tropical(5.0), Tropical(6.0), Tropical(7.0), Tropical(8.0)],
1810            &[2, 2],
1811            &[2, 1],
1812            0,
1813        )
1814        .unwrap();
1815        let mut c = StridedArray::<Tropical>::col_major(&[2, 2]);
1816
1817        einsum2_with_backend_into::<_, TropicalBackend, _>(
1818            c.view_mut(),
1819            &a.view(),
1820            &b.view(),
1821            &['i', 'k'],
1822            &['i', 'j'],
1823            &['j', 'k'],
1824            Tropical(1.0),
1825            Tropical(0.0),
1826        )
1827        .unwrap();
1828
1829        // C = [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1830        //   = [[19, 22], [43, 50]]
1831        assert_eq!(c.get(&[0, 0]), Tropical(19.0));
1832        assert_eq!(c.get(&[0, 1]), Tropical(22.0));
1833        assert_eq!(c.get(&[1, 0]), Tropical(43.0));
1834        assert_eq!(c.get(&[1, 1]), Tropical(50.0));
1835    }
1836}