strided_einsum2/
bgemm_faer.rs

1//! faer-backed batched GEMM kernel on strided views.
2//!
3//! Uses `faer::linalg::matmul::matmul` for SIMD-optimized matrix multiplication.
4//! When dimension groups cannot be fused into 2D matrices (non-contiguous strides),
5//! copies operands to contiguous column-major buffers before calling faer.
6
7use crate::contiguous::{alloc_col_major_uninit, ContiguousOperand, ContiguousOperandMut};
8use crate::util::{try_fuse_group, MultiIndex};
9use faer::linalg::matmul::matmul_with_conj;
10use faer::mat::{MatMut, MatRef};
11use faer::{Accum, Conj, Par};
12use faer_traits::ComplexField;
13use strided_view::{StridedArray, StridedView, StridedViewMut};
14
15/// Batched strided GEMM using faer: C = alpha * A * B + beta * C
16///
17/// Same interface as `bgemm_naive::bgemm_strided_into`. Uses faer's optimized
18/// matmul for all cases. When dimension groups have non-contiguous strides,
19/// copies operands to contiguous column-major buffers first using `strided_kernel::copy_into`.
20pub fn bgemm_strided_into<T>(
21    c: &mut StridedViewMut<T>,
22    a: &StridedView<T>,
23    b: &StridedView<T>,
24    _n_batch: usize,
25    n_lo: usize,
26    n_ro: usize,
27    n_sum: usize,
28    alpha: T,
29    beta: T,
30    conj_a: bool,
31    conj_b: bool,
32) -> strided_view::Result<()>
33where
34    T: ComplexField
35        + Copy
36        + strided_view::ElementOpApply
37        + Send
38        + Sync
39        + std::ops::Mul<Output = T>
40        + std::ops::Add<Output = T>
41        + num_traits::Zero
42        + num_traits::One
43        + PartialEq,
44{
45    let a_dims = a.dims();
46    let b_dims = b.dims();
47    let a_strides = a.strides();
48    let b_strides = b.strides();
49    let c_strides = c.strides();
50
51    // Extract dimension groups (batch-last canonical order)
52    // A: [lo, sum, batch], B: [sum, ro, batch], C: [lo, ro, batch]
53    let lo_dims = &a_dims[..n_lo];
54    let sum_dims = &a_dims[n_lo..n_lo + n_sum];
55    let batch_dims = &a_dims[n_lo + n_sum..];
56    let ro_dims = &b_dims[n_sum..n_sum + n_ro];
57
58    // Fused sizes for the matrix multiply
59    let m: usize = lo_dims.iter().product::<usize>().max(1);
60    let k: usize = sum_dims.iter().product::<usize>().max(1);
61    let n: usize = ro_dims.iter().product::<usize>().max(1);
62
63    // Extract stride groups (batch-last)
64    let a_lo_strides = &a_strides[..n_lo];
65    let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
66    let b_sum_strides = &b_strides[..n_sum];
67    let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
68    let c_lo_strides = &c_strides[..n_lo];
69    let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
70
71    // Try to fuse each dimension group
72    let fused_a_lo = try_fuse_group(lo_dims, a_lo_strides);
73    let fused_a_sum = try_fuse_group(sum_dims, a_sum_strides);
74    let fused_b_sum = try_fuse_group(sum_dims, b_sum_strides);
75    let fused_b_ro = try_fuse_group(ro_dims, b_ro_strides);
76    let fused_c_lo = try_fuse_group(lo_dims, c_lo_strides);
77    let fused_c_ro = try_fuse_group(ro_dims, c_ro_strides);
78
79    let a_needs_copy = fused_a_lo.is_none() || fused_a_sum.is_none();
80    let b_needs_copy = fused_b_sum.is_none() || fused_b_ro.is_none();
81    let c_needs_copy = fused_c_lo.is_none() || fused_c_ro.is_none();
82
83    let n_a_inner = n_lo + n_sum;
84    let n_b_inner = n_sum + n_ro;
85    let n_c_inner = n_lo + n_ro;
86
87    // Copy A to contiguous column-major if inner dims aren't fusable
88    let a_contig_buf: Option<StridedArray<T>>;
89    let (a_ptr, a_row_stride, a_col_stride);
90    if a_needs_copy {
91        let mut buf = alloc_col_major_uninit(a.dims());
92        strided_kernel::copy_into(&mut buf.view_mut(), a)?;
93        a_ptr = buf.view().ptr();
94        // Col-major inner A [lo..., sum...]: lo stride = 1, sum stride = m
95        a_row_stride = if m == 0 { 0 } else { 1isize };
96        a_col_stride = m as isize;
97        a_contig_buf = Some(buf);
98    } else {
99        let (_, rs) = fused_a_lo.unwrap();
100        let (_, cs) = fused_a_sum.unwrap();
101        a_ptr = a.ptr();
102        a_row_stride = rs;
103        a_col_stride = cs;
104        a_contig_buf = None;
105    }
106    let a_batch_strides: &[isize] = match a_contig_buf.as_ref() {
107        Some(buf) => &buf.strides()[n_a_inner..],
108        None => &a_strides[n_a_inner..],
109    };
110
111    // Copy B to contiguous column-major if inner dims aren't fusable
112    let b_contig_buf: Option<StridedArray<T>>;
113    let (b_ptr, b_row_stride, b_col_stride);
114    if b_needs_copy {
115        let mut buf = alloc_col_major_uninit(b.dims());
116        strided_kernel::copy_into(&mut buf.view_mut(), b)?;
117        b_ptr = buf.view().ptr();
118        // Col-major inner B [sum..., ro...]: sum stride = 1, ro stride = k
119        b_row_stride = if k == 0 { 0 } else { 1isize };
120        b_col_stride = k as isize;
121        b_contig_buf = Some(buf);
122    } else {
123        let (_, rs) = fused_b_sum.unwrap();
124        let (_, cs) = fused_b_ro.unwrap();
125        b_ptr = b.ptr();
126        b_row_stride = rs;
127        b_col_stride = cs;
128        b_contig_buf = None;
129    }
130    let b_batch_strides: &[isize] = match b_contig_buf.as_ref() {
131        Some(buf) => &buf.strides()[n_b_inner..],
132        None => &b_strides[n_b_inner..],
133    };
134
135    // Copy C to contiguous column-major if inner dims aren't fusable
136    let c_contig_buf: Option<StridedArray<T>>;
137    let (c_ptr, c_row_stride, c_col_stride);
138    if c_needs_copy {
139        let mut buf = alloc_col_major_uninit(c.dims());
140        if beta != T::zero() {
141            strided_kernel::copy_into(&mut buf.view_mut(), &c.as_view())?;
142        }
143        c_ptr = buf.view_mut().as_mut_ptr();
144        // Col-major inner C [lo..., ro...]: lo stride = 1, ro stride = m
145        c_row_stride = if m == 0 { 0 } else { 1isize };
146        c_col_stride = m as isize;
147        c_contig_buf = Some(buf);
148    } else {
149        let (_, rs) = fused_c_lo.unwrap();
150        let (_, cs) = fused_c_ro.unwrap();
151        c_ptr = c.as_mut_ptr();
152        c_row_stride = rs;
153        c_col_stride = cs;
154        c_contig_buf = None;
155    }
156    let c_batch_strides: &[isize] = match c_contig_buf.as_ref() {
157        Some(buf) => &buf.strides()[n_c_inner..],
158        None => &c_strides[n_c_inner..],
159    };
160
161    let is_beta_zero = beta == T::zero();
162    let is_beta_one = beta == T::one();
163
164    // Determine accumulation mode
165    let accum = if is_beta_zero {
166        Accum::Replace
167    } else {
168        Accum::Add
169    };
170
171    let cj_a = if conj_a { Conj::Yes } else { Conj::No };
172    let cj_b = if conj_b { Conj::Yes } else { Conj::No };
173
174    // Inline closure for per-batch GEMM (shared between fast and slow paths)
175    let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
176        // Pre-scale C by beta if beta is not 0 or 1
177        if !is_beta_zero && !is_beta_one {
178            let c_base = unsafe { c_ptr.offset(c_batch_off) };
179            for i in 0..m {
180                for j in 0..n {
181                    let offset = i as isize * c_row_stride + j as isize * c_col_stride;
182                    unsafe {
183                        let elem = c_base.offset(offset);
184                        *elem = beta * *elem;
185                    }
186                }
187            }
188        }
189
190        unsafe {
191            let a_mat: MatRef<'_, T> =
192                MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
193            let b_mat: MatRef<'_, T> =
194                MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
195            let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
196                c_ptr.offset(c_batch_off),
197                m,
198                n,
199                c_row_stride,
200                c_col_stride,
201            );
202
203            matmul_with_conj(c_mat, accum, a_mat, cj_a, b_mat, cj_b, alpha, Par::rayon(0));
204        }
205    };
206
207    // Fast path: when batch dims are contiguous for all operands, use pointer
208    // increments instead of MultiIndex carry-based iteration.
209    let fused_a = try_fuse_group(batch_dims, a_batch_strides);
210    let fused_b = try_fuse_group(batch_dims, b_batch_strides);
211    let fused_c = try_fuse_group(batch_dims, c_batch_strides);
212
213    if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
214        (fused_a, fused_b, fused_c)
215    {
216        let mut a_off = 0isize;
217        let mut b_off = 0isize;
218        let mut c_off = 0isize;
219        for _ in 0..total {
220            do_batch(a_off, b_off, c_off);
221            a_off += a_step;
222            b_off += b_step;
223            c_off += c_step;
224        }
225    } else {
226        let mut batch_iter = MultiIndex::new(batch_dims);
227        while batch_iter.next().is_some() {
228            let a_batch_off = batch_iter.offset(a_batch_strides);
229            let b_batch_off = batch_iter.offset(b_batch_strides);
230            let c_batch_off = batch_iter.offset(c_batch_strides);
231            do_batch(a_batch_off, b_batch_off, c_batch_off);
232        }
233    }
234
235    // If C was copied to a temp buffer, copy the result back
236    if let Some(ref c_buf) = c_contig_buf {
237        strided_kernel::copy_into(c, &c_buf.view())?;
238    }
239
240    Ok(())
241}
242
243/// Batched GEMM on pre-contiguous operands.
244///
245/// Operands must already have contiguous inner dimensions (prepared via
246/// `prepare_input_*` and `prepare_output_*` in the `contiguous` module).
247///
248/// - `batch_dims`: sizes of the batch dimensions
249/// - `m`: fused lo dimension size (number of rows of A/C)
250/// - `n`: fused ro dimension size (number of cols of B/C)
251/// - `k`: fused sum dimension size (inner dimension)
252pub fn bgemm_contiguous_into<T>(
253    c: &mut ContiguousOperandMut<T>,
254    a: &ContiguousOperand<T>,
255    b: &ContiguousOperand<T>,
256    batch_dims: &[usize],
257    m: usize,
258    n: usize,
259    k: usize,
260    alpha: T,
261    beta: T,
262) -> strided_view::Result<()>
263where
264    T: ComplexField
265        + Copy
266        + strided_view::ElementOpApply
267        + Send
268        + Sync
269        + std::ops::Mul<Output = T>
270        + std::ops::Add<Output = T>
271        + num_traits::Zero
272        + num_traits::One
273        + PartialEq,
274{
275    let is_beta_zero = beta == T::zero();
276    let is_beta_one = beta == T::one();
277
278    let accum = if is_beta_zero {
279        Accum::Replace
280    } else {
281        Accum::Add
282    };
283
284    let a_batch_strides = a.batch_strides();
285    let b_batch_strides = b.batch_strides();
286    let c_batch_strides = c.batch_strides();
287
288    let a_ptr = a.ptr();
289    let b_ptr = b.ptr();
290    let c_ptr = c.ptr();
291    let a_row_stride = a.row_stride();
292    let a_col_stride = a.col_stride();
293    let b_row_stride = b.row_stride();
294    let b_col_stride = b.col_stride();
295    let c_row_stride = c.row_stride();
296    let c_col_stride = c.col_stride();
297
298    let conj_a = if a.conj() { Conj::Yes } else { Conj::No };
299    let conj_b = if b.conj() { Conj::Yes } else { Conj::No };
300
301    // Inline closure for per-batch GEMM (shared between fast and slow paths)
302    let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
303        // Pre-scale C by beta if beta is not 0 or 1
304        if !is_beta_zero && !is_beta_one {
305            let c_base = unsafe { c_ptr.offset(c_batch_off) };
306            for i in 0..m {
307                for j in 0..n {
308                    let offset = i as isize * c_row_stride + j as isize * c_col_stride;
309                    unsafe {
310                        let elem = c_base.offset(offset);
311                        *elem = beta * *elem;
312                    }
313                }
314            }
315        }
316
317        unsafe {
318            let a_mat: MatRef<'_, T> =
319                MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
320            let b_mat: MatRef<'_, T> =
321                MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
322            let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
323                c_ptr.offset(c_batch_off),
324                m,
325                n,
326                c_row_stride,
327                c_col_stride,
328            );
329
330            matmul_with_conj(
331                c_mat,
332                accum,
333                a_mat,
334                conj_a,
335                b_mat,
336                conj_b,
337                alpha,
338                Par::rayon(0),
339            );
340        }
341    };
342
343    // Fast path: when batch dims are contiguous for all operands, use pointer
344    // increments instead of MultiIndex carry-based iteration.
345    let fused_a = try_fuse_group(batch_dims, a_batch_strides);
346    let fused_b = try_fuse_group(batch_dims, b_batch_strides);
347    let fused_c = try_fuse_group(batch_dims, c_batch_strides);
348
349    if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
350        (fused_a, fused_b, fused_c)
351    {
352        let mut a_off = 0isize;
353        let mut b_off = 0isize;
354        let mut c_off = 0isize;
355        for _ in 0..total {
356            do_batch(a_off, b_off, c_off);
357            a_off += a_step;
358            b_off += b_step;
359            c_off += c_step;
360        }
361    } else {
362        let mut batch_iter = MultiIndex::new(batch_dims);
363        while batch_iter.next().is_some() {
364            let a_batch_off = batch_iter.offset(a_batch_strides);
365            let b_batch_off = batch_iter.offset(b_batch_strides);
366            let c_batch_off = batch_iter.offset(c_batch_strides);
367            do_batch(a_batch_off, b_batch_off, c_batch_off);
368        }
369    }
370
371    Ok(())
372}
373
374use crate::backend::{Backend, FaerBackend};
375
376impl<T> Backend<T> for FaerBackend
377where
378    T: crate::Scalar + ComplexField,
379{
380    const MATERIALIZES_CONJ: bool = false;
381    const REQUIRES_UNIT_STRIDE: bool = false;
382
383    fn bgemm_contiguous_into(
384        c: &mut ContiguousOperandMut<T>,
385        a: &ContiguousOperand<T>,
386        b: &ContiguousOperand<T>,
387        batch_dims: &[usize],
388        m: usize,
389        n: usize,
390        k: usize,
391        alpha: T,
392        beta: T,
393    ) -> strided_view::Result<()> {
394        // Delegate to the existing free function in this module
395        bgemm_contiguous_into(c, a, b, batch_dims, m, n, k, alpha, beta)
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use strided_view::StridedArray;
403
404    #[test]
405    fn test_faer_bgemm_2x2() {
406        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
407            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
408        });
409        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
410            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
411        });
412        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
413
414        bgemm_strided_into(
415            &mut c.view_mut(),
416            &a.view(),
417            &b.view(),
418            0,
419            1,
420            1,
421            1,
422            1.0,
423            0.0,
424            false,
425            false,
426        )
427        .unwrap();
428
429        assert_eq!(c.get(&[0, 0]), 19.0);
430        assert_eq!(c.get(&[0, 1]), 22.0);
431        assert_eq!(c.get(&[1, 0]), 43.0);
432        assert_eq!(c.get(&[1, 1]), 50.0);
433    }
434
435    #[test]
436    fn test_faer_bgemm_rect() {
437        let a =
438            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
439        let b =
440            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
441        let mut c = StridedArray::<f64>::row_major(&[2, 4]);
442
443        bgemm_strided_into(
444            &mut c.view_mut(),
445            &a.view(),
446            &b.view(),
447            0,
448            1,
449            1,
450            1,
451            1.0,
452            0.0,
453            false,
454            false,
455        )
456        .unwrap();
457
458        assert_eq!(c.get(&[0, 0]), 38.0);
459        assert_eq!(c.get(&[1, 3]), 128.0);
460    }
461
462    #[test]
463    fn test_faer_bgemm_batched() {
464        // Batch-last: A: [lo, sum, batch]=[2,3,2], B: [sum, ro, batch]=[3,2,2], C: [lo, ro, batch]=[2,2,2]
465        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
466            (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
467        });
468        let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
469            (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
470        });
471        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
472
473        bgemm_strided_into(
474            &mut c.view_mut(),
475            &a.view(),
476            &b.view(),
477            1,
478            1,
479            1,
480            1,
481            1.0,
482            0.0,
483            false,
484            false,
485        )
486        .unwrap();
487
488        // C: [lo, ro, batch]
489        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
490        // C0[0,0] = 1*1+2*3+3*5 = 22
491        assert_eq!(c.get(&[0, 0, 0]), 22.0);
492    }
493
494    #[test]
495    fn test_faer_bgemm_beta_zero() {
496        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
497            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
498        });
499        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
500            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
501        });
502        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
503            [[100.0, 200.0], [300.0, 400.0]][idx[0]][idx[1]]
504        });
505
506        bgemm_strided_into(
507            &mut c.view_mut(),
508            &a.view(),
509            &b.view(),
510            0,
511            1,
512            1,
513            1,
514            1.0,
515            0.0, // beta=0: C_old should be ignored
516            false,
517            false,
518        )
519        .unwrap();
520
521        assert_eq!(c.get(&[0, 0]), 19.0);
522        assert_eq!(c.get(&[1, 1]), 50.0);
523    }
524
525    #[test]
526    fn test_faer_bgemm_beta_one() {
527        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
528            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
529        });
530        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
531            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
532        });
533        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
534            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
535        });
536
537        bgemm_strided_into(
538            &mut c.view_mut(),
539            &a.view(),
540            &b.view(),
541            0,
542            1,
543            1,
544            1,
545            1.0,
546            1.0, // beta=1: C = A*B + C_old
547            false,
548            false,
549        )
550        .unwrap();
551
552        // C[0,0] = 1*1+0*3 + 10 = 11
553        assert_eq!(c.get(&[0, 0]), 11.0);
554        // C[1,1] = 0*2+1*4 + 40 = 44
555        assert_eq!(c.get(&[1, 1]), 44.0);
556    }
557
558    #[test]
559    fn test_faer_bgemm_alpha_beta() {
560        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
561            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
562        });
563        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
564            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
565        });
566        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
567            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
568        });
569
570        bgemm_strided_into(
571            &mut c.view_mut(),
572            &a.view(),
573            &b.view(),
574            0,
575            1,
576            1,
577            1,
578            2.0,
579            3.0, // C = 2*I*B + 3*C_old
580            false,
581            false,
582        )
583        .unwrap();
584
585        // C[0,0] = 2*1 + 3*10 = 32
586        assert_eq!(c.get(&[0, 0]), 32.0);
587        // C[1,1] = 2*4 + 3*40 = 128
588        assert_eq!(c.get(&[1, 1]), 128.0);
589    }
590
591    #[test]
592    fn test_faer_bgemm_outer_product() {
593        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
594        let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
595        let mut c = StridedArray::<f64>::row_major(&[3, 4]);
596
597        bgemm_strided_into(
598            &mut c.view_mut(),
599            &a.view(),
600            &b.view(),
601            0,
602            1,
603            1,
604            0, // no sum
605            1.0,
606            0.0,
607            false,
608            false,
609        )
610        .unwrap();
611
612        assert_eq!(c.get(&[0, 0]), 1.0);
613        assert_eq!(c.get(&[2, 3]), 12.0);
614    }
615
616    #[test]
617    fn test_faer_bgemm_f32() {
618        let a = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
619            [[1.0f32, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
620        });
621        let b = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
622            [[5.0f32, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
623        });
624        let mut c = StridedArray::<f32>::row_major(&[2, 2]);
625
626        bgemm_strided_into(
627            &mut c.view_mut(),
628            &a.view(),
629            &b.view(),
630            0,
631            1,
632            1,
633            1,
634            1.0f32,
635            0.0f32,
636            false,
637            false,
638        )
639        .unwrap();
640
641        assert_eq!(c.get(&[0, 0]), 19.0f32);
642        assert_eq!(c.get(&[1, 1]), 50.0f32);
643    }
644
645    #[test]
646    fn test_faer_bgemm_col_major_input() {
647        // A is col-major (non-contiguous for row-major fusion) → triggers copy path
648        let a_data = vec![1.0, 3.0, 2.0, 4.0]; // col-major [[1,2],[3,4]]
649        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
650
651        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
652            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
653        });
654        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
655
656        bgemm_strided_into(
657            &mut c.view_mut(),
658            &a.view(),
659            &b.view(),
660            0,
661            1,
662            1,
663            1,
664            1.0,
665            0.0,
666            false,
667            false,
668        )
669        .unwrap();
670
671        // Same A=[[1,2],[3,4]], B=[[5,6],[7,8]]
672        assert_eq!(c.get(&[0, 0]), 19.0);
673        assert_eq!(c.get(&[0, 1]), 22.0);
674        assert_eq!(c.get(&[1, 0]), 43.0);
675        assert_eq!(c.get(&[1, 1]), 50.0);
676    }
677
678    #[test]
679    fn test_faer_bgemm_col_major_output() {
680        // C is col-major → triggers C copy path
681        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
682            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
683        });
684        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
685            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
686        });
687        let mut c = StridedArray::<f64>::col_major(&[2, 2]);
688
689        bgemm_strided_into(
690            &mut c.view_mut(),
691            &a.view(),
692            &b.view(),
693            0,
694            1,
695            1,
696            1,
697            1.0,
698            0.0,
699            false,
700            false,
701        )
702        .unwrap();
703
704        assert_eq!(c.get(&[0, 0]), 19.0);
705        assert_eq!(c.get(&[0, 1]), 22.0);
706        assert_eq!(c.get(&[1, 0]), 43.0);
707        assert_eq!(c.get(&[1, 1]), 50.0);
708    }
709
710    #[test]
711    fn test_faer_bgemm_col_major_with_beta() {
712        // C is col-major with beta != 0 → copy C in, matmul, copy back
713        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
714            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
715        });
716        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
717            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
718        });
719        // C col-major with initial values
720        let c_data = vec![10.0, 30.0, 20.0, 40.0]; // col-major [[10,20],[30,40]]
721        let mut c = StridedArray::<f64>::from_parts(c_data, &[2, 2], &[1, 2], 0).unwrap();
722
723        bgemm_strided_into(
724            &mut c.view_mut(),
725            &a.view(),
726            &b.view(),
727            0,
728            1,
729            1,
730            1,
731            2.0,
732            3.0, // C = 2*I*B + 3*C_old
733            false,
734            false,
735        )
736        .unwrap();
737
738        // C[0,0] = 2*1 + 3*10 = 32
739        assert_eq!(c.get(&[0, 0]), 32.0);
740        // C[1,1] = 2*4 + 3*40 = 128
741        assert_eq!(c.get(&[1, 1]), 128.0);
742    }
743
744    // ---- bgemm_contiguous_into tests ----
745
746    use crate::backend::{ActiveBackend, Backend};
747    use crate::contiguous::{prepare_input_view, prepare_output_view};
748
749    const US: bool = <ActiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE;
750
751    #[test]
752    fn test_bgemm_contiguous_2x2() {
753        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
754            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
755        });
756        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
757            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
758        });
759        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
760
761        let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
762        let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
763        let mut c_view = c.view_mut();
764        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
765
766        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
767
768        c_op.finalize_into(&mut c_view).unwrap();
769
770        assert_eq!(c.get(&[0, 0]), 19.0);
771        assert_eq!(c.get(&[0, 1]), 22.0);
772        assert_eq!(c.get(&[1, 0]), 43.0);
773        assert_eq!(c.get(&[1, 1]), 50.0);
774    }
775
776    #[test]
777    fn test_bgemm_contiguous_batched() {
778        // Batched: 2 x (2x3) * (3x2) matmul
779        // Batch-last: A: [lo, sum, batch]=[2,3,2], B: [sum, ro, batch]=[3,2,2], C: [lo, ro, batch]=[2,2,2]
780        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
781            (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
782        });
783        let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
784            (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
785        });
786        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
787
788        // n_batch=1, A: n_group1=1 (lo), n_group2=1 (sum)
789        let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
790        let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
791        let mut c_view = c.view_mut();
792        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
793
794        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[2], 2, 2, 3, 1.0, 0.0).unwrap();
795
796        c_op.finalize_into(&mut c_view).unwrap();
797
798        // C: [lo, ro, batch]
799        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
800        // C0[0,0] = 1*1+2*3+3*5 = 22
801        assert_eq!(c.get(&[0, 0, 0]), 22.0);
802        // C0[0,1] = 1*2+2*4+3*6 = 28
803        assert_eq!(c.get(&[0, 1, 0]), 28.0);
804        // C0[1,0] = 4*1+5*3+6*5 = 49
805        assert_eq!(c.get(&[1, 0, 0]), 49.0);
806        // C0[1,1] = 4*2+5*4+6*6 = 64
807        assert_eq!(c.get(&[1, 1, 0]), 64.0);
808
809        // Batch 1: A1=[[7,8,9],[10,11,12]], B1=[[7,8],[9,10],[11,12]]
810        // C1[0,0] = 7*7+8*9+9*11 = 49+72+99 = 220
811        assert_eq!(c.get(&[0, 0, 1]), 220.0);
812    }
813
814    #[test]
815    fn test_bgemm_contiguous_with_beta() {
816        // C = 2*I*B + 3*C_old
817        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
818            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
819        });
820        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
821            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
822        });
823        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
824            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
825        });
826
827        let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
828        let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
829        let mut c_view = c.view_mut();
830        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 3.0, US, true).unwrap();
831
832        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 2.0, 3.0).unwrap();
833
834        c_op.finalize_into(&mut c_view).unwrap();
835
836        // C[0,0] = 2*1 + 3*10 = 32
837        assert_eq!(c.get(&[0, 0]), 32.0);
838        // C[0,1] = 2*2 + 3*20 = 64
839        assert_eq!(c.get(&[0, 1]), 64.0);
840        // C[1,0] = 2*3 + 3*30 = 96
841        assert_eq!(c.get(&[1, 0]), 96.0);
842        // C[1,1] = 2*4 + 3*40 = 128
843        assert_eq!(c.get(&[1, 1]), 128.0);
844    }
845
846    #[test]
847    fn test_bgemm_contiguous_non_contiguous_input() {
848        // A is col-major (triggers copy in prepare_input_view for row-major grouping)
849        let a_data = vec![1.0, 3.0, 2.0, 4.0]; // col-major [[1,2],[3,4]]
850        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
851
852        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
853            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
854        });
855        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
856
857        let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
858        let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
859        let mut c_view = c.view_mut();
860        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
861
862        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
863
864        c_op.finalize_into(&mut c_view).unwrap();
865
866        // Same A=[[1,2],[3,4]], B=[[5,6],[7,8]]
867        assert_eq!(c.get(&[0, 0]), 19.0);
868        assert_eq!(c.get(&[0, 1]), 22.0);
869        assert_eq!(c.get(&[1, 0]), 43.0);
870        assert_eq!(c.get(&[1, 1]), 50.0);
871    }
872}