strided_einsum2/
backend.rs

1//! Backend abstraction for batched GEMM dispatch.
2//!
3//! This module defines the [`Backend`] trait, marker structs for each backend,
4//! and the [`ActiveBackend`] type alias that serves as the single point of
5//! backend selection based on Cargo features.
6
7/// Trait for backends that can execute batched GEMM on contiguous operands.
8///
9/// Each backend declares its configuration (conjugation materialization,
10/// stride requirements) and provides a GEMM implementation.
11///
12/// Implementations are provided by each backend module (faer, blas).
13/// External crates can implement this trait for custom scalar types
14/// (e.g., tropical semiring) and pass the backend to [`einsum2_with_backend_into`].
15///
16/// [`einsum2_with_backend_into`]: crate::einsum2_with_backend_into
17pub trait Backend<T: crate::ScalarBase> {
18    /// Whether the backend needs conjugation materialized into the data
19    /// before GEMM (e.g., CBLAS has no conjugation flag for `?gemm`).
20    const MATERIALIZES_CONJ: bool;
21
22    /// Whether the backend requires at least one unit stride per matrix
23    /// dimension (row or column stride must be 1). CBLAS `?gemm` requires
24    /// this; faer does not.
25    const REQUIRES_UNIT_STRIDE: bool;
26
27    /// Execute batched GEMM: `C = alpha * A * B + beta * C` for each batch.
28    ///
29    /// - `c`: mutable output operand (batch x m x n)
30    /// - `a`: input operand (batch x m x k)
31    /// - `b`: input operand (batch x k x n)
32    /// - `batch_dims`: sizes of the batch dimensions
33    /// - `m`, `n`, `k`: fused matrix dimensions
34    /// - `alpha`, `beta`: scaling factors
35    fn bgemm_contiguous_into(
36        c: &mut crate::contiguous::ContiguousOperandMut<T>,
37        a: &crate::contiguous::ContiguousOperand<T>,
38        b: &crate::contiguous::ContiguousOperand<T>,
39        batch_dims: &[usize],
40        m: usize,
41        n: usize,
42        k: usize,
43        alpha: T,
44        beta: T,
45    ) -> strided_view::Result<()>;
46}
47
48// ---------------------------------------------------------------------------
49// Marker structs
50// ---------------------------------------------------------------------------
51
52/// Batched GEMM backend using the [`faer`] library.
53///
54/// `Backend<T>` is implemented in `bgemm_faer.rs`.
55#[cfg(feature = "faer")]
56pub struct FaerBackend;
57
58/// Batched GEMM backend using CBLAS (via `cblas-sys` or `cblas-inject`).
59///
60/// `Backend<T>` is implemented in `bgemm_blas.rs`.
61#[cfg(any(feature = "blas", feature = "blas-inject"))]
62pub struct BlasBackend;
63
64/// Fallback batched GEMM backend using explicit loops (no external library).
65///
66/// This backend is used as `ActiveBackend` when no GEMM feature is enabled.
67/// The GEMM dispatch in `einsum2_into` calls `bgemm_naive` directly rather
68/// than going through the `Backend` trait, so `bgemm_contiguous_into` is
69/// unreachable.
70#[allow(dead_code)]
71pub struct NaiveBackend;
72
73impl<T: crate::ScalarBase> Backend<T> for NaiveBackend {
74    const MATERIALIZES_CONJ: bool = false;
75    const REQUIRES_UNIT_STRIDE: bool = false;
76
77    fn bgemm_contiguous_into(
78        _c: &mut crate::contiguous::ContiguousOperandMut<T>,
79        _a: &crate::contiguous::ContiguousOperand<T>,
80        _b: &crate::contiguous::ContiguousOperand<T>,
81        _batch_dims: &[usize],
82        _m: usize,
83        _n: usize,
84        _k: usize,
85        _alpha: T,
86        _beta: T,
87    ) -> strided_view::Result<()> {
88        unreachable!("NaiveBackend GEMM is dispatched directly, not through Backend trait")
89    }
90}
91
92// ---------------------------------------------------------------------------
93// ActiveBackend type alias -- the SINGLE point of backend selection
94// ---------------------------------------------------------------------------
95
96/// The active GEMM backend, selected by Cargo features.
97///
98/// - `faer` (without blas/blas-inject) -> [`FaerBackend`]
99/// - `blas` or `blas-inject` (without faer) -> [`BlasBackend`]
100/// - no backend feature -> [`NaiveBackend`]
101/// - invalid combos -> [`NaiveBackend`] (placeholder; `compile_error!` fires first)
102#[cfg(all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))))]
103pub type ActiveBackend = FaerBackend;
104
105#[cfg(all(
106    not(feature = "faer"),
107    any(
108        all(feature = "blas", not(feature = "blas-inject")),
109        all(feature = "blas-inject", not(feature = "blas"))
110    )
111))]
112pub type ActiveBackend = BlasBackend;
113
114#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
115pub type ActiveBackend = NaiveBackend;
116
117/// Placeholder for invalid mutually-exclusive feature combinations.
118///
119/// The crate emits `compile_error!` for these combinations (in `lib.rs`), so this
120/// alias only suppresses cascading type-resolution errors.
121#[cfg(any(
122    all(feature = "faer", any(feature = "blas", feature = "blas-inject")),
123    all(feature = "blas", feature = "blas-inject")
124))]
125pub type ActiveBackend = NaiveBackend;