tenferro_linalg/backend/mod.rs
1//! Backend abstraction for linear algebra operations.
2//!
3//! This module provides both the slice-level [`LinalgBackend`] trait (used
4//! internally by the CPU provider) and the tensor-level
5//! [`TensorLinalgBackend`] trait (the public backend boundary).
6//!
7//! # CPU provider selection
8//!
9//! Exactly one of the following features must be enabled:
10//!
11//! - `linalg-faer`: Pure-Rust via [`faer`](https://crates.io/crates/faer) (default)
12//! - `linalg-lapack`: LAPACK + CBLAS backend with provider selection
13//! (`provider-src` or `provider-inject`)
14//!
15//! Enabling both or neither is a compile error.
16//!
17//! # Device backends
18//!
19//! - **CPU**: [`CpuTensorLinalgBackend`] with [`tenferro_prims::CpuContext`]
20//! - **CUDA**: [`CudaTensorLinalgBackend`] with [`tenferro_prims::CudaContext`] (stub)
21//! - **HIP**: [`HipTensorLinalgBackend`] with [`tenferro_prims::RocmContext`] (stub)
22//!
23//! # Examples
24//!
25//! ```ignore
26//! use tenferro_linalg::backend::{TensorLinalgBackend, CpuTensorLinalgBackend};
27//! use tenferro_tensor::Tensor;
28//!
29//! let mut ctx = tenferro_prims::CpuContext::new(1);
30//! let a: Tensor<f64> = todo!();
31//! let b: Tensor<f64> = todo!();
32//! let _x = <CpuTensorLinalgBackend as TensorLinalgBackend<f64>>::solve(&mut ctx, &a, &b).unwrap();
33//! ```
34
35// ============================================================================
36// Feature policy: exactly one CPU linalg provider must be enabled
37// ============================================================================
38
39#[cfg(all(feature = "linalg-faer", feature = "linalg-lapack"))]
40compile_error!(
41 "Features `linalg-faer` and `linalg-lapack` are mutually exclusive. Enable exactly one."
42);
43
44#[cfg(not(any(feature = "linalg-faer", feature = "linalg-lapack")))]
45compile_error!("No CPU linalg provider selected. Enable `linalg-faer` or `linalg-lapack`.");
46
47// ============================================================================
48// Submodules
49// ============================================================================
50
51// Tensor-level API and types
52pub mod tensor_api;
53pub mod tensor_context;
54pub(crate) mod tensor_helpers;
55
56// ============================================================================
57// Re-exports
58// ============================================================================
59
60// Tensor-level API (public)
61pub use tensor_api::{
62 EigTensorResult, EigenTensorResult, LinalgCapabilityOp, LuTensorResult, QrTensorResult,
63 SvdTensorResult, TensorLinalgBackend,
64};
65pub use tensor_context::TensorLinalgContextFor;
66
67// CPU backend (public)
68#[cfg(feature = "linalg-lapack")]
69pub use tenferro_linalg_prims::backend::BlasLapackBackend;
70pub use tenferro_linalg_prims::backend::CpuTensorLinalgBackend;
71#[cfg(feature = "linalg-faer")]
72pub use tenferro_linalg_prims::backend::FaerBackend;
73
74// GPU backend stubs (public)
75pub(crate) use tenferro_linalg_prims::backend::col_major_strides;
76pub use tenferro_linalg_prims::backend::{
77 CudaDataType, CudaLinalgScalar, CudaTensorLinalgBackend, HipTensorLinalgBackend,
78};
79
80use tenferro_device::Result;
81
82/// Slice-level backend interface for matrix linear algebra operations.
83///
84/// All input/output slices use **column-major** layout. The trait is
85/// parameterized by scalar type `T` (e.g., `f64`, `f32`).
86///
87/// Implementations take `&mut self` to allow internal workspace reuse.
88///
89/// This trait is used internally by CPU provider implementations.
90/// The public API boundary is [`TensorLinalgBackend`].
91///
92/// # Examples
93///
94/// ```
95/// use tenferro_linalg::backend::LinalgBackend;
96///
97/// fn do_svd<B: LinalgBackend<f64, Real = f64>>(backend: &mut B) {
98/// let a = [1.0, 0.0, 0.0, 1.0]; // 2x2 identity
99/// let mut u = [0.0; 4];
100/// let mut s = [0.0; 2];
101/// let mut vt = [0.0; 4];
102/// backend.thin_svd(&a, 2, 2, &mut u, &mut s, &mut vt).unwrap();
103/// }
104/// ```
105pub trait LinalgBackend<T: Copy + 'static> {
106 /// The real-valued scalar type for singular/eigenvalues.
107 type Real: Copy + 'static;
108
109 /// Thin SVD: `A = U diag(S) Vt`.
110 fn thin_svd(
111 &mut self,
112 a: &[T],
113 m: usize,
114 n: usize,
115 u: &mut [T],
116 s: &mut [Self::Real],
117 vt: &mut [T],
118 ) -> Result<()>;
119
120 /// Thin QR decomposition: `A = Q R`.
121 fn qr(&mut self, a: &[T], m: usize, n: usize, q: &mut [T], r: &mut [T]) -> Result<()>;
122
123 /// LU decomposition with partial pivoting: `P A = L U`.
124 fn lu(
125 &mut self,
126 a: &[T],
127 m: usize,
128 n: usize,
129 perm: &mut [usize],
130 l: &mut [T],
131 u_out: &mut [T],
132 ) -> Result<()>;
133
134 /// Cholesky decomposition: `A = L L^H`.
135 fn cholesky(&mut self, a: &[T], n: usize, l: &mut [T]) -> Result<()>;
136
137 /// Symmetric eigendecomposition: `A = V diag(lambda) V^H`.
138 fn eigen_sym(
139 &mut self,
140 a: &[T],
141 n: usize,
142 values: &mut [Self::Real],
143 vectors: &mut [T],
144 ) -> Result<()>;
145
146 /// Matrix multiplication: `C = A * B`.
147 fn mat_mul(
148 &mut self,
149 a: &[T],
150 m: usize,
151 k: usize,
152 b: &[T],
153 n: usize,
154 c: &mut [T],
155 ) -> Result<()>;
156
157 /// Solve linear system: `A x = b`.
158 fn solve(&mut self, a: &[T], b: &[T], n: usize, nrhs: usize, x: &mut [T]) -> Result<()>;
159
160 /// Solve triangular system: `A x = b`.
161 fn solve_triangular(
162 &mut self,
163 a: &[T],
164 b: &[T],
165 n: usize,
166 nrhs: usize,
167 upper: bool,
168 x: &mut [T],
169 ) -> Result<()>;
170
171 /// General eigendecomposition: `A V = V diag(lambda)`.
172 ///
173 /// For real `T`: output uses interleaved re/im pairs (`2*n` values, `2*n*n` vectors).
174 /// For complex `T`: output uses direct complex elements (`n` values, `n*n` vectors).
175 fn eig_general(
176 &mut self,
177 a: &[T],
178 n: usize,
179 values_ri: &mut [T],
180 vectors_ri: &mut [T],
181 ) -> Result<()>;
182}