Skip to main content

strided_einsum2/
backend.rs

1//! Backend abstraction for batched GEMM dispatch.
2//!
3//! This module defines the [`Backend`] trait, marker structs for each backend,
4//! and the `ActiveBackend` type alias that serves as the single point of
5//! backend selection based on Cargo features.
6
7use strided_view::ElementOp;
8
9/// Trait for backends that can execute batched GEMM on contiguous operands.
10///
11/// Each backend declares its configuration (conjugation materialization,
12/// stride requirements) and provides a GEMM implementation.
13///
14/// Implementations are provided by each backend module (faer, blas).
15/// External crates can implement this trait for custom scalar types
16/// (e.g., tropical semiring) and pass the backend to [`einsum2_with_backend_into`].
17///
18/// [`einsum2_with_backend_into`]: crate::einsum2_with_backend_into
19pub trait Backend<T: crate::ScalarBase> {
20    /// Whether the backend needs conjugation materialized into the data
21    /// before GEMM (e.g., CBLAS has no conjugation flag for `?gemm`).
22    const MATERIALIZES_CONJ: bool;
23
24    /// Whether the backend requires at least one unit stride per matrix
25    /// dimension (row or column stride must be 1). CBLAS `?gemm` requires
26    /// this; faer does not.
27    const REQUIRES_UNIT_STRIDE: bool;
28
29    /// Execute batched GEMM: `C = alpha * A * B + beta * C` for each batch.
30    ///
31    /// - `c`: mutable output operand (batch x m x n)
32    /// - `a`: input operand (batch x m x k)
33    /// - `b`: input operand (batch x k x n)
34    /// - `batch_dims`: sizes of the batch dimensions
35    /// - `m`, `n`, `k`: fused matrix dimensions
36    /// - `alpha`, `beta`: scaling factors
37    fn bgemm_contiguous_into(
38        c: &mut crate::contiguous::ContiguousOperandMut<T>,
39        a: &crate::contiguous::ContiguousOperand<T>,
40        b: &crate::contiguous::ContiguousOperand<T>,
41        batch_dims: &[usize],
42        m: usize,
43        n: usize,
44        k: usize,
45        alpha: T,
46        beta: T,
47    ) -> strided_view::Result<()>;
48}
49
50// ---------------------------------------------------------------------------
51// Marker structs
52// ---------------------------------------------------------------------------
53
54/// Batched GEMM backend using the [`faer`] library.
55///
56/// `Backend<T>` is implemented in `bgemm_faer.rs`.
57#[cfg(feature = "faer")]
58pub struct FaerBackend;
59
60/// Batched GEMM backend using CBLAS (via `cblas-sys` or `cblas-inject`).
61///
62/// `Backend<T>` is implemented in `bgemm_blas.rs`.
63#[cfg(any(feature = "blas", feature = "blas-inject"))]
64pub struct BlasBackend;
65
66/// Fallback batched GEMM backend using explicit loops (no external library).
67///
68/// This backend is used as `ActiveBackend` when no GEMM feature is enabled.
69/// The GEMM dispatch in `einsum2_into` calls `bgemm_naive` directly rather
70/// than going through the `Backend` trait, so `bgemm_contiguous_into` is
71/// unreachable.
72#[allow(dead_code)]
73pub struct NaiveBackend;
74
75impl<T> Backend<T> for NaiveBackend
76where
77    T: crate::ScalarBase + strided_view::ElementOpApply,
78{
79    const MATERIALIZES_CONJ: bool = false;
80    const REQUIRES_UNIT_STRIDE: bool = false;
81
82    fn bgemm_contiguous_into(
83        c: &mut crate::contiguous::ContiguousOperandMut<T>,
84        a: &crate::contiguous::ContiguousOperand<T>,
85        b: &crate::contiguous::ContiguousOperand<T>,
86        batch_dims: &[usize],
87        m: usize,
88        n: usize,
89        k: usize,
90        alpha: T,
91        beta: T,
92    ) -> strided_view::Result<()> {
93        let a_ptr = a.ptr();
94        let b_ptr = b.ptr();
95        let c_ptr = c.ptr();
96        let a_rs = a.row_stride();
97        let a_cs = a.col_stride();
98        let b_rs = b.row_stride();
99        let b_cs = b.col_stride();
100        let c_rs = c.row_stride();
101        let c_cs = c.col_stride();
102
103        let mut batch_iter = crate::util::MultiIndex::new(batch_dims);
104        while batch_iter.next().is_some() {
105            let a_base = batch_iter.offset(a.batch_strides());
106            let b_base = batch_iter.offset(b.batch_strides());
107            let c_base = batch_iter.offset(c.batch_strides());
108
109            for i in 0..m {
110                for j in 0..n {
111                    let mut acc = T::zero();
112                    for l in 0..k {
113                        let mut a_val = unsafe {
114                            *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
115                        };
116                        let mut b_val = unsafe {
117                            *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
118                        };
119                        if a.conj() {
120                            a_val = strided_view::Conj::apply(a_val);
121                        }
122                        if b.conj() {
123                            b_val = strided_view::Conj::apply(b_val);
124                        }
125                        acc = acc + a_val * b_val;
126                    }
127                    unsafe {
128                        let c_elem = c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
129                        if beta == T::zero() {
130                            *c_elem = alpha * acc;
131                        } else {
132                            *c_elem = alpha * acc + beta * (*c_elem);
133                        }
134                    }
135                }
136            }
137        }
138        Ok(())
139    }
140}
141
142// ---------------------------------------------------------------------------
143// ActiveBackend type alias -- the SINGLE point of backend selection
144// ---------------------------------------------------------------------------
145
146/// The active GEMM backend, selected by Cargo features.
147///
148/// - `blas` or `blas-inject` -> `BlasBackend`
149/// - `faer` without BLAS -> `FaerBackend`
150/// - no backend feature -> `NaiveBackend`
151/// - invalid combos -> `NaiveBackend` (placeholder; `compile_error!` fires first)
152#[cfg(any(
153    all(feature = "blas", not(feature = "blas-inject")),
154    all(feature = "blas-inject", not(feature = "blas"))
155))]
156pub type ActiveBackend = BlasBackend;
157
158#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
159pub type ActiveBackend = FaerBackend;
160
161#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
162pub type ActiveBackend = NaiveBackend;
163
164/// Placeholder for invalid mutually-exclusive feature combinations.
165///
166/// The crate emits `compile_error!` for these combinations (in `lib.rs`), so this
167/// alias only suppresses cascading type-resolution errors.
168#[cfg(all(feature = "blas", feature = "blas-inject"))]
169pub type ActiveBackend = NaiveBackend;