tenferro_linalg_prims/backend/
cuda.rs

1mod cholesky;
2mod lu;
3mod qr;
4mod runtime;
5mod scalar_type;
6mod solve;
7mod solve_triangular;
8mod svdvals;
9mod thin_svd;
10#[cfg(any(feature = "cuda", test))]
11mod wrappers;
12
13use num_complex::{Complex32, Complex64};
14use tenferro_algebra::Standard;
15use tenferro_device::Result;
16use tenferro_prims::{
17    AnalyticPrimsDescriptor, AnalyticUnaryOp, ComplexRealPrimsDescriptor, ComplexRealUnaryOp,
18    ComplexScalePrimsDescriptor, ScalarBinaryOp, ScalarPrimsDescriptor, ScalarReductionOp,
19    ScalarUnaryOp, TensorAnalyticPrims, TensorComplexRealPrims, TensorComplexScalePrims,
20    TensorScalarPrims,
21};
22use tenferro_tensor::Tensor;
23
24use super::TensorLinalgContextFor;
25use crate::{
26    CholeskyTensorExResult, EigTensorResult, EigenTensorResult, LinalgCapabilityOp,
27    LuTensorExResult, LuTensorResult, QrTensorResult, SolveTensorExResult, SvdTensorResult,
28    TensorLinalgPrims,
29};
30pub use scalar_type::{CudaDataType, CudaLinalgScalar};
31
32/// Marker type for the CUDA tensor linalg backend.
33///
34/// # Examples
35///
36/// ```ignore
37/// let _backend = tenferro_linalg_prims::backend::CudaTensorLinalgBackend;
38/// ```
39#[derive(Debug, Default, Clone, Copy)]
40pub struct CudaTensorLinalgBackend;
41
42fn unsupported<T, S: CudaLinalgScalar>(op: &str) -> Result<T> {
43    let _ = S::cuda_data_type();
44    runtime::unsupported(op)
45}
46
47fn has_real_det_support_f32() -> bool {
48    <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
49        ScalarPrimsDescriptor::PointwiseBinary {
50            op: ScalarBinaryOp::Mul,
51        },
52    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
53        ScalarPrimsDescriptor::Reduction {
54            modes_a: vec![0],
55            modes_c: vec![],
56            op: ScalarReductionOp::Prod,
57        },
58    )
59}
60
61fn has_real_det_support_f64() -> bool {
62    <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
63        ScalarPrimsDescriptor::PointwiseBinary {
64            op: ScalarBinaryOp::Mul,
65        },
66    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
67        ScalarPrimsDescriptor::Reduction {
68            modes_a: vec![0],
69            modes_c: vec![],
70            op: ScalarReductionOp::Prod,
71        },
72    )
73}
74
75fn has_complex_det_support_c32() -> bool {
76    <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex32>>::has_complex_scale_support(
77        ComplexScalePrimsDescriptor::PointwiseMul,
78    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<Complex32>>>::has_scalar_support(
79        ScalarPrimsDescriptor::Reduction {
80            modes_a: vec![0],
81            modes_c: vec![],
82            op: ScalarReductionOp::Prod,
83        },
84    )
85}
86
87fn has_complex_det_support_c64() -> bool {
88    <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex64>>::has_complex_scale_support(
89        ComplexScalePrimsDescriptor::PointwiseMul,
90    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<Complex64>>>::has_scalar_support(
91        ScalarPrimsDescriptor::Reduction {
92            modes_a: vec![0],
93            modes_c: vec![],
94            op: ScalarReductionOp::Prod,
95        },
96    )
97}
98
99fn has_complex_slogdet_support_c32() -> bool {
100    <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex32>>::has_complex_real_support(
101        ComplexRealPrimsDescriptor::PointwiseUnary {
102            op: ComplexRealUnaryOp::Abs,
103        },
104    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
105        ScalarPrimsDescriptor::PointwiseUnary {
106            op: ScalarUnaryOp::Reciprocal,
107        },
108    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
109        ScalarPrimsDescriptor::PointwiseBinary {
110            op: ScalarBinaryOp::Greater,
111        },
112    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
113        ScalarPrimsDescriptor::PointwiseTernary {
114            op: tenferro_prims::ScalarTernaryOp::Where,
115        },
116    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
117        ScalarPrimsDescriptor::Reduction {
118            modes_a: vec![0],
119            modes_c: vec![],
120            op: ScalarReductionOp::Sum,
121        },
122    ) && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
123        AnalyticPrimsDescriptor::PointwiseUnary {
124            op: AnalyticUnaryOp::Log,
125        },
126    ) && <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex32>>::has_complex_scale_support(
127        ComplexScalePrimsDescriptor::PointwiseMul,
128    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<Complex32>>>::has_scalar_support(
129        ScalarPrimsDescriptor::Reduction {
130            modes_a: vec![0],
131            modes_c: vec![],
132            op: ScalarReductionOp::Prod,
133        },
134    )
135}
136
137fn has_complex_slogdet_support_c64() -> bool {
138    <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex64>>::has_complex_real_support(
139        ComplexRealPrimsDescriptor::PointwiseUnary {
140            op: ComplexRealUnaryOp::Abs,
141        },
142    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
143        ScalarPrimsDescriptor::PointwiseUnary {
144            op: ScalarUnaryOp::Reciprocal,
145        },
146    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
147        ScalarPrimsDescriptor::PointwiseBinary {
148            op: ScalarBinaryOp::Greater,
149        },
150    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
151        ScalarPrimsDescriptor::PointwiseTernary {
152            op: tenferro_prims::ScalarTernaryOp::Where,
153        },
154    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
155        ScalarPrimsDescriptor::Reduction {
156            modes_a: vec![0],
157            modes_c: vec![],
158            op: ScalarReductionOp::Sum,
159        },
160    ) && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
161        AnalyticPrimsDescriptor::PointwiseUnary {
162            op: AnalyticUnaryOp::Log,
163        },
164    ) && <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex64>>::has_complex_scale_support(
165        ComplexScalePrimsDescriptor::PointwiseMul,
166    ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<Complex64>>>::has_scalar_support(
167        ScalarPrimsDescriptor::Reduction {
168            modes_a: vec![0],
169            modes_c: vec![],
170            op: ScalarReductionOp::Prod,
171        },
172    )
173}
174
175fn has_det_support<T: CudaLinalgScalar>() -> bool {
176    match T::cuda_data_type() {
177        scalar_type::CudaDataType::F32 => lu::has_lu_support::<T>() && has_real_det_support_f32(),
178        scalar_type::CudaDataType::F64 => lu::has_lu_support::<T>() && has_real_det_support_f64(),
179        scalar_type::CudaDataType::Complex32 => {
180            lu::has_lu_support::<T>() && has_complex_det_support_c32()
181        }
182        scalar_type::CudaDataType::Complex64 => {
183            lu::has_lu_support::<T>() && has_complex_det_support_c64()
184        }
185    }
186}
187
188fn has_real_slogdet_support_f32() -> bool {
189    has_real_det_support_f32()
190        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
191            ScalarPrimsDescriptor::PointwiseUnary {
192                op: ScalarUnaryOp::Abs,
193            },
194        )
195        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
196            ScalarPrimsDescriptor::PointwiseBinary {
197                op: ScalarBinaryOp::Greater,
198            },
199        )
200        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
201            ScalarPrimsDescriptor::PointwiseBinary {
202                op: ScalarBinaryOp::Mul,
203            },
204        )
205        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
206            ScalarPrimsDescriptor::PointwiseBinary {
207                op: ScalarBinaryOp::Add,
208            },
209        )
210        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
211            ScalarPrimsDescriptor::PointwiseBinary {
212                op: ScalarBinaryOp::Sub,
213            },
214        )
215        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
216            ScalarPrimsDescriptor::Reduction {
217                modes_a: vec![0],
218                modes_c: vec![],
219                op: ScalarReductionOp::Sum,
220            },
221        )
222        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
223            AnalyticPrimsDescriptor::PointwiseUnary {
224                op: AnalyticUnaryOp::Log,
225            },
226        )
227}
228
229fn has_real_slogdet_support_f64() -> bool {
230    has_real_det_support_f64()
231        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
232            ScalarPrimsDescriptor::PointwiseUnary {
233                op: ScalarUnaryOp::Abs,
234            },
235        )
236        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
237            ScalarPrimsDescriptor::PointwiseBinary {
238                op: ScalarBinaryOp::Greater,
239            },
240        )
241        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
242            ScalarPrimsDescriptor::PointwiseBinary {
243                op: ScalarBinaryOp::Mul,
244            },
245        )
246        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
247            ScalarPrimsDescriptor::PointwiseBinary {
248                op: ScalarBinaryOp::Add,
249            },
250        )
251        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
252            ScalarPrimsDescriptor::PointwiseBinary {
253                op: ScalarBinaryOp::Sub,
254            },
255        )
256        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
257            ScalarPrimsDescriptor::Reduction {
258                modes_a: vec![0],
259                modes_c: vec![],
260                op: ScalarReductionOp::Sum,
261            },
262        )
263        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
264            AnalyticPrimsDescriptor::PointwiseUnary {
265                op: AnalyticUnaryOp::Log,
266            },
267        )
268}
269
270fn has_slogdet_support<T: CudaLinalgScalar>() -> bool {
271    match T::cuda_data_type() {
272        scalar_type::CudaDataType::F32 => {
273            lu::has_lu_support::<T>() && has_real_slogdet_support_f32()
274        }
275        scalar_type::CudaDataType::F64 => {
276            lu::has_lu_support::<T>() && has_real_slogdet_support_f64()
277        }
278        scalar_type::CudaDataType::Complex32 => {
279            lu::has_lu_support::<T>() && has_complex_slogdet_support_c32()
280        }
281        scalar_type::CudaDataType::Complex64 => {
282            lu::has_lu_support::<T>() && has_complex_slogdet_support_c64()
283        }
284    }
285}
286
287fn has_real_pinv_support_f32() -> bool {
288    thin_svd::has_thin_svd_support::<f32>()
289        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
290            ScalarPrimsDescriptor::PointwiseUnary {
291                op: ScalarUnaryOp::Reciprocal,
292            },
293        )
294        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
295            ScalarPrimsDescriptor::PointwiseBinary {
296                op: ScalarBinaryOp::Greater,
297            },
298        )
299        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
300            ScalarPrimsDescriptor::PointwiseBinary {
301                op: ScalarBinaryOp::Mul,
302            },
303        )
304        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
305            ScalarPrimsDescriptor::PointwiseBinary {
306                op: ScalarBinaryOp::Add,
307            },
308        )
309        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
310            ScalarPrimsDescriptor::PointwiseBinary {
311                op: ScalarBinaryOp::Sub,
312            },
313        )
314        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
315            ScalarPrimsDescriptor::Reduction {
316                modes_a: vec![0],
317                modes_c: vec![],
318                op: ScalarReductionOp::Max,
319            },
320        )
321}
322
323fn has_real_pinv_support_f64() -> bool {
324    thin_svd::has_thin_svd_support::<f64>()
325        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
326            ScalarPrimsDescriptor::PointwiseUnary {
327                op: ScalarUnaryOp::Reciprocal,
328            },
329        )
330        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
331            ScalarPrimsDescriptor::PointwiseBinary {
332                op: ScalarBinaryOp::Greater,
333            },
334        )
335        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
336            ScalarPrimsDescriptor::PointwiseBinary {
337                op: ScalarBinaryOp::Mul,
338            },
339        )
340        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
341            ScalarPrimsDescriptor::PointwiseBinary {
342                op: ScalarBinaryOp::Add,
343            },
344        )
345        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
346            ScalarPrimsDescriptor::PointwiseBinary {
347                op: ScalarBinaryOp::Sub,
348            },
349        )
350        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
351            ScalarPrimsDescriptor::Reduction {
352                modes_a: vec![0],
353                modes_c: vec![],
354                op: ScalarReductionOp::Max,
355            },
356        )
357}
358
359fn has_complex_pinv_support_c32() -> bool {
360    thin_svd::has_thin_svd_support::<Complex32>()
361        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
362            ScalarPrimsDescriptor::PointwiseUnary {
363                op: ScalarUnaryOp::Reciprocal,
364            },
365        )
366        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
367            ScalarPrimsDescriptor::PointwiseBinary {
368                op: ScalarBinaryOp::Greater,
369            },
370        )
371        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
372            ScalarPrimsDescriptor::PointwiseBinary {
373                op: ScalarBinaryOp::Mul,
374            },
375        )
376        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
377            ScalarPrimsDescriptor::PointwiseBinary {
378                op: ScalarBinaryOp::Add,
379            },
380        )
381        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
382            ScalarPrimsDescriptor::PointwiseBinary {
383                op: ScalarBinaryOp::Sub,
384            },
385        )
386        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
387            ScalarPrimsDescriptor::Reduction {
388                modes_a: vec![0],
389                modes_c: vec![],
390                op: ScalarReductionOp::Max,
391            },
392        )
393        && <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex32>>::has_complex_scale_support(
394            ComplexScalePrimsDescriptor::PointwiseMul,
395        )
396}
397
398fn has_complex_pinv_support_c64() -> bool {
399    thin_svd::has_thin_svd_support::<Complex64>()
400        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
401            ScalarPrimsDescriptor::PointwiseUnary {
402                op: ScalarUnaryOp::Reciprocal,
403            },
404        )
405        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
406            ScalarPrimsDescriptor::PointwiseBinary {
407                op: ScalarBinaryOp::Greater,
408            },
409        )
410        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
411            ScalarPrimsDescriptor::PointwiseBinary {
412                op: ScalarBinaryOp::Mul,
413            },
414        )
415        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
416            ScalarPrimsDescriptor::PointwiseBinary {
417                op: ScalarBinaryOp::Add,
418            },
419        )
420        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
421            ScalarPrimsDescriptor::PointwiseBinary {
422                op: ScalarBinaryOp::Sub,
423            },
424        )
425        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
426            ScalarPrimsDescriptor::Reduction {
427                modes_a: vec![0],
428                modes_c: vec![],
429                op: ScalarReductionOp::Max,
430            },
431        )
432        && <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex64>>::has_complex_scale_support(
433            ComplexScalePrimsDescriptor::PointwiseMul,
434        )
435}
436
437fn has_real_norm_support_f32() -> bool {
438    thin_svd::has_thin_svd_support::<f32>()
439        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
440            ScalarPrimsDescriptor::PointwiseUnary {
441                op: ScalarUnaryOp::Abs,
442            },
443        )
444        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
445            ScalarPrimsDescriptor::PointwiseBinary {
446                op: ScalarBinaryOp::Mul,
447            },
448        )
449        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
450            ScalarPrimsDescriptor::Reduction {
451                modes_a: vec![0],
452                modes_c: vec![],
453                op: ScalarReductionOp::Sum,
454            },
455        )
456        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
457            ScalarPrimsDescriptor::Reduction {
458                modes_a: vec![0],
459                modes_c: vec![],
460                op: ScalarReductionOp::Max,
461            },
462        )
463        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
464            AnalyticPrimsDescriptor::PointwiseUnary {
465                op: AnalyticUnaryOp::Sqrt,
466            },
467        )
468        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
469            AnalyticPrimsDescriptor::PointwiseBinary {
470                op: tenferro_prims::AnalyticBinaryOp::Pow,
471            },
472        )
473}
474
475fn has_real_norm_support_f64() -> bool {
476    thin_svd::has_thin_svd_support::<f64>()
477        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
478            ScalarPrimsDescriptor::PointwiseUnary {
479                op: ScalarUnaryOp::Abs,
480            },
481        )
482        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
483            ScalarPrimsDescriptor::PointwiseBinary {
484                op: ScalarBinaryOp::Mul,
485            },
486        )
487        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
488            ScalarPrimsDescriptor::Reduction {
489                modes_a: vec![0],
490                modes_c: vec![],
491                op: ScalarReductionOp::Sum,
492            },
493        )
494        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
495            ScalarPrimsDescriptor::Reduction {
496                modes_a: vec![0],
497                modes_c: vec![],
498                op: ScalarReductionOp::Max,
499            },
500        )
501        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
502            AnalyticPrimsDescriptor::PointwiseUnary {
503                op: AnalyticUnaryOp::Sqrt,
504            },
505        )
506        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
507            AnalyticPrimsDescriptor::PointwiseBinary {
508                op: tenferro_prims::AnalyticBinaryOp::Pow,
509            },
510        )
511}
512
513fn has_complex_norm_support_c32() -> bool {
514    thin_svd::has_thin_svd_support::<Complex32>()
515        && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex32>>::has_complex_real_support(
516            ComplexRealPrimsDescriptor::PointwiseUnary {
517                op: ComplexRealUnaryOp::Abs,
518            },
519        )
520        && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex32>>::has_complex_real_support(
521            ComplexRealPrimsDescriptor::Reduction {
522                modes_a: vec![0],
523                modes_c: vec![],
524                unary_op: ComplexRealUnaryOp::Abs,
525                reduction_op: ScalarReductionOp::Sum,
526            },
527        )
528        && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex32>>::has_complex_real_support(
529            ComplexRealPrimsDescriptor::Reduction {
530                modes_a: vec![0],
531                modes_c: vec![],
532                unary_op: ComplexRealUnaryOp::Abs,
533                reduction_op: ScalarReductionOp::Max,
534            },
535        )
536        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
537            ScalarPrimsDescriptor::PointwiseBinary {
538                op: ScalarBinaryOp::Mul,
539            },
540        )
541        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
542            ScalarPrimsDescriptor::Reduction {
543                modes_a: vec![0],
544                modes_c: vec![],
545                op: ScalarReductionOp::Sum,
546            },
547        )
548        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
549            ScalarPrimsDescriptor::Reduction {
550                modes_a: vec![0],
551                modes_c: vec![],
552                op: ScalarReductionOp::Max,
553            },
554        )
555        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
556            AnalyticPrimsDescriptor::PointwiseUnary {
557                op: AnalyticUnaryOp::Sqrt,
558            },
559        )
560        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
561            AnalyticPrimsDescriptor::PointwiseBinary {
562                op: tenferro_prims::AnalyticBinaryOp::Pow,
563            },
564        )
565}
566
567fn has_complex_norm_support_c64() -> bool {
568    thin_svd::has_thin_svd_support::<Complex64>()
569        && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex64>>::has_complex_real_support(
570            ComplexRealPrimsDescriptor::PointwiseUnary {
571                op: ComplexRealUnaryOp::Abs,
572            },
573        )
574        && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex64>>::has_complex_real_support(
575            ComplexRealPrimsDescriptor::Reduction {
576                modes_a: vec![0],
577                modes_c: vec![],
578                unary_op: ComplexRealUnaryOp::Abs,
579                reduction_op: ScalarReductionOp::Sum,
580            },
581        )
582        && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex64>>::has_complex_real_support(
583            ComplexRealPrimsDescriptor::Reduction {
584                modes_a: vec![0],
585                modes_c: vec![],
586                unary_op: ComplexRealUnaryOp::Abs,
587                reduction_op: ScalarReductionOp::Max,
588            },
589        )
590        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
591            ScalarPrimsDescriptor::PointwiseBinary {
592                op: ScalarBinaryOp::Mul,
593            },
594        )
595        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
596            ScalarPrimsDescriptor::Reduction {
597                modes_a: vec![0],
598                modes_c: vec![],
599                op: ScalarReductionOp::Sum,
600            },
601        )
602        && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
603            ScalarPrimsDescriptor::Reduction {
604                modes_a: vec![0],
605                modes_c: vec![],
606                op: ScalarReductionOp::Max,
607            },
608        )
609        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
610            AnalyticPrimsDescriptor::PointwiseUnary {
611                op: AnalyticUnaryOp::Sqrt,
612            },
613        )
614        && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
615            AnalyticPrimsDescriptor::PointwiseBinary {
616                op: tenferro_prims::AnalyticBinaryOp::Pow,
617            },
618        )
619}
620
621fn has_norm_support<T: CudaLinalgScalar>() -> bool {
622    match T::cuda_data_type() {
623        scalar_type::CudaDataType::F32 => has_real_norm_support_f32(),
624        scalar_type::CudaDataType::F64 => has_real_norm_support_f64(),
625        scalar_type::CudaDataType::Complex32 => has_complex_norm_support_c32(),
626        scalar_type::CudaDataType::Complex64 => has_complex_norm_support_c64(),
627    }
628}
629
630fn has_pinv_support<T: CudaLinalgScalar>() -> bool {
631    match T::cuda_data_type() {
632        scalar_type::CudaDataType::F32 => has_real_pinv_support_f32(),
633        scalar_type::CudaDataType::F64 => has_real_pinv_support_f64(),
634        scalar_type::CudaDataType::Complex32 => has_complex_pinv_support_c32(),
635        scalar_type::CudaDataType::Complex64 => has_complex_pinv_support_c64(),
636    }
637}
638
639fn has_matrix_power_support<T: CudaLinalgScalar>() -> bool {
640    solve::has_solve_support::<T>()
641        && matches!(
642            T::cuda_data_type(),
643            scalar_type::CudaDataType::F32
644                | scalar_type::CudaDataType::F64
645                | scalar_type::CudaDataType::Complex32
646                | scalar_type::CudaDataType::Complex64
647        )
648}
649
650fn has_matrix_exp_support<T: CudaLinalgScalar>() -> bool {
651    solve::has_solve_support::<T>()
652        && matches!(
653            T::cuda_data_type(),
654            scalar_type::CudaDataType::F32
655                | scalar_type::CudaDataType::F64
656                | scalar_type::CudaDataType::Complex32
657                | scalar_type::CudaDataType::Complex64
658        )
659}
660
661impl<T: CudaLinalgScalar> TensorLinalgPrims<T> for CudaTensorLinalgBackend {
662    type Context = tenferro_prims::CudaContext;
663
664    fn has_linalg_support(op: LinalgCapabilityOp) -> bool {
665        matches!(
666            op,
667            LinalgCapabilityOp::Solve
668                | LinalgCapabilityOp::LuSolve
669                | LinalgCapabilityOp::SolveEx
670                | LinalgCapabilityOp::Inv
671                | LinalgCapabilityOp::SolveTriangular
672                | LinalgCapabilityOp::Qr
673                | LinalgCapabilityOp::ThinSvd
674                | LinalgCapabilityOp::LuFactor
675                | LinalgCapabilityOp::LuFactorEx
676                | LinalgCapabilityOp::Cholesky
677                | LinalgCapabilityOp::CholeskyEx
678                | LinalgCapabilityOp::Det
679                | LinalgCapabilityOp::Slogdet
680                | LinalgCapabilityOp::Pinv
681                | LinalgCapabilityOp::MatrixPower
682                | LinalgCapabilityOp::MatrixExp
683                | LinalgCapabilityOp::Norm
684        ) && match op {
685            LinalgCapabilityOp::Solve => solve::has_solve_support::<T>(),
686            LinalgCapabilityOp::LuSolve => solve::has_solve_support::<T>(),
687            LinalgCapabilityOp::SolveEx => solve::has_solve_support::<T>(),
688            LinalgCapabilityOp::Inv => solve::has_solve_support::<T>(),
689            LinalgCapabilityOp::SolveTriangular => {
690                solve_triangular::has_solve_triangular_support::<T>()
691            }
692            LinalgCapabilityOp::Qr => qr::has_qr_support::<T>(),
693            LinalgCapabilityOp::LuFactor | LinalgCapabilityOp::LuFactorEx => {
694                lu::has_lu_support::<T>()
695            }
696            LinalgCapabilityOp::Cholesky | LinalgCapabilityOp::CholeskyEx => {
697                cholesky::has_cholesky_support::<T>()
698            }
699            LinalgCapabilityOp::ThinSvd => thin_svd::has_thin_svd_support::<T>(),
700            LinalgCapabilityOp::Det => has_det_support::<T>(),
701            LinalgCapabilityOp::Slogdet => has_slogdet_support::<T>(),
702            LinalgCapabilityOp::Pinv => has_pinv_support::<T>(),
703            LinalgCapabilityOp::MatrixPower => has_matrix_power_support::<T>(),
704            LinalgCapabilityOp::MatrixExp => has_matrix_exp_support::<T>(),
705            LinalgCapabilityOp::Norm => has_norm_support::<T>(),
706            _ => false,
707        }
708    }
709
710    fn solve_ex(
711        ctx: &mut Self::Context,
712        a: &Tensor<T>,
713        b: &Tensor<T>,
714    ) -> Result<SolveTensorExResult<T>> {
715        solve::solve_ex(ctx, a, b)
716    }
717
718    fn solve(ctx: &mut Self::Context, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>> {
719        solve::solve(ctx, a, b)
720    }
721
722    fn lu_solve(
723        ctx: &mut Self::Context,
724        factors: &Tensor<T>,
725        pivots: &Tensor<i32>,
726        b: &Tensor<T>,
727    ) -> Result<Tensor<T>> {
728        solve::lu_solve(ctx, factors, pivots, b)
729    }
730
731    fn solve_triangular(
732        ctx: &mut Self::Context,
733        a: &Tensor<T>,
734        b: &Tensor<T>,
735        upper: bool,
736    ) -> Result<Tensor<T>> {
737        solve_triangular::solve_triangular(ctx, a, b, upper)
738    }
739
740    fn qr(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<QrTensorResult<T>> {
741        qr::qr(_ctx, _a)
742    }
743
744    fn thin_svd(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<SvdTensorResult<T>> {
745        thin_svd::thin_svd(ctx, a)
746    }
747
748    fn svdvals(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<Tensor<T::Real>> {
749        svdvals::svdvals(_ctx, _a)
750    }
751
752    fn lu_factor_ex(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<LuTensorExResult<T>> {
753        lu::lu_factor_ex(_ctx, _a)
754    }
755
756    fn lu_factor(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<LuTensorResult<T>> {
757        lu::lu_factor(_ctx, _a)
758    }
759
760    fn lu_factor_no_pivot(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<LuTensorResult<T>> {
761        unsupported::<LuTensorResult<T>, T>("lu_factor_no_pivot")
762    }
763
764    fn cholesky_ex(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<CholeskyTensorExResult<T>> {
765        cholesky::cholesky_ex(_ctx, _a)
766    }
767
768    fn cholesky(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T>> {
769        cholesky::cholesky(ctx, a)
770    }
771
772    fn eigen_sym(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<EigenTensorResult<T>> {
773        unsupported::<EigenTensorResult<T>, T>("eigen_sym")
774    }
775
776    fn eig(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<EigTensorResult<T>> {
777        unsupported::<EigTensorResult<T>, T>("eig")
778    }
779}
780
781impl<T: CudaLinalgScalar> TensorLinalgContextFor<T> for tenferro_prims::CudaContext {
782    type Backend = CudaTensorLinalgBackend;
783}
784
785#[cfg(test)]
786mod tests;