strided_einsum2/backend.rs
1//! Backend abstraction for batched GEMM dispatch.
2//!
3//! This module defines the [`Backend`] trait, marker structs for each backend,
4//! and the `ActiveBackend` type alias that serves as the single point of
5//! backend selection based on Cargo features.
6
7use strided_view::ElementOp;
8
9/// Trait for backends that can execute batched GEMM on contiguous operands.
10///
11/// Each backend declares its configuration (conjugation materialization,
12/// stride requirements) and provides a GEMM implementation.
13///
14/// Implementations are provided by each backend module (faer, blas).
15/// External crates can implement this trait for custom scalar types
16/// (e.g., tropical semiring) and pass the backend to [`einsum2_with_backend_into`].
17///
18/// [`einsum2_with_backend_into`]: crate::einsum2_with_backend_into
19pub trait Backend<T: crate::ScalarBase> {
20 /// Whether the backend needs conjugation materialized into the data
21 /// before GEMM (e.g., CBLAS has no conjugation flag for `?gemm`).
22 const MATERIALIZES_CONJ: bool;
23
24 /// Whether the backend requires at least one unit stride per matrix
25 /// dimension (row or column stride must be 1). CBLAS `?gemm` requires
26 /// this; faer does not.
27 const REQUIRES_UNIT_STRIDE: bool;
28
29 /// Execute batched GEMM: `C = alpha * A * B + beta * C` for each batch.
30 ///
31 /// - `c`: mutable output operand (batch x m x n)
32 /// - `a`: input operand (batch x m x k)
33 /// - `b`: input operand (batch x k x n)
34 /// - `batch_dims`: sizes of the batch dimensions
35 /// - `m`, `n`, `k`: fused matrix dimensions
36 /// - `alpha`, `beta`: scaling factors
37 fn bgemm_contiguous_into(
38 c: &mut crate::contiguous::ContiguousOperandMut<T>,
39 a: &crate::contiguous::ContiguousOperand<T>,
40 b: &crate::contiguous::ContiguousOperand<T>,
41 batch_dims: &[usize],
42 m: usize,
43 n: usize,
44 k: usize,
45 alpha: T,
46 beta: T,
47 ) -> strided_view::Result<()>;
48}
49
50// ---------------------------------------------------------------------------
51// Marker structs
52// ---------------------------------------------------------------------------
53
54/// Batched GEMM backend using the [`faer`] library.
55///
56/// `Backend<T>` is implemented in `bgemm_faer.rs`.
57#[cfg(feature = "faer")]
58pub struct FaerBackend;
59
60/// Batched GEMM backend using CBLAS (via `cblas-sys` or `cblas-inject`).
61///
62/// `Backend<T>` is implemented in `bgemm_blas.rs`.
63#[cfg(any(feature = "blas", feature = "blas-inject"))]
64pub struct BlasBackend;
65
66/// Fallback batched GEMM backend using explicit loops (no external library).
67///
68/// This backend is used as `ActiveBackend` when no GEMM feature is enabled.
69/// The GEMM dispatch in `einsum2_into` calls `bgemm_naive` directly rather
70/// than going through the `Backend` trait, so `bgemm_contiguous_into` is
71/// unreachable.
72#[allow(dead_code)]
73pub struct NaiveBackend;
74
75impl<T> Backend<T> for NaiveBackend
76where
77 T: crate::ScalarBase + strided_view::ElementOpApply,
78{
79 const MATERIALIZES_CONJ: bool = false;
80 const REQUIRES_UNIT_STRIDE: bool = false;
81
82 fn bgemm_contiguous_into(
83 c: &mut crate::contiguous::ContiguousOperandMut<T>,
84 a: &crate::contiguous::ContiguousOperand<T>,
85 b: &crate::contiguous::ContiguousOperand<T>,
86 batch_dims: &[usize],
87 m: usize,
88 n: usize,
89 k: usize,
90 alpha: T,
91 beta: T,
92 ) -> strided_view::Result<()> {
93 let a_ptr = a.ptr();
94 let b_ptr = b.ptr();
95 let c_ptr = c.ptr();
96 let a_rs = a.row_stride();
97 let a_cs = a.col_stride();
98 let b_rs = b.row_stride();
99 let b_cs = b.col_stride();
100 let c_rs = c.row_stride();
101 let c_cs = c.col_stride();
102
103 let mut batch_iter = crate::util::MultiIndex::new(batch_dims);
104 while batch_iter.next().is_some() {
105 let a_base = batch_iter.offset(a.batch_strides());
106 let b_base = batch_iter.offset(b.batch_strides());
107 let c_base = batch_iter.offset(c.batch_strides());
108
109 for i in 0..m {
110 for j in 0..n {
111 let mut acc = T::zero();
112 for l in 0..k {
113 let mut a_val = unsafe {
114 *a_ptr.offset(a_base + i as isize * a_rs + l as isize * a_cs)
115 };
116 let mut b_val = unsafe {
117 *b_ptr.offset(b_base + l as isize * b_rs + j as isize * b_cs)
118 };
119 if a.conj() {
120 a_val = strided_view::Conj::apply(a_val);
121 }
122 if b.conj() {
123 b_val = strided_view::Conj::apply(b_val);
124 }
125 acc = acc + a_val * b_val;
126 }
127 unsafe {
128 let c_elem = c_ptr.offset(c_base + i as isize * c_rs + j as isize * c_cs);
129 if beta == T::zero() {
130 *c_elem = alpha * acc;
131 } else {
132 *c_elem = alpha * acc + beta * (*c_elem);
133 }
134 }
135 }
136 }
137 }
138 Ok(())
139 }
140}
141
142// ---------------------------------------------------------------------------
143// ActiveBackend type alias -- the SINGLE point of backend selection
144// ---------------------------------------------------------------------------
145
146/// The active GEMM backend, selected by Cargo features.
147///
148/// - `blas` or `blas-inject` -> `BlasBackend`
149/// - `faer` without BLAS -> `FaerBackend`
150/// - no backend feature -> `NaiveBackend`
151/// - invalid combos -> `NaiveBackend` (placeholder; `compile_error!` fires first)
152#[cfg(any(
153 all(feature = "blas", not(feature = "blas-inject")),
154 all(feature = "blas-inject", not(feature = "blas"))
155))]
156pub type ActiveBackend = BlasBackend;
157
158#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
159pub type ActiveBackend = FaerBackend;
160
161#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
162pub type ActiveBackend = NaiveBackend;
163
164/// Placeholder for invalid mutually-exclusive feature combinations.
165///
166/// The crate emits `compile_error!` for these combinations (in `lib.rs`), so this
167/// alias only suppresses cascading type-resolution errors.
168#[cfg(all(feature = "blas", feature = "blas-inject"))]
169pub type ActiveBackend = NaiveBackend;