Skip to main content

strided_einsum2/
raw_bgemm.rs

1//! Raw borrowed-layout batched GEMM entry points.
2//!
3//! This module is the prepared-replay boundary for callers that already own
4//! validated layout metadata. It keeps the public API independent of a concrete
5//! GEMM backend while still allowing backend modules to provide specialized raw
6//! implementations.
7
8use crate::backend::Backend;
9use crate::{contiguous, Scalar, ScalarBase};
10use strided_view::{Conj, ElementOp, ElementOpApply, RawStridedMut, RawStridedRef};
11
12/// Batched strided GEMM on raw borrowed layout metadata using the active backend.
13///
14/// This is the raw-layout counterpart to backend-specific `bgemm_strided_into`
15/// functions. It avoids constructing owned-metadata `StridedView` wrappers when
16/// a caller already has borrowed `dims`/`strides`/`offset` descriptors.
17#[allow(clippy::too_many_arguments)]
18pub fn bgemm_raw_strided_into<T>(
19    c: RawStridedMut<'_, T>,
20    a: RawStridedRef<'_, T>,
21    b: RawStridedRef<'_, T>,
22    n_batch: usize,
23    n_lo: usize,
24    n_ro: usize,
25    n_sum: usize,
26    alpha: T,
27    beta: T,
28    conj_a: bool,
29    conj_b: bool,
30) -> crate::Result<()>
31where
32    T: Scalar,
33    crate::backend::ActiveBackend: Backend<T>,
34{
35    validate_bgemm_shapes(&c, &a, &b, n_batch, n_lo, n_ro, n_sum)?;
36    unsafe {
37        bgemm_raw_strided_into_unchecked(
38            c, a, b, n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
39        )
40    }
41}
42
43/// Batched strided GEMM on raw borrowed layout metadata without validation.
44///
45/// # Safety
46/// The caller must ensure:
47/// - all raw strided operands are in bounds,
48/// - `n_lo`, `n_ro`, `n_sum`, and `n_batch` partition operand ranks as
49///   `[lo, sum, batch]`, `[sum, ro, batch]`, and `[lo, ro, batch]`,
50/// - matching dimension groups have identical extents,
51/// - `c` does not alias `a` or `b` in a way that violates mutable access.
52#[allow(clippy::too_many_arguments)]
53pub unsafe fn bgemm_raw_strided_into_unchecked<T>(
54    c: RawStridedMut<'_, T>,
55    a: RawStridedRef<'_, T>,
56    b: RawStridedRef<'_, T>,
57    n_batch: usize,
58    n_lo: usize,
59    n_ro: usize,
60    n_sum: usize,
61    alpha: T,
62    beta: T,
63    conj_a: bool,
64    conj_b: bool,
65) -> crate::Result<()>
66where
67    T: Scalar,
68    crate::backend::ActiveBackend: Backend<T>,
69{
70    bgemm_raw_with_backend_into_unchecked::<T, crate::backend::ActiveBackend>(
71        c, a, b, n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
72    )
73}
74
75/// Batched strided GEMM on raw borrowed metadata using an explicit backend.
76///
77/// Backend implementations that do not provide a specialized raw path use the
78/// same preparation pipeline as `einsum2_dispatch`: materialize/copy only when
79/// the backend requires it, call `Backend::bgemm_contiguous_into`, then finalize
80/// the destination.
81#[allow(clippy::too_many_arguments)]
82pub fn bgemm_raw_with_backend_into<T, B>(
83    c: RawStridedMut<'_, T>,
84    a: RawStridedRef<'_, T>,
85    b: RawStridedRef<'_, T>,
86    n_batch: usize,
87    n_lo: usize,
88    n_ro: usize,
89    n_sum: usize,
90    alpha: T,
91    beta: T,
92    conj_a: bool,
93    conj_b: bool,
94) -> crate::Result<()>
95where
96    T: ScalarBase + ElementOpApply,
97    B: Backend<T>,
98{
99    validate_bgemm_shapes(&c, &a, &b, n_batch, n_lo, n_ro, n_sum)?;
100    unsafe {
101        bgemm_raw_with_backend_into_unchecked::<T, B>(
102            c, a, b, n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
103        )
104    }
105}
106
107/// Unchecked variant of [`bgemm_raw_with_backend_into`].
108///
109/// # Safety
110/// The caller must uphold the same layout and aliasing invariants as
111/// [`bgemm_raw_strided_into_unchecked`].
112#[allow(clippy::too_many_arguments)]
113pub unsafe fn bgemm_raw_with_backend_into_unchecked<T, B>(
114    mut c: RawStridedMut<'_, T>,
115    a: RawStridedRef<'_, T>,
116    b: RawStridedRef<'_, T>,
117    _n_batch: usize,
118    n_lo: usize,
119    n_ro: usize,
120    n_sum: usize,
121    alpha: T,
122    beta: T,
123    conj_a: bool,
124    conj_b: bool,
125) -> crate::Result<()>
126where
127    T: ScalarBase + ElementOpApply,
128    B: Backend<T>,
129{
130    let a_dims = a.dims();
131    let b_dims = b.dims();
132    let lo_dims = &a_dims[..n_lo];
133    let sum_dims = &a_dims[n_lo..n_lo + n_sum];
134    let batch_dims = &a_dims[n_lo + n_sum..];
135    let ro_dims = &b_dims[n_sum..n_sum + n_ro];
136
137    if c.dims().iter().any(|&dim| dim == 0) {
138        return Ok(());
139    }
140    if sum_dims.iter().any(|&dim| dim == 0) {
141        scale_or_zero_raw_mut(&mut c, beta);
142        return Ok(());
143    }
144
145    let use_pool = true;
146    let materialize = if B::MATERIALIZES_CONJ {
147        Some(Conj::apply as fn(T) -> T)
148    } else {
149        None
150    };
151
152    let a_op = contiguous::prepare_input_raw(
153        &a,
154        n_lo,
155        n_sum,
156        conj_a,
157        B::REQUIRES_UNIT_STRIDE,
158        use_pool,
159        materialize,
160    )?;
161    let b_op = contiguous::prepare_input_raw(
162        &b,
163        n_sum,
164        n_ro,
165        conj_b,
166        B::REQUIRES_UNIT_STRIDE,
167        use_pool,
168        materialize,
169    )?;
170    let mut c_op = contiguous::prepare_output_raw(
171        &mut c,
172        n_lo,
173        n_ro,
174        beta,
175        B::REQUIRES_UNIT_STRIDE,
176        use_pool,
177    )?;
178
179    let m: usize = lo_dims.iter().product::<usize>().max(1);
180    let k: usize = sum_dims.iter().product::<usize>().max(1);
181    let n: usize = ro_dims.iter().product::<usize>().max(1);
182
183    B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta)?;
184    c_op.finalize_raw_into(&mut c)?;
185
186    Ok(())
187}
188
189pub(crate) fn validate_bgemm_shapes<T>(
190    c: &RawStridedMut<'_, T>,
191    a: &RawStridedRef<'_, T>,
192    b: &RawStridedRef<'_, T>,
193    n_batch: usize,
194    n_lo: usize,
195    n_ro: usize,
196    n_sum: usize,
197) -> crate::Result<()> {
198    let a_rank = n_lo + n_sum + n_batch;
199    let b_rank = n_sum + n_ro + n_batch;
200    let c_rank = n_lo + n_ro + n_batch;
201    if a.dims().len() != a_rank {
202        return Err(strided_view::StridedError::RankMismatch(a_rank, a.dims().len()).into());
203    }
204    if b.dims().len() != b_rank {
205        return Err(strided_view::StridedError::RankMismatch(b_rank, b.dims().len()).into());
206    }
207    if c.dims().len() != c_rank {
208        return Err(strided_view::StridedError::RankMismatch(c_rank, c.dims().len()).into());
209    }
210
211    let lo_dims = &a.dims()[..n_lo];
212    let sum_dims = &a.dims()[n_lo..n_lo + n_sum];
213    let batch_dims = &a.dims()[n_lo + n_sum..];
214    let ro_dims = &b.dims()[n_sum..n_sum + n_ro];
215
216    if &b.dims()[..n_sum] != sum_dims {
217        return Err(strided_view::StridedError::ShapeMismatch(
218            sum_dims.to_vec(),
219            b.dims()[..n_sum].to_vec(),
220        )
221        .into());
222    }
223    if &b.dims()[n_sum + n_ro..] != batch_dims {
224        return Err(strided_view::StridedError::ShapeMismatch(
225            batch_dims.to_vec(),
226            b.dims()[n_sum + n_ro..].to_vec(),
227        )
228        .into());
229    }
230    if &c.dims()[..n_lo] != lo_dims {
231        return Err(strided_view::StridedError::ShapeMismatch(
232            lo_dims.to_vec(),
233            c.dims()[..n_lo].to_vec(),
234        )
235        .into());
236    }
237    if &c.dims()[n_lo..n_lo + n_ro] != ro_dims {
238        return Err(strided_view::StridedError::ShapeMismatch(
239            ro_dims.to_vec(),
240            c.dims()[n_lo..n_lo + n_ro].to_vec(),
241        )
242        .into());
243    }
244    if &c.dims()[n_lo + n_ro..] != batch_dims {
245        return Err(strided_view::StridedError::ShapeMismatch(
246            batch_dims.to_vec(),
247            c.dims()[n_lo + n_ro..].to_vec(),
248        )
249        .into());
250    }
251    Ok(())
252}
253
254pub(crate) fn scale_or_zero_raw_mut<T: ScalarBase>(c: &mut RawStridedMut<'_, T>, beta: T) {
255    if c.dims().iter().any(|&dim| dim == 0) {
256        return;
257    }
258
259    fn visit<T: ScalarBase>(
260        ptr: *mut T,
261        dims: &[usize],
262        strides: &[isize],
263        axis: usize,
264        offset: isize,
265        beta: T,
266        zero: T,
267    ) {
268        if axis == dims.len() {
269            unsafe {
270                let dst = ptr.offset(offset);
271                if beta == zero {
272                    *dst = zero;
273                } else {
274                    *dst = beta * *dst;
275                }
276            }
277            return;
278        }
279
280        for i in 0..dims[axis] {
281            visit(
282                ptr,
283                dims,
284                strides,
285                axis + 1,
286                offset + i as isize * strides[axis],
287                beta,
288                zero,
289            );
290        }
291    }
292
293    visit(c.as_mut_ptr(), c.dims(), c.strides(), 0, 0, beta, T::zero());
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    fn raw_bgemm_2x2<T>(one: T, zero: T) -> Vec<T>
301    where
302        T: Scalar,
303        crate::backend::ActiveBackend: Backend<T>,
304        T: From<f32>,
305    {
306        let dims = [2, 2];
307        let strides = [2, 1];
308        let a_data = [T::from(1.0), T::from(2.0), T::from(3.0), T::from(4.0)];
309        let b_data = [T::from(5.0), T::from(6.0), T::from(7.0), T::from(8.0)];
310        let mut c_data = vec![zero; 4];
311        let a = RawStridedRef::new(&a_data, &dims, &strides, 0).unwrap();
312        let b = RawStridedRef::new(&b_data, &dims, &strides, 0).unwrap();
313        let c = RawStridedMut::new(&mut c_data, &dims, &strides, 0).unwrap();
314        bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, one, zero, false, false).unwrap();
315        c_data
316    }
317
318    #[test]
319    fn raw_bgemm_active_backend_f64() {
320        assert_eq!(raw_bgemm_2x2(1.0f64, 0.0), vec![19.0, 22.0, 43.0, 50.0]);
321    }
322
323    #[test]
324    fn raw_bgemm_active_backend_f32() {
325        assert_eq!(raw_bgemm_2x2(1.0f32, 0.0), vec![19.0f32, 22.0, 43.0, 50.0]);
326    }
327
328    #[test]
329    fn raw_bgemm_active_backend_complex_conj() {
330        use num_complex::Complex64;
331
332        let i = Complex64::i();
333        let dims = [2, 2];
334        let strides = [2, 1];
335        let a_data = [
336            Complex64::new(1.0, 0.0) + i,
337            Complex64::new(2.0, 0.0),
338            Complex64::new(3.0, 0.0),
339            Complex64::new(4.0, 0.0) - i,
340        ];
341        let b_data = [
342            Complex64::new(1.0, 0.0),
343            Complex64::new(0.0, 0.0),
344            Complex64::new(0.0, 0.0),
345            Complex64::new(1.0, 0.0),
346        ];
347        let mut c_data = vec![Complex64::new(0.0, 0.0); 4];
348        let a = RawStridedRef::new(&a_data, &dims, &strides, 0).unwrap();
349        let b = RawStridedRef::new(&b_data, &dims, &strides, 0).unwrap();
350        let c = RawStridedMut::new(&mut c_data, &dims, &strides, 0).unwrap();
351        bgemm_raw_strided_into(
352            c,
353            a,
354            b,
355            0,
356            1,
357            1,
358            1,
359            Complex64::new(1.0, 0.0),
360            Complex64::new(0.0, 0.0),
361            true,
362            false,
363        )
364        .unwrap();
365        assert_eq!(
366            c_data,
367            vec![
368                Complex64::new(1.0, -1.0),
369                Complex64::new(2.0, 0.0),
370                Complex64::new(3.0, 0.0),
371                Complex64::new(4.0, 1.0),
372            ]
373        );
374    }
375
376    #[test]
377    fn raw_bgemm_active_backend_checked_shape_mismatch() {
378        let a_dims = [2, 2];
379        let b_dims = [3, 2];
380        let c_dims = [2, 2];
381        let a_strides = [2, 1];
382        let b_strides = [2, 1];
383        let c_strides = [2, 1];
384        let a_data = [1.0, 2.0, 3.0, 4.0];
385        let b_data = [0.0; 6];
386        let mut c_data = [0.0; 4];
387        let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
388        let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
389        let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
390        let err = bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false).unwrap_err();
391        assert!(matches!(
392            err,
393            crate::EinsumError::Strided(strided_view::StridedError::ShapeMismatch(_, _))
394        ));
395    }
396
397    #[test]
398    fn raw_bgemm_explicit_backend_checked_rank_mismatch() {
399        let a_dims = [2, 2];
400        let b_dims = [2, 2];
401        let c_dims = [2];
402        let strides = [2, 1];
403        let c_strides = [1];
404        let a_data = [1.0, 2.0, 3.0, 4.0];
405        let b_data = [5.0, 6.0, 7.0, 8.0];
406        let mut c_data = [0.0; 2];
407        let a = RawStridedRef::new(&a_data, &a_dims, &strides, 0).unwrap();
408        let b = RawStridedRef::new(&b_data, &b_dims, &strides, 0).unwrap();
409        let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
410        let err = bgemm_raw_with_backend_into::<f64, crate::backend::ActiveBackend>(
411            c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false,
412        )
413        .unwrap_err();
414        assert!(matches!(
415            err,
416            crate::EinsumError::Strided(strided_view::StridedError::RankMismatch(2, 1))
417        ));
418    }
419
420    #[test]
421    fn raw_bgemm_zero_sum_scales_destination() {
422        let a_dims = [2, 0];
423        let b_dims = [0, 2];
424        let c_dims = [2, 2];
425        let a_strides = [0, 0];
426        let b_strides = [0, 0];
427        let c_strides = [2, 1];
428        let a_data = [0.0; 1];
429        let b_data = [0.0; 1];
430        let mut c_data = [1.0, 2.0, 3.0, 4.0];
431        let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
432        let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
433        let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
434
435        bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 2.0, false, false).unwrap();
436
437        assert_eq!(c_data, [2.0, 4.0, 6.0, 8.0]);
438    }
439
440    #[test]
441    fn raw_bgemm_zero_sum_beta_zero_clears_destination() {
442        let a_dims = [2, 0];
443        let b_dims = [0, 2];
444        let c_dims = [2, 2];
445        let a_strides = [0, 0];
446        let b_strides = [0, 0];
447        let c_strides = [2, 1];
448        let a_data = [0.0; 1];
449        let b_data = [0.0; 1];
450        let mut c_data = [1.0, 2.0, 3.0, 4.0];
451        let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
452        let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
453        let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
454
455        bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false).unwrap();
456
457        assert_eq!(c_data, [0.0, 0.0, 0.0, 0.0]);
458    }
459
460    #[test]
461    fn raw_bgemm_empty_output_is_noop() {
462        let a_dims = [0, 2];
463        let b_dims = [2, 2];
464        let c_dims = [0, 2];
465        let a_strides = [2, 1];
466        let b_strides = [2, 1];
467        let c_strides = [2, 1];
468        let a_data = [1.0, 2.0];
469        let b_data = [3.0, 4.0, 5.0, 6.0];
470        let mut c_data = [7.0, 8.0, 9.0, 10.0];
471        let expected = c_data;
472        let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
473        let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
474        let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
475
476        bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 1.0, false, false).unwrap();
477
478        assert_eq!(c_data, expected);
479    }
480
481    #[test]
482    fn raw_bgemm_noncontiguous_output_writes_back() {
483        let a_dims = [2, 2];
484        let b_dims = [2, 2];
485        let c_dims = [2, 2];
486        let a_strides = [2, 1];
487        let b_strides = [2, 1];
488        let c_strides = [1, 3];
489        let a_data = [1.0, 2.0, 3.0, 4.0];
490        let b_data = [5.0, 6.0, 7.0, 8.0];
491        let mut c_data = [0.0; 8];
492        let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
493        let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
494        let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 1).unwrap();
495
496        bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false).unwrap();
497
498        assert_eq!(c_data[1], 19.0);
499        assert_eq!(c_data[4], 22.0);
500        assert_eq!(c_data[2], 43.0);
501        assert_eq!(c_data[5], 50.0);
502        assert_eq!(c_data[0], 0.0);
503        assert_eq!(c_data[3], 0.0);
504    }
505}