strided_einsum2/backend.rs
1//! Backend abstraction for batched GEMM dispatch.
2//!
3//! This module defines the [`Backend`] trait, marker structs for each backend,
4//! and the [`ActiveBackend`] type alias that serves as the single point of
5//! backend selection based on Cargo features.
6
7/// Trait for backends that can execute batched GEMM on contiguous operands.
8///
9/// Each backend declares its configuration (conjugation materialization,
10/// stride requirements) and provides a GEMM implementation.
11///
12/// Implementations are provided by each backend module (faer, blas).
13/// External crates can implement this trait for custom scalar types
14/// (e.g., tropical semiring) and pass the backend to [`einsum2_with_backend_into`].
15///
16/// [`einsum2_with_backend_into`]: crate::einsum2_with_backend_into
17pub trait Backend<T: crate::ScalarBase> {
18 /// Whether the backend needs conjugation materialized into the data
19 /// before GEMM (e.g., CBLAS has no conjugation flag for `?gemm`).
20 const MATERIALIZES_CONJ: bool;
21
22 /// Whether the backend requires at least one unit stride per matrix
23 /// dimension (row or column stride must be 1). CBLAS `?gemm` requires
24 /// this; faer does not.
25 const REQUIRES_UNIT_STRIDE: bool;
26
27 /// Execute batched GEMM: `C = alpha * A * B + beta * C` for each batch.
28 ///
29 /// - `c`: mutable output operand (batch x m x n)
30 /// - `a`: input operand (batch x m x k)
31 /// - `b`: input operand (batch x k x n)
32 /// - `batch_dims`: sizes of the batch dimensions
33 /// - `m`, `n`, `k`: fused matrix dimensions
34 /// - `alpha`, `beta`: scaling factors
35 fn bgemm_contiguous_into(
36 c: &mut crate::contiguous::ContiguousOperandMut<T>,
37 a: &crate::contiguous::ContiguousOperand<T>,
38 b: &crate::contiguous::ContiguousOperand<T>,
39 batch_dims: &[usize],
40 m: usize,
41 n: usize,
42 k: usize,
43 alpha: T,
44 beta: T,
45 ) -> strided_view::Result<()>;
46}
47
48// ---------------------------------------------------------------------------
49// Marker structs
50// ---------------------------------------------------------------------------
51
52/// Batched GEMM backend using the [`faer`] library.
53///
54/// `Backend<T>` is implemented in `bgemm_faer.rs`.
55#[cfg(feature = "faer")]
56pub struct FaerBackend;
57
58/// Batched GEMM backend using CBLAS (via `cblas-sys` or `cblas-inject`).
59///
60/// `Backend<T>` is implemented in `bgemm_blas.rs`.
61#[cfg(any(feature = "blas", feature = "blas-inject"))]
62pub struct BlasBackend;
63
64/// Fallback batched GEMM backend using explicit loops (no external library).
65///
66/// This backend is used as `ActiveBackend` when no GEMM feature is enabled.
67/// The GEMM dispatch in `einsum2_into` calls `bgemm_naive` directly rather
68/// than going through the `Backend` trait, so `bgemm_contiguous_into` is
69/// unreachable.
70#[allow(dead_code)]
71pub struct NaiveBackend;
72
73impl<T: crate::ScalarBase> Backend<T> for NaiveBackend {
74 const MATERIALIZES_CONJ: bool = false;
75 const REQUIRES_UNIT_STRIDE: bool = false;
76
77 fn bgemm_contiguous_into(
78 _c: &mut crate::contiguous::ContiguousOperandMut<T>,
79 _a: &crate::contiguous::ContiguousOperand<T>,
80 _b: &crate::contiguous::ContiguousOperand<T>,
81 _batch_dims: &[usize],
82 _m: usize,
83 _n: usize,
84 _k: usize,
85 _alpha: T,
86 _beta: T,
87 ) -> strided_view::Result<()> {
88 unreachable!("NaiveBackend GEMM is dispatched directly, not through Backend trait")
89 }
90}
91
92// ---------------------------------------------------------------------------
93// ActiveBackend type alias -- the SINGLE point of backend selection
94// ---------------------------------------------------------------------------
95
96/// The active GEMM backend, selected by Cargo features.
97///
98/// - `faer` (without blas/blas-inject) -> [`FaerBackend`]
99/// - `blas` or `blas-inject` (without faer) -> [`BlasBackend`]
100/// - no backend feature -> [`NaiveBackend`]
101/// - invalid combos -> [`NaiveBackend`] (placeholder; `compile_error!` fires first)
102#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
103pub type ActiveBackend = FaerBackend;
104
105#[cfg(all(
106 not(feature = "faer"),
107 any(
108 all(feature = "blas", not(feature = "blas-inject")),
109 all(feature = "blas-inject", not(feature = "blas"))
110 )
111))]
112pub type ActiveBackend = BlasBackend;
113
114#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
115pub type ActiveBackend = NaiveBackend;
116
117/// Placeholder for invalid mutually-exclusive feature combinations.
118///
119/// The crate emits `compile_error!` for these combinations (in `lib.rs`), so this
120/// alias only suppresses cascading type-resolution errors.
121#[cfg(any(
122 all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
123 all(feature = "blas", feature = "blas-inject")
124))]
125pub type ActiveBackend = NaiveBackend;