1use crate::backend::LinalgBackend;
2
3use super::linalg;
4
5use num_complex::{Complex32, Complex64};
6use tenferro_cpu::{CpuBackend, CpuBackendKind};
7use tenferro_tensor::{
8 validate::validate_nonsingular_u, DType, Error, Tensor, TensorElementwise, TensorStructural,
9 TensorView, TensorViewCanonicalization, TypedTensor,
10};
11
12impl LinalgBackend for CpuBackend {
13 fn cholesky(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
14 ensure_host_tensor("cholesky", input)?;
15 match self.kind() {
16 CpuBackendKind::Faer => {
17 #[cfg(feature = "cpu-faer")]
18 {
19 let ctx = self.linalg_context();
20 self.with_linalg_pool(|buffers| match input {
21 Tensor::F32(t) => {
22 linalg::faer::cholesky(ctx.as_ref(), buffers, t).map(Tensor::F32)
23 }
24 Tensor::F64(t) => {
25 linalg::faer::cholesky(ctx.as_ref(), buffers, t).map(Tensor::F64)
26 }
27 Tensor::C32(t) => {
28 linalg::faer::cholesky(ctx.as_ref(), buffers, t).map(Tensor::C32)
29 }
30 Tensor::C64(t) => {
31 linalg::faer::cholesky(ctx.as_ref(), buffers, t).map(Tensor::C64)
32 }
33 _ => Err(unsupported_dtype("cholesky", input.dtype())),
34 })
35 }
36 #[cfg(not(feature = "cpu-faer"))]
37 {
38 Err(unsupported_provider("cholesky", self.kind()))
39 }
40 }
41 CpuBackendKind::Blas => {
42 #[cfg(feature = "cpu-blas")]
43 {
44 self.with_linalg_pool(|buffers| match input {
45 Tensor::F32(t) => linalg::blas::cholesky(buffers, t).map(Tensor::F32),
46 Tensor::F64(t) => linalg::blas::cholesky(buffers, t).map(Tensor::F64),
47 Tensor::C32(t) => linalg::blas::cholesky(buffers, t).map(Tensor::C32),
48 Tensor::C64(t) => linalg::blas::cholesky(buffers, t).map(Tensor::C64),
49 _ => Err(unsupported_dtype("cholesky", input.dtype())),
50 })
51 }
52 #[cfg(not(feature = "cpu-blas"))]
53 {
54 Err(unsupported_provider("cholesky", self.kind()))
55 }
56 }
57 }
58 }
59
60 fn triangular_solve(
61 &mut self,
62 a: &Tensor,
63 b: &Tensor,
64 left_side: bool,
65 lower: bool,
66 transpose_a: bool,
67 unit_diagonal: bool,
68 ) -> tenferro_tensor::Result<Tensor> {
69 ensure_host_tensor("triangular_solve", a)?;
70 ensure_host_tensor("triangular_solve", b)?;
71 match self.kind() {
72 CpuBackendKind::Faer => {
73 #[cfg(feature = "cpu-faer")]
74 {
75 let ctx = self.linalg_context();
76 self.with_linalg_pool(|buffers| match (a, b) {
77 (Tensor::F32(a), Tensor::F32(b)) => linalg::faer::triangular_solve(
78 ctx.as_ref(),
79 buffers,
80 a,
81 b,
82 left_side,
83 lower,
84 transpose_a,
85 unit_diagonal,
86 )
87 .map(Tensor::F32),
88 (Tensor::F64(a), Tensor::F64(b)) => linalg::faer::triangular_solve(
89 ctx.as_ref(),
90 buffers,
91 a,
92 b,
93 left_side,
94 lower,
95 transpose_a,
96 unit_diagonal,
97 )
98 .map(Tensor::F64),
99 (Tensor::C32(a), Tensor::C32(b)) => linalg::faer::triangular_solve(
100 ctx.as_ref(),
101 buffers,
102 a,
103 b,
104 left_side,
105 lower,
106 transpose_a,
107 unit_diagonal,
108 )
109 .map(Tensor::C32),
110 (Tensor::C64(a), Tensor::C64(b)) => linalg::faer::triangular_solve(
111 ctx.as_ref(),
112 buffers,
113 a,
114 b,
115 left_side,
116 lower,
117 transpose_a,
118 unit_diagonal,
119 )
120 .map(Tensor::C64),
121 _ => unsupported_pair("triangular_solve", a, b),
122 })
123 }
124 #[cfg(not(feature = "cpu-faer"))]
125 {
126 Err(unsupported_provider("triangular_solve", self.kind()))
127 }
128 }
129 CpuBackendKind::Blas => {
130 #[cfg(feature = "cpu-blas")]
131 {
132 self.with_linalg_pool(|buffers| match (a, b) {
133 (Tensor::F32(a), Tensor::F32(b)) => linalg::blas::triangular_solve(
134 buffers,
135 a,
136 b,
137 left_side,
138 lower,
139 transpose_a,
140 unit_diagonal,
141 )
142 .map(Tensor::F32),
143 (Tensor::F64(a), Tensor::F64(b)) => linalg::blas::triangular_solve(
144 buffers,
145 a,
146 b,
147 left_side,
148 lower,
149 transpose_a,
150 unit_diagonal,
151 )
152 .map(Tensor::F64),
153 (Tensor::C32(a), Tensor::C32(b)) => linalg::blas::triangular_solve(
154 buffers,
155 a,
156 b,
157 left_side,
158 lower,
159 transpose_a,
160 unit_diagonal,
161 )
162 .map(Tensor::C32),
163 (Tensor::C64(a), Tensor::C64(b)) => linalg::blas::triangular_solve(
164 buffers,
165 a,
166 b,
167 left_side,
168 lower,
169 transpose_a,
170 unit_diagonal,
171 )
172 .map(Tensor::C64),
173 _ => unsupported_pair("triangular_solve", a, b),
174 })
175 }
176 #[cfg(not(feature = "cpu-blas"))]
177 {
178 Err(unsupported_provider("triangular_solve", self.kind()))
179 }
180 }
181 }
182 }
183
184 fn lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
185 ensure_host_tensor("lu", input)?;
186 match self.kind() {
187 CpuBackendKind::Faer => {
188 #[cfg(feature = "cpu-faer")]
189 {
190 let ctx = self.linalg_context();
191 self.with_linalg_pool(|buffers| match input {
192 Tensor::F32(t) => linalg::faer::lu(ctx.as_ref(), buffers, t)
193 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
194 Tensor::F64(t) => linalg::faer::lu(ctx.as_ref(), buffers, t)
195 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
196 Tensor::C32(t) => linalg::faer::lu(ctx.as_ref(), buffers, t)
197 .map(|outputs| outputs.into_iter().map(Tensor::C32).collect()),
198 Tensor::C64(t) => linalg::faer::lu(ctx.as_ref(), buffers, t)
199 .map(|outputs| outputs.into_iter().map(Tensor::C64).collect()),
200 _ => Err(unsupported_dtype("lu", input.dtype())),
201 })
202 }
203 #[cfg(not(feature = "cpu-faer"))]
204 {
205 Err(unsupported_provider("lu", self.kind()))
206 }
207 }
208 CpuBackendKind::Blas => {
209 #[cfg(feature = "cpu-blas")]
210 {
211 self.with_linalg_pool(|buffers| match input {
212 Tensor::F32(t) => linalg::blas::lu(buffers, t)
213 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
214 Tensor::F64(t) => linalg::blas::lu(buffers, t)
215 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
216 Tensor::C32(t) => linalg::blas::lu(buffers, t)
217 .map(|outputs| outputs.into_iter().map(Tensor::C32).collect()),
218 Tensor::C64(t) => linalg::blas::lu(buffers, t)
219 .map(|outputs| outputs.into_iter().map(Tensor::C64).collect()),
220 _ => Err(unsupported_dtype("lu", input.dtype())),
221 })
222 }
223 #[cfg(not(feature = "cpu-blas"))]
224 {
225 Err(unsupported_provider("lu", self.kind()))
226 }
227 }
228 }
229 }
230
231 fn lu_factor(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
232 ensure_host_tensor("lu_factor", input)?;
233 match self.kind() {
234 CpuBackendKind::Faer => {
235 #[cfg(feature = "cpu-faer")]
236 {
237 let ctx = self.linalg_context();
238 self.with_linalg_pool(|buffers| match input {
239 Tensor::F32(t) => linalg::faer::lu_factor(ctx.as_ref(), buffers, t).map(
240 |(lu, pivots, parity)| {
241 vec![Tensor::F32(lu), Tensor::I32(pivots), Tensor::F32(parity)]
242 },
243 ),
244 Tensor::F64(t) => linalg::faer::lu_factor(ctx.as_ref(), buffers, t).map(
245 |(lu, pivots, parity)| {
246 vec![Tensor::F64(lu), Tensor::I32(pivots), Tensor::F64(parity)]
247 },
248 ),
249 Tensor::C32(t) => linalg::faer::lu_factor(ctx.as_ref(), buffers, t).map(
250 |(lu, pivots, parity)| {
251 vec![Tensor::C32(lu), Tensor::I32(pivots), Tensor::C32(parity)]
252 },
253 ),
254 Tensor::C64(t) => linalg::faer::lu_factor(ctx.as_ref(), buffers, t).map(
255 |(lu, pivots, parity)| {
256 vec![Tensor::C64(lu), Tensor::I32(pivots), Tensor::C64(parity)]
257 },
258 ),
259 _ => Err(unsupported_dtype("lu_factor", input.dtype())),
260 })
261 }
262 #[cfg(not(feature = "cpu-faer"))]
263 {
264 Err(unsupported_provider("lu_factor", self.kind()))
265 }
266 }
267 CpuBackendKind::Blas => {
268 #[cfg(feature = "cpu-blas")]
269 {
270 self.with_linalg_pool(|buffers| match input {
271 Tensor::F32(t) => {
272 linalg::blas::lu_factor(buffers, t).map(|(lu, pivots, parity)| {
273 vec![Tensor::F32(lu), Tensor::I32(pivots), Tensor::F32(parity)]
274 })
275 }
276 Tensor::F64(t) => {
277 linalg::blas::lu_factor(buffers, t).map(|(lu, pivots, parity)| {
278 vec![Tensor::F64(lu), Tensor::I32(pivots), Tensor::F64(parity)]
279 })
280 }
281 Tensor::C32(t) => {
282 linalg::blas::lu_factor(buffers, t).map(|(lu, pivots, parity)| {
283 vec![Tensor::C32(lu), Tensor::I32(pivots), Tensor::C32(parity)]
284 })
285 }
286 Tensor::C64(t) => {
287 linalg::blas::lu_factor(buffers, t).map(|(lu, pivots, parity)| {
288 vec![Tensor::C64(lu), Tensor::I32(pivots), Tensor::C64(parity)]
289 })
290 }
291 _ => Err(unsupported_dtype("lu_factor", input.dtype())),
292 })
293 }
294 #[cfg(not(feature = "cpu-blas"))]
295 {
296 Err(unsupported_provider("lu_factor", self.kind()))
297 }
298 }
299 }
300 }
301
302 fn full_piv_lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
303 ensure_host_tensor("full_piv_lu", input)?;
304 match self.kind() {
305 CpuBackendKind::Faer => {
306 #[cfg(feature = "cpu-faer")]
307 {
308 let ctx = self.linalg_context();
309 self.with_linalg_pool(|buffers| match input {
310 Tensor::F32(t) => linalg::faer::full_piv_lu(ctx.as_ref(), buffers, t)
311 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
312 Tensor::F64(t) => linalg::faer::full_piv_lu(ctx.as_ref(), buffers, t)
313 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
314 Tensor::C32(t) => linalg::faer::full_piv_lu(ctx.as_ref(), buffers, t)
315 .and_then(full_piv_lu_c32_outputs_to_public_tensors),
316 Tensor::C64(t) => linalg::faer::full_piv_lu(ctx.as_ref(), buffers, t)
317 .and_then(full_piv_lu_c64_outputs_to_public_tensors),
318 _ => Err(unsupported_dtype("full_piv_lu", input.dtype())),
319 })
320 }
321 #[cfg(not(feature = "cpu-faer"))]
322 {
323 Err(unsupported_provider("full_piv_lu", self.kind()))
324 }
325 }
326 CpuBackendKind::Blas => {
327 #[cfg(feature = "cpu-blas")]
328 {
329 self.with_linalg_pool(|buffers| match input {
330 Tensor::F32(t) => linalg::blas::full_piv_lu(buffers, t)
331 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
332 Tensor::F64(t) => linalg::blas::full_piv_lu(buffers, t)
333 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
334 Tensor::C32(t) => linalg::blas::full_piv_lu(buffers, t)
335 .and_then(full_piv_lu_c32_outputs_to_public_tensors),
336 Tensor::C64(t) => linalg::blas::full_piv_lu(buffers, t)
337 .and_then(full_piv_lu_c64_outputs_to_public_tensors),
338 _ => Err(unsupported_dtype("full_piv_lu", input.dtype())),
339 })
340 }
341 #[cfg(not(feature = "cpu-blas"))]
342 {
343 Err(unsupported_provider("full_piv_lu", self.kind()))
344 }
345 }
346 }
347 }
348
349 fn full_piv_lu_solve(
350 &mut self,
351 a: &Tensor,
352 b: &Tensor,
353 transpose_a: bool,
354 ) -> tenferro_tensor::Result<Tensor> {
355 ensure_host_tensor("full_piv_lu_solve", a)?;
356 ensure_host_tensor("full_piv_lu_solve", b)?;
357 ensure_supported_linalg_pair("full_piv_lu_solve", a, b)?;
358 if has_zero_dim(a.shape()) || has_zero_dim(b.shape()) {
359 return zeros_like_tensor(b);
360 }
361
362 let (rhs, restore_shape) = if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
363 (
364 self.reshape(b, &matrix_rhs_shape)?,
365 Some(b.shape().to_vec()),
366 )
367 } else {
368 (b.clone(), None)
369 };
370
371 let result = match self.kind() {
372 CpuBackendKind::Faer => {
373 #[cfg(feature = "cpu-faer")]
374 {
375 let ctx = self.linalg_context();
376 self.with_linalg_pool(|buffers| match (a, &rhs) {
377 (Tensor::F32(a), Tensor::F32(b)) => linalg::faer::full_piv_lu_solve(
378 ctx.as_ref(),
379 buffers,
380 a,
381 b,
382 transpose_a,
383 )
384 .map(Tensor::F32),
385 (Tensor::F64(a), Tensor::F64(b)) => linalg::faer::full_piv_lu_solve(
386 ctx.as_ref(),
387 buffers,
388 a,
389 b,
390 transpose_a,
391 )
392 .map(Tensor::F64),
393 (Tensor::C32(a), Tensor::C32(b)) => linalg::faer::full_piv_lu_solve(
394 ctx.as_ref(),
395 buffers,
396 a,
397 b,
398 transpose_a,
399 )
400 .map(Tensor::C32),
401 (Tensor::C64(a), Tensor::C64(b)) => linalg::faer::full_piv_lu_solve(
402 ctx.as_ref(),
403 buffers,
404 a,
405 b,
406 transpose_a,
407 )
408 .map(Tensor::C64),
409 _ => unsupported_pair("full_piv_lu_solve", a, &rhs),
410 })
411 }
412 #[cfg(not(feature = "cpu-faer"))]
413 {
414 Err(unsupported_provider("full_piv_lu_solve", self.kind()))
415 }
416 }
417 CpuBackendKind::Blas => {
418 #[cfg(feature = "cpu-blas")]
419 {
420 self.with_linalg_pool(|buffers| match (a, &rhs) {
421 (Tensor::F32(a), Tensor::F32(b)) => {
422 linalg::blas::full_piv_lu_solve(buffers, a, b, transpose_a)
423 .map(Tensor::F32)
424 }
425 (Tensor::F64(a), Tensor::F64(b)) => {
426 linalg::blas::full_piv_lu_solve(buffers, a, b, transpose_a)
427 .map(Tensor::F64)
428 }
429 (Tensor::C32(a), Tensor::C32(b)) => {
430 linalg::blas::full_piv_lu_solve(buffers, a, b, transpose_a)
431 .map(Tensor::C32)
432 }
433 (Tensor::C64(a), Tensor::C64(b)) => {
434 linalg::blas::full_piv_lu_solve(buffers, a, b, transpose_a)
435 .map(Tensor::C64)
436 }
437 _ => unsupported_pair("full_piv_lu_solve", a, &rhs),
438 })
439 }
440 #[cfg(not(feature = "cpu-blas"))]
441 {
442 Err(unsupported_provider("full_piv_lu_solve", self.kind()))
443 }
444 }
445 }?;
446
447 if let Some(shape) = restore_shape {
448 self.reshape(&result, &shape)
449 } else {
450 Ok(result)
451 }
452 }
453
454 fn svd(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
455 ensure_host_tensor("svd", input)?;
456 match self.kind() {
457 CpuBackendKind::Faer => {
458 #[cfg(feature = "cpu-faer")]
459 {
460 let ctx = self.linalg_context();
461 self.with_linalg_pool(|buffers| match input {
462 Tensor::F32(t) => linalg::faer::svd(ctx.as_ref(), buffers, t)
463 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
464 Tensor::F64(t) => linalg::faer::svd(ctx.as_ref(), buffers, t)
465 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
466 Tensor::C32(t) => linalg::faer::svd(ctx.as_ref(), buffers, t)
467 .and_then(svd_c32_outputs_to_public_tensors),
468 Tensor::C64(t) => linalg::faer::svd(ctx.as_ref(), buffers, t)
469 .and_then(svd_c64_outputs_to_public_tensors),
470 _ => Err(unsupported_dtype("svd", input.dtype())),
471 })
472 }
473 #[cfg(not(feature = "cpu-faer"))]
474 {
475 Err(unsupported_provider("svd", self.kind()))
476 }
477 }
478 CpuBackendKind::Blas => {
479 #[cfg(feature = "cpu-blas")]
480 {
481 self.with_linalg_pool(|buffers| match input {
482 Tensor::F32(t) => linalg::blas::svd(buffers, t)
483 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
484 Tensor::F64(t) => linalg::blas::svd(buffers, t)
485 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
486 Tensor::C32(t) => linalg::blas::svd(buffers, t)
487 .and_then(svd_c32_outputs_to_public_tensors),
488 Tensor::C64(t) => linalg::blas::svd(buffers, t)
489 .and_then(svd_c64_outputs_to_public_tensors),
490 _ => Err(unsupported_dtype("svd", input.dtype())),
491 })
492 }
493 #[cfg(not(feature = "cpu-blas"))]
494 {
495 Err(unsupported_provider("svd", self.kind()))
496 }
497 }
498 }
499 }
500
501 fn svd_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
502 ensure_host_tensor("svd_values", input)?;
503 match self.kind() {
504 CpuBackendKind::Faer => {
505 #[cfg(feature = "cpu-faer")]
506 {
507 let ctx = self.linalg_context();
508 self.with_linalg_pool(|buffers| match input {
509 Tensor::F32(t) => {
510 linalg::faer::svd_values(ctx.as_ref(), buffers, t).map(Tensor::F32)
511 }
512 Tensor::F64(t) => {
513 linalg::faer::svd_values(ctx.as_ref(), buffers, t).map(Tensor::F64)
514 }
515 Tensor::C32(t) => {
516 linalg::faer::svd_values(ctx.as_ref(), buffers, t).map(Tensor::F32)
517 }
518 Tensor::C64(t) => {
519 linalg::faer::svd_values(ctx.as_ref(), buffers, t).map(Tensor::F64)
520 }
521 _ => Err(unsupported_dtype("svd_values", input.dtype())),
522 })
523 }
524 #[cfg(not(feature = "cpu-faer"))]
525 {
526 Err(unsupported_provider("svd_values", self.kind()))
527 }
528 }
529 CpuBackendKind::Blas => {
530 #[cfg(feature = "cpu-blas")]
531 {
532 self.with_linalg_pool(|buffers| match input {
533 Tensor::F32(t) => linalg::blas::svd_values(buffers, t).map(Tensor::F32),
534 Tensor::F64(t) => linalg::blas::svd_values(buffers, t).map(Tensor::F64),
535 Tensor::C32(t) => linalg::blas::svd_values(buffers, t).map(Tensor::F32),
536 Tensor::C64(t) => linalg::blas::svd_values(buffers, t).map(Tensor::F64),
537 _ => Err(unsupported_dtype("svd_values", input.dtype())),
538 })
539 }
540 #[cfg(not(feature = "cpu-blas"))]
541 {
542 Err(unsupported_provider("svd_values", self.kind()))
543 }
544 }
545 }
546 }
547
548 fn svd_read(&mut self, input: TensorView<'_>) -> tenferro_tensor::Result<Vec<Tensor>> {
549 match input {
550 TensorView::F32(view) => {
551 let compact = self.to_contiguous(&view)?;
552 let input = Tensor::F32(compact);
553 self.svd(&input)
554 }
555 TensorView::F64(view) => {
556 let compact = self.to_contiguous(&view)?;
557 let input = Tensor::F64(compact);
558 self.svd(&input)
559 }
560 TensorView::C32(view) => {
561 let compact = self.to_contiguous(&view)?;
562 let input = Tensor::C32(compact);
563 self.svd(&input)
564 }
565 TensorView::C64(view) => {
566 let compact = self.to_contiguous(&view)?;
567 let input = Tensor::C64(compact);
568 self.svd(&input)
569 }
570 TensorView::I32(_) | TensorView::I64(_) | TensorView::Bool(_) => {
571 Err(unsupported_dtype("svd", input.dtype()))
572 }
573 }
574 }
575
576 fn qr(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
577 ensure_host_tensor("qr", input)?;
578 match self.kind() {
579 CpuBackendKind::Faer => {
580 #[cfg(feature = "cpu-faer")]
581 {
582 let ctx = self.linalg_context();
583 self.with_linalg_pool(|buffers| match input {
584 Tensor::F32(t) => linalg::faer::qr(ctx.as_ref(), buffers, t)
585 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
586 Tensor::F64(t) => linalg::faer::qr(ctx.as_ref(), buffers, t)
587 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
588 Tensor::C32(t) => linalg::faer::qr(ctx.as_ref(), buffers, t)
589 .map(|outputs| outputs.into_iter().map(Tensor::C32).collect()),
590 Tensor::C64(t) => linalg::faer::qr(ctx.as_ref(), buffers, t)
591 .map(|outputs| outputs.into_iter().map(Tensor::C64).collect()),
592 _ => Err(unsupported_dtype("qr", input.dtype())),
593 })
594 }
595 #[cfg(not(feature = "cpu-faer"))]
596 {
597 Err(unsupported_provider("qr", self.kind()))
598 }
599 }
600 CpuBackendKind::Blas => {
601 #[cfg(feature = "cpu-blas")]
602 {
603 self.with_linalg_pool(|buffers| match input {
604 Tensor::F32(t) => linalg::blas::qr(buffers, t)
605 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
606 Tensor::F64(t) => linalg::blas::qr(buffers, t)
607 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
608 Tensor::C32(t) => linalg::blas::qr(buffers, t)
609 .map(|outputs| outputs.into_iter().map(Tensor::C32).collect()),
610 Tensor::C64(t) => linalg::blas::qr(buffers, t)
611 .map(|outputs| outputs.into_iter().map(Tensor::C64).collect()),
612 _ => Err(unsupported_dtype("qr", input.dtype())),
613 })
614 }
615 #[cfg(not(feature = "cpu-blas"))]
616 {
617 Err(unsupported_provider("qr", self.kind()))
618 }
619 }
620 }
621 }
622
623 fn eigh(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
624 ensure_host_tensor("eigh", input)?;
625 match self.kind() {
626 CpuBackendKind::Faer => {
627 #[cfg(feature = "cpu-faer")]
628 {
629 let ctx = self.linalg_context();
630 self.with_linalg_pool(|buffers| match input {
631 Tensor::F32(t) => linalg::faer::eigh(ctx.as_ref(), buffers, t)
632 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
633 Tensor::F64(t) => linalg::faer::eigh(ctx.as_ref(), buffers, t)
634 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
635 Tensor::C32(t) => linalg::faer::eigh(ctx.as_ref(), buffers, t)
636 .and_then(eigh_c32_outputs_to_public_tensors),
637 Tensor::C64(t) => linalg::faer::eigh(ctx.as_ref(), buffers, t)
638 .and_then(eigh_c64_outputs_to_public_tensors),
639 _ => Err(unsupported_dtype("eigh", input.dtype())),
640 })
641 }
642 #[cfg(not(feature = "cpu-faer"))]
643 {
644 Err(unsupported_provider("eigh", self.kind()))
645 }
646 }
647 CpuBackendKind::Blas => {
648 #[cfg(feature = "cpu-blas")]
649 {
650 self.with_linalg_pool(|buffers| match input {
651 Tensor::F32(t) => linalg::blas::eigh(buffers, t)
652 .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
653 Tensor::F64(t) => linalg::blas::eigh(buffers, t)
654 .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
655 Tensor::C32(t) => linalg::blas::eigh(buffers, t)
656 .and_then(eigh_c32_outputs_to_public_tensors),
657 Tensor::C64(t) => linalg::blas::eigh(buffers, t)
658 .and_then(eigh_c64_outputs_to_public_tensors),
659 _ => Err(unsupported_dtype("eigh", input.dtype())),
660 })
661 }
662 #[cfg(not(feature = "cpu-blas"))]
663 {
664 Err(unsupported_provider("eigh", self.kind()))
665 }
666 }
667 }
668 }
669
670 fn eigh_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
671 ensure_host_tensor("eigh_values", input)?;
672 match self.kind() {
673 CpuBackendKind::Faer => {
674 #[cfg(feature = "cpu-faer")]
675 {
676 let ctx = self.linalg_context();
677 self.with_linalg_pool(|buffers| match input {
678 Tensor::F32(t) => {
679 linalg::faer::eigh_values(ctx.as_ref(), buffers, t).map(Tensor::F32)
680 }
681 Tensor::F64(t) => {
682 linalg::faer::eigh_values(ctx.as_ref(), buffers, t).map(Tensor::F64)
683 }
684 Tensor::C32(t) => {
685 linalg::faer::eigh_values(ctx.as_ref(), buffers, t).map(Tensor::F32)
686 }
687 Tensor::C64(t) => {
688 linalg::faer::eigh_values(ctx.as_ref(), buffers, t).map(Tensor::F64)
689 }
690 _ => Err(unsupported_dtype("eigh_values", input.dtype())),
691 })
692 }
693 #[cfg(not(feature = "cpu-faer"))]
694 {
695 Err(unsupported_provider("eigh_values", self.kind()))
696 }
697 }
698 CpuBackendKind::Blas => {
699 #[cfg(feature = "cpu-blas")]
700 {
701 self.with_linalg_pool(|buffers| match input {
702 Tensor::F32(t) => linalg::blas::eigh_values(buffers, t).map(Tensor::F32),
703 Tensor::F64(t) => linalg::blas::eigh_values(buffers, t).map(Tensor::F64),
704 Tensor::C32(t) => linalg::blas::eigh_values(buffers, t).map(Tensor::F32),
705 Tensor::C64(t) => linalg::blas::eigh_values(buffers, t).map(Tensor::F64),
706 _ => Err(unsupported_dtype("eigh_values", input.dtype())),
707 })
708 }
709 #[cfg(not(feature = "cpu-blas"))]
710 {
711 Err(unsupported_provider("eigh_values", self.kind()))
712 }
713 }
714 }
715 }
716
717 fn eig(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
718 ensure_host_tensor("eig", input)?;
719 if !matches!(
720 input,
721 Tensor::F32(_) | Tensor::F64(_) | Tensor::C32(_) | Tensor::C64(_)
722 ) {
723 return Err(unsupported_dtype("eig", input.dtype()));
724 }
725 match self.kind() {
726 CpuBackendKind::Faer => {
727 #[cfg(feature = "cpu-faer")]
728 {
729 let ctx = self.linalg_context();
730 self.with_linalg_pool(|buffers| linalg::faer::eig(ctx.as_ref(), buffers, input))
731 }
732 #[cfg(not(feature = "cpu-faer"))]
733 {
734 Err(unsupported_provider("eig", self.kind()))
735 }
736 }
737 CpuBackendKind::Blas => {
738 #[cfg(feature = "cpu-blas")]
739 {
740 self.with_linalg_pool(|buffers| linalg::blas::eig(buffers, input))
741 }
742 #[cfg(not(feature = "cpu-blas"))]
743 {
744 Err(unsupported_provider("eig", self.kind()))
745 }
746 }
747 }
748 }
749
750 fn eig_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
751 ensure_host_tensor("eig_values", input)?;
752 if !matches!(
753 input,
754 Tensor::F32(_) | Tensor::F64(_) | Tensor::C32(_) | Tensor::C64(_)
755 ) {
756 return Err(unsupported_dtype("eig_values", input.dtype()));
757 }
758 match self.kind() {
759 CpuBackendKind::Faer => {
760 #[cfg(feature = "cpu-faer")]
761 {
762 let ctx = self.linalg_context();
763 self.with_linalg_pool(|buffers| {
764 linalg::faer::eig_values(ctx.as_ref(), buffers, input)
765 })
766 }
767 #[cfg(not(feature = "cpu-faer"))]
768 {
769 Err(unsupported_provider("eig_values", self.kind()))
770 }
771 }
772 CpuBackendKind::Blas => {
773 #[cfg(feature = "cpu-blas")]
774 {
775 self.with_linalg_pool(|buffers| linalg::blas::eig_values(buffers, input))
776 }
777 #[cfg(not(feature = "cpu-blas"))]
778 {
779 Err(unsupported_provider("eig_values", self.kind()))
780 }
781 }
782 }
783 }
784
785 fn lu_solve_prepared(
786 &mut self,
787 a: &Tensor,
788 packed_lu: &Tensor,
789 pivots: &Tensor,
790 b: &Tensor,
791 transpose_a: bool,
792 conjugate_a: bool,
793 ) -> tenferro_tensor::Result<Tensor> {
794 const OP: &str = "lu_solve_prepared";
795
796 ensure_host_tensor(OP, a)?;
797 ensure_host_tensor(OP, packed_lu)?;
798 ensure_host_tensor(OP, pivots)?;
799 ensure_host_tensor(OP, b)?;
800 ensure_supported_linalg_pair(OP, a, b)?;
801 ensure_supported_linalg_pair(OP, a, packed_lu)?;
802 if !matches!(pivots, Tensor::I32(_)) {
803 return Err(Error::DTypeMismatch {
804 op: OP,
805 lhs: DType::I32,
806 rhs: pivots.dtype(),
807 });
808 }
809 if has_zero_dim(a.shape()) || has_zero_dim(b.shape()) {
810 return zeros_like_tensor(b);
811 }
812
813 let (rhs, restore_shape) = if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
814 (
815 self.reshape(b, &matrix_rhs_shape)?,
816 Some(b.shape().to_vec()),
817 )
818 } else {
819 (b.clone(), None)
820 };
821
822 validate_lu_solve_prepared_shapes(packed_lu.shape(), pivots.shape(), rhs.shape())?;
823 validate_nonsingular_u(packed_lu)?;
824 let lu_op = if conjugate_a {
825 self.conj(packed_lu)?
826 } else {
827 packed_lu.clone()
828 };
829 let result = if transpose_a {
830 let z = self.triangular_solve(&lu_op, &rhs, true, false, true, false)?;
831 let y = self.triangular_solve(&lu_op, &z, true, true, true, true)?;
832 apply_lu_pivots_cpu(&y, pivots, true)?
833 } else {
834 let pb = apply_lu_pivots_cpu(&rhs, pivots, false)?;
835 let y = self.triangular_solve(&lu_op, &pb, true, true, false, true)?;
836 self.triangular_solve(&lu_op, &y, true, false, false, false)?
837 };
838
839 if let Some(shape) = restore_shape {
840 self.reshape(&result, &shape)
841 } else {
842 Ok(result)
843 }
844 }
845
846 fn solve(&mut self, a: &Tensor, b: &Tensor) -> tenferro_tensor::Result<Tensor> {
847 ensure_host_tensor("solve", a)?;
848 ensure_host_tensor("solve", b)?;
849 ensure_supported_linalg_pair("solve", a, b)?;
850 if has_zero_dim(a.shape()) || has_zero_dim(b.shape()) {
851 return zeros_like_tensor(b);
852 }
853
854 let (rhs, restore_shape) = if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
855 (
856 self.reshape(b, &matrix_rhs_shape)?,
857 Some(b.shape().to_vec()),
858 )
859 } else {
860 (b.clone(), None)
861 };
862
863 let result = match self.kind() {
864 CpuBackendKind::Faer => {
865 #[cfg(feature = "cpu-faer")]
866 {
867 let ctx = self.linalg_context();
868 self.with_linalg_pool(|buffers| match (a, &rhs) {
869 (Tensor::F32(a), Tensor::F32(b)) => {
870 linalg::faer::solve(ctx.as_ref(), buffers, a, b, false).map(Tensor::F32)
871 }
872 (Tensor::F64(a), Tensor::F64(b)) => {
873 linalg::faer::solve(ctx.as_ref(), buffers, a, b, false).map(Tensor::F64)
874 }
875 (Tensor::C32(a), Tensor::C32(b)) => {
876 linalg::faer::solve(ctx.as_ref(), buffers, a, b, false).map(Tensor::C32)
877 }
878 (Tensor::C64(a), Tensor::C64(b)) => {
879 linalg::faer::solve(ctx.as_ref(), buffers, a, b, false).map(Tensor::C64)
880 }
881 _ => unsupported_pair("solve", a, &rhs),
882 })
883 }
884 #[cfg(not(feature = "cpu-faer"))]
885 {
886 Err(unsupported_provider("solve", self.kind()))
887 }
888 }
889 CpuBackendKind::Blas => {
890 #[cfg(feature = "cpu-blas")]
891 {
892 self.with_linalg_pool(|buffers| match (a, &rhs) {
893 (Tensor::F32(a), Tensor::F32(b)) => {
894 linalg::blas::solve(buffers, a, b, false).map(Tensor::F32)
895 }
896 (Tensor::F64(a), Tensor::F64(b)) => {
897 linalg::blas::solve(buffers, a, b, false).map(Tensor::F64)
898 }
899 (Tensor::C32(a), Tensor::C32(b)) => {
900 linalg::blas::solve(buffers, a, b, false).map(Tensor::C32)
901 }
902 (Tensor::C64(a), Tensor::C64(b)) => {
903 linalg::blas::solve(buffers, a, b, false).map(Tensor::C64)
904 }
905 _ => unsupported_pair("solve", a, &rhs),
906 })
907 }
908 #[cfg(not(feature = "cpu-blas"))]
909 {
910 Err(unsupported_provider("solve", self.kind()))
911 }
912 }
913 }?;
914
915 if let Some(shape) = restore_shape {
916 self.reshape(&result, &shape)
917 } else {
918 Ok(result)
919 }
920 }
921}
922
923fn ensure_host_tensor(op: &'static str, input: &Tensor) -> tenferro_tensor::Result<()> {
924 match input {
925 Tensor::F32(t) => ensure_host_typed_tensor(op, t),
926 Tensor::F64(t) => ensure_host_typed_tensor(op, t),
927 Tensor::I32(t) => ensure_host_typed_tensor(op, t),
928 Tensor::I64(t) => ensure_host_typed_tensor(op, t),
929 Tensor::Bool(t) => ensure_host_typed_tensor(op, t),
930 Tensor::C32(t) => ensure_host_typed_tensor(op, t),
931 Tensor::C64(t) => ensure_host_typed_tensor(op, t),
932 }
933}
934
935fn ensure_host_typed_tensor<T: 'static>(
936 op: &'static str,
937 input: &TypedTensor<T>,
938) -> tenferro_tensor::Result<()> {
939 if input.as_view().backend_buffer().is_some() {
940 return Err(Error::backend_failure(
941 op,
942 "CPU linalg backend received a backend buffer; download the tensor to host before CPU execution",
943 ));
944 }
945 Ok(())
946}
947
948fn ensure_supported_linalg_pair(
949 op: &'static str,
950 lhs: &Tensor,
951 rhs: &Tensor,
952) -> tenferro_tensor::Result<()> {
953 if lhs.dtype() != rhs.dtype() {
954 return Err(Error::DTypeMismatch {
955 op,
956 lhs: lhs.dtype(),
957 rhs: rhs.dtype(),
958 });
959 }
960 match lhs {
961 Tensor::F32(_) | Tensor::F64(_) | Tensor::C32(_) | Tensor::C64(_) => Ok(()),
962 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => {
963 Err(unsupported_dtype(op, lhs.dtype()))
964 }
965 }
966}
967
968fn has_zero_dim(shape: &[usize]) -> bool {
969 shape.contains(&0)
970}
971
972fn batch_count(batch_shape: &[usize]) -> usize {
973 batch_shape.iter().product::<usize>().max(1)
974}
975
976fn batched_vector_rhs_shape(a: &Tensor, b: &Tensor) -> Option<Vec<usize>> {
977 if b.shape().len() == 1 {
978 return Some(vec![b.shape()[0], 1]);
979 }
980
981 let is_batched_vector_rhs = a.shape().len() == b.shape().len() + 1
982 && !b.shape().is_empty()
983 && b.shape()[0] == a.shape()[0]
984 && b.shape()[1..] == a.shape()[2..];
985 if !is_batched_vector_rhs {
986 return None;
987 }
988
989 let mut rhs_shape = vec![b.shape()[0], 1];
990 rhs_shape.extend_from_slice(&b.shape()[1..]);
991 Some(rhs_shape)
992}
993
994fn zeros_like_tensor(input: &Tensor) -> tenferro_tensor::Result<Tensor> {
995 Ok(match input {
996 Tensor::F32(t) => Tensor::F32(TypedTensor::zeros(t.shape().to_vec())?),
997 Tensor::F64(t) => Tensor::F64(TypedTensor::zeros(t.shape().to_vec())?),
998 Tensor::I32(t) => Tensor::I32(TypedTensor::zeros(t.shape().to_vec())?),
999 Tensor::I64(t) => Tensor::I64(TypedTensor::zeros(t.shape().to_vec())?),
1000 Tensor::Bool(t) => Tensor::Bool(TypedTensor::from_vec_col_major(
1001 t.shape().to_vec(),
1002 vec![false; t.n_elements()],
1003 )?),
1004 Tensor::C32(t) => Tensor::C32(TypedTensor::zeros(t.shape().to_vec())?),
1005 Tensor::C64(t) => Tensor::C64(TypedTensor::zeros(t.shape().to_vec())?),
1006 })
1007}
1008
1009fn complex32_real_part_tensor(
1010 values: TypedTensor<Complex32>,
1011) -> tenferro_tensor::Result<TypedTensor<f32>> {
1012 let mut out = TypedTensor::from_vec_col_major(
1013 values.shape().to_vec(),
1014 values.host_data()?.iter().map(|value| value.re).collect(),
1015 )?;
1016 out.set_placement(values.placement().clone());
1017 Ok(out)
1018}
1019
1020fn complex64_real_part_tensor(
1021 values: TypedTensor<Complex64>,
1022) -> tenferro_tensor::Result<TypedTensor<f64>> {
1023 let mut out = TypedTensor::from_vec_col_major(
1024 values.shape().to_vec(),
1025 values.host_data()?.iter().map(|value| value.re).collect(),
1026 )?;
1027 out.set_placement(values.placement().clone());
1028 Ok(out)
1029}
1030
1031fn svd_output_count_error(count: usize) -> Error {
1032 Error::backend_failure("svd", format!("expected 3 outputs, got {count}"))
1033}
1034
1035fn full_piv_lu_output_count_error(count: usize) -> Error {
1036 Error::backend_failure("full_piv_lu", format!("expected 5 outputs, got {count}"))
1037}
1038
1039fn eigh_output_count_error(count: usize) -> Error {
1040 Error::backend_failure("eigh", format!("expected 2 outputs, got {count}"))
1041}
1042
1043fn full_piv_lu_c32_outputs_to_public_tensors(
1044 outputs: Vec<TypedTensor<Complex32>>,
1045) -> tenferro_tensor::Result<Vec<Tensor>> {
1046 let count = outputs.len();
1047 let mut outputs = outputs.into_iter();
1048 match (
1049 outputs.next(),
1050 outputs.next(),
1051 outputs.next(),
1052 outputs.next(),
1053 outputs.next(),
1054 outputs.next(),
1055 ) {
1056 (Some(p), Some(l), Some(u), Some(q), Some(parity), None) => Ok(vec![
1057 Tensor::C32(p),
1058 Tensor::C32(l),
1059 Tensor::C32(u),
1060 Tensor::C32(q),
1061 Tensor::F32(complex32_real_part_tensor(parity)?),
1062 ]),
1063 _ => Err(full_piv_lu_output_count_error(count)),
1064 }
1065}
1066
1067fn full_piv_lu_c64_outputs_to_public_tensors(
1068 outputs: Vec<TypedTensor<Complex64>>,
1069) -> tenferro_tensor::Result<Vec<Tensor>> {
1070 let count = outputs.len();
1071 let mut outputs = outputs.into_iter();
1072 match (
1073 outputs.next(),
1074 outputs.next(),
1075 outputs.next(),
1076 outputs.next(),
1077 outputs.next(),
1078 outputs.next(),
1079 ) {
1080 (Some(p), Some(l), Some(u), Some(q), Some(parity), None) => Ok(vec![
1081 Tensor::C64(p),
1082 Tensor::C64(l),
1083 Tensor::C64(u),
1084 Tensor::C64(q),
1085 Tensor::F64(complex64_real_part_tensor(parity)?),
1086 ]),
1087 _ => Err(full_piv_lu_output_count_error(count)),
1088 }
1089}
1090
1091fn svd_c32_outputs_to_public_tensors(
1092 outputs: Vec<TypedTensor<Complex32>>,
1093) -> tenferro_tensor::Result<Vec<Tensor>> {
1094 let count = outputs.len();
1095 let mut outputs = outputs.into_iter();
1096 match (
1097 outputs.next(),
1098 outputs.next(),
1099 outputs.next(),
1100 outputs.next(),
1101 ) {
1102 (Some(u), Some(values), Some(vt), None) => Ok(vec![
1103 Tensor::C32(u),
1104 Tensor::F32(complex32_real_part_tensor(values)?),
1105 Tensor::C32(vt),
1106 ]),
1107 _ => Err(svd_output_count_error(count)),
1108 }
1109}
1110
1111fn svd_c64_outputs_to_public_tensors(
1112 outputs: Vec<TypedTensor<Complex64>>,
1113) -> tenferro_tensor::Result<Vec<Tensor>> {
1114 let count = outputs.len();
1115 let mut outputs = outputs.into_iter();
1116 match (
1117 outputs.next(),
1118 outputs.next(),
1119 outputs.next(),
1120 outputs.next(),
1121 ) {
1122 (Some(u), Some(values), Some(vt), None) => Ok(vec![
1123 Tensor::C64(u),
1124 Tensor::F64(complex64_real_part_tensor(values)?),
1125 Tensor::C64(vt),
1126 ]),
1127 _ => Err(svd_output_count_error(count)),
1128 }
1129}
1130
1131fn eigh_c32_outputs_to_public_tensors(
1132 outputs: Vec<TypedTensor<Complex32>>,
1133) -> tenferro_tensor::Result<Vec<Tensor>> {
1134 let count = outputs.len();
1135 let mut outputs = outputs.into_iter();
1136 match (outputs.next(), outputs.next(), outputs.next()) {
1137 (Some(values), Some(vectors), None) => Ok(vec![
1138 Tensor::F32(complex32_real_part_tensor(values)?),
1139 Tensor::C32(vectors),
1140 ]),
1141 _ => Err(eigh_output_count_error(count)),
1142 }
1143}
1144
1145fn eigh_c64_outputs_to_public_tensors(
1146 outputs: Vec<TypedTensor<Complex64>>,
1147) -> tenferro_tensor::Result<Vec<Tensor>> {
1148 let count = outputs.len();
1149 let mut outputs = outputs.into_iter();
1150 match (outputs.next(), outputs.next(), outputs.next()) {
1151 (Some(values), Some(vectors), None) => Ok(vec![
1152 Tensor::F64(complex64_real_part_tensor(values)?),
1153 Tensor::C64(vectors),
1154 ]),
1155 _ => Err(eigh_output_count_error(count)),
1156 }
1157}
1158
1159fn apply_lu_pivots_cpu(
1160 input: &Tensor,
1161 pivots: &Tensor,
1162 inverse: bool,
1163) -> tenferro_tensor::Result<Tensor> {
1164 let Tensor::I32(pivots) = pivots else {
1165 return Err(Error::DTypeMismatch {
1166 op: "lu_solve_prepared",
1167 lhs: DType::I32,
1168 rhs: pivots.dtype(),
1169 });
1170 };
1171 match input {
1172 Tensor::F32(t) => apply_lu_pivots_typed(t, pivots, inverse).map(Tensor::F32),
1173 Tensor::F64(t) => apply_lu_pivots_typed(t, pivots, inverse).map(Tensor::F64),
1174 Tensor::C32(t) => apply_lu_pivots_typed(t, pivots, inverse).map(Tensor::C32),
1175 Tensor::C64(t) => apply_lu_pivots_typed(t, pivots, inverse).map(Tensor::C64),
1176 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => {
1177 Err(unsupported_dtype("lu_solve_prepared", input.dtype()))
1178 }
1179 }
1180}
1181
1182fn apply_lu_pivots_typed<T: Clone>(
1183 input: &TypedTensor<T>,
1184 pivots: &TypedTensor<i32>,
1185 inverse: bool,
1186) -> tenferro_tensor::Result<TypedTensor<T>> {
1187 let shape = input.shape();
1188 if shape.len() < 2 {
1189 return Err(Error::RankMismatch {
1190 op: "lu_solve_prepared",
1191 expected: 2,
1192 actual: shape.len(),
1193 });
1194 }
1195 let rows = shape[0];
1196 let cols = shape[1];
1197 let k = pivots.shape()[0];
1198 if k > rows || pivots.shape()[1..] != shape[2..] {
1199 return Err(Error::ShapeMismatch {
1200 op: "lu_solve_prepared",
1201 lhs: pivots.shape().to_vec(),
1202 rhs: shape.to_vec(),
1203 });
1204 }
1205 let batch_total = batch_count(&shape[2..]);
1206 let matrix_stride = rows * cols;
1207 let pivot_stride = k;
1208 let input_data = input.host_data()?;
1209 let pivot_data = pivots.host_data()?;
1210 let mut data = Vec::with_capacity(input_data.len());
1211
1212 for batch in 0..batch_total {
1213 let mut perm: Vec<usize> = (0..rows).collect();
1214 let pivot_offset = batch * pivot_stride;
1215 for step in 0..k {
1216 let pivot_one_based = pivot_data[pivot_offset + step];
1217 if pivot_one_based <= 0 {
1218 return Err(Error::backend_failure(
1219 "lu_solve_prepared",
1220 "LU pivot index must be 1-based and positive",
1221 ));
1222 }
1223 let pivot = usize::try_from(pivot_one_based - 1).map_err(|_| {
1224 Error::backend_failure("lu_solve_prepared", "LU pivot index is invalid")
1225 })?;
1226 if pivot >= rows {
1227 return Err(Error::backend_failure(
1228 "lu_solve_prepared",
1229 "LU pivot index is out of bounds",
1230 ));
1231 }
1232 perm.swap(step, pivot);
1233 }
1234 let row_map = if inverse {
1235 let mut inv = vec![0usize; rows];
1236 for (row, &source) in perm.iter().enumerate() {
1237 inv[source] = row;
1238 }
1239 inv
1240 } else {
1241 perm
1242 };
1243 let batch_offset = batch * matrix_stride;
1244 for col in 0..cols {
1245 for &source_row in &row_map {
1246 data.push(input_data[batch_offset + source_row + col * rows].clone());
1247 }
1248 }
1249 }
1250
1251 TypedTensor::from_vec_col_major(shape.to_vec(), data)
1252}
1253
1254fn validate_lu_solve_prepared_shapes(
1255 lu_shape: &[usize],
1256 pivots_shape: &[usize],
1257 b_shape: &[usize],
1258) -> tenferro_tensor::Result<()> {
1259 let n = square_matrix_dim("lu_solve_prepared", lu_shape)?;
1260 let (b_rows, _) = matrix_dims("lu_solve_prepared", b_shape)?;
1261 if b_rows != n {
1262 return Err(Error::InvalidConfig {
1263 op: "lu_solve_prepared",
1264 message: format!("rhs row count mismatch: expected {n}, got {b_rows}"),
1265 });
1266 }
1267 if lu_shape[2..] != b_shape[2..] {
1268 return Err(Error::ShapeMismatch {
1269 op: "lu_solve_prepared",
1270 lhs: lu_shape.to_vec(),
1271 rhs: b_shape.to_vec(),
1272 });
1273 }
1274 let mut expected_pivots = vec![n];
1275 expected_pivots.extend_from_slice(&lu_shape[2..]);
1276 if pivots_shape != expected_pivots {
1277 return Err(Error::ShapeMismatch {
1278 op: "lu_solve_prepared",
1279 lhs: expected_pivots,
1280 rhs: pivots_shape.to_vec(),
1281 });
1282 }
1283 Ok(())
1284}
1285
1286fn matrix_dims(op: &'static str, shape: &[usize]) -> tenferro_tensor::Result<(usize, usize)> {
1287 if shape.len() < 2 {
1288 return Err(Error::RankMismatch {
1289 op,
1290 expected: 2,
1291 actual: shape.len(),
1292 });
1293 }
1294 Ok((shape[0], shape[1]))
1295}
1296
1297fn square_matrix_dim(op: &'static str, shape: &[usize]) -> tenferro_tensor::Result<usize> {
1298 let (rows, cols) = matrix_dims(op, shape)?;
1299 if rows != cols {
1300 return Err(Error::ShapeMismatch {
1301 op,
1302 lhs: vec![rows],
1303 rhs: vec![cols],
1304 });
1305 }
1306 Ok(rows)
1307}
1308
1309#[allow(dead_code)]
1312fn unsupported_provider(op: &'static str, kind: CpuBackendKind) -> Error {
1313 Error::InvalidConfig {
1314 op,
1315 message: format!("CPU linalg provider {kind:?} is not compiled in"),
1316 }
1317}
1318
1319fn unsupported_pair(
1320 op: &'static str,
1321 lhs: &Tensor,
1322 rhs: &Tensor,
1323) -> tenferro_tensor::Result<Tensor> {
1324 if lhs.dtype() != rhs.dtype() {
1325 Err(Error::DTypeMismatch {
1326 op,
1327 lhs: lhs.dtype(),
1328 rhs: rhs.dtype(),
1329 })
1330 } else {
1331 Err(unsupported_dtype(op, lhs.dtype()))
1332 }
1333}
1334
1335fn unsupported_dtype(op: &'static str, dtype: DType) -> Error {
1336 Error::backend_failure(op, format!("unsupported dtype {dtype:?}"))
1337}