strided_einsum2/
bgemm_naive.rs

1//! Naive batched GEMM kernel on strided views.
2//!
3//! Operates on N-dimensional permuted views where dimensions are grouped as:
4//! - A: [lo..., sum..., batch...]
5//! - B: [sum..., ro..., batch...]
6//! - C: [lo..., ro..., batch...]
7
8use crate::util::MultiIndex;
9use strided_view::{ElementOp, ElementOpApply, StridedView, StridedViewMut};
10
11/// Batched strided GEMM: C = alpha * A * B + beta * C
12///
13/// The views must be pre-permuted so that their dimensions are grouped as
14/// (batch-last canonical order):
15/// - A: `n_lo` dims, then `n_sum` dims, then `n_batch` batch dims
16/// - B: `n_sum` dims, then `n_ro` dims, then `n_batch` batch dims
17/// - C: `n_lo` dims, then `n_ro` dims, then `n_batch` batch dims
18///
19/// Dimension sizes must match across operands within each group.
20pub fn bgemm_strided_into<T>(
21    c: &mut StridedViewMut<T>,
22    a: &StridedView<T>,
23    b: &StridedView<T>,
24    _n_batch: usize,
25    n_lo: usize,
26    n_ro: usize,
27    n_sum: usize,
28    alpha: T,
29    beta: T,
30    conj_a: bool,
31    conj_b: bool,
32) -> strided_view::Result<()>
33where
34    T: Copy
35        + ElementOpApply
36        + std::ops::Mul<Output = T>
37        + std::ops::Add<Output = T>
38        + num_traits::Zero
39        + num_traits::One
40        + PartialEq,
41{
42    let a_dims = a.dims();
43    let b_dims = b.dims();
44    let c_dims = c.dims();
45    let a_strides = a.strides();
46    let b_strides = b.strides();
47    let c_strides = c.strides();
48
49    // Extract dimension groups (batch-last canonical order)
50    let lo_dims = &a_dims[..n_lo];
51    let sum_dims = &a_dims[n_lo..n_lo + n_sum];
52    let batch_dims = &a_dims[n_lo + n_sum..];
53    let ro_dims = &b_dims[n_sum..n_sum + n_ro];
54
55    // Extract stride groups (batch-last)
56    let a_lo_strides = &a_strides[..n_lo];
57    let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
58    let a_batch_strides = &a_strides[n_lo + n_sum..];
59
60    let b_sum_strides = &b_strides[..n_sum];
61    let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
62    let b_batch_strides = &b_strides[n_sum + n_ro..];
63
64    let c_lo_strides = &c_strides[..n_lo];
65    let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
66    let c_batch_strides = &c_strides[n_lo + n_ro..];
67
68    // Validate dimension consistency
69    debug_assert_eq!(&c_dims[..n_lo], lo_dims);
70    debug_assert_eq!(&c_dims[n_lo..n_lo + n_ro], ro_dims);
71    debug_assert_eq!(&c_dims[n_lo + n_ro..], batch_dims);
72    debug_assert_eq!(&b_dims[..n_sum], sum_dims);
73    debug_assert_eq!(&b_dims[n_sum + n_ro..], batch_dims);
74
75    let a_ptr = a.ptr();
76    let b_ptr = b.ptr();
77    let c_ptr = c.as_mut_ptr();
78
79    let is_beta_zero = beta == T::zero();
80    let is_alpha_one = alpha == T::one();
81
82    let mut batch_iter = MultiIndex::new(batch_dims);
83    let mut lo_iter = MultiIndex::new(lo_dims);
84    let mut ro_iter = MultiIndex::new(ro_dims);
85    let mut sum_iter = MultiIndex::new(sum_dims);
86    while batch_iter.next().is_some() {
87        let a_batch_off = batch_iter.offset(a_batch_strides);
88        let b_batch_off = batch_iter.offset(b_batch_strides);
89        let c_batch_off = batch_iter.offset(c_batch_strides);
90
91        lo_iter.reset();
92        while lo_iter.next().is_some() {
93            let a_lo_off = lo_iter.offset(a_lo_strides);
94            let c_lo_off = lo_iter.offset(c_lo_strides);
95
96            ro_iter.reset();
97            while ro_iter.next().is_some() {
98                let b_ro_off = ro_iter.offset(b_ro_strides);
99                let c_ro_off = ro_iter.offset(c_ro_strides);
100
101                // Accumulate sum over contraction indices
102                let mut acc = T::zero();
103                sum_iter.reset();
104                while sum_iter.next().is_some() {
105                    let a_sum_off = sum_iter.offset(a_sum_strides);
106                    let b_sum_off = sum_iter.offset(b_sum_strides);
107
108                    let a_raw = unsafe { *a_ptr.offset(a_batch_off + a_lo_off + a_sum_off) };
109                    let b_raw = unsafe { *b_ptr.offset(b_batch_off + b_sum_off + b_ro_off) };
110                    let a_val = if conj_a {
111                        strided_view::Conj::apply(a_raw)
112                    } else {
113                        a_raw
114                    };
115                    let b_val = if conj_b {
116                        strided_view::Conj::apply(b_raw)
117                    } else {
118                        b_raw
119                    };
120                    acc = acc + a_val * b_val;
121                }
122
123                // Write: c = alpha * acc + beta * c_old
124                let c_off = c_batch_off + c_lo_off + c_ro_off;
125                unsafe {
126                    let c_elem = c_ptr.offset(c_off);
127                    if is_beta_zero {
128                        if is_alpha_one {
129                            *c_elem = acc;
130                        } else {
131                            *c_elem = alpha * acc;
132                        }
133                    } else {
134                        let old = *c_elem;
135                        if is_alpha_one {
136                            *c_elem = acc + beta * old;
137                        } else {
138                            *c_elem = alpha * acc + beta * old;
139                        }
140                    }
141                }
142            }
143        }
144    }
145
146    Ok(())
147}
148
149/// Batched strided GEMM with closure-based element mapping: C = alpha * map_a(A) * map_b(B) + beta * C
150///
151/// Like [`bgemm_strided_into`] but uses closures instead of conjugation flags,
152/// allowing custom scalar types that don't implement `ElementOpApply`.
153pub fn bgemm_strided_into_with_map<T, MapA, MapB>(
154    c: &mut StridedViewMut<T>,
155    a: &StridedView<T>,
156    b: &StridedView<T>,
157    _n_batch: usize,
158    n_lo: usize,
159    n_ro: usize,
160    n_sum: usize,
161    alpha: T,
162    beta: T,
163    map_a: MapA,
164    map_b: MapB,
165) -> strided_view::Result<()>
166where
167    T: Copy
168        + std::ops::Mul<Output = T>
169        + std::ops::Add<Output = T>
170        + num_traits::Zero
171        + num_traits::One
172        + PartialEq,
173    MapA: Fn(T) -> T,
174    MapB: Fn(T) -> T,
175{
176    let a_dims = a.dims();
177    let b_dims = b.dims();
178    let c_dims = c.dims();
179    let a_strides = a.strides();
180    let b_strides = b.strides();
181    let c_strides = c.strides();
182
183    let lo_dims = &a_dims[..n_lo];
184    let sum_dims = &a_dims[n_lo..n_lo + n_sum];
185    let batch_dims = &a_dims[n_lo + n_sum..];
186    let ro_dims = &b_dims[n_sum..n_sum + n_ro];
187
188    let a_lo_strides = &a_strides[..n_lo];
189    let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
190    let a_batch_strides = &a_strides[n_lo + n_sum..];
191
192    let b_sum_strides = &b_strides[..n_sum];
193    let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
194    let b_batch_strides = &b_strides[n_sum + n_ro..];
195
196    let c_lo_strides = &c_strides[..n_lo];
197    let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
198    let c_batch_strides = &c_strides[n_lo + n_ro..];
199
200    debug_assert_eq!(&c_dims[..n_lo], lo_dims);
201    debug_assert_eq!(&c_dims[n_lo..n_lo + n_ro], ro_dims);
202    debug_assert_eq!(&c_dims[n_lo + n_ro..], batch_dims);
203    debug_assert_eq!(&b_dims[..n_sum], sum_dims);
204    debug_assert_eq!(&b_dims[n_sum + n_ro..], batch_dims);
205
206    let a_ptr = a.ptr();
207    let b_ptr = b.ptr();
208    let c_ptr = c.as_mut_ptr();
209
210    let is_beta_zero = beta == T::zero();
211    let is_alpha_one = alpha == T::one();
212
213    let mut batch_iter = MultiIndex::new(batch_dims);
214    let mut lo_iter = MultiIndex::new(lo_dims);
215    let mut ro_iter = MultiIndex::new(ro_dims);
216    let mut sum_iter = MultiIndex::new(sum_dims);
217    while batch_iter.next().is_some() {
218        let a_batch_off = batch_iter.offset(a_batch_strides);
219        let b_batch_off = batch_iter.offset(b_batch_strides);
220        let c_batch_off = batch_iter.offset(c_batch_strides);
221
222        lo_iter.reset();
223        while lo_iter.next().is_some() {
224            let a_lo_off = lo_iter.offset(a_lo_strides);
225            let c_lo_off = lo_iter.offset(c_lo_strides);
226
227            ro_iter.reset();
228            while ro_iter.next().is_some() {
229                let b_ro_off = ro_iter.offset(b_ro_strides);
230                let c_ro_off = ro_iter.offset(c_ro_strides);
231
232                let mut acc = T::zero();
233                sum_iter.reset();
234                while sum_iter.next().is_some() {
235                    let a_sum_off = sum_iter.offset(a_sum_strides);
236                    let b_sum_off = sum_iter.offset(b_sum_strides);
237
238                    let a_raw = unsafe { *a_ptr.offset(a_batch_off + a_lo_off + a_sum_off) };
239                    let b_raw = unsafe { *b_ptr.offset(b_batch_off + b_sum_off + b_ro_off) };
240                    acc = acc + map_a(a_raw) * map_b(b_raw);
241                }
242
243                let c_off = c_batch_off + c_lo_off + c_ro_off;
244                unsafe {
245                    let c_elem = c_ptr.offset(c_off);
246                    if is_beta_zero {
247                        if is_alpha_one {
248                            *c_elem = acc;
249                        } else {
250                            *c_elem = alpha * acc;
251                        }
252                    } else {
253                        let old = *c_elem;
254                        if is_alpha_one {
255                            *c_elem = acc + beta * old;
256                        } else {
257                            *c_elem = alpha * acc + beta * old;
258                        }
259                    }
260                }
261            }
262        }
263    }
264
265    Ok(())
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use strided_view::StridedArray;
272
273    #[test]
274    fn test_bgemm_2x2() {
275        // Simple 2x2 matmul: C = A * B
276        // A = [[1, 2], [3, 4]], B = [[5, 6], [7, 8]]
277        // C = [[19, 22], [43, 50]]
278        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
279            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
280        });
281        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
282            [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
283        });
284        let mut c = StridedArray::<f64>::row_major(&[2, 2]);
285
286        bgemm_strided_into(
287            &mut c.view_mut(),
288            &a.view(),
289            &b.view(),
290            0,
291            1,
292            1,
293            1, // n_batch=0, n_lo=1(i), n_ro=1(k), n_sum=1(j)
294            1.0,
295            0.0,
296            false,
297            false,
298        )
299        .unwrap();
300
301        assert_eq!(c.get(&[0, 0]), 19.0);
302        assert_eq!(c.get(&[0, 1]), 22.0);
303        assert_eq!(c.get(&[1, 0]), 43.0);
304        assert_eq!(c.get(&[1, 1]), 50.0);
305    }
306
307    #[test]
308    fn test_bgemm_rect() {
309        // A: 2x3, B: 3x4, C: 2x4
310        let a =
311            StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
312        let b =
313            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
314        let mut c = StridedArray::<f64>::row_major(&[2, 4]);
315
316        bgemm_strided_into(
317            &mut c.view_mut(),
318            &a.view(),
319            &b.view(),
320            0,
321            1,
322            1,
323            1,
324            1.0,
325            0.0,
326            false,
327            false,
328        )
329        .unwrap();
330
331        // A = [[1,2,3],[4,5,6]]
332        // B = [[1,2,3,4],[5,6,7,8],[9,10,11,12]]
333        // C[0,0] = 1*1+2*5+3*9 = 38
334        assert_eq!(c.get(&[0, 0]), 38.0);
335        // C[1,3] = 4*4+5*8+6*12 = 16+40+72 = 128
336        assert_eq!(c.get(&[1, 3]), 128.0);
337    }
338
339    #[test]
340    fn test_bgemm_batched() {
341        // Batch=2, lo=2, sum=3, ro=2
342        // Batch-last: A: [lo, sum, batch]=[2,3,2], B: [sum, ro, batch]=[3,2,2], C: [lo, ro, batch]=[2,2,2]
343        let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
344            // idx=[lo, sum, batch] → same values as batch*6 + lo*3 + sum + 1
345            (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
346        });
347        let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
348            // idx=[sum, ro, batch] → same values as batch*6 + sum*2 + ro + 1
349            (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
350        });
351        let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
352
353        bgemm_strided_into(
354            &mut c.view_mut(),
355            &a.view(),
356            &b.view(),
357            1,
358            1,
359            1,
360            1, // n_batch=1, n_lo=1, n_ro=1, n_sum=1
361            1.0,
362            0.0,
363            false,
364            false,
365        )
366        .unwrap();
367
368        // C: [lo, ro, batch]
369        // Batch 0: A0=[[1,2,3],[4,5,6]], B0=[[1,2],[3,4],[5,6]]
370        // C0[0,0] = 1*1+2*3+3*5 = 22
371        assert_eq!(c.get(&[0, 0, 0]), 22.0);
372    }
373
374    #[test]
375    fn test_bgemm_alpha_beta() {
376        // C = 2*A*B + 3*C_old
377        let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
378            [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] // identity
379        });
380        let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
381            [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
382        });
383        let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
384            [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
385        });
386
387        bgemm_strided_into(
388            &mut c.view_mut(),
389            &a.view(),
390            &b.view(),
391            0,
392            1,
393            1,
394            1,
395            2.0,
396            3.0, // alpha=2, beta=3
397            false,
398            false,
399        )
400        .unwrap();
401
402        // C = 2 * I * B + 3 * C_old = 2*B + 3*C_old
403        // C[0,0] = 2*1 + 3*10 = 32
404        assert_eq!(c.get(&[0, 0]), 32.0);
405        // C[1,1] = 2*4 + 3*40 = 128
406        assert_eq!(c.get(&[1, 1]), 128.0);
407    }
408
409    #[test]
410    fn test_bgemm_outer_product() {
411        // Outer product: no sum dims
412        // a: [3], b: [4], c: [3, 4]
413        let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
414        let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
415        let mut c = StridedArray::<f64>::row_major(&[3, 4]);
416
417        bgemm_strided_into(
418            &mut c.view_mut(),
419            &a.view(),
420            &b.view(),
421            0,
422            1,
423            1,
424            0, // no batch, no sum
425            1.0,
426            0.0,
427            false,
428            false,
429        )
430        .unwrap();
431
432        assert_eq!(c.get(&[0, 0]), 1.0);
433        assert_eq!(c.get(&[2, 3]), 12.0);
434    }
435}