strided_einsum2/
lib.rs

1//! Binary Einstein summation on strided views.
2//!
3//! Provides `einsum2_into` for computing binary tensor contractions with
4//! accumulation semantics: `C = alpha * A * B + beta * C`.
5//!
6//! # Example
7//!
8//! ```
9//! use strided_view::StridedArray;
10//! use strided_einsum2::einsum2_into;
11//!
12//! // Matrix multiply: C_ik = A_ij * B_jk
13//! let a = StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
14//! let b = StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
15//! let mut c = StridedArray::<f64>::row_major(&[2, 2]);
16//!
17//! einsum2_into(
18//!     c.view_mut(), &a.view(), &b.view(),
19//!     &['i', 'k'], &['i', 'j'], &['j', 'k'],
20//!     1.0, 0.0,
21//! ).unwrap();
22//! ```
23
24#[cfg(all(feature = "faer", feature = "blas"))]
25compile_error!("Features `faer` and `blas` are mutually exclusive. Use one or the other.");
26
27#[cfg(all(feature = "faer", feature = "blas-inject"))]
28compile_error!("Features `faer` and `blas-inject` are mutually exclusive.");
29
30#[cfg(all(feature = "blas", feature = "blas-inject"))]
31compile_error!("Features `blas` and `blas-inject` are mutually exclusive.");
32
33#[cfg(all(feature = "blas-inject", not(feature = "blas")))]
34extern crate cblas_inject as cblas_sys;
35#[cfg(all(feature = "blas", not(feature = "blas-inject")))]
36extern crate cblas_sys;
37
38#[cfg(all(
39    not(feature = "faer"),
40    any(
41        all(feature = "blas", not(feature = "blas-inject")),
42        all(feature = "blas-inject", not(feature = "blas"))
43    )
44))]
45pub mod bgemm_blas;
46
47#[cfg(feature = "faer")]
48/// Batched GEMM backend using the [`faer`] library.
49pub mod bgemm_faer;
50/// Batched GEMM fallback using explicit loops.
51pub mod bgemm_naive;
52/// GEMM-ready operand types and preparation functions for contiguous data.
53pub mod contiguous;
54/// Contraction planning: axis classification and permutation computation.
55pub mod plan;
56/// Trace-axis reduction (summing axes that appear only in one operand).
57pub mod trace;
58/// Shared helpers (permutation inversion, multi-index iteration, dimension fusion).
59pub mod util;
60
61/// Backend abstraction for batched GEMM dispatch.
62pub mod backend;
63
64use std::any::TypeId;
65use std::fmt::Debug;
66use std::hash::Hash;
67
68use strided_kernel::zip_map2_into;
69#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
70use strided_view::StridedArray;
71use strided_view::{Adjoint, Conj, ElementOp, ElementOpApply, StridedView, StridedViewMut};
72
73pub use strided_traits::ScalarBase;
74
75pub use backend::{BackendConfig, BgemmBackend};
76pub use plan::Einsum2Plan;
77
78/// Trait alias for axis label types.
79pub trait AxisId: Clone + Eq + Hash + Debug {}
80impl<T: Clone + Eq + Hash + Debug> AxisId for T {}
81
82// ScalarBase is re-exported from strided_traits (see above).
83// It no longer requires ElementOpApply, enabling custom scalar types
84// (e.g., tropical semiring) to work with Identity-only views.
85
86/// Trait alias for element types supported by einsum operations.
87///
88/// When the `faer` feature is enabled, this additionally requires `faer::ComplexField`
89/// so that the faer GEMM backend can be used.
90#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
91pub trait Scalar: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
92
93#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
94impl<T> Scalar for T where T: ScalarBase + ElementOpApply + faer_traits::ComplexField {}
95
96/// Trait alias for element types (with `blas` or `blas-inject` feature).
97///
98/// Includes `BlasGemm` so that all `Scalar` types can be dispatched to CBLAS.
99#[cfg(all(
100    not(feature = "faer"),
101    any(
102        all(feature = "blas", not(feature = "blas-inject")),
103        all(feature = "blas-inject", not(feature = "blas"))
104    )
105))]
106pub trait Scalar: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
107
108#[cfg(all(
109    not(feature = "faer"),
110    any(
111        all(feature = "blas", not(feature = "blas-inject")),
112        all(feature = "blas-inject", not(feature = "blas"))
113    )
114))]
115impl<T> Scalar for T where T: ScalarBase + ElementOpApply + bgemm_blas::BlasGemm {}
116
117/// Trait alias for element types (without `faer` or BLAS features).
118#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
119pub trait Scalar: ScalarBase + ElementOpApply {}
120
121#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
122impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
123
124/// Placeholder trait definition for invalid mutually-exclusive feature combinations.
125///
126/// The crate emits `compile_error!` above for these combinations. This trait only
127/// avoids cascading type-resolution errors so users see the intended diagnostics.
128#[cfg(any(
129    all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
130    all(feature = "blas", feature = "blas-inject")
131))]
132pub trait Scalar: ScalarBase + ElementOpApply {}
133
134#[cfg(any(
135    all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
136    all(feature = "blas", feature = "blas-inject")
137))]
138impl<T> Scalar for T where T: ScalarBase + ElementOpApply {}
139
140/// Errors specific to einsum operations.
141#[derive(Debug, thiserror::Error)]
142pub enum EinsumError {
143    #[error("duplicate axis label: {0}")]
144    DuplicateAxis(String),
145    #[error("output axis {0} not found in any input")]
146    OrphanOutputAxis(String),
147    #[error("dimension mismatch for axis {axis:?}: {dim_a} vs {dim_b}")]
148    DimensionMismatch {
149        axis: String,
150        dim_a: usize,
151        dim_b: usize,
152    },
153    #[error(transparent)]
154    Strided(#[from] strided_view::StridedError),
155}
156
157/// Convenience alias for `Result<T, EinsumError>`.
158pub type Result<T> = std::result::Result<T, EinsumError>;
159
160/// Returns `true` if the given `ElementOp` type represents conjugation.
161///
162/// - `Identity` / `Transpose` → `false` (no per-element conjugation)
163/// - `Conj` / `Adjoint` → `true` (per-element conjugation needed)
164///
165/// For scalar types, `Transpose::apply(x) = x` (identity) and the dimension
166/// swap is already reflected in the view's strides/dims.  Similarly,
167/// `Adjoint::apply(x) = x.conj()` with the dimension swap in the view.
168fn op_is_conj<Op: 'static>() -> bool {
169    TypeId::of::<Op>() == TypeId::of::<Conj>() || TypeId::of::<Op>() == TypeId::of::<Adjoint>()
170}
171
172/// Binary einsum contraction: `C = alpha * contract(A, B) + beta * C`.
173///
174/// `ic`, `ia`, `ib` are axis labels for C, A, B respectively.
175/// Axes are classified as:
176/// - **batch**: in A, B, and C
177/// - **lo** (left-output): in A and C, not B
178/// - **ro** (right-output): in B and C, not A
179/// - **sum** (contraction): in A and B, not C
180/// - **left_trace**: only in A (summed out before contraction)
181/// - **right_trace**: only in B (summed out before contraction)
182pub fn einsum2_into<T: Scalar, OpA, OpB, ID: AxisId>(
183    c: StridedViewMut<T>,
184    a: &StridedView<T, OpA>,
185    b: &StridedView<T, OpB>,
186    ic: &[ID],
187    ia: &[ID],
188    ib: &[ID],
189    alpha: T,
190    beta: T,
191) -> Result<()>
192where
193    OpA: ElementOp<T> + 'static,
194    OpB: ElementOp<T> + 'static,
195{
196    // 1. Build plan
197    let plan = Einsum2Plan::new(ia, ib, ic)?;
198
199    // 2. Validate dimension consistency across operands
200    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
201
202    // 3. Reduce trace axes if present; determine conjugation flags.
203    //    When trace reduction occurs, Op is already applied during the reduction,
204    //    so conj flag is false. Otherwise, we strip the Op and pass a conj flag
205    //    to the GEMM kernel (avoiding materialization).
206    let (a_buf, conj_a) = if !plan.left_trace.is_empty() {
207        let trace_indices = plan.left_trace_indices(ia);
208        (Some(trace::reduce_trace_axes(a, &trace_indices)?), false)
209    } else {
210        (None, op_is_conj::<OpA>())
211    };
212
213    let a_view: StridedView<T> = match a_buf.as_ref() {
214        Some(buf) => buf.view(),
215        None => StridedView::new(a.data(), a.dims(), a.strides(), a.offset())
216            .expect("strip_op_view: metadata already validated"),
217    };
218
219    let (b_buf, conj_b) = if !plan.right_trace.is_empty() {
220        let trace_indices = plan.right_trace_indices(ib);
221        (Some(trace::reduce_trace_axes(b, &trace_indices)?), false)
222    } else {
223        (None, op_is_conj::<OpB>())
224    };
225
226    let b_view: StridedView<T> = match b_buf.as_ref() {
227        Some(buf) => buf.view(),
228        None => StridedView::new(b.data(), b.dims(), b.strides(), b.offset())
229            .expect("strip_op_view: metadata already validated"),
230    };
231
232    // 4. Dispatch to GEMM
233    #[cfg(any(
234        all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
235        all(
236            not(feature = "faer"),
237            any(
238                all(feature = "blas", not(feature = "blas-inject")),
239                all(feature = "blas-inject", not(feature = "blas"))
240            )
241        )
242    ))]
243    einsum2_gemm_dispatch(c, &a_view, &b_view, &plan, alpha, beta, conj_a, conj_b)?;
244
245    #[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
246    {
247        let a_perm = a_view.permute(&plan.left_perm)?;
248        let b_perm = b_view.permute(&plan.right_perm)?;
249        let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
250
251        if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
252            let mul_fn = move |a_val: T, b_val: T| -> T {
253                let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
254                let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
255                alpha * a_c * b_c
256            };
257            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
258            return Ok(());
259        }
260
261        bgemm_naive::bgemm_strided_into(
262            &mut c_perm,
263            &a_perm,
264            &b_perm,
265            plan.batch.len(),
266            plan.lo.len(),
267            plan.ro.len(),
268            plan.sum.len(),
269            alpha,
270            beta,
271            conj_a,
272            conj_b,
273        )?;
274    }
275
276    Ok(())
277}
278
279/// Binary einsum for custom scalar types using naive GEMM.
280///
281/// Like [`einsum2_into`] but works with any `ScalarBase` type (no `ElementOpApply` required).
282/// Uses closures `map_a` and `map_b` for per-element transformation instead of
283/// conjugation flags. Always dispatches to the naive GEMM kernel.
284///
285/// Views must use `Identity` element operations (the default).
286pub fn einsum2_naive_into<T, ID, MapA, MapB>(
287    c: StridedViewMut<T>,
288    a: &StridedView<T>,
289    b: &StridedView<T>,
290    ic: &[ID],
291    ia: &[ID],
292    ib: &[ID],
293    alpha: T,
294    beta: T,
295    map_a: MapA,
296    map_b: MapB,
297) -> Result<()>
298where
299    T: ScalarBase,
300    ID: AxisId,
301    MapA: Fn(T) -> T,
302    MapB: Fn(T) -> T,
303{
304    let plan = Einsum2Plan::new(ia, ib, ic)?;
305    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
306
307    // Reduce trace axes if present.
308    // When trace reduction occurs, map is applied via map_into before reduction,
309    // so we use identity map for GEMM. Otherwise, pass through the original map.
310    let (a_buf, use_map_a) = if !plan.left_trace.is_empty() {
311        let trace_indices = plan.left_trace_indices(ia);
312        let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(a.dims()) };
313        strided_kernel::map_into(&mut mapped.view_mut(), a, &map_a)?;
314        let reduced = trace::reduce_trace_axes(&mapped.view(), &trace_indices)?;
315        (Some(reduced), false)
316    } else {
317        (None, true)
318    };
319    let a_view: StridedView<T> = match a_buf.as_ref() {
320        Some(buf) => buf.view(),
321        None => a.clone(),
322    };
323
324    let (b_buf, use_map_b) = if !plan.right_trace.is_empty() {
325        let trace_indices = plan.right_trace_indices(ib);
326        let mut mapped = unsafe { strided_view::StridedArray::<T>::col_major_uninit(b.dims()) };
327        strided_kernel::map_into(&mut mapped.view_mut(), b, &map_b)?;
328        let reduced = trace::reduce_trace_axes(&mapped.view(), &trace_indices)?;
329        (Some(reduced), false)
330    } else {
331        (None, true)
332    };
333    let b_view: StridedView<T> = match b_buf.as_ref() {
334        Some(buf) => buf.view(),
335        None => b.clone(),
336    };
337
338    let a_perm = a_view.permute(&plan.left_perm)?;
339    let b_perm = b_view.permute(&plan.right_perm)?;
340    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
341
342    // Element-wise fast path
343    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
344        let mul_fn = move |a_val: T, b_val: T| -> T {
345            let a_c = if use_map_a { map_a(a_val) } else { a_val };
346            let b_c = if use_map_b { map_b(b_val) } else { b_val };
347            alpha * a_c * b_c
348        };
349        zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
350        return Ok(());
351    }
352
353    let final_map_a: Box<dyn Fn(T) -> T> = if use_map_a {
354        Box::new(map_a)
355    } else {
356        Box::new(|x| x)
357    };
358    let final_map_b: Box<dyn Fn(T) -> T> = if use_map_b {
359        Box::new(map_b)
360    } else {
361        Box::new(|x| x)
362    };
363
364    bgemm_naive::bgemm_strided_into_with_map(
365        &mut c_perm,
366        &a_perm,
367        &b_perm,
368        plan.batch.len(),
369        plan.lo.len(),
370        plan.ro.len(),
371        plan.sum.len(),
372        alpha,
373        beta,
374        final_map_a,
375        final_map_b,
376    )?;
377
378    Ok(())
379}
380
381/// Binary einsum with a pluggable GEMM backend.
382///
383/// Like [`einsum2_into`] but works with any `ScalarBase` type and dispatches
384/// to the caller-provided GEMM backend `B`. Views must use `Identity` element
385/// operations (the default).
386///
387/// External crates can implement [`BgemmBackend`] for custom scalar types
388/// (e.g., tropical semiring) and pass the backend here.
389pub fn einsum2_with_backend_into<T, B, ID>(
390    c: StridedViewMut<T>,
391    a: &StridedView<T>,
392    b: &StridedView<T>,
393    ic: &[ID],
394    ia: &[ID],
395    ib: &[ID],
396    alpha: T,
397    beta: T,
398) -> Result<()>
399where
400    T: ScalarBase,
401    B: BgemmBackend<T> + BackendConfig,
402    ID: AxisId,
403{
404    let plan = Einsum2Plan::new(ia, ib, ic)?;
405    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
406
407    // Trace reduction (plain sum, Identity views)
408    let a_buf = if !plan.left_trace.is_empty() {
409        let trace_indices = plan.left_trace_indices(ia);
410        Some(trace::reduce_trace_axes(a, &trace_indices)?)
411    } else {
412        None
413    };
414    let a_view: StridedView<T> = match a_buf.as_ref() {
415        Some(buf) => buf.view(),
416        None => a.clone(),
417    };
418
419    let b_buf = if !plan.right_trace.is_empty() {
420        let trace_indices = plan.right_trace_indices(ib);
421        Some(trace::reduce_trace_axes(b, &trace_indices)?)
422    } else {
423        None
424    };
425    let b_view: StridedView<T> = match b_buf.as_ref() {
426        Some(buf) => buf.view(),
427        None => b.clone(),
428    };
429
430    // Permute to canonical order
431    let a_perm = a_view.permute(&plan.left_perm)?;
432    let b_perm = b_view.permute(&plan.right_perm)?;
433    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
434
435    let n_batch = plan.batch.len();
436    let n_lo = plan.lo.len();
437    let n_ro = plan.ro.len();
438    let n_sum = plan.sum.len();
439
440    // Element-wise fast path (all batch, no contraction)
441    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
442        if alpha == T::one() {
443            zip_map2_into(&mut c_perm, &a_perm, &b_perm, |a_val, b_val| a_val * b_val)?;
444        } else {
445            let mul_fn = move |a_val: T, b_val: T| -> T { alpha * a_val * b_val };
446            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
447        }
448        return Ok(());
449    }
450
451    // Prepare contiguous operands for GEMM
452    let a_op = contiguous::prepare_input_view_for_backend::<T, B>(&a_perm, n_batch, n_lo, n_sum)?;
453    let b_op = contiguous::prepare_input_view_for_backend::<T, B>(&b_perm, n_batch, n_sum, n_ro)?;
454    let mut c_op = contiguous::prepare_output_view_for_backend::<T, B>(
455        &mut c_perm,
456        n_batch,
457        n_lo,
458        n_ro,
459        beta,
460    )?;
461
462    // Dimension sizes (batch-last canonical order)
463    let lo_dims = &a_perm.dims()[..n_lo];
464    let sum_dims = &a_perm.dims()[n_lo..n_lo + n_sum];
465    let batch_dims = &a_perm.dims()[n_lo + n_sum..];
466    let ro_dims = &b_perm.dims()[n_sum..n_sum + n_ro];
467    let m: usize = lo_dims.iter().product::<usize>().max(1);
468    let k: usize = sum_dims.iter().product::<usize>().max(1);
469    let n: usize = ro_dims.iter().product::<usize>().max(1);
470
471    // GEMM via backend
472    B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta)?;
473
474    // Finalize (copy-back if needed)
475    c_op.finalize_into(&mut c_perm)?;
476
477    Ok(())
478}
479
480/// Internal GEMM dispatch using ContiguousOperand types.
481///
482/// Called after trace reduction and Op stripping. Handles:
483/// 1. Permutation to canonical order
484/// 2. Element-wise fast path (if applicable)
485/// 3. Contiguous preparation via `prepare_input_view`
486/// 4. GEMM via `ActiveBackend::bgemm_contiguous_into`
487/// 5. Finalize (copy-back if needed)
488#[cfg(any(
489    all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
490    all(
491        not(feature = "faer"),
492        any(
493            all(feature = "blas", not(feature = "blas-inject")),
494            all(feature = "blas-inject", not(feature = "blas"))
495        )
496    )
497))]
498fn einsum2_gemm_dispatch<T: Scalar>(
499    c: StridedViewMut<T>,
500    a: &StridedView<T>,
501    b: &StridedView<T>,
502    plan: &Einsum2Plan<impl AxisId>,
503    alpha: T,
504    beta: T,
505    conj_a: bool,
506    conj_b: bool,
507) -> Result<()>
508where
509    backend::ActiveBackend: BgemmBackend<T>,
510{
511    // 1. Permute to canonical order
512    let a_perm = a.permute(&plan.left_perm)?;
513    let b_perm = b.permute(&plan.right_perm)?;
514    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
515
516    // 2. Fast path: element-wise (all batch, no contraction)
517    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
518        if alpha == T::one() && !conj_a && !conj_b {
519            zip_map2_into(&mut c_perm, &a_perm, &b_perm, |a_val, b_val| a_val * b_val)?;
520        } else if alpha == T::one() {
521            let mul_fn = move |a_val: T, b_val: T| -> T {
522                let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
523                let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
524                a_c * b_c
525            };
526            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
527        } else {
528            let mul_fn = move |a_val: T, b_val: T| -> T {
529                let a_c = if conj_a { Conj::apply(a_val) } else { a_val };
530                let b_c = if conj_b { Conj::apply(b_val) } else { b_val };
531                alpha * a_c * b_c
532            };
533            zip_map2_into(&mut c_perm, &a_perm, &b_perm, mul_fn)?;
534        }
535        return Ok(());
536    }
537
538    // 3. Prepare contiguous operands
539    let n_batch = plan.batch.len();
540    let n_lo = plan.lo.len();
541    let n_ro = plan.ro.len();
542    let n_sum = plan.sum.len();
543
544    let a_op = contiguous::prepare_input_view(&a_perm, n_batch, n_lo, n_sum, conj_a)?;
545    let b_op = contiguous::prepare_input_view(&b_perm, n_batch, n_sum, n_ro, conj_b)?;
546    let mut c_op = contiguous::prepare_output_view(&mut c_perm, n_batch, n_lo, n_ro, beta)?;
547
548    // Compute fused dimension sizes
549    let lo_dims = &a_perm.dims()[..n_lo];
550    let sum_dims = &a_perm.dims()[n_lo..n_lo + n_sum];
551    let batch_dims = &a_perm.dims()[n_lo + n_sum..];
552    let ro_dims = &b_perm.dims()[n_sum..n_sum + n_ro];
553    let m: usize = lo_dims.iter().product::<usize>().max(1);
554    let k: usize = sum_dims.iter().product::<usize>().max(1);
555    let n: usize = ro_dims.iter().product::<usize>().max(1);
556
557    // 4. GEMM — dispatched through trait
558    backend::ActiveBackend::bgemm_contiguous_into(
559        &mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta,
560    )?;
561
562    // 5. Finalize
563    c_op.finalize_into(&mut c_perm)?;
564
565    Ok(())
566}
567
568/// Binary einsum accepting owned inputs for zero-copy optimization.
569///
570/// Same semantics as [`einsum2_into`] but accepts owned `StridedArray` inputs.
571/// When inputs have non-contiguous strides after permutation, ownership
572/// transfer avoids allocating separate buffers. For contiguous inputs,
573/// the behavior is identical.
574///
575/// `conj_a` and `conj_b` indicate whether to conjugate elements of A/B.
576#[cfg(any(
577    all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
578    all(
579        not(feature = "faer"),
580        any(
581            all(feature = "blas", not(feature = "blas-inject")),
582            all(feature = "blas-inject", not(feature = "blas"))
583        )
584    )
585))]
586pub fn einsum2_into_owned<T: Scalar, ID: AxisId>(
587    c: StridedViewMut<T>,
588    a: StridedArray<T>,
589    b: StridedArray<T>,
590    ic: &[ID],
591    ia: &[ID],
592    ib: &[ID],
593    alpha: T,
594    beta: T,
595    conj_a: bool,
596    conj_b: bool,
597) -> Result<()>
598where
599    backend::ActiveBackend: BgemmBackend<T>,
600{
601    // 1. Build plan
602    let plan = Einsum2Plan::new(ia, ib, ic)?;
603
604    // 2. Validate dimensions
605    validate_dimensions::<ID>(&plan, a.dims(), b.dims(), c.dims(), ia, ib, ic)?;
606
607    // 3. Trace reduction: reduce trace axes if present.
608    //    When trace reduction occurs, conjugation is applied during reduction,
609    //    so the conj flag becomes false. Otherwise keep the caller's flag.
610    let (a_for_gemm, conj_a_final) = if !plan.left_trace.is_empty() {
611        let trace_indices = plan.left_trace_indices(ia);
612        (trace::reduce_trace_axes(&a.view(), &trace_indices)?, false)
613    } else {
614        (a, conj_a)
615    };
616
617    let (b_for_gemm, conj_b_final) = if !plan.right_trace.is_empty() {
618        let trace_indices = plan.right_trace_indices(ib);
619        (trace::reduce_trace_axes(&b.view(), &trace_indices)?, false)
620    } else {
621        (b, conj_b)
622    };
623
624    // 4. Permute to canonical order (metadata-only on owned arrays)
625    let a_perm = a_for_gemm.permuted(&plan.left_perm)?;
626    let b_perm = b_for_gemm.permuted(&plan.right_perm)?;
627    let mut c_perm = c.permute(&plan.c_to_internal_perm)?;
628
629    let n_batch = plan.batch.len();
630    let n_lo = plan.lo.len();
631    let n_ro = plan.ro.len();
632    let n_sum = plan.sum.len();
633
634    // 5. Fast path: element-wise (all batch, no contraction)
635    if plan.sum.is_empty() && plan.lo.is_empty() && plan.ro.is_empty() && beta == T::zero() {
636        let mul_fn = move |a_val: T, b_val: T| -> T {
637            let a_c = if conj_a_final {
638                Conj::apply(a_val)
639            } else {
640                a_val
641            };
642            let b_c = if conj_b_final {
643                Conj::apply(b_val)
644            } else {
645                b_val
646            };
647            alpha * a_c * b_c
648        };
649        zip_map2_into(&mut c_perm, &a_perm.view(), &b_perm.view(), mul_fn)?;
650        return Ok(());
651    }
652
653    // 6. Extract dimension sizes BEFORE consuming arrays via prepare_input_owned
654    let a_dims_perm = a_perm.dims().to_vec();
655    let b_dims_perm = b_perm.dims().to_vec();
656
657    let lo_dims = &a_dims_perm[..n_lo];
658    let sum_dims = &a_dims_perm[n_lo..n_lo + n_sum];
659    let batch_dims = a_dims_perm[n_lo + n_sum..].to_vec();
660    let ro_dims = &b_dims_perm[n_sum..n_sum + n_ro];
661    let m: usize = lo_dims.iter().product::<usize>().max(1);
662    let k: usize = sum_dims.iter().product::<usize>().max(1);
663    let n: usize = ro_dims.iter().product::<usize>().max(1);
664
665    // 7. Prepare contiguous operands (owned path -- avoids extra copies)
666    let a_op = contiguous::prepare_input_owned(a_perm, n_batch, n_lo, n_sum, conj_a_final)?;
667    let b_op = contiguous::prepare_input_owned(b_perm, n_batch, n_sum, n_ro, conj_b_final)?;
668    let mut c_op = contiguous::prepare_output_view(&mut c_perm, n_batch, n_lo, n_ro, beta)?;
669
670    // 8. GEMM — dispatched through trait
671    backend::ActiveBackend::bgemm_contiguous_into(
672        &mut c_op,
673        &a_op,
674        &b_op,
675        &batch_dims,
676        m,
677        n,
678        k,
679        alpha,
680        beta,
681    )?;
682
683    // 9. Finalize
684    c_op.finalize_into(&mut c_perm)?;
685
686    Ok(())
687}
688
689/// Validate that dimensions match across operands for each axis group.
690fn validate_dimensions<ID: AxisId>(
691    plan: &Einsum2Plan<ID>,
692    a_dims: &[usize],
693    b_dims: &[usize],
694    c_dims: &[usize],
695    ia: &[ID],
696    ib: &[ID],
697    ic: &[ID],
698) -> Result<()> {
699    let find_dim = |labels: &[ID], dims: &[usize], id: &ID| -> usize {
700        labels
701            .iter()
702            .position(|x| x == id)
703            .map(|i| dims[i])
704            .unwrap()
705    };
706
707    // Batch: must match in A, B, and C
708    for id in &plan.batch {
709        let da = find_dim(ia, a_dims, id);
710        let db = find_dim(ib, b_dims, id);
711        let dc = find_dim(ic, c_dims, id);
712        if da != db || da != dc {
713            return Err(EinsumError::DimensionMismatch {
714                axis: format!("{:?}", id),
715                dim_a: da,
716                dim_b: db,
717            });
718        }
719    }
720
721    // Sum: must match in A and B
722    for id in &plan.sum {
723        let da = find_dim(ia, a_dims, id);
724        let db = find_dim(ib, b_dims, id);
725        if da != db {
726            return Err(EinsumError::DimensionMismatch {
727                axis: format!("{:?}", id),
728                dim_a: da,
729                dim_b: db,
730            });
731        }
732    }
733
734    // LO: must match in A and C
735    for id in &plan.lo {
736        let da = find_dim(ia, a_dims, id);
737        let dc = find_dim(ic, c_dims, id);
738        if da != dc {
739            return Err(EinsumError::DimensionMismatch {
740                axis: format!("{:?}", id),
741                dim_a: da,
742                dim_b: dc,
743            });
744        }
745    }
746
747    // RO: must match in B and C
748    for id in &plan.ro {
749        let db = find_dim(ib, b_dims, id);
750        let dc = find_dim(ic, c_dims, id);
751        if db != dc {
752            return Err(EinsumError::DimensionMismatch {
753                axis: format!("{:?}", id),
754                dim_a: db,
755                dim_b: dc,
756            });
757        }
758    }
759
760    Ok(())
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766    use strided_view::StridedArray;
767
768    #[test]
769    fn test_matmul_ij_jk_ik() {
770        // C_ik = A_ij * B_jk
771        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
772            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
773        });
774        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
775            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
776        });
777        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
778
779        einsum2_into(
780            c.view_mut(),
781            &a.view(),
782            &b.view(),
783            &['i', 'k'],
784            &['i', 'j'],
785            &['j', 'k'],
786            1.0,
787            0.0,
788        )
789        .unwrap();
790
791        assert_eq!(c.get(&[0, 0]), 19.0);
792        assert_eq!(c.get(&[0, 1]), 22.0);
793        assert_eq!(c.get(&[1, 0]), 43.0);
794        assert_eq!(c.get(&[1, 1]), 50.0);
795    }
796
797    #[test]
798    fn test_matmul_rect() {
799        // A: 2x3, B: 3x4, C: 2x4
800        let a =
801            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
802        let b =
803            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
804        let mut c = StridedArray::<f64>::row_major(&[2, 4]);
805
806        einsum2_into(
807            c.view_mut(),
808            &a.view(),
809            &b.view(),
810            &['i', 'k'],
811            &['i', 'j'],
812            &['j', 'k'],
813            1.0,
814            0.0,
815        )
816        .unwrap();
817
818        // A = [[1,2,3],[4,5,6]], B = [[1,2,3,4],[5,6,7,8],[9,10,11,12]]
819        assert_eq!(c.get(&[0, 0]), 38.0);
820        assert_eq!(c.get(&[1, 3]), 128.0);
821    }
822
823    #[test]
824    fn test_batched_matmul() {
825        // C_bik = A_bij * B_bjk
826        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
827            (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
828        });
829        let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
830            (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
831        });
832        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
833
834        einsum2_into(
835            c.view_mut(),
836            &a.view(),
837            &b.view(),
838            &['b', 'i', 'k'],
839            &['b', 'i', 'j'],
840            &['b', 'j', 'k'],
841            1.0,
842            0.0,
843        )
844        .unwrap();
845
846        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
847        // C0[0,0] = 1*1+2*3+3*5 = 22
848        assert_eq!(c.get(&[0, 0, 0]), 22.0);
849    }
850
851    #[test]
852    fn test_batched_matmul_col_major_output() {
853        // C_bik = A_bij * B_bjk with col-major output (same layout as opteinsum)
854        let a_data = vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
855        let b_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
856        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
857        let b = StridedArray::<f64>::from_parts(b_data, &[2, 2, 2], &[4, 2, 1], 0).unwrap();
858        let mut c = StridedArray::<f64>::col_major(&[2, 2, 2]);
859
860        einsum2_into(
861            c.view_mut(),
862            &a.view(),
863            &b.view(),
864            &['b', 'i', 'k'],
865            &['b', 'i', 'j'],
866            &['b', 'j', 'k'],
867            1.0,
868            0.0,
869        )
870        .unwrap();
871
872        // batch 0: I * [[1,2],[3,4]] = [[1,2],[3,4]]
873        assert_eq!(c.get(&[0, 0, 0]), 1.0);
874        assert_eq!(c.get(&[0, 0, 1]), 2.0);
875        assert_eq!(c.get(&[0, 1, 0]), 3.0);
876        assert_eq!(c.get(&[0, 1, 1]), 4.0);
877        // batch 1: 2I * [[5,6],[7,8]] = [[10,12],[14,16]]
878        assert_eq!(c.get(&[1, 0, 0]), 10.0);
879        assert_eq!(c.get(&[1, 1, 1]), 16.0);
880    }
881
882    #[test]
883    fn test_outer_product() {
884        // C_ij = A_i * B_j
885        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
886        let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
887        let mut c = StridedArray::<f64>::row_major(&[3, 4]);
888
889        einsum2_into(
890            c.view_mut(),
891            &a.view(),
892            &b.view(),
893            &['i', 'j'],
894            &['i'],
895            &['j'],
896            1.0,
897            0.0,
898        )
899        .unwrap();
900
901        assert_eq!(c.get(&[0, 0]), 1.0);
902        assert_eq!(c.get(&[2, 3]), 12.0);
903    }
904
905    #[test]
906    fn test_dot_product() {
907        // C = A_i * B_i (scalar output)
908        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
909        let b = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
910        let mut c = StridedArray::<f64>::row_major(&[]);
911
912        einsum2_into(
913            c.view_mut(),
914            &a.view(),
915            &b.view(),
916            &[] as &[char],
917            &['i'],
918            &['i'],
919            1.0,
920            0.0,
921        )
922        .unwrap();
923
924        // 1*1 + 2*2 + 3*3 = 14
925        assert_eq!(c.get(&[]), 14.0);
926    }
927
928    #[test]
929    fn test_alpha_beta() {
930        // C = 2*A*B + 3*C_old
931        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
932            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
933        });
934        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
935            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
936        });
937        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
938            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
939        });
940
941        einsum2_into(
942            c.view_mut(),
943            &a.view(),
944            &b.view(),
945            &['i', 'k'],
946            &['i', 'j'],
947            &['j', 'k'],
948            2.0,
949            3.0,
950        )
951        .unwrap();
952
953        // C = 2*I*B + 3*C_old
954        assert_eq!(c.get(&[0, 0]), 32.0); // 2*1 + 3*10
955        assert_eq!(c.get(&[1, 1]), 128.0); // 2*4 + 3*40
956    }
957
958    #[test]
959    fn test_transposed_output() {
960        // C_ki = A_ij * B_jk (output transposed)
961        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
962            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
963        });
964        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
965            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
966        });
967        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
968
969        einsum2_into(
970            c.view_mut(),
971            &a.view(),
972            &b.view(),
973            &['k', 'i'], // C indexed as (k, i) instead of (i, k)
974            &['i', 'j'],
975            &['j', 'k'],
976            1.0,
977            0.0,
978        )
979        .unwrap();
980
981        // Normal matmul result: C_ik = [[19,22],[43,50]]
982        // But C is indexed as (k, i), so C[k, i] = (A*B)[i, k]
983        assert_eq!(c.get(&[0, 0]), 19.0); // C[k=0, i=0]
984        assert_eq!(c.get(&[0, 1]), 43.0); // C[k=0, i=1]
985        assert_eq!(c.get(&[1, 0]), 22.0); // C[k=1, i=0]
986        assert_eq!(c.get(&[1, 1]), 50.0); // C[k=1, i=1]
987    }
988
989    #[test]
990    fn test_left_trace() {
991        // C_k = sum_j (sum_i A_ij) * B_jk
992        // left_trace=[i], sum=[j], ro=[k]
993        let a =
994            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
995        // A = [[1,2,3],[4,5,6]]
996        // sum over i: [5, 7, 9]
997        let b =
998            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
999        // B = [[1,2],[3,4],[5,6]]
1000        let mut c = StridedArray::<f64>::row_major(&[2]);
1001
1002        einsum2_into(
1003            c.view_mut(),
1004            &a.view(),
1005            &b.view(),
1006            &['k'],
1007            &['i', 'j'],
1008            &['j', 'k'],
1009            1.0,
1010            0.0,
1011        )
1012        .unwrap();
1013
1014        // C_k = sum_j [5,7,9][j] * B[j,k]
1015        // C[0] = 5*1 + 7*3 + 9*5 = 5 + 21 + 45 = 71
1016        // C[1] = 5*2 + 7*4 + 9*6 = 10 + 28 + 54 = 92
1017        assert_eq!(c.get(&[0]), 71.0);
1018        assert_eq!(c.get(&[1]), 92.0);
1019    }
1020
1021    #[test]
1022    fn test_u32_labels() {
1023        // Same as matmul but with u32 labels
1024        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1025            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1026        });
1027        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1028            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1029        });
1030        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1031
1032        einsum2_into(
1033            c.view_mut(),
1034            &a.view(),
1035            &b.view(),
1036            &[0u32, 2],
1037            &[0u32, 1],
1038            &[1u32, 2],
1039            1.0,
1040            0.0,
1041        )
1042        .unwrap();
1043
1044        assert_eq!(c.get(&[0, 0]), 19.0);
1045        assert_eq!(c.get(&[1, 1]), 50.0);
1046    }
1047
1048    #[test]
1049    fn test_complex_matmul() {
1050        use num_complex::Complex64;
1051        let i = Complex64::i();
1052
1053        // A = [[1+i, 2], [3, 4-i]]
1054        let a_vals = [
1055            [1.0 + i, Complex64::new(2.0, 0.0)],
1056            [Complex64::new(3.0, 0.0), 4.0 - i],
1057        ];
1058        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1059
1060        // B = [[1, i], [0, 1]]
1061        let b_vals = [
1062            [Complex64::new(1.0, 0.0), i],
1063            [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
1064        ];
1065        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1066
1067        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1068
1069        einsum2_into(
1070            c.view_mut(),
1071            &a.view(),
1072            &b.view(),
1073            &['i', 'k'],
1074            &['i', 'j'],
1075            &['j', 'k'],
1076            Complex64::new(1.0, 0.0),
1077            Complex64::new(0.0, 0.0),
1078        )
1079        .unwrap();
1080
1081        // C = A * B
1082        // C[0,0] = (1+i)*1 + 2*0 = 1+i
1083        // C[0,1] = (1+i)*i + 2*1 = i+i²+2 = i-1+2 = 1+i
1084        // C[1,0] = 3*1 + (4-i)*0 = 3
1085        // C[1,1] = 3*i + (4-i)*1 = 3i+4-i = 4+2i
1086        assert_eq!(c.get(&[0, 0]), 1.0 + i);
1087        assert_eq!(c.get(&[0, 1]), 1.0 + i);
1088        assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1089        assert_eq!(c.get(&[1, 1]), 4.0 + 2.0 * i);
1090    }
1091
1092    #[test]
1093    fn test_complex_matmul_with_conj() {
1094        use num_complex::Complex64;
1095        let i = Complex64::i();
1096
1097        // A = [[1+i, 2i], [3, 4-i]]
1098        let a_vals = [[1.0 + i, 2.0 * i], [Complex64::new(3.0, 0.0), 4.0 - i]];
1099        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1100
1101        // B = identity
1102        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| {
1103            if idx[0] == idx[1] {
1104                Complex64::new(1.0, 0.0)
1105            } else {
1106                Complex64::new(0.0, 0.0)
1107            }
1108        });
1109
1110        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1111
1112        // C = conj(A) * B = conj(A)
1113        let a_conj = a.view().conj();
1114        einsum2_into(
1115            c.view_mut(),
1116            &a_conj,
1117            &b.view(),
1118            &['i', 'k'],
1119            &['i', 'j'],
1120            &['j', 'k'],
1121            Complex64::new(1.0, 0.0),
1122            Complex64::new(0.0, 0.0),
1123        )
1124        .unwrap();
1125
1126        // conj(A) = [[1-i, -2i], [3, 4+i]]
1127        assert_eq!(c.get(&[0, 0]), 1.0 - i);
1128        assert_eq!(c.get(&[0, 1]), -2.0 * i);
1129        assert_eq!(c.get(&[1, 0]), Complex64::new(3.0, 0.0));
1130        assert_eq!(c.get(&[1, 1]), 4.0 + i);
1131    }
1132
1133    #[test]
1134    fn test_complex_matmul_with_conj_both() {
1135        use num_complex::Complex64;
1136        let i = Complex64::i();
1137
1138        // A = [[1+i, 0], [0, 2-i]]
1139        let a_vals = [
1140            [1.0 + i, Complex64::new(0.0, 0.0)],
1141            [Complex64::new(0.0, 0.0), 2.0 - i],
1142        ];
1143        let a = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| a_vals[idx[0]][idx[1]]);
1144
1145        // B = [[1, i], [0, 1+i]]
1146        let b_vals = [
1147            [Complex64::new(1.0, 0.0), i],
1148            [Complex64::new(0.0, 0.0), 1.0 + i],
1149        ];
1150        let b = StridedArray::<Complex64>::from_fn_row_major(&[2, 2], |idx| b_vals[idx[0]][idx[1]]);
1151
1152        let mut c = StridedArray::<Complex64>::row_major(&[2, 2]);
1153
1154        // C = conj(A) * conj(B)
1155        let a_conj = a.view().conj();
1156        let b_conj = b.view().conj();
1157        einsum2_into(
1158            c.view_mut(),
1159            &a_conj,
1160            &b_conj,
1161            &['i', 'k'],
1162            &['i', 'j'],
1163            &['j', 'k'],
1164            Complex64::new(1.0, 0.0),
1165            Complex64::new(0.0, 0.0),
1166        )
1167        .unwrap();
1168
1169        // conj(A) = [[1-i, 0], [0, 2+i]]
1170        // conj(B) = [[1, -i], [0, 1-i]]
1171        // C = conj(A) * conj(B)
1172        // C[0,0] = (1-i)*1 + 0*0 = 1-i
1173        // C[0,1] = (1-i)*(-i) + 0*(1-i) = -i+i² = -i-1 = -(1+i)
1174        // C[1,0] = 0*1 + (2+i)*0 = 0
1175        // C[1,1] = 0*(-i) + (2+i)*(1-i) = 2-2i+i-i² = 2-i+1 = 3-i
1176        assert_eq!(c.get(&[0, 0]), 1.0 - i);
1177        assert_eq!(c.get(&[0, 1]), -(1.0 + i));
1178        assert_eq!(c.get(&[1, 0]), Complex64::new(0.0, 0.0));
1179        assert_eq!(c.get(&[1, 1]), 3.0 - i);
1180    }
1181
1182    #[test]
1183    fn test_elementwise_hadamard() {
1184        // C_ijk = A_ijk * B_ijk — all batch, no contraction
1185        let a = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1186            (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64
1187        });
1188        let b = StridedArray::<f64>::from_fn_row_major(&[3, 4, 5], |idx| {
1189            (idx[0] * 20 + idx[1] * 5 + idx[2] + 1) as f64 * 0.1
1190        });
1191        let mut c = StridedArray::<f64>::row_major(&[3, 4, 5]);
1192
1193        einsum2_into(
1194            c.view_mut(),
1195            &a.view(),
1196            &b.view(),
1197            &['i', 'j', 'k'],
1198            &['i', 'j', 'k'],
1199            &['i', 'j', 'k'],
1200            1.0,
1201            0.0,
1202        )
1203        .unwrap();
1204
1205        // Spot check: C[0,0,0] = 1 * 0.1 = 0.1
1206        assert!((c.get(&[0, 0, 0]) - 0.1).abs() < 1e-12);
1207        // C[2,3,4] = 60 * 6.0 = 360
1208        assert!((c.get(&[2, 3, 4]) - 360.0).abs() < 1e-10);
1209    }
1210
1211    #[test]
1212    fn test_elementwise_hadamard_with_alpha() {
1213        let a =
1214            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1215        let b =
1216            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1217        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1218
1219        einsum2_into(
1220            c.view_mut(),
1221            &a.view(),
1222            &b.view(),
1223            &['i', 'j'],
1224            &['i', 'j'],
1225            &['i', 'j'],
1226            2.0,
1227            0.0,
1228        )
1229        .unwrap();
1230
1231        // C[0,0] = 2.0 * 1 * 1 = 2.0
1232        assert_eq!(c.get(&[0, 0]), 2.0);
1233        // C[1,2] = 2.0 * 6 * 6 = 72.0
1234        assert_eq!(c.get(&[1, 2]), 72.0);
1235    }
1236
1237    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1238    #[test]
1239    fn test_einsum2_owned_matmul() {
1240        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1241            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1242        });
1243        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1244            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1245        });
1246        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1247
1248        einsum2_into_owned(
1249            c.view_mut(),
1250            a,
1251            b,
1252            &['i', 'k'],
1253            &['i', 'j'],
1254            &['j', 'k'],
1255            1.0,
1256            0.0,
1257            false,
1258            false,
1259        )
1260        .unwrap();
1261
1262        assert_eq!(c.get(&[0, 0]), 19.0);
1263        assert_eq!(c.get(&[0, 1]), 22.0);
1264        assert_eq!(c.get(&[1, 0]), 43.0);
1265        assert_eq!(c.get(&[1, 1]), 50.0);
1266    }
1267
1268    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1269    #[test]
1270    fn test_einsum2_owned_batched() {
1271        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2, 3], |idx| {
1272            (idx[0] * 6 + idx[1] * 3 + idx[2] + 1) as f64
1273        });
1274        let b = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
1275            (idx[0] * 6 + idx[1] * 2 + idx[2] + 1) as f64
1276        });
1277        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
1278
1279        einsum2_into_owned(
1280            c.view_mut(),
1281            a,
1282            b,
1283            &['b', 'i', 'k'],
1284            &['b', 'i', 'j'],
1285            &['b', 'j', 'k'],
1286            1.0,
1287            0.0,
1288            false,
1289            false,
1290        )
1291        .unwrap();
1292
1293        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
1294        // C0[0,0] = 1*1+2*3+3*5 = 22
1295        assert_eq!(c.get(&[0, 0, 0]), 22.0);
1296    }
1297
1298    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1299    #[test]
1300    fn test_einsum2_owned_alpha_beta() {
1301        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1302            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]]
1303        });
1304        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1305            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1306        });
1307        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1308            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
1309        });
1310
1311        einsum2_into_owned(
1312            c.view_mut(),
1313            a,
1314            b,
1315            &['i', 'k'],
1316            &['i', 'j'],
1317            &['j', 'k'],
1318            2.0,
1319            3.0,
1320            false,
1321            false,
1322        )
1323        .unwrap();
1324
1325        // C = 2*I*B + 3*C_old
1326        assert_eq!(c.get(&[0, 0]), 32.0); // 2*1 + 3*10
1327        assert_eq!(c.get(&[1, 1]), 128.0); // 2*4 + 3*40
1328    }
1329
1330    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1331    #[test]
1332    fn test_einsum2_owned_elementwise() {
1333        // All batch, no contraction -- element-wise fast path
1334        let a =
1335            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1336        let b =
1337            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1338        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1339
1340        einsum2_into_owned(
1341            c.view_mut(),
1342            a,
1343            b,
1344            &['i', 'j'],
1345            &['i', 'j'],
1346            &['i', 'j'],
1347            2.0,
1348            0.0,
1349            false,
1350            false,
1351        )
1352        .unwrap();
1353
1354        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1355        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1356    }
1357
1358    #[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
1359    #[test]
1360    fn test_einsum2_owned_left_trace() {
1361        // C_k = sum_j (sum_i A_ij) * B_jk
1362        // left_trace=[i], sum=[j], ro=[k]
1363        let a =
1364            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1365        // A = [[1,2,3],[4,5,6]]
1366        // sum over i: [5, 7, 9]
1367        let b =
1368            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1369        // B = [[1,2],[3,4],[5,6]]
1370        let mut c = StridedArray::<f64>::row_major(&[2]);
1371
1372        einsum2_into_owned(
1373            c.view_mut(),
1374            a,
1375            b,
1376            &['k'],
1377            &['i', 'j'],
1378            &['j', 'k'],
1379            1.0,
1380            0.0,
1381            false,
1382            false,
1383        )
1384        .unwrap();
1385
1386        // C[0] = 5*1 + 7*3 + 9*5 = 5 + 21 + 45 = 71
1387        // C[1] = 5*2 + 7*4 + 9*6 = 10 + 28 + 54 = 92
1388        assert_eq!(c.get(&[0]), 71.0);
1389        assert_eq!(c.get(&[1]), 92.0);
1390    }
1391
1392    #[test]
1393    fn test_einsum2_naive_matmul() {
1394        // einsum2_naive_into with identity maps = regular matmul
1395        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1396            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1397        });
1398        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1399            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1400        });
1401        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1402
1403        einsum2_naive_into(
1404            c.view_mut(),
1405            &a.view(),
1406            &b.view(),
1407            &['i', 'k'],
1408            &['i', 'j'],
1409            &['j', 'k'],
1410            1.0,
1411            0.0,
1412            |x| x,
1413            |x| x,
1414        )
1415        .unwrap();
1416
1417        assert_eq!(c.get(&[0, 0]), 19.0);
1418        assert_eq!(c.get(&[0, 1]), 22.0);
1419        assert_eq!(c.get(&[1, 0]), 43.0);
1420        assert_eq!(c.get(&[1, 1]), 50.0);
1421    }
1422
1423    #[test]
1424    fn test_einsum2_naive_elementwise() {
1425        // Element-wise (all batch) with identity maps
1426        let a =
1427            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1428        let b =
1429            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1430        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1431
1432        einsum2_naive_into(
1433            c.view_mut(),
1434            &a.view(),
1435            &b.view(),
1436            &['i', 'j'],
1437            &['i', 'j'],
1438            &['i', 'j'],
1439            2.0,
1440            0.0,
1441            |x| x,
1442            |x| x,
1443        )
1444        .unwrap();
1445
1446        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1447        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1448    }
1449
1450    #[test]
1451    fn test_einsum2_naive_custom_type() {
1452        // Custom scalar type that does NOT implement ElementOpApply.
1453        // This demonstrates that einsum2_naive_into works with custom types.
1454        use num_traits::{One, Zero};
1455
1456        #[derive(Debug, Clone, Copy, PartialEq)]
1457        struct MyVal(f64);
1458
1459        impl Default for MyVal {
1460            fn default() -> Self {
1461                MyVal(0.0)
1462            }
1463        }
1464
1465        impl std::ops::Add for MyVal {
1466            type Output = Self;
1467            fn add(self, rhs: Self) -> Self {
1468                MyVal(self.0 + rhs.0)
1469            }
1470        }
1471
1472        impl std::ops::Mul for MyVal {
1473            type Output = Self;
1474            fn mul(self, rhs: Self) -> Self {
1475                MyVal(self.0 * rhs.0)
1476            }
1477        }
1478
1479        impl Zero for MyVal {
1480            fn zero() -> Self {
1481                MyVal(0.0)
1482            }
1483            fn is_zero(&self) -> bool {
1484                self.0 == 0.0
1485            }
1486        }
1487
1488        impl One for MyVal {
1489            fn one() -> Self {
1490                MyVal(1.0)
1491            }
1492        }
1493
1494        // C_ik = A_ij * B_jk (matrix multiply with custom type)
1495        let a = StridedArray::from_parts(
1496            vec![MyVal(1.0), MyVal(2.0), MyVal(3.0), MyVal(4.0)],
1497            &[2, 2],
1498            &[2, 1],
1499            0,
1500        )
1501        .unwrap();
1502        let b = StridedArray::from_parts(
1503            vec![MyVal(5.0), MyVal(6.0), MyVal(7.0), MyVal(8.0)],
1504            &[2, 2],
1505            &[2, 1],
1506            0,
1507        )
1508        .unwrap();
1509        let mut c = StridedArray::<MyVal>::col_major(&[2, 2]);
1510
1511        einsum2_naive_into(
1512            c.view_mut(),
1513            &a.view(),
1514            &b.view(),
1515            &['i', 'k'],
1516            &['i', 'j'],
1517            &['j', 'k'],
1518            MyVal(1.0),
1519            MyVal(0.0),
1520            |x| x,
1521            |x| x,
1522        )
1523        .unwrap();
1524
1525        // C = [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1526        //   = [[19, 22], [43, 50]]
1527        assert_eq!(c.get(&[0, 0]), MyVal(19.0));
1528        assert_eq!(c.get(&[0, 1]), MyVal(22.0));
1529        assert_eq!(c.get(&[1, 0]), MyVal(43.0));
1530        assert_eq!(c.get(&[1, 1]), MyVal(50.0));
1531    }
1532
1533    #[test]
1534    fn test_einsum2_naive_left_trace() {
1535        // C_k = sum_j (sum_i A_ij) * B_jk with map_a = identity
1536        let a =
1537            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1538        let b =
1539            StridedArray::<f64>::from_fn_row_major(&[3, 2], |idx| (idx[0] * 2 + idx[1] + 1) as f64);
1540        let mut c = StridedArray::<f64>::row_major(&[2]);
1541
1542        einsum2_naive_into(
1543            c.view_mut(),
1544            &a.view(),
1545            &b.view(),
1546            &['k'],
1547            &['i', 'j'],
1548            &['j', 'k'],
1549            1.0,
1550            0.0,
1551            |x| x,
1552            |x| x,
1553        )
1554        .unwrap();
1555
1556        assert_eq!(c.get(&[0]), 71.0);
1557        assert_eq!(c.get(&[1]), 92.0);
1558    }
1559
1560    /// Test backend that delegates to naive GEMM loops.
1561    struct TestNaiveBackend;
1562
1563    impl BackendConfig for TestNaiveBackend {
1564        const MATERIALIZES_CONJ: bool = false;
1565        const REQUIRES_UNIT_STRIDE: bool = false;
1566    }
1567
1568    impl BgemmBackend<f64> for TestNaiveBackend {
1569        fn bgemm_contiguous_into(
1570            c: &mut contiguous::ContiguousOperandMut<f64>,
1571            a: &contiguous::ContiguousOperand<f64>,
1572            b: &contiguous::ContiguousOperand<f64>,
1573            batch_dims: &[usize],
1574            m: usize,
1575            n: usize,
1576            k: usize,
1577            alpha: f64,
1578            beta: f64,
1579        ) -> strided_view::Result<()> {
1580            // Delegate to the contiguous naive GEMM loop
1581            let a_ptr = a.ptr();
1582            let b_ptr = b.ptr();
1583            let c_ptr = c.ptr();
1584            let a_rs = a.row_stride();
1585            let a_cs = a.col_stride();
1586            let b_rs = b.row_stride();
1587            let b_cs = b.col_stride();
1588            let c_rs = c.row_stride();
1589            let c_cs = c.col_stride();
1590
1591            let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1592            while batch_idx.next().is_some() {
1593                let a_base = batch_idx.offset(a.batch_strides());
1594                let b_base = batch_idx.offset(b.batch_strides());
1595                let c_base = batch_idx.offset(c.batch_strides());
1596
1597                for i in 0..m {
1598                    for j in 0..n {
1599                        let mut acc = 0.0f64;
1600                        for l in 0..k {
1601                            let a_val = unsafe {
1602                                *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1603                            };
1604                            let b_val = unsafe {
1605                                *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1606                            };
1607                            acc += a_val * b_val;
1608                        }
1609                        unsafe {
1610                            let c_elem =
1611                                c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1612                            if beta == 0.0 {
1613                                *c_elem = alpha * acc;
1614                            } else {
1615                                *c_elem = alpha * acc + beta * (*c_elem);
1616                            }
1617                        }
1618                    }
1619                }
1620            }
1621            Ok(())
1622        }
1623    }
1624
1625    #[test]
1626    fn test_einsum2_with_backend_matmul() {
1627        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1628            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1629        });
1630        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1631            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1632        });
1633        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1634
1635        einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1636            c.view_mut(),
1637            &a.view(),
1638            &b.view(),
1639            &['i', 'k'],
1640            &['i', 'j'],
1641            &['j', 'k'],
1642            1.0,
1643            0.0,
1644        )
1645        .unwrap();
1646
1647        assert_eq!(c.get(&[0, 0]), 19.0);
1648        assert_eq!(c.get(&[0, 1]), 22.0);
1649        assert_eq!(c.get(&[1, 0]), 43.0);
1650        assert_eq!(c.get(&[1, 1]), 50.0);
1651    }
1652
1653    #[test]
1654    fn test_einsum2_with_backend_elementwise() {
1655        let a =
1656            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1657        let b =
1658            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
1659        let mut c = StridedArray::<f64>::row_major(&[2, 3]);
1660
1661        einsum2_with_backend_into::<_, TestNaiveBackend, _>(
1662            c.view_mut(),
1663            &a.view(),
1664            &b.view(),
1665            &['i', 'j'],
1666            &['i', 'j'],
1667            &['i', 'j'],
1668            2.0,
1669            0.0,
1670        )
1671        .unwrap();
1672
1673        assert_eq!(c.get(&[0, 0]), 2.0); // 2 * 1 * 1
1674        assert_eq!(c.get(&[1, 2]), 72.0); // 2 * 6 * 6
1675    }
1676
1677    #[test]
1678    fn test_einsum2_with_backend_custom_type() {
1679        use num_traits::{One, Zero};
1680
1681        #[derive(Debug, Clone, Copy, PartialEq)]
1682        struct Tropical(f64);
1683
1684        impl Default for Tropical {
1685            fn default() -> Self {
1686                Tropical(0.0)
1687            }
1688        }
1689
1690        impl std::ops::Add for Tropical {
1691            type Output = Self;
1692            fn add(self, rhs: Self) -> Self {
1693                Tropical(self.0 + rhs.0)
1694            }
1695        }
1696
1697        impl std::ops::Mul for Tropical {
1698            type Output = Self;
1699            fn mul(self, rhs: Self) -> Self {
1700                Tropical(self.0 * rhs.0)
1701            }
1702        }
1703
1704        impl Zero for Tropical {
1705            fn zero() -> Self {
1706                Tropical(0.0)
1707            }
1708            fn is_zero(&self) -> bool {
1709                self.0 == 0.0
1710            }
1711        }
1712
1713        impl One for Tropical {
1714            fn one() -> Self {
1715                Tropical(1.0)
1716            }
1717        }
1718
1719        struct TropicalBackend;
1720
1721        impl BackendConfig for TropicalBackend {
1722            const MATERIALIZES_CONJ: bool = false;
1723            const REQUIRES_UNIT_STRIDE: bool = false;
1724        }
1725
1726        impl BgemmBackend<Tropical> for TropicalBackend {
1727            fn bgemm_contiguous_into(
1728                c: &mut contiguous::ContiguousOperandMut<Tropical>,
1729                a: &contiguous::ContiguousOperand<Tropical>,
1730                b: &contiguous::ContiguousOperand<Tropical>,
1731                batch_dims: &[usize],
1732                m: usize,
1733                n: usize,
1734                k: usize,
1735                alpha: Tropical,
1736                beta: Tropical,
1737            ) -> strided_view::Result<()> {
1738                // Simple naive GEMM for Tropical type
1739                let a_ptr = a.ptr();
1740                let b_ptr = b.ptr();
1741                let c_ptr = c.ptr();
1742                let a_rs = a.row_stride();
1743                let a_cs = a.col_stride();
1744                let b_rs = b.row_stride();
1745                let b_cs = b.col_stride();
1746                let c_rs = c.row_stride();
1747                let c_cs = c.col_stride();
1748
1749                let mut batch_idx = crate::util::MultiIndex::new(batch_dims);
1750                while batch_idx.next().is_some() {
1751                    let a_base = batch_idx.offset(a.batch_strides());
1752                    let b_base = batch_idx.offset(b.batch_strides());
1753                    let c_base = batch_idx.offset(c.batch_strides());
1754
1755                    for i in 0..m {
1756                        for j in 0..n {
1757                            let mut acc = Tropical::zero();
1758                            for l in 0..k {
1759                                let a_val = unsafe {
1760                                    *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
1761                                };
1762                                let b_val = unsafe {
1763                                    *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
1764                                };
1765                                acc = acc + a_val * b_val;
1766                            }
1767                            unsafe {
1768                                let c_elem =
1769                                    c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
1770                                if beta == Tropical::zero() {
1771                                    *c_elem = alpha * acc;
1772                                } else {
1773                                    *c_elem = alpha * acc + beta * (*c_elem);
1774                                }
1775                            }
1776                        }
1777                    }
1778                }
1779                Ok(())
1780            }
1781        }
1782
1783        let a = StridedArray::from_parts(
1784            vec![Tropical(1.0), Tropical(2.0), Tropical(3.0), Tropical(4.0)],
1785            &[2, 2],
1786            &[2, 1],
1787            0,
1788        )
1789        .unwrap();
1790        let b = StridedArray::from_parts(
1791            vec![Tropical(5.0), Tropical(6.0), Tropical(7.0), Tropical(8.0)],
1792            &[2, 2],
1793            &[2, 1],
1794            0,
1795        )
1796        .unwrap();
1797        let mut c = StridedArray::<Tropical>::col_major(&[2, 2]);
1798
1799        einsum2_with_backend_into::<_, TropicalBackend, _>(
1800            c.view_mut(),
1801            &a.view(),
1802            &b.view(),
1803            &['i', 'k'],
1804            &['i', 'j'],
1805            &['j', 'k'],
1806            Tropical(1.0),
1807            Tropical(0.0),
1808        )
1809        .unwrap();
1810
1811        // C = [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1812        //   = [[19, 22], [43, 50]]
1813        assert_eq!(c.get(&[0, 0]), Tropical(19.0));
1814        assert_eq!(c.get(&[0, 1]), Tropical(22.0));
1815        assert_eq!(c.get(&[1, 0]), Tropical(43.0));
1816        assert_eq!(c.get(&[1, 1]), Tropical(50.0));
1817    }
1818}