strided_einsum2/
bgemm_faer.rs

1//! faer-backed batched GEMM kernel on strided views.
2//!
3//! Uses `faer::linalg::matmul::matmul` for SIMD-optimized matrix multiplication.
4//! When dimension groups cannot be fused into 2D matrices (non-contiguous strides),
5//! copies operands to contiguous column-major buffers before calling faer.
6
7use crate::contiguous::{alloc_col_major_uninit, ContiguousOperand, ContiguousOperandMut};
8use crate::util::{try_fuse_group, MultiIndex};
9use faer::linalg::matmul::matmul_with_conj;
10use faer::mat::{MatMut, MatRef};
11use faer::{Accum, Conj, Par};
12use faer_traits::ComplexField;
13use strided_view::{StridedArray, StridedView, StridedViewMut};
14
15/// Batched strided GEMM using faer: C = alpha * A * B + beta * C
16///
17/// Same interface as `bgemm_naive::bgemm_strided_into`. Uses faer's optimized
18/// matmul for all cases. When dimension groups have non-contiguous strides,
19/// copies operands to contiguous column-major buffers first using `strided_kernel::copy_into`.
20pub fn bgemm_strided_into<T>(
21    c: &mut StridedViewMut<T>,
22    a: &StridedView<T>,
23    b: &StridedView<T>,
24    _n_batch: usize,
25    n_lo: usize,
26    n_ro: usize,
27    n_sum: usize,
28    alpha: T,
29    beta: T,
30    conj_a: bool,
31    conj_b: bool,
32) -> strided_view::Result<()>
33where
34    T: ComplexField
35        + Copy
36        + strided_view::ElementOpApply
37        + Send
38        + Sync
39        + std::ops::Mul<Output = T>
40        + std::ops::Add<Output = T>
41        + num_traits::Zero
42        + num_traits::One
43        + PartialEq,
44{
45    let a_dims = a.dims();
46    let b_dims = b.dims();
47    let a_strides = a.strides();
48    let b_strides = b.strides();
49    let c_strides = c.strides();
50
51    // Extract dimension groups (batch-last canonical order)
52    // A: [lo, sum, batch], B: [sum, ro, batch], C: [lo, ro, batch]
53    let lo_dims = &a_dims[..n_lo];
54    let sum_dims = &a_dims[n_lo..n_lo + n_sum];
55    let batch_dims = &a_dims[n_lo + n_sum..];
56    let ro_dims = &b_dims[n_sum..n_sum + n_ro];
57
58    // Fused sizes for the matrix multiply
59    let m: usize = lo_dims.iter().product::<usize>().max(1);
60    let k: usize = sum_dims.iter().product::<usize>().max(1);
61    let n: usize = ro_dims.iter().product::<usize>().max(1);
62
63    // Extract stride groups (batch-last)
64    let a_lo_strides = &a_strides[..n_lo];
65    let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
66    let b_sum_strides = &b_strides[..n_sum];
67    let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
68    let c_lo_strides = &c_strides[..n_lo];
69    let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
70
71    // Try to fuse each dimension group
72    let fused_a_lo = try_fuse_group(lo_dims, a_lo_strides);
73    let fused_a_sum = try_fuse_group(sum_dims, a_sum_strides);
74    let fused_b_sum = try_fuse_group(sum_dims, b_sum_strides);
75    let fused_b_ro = try_fuse_group(ro_dims, b_ro_strides);
76    let fused_c_lo = try_fuse_group(lo_dims, c_lo_strides);
77    let fused_c_ro = try_fuse_group(ro_dims, c_ro_strides);
78
79    let a_needs_copy = fused_a_lo.is_none() || fused_a_sum.is_none();
80    let b_needs_copy = fused_b_sum.is_none() || fused_b_ro.is_none();
81    let c_needs_copy = fused_c_lo.is_none() || fused_c_ro.is_none();
82
83    let n_a_inner = n_lo + n_sum;
84    let n_b_inner = n_sum + n_ro;
85    let n_c_inner = n_lo + n_ro;
86
87    // Copy A to contiguous column-major if inner dims aren't fusable
88    let a_contig_buf: Option<StridedArray<T>>;
89    let (a_ptr, a_row_stride, a_col_stride);
90    if a_needs_copy {
91        let mut buf = alloc_col_major_uninit(a.dims());
92        strided_kernel::copy_into(&mut buf.view_mut(), a)?;
93        a_ptr = buf.view().ptr();
94        // Col-major inner A [lo..., sum...]: lo stride = 1, sum stride = m
95        a_row_stride = if m == 0 { 0 } else { 1isize };
96        a_col_stride = m as isize;
97        a_contig_buf = Some(buf);
98    } else {
99        let (_, rs) = fused_a_lo.unwrap();
100        let (_, cs) = fused_a_sum.unwrap();
101        a_ptr = a.ptr();
102        a_row_stride = rs;
103        a_col_stride = cs;
104        a_contig_buf = None;
105    }
106    let a_batch_strides: &[isize] = match a_contig_buf.as_ref() {
107        Some(buf) => &buf.strides()[n_a_inner..],
108        None => &a_strides[n_a_inner..],
109    };
110
111    // Copy B to contiguous column-major if inner dims aren't fusable
112    let b_contig_buf: Option<StridedArray<T>>;
113    let (b_ptr, b_row_stride, b_col_stride);
114    if b_needs_copy {
115        let mut buf = alloc_col_major_uninit(b.dims());
116        strided_kernel::copy_into(&mut buf.view_mut(), b)?;
117        b_ptr = buf.view().ptr();
118        // Col-major inner B [sum..., ro...]: sum stride = 1, ro stride = k
119        b_row_stride = if k == 0 { 0 } else { 1isize };
120        b_col_stride = k as isize;
121        b_contig_buf = Some(buf);
122    } else {
123        let (_, rs) = fused_b_sum.unwrap();
124        let (_, cs) = fused_b_ro.unwrap();
125        b_ptr = b.ptr();
126        b_row_stride = rs;
127        b_col_stride = cs;
128        b_contig_buf = None;
129    }
130    let b_batch_strides: &[isize] = match b_contig_buf.as_ref() {
131        Some(buf) => &buf.strides()[n_b_inner..],
132        None => &b_strides[n_b_inner..],
133    };
134
135    // Copy C to contiguous column-major if inner dims aren't fusable
136    let c_contig_buf: Option<StridedArray<T>>;
137    let (c_ptr, c_row_stride, c_col_stride);
138    if c_needs_copy {
139        let mut buf = alloc_col_major_uninit(c.dims());
140        if beta != T::zero() {
141            strided_kernel::copy_into(&mut buf.view_mut(), &c.as_view())?;
142        }
143        c_ptr = buf.view_mut().as_mut_ptr();
144        // Col-major inner C [lo..., ro...]: lo stride = 1, ro stride = m
145        c_row_stride = if m == 0 { 0 } else { 1isize };
146        c_col_stride = m as isize;
147        c_contig_buf = Some(buf);
148    } else {
149        let (_, rs) = fused_c_lo.unwrap();
150        let (_, cs) = fused_c_ro.unwrap();
151        c_ptr = c.as_mut_ptr();
152        c_row_stride = rs;
153        c_col_stride = cs;
154        c_contig_buf = None;
155    }
156    let c_batch_strides: &[isize] = match c_contig_buf.as_ref() {
157        Some(buf) => &buf.strides()[n_c_inner..],
158        None => &c_strides[n_c_inner..],
159    };
160
161    let is_beta_zero = beta == T::zero();
162    let is_beta_one = beta == T::one();
163
164    // Determine accumulation mode
165    let accum = if is_beta_zero {
166        Accum::Replace
167    } else {
168        Accum::Add
169    };
170
171    let cj_a = if conj_a { Conj::Yes } else { Conj::No };
172    let cj_b = if conj_b { Conj::Yes } else { Conj::No };
173
174    // Inline closure for per-batch GEMM (shared between fast and slow paths)
175    let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
176        // Pre-scale C by beta if beta is not 0 or 1
177        if !is_beta_zero && !is_beta_one {
178            let c_base = unsafe { c_ptr.offset(c_batch_off) };
179            for i in 0..m {
180                for j in 0..n {
181                    let offset = i as isize * c_row_stride + j as isize * c_col_stride;
182                    unsafe {
183                        let elem = c_base.offset(offset);
184                        *elem = beta * *elem;
185                    }
186                }
187            }
188        }
189
190        unsafe {
191            let a_mat: MatRef<'_, T> =
192                MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
193            let b_mat: MatRef<'_, T> =
194                MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
195            let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
196                c_ptr.offset(c_batch_off),
197                m,
198                n,
199                c_row_stride,
200                c_col_stride,
201            );
202
203            matmul_with_conj(c_mat, accum, a_mat, cj_a, b_mat, cj_b, alpha, Par::rayon(0));
204        }
205    };
206
207    // Fast path: when batch dims are contiguous for all operands, use pointer
208    // increments instead of MultiIndex carry-based iteration.
209    let fused_a = try_fuse_group(batch_dims, a_batch_strides);
210    let fused_b = try_fuse_group(batch_dims, b_batch_strides);
211    let fused_c = try_fuse_group(batch_dims, c_batch_strides);
212
213    if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
214        (fused_a, fused_b, fused_c)
215    {
216        let mut a_off = 0isize;
217        let mut b_off = 0isize;
218        let mut c_off = 0isize;
219        for _ in 0..total {
220            do_batch(a_off, b_off, c_off);
221            a_off += a_step;
222            b_off += b_step;
223            c_off += c_step;
224        }
225    } else {
226        let mut batch_iter = MultiIndex::new(batch_dims);
227        while batch_iter.next().is_some() {
228            let a_batch_off = batch_iter.offset(a_batch_strides);
229            let b_batch_off = batch_iter.offset(b_batch_strides);
230            let c_batch_off = batch_iter.offset(c_batch_strides);
231            do_batch(a_batch_off, b_batch_off, c_batch_off);
232        }
233    }
234
235    // If C was copied to a temp buffer, copy the result back
236    if let Some(ref c_buf) = c_contig_buf {
237        strided_kernel::copy_into(c, &c_buf.view())?;
238    }
239
240    Ok(())
241}
242
243/// Batched GEMM on pre-contiguous operands.
244///
245/// Operands must already have contiguous inner dimensions (prepared via
246/// `prepare_input_*` and `prepare_output_*` in the `contiguous` module).
247///
248/// - `batch_dims`: sizes of the batch dimensions
249/// - `m`: fused lo dimension size (number of rows of A/C)
250/// - `n`: fused ro dimension size (number of cols of B/C)
251/// - `k`: fused sum dimension size (inner dimension)
252pub fn bgemm_contiguous_into<T>(
253    c: &mut ContiguousOperandMut<T>,
254    a: &ContiguousOperand<T>,
255    b: &ContiguousOperand<T>,
256    batch_dims: &[usize],
257    m: usize,
258    n: usize,
259    k: usize,
260    alpha: T,
261    beta: T,
262) -> strided_view::Result<()>
263where
264    T: ComplexField
265        + Copy
266        + strided_view::ElementOpApply
267        + Send
268        + Sync
269        + std::ops::Mul<Output = T>
270        + std::ops::Add<Output = T>
271        + num_traits::Zero
272        + num_traits::One
273        + PartialEq,
274{
275    let is_beta_zero = beta == T::zero();
276    let is_beta_one = beta == T::one();
277
278    let accum = if is_beta_zero {
279        Accum::Replace
280    } else {
281        Accum::Add
282    };
283
284    let a_batch_strides = a.batch_strides();
285    let b_batch_strides = b.batch_strides();
286    let c_batch_strides = c.batch_strides();
287
288    let a_ptr = a.ptr();
289    let b_ptr = b.ptr();
290    let c_ptr = c.ptr();
291    let a_row_stride = a.row_stride();
292    let a_col_stride = a.col_stride();
293    let b_row_stride = b.row_stride();
294    let b_col_stride = b.col_stride();
295    let c_row_stride = c.row_stride();
296    let c_col_stride = c.col_stride();
297
298    let conj_a = if a.conj() { Conj::Yes } else { Conj::No };
299    let conj_b = if b.conj() { Conj::Yes } else { Conj::No };
300
301    // Inline closure for per-batch GEMM (shared between fast and slow paths)
302    let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
303        // Pre-scale C by beta if beta is not 0 or 1
304        if !is_beta_zero && !is_beta_one {
305            let c_base = unsafe { c_ptr.offset(c_batch_off) };
306            for i in 0..m {
307                for j in 0..n {
308                    let offset = i as isize * c_row_stride + j as isize * c_col_stride;
309                    unsafe {
310                        let elem = c_base.offset(offset);
311                        *elem = beta * *elem;
312                    }
313                }
314            }
315        }
316
317        unsafe {
318            let a_mat: MatRef<'_, T> =
319                MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
320            let b_mat: MatRef<'_, T> =
321                MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
322            let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
323                c_ptr.offset(c_batch_off),
324                m,
325                n,
326                c_row_stride,
327                c_col_stride,
328            );
329
330            matmul_with_conj(
331                c_mat,
332                accum,
333                a_mat,
334                conj_a,
335                b_mat,
336                conj_b,
337                alpha,
338                Par::rayon(0),
339            );
340        }
341    };
342
343    // Fast path: when batch dims are contiguous for all operands, use pointer
344    // increments instead of MultiIndex carry-based iteration.
345    let fused_a = try_fuse_group(batch_dims, a_batch_strides);
346    let fused_b = try_fuse_group(batch_dims, b_batch_strides);
347    let fused_c = try_fuse_group(batch_dims, c_batch_strides);
348
349    if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
350        (fused_a, fused_b, fused_c)
351    {
352        let mut a_off = 0isize;
353        let mut b_off = 0isize;
354        let mut c_off = 0isize;
355        for _ in 0..total {
356            do_batch(a_off, b_off, c_off);
357            a_off += a_step;
358            b_off += b_step;
359            c_off += c_step;
360        }
361    } else {
362        let mut batch_iter = MultiIndex::new(batch_dims);
363        while batch_iter.next().is_some() {
364            let a_batch_off = batch_iter.offset(a_batch_strides);
365            let b_batch_off = batch_iter.offset(b_batch_strides);
366            let c_batch_off = batch_iter.offset(c_batch_strides);
367            do_batch(a_batch_off, b_batch_off, c_batch_off);
368        }
369    }
370
371    Ok(())
372}
373
374use crate::backend::{BgemmBackend, FaerBackend};
375
376impl<T> BgemmBackend<T> for FaerBackend
377where
378    T: crate::Scalar + ComplexField,
379{
380    fn bgemm_contiguous_into(
381        c: &mut ContiguousOperandMut<T>,
382        a: &ContiguousOperand<T>,
383        b: &ContiguousOperand<T>,
384        batch_dims: &[usize],
385        m: usize,
386        n: usize,
387        k: usize,
388        alpha: T,
389        beta: T,
390    ) -> strided_view::Result<()> {
391        // Delegate to the existing free function in this module
392        bgemm_contiguous_into(c, a, b, batch_dims, m, n, k, alpha, beta)
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use strided_view::StridedArray;
400
401    #[test]
402    fn test_faer_bgemm_2x2() {
403        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
404            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
405        });
406        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
407            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
408        });
409        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
410
411        bgemm_strided_into(
412            &mut c.view_mut(),
413            &a.view(),
414            &b.view(),
415            0,
416            1,
417            1,
418            1,
419            1.0,
420            0.0,
421            false,
422            false,
423        )
424        .unwrap();
425
426        assert_eq!(c.get(&[0, 0]), 19.0);
427        assert_eq!(c.get(&[0, 1]), 22.0);
428        assert_eq!(c.get(&[1, 0]), 43.0);
429        assert_eq!(c.get(&[1, 1]), 50.0);
430    }
431
432    #[test]
433    fn test_faer_bgemm_rect() {
434        let a =
435            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
436        let b =
437            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
438        let mut c = StridedArray::<f64>::row_major(&[2, 4]);
439
440        bgemm_strided_into(
441            &mut c.view_mut(),
442            &a.view(),
443            &b.view(),
444            0,
445            1,
446            1,
447            1,
448            1.0,
449            0.0,
450            false,
451            false,
452        )
453        .unwrap();
454
455        assert_eq!(c.get(&[0, 0]), 38.0);
456        assert_eq!(c.get(&[1, 3]), 128.0);
457    }
458
459    #[test]
460    fn test_faer_bgemm_batched() {
461        // Batch-last: A: [lo, sum, batch]=[2,3,2], B: [sum, ro, batch]=[3,2,2], C: [lo, ro, batch]=[2,2,2]
462        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
463            (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
464        });
465        let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
466            (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
467        });
468        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
469
470        bgemm_strided_into(
471            &mut c.view_mut(),
472            &a.view(),
473            &b.view(),
474            1,
475            1,
476            1,
477            1,
478            1.0,
479            0.0,
480            false,
481            false,
482        )
483        .unwrap();
484
485        // C: [lo, ro, batch]
486        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
487        // C0[0,0] = 1*1+2*3+3*5 = 22
488        assert_eq!(c.get(&[0, 0, 0]), 22.0);
489    }
490
491    #[test]
492    fn test_faer_bgemm_beta_zero() {
493        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
494            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
495        });
496        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
497            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
498        });
499        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
500            [[100.0, 200.0], [300.0, 400.0]][idx[0]][idx[1]]
501        });
502
503        bgemm_strided_into(
504            &mut c.view_mut(),
505            &a.view(),
506            &b.view(),
507            0,
508            1,
509            1,
510            1,
511            1.0,
512            0.0, // beta=0: C_old should be ignored
513            false,
514            false,
515        )
516        .unwrap();
517
518        assert_eq!(c.get(&[0, 0]), 19.0);
519        assert_eq!(c.get(&[1, 1]), 50.0);
520    }
521
522    #[test]
523    fn test_faer_bgemm_beta_one() {
524        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
525            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
526        });
527        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
528            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
529        });
530        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
531            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
532        });
533
534        bgemm_strided_into(
535            &mut c.view_mut(),
536            &a.view(),
537            &b.view(),
538            0,
539            1,
540            1,
541            1,
542            1.0,
543            1.0, // beta=1: C = A*B + C_old
544            false,
545            false,
546        )
547        .unwrap();
548
549        // C[0,0] = 1*1+0*3 + 10 = 11
550        assert_eq!(c.get(&[0, 0]), 11.0);
551        // C[1,1] = 0*2+1*4 + 40 = 44
552        assert_eq!(c.get(&[1, 1]), 44.0);
553    }
554
555    #[test]
556    fn test_faer_bgemm_alpha_beta() {
557        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
558            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
559        });
560        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
561            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
562        });
563        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
564            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
565        });
566
567        bgemm_strided_into(
568            &mut c.view_mut(),
569            &a.view(),
570            &b.view(),
571            0,
572            1,
573            1,
574            1,
575            2.0,
576            3.0, // C = 2*I*B + 3*C_old
577            false,
578            false,
579        )
580        .unwrap();
581
582        // C[0,0] = 2*1 + 3*10 = 32
583        assert_eq!(c.get(&[0, 0]), 32.0);
584        // C[1,1] = 2*4 + 3*40 = 128
585        assert_eq!(c.get(&[1, 1]), 128.0);
586    }
587
588    #[test]
589    fn test_faer_bgemm_outer_product() {
590        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
591        let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
592        let mut c = StridedArray::<f64>::row_major(&[3, 4]);
593
594        bgemm_strided_into(
595            &mut c.view_mut(),
596            &a.view(),
597            &b.view(),
598            0,
599            1,
600            1,
601            0, // no sum
602            1.0,
603            0.0,
604            false,
605            false,
606        )
607        .unwrap();
608
609        assert_eq!(c.get(&[0, 0]), 1.0);
610        assert_eq!(c.get(&[2, 3]), 12.0);
611    }
612
613    #[test]
614    fn test_faer_bgemm_f32() {
615        let a = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
616            [[1.0f32, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
617        });
618        let b = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
619            [[5.0f32, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
620        });
621        let mut c = StridedArray::<f32>::row_major(&[2, 2]);
622
623        bgemm_strided_into(
624            &mut c.view_mut(),
625            &a.view(),
626            &b.view(),
627            0,
628            1,
629            1,
630            1,
631            1.0f32,
632            0.0f32,
633            false,
634            false,
635        )
636        .unwrap();
637
638        assert_eq!(c.get(&[0, 0]), 19.0f32);
639        assert_eq!(c.get(&[1, 1]), 50.0f32);
640    }
641
642    #[test]
643    fn test_faer_bgemm_col_major_input() {
644        // A is col-major (non-contiguous for row-major fusion) → triggers copy path
645        let a_data = vec![1.0, 3.0, 2.0, 4.0]; // col-major [[1,2],[3,4]]
646        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
647
648        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
649            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
650        });
651        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
652
653        bgemm_strided_into(
654            &mut c.view_mut(),
655            &a.view(),
656            &b.view(),
657            0,
658            1,
659            1,
660            1,
661            1.0,
662            0.0,
663            false,
664            false,
665        )
666        .unwrap();
667
668        // Same A=[[1,2],[3,4]], B=[[5,6],[7,8]]
669        assert_eq!(c.get(&[0, 0]), 19.0);
670        assert_eq!(c.get(&[0, 1]), 22.0);
671        assert_eq!(c.get(&[1, 0]), 43.0);
672        assert_eq!(c.get(&[1, 1]), 50.0);
673    }
674
675    #[test]
676    fn test_faer_bgemm_col_major_output() {
677        // C is col-major → triggers C copy path
678        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
679            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
680        });
681        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
682            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
683        });
684        let mut c = StridedArray::<f64>::col_major(&[2, 2]);
685
686        bgemm_strided_into(
687            &mut c.view_mut(),
688            &a.view(),
689            &b.view(),
690            0,
691            1,
692            1,
693            1,
694            1.0,
695            0.0,
696            false,
697            false,
698        )
699        .unwrap();
700
701        assert_eq!(c.get(&[0, 0]), 19.0);
702        assert_eq!(c.get(&[0, 1]), 22.0);
703        assert_eq!(c.get(&[1, 0]), 43.0);
704        assert_eq!(c.get(&[1, 1]), 50.0);
705    }
706
707    #[test]
708    fn test_faer_bgemm_col_major_with_beta() {
709        // C is col-major with beta != 0 → copy C in, matmul, copy back
710        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
711            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
712        });
713        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
714            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
715        });
716        // C col-major with initial values
717        let c_data = vec![10.0, 30.0, 20.0, 40.0]; // col-major [[10,20],[30,40]]
718        let mut c = StridedArray::<f64>::from_parts(c_data, &[2, 2], &[1, 2], 0).unwrap();
719
720        bgemm_strided_into(
721            &mut c.view_mut(),
722            &a.view(),
723            &b.view(),
724            0,
725            1,
726            1,
727            1,
728            2.0,
729            3.0, // C = 2*I*B + 3*C_old
730            false,
731            false,
732        )
733        .unwrap();
734
735        // C[0,0] = 2*1 + 3*10 = 32
736        assert_eq!(c.get(&[0, 0]), 32.0);
737        // C[1,1] = 2*4 + 3*40 = 128
738        assert_eq!(c.get(&[1, 1]), 128.0);
739    }
740
741    // ---- bgemm_contiguous_into tests ----
742
743    use crate::contiguous::{prepare_input_view, prepare_output_view};
744
745    #[test]
746    fn test_bgemm_contiguous_2x2() {
747        // Basic 2x2 matmul: C = A * B
748        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
749            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
750        });
751        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
752            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
753        });
754        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
755
756        // n_batch=0, A: n_group1=1 (lo), n_group2=1 (sum)
757        //             B: n_group1=1 (sum), n_group2=1 (ro)
758        //             C: n_group1=1 (lo), n_group2=1 (ro)
759        let a_op = prepare_input_view(&a.view(), 0, 1, 1, false).unwrap();
760        let b_op = prepare_input_view(&b.view(), 0, 1, 1, false).unwrap();
761        let mut c_view = c.view_mut();
762        let mut c_op = prepare_output_view(&mut c_view, 0, 1, 1, 0.0).unwrap();
763
764        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
765
766        c_op.finalize_into(&mut c_view).unwrap();
767
768        assert_eq!(c.get(&[0, 0]), 19.0);
769        assert_eq!(c.get(&[0, 1]), 22.0);
770        assert_eq!(c.get(&[1, 0]), 43.0);
771        assert_eq!(c.get(&[1, 1]), 50.0);
772    }
773
774    #[test]
775    fn test_bgemm_contiguous_batched() {
776        // Batched: 2 x (2x3) * (3x2) matmul
777        // Batch-last: A: [lo, sum, batch]=[2,3,2], B: [sum, ro, batch]=[3,2,2], C: [lo, ro, batch]=[2,2,2]
778        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
779            (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
780        });
781        let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
782            (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
783        });
784        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
785
786        // n_batch=1, A: n_group1=1 (lo), n_group2=1 (sum)
787        let a_op = prepare_input_view(&a.view(), 1, 1, 1, false).unwrap();
788        let b_op = prepare_input_view(&b.view(), 1, 1, 1, false).unwrap();
789        let mut c_view = c.view_mut();
790        let mut c_op = prepare_output_view(&mut c_view, 1, 1, 1, 0.0).unwrap();
791
792        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[2], 2, 2, 3, 1.0, 0.0).unwrap();
793
794        c_op.finalize_into(&mut c_view).unwrap();
795
796        // C: [lo, ro, batch]
797        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
798        // C0[0,0] = 1*1+2*3+3*5 = 22
799        assert_eq!(c.get(&[0, 0, 0]), 22.0);
800        // C0[0,1] = 1*2+2*4+3*6 = 28
801        assert_eq!(c.get(&[0, 1, 0]), 28.0);
802        // C0[1,0] = 4*1+5*3+6*5 = 49
803        assert_eq!(c.get(&[1, 0, 0]), 49.0);
804        // C0[1,1] = 4*2+5*4+6*6 = 64
805        assert_eq!(c.get(&[1, 1, 0]), 64.0);
806
807        // Batch 1: A1=[[7,8,9],[10,11,12]], B1=[[7,8],[9,10],[11,12]]
808        // C1[0,0] = 7*7+8*9+9*11 = 49+72+99 = 220
809        assert_eq!(c.get(&[0, 0, 1]), 220.0);
810    }
811
812    #[test]
813    fn test_bgemm_contiguous_with_beta() {
814        // C = 2*I*B + 3*C_old
815        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
816            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
817        });
818        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
819            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
820        });
821        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
822            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
823        });
824
825        let a_op = prepare_input_view(&a.view(), 0, 1, 1, false).unwrap();
826        let b_op = prepare_input_view(&b.view(), 0, 1, 1, false).unwrap();
827        let mut c_view = c.view_mut();
828        let mut c_op = prepare_output_view(&mut c_view, 0, 1, 1, 3.0).unwrap();
829
830        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 2.0, 3.0).unwrap();
831
832        c_op.finalize_into(&mut c_view).unwrap();
833
834        // C[0,0] = 2*1 + 3*10 = 32
835        assert_eq!(c.get(&[0, 0]), 32.0);
836        // C[0,1] = 2*2 + 3*20 = 64
837        assert_eq!(c.get(&[0, 1]), 64.0);
838        // C[1,0] = 2*3 + 3*30 = 96
839        assert_eq!(c.get(&[1, 0]), 96.0);
840        // C[1,1] = 2*4 + 3*40 = 128
841        assert_eq!(c.get(&[1, 1]), 128.0);
842    }
843
844    #[test]
845    fn test_bgemm_contiguous_non_contiguous_input() {
846        // A is col-major (triggers copy in prepare_input_view for row-major grouping)
847        let a_data = vec![1.0, 3.0, 2.0, 4.0]; // col-major [[1,2],[3,4]]
848        let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
849
850        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
851            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
852        });
853        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
854
855        let a_op = prepare_input_view(&a.view(), 0, 1, 1, false).unwrap();
856        let b_op = prepare_input_view(&b.view(), 0, 1, 1, false).unwrap();
857        let mut c_view = c.view_mut();
858        let mut c_op = prepare_output_view(&mut c_view, 0, 1, 1, 0.0).unwrap();
859
860        bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
861
862        c_op.finalize_into(&mut c_view).unwrap();
863
864        // Same A=[[1,2],[3,4]], B=[[5,6],[7,8]]
865        assert_eq!(c.get(&[0, 0]), 19.0);
866        assert_eq!(c.get(&[0, 1]), 22.0);
867        assert_eq!(c.get(&[1, 0]), 43.0);
868        assert_eq!(c.get(&[1, 1]), 50.0);
869    }
870}