1mod cholesky;
2mod lu;
3mod qr;
4mod runtime;
5mod scalar_type;
6mod solve;
7mod solve_triangular;
8mod svdvals;
9mod thin_svd;
10#[cfg(any(feature = "cuda", test))]
11mod wrappers;
12
13use num_complex::{Complex32, Complex64};
14use tenferro_algebra::Standard;
15use tenferro_device::Result;
16use tenferro_prims::{
17 AnalyticPrimsDescriptor, AnalyticUnaryOp, ComplexRealPrimsDescriptor, ComplexRealUnaryOp,
18 ComplexScalePrimsDescriptor, ScalarBinaryOp, ScalarPrimsDescriptor, ScalarReductionOp,
19 ScalarUnaryOp, TensorAnalyticPrims, TensorComplexRealPrims, TensorComplexScalePrims,
20 TensorScalarPrims,
21};
22use tenferro_tensor::Tensor;
23
24use super::TensorLinalgContextFor;
25use crate::{
26 CholeskyTensorExResult, EigTensorResult, EigenTensorResult, LinalgCapabilityOp,
27 LuTensorExResult, LuTensorResult, QrTensorResult, SolveTensorExResult, SvdTensorResult,
28 TensorLinalgPrims,
29};
30pub use scalar_type::{CudaDataType, CudaLinalgScalar};
31
32#[derive(Debug, Default, Clone, Copy)]
40pub struct CudaTensorLinalgBackend;
41
42fn unsupported<T, S: CudaLinalgScalar>(op: &str) -> Result<T> {
43 let _ = S::cuda_data_type();
44 runtime::unsupported(op)
45}
46
47fn has_real_det_support_f32() -> bool {
48 <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
49 ScalarPrimsDescriptor::PointwiseBinary {
50 op: ScalarBinaryOp::Mul,
51 },
52 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
53 ScalarPrimsDescriptor::Reduction {
54 modes_a: vec![0],
55 modes_c: vec![],
56 op: ScalarReductionOp::Prod,
57 },
58 )
59}
60
61fn has_real_det_support_f64() -> bool {
62 <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
63 ScalarPrimsDescriptor::PointwiseBinary {
64 op: ScalarBinaryOp::Mul,
65 },
66 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
67 ScalarPrimsDescriptor::Reduction {
68 modes_a: vec![0],
69 modes_c: vec![],
70 op: ScalarReductionOp::Prod,
71 },
72 )
73}
74
75fn has_complex_det_support_c32() -> bool {
76 <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex32>>::has_complex_scale_support(
77 ComplexScalePrimsDescriptor::PointwiseMul,
78 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<Complex32>>>::has_scalar_support(
79 ScalarPrimsDescriptor::Reduction {
80 modes_a: vec![0],
81 modes_c: vec![],
82 op: ScalarReductionOp::Prod,
83 },
84 )
85}
86
87fn has_complex_det_support_c64() -> bool {
88 <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex64>>::has_complex_scale_support(
89 ComplexScalePrimsDescriptor::PointwiseMul,
90 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<Complex64>>>::has_scalar_support(
91 ScalarPrimsDescriptor::Reduction {
92 modes_a: vec![0],
93 modes_c: vec![],
94 op: ScalarReductionOp::Prod,
95 },
96 )
97}
98
99fn has_complex_slogdet_support_c32() -> bool {
100 <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex32>>::has_complex_real_support(
101 ComplexRealPrimsDescriptor::PointwiseUnary {
102 op: ComplexRealUnaryOp::Abs,
103 },
104 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
105 ScalarPrimsDescriptor::PointwiseUnary {
106 op: ScalarUnaryOp::Reciprocal,
107 },
108 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
109 ScalarPrimsDescriptor::PointwiseBinary {
110 op: ScalarBinaryOp::Greater,
111 },
112 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
113 ScalarPrimsDescriptor::PointwiseTernary {
114 op: tenferro_prims::ScalarTernaryOp::Where,
115 },
116 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
117 ScalarPrimsDescriptor::Reduction {
118 modes_a: vec![0],
119 modes_c: vec![],
120 op: ScalarReductionOp::Sum,
121 },
122 ) && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
123 AnalyticPrimsDescriptor::PointwiseUnary {
124 op: AnalyticUnaryOp::Log,
125 },
126 ) && <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex32>>::has_complex_scale_support(
127 ComplexScalePrimsDescriptor::PointwiseMul,
128 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<Complex32>>>::has_scalar_support(
129 ScalarPrimsDescriptor::Reduction {
130 modes_a: vec![0],
131 modes_c: vec![],
132 op: ScalarReductionOp::Prod,
133 },
134 )
135}
136
137fn has_complex_slogdet_support_c64() -> bool {
138 <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex64>>::has_complex_real_support(
139 ComplexRealPrimsDescriptor::PointwiseUnary {
140 op: ComplexRealUnaryOp::Abs,
141 },
142 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
143 ScalarPrimsDescriptor::PointwiseUnary {
144 op: ScalarUnaryOp::Reciprocal,
145 },
146 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
147 ScalarPrimsDescriptor::PointwiseBinary {
148 op: ScalarBinaryOp::Greater,
149 },
150 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
151 ScalarPrimsDescriptor::PointwiseTernary {
152 op: tenferro_prims::ScalarTernaryOp::Where,
153 },
154 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
155 ScalarPrimsDescriptor::Reduction {
156 modes_a: vec![0],
157 modes_c: vec![],
158 op: ScalarReductionOp::Sum,
159 },
160 ) && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
161 AnalyticPrimsDescriptor::PointwiseUnary {
162 op: AnalyticUnaryOp::Log,
163 },
164 ) && <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex64>>::has_complex_scale_support(
165 ComplexScalePrimsDescriptor::PointwiseMul,
166 ) && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<Complex64>>>::has_scalar_support(
167 ScalarPrimsDescriptor::Reduction {
168 modes_a: vec![0],
169 modes_c: vec![],
170 op: ScalarReductionOp::Prod,
171 },
172 )
173}
174
175fn has_det_support<T: CudaLinalgScalar>() -> bool {
176 match T::cuda_data_type() {
177 scalar_type::CudaDataType::F32 => lu::has_lu_support::<T>() && has_real_det_support_f32(),
178 scalar_type::CudaDataType::F64 => lu::has_lu_support::<T>() && has_real_det_support_f64(),
179 scalar_type::CudaDataType::Complex32 => {
180 lu::has_lu_support::<T>() && has_complex_det_support_c32()
181 }
182 scalar_type::CudaDataType::Complex64 => {
183 lu::has_lu_support::<T>() && has_complex_det_support_c64()
184 }
185 }
186}
187
188fn has_real_slogdet_support_f32() -> bool {
189 has_real_det_support_f32()
190 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
191 ScalarPrimsDescriptor::PointwiseUnary {
192 op: ScalarUnaryOp::Abs,
193 },
194 )
195 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
196 ScalarPrimsDescriptor::PointwiseBinary {
197 op: ScalarBinaryOp::Greater,
198 },
199 )
200 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
201 ScalarPrimsDescriptor::PointwiseBinary {
202 op: ScalarBinaryOp::Mul,
203 },
204 )
205 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
206 ScalarPrimsDescriptor::PointwiseBinary {
207 op: ScalarBinaryOp::Add,
208 },
209 )
210 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
211 ScalarPrimsDescriptor::PointwiseBinary {
212 op: ScalarBinaryOp::Sub,
213 },
214 )
215 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
216 ScalarPrimsDescriptor::Reduction {
217 modes_a: vec![0],
218 modes_c: vec![],
219 op: ScalarReductionOp::Sum,
220 },
221 )
222 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
223 AnalyticPrimsDescriptor::PointwiseUnary {
224 op: AnalyticUnaryOp::Log,
225 },
226 )
227}
228
229fn has_real_slogdet_support_f64() -> bool {
230 has_real_det_support_f64()
231 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
232 ScalarPrimsDescriptor::PointwiseUnary {
233 op: ScalarUnaryOp::Abs,
234 },
235 )
236 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
237 ScalarPrimsDescriptor::PointwiseBinary {
238 op: ScalarBinaryOp::Greater,
239 },
240 )
241 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
242 ScalarPrimsDescriptor::PointwiseBinary {
243 op: ScalarBinaryOp::Mul,
244 },
245 )
246 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
247 ScalarPrimsDescriptor::PointwiseBinary {
248 op: ScalarBinaryOp::Add,
249 },
250 )
251 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
252 ScalarPrimsDescriptor::PointwiseBinary {
253 op: ScalarBinaryOp::Sub,
254 },
255 )
256 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
257 ScalarPrimsDescriptor::Reduction {
258 modes_a: vec![0],
259 modes_c: vec![],
260 op: ScalarReductionOp::Sum,
261 },
262 )
263 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
264 AnalyticPrimsDescriptor::PointwiseUnary {
265 op: AnalyticUnaryOp::Log,
266 },
267 )
268}
269
270fn has_slogdet_support<T: CudaLinalgScalar>() -> bool {
271 match T::cuda_data_type() {
272 scalar_type::CudaDataType::F32 => {
273 lu::has_lu_support::<T>() && has_real_slogdet_support_f32()
274 }
275 scalar_type::CudaDataType::F64 => {
276 lu::has_lu_support::<T>() && has_real_slogdet_support_f64()
277 }
278 scalar_type::CudaDataType::Complex32 => {
279 lu::has_lu_support::<T>() && has_complex_slogdet_support_c32()
280 }
281 scalar_type::CudaDataType::Complex64 => {
282 lu::has_lu_support::<T>() && has_complex_slogdet_support_c64()
283 }
284 }
285}
286
287fn has_real_pinv_support_f32() -> bool {
288 thin_svd::has_thin_svd_support::<f32>()
289 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
290 ScalarPrimsDescriptor::PointwiseUnary {
291 op: ScalarUnaryOp::Reciprocal,
292 },
293 )
294 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
295 ScalarPrimsDescriptor::PointwiseBinary {
296 op: ScalarBinaryOp::Greater,
297 },
298 )
299 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
300 ScalarPrimsDescriptor::PointwiseBinary {
301 op: ScalarBinaryOp::Mul,
302 },
303 )
304 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
305 ScalarPrimsDescriptor::PointwiseBinary {
306 op: ScalarBinaryOp::Add,
307 },
308 )
309 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
310 ScalarPrimsDescriptor::PointwiseBinary {
311 op: ScalarBinaryOp::Sub,
312 },
313 )
314 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
315 ScalarPrimsDescriptor::Reduction {
316 modes_a: vec![0],
317 modes_c: vec![],
318 op: ScalarReductionOp::Max,
319 },
320 )
321}
322
323fn has_real_pinv_support_f64() -> bool {
324 thin_svd::has_thin_svd_support::<f64>()
325 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
326 ScalarPrimsDescriptor::PointwiseUnary {
327 op: ScalarUnaryOp::Reciprocal,
328 },
329 )
330 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
331 ScalarPrimsDescriptor::PointwiseBinary {
332 op: ScalarBinaryOp::Greater,
333 },
334 )
335 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
336 ScalarPrimsDescriptor::PointwiseBinary {
337 op: ScalarBinaryOp::Mul,
338 },
339 )
340 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
341 ScalarPrimsDescriptor::PointwiseBinary {
342 op: ScalarBinaryOp::Add,
343 },
344 )
345 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
346 ScalarPrimsDescriptor::PointwiseBinary {
347 op: ScalarBinaryOp::Sub,
348 },
349 )
350 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
351 ScalarPrimsDescriptor::Reduction {
352 modes_a: vec![0],
353 modes_c: vec![],
354 op: ScalarReductionOp::Max,
355 },
356 )
357}
358
359fn has_complex_pinv_support_c32() -> bool {
360 thin_svd::has_thin_svd_support::<Complex32>()
361 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
362 ScalarPrimsDescriptor::PointwiseUnary {
363 op: ScalarUnaryOp::Reciprocal,
364 },
365 )
366 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
367 ScalarPrimsDescriptor::PointwiseBinary {
368 op: ScalarBinaryOp::Greater,
369 },
370 )
371 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
372 ScalarPrimsDescriptor::PointwiseBinary {
373 op: ScalarBinaryOp::Mul,
374 },
375 )
376 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
377 ScalarPrimsDescriptor::PointwiseBinary {
378 op: ScalarBinaryOp::Add,
379 },
380 )
381 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
382 ScalarPrimsDescriptor::PointwiseBinary {
383 op: ScalarBinaryOp::Sub,
384 },
385 )
386 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
387 ScalarPrimsDescriptor::Reduction {
388 modes_a: vec![0],
389 modes_c: vec![],
390 op: ScalarReductionOp::Max,
391 },
392 )
393 && <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex32>>::has_complex_scale_support(
394 ComplexScalePrimsDescriptor::PointwiseMul,
395 )
396}
397
398fn has_complex_pinv_support_c64() -> bool {
399 thin_svd::has_thin_svd_support::<Complex64>()
400 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
401 ScalarPrimsDescriptor::PointwiseUnary {
402 op: ScalarUnaryOp::Reciprocal,
403 },
404 )
405 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
406 ScalarPrimsDescriptor::PointwiseBinary {
407 op: ScalarBinaryOp::Greater,
408 },
409 )
410 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
411 ScalarPrimsDescriptor::PointwiseBinary {
412 op: ScalarBinaryOp::Mul,
413 },
414 )
415 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
416 ScalarPrimsDescriptor::PointwiseBinary {
417 op: ScalarBinaryOp::Add,
418 },
419 )
420 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
421 ScalarPrimsDescriptor::PointwiseBinary {
422 op: ScalarBinaryOp::Sub,
423 },
424 )
425 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
426 ScalarPrimsDescriptor::Reduction {
427 modes_a: vec![0],
428 modes_c: vec![],
429 op: ScalarReductionOp::Max,
430 },
431 )
432 && <tenferro_prims::CudaBackend as TensorComplexScalePrims<Complex64>>::has_complex_scale_support(
433 ComplexScalePrimsDescriptor::PointwiseMul,
434 )
435}
436
437fn has_real_norm_support_f32() -> bool {
438 thin_svd::has_thin_svd_support::<f32>()
439 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
440 ScalarPrimsDescriptor::PointwiseUnary {
441 op: ScalarUnaryOp::Abs,
442 },
443 )
444 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
445 ScalarPrimsDescriptor::PointwiseBinary {
446 op: ScalarBinaryOp::Mul,
447 },
448 )
449 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
450 ScalarPrimsDescriptor::Reduction {
451 modes_a: vec![0],
452 modes_c: vec![],
453 op: ScalarReductionOp::Sum,
454 },
455 )
456 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
457 ScalarPrimsDescriptor::Reduction {
458 modes_a: vec![0],
459 modes_c: vec![],
460 op: ScalarReductionOp::Max,
461 },
462 )
463 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
464 AnalyticPrimsDescriptor::PointwiseUnary {
465 op: AnalyticUnaryOp::Sqrt,
466 },
467 )
468 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
469 AnalyticPrimsDescriptor::PointwiseBinary {
470 op: tenferro_prims::AnalyticBinaryOp::Pow,
471 },
472 )
473}
474
475fn has_real_norm_support_f64() -> bool {
476 thin_svd::has_thin_svd_support::<f64>()
477 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
478 ScalarPrimsDescriptor::PointwiseUnary {
479 op: ScalarUnaryOp::Abs,
480 },
481 )
482 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
483 ScalarPrimsDescriptor::PointwiseBinary {
484 op: ScalarBinaryOp::Mul,
485 },
486 )
487 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
488 ScalarPrimsDescriptor::Reduction {
489 modes_a: vec![0],
490 modes_c: vec![],
491 op: ScalarReductionOp::Sum,
492 },
493 )
494 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
495 ScalarPrimsDescriptor::Reduction {
496 modes_a: vec![0],
497 modes_c: vec![],
498 op: ScalarReductionOp::Max,
499 },
500 )
501 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
502 AnalyticPrimsDescriptor::PointwiseUnary {
503 op: AnalyticUnaryOp::Sqrt,
504 },
505 )
506 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
507 AnalyticPrimsDescriptor::PointwiseBinary {
508 op: tenferro_prims::AnalyticBinaryOp::Pow,
509 },
510 )
511}
512
513fn has_complex_norm_support_c32() -> bool {
514 thin_svd::has_thin_svd_support::<Complex32>()
515 && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex32>>::has_complex_real_support(
516 ComplexRealPrimsDescriptor::PointwiseUnary {
517 op: ComplexRealUnaryOp::Abs,
518 },
519 )
520 && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex32>>::has_complex_real_support(
521 ComplexRealPrimsDescriptor::Reduction {
522 modes_a: vec![0],
523 modes_c: vec![],
524 unary_op: ComplexRealUnaryOp::Abs,
525 reduction_op: ScalarReductionOp::Sum,
526 },
527 )
528 && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex32>>::has_complex_real_support(
529 ComplexRealPrimsDescriptor::Reduction {
530 modes_a: vec![0],
531 modes_c: vec![],
532 unary_op: ComplexRealUnaryOp::Abs,
533 reduction_op: ScalarReductionOp::Max,
534 },
535 )
536 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
537 ScalarPrimsDescriptor::PointwiseBinary {
538 op: ScalarBinaryOp::Mul,
539 },
540 )
541 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
542 ScalarPrimsDescriptor::Reduction {
543 modes_a: vec![0],
544 modes_c: vec![],
545 op: ScalarReductionOp::Sum,
546 },
547 )
548 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f32>>>::has_scalar_support(
549 ScalarPrimsDescriptor::Reduction {
550 modes_a: vec![0],
551 modes_c: vec![],
552 op: ScalarReductionOp::Max,
553 },
554 )
555 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
556 AnalyticPrimsDescriptor::PointwiseUnary {
557 op: AnalyticUnaryOp::Sqrt,
558 },
559 )
560 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f32>>>::has_analytic_support(
561 AnalyticPrimsDescriptor::PointwiseBinary {
562 op: tenferro_prims::AnalyticBinaryOp::Pow,
563 },
564 )
565}
566
567fn has_complex_norm_support_c64() -> bool {
568 thin_svd::has_thin_svd_support::<Complex64>()
569 && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex64>>::has_complex_real_support(
570 ComplexRealPrimsDescriptor::PointwiseUnary {
571 op: ComplexRealUnaryOp::Abs,
572 },
573 )
574 && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex64>>::has_complex_real_support(
575 ComplexRealPrimsDescriptor::Reduction {
576 modes_a: vec![0],
577 modes_c: vec![],
578 unary_op: ComplexRealUnaryOp::Abs,
579 reduction_op: ScalarReductionOp::Sum,
580 },
581 )
582 && <tenferro_prims::CudaBackend as TensorComplexRealPrims<Complex64>>::has_complex_real_support(
583 ComplexRealPrimsDescriptor::Reduction {
584 modes_a: vec![0],
585 modes_c: vec![],
586 unary_op: ComplexRealUnaryOp::Abs,
587 reduction_op: ScalarReductionOp::Max,
588 },
589 )
590 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
591 ScalarPrimsDescriptor::PointwiseBinary {
592 op: ScalarBinaryOp::Mul,
593 },
594 )
595 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
596 ScalarPrimsDescriptor::Reduction {
597 modes_a: vec![0],
598 modes_c: vec![],
599 op: ScalarReductionOp::Sum,
600 },
601 )
602 && <tenferro_prims::CudaBackend as TensorScalarPrims<Standard<f64>>>::has_scalar_support(
603 ScalarPrimsDescriptor::Reduction {
604 modes_a: vec![0],
605 modes_c: vec![],
606 op: ScalarReductionOp::Max,
607 },
608 )
609 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
610 AnalyticPrimsDescriptor::PointwiseUnary {
611 op: AnalyticUnaryOp::Sqrt,
612 },
613 )
614 && <tenferro_prims::CudaBackend as TensorAnalyticPrims<Standard<f64>>>::has_analytic_support(
615 AnalyticPrimsDescriptor::PointwiseBinary {
616 op: tenferro_prims::AnalyticBinaryOp::Pow,
617 },
618 )
619}
620
621fn has_norm_support<T: CudaLinalgScalar>() -> bool {
622 match T::cuda_data_type() {
623 scalar_type::CudaDataType::F32 => has_real_norm_support_f32(),
624 scalar_type::CudaDataType::F64 => has_real_norm_support_f64(),
625 scalar_type::CudaDataType::Complex32 => has_complex_norm_support_c32(),
626 scalar_type::CudaDataType::Complex64 => has_complex_norm_support_c64(),
627 }
628}
629
630fn has_pinv_support<T: CudaLinalgScalar>() -> bool {
631 match T::cuda_data_type() {
632 scalar_type::CudaDataType::F32 => has_real_pinv_support_f32(),
633 scalar_type::CudaDataType::F64 => has_real_pinv_support_f64(),
634 scalar_type::CudaDataType::Complex32 => has_complex_pinv_support_c32(),
635 scalar_type::CudaDataType::Complex64 => has_complex_pinv_support_c64(),
636 }
637}
638
639fn has_matrix_power_support<T: CudaLinalgScalar>() -> bool {
640 solve::has_solve_support::<T>()
641 && matches!(
642 T::cuda_data_type(),
643 scalar_type::CudaDataType::F32
644 | scalar_type::CudaDataType::F64
645 | scalar_type::CudaDataType::Complex32
646 | scalar_type::CudaDataType::Complex64
647 )
648}
649
650fn has_matrix_exp_support<T: CudaLinalgScalar>() -> bool {
651 solve::has_solve_support::<T>()
652 && matches!(
653 T::cuda_data_type(),
654 scalar_type::CudaDataType::F32
655 | scalar_type::CudaDataType::F64
656 | scalar_type::CudaDataType::Complex32
657 | scalar_type::CudaDataType::Complex64
658 )
659}
660
661impl<T: CudaLinalgScalar> TensorLinalgPrims<T> for CudaTensorLinalgBackend {
662 type Context = tenferro_prims::CudaContext;
663
664 fn has_linalg_support(op: LinalgCapabilityOp) -> bool {
665 matches!(
666 op,
667 LinalgCapabilityOp::Solve
668 | LinalgCapabilityOp::LuSolve
669 | LinalgCapabilityOp::SolveEx
670 | LinalgCapabilityOp::Inv
671 | LinalgCapabilityOp::SolveTriangular
672 | LinalgCapabilityOp::Qr
673 | LinalgCapabilityOp::ThinSvd
674 | LinalgCapabilityOp::LuFactor
675 | LinalgCapabilityOp::LuFactorEx
676 | LinalgCapabilityOp::Cholesky
677 | LinalgCapabilityOp::CholeskyEx
678 | LinalgCapabilityOp::Det
679 | LinalgCapabilityOp::Slogdet
680 | LinalgCapabilityOp::Pinv
681 | LinalgCapabilityOp::MatrixPower
682 | LinalgCapabilityOp::MatrixExp
683 | LinalgCapabilityOp::Norm
684 ) && match op {
685 LinalgCapabilityOp::Solve => solve::has_solve_support::<T>(),
686 LinalgCapabilityOp::LuSolve => solve::has_solve_support::<T>(),
687 LinalgCapabilityOp::SolveEx => solve::has_solve_support::<T>(),
688 LinalgCapabilityOp::Inv => solve::has_solve_support::<T>(),
689 LinalgCapabilityOp::SolveTriangular => {
690 solve_triangular::has_solve_triangular_support::<T>()
691 }
692 LinalgCapabilityOp::Qr => qr::has_qr_support::<T>(),
693 LinalgCapabilityOp::LuFactor | LinalgCapabilityOp::LuFactorEx => {
694 lu::has_lu_support::<T>()
695 }
696 LinalgCapabilityOp::Cholesky | LinalgCapabilityOp::CholeskyEx => {
697 cholesky::has_cholesky_support::<T>()
698 }
699 LinalgCapabilityOp::ThinSvd => thin_svd::has_thin_svd_support::<T>(),
700 LinalgCapabilityOp::Det => has_det_support::<T>(),
701 LinalgCapabilityOp::Slogdet => has_slogdet_support::<T>(),
702 LinalgCapabilityOp::Pinv => has_pinv_support::<T>(),
703 LinalgCapabilityOp::MatrixPower => has_matrix_power_support::<T>(),
704 LinalgCapabilityOp::MatrixExp => has_matrix_exp_support::<T>(),
705 LinalgCapabilityOp::Norm => has_norm_support::<T>(),
706 _ => false,
707 }
708 }
709
710 fn solve_ex(
711 ctx: &mut Self::Context,
712 a: &Tensor<T>,
713 b: &Tensor<T>,
714 ) -> Result<SolveTensorExResult<T>> {
715 solve::solve_ex(ctx, a, b)
716 }
717
718 fn solve(ctx: &mut Self::Context, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>> {
719 solve::solve(ctx, a, b)
720 }
721
722 fn lu_solve(
723 ctx: &mut Self::Context,
724 factors: &Tensor<T>,
725 pivots: &Tensor<i32>,
726 b: &Tensor<T>,
727 ) -> Result<Tensor<T>> {
728 solve::lu_solve(ctx, factors, pivots, b)
729 }
730
731 fn solve_triangular(
732 ctx: &mut Self::Context,
733 a: &Tensor<T>,
734 b: &Tensor<T>,
735 upper: bool,
736 ) -> Result<Tensor<T>> {
737 solve_triangular::solve_triangular(ctx, a, b, upper)
738 }
739
740 fn qr(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<QrTensorResult<T>> {
741 qr::qr(_ctx, _a)
742 }
743
744 fn thin_svd(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<SvdTensorResult<T>> {
745 thin_svd::thin_svd(ctx, a)
746 }
747
748 fn svdvals(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<Tensor<T::Real>> {
749 svdvals::svdvals(_ctx, _a)
750 }
751
752 fn lu_factor_ex(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<LuTensorExResult<T>> {
753 lu::lu_factor_ex(_ctx, _a)
754 }
755
756 fn lu_factor(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<LuTensorResult<T>> {
757 lu::lu_factor(_ctx, _a)
758 }
759
760 fn lu_factor_no_pivot(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<LuTensorResult<T>> {
761 unsupported::<LuTensorResult<T>, T>("lu_factor_no_pivot")
762 }
763
764 fn cholesky_ex(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<CholeskyTensorExResult<T>> {
765 cholesky::cholesky_ex(_ctx, _a)
766 }
767
768 fn cholesky(ctx: &mut Self::Context, a: &Tensor<T>) -> Result<Tensor<T>> {
769 cholesky::cholesky(ctx, a)
770 }
771
772 fn eigen_sym(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<EigenTensorResult<T>> {
773 unsupported::<EigenTensorResult<T>, T>("eigen_sym")
774 }
775
776 fn eig(_ctx: &mut Self::Context, _a: &Tensor<T>) -> Result<EigTensorResult<T>> {
777 unsupported::<EigTensorResult<T>, T>("eig")
778 }
779}
780
781impl<T: CudaLinalgScalar> TensorLinalgContextFor<T> for tenferro_prims::CudaContext {
782 type Backend = CudaTensorLinalgBackend;
783}
784
785#[cfg(test)]
786mod tests;