tenferro_prims/
lib.rs

1//! Tensor primitive execution families for the tenferro workspace.
2//!
3//! The crate is organized around focused backend contracts instead of a single
4//! monolithic descriptor surface:
5//!
6//! - [`TensorSemiringCore`] for the minimal semiring substrate used by
7//!   `tenferro-einsum`
8//! - [`TensorSemiringFastPath`] for optional semiring performance paths such as
9//!   contraction fast paths
10//! - [`TensorScalarPrims`] for standard scalar pointwise and reduction families
11//! - [`TensorAnalyticPrims`] for analytic pointwise and reduction families
12//! - [`TensorComplexRealPrims`] for cross-dtype complex-to-real unary families
13//! - [`TensorComplexScalePrims`] for complex payload scaled by real-valued tensors
14//! - [`TensorMetadataPrims`] for integer/bool metadata tensor families with
15//!   overwrite-based execution and erased metadata tensor handles
16//! - [`TensorMetadataCastPrims`] for metadata-to-scalar bridge families such as
17//!   bool/int casts and `where`
18//! - [`TensorRngPrims`] for dense eager RNG constructors such as `rand` and
19//!   `randn`
20//! - [`TensorIndexingPrims`] for index-based selection, gathering, and
21//!   scattering
22//! - [`TensorSortPrims`] for sort, argsort, and top-k operations
23//!
24//! Most families follow the same plan/execute pattern:
25//!
26//! 1. Create a family descriptor
27//! 2. Build a backend plan for concrete tensor shapes
28//! 3. Execute the plan with BLAS-style `alpha`/`beta` scaling
29//!
30//! [`TensorMetadataPrims`] is the exception: it uses overwrite-based execution
31//! over erased integer/bool metadata tensor handles instead of scalar-family
32//! scaling.
33//!
34//! # CPU GEMM backend selection
35//!
36//! `BatchedGemm` on [`CpuBackend`] requires exactly one CPU GEMM backend feature:
37//! - `gemm-faer` (default): pure-Rust faer matmul backend
38//! - `gemm-blas`: CBLAS backend (`cblas-sys`) with selectable symbol provider
39//!
40//! If `gemm-blas` is selected, choose exactly one provider:
41//! - `provider-src`: link BLAS source crates (`blas-src` + `cblas-src`)
42//! - `provider-inject`: link runtime-injected symbols (`cblas-inject`)
43//!
44//! With `provider-src`, choose exactly one `src-*` implementation:
45//! `src-openblas`, `src-netlib`, `src-accelerate`, `src-r`,
46//! `src-intel-mkl-dynamic-sequential`, `src-intel-mkl-dynamic-parallel`,
47//! `src-intel-mkl-static-sequential`, `src-intel-mkl-static-parallel`.
48//!
49//! Example (OpenBLAS source provider):
50//! `cargo test -p tenferro-prims --no-default-features --features "gemm-blas,provider-src,src-openblas"`
51//!
52//! Example (runtime-injected provider):
53//! `cargo test -p tenferro-prims --no-default-features --features "gemm-blas,provider-inject"`
54//!
55//! On [`CpuBackend`], semiring-core `BatchedGemm` supports `f32`, `f64`,
56//! `Complex32`, and `Complex64`.
57//!
58//! # Examples
59//!
60//! ## Semiring core planning
61//!
62//! ```ignore
63//! use tenferro_algebra::Standard;
64//! use tenferro_device::LogicalMemorySpace;
65//! use tenferro_prims::{CpuBackend, CpuContext, SemiringCoreDescriptor, TensorSemiringCore};
66//! use tenferro_tensor::{MemoryOrder, Tensor};
67//!
68//! let mut ctx = CpuContext::new(4);
69//! let col = MemoryOrder::ColumnMajor;
70//! let mem = LogicalMemorySpace::MainMemory;
71//! let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
72//! let b = Tensor::<f64>::zeros(&[4, 5], mem, col).unwrap();
73//! let mut c = Tensor::<f64>::zeros(&[3, 5], mem, col).unwrap();
74//!
75//! let desc = SemiringCoreDescriptor::BatchedGemm {
76//!     batch_dims: vec![],
77//!     m: 3,
78//!     n: 5,
79//!     k: 4,
80//! };
81//! let plan = <CpuBackend as TensorSemiringCore<Standard<f64>>>::plan(
82//!     &mut ctx,
83//!     &desc,
84//!     &[&[3, 4], &[4, 5], &[3, 5]],
85//! )
86//! .unwrap();
87//! <CpuBackend as TensorSemiringCore<Standard<f64>>>::execute(
88//!     &mut ctx,
89//!     &plan,
90//!     1.0,
91//!     &[&a, &b],
92//!     0.0,
93//!     &mut c,
94//! )
95//! .unwrap();
96//! ```
97//!
98//! ## Scalar family planning
99//!
100//! ```ignore
101//! use tenferro_algebra::Standard;
102//! use tenferro_device::LogicalMemorySpace;
103//! use tenferro_prims::{
104//!     CpuBackend, CpuContext, ScalarPrimsDescriptor, ScalarReductionOp, TensorScalarPrims,
105//! };
106//! use tenferro_tensor::{MemoryOrder, Tensor};
107//!
108//! let mut ctx = CpuContext::new(4);
109//! let col = MemoryOrder::ColumnMajor;
110//! let mem = LogicalMemorySpace::MainMemory;
111//! let a = Tensor::<f64>::zeros(&[3, 4], mem, col).unwrap();
112//! let mut c = Tensor::<f64>::zeros(&[3], mem, col).unwrap();
113//!
114//! let desc = ScalarPrimsDescriptor::Reduction {
115//!     modes_a: vec![0, 1],
116//!     modes_c: vec![0],
117//!     op: ScalarReductionOp::Sum,
118//! };
119//! let plan = <CpuBackend as TensorScalarPrims<Standard<f64>>>::plan(
120//!     &mut ctx,
121//!     &desc,
122//!     &[&[3, 4], &[3]],
123//! )
124//! .unwrap();
125//! ```
126
127#[cfg(all(feature = "gemm-faer", feature = "gemm-blas"))]
128compile_error!("enable exactly one GEMM backend: gemm-faer or gemm-blas");
129
130#[cfg(all(not(feature = "gemm-faer"), not(feature = "gemm-blas")))]
131compile_error!("enable exactly one GEMM backend: gemm-faer or gemm-blas");
132
133#[cfg(all(feature = "provider-src", not(feature = "gemm-blas")))]
134compile_error!("provider-src requires gemm-blas");
135#[cfg(all(feature = "provider-inject", not(feature = "gemm-blas")))]
136compile_error!("provider-inject requires gemm-blas");
137#[cfg(all(
138    any(
139        feature = "src-openblas",
140        feature = "src-netlib",
141        feature = "src-accelerate",
142        feature = "src-r",
143        feature = "src-intel-mkl-dynamic-sequential",
144        feature = "src-intel-mkl-dynamic-parallel",
145        feature = "src-intel-mkl-static-sequential",
146        feature = "src-intel-mkl-static-parallel"
147    ),
148    not(feature = "gemm-blas")
149))]
150compile_error!("src-* features require gemm-blas and provider-src");
151
152#[cfg(feature = "gemm-blas")]
153const _: () = {
154    let provider_count =
155        (cfg!(feature = "provider-src") as usize) + (cfg!(feature = "provider-inject") as usize);
156    assert!(
157        provider_count == 1,
158        "gemm-blas requires exactly one provider: provider-src or provider-inject"
159    );
160
161    let src_count = (cfg!(feature = "src-openblas") as usize)
162        + (cfg!(feature = "src-netlib") as usize)
163        + (cfg!(feature = "src-accelerate") as usize)
164        + (cfg!(feature = "src-r") as usize)
165        + (cfg!(feature = "src-intel-mkl-dynamic-sequential") as usize)
166        + (cfg!(feature = "src-intel-mkl-dynamic-parallel") as usize)
167        + (cfg!(feature = "src-intel-mkl-static-sequential") as usize)
168        + (cfg!(feature = "src-intel-mkl-static-parallel") as usize);
169
170    if cfg!(feature = "provider-src") {
171        assert!(
172            src_count == 1,
173            "provider-src requires exactly one src-* feature"
174        );
175    }
176    if cfg!(feature = "provider-inject") {
177        assert!(src_count == 0, "provider-inject forbids src-* features");
178    }
179};
180
181#[cfg(feature = "provider-src")]
182extern crate blas_src as _;
183#[cfg(feature = "provider-inject")]
184extern crate cblas_inject as _;
185#[cfg(feature = "provider-src")]
186extern crate cblas_src as _;
187
188mod cpu;
189mod families;
190mod infra;
191#[cfg(all(feature = "gemm-blas", feature = "provider-inject"))]
192pub mod inject;
193mod shape_helpers;
194pub mod tensor_ops;
195
196// CUDA backend: real implementation when `cuda` feature is enabled,
197// otherwise stub types that return errors.
198#[cfg(feature = "cuda")]
199mod cuda;
200#[cfg(feature = "cuda")]
201mod cuda_ffi;
202
203mod gpu_stubs;
204
205#[doc(hidden)]
206pub use cpu::CpuAnalyticPlan;
207#[doc(hidden)]
208pub use cpu::CpuComplexRealPlan;
209#[doc(hidden)]
210pub use cpu::CpuComplexScalePlan;
211#[doc(hidden)]
212pub use cpu::CpuIndexingPlan;
213#[doc(hidden)]
214pub use cpu::CpuScalarPlan;
215#[doc(hidden)]
216pub use cpu::CpuSortPlan;
217pub use cpu::*;
218pub use families::*;
219pub use infra::*;
220
221#[doc(hidden)]
222pub fn print_and_reset_contract_profile() {}
223
224#[cfg(feature = "cuda")]
225pub use cuda::*;
226#[cfg(feature = "cuda")]
227pub use cuda_ffi::*;
228
229#[cfg(not(feature = "cuda"))]
230pub use gpu_stubs::CudaBackend;
231#[cfg(not(feature = "cuda"))]
232pub use gpu_stubs::CudaContext;
233#[cfg(not(feature = "cuda"))]
234pub use gpu_stubs::CudaPlan;
235
236// ROCm stubs are always from gpu_stubs (no real ROCm backend yet)
237pub use gpu_stubs::RocmBackend;
238pub use gpu_stubs::RocmContext;
239pub use gpu_stubs::RocmPlan;
240
241use tenferro_algebra::Scalar;
242use tenferro_device::{Error, Result};
243use tenferro_tensor::Tensor;
244
245/// Reusable typed temporary vector pool exposed through backend contexts.
246#[doc(hidden)]
247pub trait TensorTempPoolContext {
248    fn take_temp_vec<T: Send + 'static>(&mut self, len: usize) -> Vec<T>;
249    fn put_temp_vec<T: Send + 'static>(&mut self, vec: Vec<T>);
250}
251
252// ===========================================================================
253// Helpers for multi-index iteration
254// ===========================================================================
255
256/// Iterate over all index combinations for the given dimensions (column-major order).
257pub(crate) fn for_each_index(dims: &[usize], mut f: impl FnMut(&[usize])) {
258    let ndim = dims.len();
259    if ndim == 0 {
260        f(&[]);
261        return;
262    }
263    let total: usize = dims.iter().product();
264    if total == 0 {
265        return;
266    }
267    let mut index = vec![0usize; ndim];
268    for _ in 0..total {
269        f(&index);
270        // Increment column-major
271        for d in 0..ndim {
272            index[d] += 1;
273            if index[d] < dims[d] {
274                break;
275            }
276            index[d] = 0;
277        }
278    }
279}
280
281/// Find the position of a mode label in a mode list, returning an error if not found.
282pub(crate) fn mode_position(modes: &[u32], label: u32) -> Result<usize> {
283    modes
284        .iter()
285        .position(|&m| m == label)
286        .ok_or_else(|| Error::InvalidArgument(format!("mode label {label} not found")))
287}
288
289/// Validate that the number of shapes matches expectations for an operation.
290pub(crate) fn validate_shape_count(
291    shapes: &[&[usize]],
292    expected: usize,
293    op_name: &str,
294) -> Result<()> {
295    if shapes.len() != expected {
296        return Err(Error::InvalidArgument(format!(
297            "{op_name} expects {expected} shapes (got {})",
298            shapes.len()
299        )));
300    }
301    Ok(())
302}
303
304/// Validate that a shape has the expected rank.
305pub(crate) fn validate_rank(shape: &[usize], expected: usize, _operand_name: &str) -> Result<()> {
306    if shape.len() != expected {
307        return Err(Error::RankMismatch {
308            expected,
309            got: shape.len(),
310        });
311    }
312    Ok(())
313}
314
315/// Validate that a shape exactly matches the expected shape.
316pub(crate) fn validate_shape_eq(
317    got: &[usize],
318    expected: &[usize],
319    _operand_name: &str,
320) -> Result<()> {
321    if got != expected {
322        return Err(Error::ShapeMismatch {
323            expected: expected.to_vec(),
324            got: got.to_vec(),
325        });
326    }
327    Ok(())
328}
329
330/// Validate the number of input operands for execute.
331pub(crate) fn validate_execute_inputs<T: Scalar>(
332    inputs: &[&Tensor<T>],
333    expected: usize,
334    op_name: &str,
335) -> Result<()> {
336    if inputs.len() != expected {
337        return Err(Error::InvalidArgument(format!(
338            "{op_name} expects {expected} input(s) (got {})",
339            inputs.len()
340        )));
341    }
342    Ok(())
343}
344
345#[cfg(test)]
346mod tests;