tenferro_linalg_prims/
lib.rs

1//! Backend-facing linalg kernel contracts for the tenferro workspace.
2//!
3//! This crate holds the low-level tensor linalg protocol that device backends
4//! implement. High-level composite APIs remain in `tenferro-linalg`.
5//!
6//! # Examples
7//!
8//! ```ignore
9//! use tenferro_linalg_prims::{TensorLinalgPrims, QrTensorResult};
10//! use tenferro_tensor::Tensor;
11//!
12//! fn accepts_backend<B: TensorLinalgPrims<f64>>() {}
13//! let _: Option<QrTensorResult<f64>> = None;
14//! let _: Option<Tensor<f64>> = None;
15//! ```
16
17#[cfg(all(feature = "provider-src", not(feature = "linalg-lapack")))]
18compile_error!("provider-src requires linalg-lapack");
19#[cfg(all(feature = "provider-inject", not(feature = "linalg-lapack")))]
20compile_error!("provider-inject requires linalg-lapack");
21#[cfg(all(feature = "linalg-faer", feature = "linalg-lapack"))]
22compile_error!(
23    "Features `linalg-faer` and `linalg-lapack` are mutually exclusive. Enable exactly one."
24);
25#[cfg(not(any(feature = "linalg-faer", feature = "linalg-lapack")))]
26compile_error!("No CPU linalg provider selected. Enable `linalg-faer` or `linalg-lapack`.");
27#[cfg(all(
28    any(
29        feature = "src-openblas",
30        feature = "src-netlib",
31        feature = "src-accelerate",
32        feature = "src-r",
33        feature = "src-intel-mkl-dynamic-sequential",
34        feature = "src-intel-mkl-dynamic-parallel",
35        feature = "src-intel-mkl-static-sequential",
36        feature = "src-intel-mkl-static-parallel"
37    ),
38    not(feature = "linalg-lapack")
39))]
40compile_error!("src-* features require linalg-lapack and provider-src");
41
42#[cfg(feature = "linalg-lapack")]
43const _: () = {
44    let provider_count =
45        (cfg!(feature = "provider-src") as usize) + (cfg!(feature = "provider-inject") as usize);
46    assert!(
47        provider_count == 1,
48        "linalg-lapack requires exactly one provider: provider-src or provider-inject"
49    );
50
51    let src_count = (cfg!(feature = "src-openblas") as usize)
52        + (cfg!(feature = "src-netlib") as usize)
53        + (cfg!(feature = "src-accelerate") as usize)
54        + (cfg!(feature = "src-r") as usize)
55        + (cfg!(feature = "src-intel-mkl-dynamic-sequential") as usize)
56        + (cfg!(feature = "src-intel-mkl-dynamic-parallel") as usize)
57        + (cfg!(feature = "src-intel-mkl-static-sequential") as usize)
58        + (cfg!(feature = "src-intel-mkl-static-parallel") as usize);
59
60    if cfg!(feature = "provider-src") {
61        assert!(
62            src_count == 1,
63            "provider-src requires exactly one src-* feature"
64        );
65    }
66    if cfg!(feature = "provider-inject") {
67        assert!(src_count == 0, "provider-inject forbids src-* features");
68    }
69};
70
71#[cfg(feature = "provider-src")]
72extern crate blas_src as _;
73#[cfg(feature = "provider-src")]
74extern crate cblas_src as _;
75#[cfg(feature = "provider-src")]
76extern crate lapack_src as _;
77
78#[cfg(feature = "provider-inject")]
79extern crate cblas_inject as _;
80#[cfg(feature = "provider-inject")]
81extern crate lapack_inject as _;
82
83use num_complex::{Complex32, Complex64};
84use num_traits::Zero;
85use tenferro_algebra::Scalar;
86use tenferro_device::Result;
87use tenferro_tensor::Tensor;
88
89pub mod backend;
90
91/// Scalar types supported by linalg kernel contracts.
92///
93/// # Examples
94///
95/// ```
96/// use tenferro_linalg_prims::LinalgScalar;
97///
98/// fn needs_linalg_scalar<T: LinalgScalar>(x: T) -> T { x }
99/// assert_eq!(needs_linalg_scalar(1.0_f64), 1.0);
100/// ```
101pub trait LinalgScalar:
102    Scalar
103    + std::ops::Sub<Output = Self>
104    + std::ops::Neg<Output = Self>
105    + std::ops::Div<Output = Self>
106    + num_traits::NumCast
107    + std::fmt::Debug
108    + 'static
109{
110    type Real: LinalgScalar<Real = Self::Real, Complex = Self::Complex> + num_traits::Float;
111    type Complex: LinalgScalar<Real = Self::Real, Complex = Self::Complex>;
112
113    /// Return the scalar magnitude in the associated real field.
114    fn abs_real(&self) -> Self::Real;
115    /// Return a reasonable machine epsilon for the associated real field.
116    fn real_epsilon() -> Self::Real;
117    /// Return the algebraic conjugate.
118    fn conj(&self) -> Self;
119    /// Build a scalar from explicit real/imaginary parts.
120    fn from_parts(real: Self::Real, imag: Self::Real) -> Self;
121    /// Build a scalar from the associated real field.
122    fn from_real(real: Self::Real) -> Self {
123        Self::from_parts(real, Self::Real::zero())
124    }
125    /// Return the real part in the associated real field.
126    fn real_part(&self) -> Self::Real;
127    /// Return the imaginary part in the associated real field.
128    fn imag_part(&self) -> Self::Real;
129}
130
131/// Scalar types with concrete backend kernel support in the current workspace.
132///
133/// This marker keeps public/high-level linalg bounds generic over backends
134/// without leaking provider-specific names such as `Cpu*` into higher layers.
135///
136/// # Examples
137///
138/// ```
139/// use tenferro_linalg_prims::KernelLinalgScalar;
140///
141/// fn needs_kernel_scalar<T: KernelLinalgScalar>(x: T) -> T { x }
142/// assert_eq!(needs_kernel_scalar(1.0_f64), 1.0);
143/// ```
144pub trait KernelLinalgScalar: LinalgScalar {}
145
146/// LAPACK-oriented eigen helper contract for CPU eigendecomposition paths.
147///
148/// This trait is intentionally narrower than [`LinalgScalar`]. It exists so
149/// CPU eigensolver glue can request the real/imag buffer conversion helpers it
150/// needs without forcing every backend-generic scalar contract to carry LAPACK
151/// details.
152///
153/// # Examples
154///
155/// ```
156/// use tenferro_linalg_prims::LapackEigScalar;
157///
158/// let (vals, vecs) = <f64 as LapackEigScalar>::eig_buffer_sizes(2);
159/// assert_eq!((vals, vecs), (4, 8));
160/// ```
161pub trait LapackEigScalar: LinalgScalar {
162    /// Return the temporary value/vector buffer sizes used by the CPU eig path.
163    fn eig_buffer_sizes(n: usize) -> (usize, usize);
164
165    /// Convert LAPACK-style real/imag outputs into complex values/vectors.
166    fn eig_ri_to_complex(
167        n: usize,
168        val_ri: &[Self],
169        vec_ri: &[Self],
170        values_out: &mut [Self::Complex],
171        vectors_out: &mut [Self::Complex],
172    );
173}
174
175macro_rules! impl_real_linalg_scalar {
176    ($ty:ty, $complex:ty) => {
177        impl LinalgScalar for $ty {
178            type Real = $ty;
179            type Complex = $complex;
180
181            fn abs_real(&self) -> $ty {
182                num_traits::Float::abs(*self)
183            }
184
185            fn real_epsilon() -> $ty {
186                <$ty as num_traits::Float>::epsilon()
187            }
188
189            fn conj(&self) -> $ty {
190                *self
191            }
192
193            fn from_parts(real: Self::Real, _imag: Self::Real) -> Self {
194                real
195            }
196
197            fn real_part(&self) -> Self::Real {
198                *self
199            }
200
201            fn imag_part(&self) -> Self::Real {
202                0.0
203            }
204        }
205
206        impl KernelLinalgScalar for $ty {}
207
208        impl LapackEigScalar for $ty {
209            fn eig_buffer_sizes(n: usize) -> (usize, usize) {
210                (2 * n, 2 * n * n)
211            }
212
213            fn eig_ri_to_complex(
214                n: usize,
215                val_ri: &[Self],
216                vec_ri: &[Self],
217                values_out: &mut [$complex],
218                vectors_out: &mut [$complex],
219            ) {
220                for i in 0..n {
221                    values_out[i] = <$complex>::new(val_ri[2 * i], val_ri[2 * i + 1]);
222                }
223                for k in 0..(n * n) {
224                    vectors_out[k] = <$complex>::new(vec_ri[2 * k], vec_ri[2 * k + 1]);
225                }
226            }
227        }
228    };
229}
230
231macro_rules! impl_complex_linalg_scalar {
232    ($ty:ty, $real:ty) => {
233        impl LinalgScalar for $ty {
234            type Real = $real;
235            type Complex = $ty;
236
237            fn abs_real(&self) -> $real {
238                self.norm()
239            }
240
241            fn real_epsilon() -> $real {
242                <$real as num_traits::Float>::epsilon()
243            }
244
245            fn conj(&self) -> $ty {
246                self.conj()
247            }
248
249            fn from_parts(real: Self::Real, imag: Self::Real) -> Self {
250                <$ty>::new(real, imag)
251            }
252
253            fn real_part(&self) -> Self::Real {
254                self.re
255            }
256
257            fn imag_part(&self) -> Self::Real {
258                self.im
259            }
260        }
261
262        impl KernelLinalgScalar for $ty {}
263
264        impl LapackEigScalar for $ty {
265            fn eig_buffer_sizes(n: usize) -> (usize, usize) {
266                (n, n * n)
267            }
268
269            fn eig_ri_to_complex(
270                _n: usize,
271                val_ri: &[Self],
272                vec_ri: &[Self],
273                values_out: &mut [$ty],
274                vectors_out: &mut [$ty],
275            ) {
276                values_out.copy_from_slice(val_ri);
277                vectors_out.copy_from_slice(vec_ri);
278            }
279        }
280    };
281}
282
283impl_real_linalg_scalar!(f64, Complex64);
284impl_real_linalg_scalar!(f32, Complex32);
285impl_complex_linalg_scalar!(Complex64, f64);
286impl_complex_linalg_scalar!(Complex32, f32);
287
288/// Result of a tensor-level QR decomposition.
289///
290/// # Examples
291///
292/// ```ignore
293/// use tenferro_linalg_prims::QrTensorResult;
294/// let _result: Option<QrTensorResult<f64>> = None;
295/// ```
296#[derive(Clone)]
297pub struct QrTensorResult<T: LinalgScalar> {
298    pub q: Tensor<T>,
299    pub r: Tensor<T>,
300}
301
302/// Result of a tensor-level thin SVD.
303///
304/// # Examples
305///
306/// ```ignore
307/// use tenferro_linalg_prims::SvdTensorResult;
308/// let _result: Option<SvdTensorResult<f64>> = None;
309/// ```
310#[derive(Clone)]
311pub struct SvdTensorResult<T: LinalgScalar> {
312    pub u: Tensor<T>,
313    pub s: Tensor<T::Real>,
314    pub vt: Tensor<T>,
315}
316
317/// Result of a tensor-level LU factorization.
318///
319/// # Examples
320///
321/// ```ignore
322/// use tenferro_linalg_prims::LuTensorResult;
323/// let _result: Option<LuTensorResult<f64>> = None;
324/// ```
325#[derive(Clone)]
326pub struct LuTensorResult<T: LinalgScalar> {
327    pub l: Tensor<T>,
328    pub u: Tensor<T>,
329    pub pivots: Tensor<i32>,
330}
331
332/// Result of a tensor-level LU factorization with numerical status.
333///
334/// # Examples
335///
336/// ```ignore
337/// use tenferro_linalg_prims::LuTensorExResult;
338/// let _result: Option<LuTensorExResult<f64>> = None;
339/// ```
340#[derive(Clone)]
341pub struct LuTensorExResult<T: LinalgScalar> {
342    /// Unit-lower-triangular factor.
343    pub l: Tensor<T>,
344    /// Upper-triangular factor.
345    pub u: Tensor<T>,
346    /// Backend pivot tensor.
347    pub pivots: Tensor<i32>,
348    /// Per-batch numerical status tensor.
349    pub info: Tensor<i32>,
350}
351
352/// Result of a tensor-level linear solve with numerical status.
353///
354/// # Examples
355///
356/// ```ignore
357/// use tenferro_linalg_prims::SolveTensorExResult;
358/// let _result: Option<SolveTensorExResult<f64>> = None;
359/// ```
360#[derive(Clone)]
361pub struct SolveTensorExResult<T: LinalgScalar> {
362    /// Solution tensor.
363    pub solution: Tensor<T>,
364    /// Per-batch numerical status tensor.
365    pub info: Tensor<i32>,
366}
367
368/// Result of a tensor-level Cholesky factorization with numerical status.
369///
370/// # Examples
371///
372/// ```ignore
373/// use tenferro_linalg_prims::CholeskyTensorExResult;
374/// let _result: Option<CholeskyTensorExResult<f64>> = None;
375/// ```
376#[derive(Clone)]
377pub struct CholeskyTensorExResult<T: LinalgScalar> {
378    /// Lower-triangular Cholesky factor.
379    pub l: Tensor<T>,
380    /// Per-batch numerical status tensor.
381    pub info: Tensor<i32>,
382}
383
384/// Result of a tensor-level Hermitian eigendecomposition.
385///
386/// # Examples
387///
388/// ```ignore
389/// use tenferro_linalg_prims::EigenTensorResult;
390/// let _result: Option<EigenTensorResult<f64>> = None;
391/// ```
392#[derive(Clone)]
393pub struct EigenTensorResult<T: LinalgScalar> {
394    pub values: Tensor<T::Real>,
395    pub vectors: Tensor<T>,
396}
397
398/// Result of a tensor-level general eigendecomposition.
399///
400/// # Examples
401///
402/// ```ignore
403/// use tenferro_linalg_prims::EigTensorResult;
404/// let _result: Option<EigTensorResult<f64>> = None;
405/// ```
406#[derive(Clone)]
407pub struct EigTensorResult<T: LinalgScalar> {
408    pub values: Tensor<T::Complex>,
409    pub vectors: Tensor<T::Complex>,
410}
411
412/// Backend-facing tensor linalg protocol.
413///
414/// # Examples
415///
416/// ```ignore
417/// use tenferro_linalg_prims::TensorLinalgPrims;
418///
419/// fn accepts_backend<B: TensorLinalgPrims<f64>>() {}
420/// let _ = accepts_backend::<todo!()>;
421/// ```
422#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
423pub enum LinalgCapabilityOp {
424    Solve,
425    SolveTriangular,
426    Qr,
427    ThinSvd,
428    LuFactor,
429    Cholesky,
430    EigenSym,
431    Eig,
432    LuSolve,
433    Lstsq,
434    LuFactorEx,
435    CholeskyEx,
436    SolveEx,
437    Inv,
438    Det,
439    Slogdet,
440    Pinv,
441    MatrixExp,
442    MatrixPower,
443    Cross,
444    HouseholderProduct,
445    Vander,
446    TensorInv,
447    TensorSolve,
448    Norm,
449}
450
451pub trait TensorLinalgPrims<T: KernelLinalgScalar> {
452    type Context;
453
454    fn has_linalg_support(op: LinalgCapabilityOp) -> bool;
455
456    /// Solve a square linear system while returning per-batch numerical status.
457    ///
458    /// # Examples
459    ///
460    /// ```ignore
461    /// use tenferro_linalg_prims::TensorLinalgPrims;
462    ///
463    /// fn accepts_backend<B: TensorLinalgPrims<f64>>() {}
464    /// let _ = accepts_backend::<todo!()>;
465    /// ```
466    fn solve_ex(
467        ctx: &mut Self::Context,
468        a: &Tensor<T>,
469        b: &Tensor<T>,
470    ) -> Result<SolveTensorExResult<T>>;
471
472    fn solve(ctx: &mut Self::Context, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>;
473    fn lu_solve(
474        ctx: &mut Self::Context,
475        factors: &Tensor<T>,
476        pivots: &Tensor<i32>,
477        b: &Tensor<T>,
478    ) -> Result<Tensor<T>>;
479    fn solve_triangular(
480        ctx: &mut Self::Context,
481        a: &Tensor<T>,
482        b: &Tensor<T>,
483        upper: bool,
484    ) -> Result<Tensor<T>>;
485    fn qr(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<QrTensorResult<T>>;
486    fn thin_svd(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<SvdTensorResult<T>>;
487    fn svdvals(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T::Real>>;
488    /// Compute an LU factorization while returning per-batch numerical status.
489    ///
490    /// # Examples
491    ///
492    /// ```ignore
493    /// use tenferro_linalg_prims::TensorLinalgPrims;
494    ///
495    /// fn accepts_backend<B: TensorLinalgPrims<f64>>() {}
496    /// let _ = accepts_backend::<todo!()>;
497    /// ```
498    fn lu_factor_ex(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorExResult<T>>;
499    fn lu_factor(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorResult<T>>;
500    /// Compute an LU factorization without pivoting.
501    ///
502    /// Backends may support this only on a subset of devices or dtypes.
503    ///
504    /// # Examples
505    ///
506    /// ```ignore
507    /// use tenferro_linalg_prims::TensorLinalgPrims;
508    ///
509    /// fn accepts_backend<B: TensorLinalgPrims<f64>>() {}
510    /// let _ = accepts_backend::<todo!()>;
511    /// ```
512    fn lu_factor_no_pivot(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorResult<T>>;
513    /// Compute a Cholesky factorization while returning per-batch numerical status.
514    ///
515    /// # Examples
516    ///
517    /// ```ignore
518    /// use tenferro_linalg_prims::TensorLinalgPrims;
519    ///
520    /// fn accepts_backend<B: TensorLinalgPrims<f64>>() {}
521    /// let _ = accepts_backend::<todo!()>;
522    /// ```
523    fn cholesky_ex(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<CholeskyTensorExResult<T>>;
524    fn cholesky(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T>>;
525    fn eigen_sym(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<EigenTensorResult<T>>;
526    fn eig(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<EigTensorResult<T>>;
527}
528
529#[cfg(test)]
530mod tests;