Skip to main content

strided_einsum2/
bgemm_faer.rs

1//! faer-backed batched GEMM kernel on strided views.
2//!
3//! Uses `faer::linalg::matmul::matmul` for SIMD-optimized matrix multiplication.
4//! When dimension groups cannot be fused into 2D matrices (non-contiguous strides),
5//! copies operands to contiguous column-major buffers before calling faer.
6
7use crate::contiguous::{alloc_col_major_uninit, ContiguousOperand, ContiguousOperandMut};
8use crate::util::{try_fuse_group, MultiIndex};
9use faer::linalg::matmul::matmul_with_conj;
10use faer::mat::{MatMut, MatRef};
11use faer::{Accum, Conj, Par};
12use faer_traits::ComplexField;
13use strided_view::{RawStridedMut, RawStridedRef, StridedArray, StridedView, StridedViewMut};
14
15/// Batched strided GEMM using faer: C = alpha * A * B + beta * C
16///
17/// Same interface as `bgemm_naive::bgemm_strided_into`. Uses faer's optimized
18/// matmul for all cases. When dimension groups have non-contiguous strides,
19/// copies operands to contiguous column-major buffers first using `strided_kernel::copy_into`.
20pub fn bgemm_strided_into<T>(
21    c: &mut StridedViewMut<T>,
22    a: &StridedView<T>,
23    b: &StridedView<T>,
24    _n_batch: usize,
25    n_lo: usize,
26    n_ro: usize,
27    n_sum: usize,
28    alpha: T,
29    beta: T,
30    conj_a: bool,
31    conj_b: bool,
32) -> strided_view::Result<()>
33where
34    T: ComplexField
35        + Copy
36        + strided_view::ElementOpApply
37        + Send
38        + Sync
39        + std::ops::Mul<Output = T>
40        + std::ops::Add<Output = T>
41        + num_traits::Zero
42        + num_traits::One
43        + PartialEq,
44{
45    let a = unsafe { RawStridedRef::new_unchecked(a.data(), a.dims(), a.strides(), a.offset()) };
46    let b = unsafe { RawStridedRef::new_unchecked(b.data(), b.dims(), b.strides(), b.offset()) };
47    let c_dims = c.dims().to_vec();
48    let c_strides = c.strides().to_vec();
49    let c_offset = c.offset();
50    let c = unsafe { RawStridedMut::new_unchecked(c.data_mut(), &c_dims, &c_strides, c_offset) };
51    bgemm_raw_strided_into(
52        c, a, b, _n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
53    )
54}
55
56/// Batched strided GEMM on raw borrowed layout metadata.
57///
58/// This is equivalent to [`bgemm_strided_into`], but avoids constructing
59/// [`StridedView`] wrappers when the caller already has borrowed
60/// dims/strides/offset metadata.
61#[allow(clippy::too_many_arguments)]
62pub(crate) fn bgemm_raw_strided_into<T>(
63    c: RawStridedMut<'_, T>,
64    a: RawStridedRef<'_, T>,
65    b: RawStridedRef<'_, T>,
66    n_batch: usize,
67    n_lo: usize,
68    n_ro: usize,
69    n_sum: usize,
70    alpha: T,
71    beta: T,
72    conj_a: bool,
73    conj_b: bool,
74) -> strided_view::Result<()>
75where
76    T: ComplexField
77        + Copy
78        + strided_view::ElementOpApply
79        + Send
80        + Sync
81        + std::ops::Mul<Output = T>
82        + std::ops::Add<Output = T>
83        + num_traits::Zero
84        + num_traits::One
85        + PartialEq,
86{
87    validate_bgemm_shapes(&c, &a, &b, n_batch, n_lo, n_ro, n_sum)?;
88    unsafe {
89        bgemm_raw_strided_into_unchecked(
90            c, a, b, n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
91        )
92    }
93}
94
95/// Batched strided GEMM on raw borrowed layout metadata without validation.
96///
97/// # Safety
98/// The caller must ensure:
99/// - all raw strided operands are in bounds,
100/// - `n_lo`, `n_ro`, `n_sum`, and `n_batch` partition operand ranks as
101///   `[lo, sum, batch]`, `[sum, ro, batch]`, and `[lo, ro, batch]`,
102/// - matching dimension groups have identical extents,
103/// - `c` does not alias `a` or `b` in a way that violates Rust's mutable access.
104#[allow(clippy::too_many_arguments)]
105pub(crate) unsafe fn bgemm_raw_strided_into_unchecked<T>(
106    mut c: RawStridedMut<'_, T>,
107    a: RawStridedRef<'_, T>,
108    b: RawStridedRef<'_, T>,
109    _n_batch: usize,
110    n_lo: usize,
111    n_ro: usize,
112    n_sum: usize,
113    alpha: T,
114    beta: T,
115    conj_a: bool,
116    conj_b: bool,
117) -> strided_view::Result<()>
118where
119    T: ComplexField
120        + Copy
121        + strided_view::ElementOpApply
122        + Send
123        + Sync
124        + std::ops::Mul<Output = T>
125        + std::ops::Add<Output = T>
126        + num_traits::Zero
127        + num_traits::One
128        + PartialEq,
129{
130    let a_dims = a.dims();
131    let b_dims = b.dims();
132    let a_strides = a.strides();
133    let b_strides = b.strides();
134    let c_strides = c.strides();
135
136    // Extract dimension groups (batch-last canonical order)
137    // A: [lo, sum, batch], B: [sum, ro, batch], C: [lo, ro, batch]
138    let lo_dims = &a_dims[..n_lo];
139    let sum_dims = &a_dims[n_lo..n_lo + n_sum];
140    let batch_dims = &a_dims[n_lo + n_sum..];
141    let ro_dims = &b_dims[n_sum..n_sum + n_ro];
142
143    if c.dims().iter().any(|&dim| dim == 0) {
144        return Ok(());
145    }
146
147    // Fused sizes for the matrix multiply
148    let m: usize = lo_dims.iter().product::<usize>().max(1);
149    let k: usize = sum_dims.iter().product::<usize>().max(1);
150    let n: usize = ro_dims.iter().product::<usize>().max(1);
151
152    // Extract stride groups (batch-last)
153    let a_lo_strides = &a_strides[..n_lo];
154    let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
155    let b_sum_strides = &b_strides[..n_sum];
156    let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
157    let c_lo_strides = &c_strides[..n_lo];
158    let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
159
160    // Try to fuse each dimension group
161    let fused_a_lo = try_fuse_group(lo_dims, a_lo_strides);
162    let fused_a_sum = try_fuse_group(sum_dims, a_sum_strides);
163    let fused_b_sum = try_fuse_group(sum_dims, b_sum_strides);
164    let fused_b_ro = try_fuse_group(ro_dims, b_ro_strides);
165    let fused_c_lo = try_fuse_group(lo_dims, c_lo_strides);
166    let fused_c_ro = try_fuse_group(ro_dims, c_ro_strides);
167
168    let a_needs_copy = fused_a_lo.is_none() || fused_a_sum.is_none();
169    let b_needs_copy = fused_b_sum.is_none() || fused_b_ro.is_none();
170    let c_needs_copy = fused_c_lo.is_none() || fused_c_ro.is_none();
171
172    let n_a_inner = n_lo + n_sum;
173    let n_b_inner = n_sum + n_ro;
174    let n_c_inner = n_lo + n_ro;
175
176    // Copy A to contiguous column-major if inner dims aren't fusable
177    let a_contig_buf: Option<StridedArray<T>>;
178    let (a_ptr, a_row_stride, a_col_stride);
179    if a_needs_copy {
180        let mut buf = alloc_col_major_uninit(a.dims());
181        strided_kernel::copy_into(&mut buf.view_mut(), &a.as_view())?;
182        a_ptr = buf.view().ptr();
183        // Col-major inner A [lo..., sum...]: lo stride = 1, sum stride = m
184        a_row_stride = if m == 0 { 0 } else { 1isize };
185        a_col_stride = m as isize;
186        a_contig_buf = Some(buf);
187    } else {
188        let (_, rs) = fused_a_lo.unwrap();
189        let (_, cs) = fused_a_sum.unwrap();
190        a_ptr = a.ptr();
191        a_row_stride = rs;
192        a_col_stride = cs;
193        a_contig_buf = None;
194    }
195    let a_batch_strides: &[isize] = match a_contig_buf.as_ref() {
196        Some(buf) => &buf.strides()[n_a_inner..],
197        None => &a_strides[n_a_inner..],
198    };
199
200    // Copy B to contiguous column-major if inner dims aren't fusable
201    let b_contig_buf: Option<StridedArray<T>>;
202    let (b_ptr, b_row_stride, b_col_stride);
203    if b_needs_copy {
204        let mut buf = alloc_col_major_uninit(b.dims());
205        strided_kernel::copy_into(&mut buf.view_mut(), &b.as_view())?;
206        b_ptr = buf.view().ptr();
207        // Col-major inner B [sum..., ro...]: sum stride = 1, ro stride = k
208        b_row_stride = if k == 0 { 0 } else { 1isize };
209        b_col_stride = k as isize;
210        b_contig_buf = Some(buf);
211    } else {
212        let (_, rs) = fused_b_sum.unwrap();
213        let (_, cs) = fused_b_ro.unwrap();
214        b_ptr = b.ptr();
215        b_row_stride = rs;
216        b_col_stride = cs;
217        b_contig_buf = None;
218    }
219    let b_batch_strides: &[isize] = match b_contig_buf.as_ref() {
220        Some(buf) => &buf.strides()[n_b_inner..],
221        None => &b_strides[n_b_inner..],
222    };
223
224    // Copy C to contiguous column-major if inner dims aren't fusable
225    let c_contig_buf: Option<StridedArray<T>>;
226    let (c_ptr, c_row_stride, c_col_stride);
227    if c_needs_copy {
228        let mut buf = alloc_col_major_uninit(c.dims());
229        if beta != T::zero() {
230            let c_view: StridedView<'_, T> = c.as_view();
231            strided_kernel::copy_into(&mut buf.view_mut(), &c_view)?;
232        }
233        c_ptr = buf.view_mut().as_mut_ptr();
234        // Col-major inner C [lo..., ro...]: lo stride = 1, ro stride = m
235        c_row_stride = if m == 0 { 0 } else { 1isize };
236        c_col_stride = m as isize;
237        c_contig_buf = Some(buf);
238    } else {
239        let (_, rs) = fused_c_lo.unwrap();
240        let (_, cs) = fused_c_ro.unwrap();
241        c_ptr = c.as_mut_ptr();
242        c_row_stride = rs;
243        c_col_stride = cs;
244        c_contig_buf = None;
245    }
246    let c_batch_strides: &[isize] = match c_contig_buf.as_ref() {
247        Some(buf) => &buf.strides()[n_c_inner..],
248        None => &c_strides[n_c_inner..],
249    };
250
251    let is_beta_zero = beta == T::zero();
252    let is_beta_one = beta == T::one();
253
254    // Determine accumulation mode
255    let accum = if is_beta_zero {
256        Accum::Replace
257    } else {
258        Accum::Add
259    };
260
261    let cj_a = if conj_a { Conj::Yes } else { Conj::No };
262    let cj_b = if conj_b { Conj::Yes } else { Conj::No };
263
264    // Inline closure for per-batch GEMM (shared between fast and slow paths)
265    let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
266        // Pre-scale C by beta if beta is not 0 or 1
267        if !is_beta_zero && !is_beta_one {
268            let c_base = unsafe { c_ptr.offset(c_batch_off) };
269            for i in 0..m {
270                for j in 0..n {
271                    let offset = i as isize * c_row_stride + j as isize * c_col_stride;
272                    unsafe {
273                        let elem = c_base.offset(offset);
274                        *elem = beta * *elem;
275                    }
276                }
277            }
278        }
279
280        unsafe {
281            let a_mat: MatRef<'_, T> =
282                MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
283            let b_mat: MatRef<'_, T> =
284                MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
285            let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
286                c_ptr.offset(c_batch_off),
287                m,
288                n,
289                c_row_stride,
290                c_col_stride,
291            );
292
293            matmul_with_conj(c_mat, accum, a_mat, cj_a, b_mat, cj_b, alpha, Par::rayon(0));
294        }
295    };
296
297    // Fast path: when batch dims are contiguous for all operands, use pointer
298    // increments instead of MultiIndex carry-based iteration.
299    let fused_a = try_fuse_group(batch_dims, a_batch_strides);
300    let fused_b = try_fuse_group(batch_dims, b_batch_strides);
301    let fused_c = try_fuse_group(batch_dims, c_batch_strides);
302
303    if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
304        (fused_a, fused_b, fused_c)
305    {
306        let mut a_off = 0isize;
307        let mut b_off = 0isize;
308        let mut c_off = 0isize;
309        for _ in 0..total {
310            do_batch(a_off, b_off, c_off);
311            a_off += a_step;
312            b_off += b_step;
313            c_off += c_step;
314        }
315    } else {
316        let mut batch_iter = MultiIndex::new(batch_dims);
317        while batch_iter.next().is_some() {
318            let a_batch_off = batch_iter.offset(a_batch_strides);
319            let b_batch_off = batch_iter.offset(b_batch_strides);
320            let c_batch_off = batch_iter.offset(c_batch_strides);
321            do_batch(a_batch_off, b_batch_off, c_batch_off);
322        }
323    }
324
325    // If C was copied to a temp buffer, copy the result back
326    if let Some(ref c_buf) = c_contig_buf {
327        let mut c_view = c.as_view_mut();
328        strided_kernel::copy_into(&mut c_view, &c_buf.view())?;
329    }
330
331    Ok(())
332}
333
334fn validate_bgemm_shapes<T>(
335    c: &RawStridedMut<'_, T>,
336    a: &RawStridedRef<'_, T>,
337    b: &RawStridedRef<'_, T>,
338    n_batch: usize,
339    n_lo: usize,
340    n_ro: usize,
341    n_sum: usize,
342) -> strided_view::Result<()> {
343    let a_rank = n_lo + n_sum + n_batch;
344    let b_rank = n_sum + n_ro + n_batch;
345    let c_rank = n_lo + n_ro + n_batch;
346    if a.dims().len() != a_rank {
347        return Err(strided_view::StridedError::RankMismatch(
348            a_rank,
349            a.dims().len(),
350        ));
351    }
352    if b.dims().len() != b_rank {
353        return Err(strided_view::StridedError::RankMismatch(
354            b_rank,
355            b.dims().len(),
356        ));
357    }
358    if c.dims().len() != c_rank {
359        return Err(strided_view::StridedError::RankMismatch(
360            c_rank,
361            c.dims().len(),
362        ));
363    }
364
365    let lo_dims = &a.dims()[..n_lo];
366    let sum_dims = &a.dims()[n_lo..n_lo + n_sum];
367    let batch_dims = &a.dims()[n_lo + n_sum..];
368    let ro_dims = &b.dims()[n_sum..n_sum + n_ro];
369
370    if &b.dims()[..n_sum] != sum_dims {
371        return Err(strided_view::StridedError::ShapeMismatch(
372            sum_dims.to_vec(),
373            b.dims()[..n_sum].to_vec(),
374        ));
375    }
376    if &b.dims()[n_sum + n_ro..] != batch_dims {
377        return Err(strided_view::StridedError::ShapeMismatch(
378            batch_dims.to_vec(),
379            b.dims()[n_sum + n_ro..].to_vec(),
380        ));
381    }
382    if &c.dims()[..n_lo] != lo_dims {
383        return Err(strided_view::StridedError::ShapeMismatch(
384            lo_dims.to_vec(),
385            c.dims()[..n_lo].to_vec(),
386        ));
387    }
388    if &c.dims()[n_lo..n_lo + n_ro] != ro_dims {
389        return Err(strided_view::StridedError::ShapeMismatch(
390            ro_dims.to_vec(),
391            c.dims()[n_lo..n_lo + n_ro].to_vec(),
392        ));
393    }
394    if &c.dims()[n_lo + n_ro..] != batch_dims {
395        return Err(strided_view::StridedError::ShapeMismatch(
396            batch_dims.to_vec(),
397            c.dims()[n_lo + n_ro..].to_vec(),
398        ));
399    }
400    Ok(())
401}
402
403/// Batched GEMM on pre-contiguous operands.
404///
405/// Operands must already have contiguous inner dimensions (prepared via
406/// `prepare_input_*` and `prepare_output_*` in the `contiguous` module).
407///
408/// - `batch_dims`: sizes of the batch dimensions
409/// - `m`: fused lo dimension size (number of rows of A/C)
410/// - `n`: fused ro dimension size (number of cols of B/C)
411/// - `k`: fused sum dimension size (inner dimension)
412pub fn bgemm_contiguous_into<T>(
413    c: &mut ContiguousOperandMut<T>,
414    a: &ContiguousOperand<T>,
415    b: &ContiguousOperand<T>,
416    batch_dims: &[usize],
417    m: usize,
418    n: usize,
419    k: usize,
420    alpha: T,
421    beta: T,
422) -> strided_view::Result<()>
423where
424    T: ComplexField
425        + Copy
426        + strided_view::ElementOpApply
427        + Send
428        + Sync
429        + std::ops::Mul<Output = T>
430        + std::ops::Add<Output = T>
431        + num_traits::Zero
432        + num_traits::One
433        + PartialEq,
434{
435    let is_beta_zero = beta == T::zero();
436    let is_beta_one = beta == T::one();
437
438    let accum = if is_beta_zero {
439        Accum::Replace
440    } else {
441        Accum::Add
442    };
443
444    let a_batch_strides = a.batch_strides();
445    let b_batch_strides = b.batch_strides();
446    let c_batch_strides = c.batch_strides();
447
448    let a_ptr = a.ptr();
449    let b_ptr = b.ptr();
450    let c_ptr = c.ptr();
451    let a_row_stride = a.row_stride();
452    let a_col_stride = a.col_stride();
453    let b_row_stride = b.row_stride();
454    let b_col_stride = b.col_stride();
455    let c_row_stride = c.row_stride();
456    let c_col_stride = c.col_stride();
457
458    let conj_a = if a.conj() { Conj::Yes } else { Conj::No };
459    let conj_b = if b.conj() { Conj::Yes } else { Conj::No };
460
461    // Inline closure for per-batch GEMM (shared between fast and slow paths)
462    let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
463        // Pre-scale C by beta if beta is not 0 or 1
464        if !is_beta_zero && !is_beta_one {
465            let c_base = unsafe { c_ptr.offset(c_batch_off) };
466            for i in 0..m {
467                for j in 0..n {
468                    let offset = i as isize * c_row_stride + j as isize * c_col_stride;
469                    unsafe {
470                        let elem = c_base.offset(offset);
471                        *elem = beta * *elem;
472                    }
473                }
474            }
475        }
476
477        unsafe {
478            let a_mat: MatRef<'_, T> =
479                MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
480            let b_mat: MatRef<'_, T> =
481                MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
482            let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
483                c_ptr.offset(c_batch_off),
484                m,
485                n,
486                c_row_stride,
487                c_col_stride,
488            );
489
490            matmul_with_conj(
491                c_mat,
492                accum,
493                a_mat,
494                conj_a,
495                b_mat,
496                conj_b,
497                alpha,
498                Par::rayon(0),
499            );
500        }
501    };
502
503    // Fast path: when batch dims are contiguous for all operands, use pointer
504    // increments instead of MultiIndex carry-based iteration.
505    let fused_a = try_fuse_group(batch_dims, a_batch_strides);
506    let fused_b = try_fuse_group(batch_dims, b_batch_strides);
507    let fused_c = try_fuse_group(batch_dims, c_batch_strides);
508
509    if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
510        (fused_a, fused_b, fused_c)
511    {
512        let mut a_off = 0isize;
513        let mut b_off = 0isize;
514        let mut c_off = 0isize;
515        for _ in 0..total {
516            do_batch(a_off, b_off, c_off);
517            a_off += a_step;
518            b_off += b_step;
519            c_off += c_step;
520        }
521    } else {
522        let mut batch_iter = MultiIndex::new(batch_dims);
523        while batch_iter.next().is_some() {
524            let a_batch_off = batch_iter.offset(a_batch_strides);
525            let b_batch_off = batch_iter.offset(b_batch_strides);
526            let c_batch_off = batch_iter.offset(c_batch_strides);
527            do_batch(a_batch_off, b_batch_off, c_batch_off);
528        }
529    }
530
531    Ok(())
532}
533
534use crate::backend::{Backend, FaerBackend};
535
536impl<T> Backend<T> for FaerBackend
537where
538    T: crate::ScalarBase + strided_view::ElementOpApply + ComplexField,
539{
540    const MATERIALIZES_CONJ: bool = false;
541    const REQUIRES_UNIT_STRIDE: bool = false;
542
543    fn bgemm_contiguous_into(
544        c: &mut ContiguousOperandMut<T>,
545        a: &ContiguousOperand<T>,
546        b: &ContiguousOperand<T>,
547        batch_dims: &[usize],
548        m: usize,
549        n: usize,
550        k: usize,
551        alpha: T,
552        beta: T,
553    ) -> strided_view::Result<()> {
554        // Delegate to the existing free function in this module
555        bgemm_contiguous_into(c, a, b, batch_dims, m, n, k, alpha, beta)
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use strided_view::StridedArray;
563
564    #[test]
565    fn test_faer_bgemm_2x2() {
566        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
567            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
568        });
569        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
570            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
571        });
572        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
573
574        bgemm_strided_into(
575            &mut c.view_mut(),
576            &a.view(),
577            &b.view(),
578            0,
579            1,
580            1,
581            1,
582            1.0,
583            0.0,
584            false,
585            false,
586        )
587        .unwrap();
588
589        assert_eq!(c.get(&[0, 0]), 19.0);
590        assert_eq!(c.get(&[0, 1]), 22.0);
591        assert_eq!(c.get(&[1, 0]), 43.0);
592        assert_eq!(c.get(&[1, 1]), 50.0);
593    }
594
595    fn raw_bgemm_2x2<T>(one: T, zero: T) -> Vec<T>
596    where
597        T: ComplexField
598            + Copy
599            + strided_view::ElementOpApply
600            + Send
601            + Sync
602            + std::ops::Mul<Output = T>
603            + std::ops::Add<Output = T>
604            + num_traits::Zero
605            + num_traits::One
606            + PartialEq
607            + From<f32>,
608    {
609        let dims = [2, 2];
610        let strides = [2, 1];
611        let a_data = [T::from(1.0), T::from(2.0), T::from(3.0), T::from(4.0)];
612        let b_data = [T::from(5.0), T::from(6.0), T::from(7.0), T::from(8.0)];
613        let mut c_data = vec![zero; 4];
614        let a = RawStridedRef::new(&a_data, &dims, &strides, 0).unwrap();
615        let b = RawStridedRef::new(&b_data, &dims, &strides, 0).unwrap();
616        let c = RawStridedMut::new(&mut c_data, &dims, &strides, 0).unwrap();
617        bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, one, zero, false, false).unwrap();
618        c_data
619    }
620
621    #[test]
622    fn test_faer_raw_bgemm_f64() {
623        assert_eq!(raw_bgemm_2x2(1.0f64, 0.0), vec![19.0, 22.0, 43.0, 50.0]);
624    }
625
626    #[test]
627    fn test_faer_raw_bgemm_f32() {
628        assert_eq!(raw_bgemm_2x2(1.0f32, 0.0), vec![19.0f32, 22.0, 43.0, 50.0]);
629    }
630
631    #[test]
632    fn test_faer_raw_bgemm_complex_conj() {
633        use num_complex::Complex64;
634
635        let i = Complex64::i();
636        let dims = [2, 2];
637        let strides = [2, 1];
638        let a_data = [
639            Complex64::new(1.0, 0.0) + i,
640            Complex64::new(2.0, 0.0),
641            Complex64::new(3.0, 0.0),
642            Complex64::new(4.0, 0.0) - i,
643        ];
644        let b_data = [
645            Complex64::new(1.0, 0.0),
646            Complex64::new(0.0, 0.0),
647            Complex64::new(0.0, 0.0),
648            Complex64::new(1.0, 0.0),
649        ];
650        let mut c_data = vec![Complex64::new(0.0, 0.0); 4];
651        let a = RawStridedRef::new(&a_data, &dims, &strides, 0).unwrap();
652        let b = RawStridedRef::new(&b_data, &dims, &strides, 0).unwrap();
653        let c = RawStridedMut::new(&mut c_data, &dims, &strides, 0).unwrap();
654        bgemm_raw_strided_into(
655            c,
656            a,
657            b,
658            0,
659            1,
660            1,
661            1,
662            Complex64::new(1.0, 0.0),
663            Complex64::new(0.0, 0.0),
664            true,
665            false,
666        )
667        .unwrap();
668        assert_eq!(
669            c_data,
670            vec![
671                Complex64::new(1.0, -1.0),
672                Complex64::new(2.0, 0.0),
673                Complex64::new(3.0, 0.0),
674                Complex64::new(4.0, 1.0),
675            ]
676        );
677    }
678
679    #[test]
680    fn test_faer_raw_bgemm_checked_shape_mismatch() {
681        let a_dims = [2, 2];
682        let b_dims = [3, 2];
683        let c_dims = [2, 2];
684        let a_strides = [2, 1];
685        let b_strides = [2, 1];
686        let c_strides = [2, 1];
687        let a_data = [1.0, 2.0, 3.0, 4.0];
688        let b_data = [0.0; 6];
689        let mut c_data = [0.0; 4];
690        let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
691        let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
692        let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
693        let err = bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false).unwrap_err();
694        assert!(matches!(
695            err,
696            strided_view::StridedError::ShapeMismatch(_, _)
697        ));
698    }
699
700    #[test]
701    fn test_faer_bgemm_rect() {
702        let a =
703            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
704        let b =
705            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
706        let mut c = StridedArray::<f64>::row_major(&[2, 4]);
707
708        bgemm_strided_into(
709            &mut c.view_mut(),
710            &a.view(),
711            &b.view(),
712            0,
713            1,
714            1,
715            1,
716            1.0,
717            0.0,
718            false,
719            false,
720        )
721        .unwrap();
722
723        assert_eq!(c.get(&[0, 0]), 38.0);
724        assert_eq!(c.get(&[1, 3]), 128.0);
725    }
726
727    #[test]
728    fn test_faer_bgemm_batched() {
729        // Batch-last: A: [lo, sum, batch]=[2,3,2], B: [sum, ro, batch]=[3,2,2], C: [lo, ro, batch]=[2,2,2]
730        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
731            (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
732        });
733        let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
734            (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
735        });
736        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
737
738        bgemm_strided_into(
739            &mut c.view_mut(),
740            &a.view(),
741            &b.view(),
742            1,
743            1,
744            1,
745            1,
746            1.0,
747            0.0,
748            false,
749            false,
750        )
751        .unwrap();
752
753        // C: [lo, ro, batch]
754        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
755        // C0[0,0] = 1*1+2*3+3*5 = 22
756        assert_eq!(c.get(&[0, 0, 0]), 22.0);
757    }
758
759    #[test]
760    fn test_faer_bgemm_beta_zero() {
761        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
762            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
763        });
764        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
765            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
766        });
767        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
768            [[100.0, 200.0], [300.0, 400.0]][idx[0]][idx[1]]
769        });
770
771        bgemm_strided_into(
772            &mut c.view_mut(),
773            &a.view(),
774            &b.view(),
775            0,
776            1,
777            1,
778            1,
779            1.0,
780            0.0, // beta=0: C_old should be ignored
781            false,
782            false,
783        )
784        .unwrap();
785
786        assert_eq!(c.get(&[0, 0]), 19.0);
787        assert_eq!(c.get(&[1, 1]), 50.0);
788    }
789
790    #[test]
791    fn test_faer_bgemm_beta_one() {
792        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
793            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
794        });
795        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
796            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
797        });
798        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
799            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
800        });
801
802        bgemm_strided_into(
803            &mut c.view_mut(),
804            &a.view(),
805            &b.view(),
806            0,
807            1,
808            1,
809            1,
810            1.0,
811            1.0, // beta=1: C = A*B + C_old
812            false,
813            false,
814        )
815        .unwrap();
816
817        // C[0,0] = 1*1+0*3 + 10 = 11
818        assert_eq!(c.get(&[0, 0]), 11.0);
819        // C[1,1] = 0*2+1*4 + 40 = 44
820        assert_eq!(c.get(&[1, 1]), 44.0);
821    }
822
823    #[test]
824    fn test_faer_bgemm_alpha_beta() {
825        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
826            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
827        });
828        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
829            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
830        });
831        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
832            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
833        });
834
835        bgemm_strided_into(
836            &mut c.view_mut(),
837            &a.view(),
838            &b.view(),
839            0,
840            1,
841            1,
842            1,
843            2.0,
844            3.0, // C = 2*I*B + 3*C_old
845            false,
846            false,
847        )
848        .unwrap();
849
850        // C[0,0] = 2*1 + 3*10 = 32
851        assert_eq!(c.get(&[0, 0]), 32.0);
852        // C[1,1] = 2*4 + 3*40 = 128
853        assert_eq!(c.get(&[1, 1]), 128.0);
854    }
855
856    #[test]
857    fn test_faer_bgemm_outer_product() {
858        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
859        let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
860        let mut c = StridedArray::<f64>::row_major(&[3, 4]);
861
862        bgemm_strided_into(
863            &mut c.view_mut(),
864            &a.view(),
865            &b.view(),
866            0,
867            1,
868            1,
869            0, // no sum
870            1.0,
871            0.0,
872            false,
873            false,
874        )
875        .unwrap();
876
877        assert_eq!(c.get(&[0, 0]), 1.0);
878        assert_eq!(c.get(&[2, 3]), 12.0);
879    }
880
881    #[test]
882    fn test_faer_bgemm_f32() {
883        let a = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
884            [[1.0f32, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
885        });
886        let b = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
887            [[5.0f32, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
888        });
889        let mut c = StridedArray::<f32>::row_major(&[2, 2]);
890
891        bgemm_strided_into(
892            &mut c.view_mut(),
893            &a.view(),
894            &b.view(),
895            0,
896            1,
897            1,
898            1,
899            1.0f32,
900            0.0f32,
901            false,
902            false,
903        )
904        .unwrap();
905
906        assert_eq!(c.get(&[0, 0]), 19.0f32);
907        assert_eq!(c.get(&[1, 1]), 50.0f32);
908    }
909
910    #[test]
911    fn test_faer_bgemm_col_major_input() {
912        // A is col-major (non-contiguous for row-major fusion) → triggers copy path
913        let a_data = vec![1.0, 3.0, 2.0, 4.0]; // col-major [[1,2],[3,4]]
914        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
915
916        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
917            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
918        });
919        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
920
921        bgemm_strided_into(
922            &mut c.view_mut(),
923            &a.view(),
924            &b.view(),
925            0,
926            1,
927            1,
928            1,
929            1.0,
930            0.0,
931            false,
932            false,
933        )
934        .unwrap();
935
936        // Same A=[[1,2],[3,4]], B=[[5,6],[7,8]]
937        assert_eq!(c.get(&[0, 0]), 19.0);
938        assert_eq!(c.get(&[0, 1]), 22.0);
939        assert_eq!(c.get(&[1, 0]), 43.0);
940        assert_eq!(c.get(&[1, 1]), 50.0);
941    }
942
943    #[test]
944    fn test_faer_bgemm_col_major_output() {
945        // C is col-major → triggers C copy path
946        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
947            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
948        });
949        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
950            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
951        });
952        let mut c = StridedArray::<f64>::col_major(&[2, 2]);
953
954        bgemm_strided_into(
955            &mut c.view_mut(),
956            &a.view(),
957            &b.view(),
958            0,
959            1,
960            1,
961            1,
962            1.0,
963            0.0,
964            false,
965            false,
966        )
967        .unwrap();
968
969        assert_eq!(c.get(&[0, 0]), 19.0);
970        assert_eq!(c.get(&[0, 1]), 22.0);
971        assert_eq!(c.get(&[1, 0]), 43.0);
972        assert_eq!(c.get(&[1, 1]), 50.0);
973    }
974
975    #[test]
976    fn test_faer_bgemm_col_major_with_beta() {
977        // C is col-major with beta != 0 → copy C in, matmul, copy back
978        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
979            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
980        });
981        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
982            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
983        });
984        // C col-major with initial values
985        let c_data = vec![10.0, 30.0, 20.0, 40.0]; // col-major [[10,20],[30,40]]
986        let mut c = StridedArray::<f64>::from_parts(c_data, &[2, 2], &[1, 2], 0).unwrap();
987
988        bgemm_strided_into(
989            &mut c.view_mut(),
990            &a.view(),
991            &b.view(),
992            0,
993            1,
994            1,
995            1,
996            2.0,
997            3.0, // C = 2*I*B + 3*C_old
998            false,
999            false,
1000        )
1001        .unwrap();
1002
1003        // C[0,0] = 2*1 + 3*10 = 32
1004        assert_eq!(c.get(&[0, 0]), 32.0);
1005        // C[1,1] = 2*4 + 3*40 = 128
1006        assert_eq!(c.get(&[1, 1]), 128.0);
1007    }
1008
1009    // ---- bgemm_contiguous_into tests ----
1010
1011    use crate::backend::{ActiveBackend, Backend};
1012    use crate::contiguous::{prepare_input_view, prepare_output_view};
1013
1014    const US: bool = <ActiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE;
1015
1016    #[test]
1017    fn test_bgemm_contiguous_2x2() {
1018        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1019            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1020        });
1021        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1022            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1023        });
1024        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1025
1026        let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
1027        let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
1028        let mut c_view = c.view_mut();
1029        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
1030
1031        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
1032
1033        c_op.finalize_into(&mut c_view).unwrap();
1034
1035        assert_eq!(c.get(&[0, 0]), 19.0);
1036        assert_eq!(c.get(&[0, 1]), 22.0);
1037        assert_eq!(c.get(&[1, 0]), 43.0);
1038        assert_eq!(c.get(&[1, 1]), 50.0);
1039    }
1040
1041    #[test]
1042    fn test_bgemm_contiguous_batched() {
1043        // Batched: 2 x (2x3) * (3x2) matmul
1044        // Batch-last: A: [lo, sum, batch]=[2,3,2], B: [sum, ro, batch]=[3,2,2], C: [lo, ro, batch]=[2,2,2]
1045        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
1046            (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
1047        });
1048        let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
1049            (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
1050        });
1051        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
1052
1053        // n_batch=1, A: n_group1=1 (lo), n_group2=1 (sum)
1054        let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
1055        let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
1056        let mut c_view = c.view_mut();
1057        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
1058
1059        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[2], 2, 2, 3, 1.0, 0.0).unwrap();
1060
1061        c_op.finalize_into(&mut c_view).unwrap();
1062
1063        // C: [lo, ro, batch]
1064        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
1065        // C0[0,0] = 1*1+2*3+3*5 = 22
1066        assert_eq!(c.get(&[0, 0, 0]), 22.0);
1067        // C0[0,1] = 1*2+2*4+3*6 = 28
1068        assert_eq!(c.get(&[0, 1, 0]), 28.0);
1069        // C0[1,0] = 4*1+5*3+6*5 = 49
1070        assert_eq!(c.get(&[1, 0, 0]), 49.0);
1071        // C0[1,1] = 4*2+5*4+6*6 = 64
1072        assert_eq!(c.get(&[1, 1, 0]), 64.0);
1073
1074        // Batch 1: A1=[[7,8,9],[10,11,12]], B1=[[7,8],[9,10],[11,12]]
1075        // C1[0,0] = 7*7+8*9+9*11 = 49+72+99 = 220
1076        assert_eq!(c.get(&[0, 0, 1]), 220.0);
1077    }
1078
1079    #[test]
1080    fn test_bgemm_contiguous_with_beta() {
1081        // C = 2*I*B + 3*C_old
1082        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1083            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
1084        });
1085        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1086            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1087        });
1088        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1089            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
1090        });
1091
1092        let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
1093        let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
1094        let mut c_view = c.view_mut();
1095        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 3.0, US, true).unwrap();
1096
1097        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 2.0, 3.0).unwrap();
1098
1099        c_op.finalize_into(&mut c_view).unwrap();
1100
1101        // C[0,0] = 2*1 + 3*10 = 32
1102        assert_eq!(c.get(&[0, 0]), 32.0);
1103        // C[0,1] = 2*2 + 3*20 = 64
1104        assert_eq!(c.get(&[0, 1]), 64.0);
1105        // C[1,0] = 2*3 + 3*30 = 96
1106        assert_eq!(c.get(&[1, 0]), 96.0);
1107        // C[1,1] = 2*4 + 3*40 = 128
1108        assert_eq!(c.get(&[1, 1]), 128.0);
1109    }
1110
1111    #[test]
1112    fn test_bgemm_contiguous_non_contiguous_input() {
1113        // A is col-major (triggers copy in prepare_input_view for row-major grouping)
1114        let a_data = vec![1.0, 3.0, 2.0, 4.0]; // col-major [[1,2],[3,4]]
1115        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
1116
1117        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1118            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1119        });
1120        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1121
1122        let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
1123        let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
1124        let mut c_view = c.view_mut();
1125        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
1126
1127        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
1128
1129        c_op.finalize_into(&mut c_view).unwrap();
1130
1131        // Same A=[[1,2],[3,4]], B=[[5,6],[7,8]]
1132        assert_eq!(c.get(&[0, 0]), 19.0);
1133        assert_eq!(c.get(&[0, 1]), 22.0);
1134        assert_eq!(c.get(&[1, 0]), 43.0);
1135        assert_eq!(c.get(&[1, 1]), 50.0);
1136    }
1137}