1#[cfg(all(feature = "provider-src", not(feature = "linalg-lapack")))]
18compile_error!("provider-src requires linalg-lapack");
19#[cfg(all(feature = "provider-inject", not(feature = "linalg-lapack")))]
20compile_error!("provider-inject requires linalg-lapack");
21#[cfg(all(feature = "linalg-faer", feature = "linalg-lapack"))]
22compile_error!(
23 "Features `linalg-faer` and `linalg-lapack` are mutually exclusive. Enable exactly one."
24);
25#[cfg(not(any(feature = "linalg-faer", feature = "linalg-lapack")))]
26compile_error!("No CPU linalg provider selected. Enable `linalg-faer` or `linalg-lapack`.");
27#[cfg(all(
28 any(
29 feature = "src-openblas",
30 feature = "src-netlib",
31 feature = "src-accelerate",
32 feature = "src-r",
33 feature = "src-intel-mkl-dynamic-sequential",
34 feature = "src-intel-mkl-dynamic-parallel",
35 feature = "src-intel-mkl-static-sequential",
36 feature = "src-intel-mkl-static-parallel"
37 ),
38 not(feature = "linalg-lapack")
39))]
40compile_error!("src-* features require linalg-lapack and provider-src");
41
42#[cfg(feature = "linalg-lapack")]
43const _: () = {
44 let provider_count =
45 (cfg!(feature = "provider-src") as usize) + (cfg!(feature = "provider-inject") as usize);
46 assert!(
47 provider_count == 1,
48 "linalg-lapack requires exactly one provider: provider-src or provider-inject"
49 );
50
51 let src_count = (cfg!(feature = "src-openblas") as usize)
52 + (cfg!(feature = "src-netlib") as usize)
53 + (cfg!(feature = "src-accelerate") as usize)
54 + (cfg!(feature = "src-r") as usize)
55 + (cfg!(feature = "src-intel-mkl-dynamic-sequential") as usize)
56 + (cfg!(feature = "src-intel-mkl-dynamic-parallel") as usize)
57 + (cfg!(feature = "src-intel-mkl-static-sequential") as usize)
58 + (cfg!(feature = "src-intel-mkl-static-parallel") as usize);
59
60 if cfg!(feature = "provider-src") {
61 assert!(
62 src_count == 1,
63 "provider-src requires exactly one src-* feature"
64 );
65 }
66 if cfg!(feature = "provider-inject") {
67 assert!(src_count == 0, "provider-inject forbids src-* features");
68 }
69};
70
71#[cfg(feature = "provider-src")]
72extern crate blas_src as _;
73#[cfg(feature = "provider-src")]
74extern crate cblas_src as _;
75#[cfg(feature = "provider-src")]
76extern crate lapack_src as _;
77
78#[cfg(feature = "provider-inject")]
79extern crate cblas_inject as _;
80#[cfg(feature = "provider-inject")]
81extern crate lapack_inject as _;
82
83use num_complex::{Complex32, Complex64};
84use num_traits::Zero;
85use tenferro_algebra::Scalar;
86use tenferro_device::Result;
87use tenferro_tensor::Tensor;
88
89pub mod backend;
90
91pub trait LinalgScalar:
102 Scalar
103 + std::ops::Sub<Output = Self>
104 + std::ops::Neg<Output = Self>
105 + std::ops::Div<Output = Self>
106 + num_traits::NumCast
107 + std::fmt::Debug
108 + 'static
109{
110 type Real: LinalgScalar<Real = Self::Real, Complex = Self::Complex> + num_traits::Float;
111 type Complex: LinalgScalar<Real = Self::Real, Complex = Self::Complex>;
112
113 fn abs_real(&self) -> Self::Real;
115 fn real_epsilon() -> Self::Real;
117 fn conj(&self) -> Self;
119 fn from_parts(real: Self::Real, imag: Self::Real) -> Self;
121 fn from_real(real: Self::Real) -> Self {
123 Self::from_parts(real, Self::Real::zero())
124 }
125 fn real_part(&self) -> Self::Real;
127 fn imag_part(&self) -> Self::Real;
129}
130
131pub trait KernelLinalgScalar: LinalgScalar {}
145
146pub trait LapackEigScalar: LinalgScalar {
162 fn eig_buffer_sizes(n: usize) -> (usize, usize);
164
165 fn eig_ri_to_complex(
167 n: usize,
168 val_ri: &[Self],
169 vec_ri: &[Self],
170 values_out: &mut [Self::Complex],
171 vectors_out: &mut [Self::Complex],
172 );
173}
174
175macro_rules! impl_real_linalg_scalar {
176 ($ty:ty, $complex:ty) => {
177 impl LinalgScalar for $ty {
178 type Real = $ty;
179 type Complex = $complex;
180
181 fn abs_real(&self) -> $ty {
182 num_traits::Float::abs(*self)
183 }
184
185 fn real_epsilon() -> $ty {
186 <$ty as num_traits::Float>::epsilon()
187 }
188
189 fn conj(&self) -> $ty {
190 *self
191 }
192
193 fn from_parts(real: Self::Real, _imag: Self::Real) -> Self {
194 real
195 }
196
197 fn real_part(&self) -> Self::Real {
198 *self
199 }
200
201 fn imag_part(&self) -> Self::Real {
202 0.0
203 }
204 }
205
206 impl KernelLinalgScalar for $ty {}
207
208 impl LapackEigScalar for $ty {
209 fn eig_buffer_sizes(n: usize) -> (usize, usize) {
210 (2 * n, 2 * n * n)
211 }
212
213 fn eig_ri_to_complex(
214 n: usize,
215 val_ri: &[Self],
216 vec_ri: &[Self],
217 values_out: &mut [$complex],
218 vectors_out: &mut [$complex],
219 ) {
220 for i in 0..n {
221 values_out[i] = <$complex>::new(val_ri[2 * i], val_ri[2 * i + 1]);
222 }
223 for k in 0..(n * n) {
224 vectors_out[k] = <$complex>::new(vec_ri[2 * k], vec_ri[2 * k + 1]);
225 }
226 }
227 }
228 };
229}
230
231macro_rules! impl_complex_linalg_scalar {
232 ($ty:ty, $real:ty) => {
233 impl LinalgScalar for $ty {
234 type Real = $real;
235 type Complex = $ty;
236
237 fn abs_real(&self) -> $real {
238 self.norm()
239 }
240
241 fn real_epsilon() -> $real {
242 <$real as num_traits::Float>::epsilon()
243 }
244
245 fn conj(&self) -> $ty {
246 self.conj()
247 }
248
249 fn from_parts(real: Self::Real, imag: Self::Real) -> Self {
250 <$ty>::new(real, imag)
251 }
252
253 fn real_part(&self) -> Self::Real {
254 self.re
255 }
256
257 fn imag_part(&self) -> Self::Real {
258 self.im
259 }
260 }
261
262 impl KernelLinalgScalar for $ty {}
263
264 impl LapackEigScalar for $ty {
265 fn eig_buffer_sizes(n: usize) -> (usize, usize) {
266 (n, n * n)
267 }
268
269 fn eig_ri_to_complex(
270 _n: usize,
271 val_ri: &[Self],
272 vec_ri: &[Self],
273 values_out: &mut [$ty],
274 vectors_out: &mut [$ty],
275 ) {
276 values_out.copy_from_slice(val_ri);
277 vectors_out.copy_from_slice(vec_ri);
278 }
279 }
280 };
281}
282
283impl_real_linalg_scalar!(f64, Complex64);
284impl_real_linalg_scalar!(f32, Complex32);
285impl_complex_linalg_scalar!(Complex64, f64);
286impl_complex_linalg_scalar!(Complex32, f32);
287
288#[derive(Clone)]
297pub struct QrTensorResult<T: LinalgScalar> {
298 pub q: Tensor<T>,
299 pub r: Tensor<T>,
300}
301
302#[derive(Clone)]
311pub struct SvdTensorResult<T: LinalgScalar> {
312 pub u: Tensor<T>,
313 pub s: Tensor<T::Real>,
314 pub vt: Tensor<T>,
315}
316
317#[derive(Clone)]
326pub struct LuTensorResult<T: LinalgScalar> {
327 pub l: Tensor<T>,
328 pub u: Tensor<T>,
329 pub pivots: Tensor<i32>,
330}
331
332#[derive(Clone)]
341pub struct LuTensorExResult<T: LinalgScalar> {
342 pub l: Tensor<T>,
344 pub u: Tensor<T>,
346 pub pivots: Tensor<i32>,
348 pub info: Tensor<i32>,
350}
351
352#[derive(Clone)]
361pub struct SolveTensorExResult<T: LinalgScalar> {
362 pub solution: Tensor<T>,
364 pub info: Tensor<i32>,
366}
367
368#[derive(Clone)]
377pub struct CholeskyTensorExResult<T: LinalgScalar> {
378 pub l: Tensor<T>,
380 pub info: Tensor<i32>,
382}
383
384#[derive(Clone)]
393pub struct EigenTensorResult<T: LinalgScalar> {
394 pub values: Tensor<T::Real>,
395 pub vectors: Tensor<T>,
396}
397
398#[derive(Clone)]
407pub struct EigTensorResult<T: LinalgScalar> {
408 pub values: Tensor<T::Complex>,
409 pub vectors: Tensor<T::Complex>,
410}
411
412#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
423pub enum LinalgCapabilityOp {
424 Solve,
425 SolveTriangular,
426 Qr,
427 ThinSvd,
428 LuFactor,
429 Cholesky,
430 EigenSym,
431 Eig,
432 LuSolve,
433 Lstsq,
434 LuFactorEx,
435 CholeskyEx,
436 SolveEx,
437 Inv,
438 Det,
439 Slogdet,
440 Pinv,
441 MatrixExp,
442 MatrixPower,
443 Cross,
444 HouseholderProduct,
445 Vander,
446 TensorInv,
447 TensorSolve,
448 Norm,
449}
450
451pub trait TensorLinalgPrims<T: KernelLinalgScalar> {
452 type Context;
453
454 fn has_linalg_support(op: LinalgCapabilityOp) -> bool;
455
456 fn solve_ex(
467 ctx: &mut Self::Context,
468 a: &Tensor<T>,
469 b: &Tensor<T>,
470 ) -> Result<SolveTensorExResult<T>>;
471
472 fn solve(ctx: &mut Self::Context, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>;
473 fn lu_solve(
474 ctx: &mut Self::Context,
475 factors: &Tensor<T>,
476 pivots: &Tensor<i32>,
477 b: &Tensor<T>,
478 ) -> Result<Tensor<T>>;
479 fn solve_triangular(
480 ctx: &mut Self::Context,
481 a: &Tensor<T>,
482 b: &Tensor<T>,
483 upper: bool,
484 ) -> Result<Tensor<T>>;
485 fn qr(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<QrTensorResult<T>>;
486 fn thin_svd(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<SvdTensorResult<T>>;
487 fn svdvals(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T::Real>>;
488 fn lu_factor_ex(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorExResult<T>>;
499 fn lu_factor(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorResult<T>>;
500 fn lu_factor_no_pivot(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorResult<T>>;
513 fn cholesky_ex(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<CholeskyTensorExResult<T>>;
524 fn cholesky(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T>>;
525 fn eigen_sym(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<EigenTensorResult<T>>;
526 fn eig(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<EigTensorResult<T>>;
527}
528
529#[cfg(test)]
530mod tests;