tenferro_linalg_prims/backend/
cpu.rs

1//! CPU tensor linalg backend.
2//!
3//! The actual provider implementation is selected at compile time via
4//! `linalg-faer` or `linalg-lapack` features.
5
6use std::any::TypeId;
7
8use num_complex::{Complex32, Complex64};
9
10use super::TensorLinalgContextFor;
11use crate::{
12    CholeskyTensorExResult, EigTensorResult, EigenTensorResult, KernelLinalgScalar,
13    LapackEigScalar, LinalgCapabilityOp, LuTensorExResult, LuTensorResult, QrTensorResult,
14    SolveTensorExResult, SvdTensorResult, TensorLinalgPrims,
15};
16use tenferro_device::Result;
17use tenferro_tensor::Tensor;
18
19#[cfg(feature = "linalg-faer")]
20use super::cpu_faer as cpu_impl;
21#[cfg(feature = "linalg-lapack")]
22use super::cpu_lapack as cpu_impl;
23
24#[cfg(feature = "linalg-faer")]
25type SelectedCpuSliceBackend = super::faer_backend::FaerBackend;
26#[cfg(feature = "linalg-lapack")]
27type SelectedCpuSliceBackend = super::cpu_lapack::LapackBackend;
28
29mod private {
30    use num_complex::{Complex32, Complex64};
31
32    use super::{super::LinalgBackend, SelectedCpuSliceBackend};
33    use crate::{KernelLinalgScalar, LapackEigScalar};
34    use tenferro_device::Result;
35
36    pub trait CpuLinalgOps: KernelLinalgScalar + LapackEigScalar {
37        fn solve_slices(
38            a: &[Self],
39            b: &[Self],
40            n: usize,
41            nrhs: usize,
42            x: &mut [Self],
43        ) -> Result<()>;
44        fn solve_triangular_slices(
45            a: &[Self],
46            b: &[Self],
47            n: usize,
48            nrhs: usize,
49            upper: bool,
50            x: &mut [Self],
51        ) -> Result<()>;
52        fn thin_svd_slices(
53            a: &[Self],
54            m: usize,
55            n: usize,
56            u: &mut [Self],
57            s: &mut [Self::Real],
58            vt: &mut [Self],
59        ) -> Result<()>;
60        fn qr_slices(a: &[Self], m: usize, n: usize, q: &mut [Self], r: &mut [Self]) -> Result<()>;
61        fn lu_slices(
62            a: &[Self],
63            m: usize,
64            n: usize,
65            perm: &mut [usize],
66            l: &mut [Self],
67            u_out: &mut [Self],
68        ) -> Result<()>;
69        fn cholesky_slices(a: &[Self], n: usize, l: &mut [Self]) -> Result<()>;
70        fn eigen_sym_slices(
71            a: &[Self],
72            n: usize,
73            values: &mut [Self::Real],
74            vectors: &mut [Self],
75        ) -> Result<()>;
76        fn eig_slices(
77            a: &[Self],
78            n: usize,
79            values_ri: &mut [Self],
80            vectors_ri: &mut [Self],
81        ) -> Result<()>;
82    }
83
84    macro_rules! impl_cpu_linalg_ops {
85        ($ty:ty) => {
86            impl CpuLinalgOps for $ty {
87                fn solve_slices(
88                    a: &[Self],
89                    b: &[Self],
90                    n: usize,
91                    nrhs: usize,
92                    x: &mut [Self],
93                ) -> Result<()> {
94                    SelectedCpuSliceBackend::new().solve(a, b, n, nrhs, x)
95                }
96
97                fn solve_triangular_slices(
98                    a: &[Self],
99                    b: &[Self],
100                    n: usize,
101                    nrhs: usize,
102                    upper: bool,
103                    x: &mut [Self],
104                ) -> Result<()> {
105                    SelectedCpuSliceBackend::new().solve_triangular(a, b, n, nrhs, upper, x)
106                }
107
108                fn thin_svd_slices(
109                    a: &[Self],
110                    m: usize,
111                    n: usize,
112                    u: &mut [Self],
113                    s: &mut [Self::Real],
114                    vt: &mut [Self],
115                ) -> Result<()> {
116                    SelectedCpuSliceBackend::new().thin_svd(a, m, n, u, s, vt)
117                }
118
119                fn qr_slices(
120                    a: &[Self],
121                    m: usize,
122                    n: usize,
123                    q: &mut [Self],
124                    r: &mut [Self],
125                ) -> Result<()> {
126                    SelectedCpuSliceBackend::new().qr(a, m, n, q, r)
127                }
128
129                fn lu_slices(
130                    a: &[Self],
131                    m: usize,
132                    n: usize,
133                    perm: &mut [usize],
134                    l: &mut [Self],
135                    u_out: &mut [Self],
136                ) -> Result<()> {
137                    SelectedCpuSliceBackend::new().lu(a, m, n, perm, l, u_out)
138                }
139
140                fn cholesky_slices(a: &[Self], n: usize, l: &mut [Self]) -> Result<()> {
141                    SelectedCpuSliceBackend::new().cholesky(a, n, l)
142                }
143
144                fn eigen_sym_slices(
145                    a: &[Self],
146                    n: usize,
147                    values: &mut [Self::Real],
148                    vectors: &mut [Self],
149                ) -> Result<()> {
150                    SelectedCpuSliceBackend::new().eigen_sym(a, n, values, vectors)
151                }
152
153                fn eig_slices(
154                    a: &[Self],
155                    n: usize,
156                    values_ri: &mut [Self],
157                    vectors_ri: &mut [Self],
158                ) -> Result<()> {
159                    SelectedCpuSliceBackend::new().eig_general(a, n, values_ri, vectors_ri)
160                }
161            }
162        };
163    }
164
165    impl_cpu_linalg_ops!(f64);
166    impl_cpu_linalg_ops!(f32);
167    impl_cpu_linalg_ops!(Complex64);
168    impl_cpu_linalg_ops!(Complex32);
169}
170
171macro_rules! cast_slice {
172    ($slice:expr, $from:ty, $to:ty) => {{
173        unsafe { &*($slice as *const [$from] as *const [$to]) }
174    }};
175}
176
177macro_rules! cast_slice_mut {
178    ($slice:expr, $from:ty, $to:ty) => {{
179        unsafe { &mut *($slice as *mut [$from] as *mut [$to]) }
180    }};
181}
182
183macro_rules! dispatch_kernel_linalg_scalar_type {
184    ($generic:ty, $concrete:ident, $body:block) => {{
185        let tid = TypeId::of::<$generic>();
186        if tid == TypeId::of::<f64>() {
187            type $concrete = f64;
188            $body
189        } else if tid == TypeId::of::<f32>() {
190            type $concrete = f32;
191            $body
192        } else if tid == TypeId::of::<Complex64>() {
193            type $concrete = Complex64;
194            $body
195        } else if tid == TypeId::of::<Complex32>() {
196            type $concrete = Complex32;
197            $body
198        } else {
199            unreachable!("KernelLinalgScalar must be one of the standard linalg dtypes")
200        }
201    }};
202}
203
204pub(crate) fn solve_slices<T: KernelLinalgScalar>(
205    a: &[T],
206    b: &[T],
207    n: usize,
208    nrhs: usize,
209    x: &mut [T],
210) -> Result<()> {
211    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
212        <Concrete as private::CpuLinalgOps>::solve_slices(
213            cast_slice!(a, T, Concrete),
214            cast_slice!(b, T, Concrete),
215            n,
216            nrhs,
217            cast_slice_mut!(x, T, Concrete),
218        )
219    })
220}
221
222pub(crate) fn solve_triangular_slices<T: KernelLinalgScalar>(
223    a: &[T],
224    b: &[T],
225    n: usize,
226    nrhs: usize,
227    upper: bool,
228    x: &mut [T],
229) -> Result<()> {
230    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
231        <Concrete as private::CpuLinalgOps>::solve_triangular_slices(
232            cast_slice!(a, T, Concrete),
233            cast_slice!(b, T, Concrete),
234            n,
235            nrhs,
236            upper,
237            cast_slice_mut!(x, T, Concrete),
238        )
239    })
240}
241
242pub(crate) fn thin_svd_slices<T: KernelLinalgScalar>(
243    a: &[T],
244    m: usize,
245    n: usize,
246    u: &mut [T],
247    s: &mut [T::Real],
248    vt: &mut [T],
249) -> Result<()> {
250    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
251        type ConcreteReal = <Concrete as crate::LinalgScalar>::Real;
252        <Concrete as private::CpuLinalgOps>::thin_svd_slices(
253            cast_slice!(a, T, Concrete),
254            m,
255            n,
256            cast_slice_mut!(u, T, Concrete),
257            cast_slice_mut!(s, T::Real, ConcreteReal),
258            cast_slice_mut!(vt, T, Concrete),
259        )
260    })
261}
262
263pub(crate) fn qr_slices<T: KernelLinalgScalar>(
264    a: &[T],
265    m: usize,
266    n: usize,
267    q: &mut [T],
268    r: &mut [T],
269) -> Result<()> {
270    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
271        <Concrete as private::CpuLinalgOps>::qr_slices(
272            cast_slice!(a, T, Concrete),
273            m,
274            n,
275            cast_slice_mut!(q, T, Concrete),
276            cast_slice_mut!(r, T, Concrete),
277        )
278    })
279}
280
281pub(crate) fn lu_slices<T: KernelLinalgScalar>(
282    a: &[T],
283    m: usize,
284    n: usize,
285    perm: &mut [usize],
286    l: &mut [T],
287    u_out: &mut [T],
288) -> Result<()> {
289    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
290        <Concrete as private::CpuLinalgOps>::lu_slices(
291            cast_slice!(a, T, Concrete),
292            m,
293            n,
294            perm,
295            cast_slice_mut!(l, T, Concrete),
296            cast_slice_mut!(u_out, T, Concrete),
297        )
298    })
299}
300
301pub(crate) fn cholesky_slices<T: KernelLinalgScalar>(a: &[T], n: usize, l: &mut [T]) -> Result<()> {
302    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
303        <Concrete as private::CpuLinalgOps>::cholesky_slices(
304            cast_slice!(a, T, Concrete),
305            n,
306            cast_slice_mut!(l, T, Concrete),
307        )
308    })
309}
310
311pub(crate) fn eigen_sym_slices<T: KernelLinalgScalar>(
312    a: &[T],
313    n: usize,
314    values: &mut [T::Real],
315    vectors: &mut [T],
316) -> Result<()> {
317    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
318        type ConcreteReal = <Concrete as crate::LinalgScalar>::Real;
319        <Concrete as private::CpuLinalgOps>::eigen_sym_slices(
320            cast_slice!(a, T, Concrete),
321            n,
322            cast_slice_mut!(values, T::Real, ConcreteReal),
323            cast_slice_mut!(vectors, T, Concrete),
324        )
325    })
326}
327
328pub(crate) fn eig_slices<T: KernelLinalgScalar>(
329    a: &[T],
330    n: usize,
331    values_ri: &mut [T],
332    vectors_ri: &mut [T],
333) -> Result<()> {
334    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
335        <Concrete as private::CpuLinalgOps>::eig_slices(
336            cast_slice!(a, T, Concrete),
337            n,
338            cast_slice_mut!(values_ri, T, Concrete),
339            cast_slice_mut!(vectors_ri, T, Concrete),
340        )
341    })
342}
343
344pub(crate) fn eig_buffer_sizes<T: KernelLinalgScalar>(n: usize) -> (usize, usize) {
345    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
346        <Concrete as LapackEigScalar>::eig_buffer_sizes(n)
347    })
348}
349
350pub(crate) fn eig_ri_to_complex<T: KernelLinalgScalar>(
351    n: usize,
352    val_ri: &[T],
353    vec_ri: &[T],
354    values_out: &mut [T::Complex],
355    vectors_out: &mut [T::Complex],
356) {
357    dispatch_kernel_linalg_scalar_type!(T, Concrete, {
358        type ConcreteComplex = <Concrete as crate::LinalgScalar>::Complex;
359        <Concrete as LapackEigScalar>::eig_ri_to_complex(
360            n,
361            cast_slice!(val_ri, T, Concrete),
362            cast_slice!(vec_ri, T, Concrete),
363            cast_slice_mut!(values_out, T::Complex, ConcreteComplex),
364            cast_slice_mut!(vectors_out, T::Complex, ConcreteComplex),
365        )
366    })
367}
368
369/// Marker type for the CPU tensor linalg backend.
370///
371/// # Examples
372///
373/// ```ignore
374/// let _backend = tenferro_linalg_prims::backend::CpuTensorLinalgBackend;
375/// ```
376#[derive(Debug, Default, Clone, Copy)]
377pub struct CpuTensorLinalgBackend;
378
379impl<T> TensorLinalgPrims<T> for CpuTensorLinalgBackend
380where
381    T: KernelLinalgScalar,
382{
383    type Context = tenferro_prims::CpuContext;
384
385    fn has_linalg_support(_op: LinalgCapabilityOp) -> bool {
386        true
387    }
388
389    fn solve_ex(
390        ctx: &mut Self::Context,
391        a: &Tensor<T>,
392        b: &Tensor<T>,
393    ) -> Result<SolveTensorExResult<T>> {
394        super::cpu_tensor_impl::solve_ex(ctx, a, b)
395    }
396
397    fn solve(ctx: &mut Self::Context, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>> {
398        cpu_impl::solve(ctx, a, b)
399    }
400
401    fn lu_solve(
402        ctx: &mut Self::Context,
403        factors: &Tensor<T>,
404        pivots: &Tensor<i32>,
405        b: &Tensor<T>,
406    ) -> Result<Tensor<T>> {
407        super::cpu_tensor_impl::lu_solve(ctx, factors, pivots, b)
408    }
409
410    fn solve_triangular(
411        ctx: &mut Self::Context,
412        a: &Tensor<T>,
413        b: &Tensor<T>,
414        upper: bool,
415    ) -> Result<Tensor<T>> {
416        cpu_impl::solve_triangular(ctx, a, b, upper)
417    }
418
419    fn qr(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<QrTensorResult<T>> {
420        cpu_impl::qr(ctx, a)
421    }
422
423    fn thin_svd(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<SvdTensorResult<T>> {
424        cpu_impl::thin_svd(ctx, a)
425    }
426
427    fn svdvals(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T::Real>> {
428        Ok(cpu_impl::thin_svd(ctx, a)?.s)
429    }
430
431    fn lu_factor_ex(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorExResult<T>> {
432        super::cpu_tensor_impl::lu_factor_ex(ctx, a)
433    }
434
435    fn lu_factor(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorResult<T>> {
436        cpu_impl::lu_factor(ctx, a)
437    }
438
439    fn lu_factor_no_pivot(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorResult<T>> {
440        super::cpu_tensor_impl::lu_factor_no_pivot(ctx, a)
441    }
442
443    fn cholesky_ex(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<CholeskyTensorExResult<T>> {
444        super::cpu_tensor_impl::cholesky_ex(ctx, a)
445    }
446
447    fn cholesky(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T>> {
448        cpu_impl::cholesky(ctx, a)
449    }
450
451    fn eigen_sym(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<EigenTensorResult<T>> {
452        cpu_impl::eigen_sym(ctx, a)
453    }
454
455    fn eig(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<EigTensorResult<T>> {
456        cpu_impl::eig(ctx, a)
457    }
458}
459
460impl<T> TensorLinalgContextFor<T> for tenferro_prims::CpuContext
461where
462    T: KernelLinalgScalar,
463{
464    type Backend = CpuTensorLinalgBackend;
465}