1use std::any::TypeId;
7
8use num_complex::{Complex32, Complex64};
9
10use super::TensorLinalgContextFor;
11use crate::{
12 CholeskyTensorExResult, EigTensorResult, EigenTensorResult, KernelLinalgScalar,
13 LapackEigScalar, LinalgCapabilityOp, LuTensorExResult, LuTensorResult, QrTensorResult,
14 SolveTensorExResult, SvdTensorResult, TensorLinalgPrims,
15};
16use tenferro_device::Result;
17use tenferro_tensor::Tensor;
18
19#[cfg(feature = "linalg-faer")]
20use super::cpu_faer as cpu_impl;
21#[cfg(feature = "linalg-lapack")]
22use super::cpu_lapack as cpu_impl;
23
24#[cfg(feature = "linalg-faer")]
25type SelectedCpuSliceBackend = super::faer_backend::FaerBackend;
26#[cfg(feature = "linalg-lapack")]
27type SelectedCpuSliceBackend = super::cpu_lapack::LapackBackend;
28
29mod private {
30 use num_complex::{Complex32, Complex64};
31
32 use super::{super::LinalgBackend, SelectedCpuSliceBackend};
33 use crate::{KernelLinalgScalar, LapackEigScalar};
34 use tenferro_device::Result;
35
36 pub trait CpuLinalgOps: KernelLinalgScalar + LapackEigScalar {
37 fn solve_slices(
38 a: &[Self],
39 b: &[Self],
40 n: usize,
41 nrhs: usize,
42 x: &mut [Self],
43 ) -> Result<()>;
44 fn solve_triangular_slices(
45 a: &[Self],
46 b: &[Self],
47 n: usize,
48 nrhs: usize,
49 upper: bool,
50 x: &mut [Self],
51 ) -> Result<()>;
52 fn thin_svd_slices(
53 a: &[Self],
54 m: usize,
55 n: usize,
56 u: &mut [Self],
57 s: &mut [Self::Real],
58 vt: &mut [Self],
59 ) -> Result<()>;
60 fn qr_slices(a: &[Self], m: usize, n: usize, q: &mut [Self], r: &mut [Self]) -> Result<()>;
61 fn lu_slices(
62 a: &[Self],
63 m: usize,
64 n: usize,
65 perm: &mut [usize],
66 l: &mut [Self],
67 u_out: &mut [Self],
68 ) -> Result<()>;
69 fn cholesky_slices(a: &[Self], n: usize, l: &mut [Self]) -> Result<()>;
70 fn eigen_sym_slices(
71 a: &[Self],
72 n: usize,
73 values: &mut [Self::Real],
74 vectors: &mut [Self],
75 ) -> Result<()>;
76 fn eig_slices(
77 a: &[Self],
78 n: usize,
79 values_ri: &mut [Self],
80 vectors_ri: &mut [Self],
81 ) -> Result<()>;
82 }
83
84 macro_rules! impl_cpu_linalg_ops {
85 ($ty:ty) => {
86 impl CpuLinalgOps for $ty {
87 fn solve_slices(
88 a: &[Self],
89 b: &[Self],
90 n: usize,
91 nrhs: usize,
92 x: &mut [Self],
93 ) -> Result<()> {
94 SelectedCpuSliceBackend::new().solve(a, b, n, nrhs, x)
95 }
96
97 fn solve_triangular_slices(
98 a: &[Self],
99 b: &[Self],
100 n: usize,
101 nrhs: usize,
102 upper: bool,
103 x: &mut [Self],
104 ) -> Result<()> {
105 SelectedCpuSliceBackend::new().solve_triangular(a, b, n, nrhs, upper, x)
106 }
107
108 fn thin_svd_slices(
109 a: &[Self],
110 m: usize,
111 n: usize,
112 u: &mut [Self],
113 s: &mut [Self::Real],
114 vt: &mut [Self],
115 ) -> Result<()> {
116 SelectedCpuSliceBackend::new().thin_svd(a, m, n, u, s, vt)
117 }
118
119 fn qr_slices(
120 a: &[Self],
121 m: usize,
122 n: usize,
123 q: &mut [Self],
124 r: &mut [Self],
125 ) -> Result<()> {
126 SelectedCpuSliceBackend::new().qr(a, m, n, q, r)
127 }
128
129 fn lu_slices(
130 a: &[Self],
131 m: usize,
132 n: usize,
133 perm: &mut [usize],
134 l: &mut [Self],
135 u_out: &mut [Self],
136 ) -> Result<()> {
137 SelectedCpuSliceBackend::new().lu(a, m, n, perm, l, u_out)
138 }
139
140 fn cholesky_slices(a: &[Self], n: usize, l: &mut [Self]) -> Result<()> {
141 SelectedCpuSliceBackend::new().cholesky(a, n, l)
142 }
143
144 fn eigen_sym_slices(
145 a: &[Self],
146 n: usize,
147 values: &mut [Self::Real],
148 vectors: &mut [Self],
149 ) -> Result<()> {
150 SelectedCpuSliceBackend::new().eigen_sym(a, n, values, vectors)
151 }
152
153 fn eig_slices(
154 a: &[Self],
155 n: usize,
156 values_ri: &mut [Self],
157 vectors_ri: &mut [Self],
158 ) -> Result<()> {
159 SelectedCpuSliceBackend::new().eig_general(a, n, values_ri, vectors_ri)
160 }
161 }
162 };
163 }
164
165 impl_cpu_linalg_ops!(f64);
166 impl_cpu_linalg_ops!(f32);
167 impl_cpu_linalg_ops!(Complex64);
168 impl_cpu_linalg_ops!(Complex32);
169}
170
171macro_rules! cast_slice {
172 ($slice:expr, $from:ty, $to:ty) => {{
173 unsafe { &*($slice as *const [$from] as *const [$to]) }
174 }};
175}
176
177macro_rules! cast_slice_mut {
178 ($slice:expr, $from:ty, $to:ty) => {{
179 unsafe { &mut *($slice as *mut [$from] as *mut [$to]) }
180 }};
181}
182
183macro_rules! dispatch_kernel_linalg_scalar_type {
184 ($generic:ty, $concrete:ident, $body:block) => {{
185 let tid = TypeId::of::<$generic>();
186 if tid == TypeId::of::<f64>() {
187 type $concrete = f64;
188 $body
189 } else if tid == TypeId::of::<f32>() {
190 type $concrete = f32;
191 $body
192 } else if tid == TypeId::of::<Complex64>() {
193 type $concrete = Complex64;
194 $body
195 } else if tid == TypeId::of::<Complex32>() {
196 type $concrete = Complex32;
197 $body
198 } else {
199 unreachable!("KernelLinalgScalar must be one of the standard linalg dtypes")
200 }
201 }};
202}
203
204pub(crate) fn solve_slices<T: KernelLinalgScalar>(
205 a: &[T],
206 b: &[T],
207 n: usize,
208 nrhs: usize,
209 x: &mut [T],
210) -> Result<()> {
211 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
212 <Concrete as private::CpuLinalgOps>::solve_slices(
213 cast_slice!(a, T, Concrete),
214 cast_slice!(b, T, Concrete),
215 n,
216 nrhs,
217 cast_slice_mut!(x, T, Concrete),
218 )
219 })
220}
221
222pub(crate) fn solve_triangular_slices<T: KernelLinalgScalar>(
223 a: &[T],
224 b: &[T],
225 n: usize,
226 nrhs: usize,
227 upper: bool,
228 x: &mut [T],
229) -> Result<()> {
230 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
231 <Concrete as private::CpuLinalgOps>::solve_triangular_slices(
232 cast_slice!(a, T, Concrete),
233 cast_slice!(b, T, Concrete),
234 n,
235 nrhs,
236 upper,
237 cast_slice_mut!(x, T, Concrete),
238 )
239 })
240}
241
242pub(crate) fn thin_svd_slices<T: KernelLinalgScalar>(
243 a: &[T],
244 m: usize,
245 n: usize,
246 u: &mut [T],
247 s: &mut [T::Real],
248 vt: &mut [T],
249) -> Result<()> {
250 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
251 type ConcreteReal = <Concrete as crate::LinalgScalar>::Real;
252 <Concrete as private::CpuLinalgOps>::thin_svd_slices(
253 cast_slice!(a, T, Concrete),
254 m,
255 n,
256 cast_slice_mut!(u, T, Concrete),
257 cast_slice_mut!(s, T::Real, ConcreteReal),
258 cast_slice_mut!(vt, T, Concrete),
259 )
260 })
261}
262
263pub(crate) fn qr_slices<T: KernelLinalgScalar>(
264 a: &[T],
265 m: usize,
266 n: usize,
267 q: &mut [T],
268 r: &mut [T],
269) -> Result<()> {
270 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
271 <Concrete as private::CpuLinalgOps>::qr_slices(
272 cast_slice!(a, T, Concrete),
273 m,
274 n,
275 cast_slice_mut!(q, T, Concrete),
276 cast_slice_mut!(r, T, Concrete),
277 )
278 })
279}
280
281pub(crate) fn lu_slices<T: KernelLinalgScalar>(
282 a: &[T],
283 m: usize,
284 n: usize,
285 perm: &mut [usize],
286 l: &mut [T],
287 u_out: &mut [T],
288) -> Result<()> {
289 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
290 <Concrete as private::CpuLinalgOps>::lu_slices(
291 cast_slice!(a, T, Concrete),
292 m,
293 n,
294 perm,
295 cast_slice_mut!(l, T, Concrete),
296 cast_slice_mut!(u_out, T, Concrete),
297 )
298 })
299}
300
301pub(crate) fn cholesky_slices<T: KernelLinalgScalar>(a: &[T], n: usize, l: &mut [T]) -> Result<()> {
302 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
303 <Concrete as private::CpuLinalgOps>::cholesky_slices(
304 cast_slice!(a, T, Concrete),
305 n,
306 cast_slice_mut!(l, T, Concrete),
307 )
308 })
309}
310
311pub(crate) fn eigen_sym_slices<T: KernelLinalgScalar>(
312 a: &[T],
313 n: usize,
314 values: &mut [T::Real],
315 vectors: &mut [T],
316) -> Result<()> {
317 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
318 type ConcreteReal = <Concrete as crate::LinalgScalar>::Real;
319 <Concrete as private::CpuLinalgOps>::eigen_sym_slices(
320 cast_slice!(a, T, Concrete),
321 n,
322 cast_slice_mut!(values, T::Real, ConcreteReal),
323 cast_slice_mut!(vectors, T, Concrete),
324 )
325 })
326}
327
328pub(crate) fn eig_slices<T: KernelLinalgScalar>(
329 a: &[T],
330 n: usize,
331 values_ri: &mut [T],
332 vectors_ri: &mut [T],
333) -> Result<()> {
334 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
335 <Concrete as private::CpuLinalgOps>::eig_slices(
336 cast_slice!(a, T, Concrete),
337 n,
338 cast_slice_mut!(values_ri, T, Concrete),
339 cast_slice_mut!(vectors_ri, T, Concrete),
340 )
341 })
342}
343
344pub(crate) fn eig_buffer_sizes<T: KernelLinalgScalar>(n: usize) -> (usize, usize) {
345 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
346 <Concrete as LapackEigScalar>::eig_buffer_sizes(n)
347 })
348}
349
350pub(crate) fn eig_ri_to_complex<T: KernelLinalgScalar>(
351 n: usize,
352 val_ri: &[T],
353 vec_ri: &[T],
354 values_out: &mut [T::Complex],
355 vectors_out: &mut [T::Complex],
356) {
357 dispatch_kernel_linalg_scalar_type!(T, Concrete, {
358 type ConcreteComplex = <Concrete as crate::LinalgScalar>::Complex;
359 <Concrete as LapackEigScalar>::eig_ri_to_complex(
360 n,
361 cast_slice!(val_ri, T, Concrete),
362 cast_slice!(vec_ri, T, Concrete),
363 cast_slice_mut!(values_out, T::Complex, ConcreteComplex),
364 cast_slice_mut!(vectors_out, T::Complex, ConcreteComplex),
365 )
366 })
367}
368
369#[derive(Debug, Default, Clone, Copy)]
377pub struct CpuTensorLinalgBackend;
378
379impl<T> TensorLinalgPrims<T> for CpuTensorLinalgBackend
380where
381 T: KernelLinalgScalar,
382{
383 type Context = tenferro_prims::CpuContext;
384
385 fn has_linalg_support(_op: LinalgCapabilityOp) -> bool {
386 true
387 }
388
389 fn solve_ex(
390 ctx: &mut Self::Context,
391 a: &Tensor<T>,
392 b: &Tensor<T>,
393 ) -> Result<SolveTensorExResult<T>> {
394 super::cpu_tensor_impl::solve_ex(ctx, a, b)
395 }
396
397 fn solve(ctx: &mut Self::Context, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>> {
398 cpu_impl::solve(ctx, a, b)
399 }
400
401 fn lu_solve(
402 ctx: &mut Self::Context,
403 factors: &Tensor<T>,
404 pivots: &Tensor<i32>,
405 b: &Tensor<T>,
406 ) -> Result<Tensor<T>> {
407 super::cpu_tensor_impl::lu_solve(ctx, factors, pivots, b)
408 }
409
410 fn solve_triangular(
411 ctx: &mut Self::Context,
412 a: &Tensor<T>,
413 b: &Tensor<T>,
414 upper: bool,
415 ) -> Result<Tensor<T>> {
416 cpu_impl::solve_triangular(ctx, a, b, upper)
417 }
418
419 fn qr(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<QrTensorResult<T>> {
420 cpu_impl::qr(ctx, a)
421 }
422
423 fn thin_svd(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<SvdTensorResult<T>> {
424 cpu_impl::thin_svd(ctx, a)
425 }
426
427 fn svdvals(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T::Real>> {
428 Ok(cpu_impl::thin_svd(ctx, a)?.s)
429 }
430
431 fn lu_factor_ex(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorExResult<T>> {
432 super::cpu_tensor_impl::lu_factor_ex(ctx, a)
433 }
434
435 fn lu_factor(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorResult<T>> {
436 cpu_impl::lu_factor(ctx, a)
437 }
438
439 fn lu_factor_no_pivot(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<LuTensorResult<T>> {
440 super::cpu_tensor_impl::lu_factor_no_pivot(ctx, a)
441 }
442
443 fn cholesky_ex(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<CholeskyTensorExResult<T>> {
444 super::cpu_tensor_impl::cholesky_ex(ctx, a)
445 }
446
447 fn cholesky(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T>> {
448 cpu_impl::cholesky(ctx, a)
449 }
450
451 fn eigen_sym(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<EigenTensorResult<T>> {
452 cpu_impl::eigen_sym(ctx, a)
453 }
454
455 fn eig(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<EigTensorResult<T>> {
456 cpu_impl::eig(ctx, a)
457 }
458}
459
460impl<T> TensorLinalgContextFor<T> for tenferro_prims::CpuContext
461where
462 T: KernelLinalgScalar,
463{
464 type Backend = CpuTensorLinalgBackend;
465}