Skip to main content

strided_einsum2/
lib.rs

1//! Binary Einstein summation on strided views.
2//!
3//! Provides `einsum2_into` for computing binary tensor contractions with
4//! accumulation semantics: `C = alpha * A * B + beta * C`.
5//!
6//! # Example
7//!
8//! ```
9//! use strided_view::StridedArray;
10//! use strided_einsum2::einsum2_into;
11//!
12//! // Matrix multiply: C_ik = A_ij * B_jk
13//! let a = StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
14//! let b = StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
15//! let mut c = StridedArray::<f64>::row_major(&[2, 2]);
16//!
17//! einsum2_into(
18//!     c.view_mut(), &a.view(), &b.view(),
19//!     &['i', 'k'], &['i', 'j'], &['j', 'k'],
20//!     1.0, 0.0,
21//! ).unwrap();
22//! ```
23
24#[cfg(all(feature = "blas", feature = "blas-inject"))]
25compile_error!("Features `blas` and `blas-inject` are mutually exclusive.");
26
27#[cfg(any(
28    all(feature = "blas-accelerate", feature = "blas-openblas"),
29    all(feature = "blas-accelerate", feature = "blas-mkl"),
30    all(feature = "blas-openblas", feature = "blas-mkl")
31))]
32compile_error!("Select at most one explicit BLAS provider feature.");
33
34#[cfg(all(feature = "blas-inject", not(feature = "blas")))]
35extern crate cblas_inject as cblas_sys;
36#[cfg(all(feature = "blas", not(feature = "blas-inject")))]
37extern crate cblas_sys;
38
39#[cfg(any(
40    all(feature = "blas", not(feature = "blas-inject")),
41    all(feature = "blas-inject", not(feature = "blas"))
42))]
43pub mod bgemm_blas;
44
45#[cfg(feature = "faer")]
46/// Batched GEMM backend using the [`faer`] library.
47pub mod bgemm_faer;
48/// Batched GEMM fallback using explicit loops.
49pub mod bgemm_naive;
50/// GEMM-ready operand types and preparation functions for contiguous data.
51pub mod contiguous;
52/// Axis-based general dot product API.
53pub mod dot_general;
54/// Contraction planning: axis classification and permutation computation.
55pub mod plan;
56/// Raw borrowed-layout batched GEMM entry points.
57pub mod raw_bgemm;
58/// Trace-axis reduction (summing axes that appear only in one operand).
59pub mod trace;
60/// Shared helpers (permutation inversion, multi-index iteration, dimension fusion).
61pub mod util;
62
63/// Backend abstraction for batched GEMM dispatch.
64pub mod backend;
65
66use std::any::TypeId;
67use std::fmt::Debug;
68use std::hash::Hash;
69
70use strided_kernel::zip_map2_into;
71#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
72use strided_view::StridedArray;
73use strided_view::{Adjoint, Conj, ElementOp, ElementOpApply};
74
75pub use strided_traits::ScalarBase;
76pub use strided_view::{
77    col_major_strides, RawStridedMut, RawStridedRef, StridedView, StridedViewMut,
78};
79
80pub use backend::Backend;
81pub use dot_general::{dot_general_into, dot_general_with_backend_into, DotGeneralConfig};
82pub use plan::Einsum2Plan;
83pub use raw_bgemm::{
84    bgemm_raw_strided_into, bgemm_raw_strided_into_unchecked, bgemm_raw_with_backend_into,
85    bgemm_raw_with_backend_into_unchecked,
86};
87
88/// Trait alias for axis label types.
89pub trait AxisId: Clone + Eq + Hash + Debug {}
90impl<T: Clone + Eq + Hash + Debug> AxisId for T {}
91
92// ScalarBase is re-exported from strided_traits (see above).
93// It no longer requires ElementOpApply, enabling custom scalar types
94// (e.g., tropical semiring) to work with Identity-only views.
95
96/// Trait alias for element types supported by einsum operations.
97///
98/// When BLAS is enabled, `Scalar` follows the active backend and requires
99/// `BlasGemm`, even if faer is also compiled for explicit backend use.
100#[cfg(any(
101    all(feature = "blas", not(feature = "blas-inject")),
102    all(feature = "blas-inject", not(feature = "blas"))
103))]
104pub trait Scalar: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
105
106#[cfg(any(
107    all(feature = "blas", not(feature = "blas-inject")),
108    all(feature = "blas-inject", not(feature = "blas"))
109))]
110impl<T> Scalar for T where T: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
111
112/// Trait alias for element types supported by the faer active backend.
113///
114/// When only faer is enabled, this additionally requires `faer::ComplexField`
115/// so that the faer GEMM backend can be used.
116#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
117pub trait Scalar: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
118
119#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
120impl<T> Scalar for T where T: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
121
122/// Trait alias for element types (without `faer` or BLAS features).
123#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
124pub trait Scalar: ScalarBase + ElementOpApply {}
125
126#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
127impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
128
129/// Placeholder trait definition for invalid mutually-exclusive feature combinations.
130///
131/// The crate emits `compile_error!` above for these combinations. This trait only
132/// avoids cascading type-resolution errors so users see the intended diagnostics.
133#[cfg(all(feature = "blas", feature = "blas-inject"))]
134pub trait Scalar: ScalarBase + ElementOpApply {}
135
136#[cfg(all(feature = "blas", feature = "blas-inject"))]
137impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
138
139/// Errors specific to einsum operations.
140#[derive(Debug, thiserror::Error)]
141pub enum EinsumError {
142    #[error("duplicate axis label: {0}")]
143    DuplicateAxis(String),
144    #[error("output axis {0} not found in any input")]
145    OrphanOutputAxis(String),
146    #[error("dimension mismatch for axis {axis:?}: {dim_a} vs {dim_b}")]
147    DimensionMismatch {
148        axis: String,
149        dim_a: usize,
150        dim_b: usize,
151    },
152    #[error("invalid dot-general config: {0}")]
153    InvalidDotGeneralConfig(String),
154    #[error("output shape mismatch: expected {expected:?}, got {got:?}")]
155    OutputShapeMismatch {
156        expected: Vec<usize>,
157        got: Vec<usize>,
158    },
159    #[error(transparent)]
160    Strided(#[from] strided_view::StridedError),
161}
162
163/// Convenience alias for `Result<T, EinsumError>`.
164pub type Result<T> = std::result::Result<T, EinsumError>;
165
166/// Returns `true` if the given `ElementOp` type represents conjugation.
167///
168/// - `Identity` / `Transpose` → `false` (no per-element conjugation)
169/// - `Conj` / `Adjoint` → `true` (per-element conjugation needed)
170///
171/// For scalar types, `Transpose::apply(x) = x` (identity) and the dimension
172/// swap is already reflected in the view's strides/dims.  Similarly,
173/// `Adjoint::apply(x) = x.conj()` with the dimension swap in the view.
174fn op_is_conj<Op: 'static>() -> bool {
175    TypeId::of::<Op>() == TypeId::of::<Conj>() || TypeId::of::<Op>() == TypeId::of::<Adjoint>()
176}
177
178/// Binary einsum contraction: `C = alpha * contract(A, B) + beta * C`.
179///
180/// `ic`, `ia`, `ib` are axis labels for C, A, B respectively.
181/// Axes are classified as:
182/// - **batch**: in A, B, and C
183/// - **lo** (left-output): in A and C, not B
184/// - **ro** (right-output): in B and C, not A
185/// - **sum** (contraction): in A and B, not C
186///
187/// Trace axes (only in A or only in B) are detected and reduced lazily.
188pub fn einsum2_into<T: Scalar, OpA, OpB, ID: AxisId>(
189    c: StridedViewMut<T>,
190    a: &StridedView<T, OpA>,
191    b: &StridedView<T, OpB>,
192    ic: &[ID],
193    ia: &[ID],
194    ib: &[ID],
195    alpha: T,
196    beta: T,
197) -> Result<()>
198where
199    OpA: ElementOp<T> + 'static,
200    OpB: ElementOp<T> + 'static,
201{
202    // 1. Build plan
203    let plan = Einsum2Plan::new(ia, ib, ic)?;
204
205    // 2. Validate dimension consistency across operands
206    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
207
208    // 3. Reduce trace axes if present; determine conjugation flags.
209    //    When trace reduction occurs, Op is already applied during the reduction,
210    //    so conj flag is false. Otherwise, we strip the Op and pass a conj flag
211    //    to the GEMM kernel (avoiding materialization).
212    let left_trace = trace::find_trace_indices(ia, ib, ic);
213    let (a_buf, conj_a) = if !left_trace.is_empty() {
214        (Some(trace::reduce_trace_axes(a, &left_trace)?), false)
215    } else {
216        (None, op_is_conj::<OpA>())
217    };
218
219    let a_view: StridedView<T> = match a_buf.as_ref() {
220        Some(buf) => buf.view(),
221        None => StridedView::new(a.data(), a.dims(), a.strides(), a.offset())
222            .expect("strip_op_view: metadata already validated"),
223    };
224
225    let right_trace = trace::find_trace_indices(ib, ia, ic);
226    let (b_buf, conj_b) = if !right_trace.is_empty() {
227        (Some(trace::reduce_trace_axes(b, &right_trace)?), false)
228    } else {
229        (None, op_is_conj::<OpB>())
230    };
231
232    let b_view: StridedView<T> = match b_buf.as_ref() {
233        Some(buf) => buf.view(),
234        None => StridedView::new(b.data(), b.dims(), b.strides(), b.offset())
235            .expect("strip_op_view: metadata already validated"),
236    };
237
238    // 4. Dispatch to GEMM
239    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
240    {
241        let conj_fn = make_conj_fn::<T>();
242        einsum2_dispatch::<T, backend::ActiveBackend, _>(
243            c, &a_view, &b_view, &plan, alpha, beta, conj_a, conj_b, conj_fn,
244        )?;
245    }
246
247    #[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
248    {
249        let a_perm = a_view.permute(&plan.left_perm)?;
250        let b_perm = b_view.permute(&plan.right_perm)?;
251        let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
252
253        if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
254            let mul_fn = move |a_val: T, b_val: T| -> T {
255                let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
256                let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
257                alpha * a_c * b_c
258            };
259            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
260            return Ok(());
261        }
262
263        bgemm_naive::bgemm_strided_into(
264            &mut c_perm,
265            &a_perm,
266            &b_perm,
267            plan.batch.len(),
268            plan.lo.len(),
269            plan.ro.len(),
270            plan.sum.len(),
271            alpha,
272            beta,
273            conj_a,
274            conj_b,
275        )?;
276    }
277
278    Ok(())
279}
280
281/// Binary einsum for custom scalar types using naive GEMM.
282///
283/// Like [`einsum2_into`] but works with any `ScalarBase` type (no `ElementOpApply` required).
284/// Uses closures `map_a` and `map_b` for per-element transformation instead of
285/// conjugation flags. Always dispatches to the naive GEMM kernel.
286///
287/// Views must use `Identity` element operations (the default).
288pub fn einsum2_naive_into<T, ID, MapA, MapB>(
289    c: StridedViewMut<T>,
290    a: &StridedView<T>,
291    b: &StridedView<T>,
292    ic: &[ID],
293    ia: &[ID],
294    ib: &[ID],
295    alpha: T,
296    beta: T,
297    map_a: MapA,
298    map_b: MapB,
299) -> Result<()>
300where
301    T: ScalarBase,
302    ID: AxisId,
303    MapA: Fn(T) -> T + strided_kernel::MaybeSync,
304    MapB: Fn(T) -> T + strided_kernel::MaybeSync,
305{
306    let plan = Einsum2Plan::new(ia, ib, ic)?;
307    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
308
309    // Reduce trace axes if present.
310    // When trace reduction occurs, map is applied via map_into before reduction,
311    // so we use identity map for GEMM. Otherwise, pass through the original map.
312    let left_trace = trace::find_trace_indices(ia, ib, ic);
313    let (a_buf, use_map_a) = if !left_trace.is_empty() {
314        let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(a.dims()) };
315        strided_kernel::map_into(&mut mapped.view_mut(), a, &map_a)?;
316        let reduced = trace::reduce_trace_axes(&mapped.view(), &left_trace)?;
317        (Some(reduced), false)
318    } else {
319        (None, true)
320    };
321    let a_view: StridedView<T> = match a_buf.as_ref() {
322        Some(buf) => buf.view(),
323        None => a.clone(),
324    };
325
326    let right_trace = trace::find_trace_indices(ib, ia, ic);
327    let (b_buf, use_map_b) = if !right_trace.is_empty() {
328        let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(b.dims()) };
329        strided_kernel::map_into(&mut mapped.view_mut(), b, &map_b)?;
330        let reduced = trace::reduce_trace_axes(&mapped.view(), &right_trace)?;
331        (Some(reduced), false)
332    } else {
333        (None, true)
334    };
335    let b_view: StridedView<T> = match b_buf.as_ref() {
336        Some(buf) => buf.view(),
337        None => b.clone(),
338    };
339
340    let a_perm = a_view.permute(&plan.left_perm)?;
341    let b_perm = b_view.permute(&plan.right_perm)?;
342    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
343
344    // Element-wise fast path
345    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
346        let mul_fn = move |a_val: T, b_val: T| -> T {
347            let a_c = if use_map_a { map_a(a_val) } else { a_val };
348            let b_c = if use_map_b { map_b(b_val) } else { b_val };
349            alpha * a_c * b_c
350        };
351        zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
352        return Ok(());
353    }
354
355    let final_map_a: Box<dyn Fn(T) -> T> = if use_map_a {
356        Box::new(map_a)
357    } else {
358        Box::new(|x| x)
359    };
360    let final_map_b: Box<dyn Fn(T) -> T> = if use_map_b {
361        Box::new(map_b)
362    } else {
363        Box::new(|x| x)
364    };
365
366    bgemm_naive::bgemm_strided_into_with_map(
367        &mut c_perm,
368        &a_perm,
369        &b_perm,
370        plan.batch.len(),
371        plan.lo.len(),
372        plan.ro.len(),
373        plan.sum.len(),
374        alpha,
375        beta,
376        final_map_a,
377        final_map_b,
378    )?;
379
380    Ok(())
381}
382
383/// Binary einsum with a pluggable GEMM backend.
384///
385/// Like [`einsum2_into`] but works with any `ScalarBase` type and dispatches
386/// to the caller-provided GEMM backend `B`. Views must use `Identity` element
387/// operations (the default).
388///
389/// External crates can implement [`Backend`] for custom scalar types
390/// (e.g., tropical semiring) and pass the backend here.
391pub fn einsum2_with_backend_into<T, B, ID>(
392    c: StridedViewMut<T>,
393    a: &StridedView<T>,
394    b: &StridedView<T>,
395    ic: &[ID],
396    ia: &[ID],
397    ib: &[ID],
398    alpha: T,
399    beta: T,
400) -> Result<()>
401where
402    T: ScalarBase,
403    B: Backend<T>,
404    ID: AxisId,
405{
406    let plan = Einsum2Plan::new(ia, ib, ic)?;
407    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
408
409    // Trace reduction (plain sum, Identity views)
410    let left_trace = trace::find_trace_indices(ia, ib, ic);
411    let a_buf = if !left_trace.is_empty() {
412        Some(trace::reduce_trace_axes(a, &left_trace)?)
413    } else {
414        None
415    };
416    let a_view: StridedView<T> = match a_buf.as_ref() {
417        Some(buf) => buf.view(),
418        None => a.clone(),
419    };
420
421    let right_trace = trace::find_trace_indices(ib, ia, ic);
422    let b_buf = if !right_trace.is_empty() {
423        Some(trace::reduce_trace_axes(b, &right_trace)?)
424    } else {
425        None
426    };
427    let b_view: StridedView<T> = match b_buf.as_ref() {
428        Some(buf) => buf.view(),
429        None => b.clone(),
430    };
431
432    // No conjugation for generic backend path
433    einsum2_dispatch::<T, B, _>(c, &a_view, &b_view, &plan, alpha, beta, false, false, None)
434}
435
436/// Build the conj materialization function pointer for the active backend.
437///
438/// When the backend requires conj to be materialized into data (e.g. CBLAS),
439/// returns `Some(conj_apply)`. Otherwise returns `None` (backend handles conj
440/// via transpose flags or similar).
441#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
442fn make_conj_fn<T: Scalar>() -> Option<fn(T) -> T> {
443    if <backend::ActiveBackend as Backend<T>>::MATERIALIZES_CONJ {
444        Some(|x| Conj::apply(x))
445    } else {
446        None
447    }
448}
449
450fn scale_or_zero_strided_mut<T: ScalarBase>(c: &mut StridedViewMut<T>, beta: T) {
451    if c.is_empty() {
452        return;
453    }
454
455    let dims = c.dims().to_vec();
456    let strides = c.strides().to_vec();
457    let ptr = c.as_mut_ptr();
458    let zero = T::zero();
459
460    fn visit<T: ScalarBase>(
461        ptr: *mut T,
462        dims: &[usize],
463        strides: &[isize],
464        axis: usize,
465        offset: isize,
466        beta: T,
467        zero: T,
468    ) {
469        if axis == dims.len() {
470            unsafe {
471                let dst = ptr.offset(offset);
472                if beta == zero {
473                    *dst = zero;
474                } else {
475                    *dst = beta * *dst;
476                }
477            }
478            return;
479        }
480
481        for i in 0..dims[axis] {
482            visit(
483                ptr,
484                dims,
485                strides,
486                axis + 1,
487                offset + i as isize * strides[axis],
488                beta,
489                zero,
490            );
491        }
492    }
493
494    visit(ptr, &dims, &strides, 0, 0, beta, zero);
495}
496
497/// Internal GEMM dispatch, generic over backend.
498///
499/// Called after trace reduction with plain Identity views. Handles:
500/// 1. Permutation to canonical order
501/// 2. Element-wise fast path (if applicable)
502/// 3. Contiguous preparation via `prepare_input_view`
503/// 4. GEMM via `B::bgemm_contiguous_into`
504/// 5. Finalize (copy-back if needed)
505///
506/// `conj_fn` is the materialization function for backends that need conj
507/// applied to data before GEMM. Pass `None` when conj_a/conj_b are both false
508/// or when the backend handles conj natively (via flags).
509pub(crate) fn einsum2_dispatch<T, B, ID>(
510    c: StridedViewMut<T>,
511    a: &StridedView<T>,
512    b: &StridedView<T>,
513    plan: &Einsum2Plan<ID>,
514    alpha: T,
515    beta: T,
516    conj_a: bool,
517    conj_b: bool,
518    conj_fn: Option<fn(T) -> T>,
519) -> Result<()>
520where
521    T: ScalarBase,
522    B: Backend<T>,
523    ID: AxisId,
524{
525    // 1. Permute to canonical order
526    let a_perm = a.permute(&plan.left_perm)?;
527    let b_perm = b.permute(&plan.right_perm)?;
528    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
529
530    if c_perm.is_empty() {
531        return Ok(());
532    }
533
534    // 2. Fast path: element-wise (all batch, no contraction)
535    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
536        if !conj_a && !conj_b && alpha == T::one() {
537            zip_map2_into(&mut c_perm, &a_perm, &b_perm, |a_val, b_val| a_val * b_val)?;
538        } else if !conj_a && !conj_b {
539            let mul_fn = move |a_val: T, b_val: T| -> T { alpha * a_val * b_val };
540            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
541        } else {
542            let conj_fn = conj_fn.unwrap_or(|x| x);
543            let mul_fn = move |a_val: T, b_val: T| -> T {
544                let a_c = if conj_a { conj_fn(a_val) } else { a_val };
545                let b_c = if conj_b { conj_fn(b_val) } else { b_val };
546                alpha * a_c * b_c
547            };
548            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
549        }
550        return Ok(());
551    }
552
553    // 3. Prepare contiguous operands
554    let n_lo = plan.lo.len();
555    let n_ro = plan.ro.len();
556    let n_sum = plan.sum.len();
557    let use_pool = true;
558    let materialize = if B::MATERIALIZES_CONJ { conj_fn } else { None };
559
560    let a_op = contiguous::prepare_input_view(
561        &a_perm,
562        n_lo,
563        n_sum,
564        conj_a,
565        B::REQUIRES_UNIT_STRIDE,
566        use_pool,
567        materialize,
568    )?;
569    let b_op = contiguous::prepare_input_view(
570        &b_perm,
571        n_sum,
572        n_ro,
573        conj_b,
574        B::REQUIRES_UNIT_STRIDE,
575        use_pool,
576        materialize,
577    )?;
578    let mut c_op = contiguous::prepare_output_view(
579        &mut c_perm,
580        n_lo,
581        n_ro,
582        beta,
583        B::REQUIRES_UNIT_STRIDE,
584        use_pool,
585    )?;
586
587    // Compute fused dimension sizes
588    let lo_dims = &a_perm.dims()[..n_lo];
589    let sum_dims = &a_perm.dims()[n_lo..n_lo + n_sum];
590    let batch_dims = &a_perm.dims()[n_lo + n_sum..];
591    let ro_dims = &b_perm.dims()[n_sum..n_sum + n_ro];
592    if sum_dims.iter().any(|&dim| dim == 0) {
593        scale_or_zero_strided_mut(&mut c_perm, beta);
594        return Ok(());
595    }
596    let m: usize = lo_dims.iter().product::<usize>().max(1);
597    let k: usize = sum_dims.iter().product::<usize>().max(1);
598    let n: usize = ro_dims.iter().product::<usize>().max(1);
599
600    // 4. GEMM — dispatched through trait
601    B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta)?;
602
603    // 5. Finalize
604    c_op.finalize_into(&mut c_perm)?;
605
606    Ok(())
607}
608
609/// Binary einsum accepting owned inputs for zero-copy optimization.
610///
611/// Same semantics as [`einsum2_into`] but accepts owned `StridedArray` inputs.
612/// When inputs have non-contiguous strides after permutation, ownership
613/// transfer avoids allocating separate buffers. For contiguous inputs,
614/// the behavior is identical.
615///
616/// `conj_a` and `conj_b` indicate whether to conjugate elements of A/B.
617#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
618pub fn einsum2_into_owned<T: Scalar, ID: AxisId>(
619    c: StridedViewMut<T>,
620    a: StridedArray<T>,
621    b: StridedArray<T>,
622    ic: &[ID],
623    ia: &[ID],
624    ib: &[ID],
625    alpha: T,
626    beta: T,
627    conj_a: bool,
628    conj_b: bool,
629) -> Result<()>
630where
631    backend::ActiveBackend: Backend<T>,
632{
633    // 1. Build plan
634    let plan = Einsum2Plan::new(ia, ib, ic)?;
635
636    // 2. Validate dimensions
637    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
638
639    // 3. Trace reduction: reduce trace axes if present.
640    //    When trace reduction occurs, conjugation is applied during reduction,
641    //    so the conj flag becomes false. Otherwise keep the caller's flag.
642    let left_trace = trace::find_trace_indices(ia, ib, ic);
643    let (a_for_gemm, conj_a_final) = if !left_trace.is_empty() {
644        (trace::reduce_trace_axes(&a.view(), &left_trace)?, false)
645    } else {
646        (a, conj_a)
647    };
648
649    let right_trace = trace::find_trace_indices(ib, ia, ic);
650    let (b_for_gemm, conj_b_final) = if !right_trace.is_empty() {
651        (trace::reduce_trace_axes(&b.view(), &right_trace)?, false)
652    } else {
653        (b, conj_b)
654    };
655
656    // 4. Permute to canonical order (metadata-only on owned arrays)
657    let a_perm = a_for_gemm.permuted(&plan.left_perm)?;
658    let b_perm = b_for_gemm.permuted(&plan.right_perm)?;
659    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
660
661    let n_lo = plan.lo.len();
662    let n_ro = plan.ro.len();
663    let n_sum = plan.sum.len();
664
665    // 5. Fast path: element-wise (all batch, no contraction)
666    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
667        let mul_fn = move |a_val: T, b_val: T| -> T {
668            let a_c = if conj_a_final {
669                Conj::apply(a_val)
670            } else {
671                a_val
672            };
673            let b_c = if conj_b_final {
674                Conj::apply(b_val)
675            } else {
676                b_val
677            };
678            alpha * a_c * b_c
679        };
680        zip_map2_into(&mut c_perm, &a_perm.view(), &b_perm.view(), mul_fn)?;
681        return Ok(());
682    }
683
684    // 6. Extract dimension sizes BEFORE consuming arrays via prepare_input_owned
685    let a_dims_perm = a_perm.dims().to_vec();
686    let b_dims_perm = b_perm.dims().to_vec();
687
688    let lo_dims = &a_dims_perm[..n_lo];
689    let sum_dims = &a_dims_perm[n_lo..n_lo + n_sum];
690    let batch_dims = a_dims_perm[n_lo + n_sum..].to_vec();
691    let ro_dims = &b_dims_perm[n_sum..n_sum + n_ro];
692    let m: usize = lo_dims.iter().product::<usize>().max(1);
693    let k: usize = sum_dims.iter().product::<usize>().max(1);
694    let n: usize = ro_dims.iter().product::<usize>().max(1);
695
696    // 7. Prepare contiguous operands (owned path -- avoids extra copies)
697    let conj_fn = make_conj_fn::<T>();
698    let materialize = if <backend::ActiveBackend as Backend<T>>::MATERIALIZES_CONJ {
699        conj_fn
700    } else {
701        None
702    };
703    let use_pool = true;
704    let unit_stride = <backend::ActiveBackend as Backend<T>>::REQUIRES_UNIT_STRIDE;
705    let a_op = contiguous::prepare_input_owned(
706        a_perm,
707        n_lo,
708        n_sum,
709        conj_a_final,
710        unit_stride,
711        use_pool,
712        materialize,
713    )?;
714    let b_op = contiguous::prepare_input_owned(
715        b_perm,
716        n_sum,
717        n_ro,
718        conj_b_final,
719        unit_stride,
720        use_pool,
721        materialize,
722    )?;
723    let mut c_op =
724        contiguous::prepare_output_view(&mut c_perm, n_lo, n_ro, beta, unit_stride, use_pool)?;
725
726    // 8. GEMM — dispatched through trait
727    backend::ActiveBackend::bgemm_contiguous_into(
728        &mut c_op,
729        &a_op,
730        &b_op,
731        &batch_dims,
732        m,
733        n,
734        k,
735        alpha,
736        beta,
737    )?;
738
739    // 9. Finalize
740    c_op.finalize_into(&mut c_perm)?;
741
742    Ok(())
743}
744
745/// Validate that dimensions match across operands for each axis group.
746fn validate_dimensions<ID: AxisId>(
747    plan: &Einsum2Plan<ID>,
748    a_dims: &[usize],
749    b_dims: &[usize],
750    c_dims: &[usize],
751    ia: &[ID],
752    ib: &[ID],
753    ic: &[ID],
754) -> Result<()> {
755    let find_dim = |labels: &[ID], dims: &[usize], id: &ID| -> usize {
756        labels
757            .iter()
758            .position(|x| x == id)
759            .map(|i| dims[i])
760            .unwrap()
761    };
762
763    // Batch: must match in A, B, and C
764    for id in &plan.batch {
765        let da = find_dim(ia, a_dims, id);
766        let db = find_dim(ib, b_dims, id);
767        let dc = find_dim(ic, c_dims, id);
768        if da != db || da != dc {
769            return Err(EinsumError::DimensionMismatch {
770                axis: format!("{:?}", id),
771                dim_a: da,
772                dim_b: db,
773            });
774        }
775    }
776
777    // Sum: must match in A and B
778    for id in &plan.sum {
779        let da = find_dim(ia, a_dims, id);
780        let db = find_dim(ib, b_dims, id);
781        if da != db {
782            return Err(EinsumError::DimensionMismatch {
783                axis: format!("{:?}", id),
784                dim_a: da,
785                dim_b: db,
786            });
787        }
788    }
789
790    // LO: must match in A and C
791    for id in &plan.lo {
792        let da = find_dim(ia, a_dims, id);
793        let dc = find_dim(ic, c_dims, id);
794        if da != dc {
795            return Err(EinsumError::DimensionMismatch {
796                axis: format!("{:?}", id),
797                dim_a: da,
798                dim_b: dc,
799            });
800        }
801    }
802
803    // RO: must match in B and C
804    for id in &plan.ro {
805        let db = find_dim(ib, b_dims, id);
806        let dc = find_dim(ic, c_dims, id);
807        if db != dc {
808            return Err(EinsumError::DimensionMismatch {
809                axis: format!("{:?}", id),
810                dim_a: db,
811                dim_b: dc,
812            });
813        }
814    }
815
816    Ok(())
817}
818
819#[cfg(test)]
820mod tests {
821    use super::*;
822    use strided_view::StridedArray;
823
824    #[test]
825    fn test_matmul_ij_jk_ik() {
826        // C_ik = A_ij * B_jk
827        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
828            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
829        });
830        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
831            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
832        });
833        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
834
835        einsum2_into(
836            c.view_mut(),
837            &a.view(),
838            &b.view(),
839            &['i', 'k'],
840            &['i', 'j'],
841            &['j', 'k'],
842            1.0,
843            0.0,
844        )
845        .unwrap();
846
847        assert_eq!(c.get(&[0, 0]), 19.0);
848        assert_eq!(c.get(&[0, 1]), 22.0);
849        assert_eq!(c.get(&[1, 0]), 43.0);
850        assert_eq!(c.get(&[1, 1]), 50.0);
851    }
852
853    #[test]
854    fn test_matmul_rect() {
855        // A: 2x3, B: 3x4, C: 2x4
856        let a =
857            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
858        let b =
859            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
860        let mut c = StridedArray::<f64>::row_major(&[2, 4]);
861
862        einsum2_into(
863            c.view_mut(),
864            &a.view(),
865            &b.view(),
866            &['i', 'k'],
867            &['i', 'j'],
868            &['j', 'k'],
869            1.0,
870            0.0,
871        )
872        .unwrap();
873
874        // A = [[1,2,3],[4,5,6]], B = [[1,2,3,4],[5,6,7,8],[9,10,11,12]]
875        assert_eq!(c.get(&[0, 0]), 38.0);
876        assert_eq!(c.get(&[1, 3]), 128.0);
877    }
878
879    #[test]
880    fn test_batched_matmul() {
881        // C_bik = A_bij * B_bjk
882        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
883            (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
884        });
885        let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
886            (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
887        });
888        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
889
890        einsum2_into(
891            c.view_mut(),
892            &a.view(),
893            &b.view(),
894            &['b', 'i', 'k'],
895            &['b', 'i', 'j'],
896            &['b', 'j', 'k'],
897            1.0,
898            0.0,
899        )
900        .unwrap();
901
902        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
903        // C0[0,0] = 1*1+2*3+3*5 = 22
904        assert_eq!(c.get(&[0, 0, 0]), 22.0);
905    }
906
907    #[test]
908    fn test_batched_matmul_col_major_output() {
909        // C_bik = A_bij * B_bjk with col-major output (same layout as opteinsum)
910        let a_data = vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
911        let b_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
912        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
913        let b = StridedArray::<f64>::from_parts(b_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
914        let mut c = StridedArray::<f64>::col_major(&[2, 2, 2]);
915
916        einsum2_into(
917            c.view_mut(),
918            &a.view(),
919            &b.view(),
920            &['b', 'i', 'k'],
921            &['b', 'i', 'j'],
922            &['b', 'j', 'k'],
923            1.0,
924            0.0,
925        )
926        .unwrap();
927
928        // batch 0: I * [[1,2],[3,4]] = [[1,2],[3,4]]
929        assert_eq!(c.get(&[0, 0, 0]), 1.0);
930        assert_eq!(c.get(&[0, 0, 1]), 2.0);
931        assert_eq!(c.get(&[0, 1, 0]), 3.0);
932        assert_eq!(c.get(&[0, 1, 1]), 4.0);
933        // batch 1: 2I * [[5,6],[7,8]] = [[10,12],[14,16]]
934        assert_eq!(c.get(&[1, 0, 0]), 10.0);
935        assert_eq!(c.get(&[1, 1, 1]), 16.0);
936    }
937
938    #[test]
939    fn test_outer_product() {
940        // C_ij = A_i * B_j
941        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
942        let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
943        let mut c = StridedArray::<f64>::row_major(&[3, 4]);
944
945        einsum2_into(
946            c.view_mut(),
947            &a.view(),
948            &b.view(),
949            &['i', 'j'],
950            &['i'],
951            &['j'],
952            1.0,
953            0.0,
954        )
955        .unwrap();
956
957        assert_eq!(c.get(&[0, 0]), 1.0);
958        assert_eq!(c.get(&[2, 3]), 12.0);
959    }
960
961    #[test]
962    fn test_dot_product() {
963        // C = A_i * B_i (scalar output)
964        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
965        let b = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
966        let mut c = StridedArray::<f64>::row_major(&[]);
967
968        einsum2_into(
969            c.view_mut(),
970            &a.view(),
971            &b.view(),
972            &[] as &[char],
973            &['i'],
974            &['i'],
975            1.0,
976            0.0,
977        )
978        .unwrap();
979
980        // 1*1 + 2*2 + 3*3 = 14
981        assert_eq!(c.get(&[]), 14.0);
982    }
983
984    #[test]
985    fn test_alpha_beta() {
986        // C = 2*A*B + 3*C_old
987        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
988            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
989        });
990        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
991            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
992        });
993        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
994            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
995        });
996
997        einsum2_into(
998            c.view_mut(),
999            &a.view(),
1000            &b.view(),
1001            &['i', 'k'],
1002            &['i', 'j'],
1003            &['j', 'k'],
1004            2.0,
1005            3.0,
1006        )
1007        .unwrap();
1008
1009        // C = 2*I*B + 3*C_old
1010        assert_eq!(c.get(&[0, 0]), 32.0); // 2*1 + 3*10
1011        assert_eq!(c.get(&[1, 1]), 128.0); // 2*4 + 3*40
1012    }
1013
1014    #[test]
1015    fn test_transposed_output() {
1016        // C_ki = A_ij * B_jk (output transposed)
1017        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1018            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1019        });
1020        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1021            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1022        });
1023        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1024
1025        einsum2_into(
1026            c.view_mut(),
1027            &a.view(),
1028            &b.view(),
1029            &['k', 'i'], // C indexed as (k, i) instead of (i, k)
1030            &['i', 'j'],
1031            &['j', 'k'],
1032            1.0,
1033            0.0,
1034        )
1035        .unwrap();
1036
1037        // Normal matmul result: C_ik = [[19,22],[43,50]]
1038        // But C is indexed as (k, i), so C[k, i] = (A*B)[i, k]
1039        assert_eq!(c.get(&[0, 0]), 19.0); // C[k=0, i=0]
1040        assert_eq!(c.get(&[0, 1]), 43.0); // C[k=0, i=1]
1041        assert_eq!(c.get(&[1, 0]), 22.0); // C[k=1, i=0]
1042        assert_eq!(c.get(&[1, 1]), 50.0); // C[k=1, i=1]
1043    }
1044
1045    #[test]
1046    fn test_left_trace() {
1047        // C_k = sum_j (sum_i A_ij) * B_jk
1048        // left_trace=[i], sum=[j], ro=[k]
1049        let a =
1050            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1051        // A = [[1,2,3],[4,5,6]]
1052        // sum over i: [5, 7, 9]
1053        let b =
1054            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1055        // B = [[1,2],[3,4],[5,6]]
1056        let mut c = StridedArray::<f64>::row_major(&[2]);
1057
1058        einsum2_into(
1059            c.view_mut(),
1060            &a.view(),
1061            &b.view(),
1062            &['k'],
1063            &['i', 'j'],
1064            &['j', 'k'],
1065            1.0,
1066            0.0,
1067        )
1068        .unwrap();
1069
1070        // C_k = sum_j [5,7,9][j] * B[j,k]
1071        // C[0] = 5*1 + 7*3 + 9*5 = 5 + 21 + 45 = 71
1072        // C[1] = 5*2 + 7*4 + 9*6 = 10 + 28 + 54 = 92
1073        assert_eq!(c.get(&[0]), 71.0);
1074        assert_eq!(c.get(&[1]), 92.0);
1075    }
1076
1077    #[test]
1078    fn test_u32_labels() {
1079        // Same as matmul but with u32 labels
1080        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1081            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1082        });
1083        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1084            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1085        });
1086        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1087
1088        einsum2_into(
1089            c.view_mut(),
1090            &a.view(),
1091            &b.view(),
1092            &[0u32, 2],
1093            &[0u32, 1],
1094            &[1u32, 2],
1095            1.0,
1096            0.0,
1097        )
1098        .unwrap();
1099
1100        assert_eq!(c.get(&[0, 0]), 19.0);
1101        assert_eq!(c.get(&[1, 1]), 50.0);
1102    }
1103
1104    #[test]
1105    fn test_complex_matmul() {
1106        use num_complex::Complex64;
1107        let i = Complex64::i();
1108
1109        // A = [[1+i, 2], [3, 4-i]]
1110        let a_vals = [
1111            [1.0 + i, Complex64::new(2.0, 0.0)],
1112            [Complex64::new(3.0, 0.0), 4.0 - i],
1113        ];
1114        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1115
1116        // B = [[1, i], [0, 1]]
1117        let b_vals = [
1118            [Complex64::new(1.0, 0.0), i],
1119            [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
1120        ];
1121        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1122
1123        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1124
1125        einsum2_into(
1126            c.view_mut(),
1127            &a.view(),
1128            &b.view(),
1129            &['i', 'k'],
1130            &['i', 'j'],
1131            &['j', 'k'],
1132            Complex64::new(1.0, 0.0),
1133            Complex64::new(0.0, 0.0),
1134        )
1135        .unwrap();
1136
1137        // C = A * B
1138        // C[0,0] = (1+i)*1 + 2*0 = 1+i
1139        // C[0,1] = (1+i)*i + 2*1 = i+i²+2 = i-1+2 = 1+i
1140        // C[1,0] = 3*1 + (4-i)*0 = 3
1141        // C[1,1] = 3*i + (4-i)*1 = 3i+4-i = 4+2i
1142        assert_eq!(c.get(&[0, 0]), 1.0 + i);
1143        assert_eq!(c.get(&[0, 1]), 1.0 + i);
1144        assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1145        assert_eq!(c.get(&[1, 1]), 4.0 + 2.0 * i);
1146    }
1147
1148    #[test]
1149    fn test_complex_matmul_with_conj() {
1150        use num_complex::Complex64;
1151        let i = Complex64::i();
1152
1153        // A = [[1+i, 2i], [3, 4-i]]
1154        let a_vals = [[1.0 + i, 2.0 * i], [Complex64::new(3.0, 0.0), 4.0 - i]];
1155        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1156
1157        // B = identity
1158        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| {
1159            if idx[0] == idx[1] {
1160                Complex64::new(1.0, 0.0)
1161            } else {
1162                Complex64::new(0.0, 0.0)
1163            }
1164        });
1165
1166        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1167
1168        // C = conj(A) * B = conj(A)
1169        let a_conj = a.view().conj();
1170        einsum2_into(
1171            c.view_mut(),
1172            &a_conj,
1173            &b.view(),
1174            &['i', 'k'],
1175            &['i', 'j'],
1176            &['j', 'k'],
1177            Complex64::new(1.0, 0.0),
1178            Complex64::new(0.0, 0.0),
1179        )
1180        .unwrap();
1181
1182        // conj(A) = [[1-i, -2i], [3, 4+i]]
1183        assert_eq!(c.get(&[0, 0]), 1.0 - i);
1184        assert_eq!(c.get(&[0, 1]), -2.0 * i);
1185        assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1186        assert_eq!(c.get(&[1, 1]), 4.0 + i);
1187    }
1188
1189    #[test]
1190    fn test_complex_matmul_with_conj_both() {
1191        use num_complex::Complex64;
1192        let i = Complex64::i();
1193
1194        // A = [[1+i, 0], [0, 2-i]]
1195        let a_vals = [
1196            [1.0 + i, Complex64::new(0.0, 0.0)],
1197            [Complex64::new(0.0, 0.0), 2.0 - i],
1198        ];
1199        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1200
1201        // B = [[1, i], [0, 1+i]]
1202        let b_vals = [
1203            [Complex64::new(1.0, 0.0), i],
1204            [Complex64::new(0.0, 0.0), 1.0 + i],
1205        ];
1206        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1207
1208        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1209
1210        // C = conj(A) * conj(B)
1211        let a_conj = a.view().conj();
1212        let b_conj = b.view().conj();
1213        einsum2_into(
1214            c.view_mut(),
1215            &a_conj,
1216            &b_conj,
1217            &['i', 'k'],
1218            &['i', 'j'],
1219            &['j', 'k'],
1220            Complex64::new(1.0, 0.0),
1221            Complex64::new(0.0, 0.0),
1222        )
1223        .unwrap();
1224
1225        // conj(A) = [[1-i, 0], [0, 2+i]]
1226        // conj(B) = [[1, -i], [0, 1-i]]
1227        // C = conj(A) * conj(B)
1228        // C[0,0] = (1-i)*1 + 0*0 = 1-i
1229        // C[0,1] = (1-i)*(-i) + 0*(1-i) = -i+i² = -i-1 = -(1+i)
1230        // C[1,0] = 0*1 + (2+i)*0 = 0
1231        // C[1,1] = 0*(-i) + (2+i)*(1-i) = 2-2i+i-i² = 2-i+1 = 3-i
1232        assert_eq!(c.get(&[0, 0]), 1.0 - i);
1233        assert_eq!(c.get(&[0, 1]), -(1.0 + i));
1234        assert_eq!(c.get(&[1, 0]), Complex64::new(0.0, 0.0));
1235        assert_eq!(c.get(&[1, 1]), 3.0 - i);
1236    }
1237
1238    #[test]
1239    fn test_elementwise_hadamard() {
1240        // C_ijk = A_ijk * B_ijk — all batch, no contraction
1241        let a = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1242            (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64
1243        });
1244        let b = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1245            (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64 * 0.1
1246        });
1247        let mut c = StridedArray::<f64>::row_major(&[3, 4, 5]);
1248
1249        einsum2_into(
1250            c.view_mut(),
1251            &a.view(),
1252            &b.view(),
1253            &['i', 'j', 'k'],
1254            &['i', 'j', 'k'],
1255            &['i', 'j', 'k'],
1256            1.0,
1257            0.0,
1258        )
1259        .unwrap();
1260
1261        // Spot check: C[0,0,0] = 1 * 0.1 = 0.1
1262        assert!((c.get(&[0, 0, 0]) - 0.1).abs() < 1e-12);
1263        // C[2,3,4] = 60 * 6.0 = 360
1264        assert!((c.get(&[2, 3, 4]) - 360.0).abs() < 1e-10);
1265    }
1266
1267    #[test]
1268    fn test_elementwise_hadamard_with_alpha() {
1269        let a =
1270            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1271        let b =
1272            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1273        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1274
1275        einsum2_into(
1276            c.view_mut(),
1277            &a.view(),
1278            &b.view(),
1279            &['i', 'j'],
1280            &['i', 'j'],
1281            &['i', 'j'],
1282            2.0,
1283            0.0,
1284        )
1285        .unwrap();
1286
1287        // C[0,0] = 2.0 * 1 * 1 = 2.0
1288        assert_eq!(c.get(&[0, 0]), 2.0);
1289        // C[1,2] = 2.0 * 6 * 6 = 72.0
1290        assert_eq!(c.get(&[1, 2]), 72.0);
1291    }
1292
1293    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1294    #[test]
1295    fn test_einsum2_owned_matmul() {
1296        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1297            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1298        });
1299        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1300            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1301        });
1302        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1303
1304        einsum2_into_owned(
1305            c.view_mut(),
1306            a,
1307            b,
1308            &['i', 'k'],
1309            &['i', 'j'],
1310            &['j', 'k'],
1311            1.0,
1312            0.0,
1313            false,
1314            false,
1315        )
1316        .unwrap();
1317
1318        assert_eq!(c.get(&[0, 0]), 19.0);
1319        assert_eq!(c.get(&[0, 1]), 22.0);
1320        assert_eq!(c.get(&[1, 0]), 43.0);
1321        assert_eq!(c.get(&[1, 1]), 50.0);
1322    }
1323
1324    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1325    #[test]
1326    fn test_einsum2_owned_batched() {
1327        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
1328            (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
1329        });
1330        let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
1331            (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
1332        });
1333        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
1334
1335        einsum2_into_owned(
1336            c.view_mut(),
1337            a,
1338            b,
1339            &['b', 'i', 'k'],
1340            &['b', 'i', 'j'],
1341            &['b', 'j', 'k'],
1342            1.0,
1343            0.0,
1344            false,
1345            false,
1346        )
1347        .unwrap();
1348
1349        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
1350        // C0[0,0] = 1*1+2*3+3*5 = 22
1351        assert_eq!(c.get(&[0, 0, 0]), 22.0);
1352    }
1353
1354    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1355    #[test]
1356    fn test_einsum2_owned_alpha_beta() {
1357        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1358            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]]
1359        });
1360        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1361            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1362        });
1363        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1364            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
1365        });
1366
1367        einsum2_into_owned(
1368            c.view_mut(),
1369            a,
1370            b,
1371            &['i', 'k'],
1372            &['i', 'j'],
1373            &['j', 'k'],
1374            2.0,
1375            3.0,
1376            false,
1377            false,
1378        )
1379        .unwrap();
1380
1381        // C = 2*I*B + 3*C_old
1382        assert_eq!(c.get(&[0, 0]), 32.0); // 2*1 + 3*10
1383        assert_eq!(c.get(&[1, 1]), 128.0); // 2*4 + 3*40
1384    }
1385
1386    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1387    #[test]
1388    fn test_einsum2_owned_elementwise() {
1389        // All batch, no contraction -- element-wise fast path
1390        let a =
1391            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1392        let b =
1393            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1394        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1395
1396        einsum2_into_owned(
1397            c.view_mut(),
1398            a,
1399            b,
1400            &['i', 'j'],
1401            &['i', 'j'],
1402            &['i', 'j'],
1403            2.0,
1404            0.0,
1405            false,
1406            false,
1407        )
1408        .unwrap();
1409
1410        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1411        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1412    }
1413
1414    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1415    #[test]
1416    fn test_einsum2_owned_left_trace() {
1417        // C_k = sum_j (sum_i A_ij) * B_jk
1418        // left_trace=[i], sum=[j], ro=[k]
1419        let a =
1420            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1421        // A = [[1,2,3],[4,5,6]]
1422        // sum over i: [5, 7, 9]
1423        let b =
1424            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1425        // B = [[1,2],[3,4],[5,6]]
1426        let mut c = StridedArray::<f64>::row_major(&[2]);
1427
1428        einsum2_into_owned(
1429            c.view_mut(),
1430            a,
1431            b,
1432            &['k'],
1433            &['i', 'j'],
1434            &['j', 'k'],
1435            1.0,
1436            0.0,
1437            false,
1438            false,
1439        )
1440        .unwrap();
1441
1442        // C[0] = 5*1 + 7*3 + 9*5 = 5 + 21 + 45 = 71
1443        // C[1] = 5*2 + 7*4 + 9*6 = 10 + 28 + 54 = 92
1444        assert_eq!(c.get(&[0]), 71.0);
1445        assert_eq!(c.get(&[1]), 92.0);
1446    }
1447
1448    #[test]
1449    fn test_einsum2_naive_matmul() {
1450        // einsum2_naive_into with identity maps = regular matmul
1451        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1452            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1453        });
1454        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1455            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1456        });
1457        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1458
1459        einsum2_naive_into(
1460            c.view_mut(),
1461            &a.view(),
1462            &b.view(),
1463            &['i', 'k'],
1464            &['i', 'j'],
1465            &['j', 'k'],
1466            1.0,
1467            0.0,
1468            |x| x,
1469            |x| x,
1470        )
1471        .unwrap();
1472
1473        assert_eq!(c.get(&[0, 0]), 19.0);
1474        assert_eq!(c.get(&[0, 1]), 22.0);
1475        assert_eq!(c.get(&[1, 0]), 43.0);
1476        assert_eq!(c.get(&[1, 1]), 50.0);
1477    }
1478
1479    #[test]
1480    fn test_einsum2_naive_elementwise() {
1481        // Element-wise (all batch) with identity maps
1482        let a =
1483            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1484        let b =
1485            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1486        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1487
1488        einsum2_naive_into(
1489            c.view_mut(),
1490            &a.view(),
1491            &b.view(),
1492            &['i', 'j'],
1493            &['i', 'j'],
1494            &['i', 'j'],
1495            2.0,
1496            0.0,
1497            |x| x,
1498            |x| x,
1499        )
1500        .unwrap();
1501
1502        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1503        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1504    }
1505
1506    #[test]
1507    fn test_einsum2_naive_custom_type() {
1508        // Custom scalar type that does NOT implement ElementOpApply.
1509        // This demonstrates that einsum2_naive_into works with custom types.
1510        use num_traits::{One, Zero};
1511
1512        #[derive(Debug, Clone, Copy, PartialEq)]
1513        struct MyVal(f64);
1514
1515        impl Default for MyVal {
1516            fn default() -> Self {
1517                MyVal(0.0)
1518            }
1519        }
1520
1521        impl std::ops::Add for MyVal {
1522            type Output = Self;
1523            fn add(self, rhs: Self) -> Self {
1524                MyVal(self.0 + rhs.0)
1525            }
1526        }
1527
1528        impl std::ops::Mul for MyVal {
1529            type Output = Self;
1530            fn mul(self, rhs: Self) -> Self {
1531                MyVal(self.0 * rhs.0)
1532            }
1533        }
1534
1535        impl Zero for MyVal {
1536            fn zero() -> Self {
1537                MyVal(0.0)
1538            }
1539            fn is_zero(&self) -> bool {
1540                self.0 == 0.0
1541            }
1542        }
1543
1544        impl One for MyVal {
1545            fn one() -> Self {
1546                MyVal(1.0)
1547            }
1548        }
1549
1550        // C_ik = A_ij * B_jk (matrix multiply with custom type)
1551        let a = StridedArray::from_parts(
1552            vec![MyVal(1.0), MyVal(2.0), MyVal(3.0), MyVal(4.0)],
1553            &[2, 2],
1554            &[2, 1],
1555            0,
1556        )
1557        .unwrap();
1558        let b = StridedArray::from_parts(
1559            vec![MyVal(5.0), MyVal(6.0), MyVal(7.0), MyVal(8.0)],
1560            &[2, 2],
1561            &[2, 1],
1562            0,
1563        )
1564        .unwrap();
1565        let mut c = StridedArray::<MyVal>::col_major(&[2, 2]);
1566
1567        einsum2_naive_into(
1568            c.view_mut(),
1569            &a.view(),
1570            &b.view(),
1571            &['i', 'k'],
1572            &['i', 'j'],
1573            &['j', 'k'],
1574            MyVal(1.0),
1575            MyVal(0.0),
1576            |x| x,
1577            |x| x,
1578        )
1579        .unwrap();
1580
1581        // C = [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1582        //   = [[19, 22], [43, 50]]
1583        assert_eq!(c.get(&[0, 0]), MyVal(19.0));
1584        assert_eq!(c.get(&[0, 1]), MyVal(22.0));
1585        assert_eq!(c.get(&[1, 0]), MyVal(43.0));
1586        assert_eq!(c.get(&[1, 1]), MyVal(50.0));
1587    }
1588
1589    #[test]
1590    fn test_einsum2_naive_left_trace() {
1591        // C_k = sum_j (sum_i A_ij) * B_jk with map_a = identity
1592        let a =
1593            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1594        let b =
1595            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1596        let mut c = StridedArray::<f64>::row_major(&[2]);
1597
1598        einsum2_naive_into(
1599            c.view_mut(),
1600            &a.view(),
1601            &b.view(),
1602            &['k'],
1603            &['i', 'j'],
1604            &['j', 'k'],
1605            1.0,
1606            0.0,
1607            |x| x,
1608            |x| x,
1609        )
1610        .unwrap();
1611
1612        assert_eq!(c.get(&[0]), 71.0);
1613        assert_eq!(c.get(&[1]), 92.0);
1614    }
1615
1616    /// Test backend that delegates to naive GEMM loops.
1617    struct TestNaiveBackend;
1618
1619    impl Backend<f64> for TestNaiveBackend {
1620        const MATERIALIZES_CONJ: bool = false;
1621        const REQUIRES_UNIT_STRIDE: bool = false;
1622
1623        fn bgemm_contiguous_into(
1624            c: &mut contiguous::ContiguousOperandMut<f64>,
1625            a: &contiguous::ContiguousOperand<f64>,
1626            b: &contiguous::ContiguousOperand<f64>,
1627            batch_dims: &[usize],
1628            m: usize,
1629            n: usize,
1630            k: usize,
1631            alpha: f64,
1632            beta: f64,
1633        ) -> strided_view::Result<()> {
1634            // Delegate to the contiguous naive GEMM loop
1635            let a_ptr = a.ptr();
1636            let b_ptr = b.ptr();
1637            let c_ptr = c.ptr();
1638            let a_rs = a.row_stride();
1639            let a_cs = a.col_stride();
1640            let b_rs = b.row_stride();
1641            let b_cs = b.col_stride();
1642            let c_rs = c.row_stride();
1643            let c_cs = c.col_stride();
1644
1645            let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1646            while batch_idx.next().is_some() {
1647                let a_base = batch_idx.offset(a.batch_strides());
1648                let b_base = batch_idx.offset(b.batch_strides());
1649                let c_base = batch_idx.offset(c.batch_strides());
1650
1651                for i in 0..m {
1652                    for j in 0..n {
1653                        let mut acc = 0.0f64;
1654                        for l in 0..k {
1655                            let a_val = unsafe {
1656                                *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1657                            };
1658                            let b_val = unsafe {
1659                                *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1660                            };
1661                            acc += a_val * b_val;
1662                        }
1663                        unsafe {
1664                            let c_elem =
1665                                c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1666                            if beta == 0.0 {
1667                                *c_elem = alpha * acc;
1668                            } else {
1669                                *c_elem = alpha * acc + beta * (*c_elem);
1670                            }
1671                        }
1672                    }
1673                }
1674            }
1675            Ok(())
1676        }
1677    }
1678
1679    #[test]
1680    fn test_einsum2_with_backend_matmul() {
1681        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1682            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1683        });
1684        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1685            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1686        });
1687        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1688
1689        einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1690            c.view_mut(),
1691            &a.view(),
1692            &b.view(),
1693            &['i', 'k'],
1694            &['i', 'j'],
1695            &['j', 'k'],
1696            1.0,
1697            0.0,
1698        )
1699        .unwrap();
1700
1701        assert_eq!(c.get(&[0, 0]), 19.0);
1702        assert_eq!(c.get(&[0, 1]), 22.0);
1703        assert_eq!(c.get(&[1, 0]), 43.0);
1704        assert_eq!(c.get(&[1, 1]), 50.0);
1705    }
1706
1707    #[test]
1708    fn test_einsum2_with_backend_elementwise() {
1709        let a =
1710            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1711        let b =
1712            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1713        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1714
1715        einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1716            c.view_mut(),
1717            &a.view(),
1718            &b.view(),
1719            &['i', 'j'],
1720            &['i', 'j'],
1721            &['i', 'j'],
1722            2.0,
1723            0.0,
1724        )
1725        .unwrap();
1726
1727        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1728        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1729    }
1730
1731    #[test]
1732    fn test_einsum2_with_backend_custom_type() {
1733        use num_traits::{One, Zero};
1734
1735        #[derive(Debug, Clone, Copy, PartialEq)]
1736        struct Tropical(f64);
1737
1738        impl Default for Tropical {
1739            fn default() -> Self {
1740                Tropical(0.0)
1741            }
1742        }
1743
1744        impl std::ops::Add for Tropical {
1745            type Output = Self;
1746            fn add(self, rhs: Self) -> Self {
1747                Tropical(self.0 + rhs.0)
1748            }
1749        }
1750
1751        impl std::ops::Mul for Tropical {
1752            type Output = Self;
1753            fn mul(self, rhs: Self) -> Self {
1754                Tropical(self.0 * rhs.0)
1755            }
1756        }
1757
1758        impl Zero for Tropical {
1759            fn zero() -> Self {
1760                Tropical(0.0)
1761            }
1762            fn is_zero(&self) -> bool {
1763                self.0 == 0.0
1764            }
1765        }
1766
1767        impl One for Tropical {
1768            fn one() -> Self {
1769                Tropical(1.0)
1770            }
1771        }
1772
1773        struct TropicalBackend;
1774
1775        impl Backend<Tropical> for TropicalBackend {
1776            const MATERIALIZES_CONJ: bool = false;
1777            const REQUIRES_UNIT_STRIDE: bool = false;
1778
1779            fn bgemm_contiguous_into(
1780                c: &mut contiguous::ContiguousOperandMut<Tropical>,
1781                a: &contiguous::ContiguousOperand<Tropical>,
1782                b: &contiguous::ContiguousOperand<Tropical>,
1783                batch_dims: &[usize],
1784                m: usize,
1785                n: usize,
1786                k: usize,
1787                alpha: Tropical,
1788                beta: Tropical,
1789            ) -> strided_view::Result<()> {
1790                // Simple naive GEMM for Tropical type
1791                let a_ptr = a.ptr();
1792                let b_ptr = b.ptr();
1793                let c_ptr = c.ptr();
1794                let a_rs = a.row_stride();
1795                let a_cs = a.col_stride();
1796                let b_rs = b.row_stride();
1797                let b_cs = b.col_stride();
1798                let c_rs = c.row_stride();
1799                let c_cs = c.col_stride();
1800
1801                let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1802                while batch_idx.next().is_some() {
1803                    let a_base = batch_idx.offset(a.batch_strides());
1804                    let b_base = batch_idx.offset(b.batch_strides());
1805                    let c_base = batch_idx.offset(c.batch_strides());
1806
1807                    for i in 0..m {
1808                        for j in 0..n {
1809                            let mut acc = Tropical::zero();
1810                            for l in 0..k {
1811                                let a_val = unsafe {
1812                                    *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1813                                };
1814                                let b_val = unsafe {
1815                                    *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1816                                };
1817                                acc = acc + a_val * b_val;
1818                            }
1819                            unsafe {
1820                                let c_elem =
1821                                    c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1822                                if beta == Tropical::zero() {
1823                                    *c_elem = alpha * acc;
1824                                } else {
1825                                    *c_elem = alpha * acc + beta * (*c_elem);
1826                                }
1827                            }
1828                        }
1829                    }
1830                }
1831                Ok(())
1832            }
1833        }
1834
1835        let a = StridedArray::from_parts(
1836            vec![Tropical(1.0), Tropical(2.0), Tropical(3.0), Tropical(4.0)],
1837            &[2, 2],
1838            &[2, 1],
1839            0,
1840        )
1841        .unwrap();
1842        let b = StridedArray::from_parts(
1843            vec![Tropical(5.0), Tropical(6.0), Tropical(7.0), Tropical(8.0)],
1844            &[2, 2],
1845            &[2, 1],
1846            0,
1847        )
1848        .unwrap();
1849        let mut c = StridedArray::<Tropical>::col_major(&[2, 2]);
1850
1851        einsum2_with_backend_into::<_, TropicalBackend, _>(
1852            c.view_mut(),
1853            &a.view(),
1854            &b.view(),
1855            &['i', 'k'],
1856            &['i', 'j'],
1857            &['j', 'k'],
1858            Tropical(1.0),
1859            Tropical(0.0),
1860        )
1861        .unwrap();
1862
1863        // C = [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1864        //   = [[19, 22], [43, 50]]
1865        assert_eq!(c.get(&[0, 0]), Tropical(19.0));
1866        assert_eq!(c.get(&[0, 1]), Tropical(22.0));
1867        assert_eq!(c.get(&[1, 0]), Tropical(43.0));
1868        assert_eq!(c.get(&[1, 1]), Tropical(50.0));
1869    }
1870}