strided_einsum2/backend.rs
1//! Backend abstraction for batched GEMM dispatch.
2//!
3//! This module defines the [`BackendConfig`] and [`BgemmBackend`] traits, marker
4//! structs for each backend, and the [`ActiveBackend`] type alias that serves as
5//! the single point of backend selection based on Cargo features.
6
7/// Static configuration for a GEMM backend.
8///
9/// Each backend declares its requirements so that operand preparation can
10/// adapt without per-call `cfg` checks.
11pub trait BackendConfig {
12 /// Whether the backend needs conjugation materialized into the data
13 /// before GEMM (e.g., CBLAS has no conjugation flag for `?gemm`).
14 const MATERIALIZES_CONJ: bool;
15
16 /// Whether the backend requires at least one unit stride per matrix
17 /// dimension (row or column stride must be 1). CBLAS `?gemm` requires
18 /// this; faer does not.
19 const REQUIRES_UNIT_STRIDE: bool;
20}
21
22/// Trait for backends that can execute batched GEMM on contiguous operands.
23///
24/// Implementations are provided by each backend module (faer, blas).
25/// External crates can implement this trait for custom scalar types
26/// (e.g., tropical semiring) and pass the backend to [`einsum2_with_backend_into`].
27///
28/// [`einsum2_with_backend_into`]: crate::einsum2_with_backend_into
29pub trait BgemmBackend<T> {
30 /// Execute batched GEMM: `C = alpha * A * B + beta * C` for each batch.
31 ///
32 /// - `c`: mutable output operand (batch x m x n)
33 /// - `a`: input operand (batch x m x k)
34 /// - `b`: input operand (batch x k x n)
35 /// - `batch_dims`: sizes of the batch dimensions
36 /// - `m`, `n`, `k`: fused matrix dimensions
37 /// - `alpha`, `beta`: scaling factors
38 fn bgemm_contiguous_into(
39 c: &mut crate::contiguous::ContiguousOperandMut<T>,
40 a: &crate::contiguous::ContiguousOperand<T>,
41 b: &crate::contiguous::ContiguousOperand<T>,
42 batch_dims: &[usize],
43 m: usize,
44 n: usize,
45 k: usize,
46 alpha: T,
47 beta: T,
48 ) -> strided_view::Result<()>;
49}
50
51// ---------------------------------------------------------------------------
52// Marker structs and BackendConfig implementations
53// ---------------------------------------------------------------------------
54
55/// Batched GEMM backend using the [`faer`] library.
56#[cfg(feature = "faer")]
57pub struct FaerBackend;
58
59#[cfg(feature = "faer")]
60impl BackendConfig for FaerBackend {
61 const MATERIALIZES_CONJ: bool = false;
62 const REQUIRES_UNIT_STRIDE: bool = false;
63}
64
65/// Batched GEMM backend using CBLAS (via `cblas-sys` or `cblas-inject`).
66#[cfg(any(feature = "blas", feature = "blas-inject"))]
67pub struct BlasBackend;
68
69#[cfg(any(feature = "blas", feature = "blas-inject"))]
70impl BackendConfig for BlasBackend {
71 const MATERIALIZES_CONJ: bool = true;
72 const REQUIRES_UNIT_STRIDE: bool = true;
73}
74
75/// Fallback batched GEMM backend using explicit loops (no external library).
76#[allow(dead_code)]
77pub struct NaiveBackend;
78
79impl BackendConfig for NaiveBackend {
80 const MATERIALIZES_CONJ: bool = false;
81 const REQUIRES_UNIT_STRIDE: bool = false;
82}
83
84// ---------------------------------------------------------------------------
85// ActiveBackend type alias -- the SINGLE point of backend selection
86// ---------------------------------------------------------------------------
87
88/// The active GEMM backend, selected by Cargo features.
89///
90/// - `faer` (without blas/blas-inject) -> [`FaerBackend`]
91/// - `blas` or `blas-inject` (without faer) -> [`BlasBackend`]
92/// - no backend feature -> [`NaiveBackend`]
93/// - invalid combos -> [`NaiveBackend`] (placeholder; `compile_error!` fires first)
94#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
95pub type ActiveBackend = FaerBackend;
96
97#[cfg(all(
98 not(feature = "faer"),
99 any(
100 all(feature = "blas", not(feature = "blas-inject")),
101 all(feature = "blas-inject", not(feature = "blas"))
102 )
103))]
104pub type ActiveBackend = BlasBackend;
105
106#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
107pub type ActiveBackend = NaiveBackend;
108
109/// Placeholder for invalid mutually-exclusive feature combinations.
110///
111/// The crate emits `compile_error!` for these combinations (in `lib.rs`), so this
112/// alias only suppresses cascading type-resolution errors.
113#[cfg(any(
114 all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
115 all(feature = "blas", feature = "blas-inject")
116))]
117pub type ActiveBackend = NaiveBackend;