Skip to main content

tenferro_linalg/cpu/
backend.rs

1use crate::backend::LinalgBackend;
2
3use super::linalg;
4
5use num_complex::{Complex32, Complex64};
6use tenferro_cpu::{CpuBackend, CpuBackendKind};
7use tenferro_tensor::{
8    validate::validate_nonsingular_u, DType, Error, Tensor, TensorElementwise, TensorStructural,
9    TensorView, TensorViewCanonicalization, TypedTensor,
10};
11
12impl LinalgBackend for CpuBackend {
13    fn cholesky(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
14        ensure_host_tensor("cholesky", input)?;
15        match self.kind() {
16            CpuBackendKind::Faer => {
17                #[cfg(feature = "cpu-faer")]
18                {
19                    let ctx = self.linalg_context();
20                    self.with_linalg_pool(|buffers| match input {
21                        Tensor::F32(t) => {
22                            linalg::faer::cholesky(ctx.as_ref(), buffers, t).map(Tensor::F32)
23                        }
24                        Tensor::F64(t) => {
25                            linalg::faer::cholesky(ctx.as_ref(), buffers, t).map(Tensor::F64)
26                        }
27                        Tensor::C32(t) => {
28                            linalg::faer::cholesky(ctx.as_ref(), buffers, t).map(Tensor::C32)
29                        }
30                        Tensor::C64(t) => {
31                            linalg::faer::cholesky(ctx.as_ref(), buffers, t).map(Tensor::C64)
32                        }
33                        _ => Err(unsupported_dtype("cholesky", input.dtype())),
34                    })
35                }
36                #[cfg(not(feature = "cpu-faer"))]
37                {
38                    Err(unsupported_provider("cholesky", self.kind()))
39                }
40            }
41            CpuBackendKind::Blas => {
42                #[cfg(feature = "cpu-blas")]
43                {
44                    self.with_linalg_pool(|buffers| match input {
45                        Tensor::F32(t) => linalg::blas::cholesky(buffers, t).map(Tensor::F32),
46                        Tensor::F64(t) => linalg::blas::cholesky(buffers, t).map(Tensor::F64),
47                        Tensor::C32(t) => linalg::blas::cholesky(buffers, t).map(Tensor::C32),
48                        Tensor::C64(t) => linalg::blas::cholesky(buffers, t).map(Tensor::C64),
49                        _ => Err(unsupported_dtype("cholesky", input.dtype())),
50                    })
51                }
52                #[cfg(not(feature = "cpu-blas"))]
53                {
54                    Err(unsupported_provider("cholesky", self.kind()))
55                }
56            }
57        }
58    }
59
60    fn triangular_solve(
61        &mut self,
62        a: &Tensor,
63        b: &Tensor,
64        left_side: bool,
65        lower: bool,
66        transpose_a: bool,
67        unit_diagonal: bool,
68    ) -> tenferro_tensor::Result<Tensor> {
69        ensure_host_tensor("triangular_solve", a)?;
70        ensure_host_tensor("triangular_solve", b)?;
71        match self.kind() {
72            CpuBackendKind::Faer => {
73                #[cfg(feature = "cpu-faer")]
74                {
75                    let ctx = self.linalg_context();
76                    self.with_linalg_pool(|buffers| match (a, b) {
77                        (Tensor::F32(a), Tensor::F32(b)) => linalg::faer::triangular_solve(
78                            ctx.as_ref(),
79                            buffers,
80                            a,
81                            b,
82                            left_side,
83                            lower,
84                            transpose_a,
85                            unit_diagonal,
86                        )
87                        .map(Tensor::F32),
88                        (Tensor::F64(a), Tensor::F64(b)) => linalg::faer::triangular_solve(
89                            ctx.as_ref(),
90                            buffers,
91                            a,
92                            b,
93                            left_side,
94                            lower,
95                            transpose_a,
96                            unit_diagonal,
97                        )
98                        .map(Tensor::F64),
99                        (Tensor::C32(a), Tensor::C32(b)) => linalg::faer::triangular_solve(
100                            ctx.as_ref(),
101                            buffers,
102                            a,
103                            b,
104                            left_side,
105                            lower,
106                            transpose_a,
107                            unit_diagonal,
108                        )
109                        .map(Tensor::C32),
110                        (Tensor::C64(a), Tensor::C64(b)) => linalg::faer::triangular_solve(
111                            ctx.as_ref(),
112                            buffers,
113                            a,
114                            b,
115                            left_side,
116                            lower,
117                            transpose_a,
118                            unit_diagonal,
119                        )
120                        .map(Tensor::C64),
121                        _ => unsupported_pair("triangular_solve", a, b),
122                    })
123                }
124                #[cfg(not(feature = "cpu-faer"))]
125                {
126                    Err(unsupported_provider("triangular_solve", self.kind()))
127                }
128            }
129            CpuBackendKind::Blas => {
130                #[cfg(feature = "cpu-blas")]
131                {
132                    self.with_linalg_pool(|buffers| match (a, b) {
133                        (Tensor::F32(a), Tensor::F32(b)) => linalg::blas::triangular_solve(
134                            buffers,
135                            a,
136                            b,
137                            left_side,
138                            lower,
139                            transpose_a,
140                            unit_diagonal,
141                        )
142                        .map(Tensor::F32),
143                        (Tensor::F64(a), Tensor::F64(b)) => linalg::blas::triangular_solve(
144                            buffers,
145                            a,
146                            b,
147                            left_side,
148                            lower,
149                            transpose_a,
150                            unit_diagonal,
151                        )
152                        .map(Tensor::F64),
153                        (Tensor::C32(a), Tensor::C32(b)) => linalg::blas::triangular_solve(
154                            buffers,
155                            a,
156                            b,
157                            left_side,
158                            lower,
159                            transpose_a,
160                            unit_diagonal,
161                        )
162                        .map(Tensor::C32),
163                        (Tensor::C64(a), Tensor::C64(b)) => linalg::blas::triangular_solve(
164                            buffers,
165                            a,
166                            b,
167                            left_side,
168                            lower,
169                            transpose_a,
170                            unit_diagonal,
171                        )
172                        .map(Tensor::C64),
173                        _ => unsupported_pair("triangular_solve", a, b),
174                    })
175                }
176                #[cfg(not(feature = "cpu-blas"))]
177                {
178                    Err(unsupported_provider("triangular_solve", self.kind()))
179                }
180            }
181        }
182    }
183
184    fn lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
185        ensure_host_tensor("lu", input)?;
186        match self.kind() {
187            CpuBackendKind::Faer => {
188                #[cfg(feature = "cpu-faer")]
189                {
190                    let ctx = self.linalg_context();
191                    self.with_linalg_pool(|buffers| match input {
192                        Tensor::F32(t) => linalg::faer::lu(ctx.as_ref(), buffers, t)
193                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
194                        Tensor::F64(t) => linalg::faer::lu(ctx.as_ref(), buffers, t)
195                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
196                        Tensor::C32(t) => linalg::faer::lu(ctx.as_ref(), buffers, t)
197                            .map(|outputs| outputs.into_iter().map(Tensor::C32).collect()),
198                        Tensor::C64(t) => linalg::faer::lu(ctx.as_ref(), buffers, t)
199                            .map(|outputs| outputs.into_iter().map(Tensor::C64).collect()),
200                        _ => Err(unsupported_dtype("lu", input.dtype())),
201                    })
202                }
203                #[cfg(not(feature = "cpu-faer"))]
204                {
205                    Err(unsupported_provider("lu", self.kind()))
206                }
207            }
208            CpuBackendKind::Blas => {
209                #[cfg(feature = "cpu-blas")]
210                {
211                    self.with_linalg_pool(|buffers| match input {
212                        Tensor::F32(t) => linalg::blas::lu(buffers, t)
213                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
214                        Tensor::F64(t) => linalg::blas::lu(buffers, t)
215                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
216                        Tensor::C32(t) => linalg::blas::lu(buffers, t)
217                            .map(|outputs| outputs.into_iter().map(Tensor::C32).collect()),
218                        Tensor::C64(t) => linalg::blas::lu(buffers, t)
219                            .map(|outputs| outputs.into_iter().map(Tensor::C64).collect()),
220                        _ => Err(unsupported_dtype("lu", input.dtype())),
221                    })
222                }
223                #[cfg(not(feature = "cpu-blas"))]
224                {
225                    Err(unsupported_provider("lu", self.kind()))
226                }
227            }
228        }
229    }
230
231    fn lu_factor(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
232        ensure_host_tensor("lu_factor", input)?;
233        match self.kind() {
234            CpuBackendKind::Faer => {
235                #[cfg(feature = "cpu-faer")]
236                {
237                    let ctx = self.linalg_context();
238                    self.with_linalg_pool(|buffers| match input {
239                        Tensor::F32(t) => linalg::faer::lu_factor(ctx.as_ref(), buffers, t).map(
240                            |(lu, pivots, parity)| {
241                                vec![Tensor::F32(lu), Tensor::I32(pivots), Tensor::F32(parity)]
242                            },
243                        ),
244                        Tensor::F64(t) => linalg::faer::lu_factor(ctx.as_ref(), buffers, t).map(
245                            |(lu, pivots, parity)| {
246                                vec![Tensor::F64(lu), Tensor::I32(pivots), Tensor::F64(parity)]
247                            },
248                        ),
249                        Tensor::C32(t) => linalg::faer::lu_factor(ctx.as_ref(), buffers, t).map(
250                            |(lu, pivots, parity)| {
251                                vec![Tensor::C32(lu), Tensor::I32(pivots), Tensor::C32(parity)]
252                            },
253                        ),
254                        Tensor::C64(t) => linalg::faer::lu_factor(ctx.as_ref(), buffers, t).map(
255                            |(lu, pivots, parity)| {
256                                vec![Tensor::C64(lu), Tensor::I32(pivots), Tensor::C64(parity)]
257                            },
258                        ),
259                        _ => Err(unsupported_dtype("lu_factor", input.dtype())),
260                    })
261                }
262                #[cfg(not(feature = "cpu-faer"))]
263                {
264                    Err(unsupported_provider("lu_factor", self.kind()))
265                }
266            }
267            CpuBackendKind::Blas => {
268                #[cfg(feature = "cpu-blas")]
269                {
270                    self.with_linalg_pool(|buffers| match input {
271                        Tensor::F32(t) => {
272                            linalg::blas::lu_factor(buffers, t).map(|(lu, pivots, parity)| {
273                                vec![Tensor::F32(lu), Tensor::I32(pivots), Tensor::F32(parity)]
274                            })
275                        }
276                        Tensor::F64(t) => {
277                            linalg::blas::lu_factor(buffers, t).map(|(lu, pivots, parity)| {
278                                vec![Tensor::F64(lu), Tensor::I32(pivots), Tensor::F64(parity)]
279                            })
280                        }
281                        Tensor::C32(t) => {
282                            linalg::blas::lu_factor(buffers, t).map(|(lu, pivots, parity)| {
283                                vec![Tensor::C32(lu), Tensor::I32(pivots), Tensor::C32(parity)]
284                            })
285                        }
286                        Tensor::C64(t) => {
287                            linalg::blas::lu_factor(buffers, t).map(|(lu, pivots, parity)| {
288                                vec![Tensor::C64(lu), Tensor::I32(pivots), Tensor::C64(parity)]
289                            })
290                        }
291                        _ => Err(unsupported_dtype("lu_factor", input.dtype())),
292                    })
293                }
294                #[cfg(not(feature = "cpu-blas"))]
295                {
296                    Err(unsupported_provider("lu_factor", self.kind()))
297                }
298            }
299        }
300    }
301
302    fn full_piv_lu(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
303        ensure_host_tensor("full_piv_lu", input)?;
304        match self.kind() {
305            CpuBackendKind::Faer => {
306                #[cfg(feature = "cpu-faer")]
307                {
308                    let ctx = self.linalg_context();
309                    self.with_linalg_pool(|buffers| match input {
310                        Tensor::F32(t) => linalg::faer::full_piv_lu(ctx.as_ref(), buffers, t)
311                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
312                        Tensor::F64(t) => linalg::faer::full_piv_lu(ctx.as_ref(), buffers, t)
313                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
314                        Tensor::C32(t) => linalg::faer::full_piv_lu(ctx.as_ref(), buffers, t)
315                            .and_then(full_piv_lu_c32_outputs_to_public_tensors),
316                        Tensor::C64(t) => linalg::faer::full_piv_lu(ctx.as_ref(), buffers, t)
317                            .and_then(full_piv_lu_c64_outputs_to_public_tensors),
318                        _ => Err(unsupported_dtype("full_piv_lu", input.dtype())),
319                    })
320                }
321                #[cfg(not(feature = "cpu-faer"))]
322                {
323                    Err(unsupported_provider("full_piv_lu", self.kind()))
324                }
325            }
326            CpuBackendKind::Blas => {
327                #[cfg(feature = "cpu-blas")]
328                {
329                    self.with_linalg_pool(|buffers| match input {
330                        Tensor::F32(t) => linalg::blas::full_piv_lu(buffers, t)
331                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
332                        Tensor::F64(t) => linalg::blas::full_piv_lu(buffers, t)
333                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
334                        Tensor::C32(t) => linalg::blas::full_piv_lu(buffers, t)
335                            .and_then(full_piv_lu_c32_outputs_to_public_tensors),
336                        Tensor::C64(t) => linalg::blas::full_piv_lu(buffers, t)
337                            .and_then(full_piv_lu_c64_outputs_to_public_tensors),
338                        _ => Err(unsupported_dtype("full_piv_lu", input.dtype())),
339                    })
340                }
341                #[cfg(not(feature = "cpu-blas"))]
342                {
343                    Err(unsupported_provider("full_piv_lu", self.kind()))
344                }
345            }
346        }
347    }
348
349    fn full_piv_lu_solve(
350        &mut self,
351        a: &Tensor,
352        b: &Tensor,
353        transpose_a: bool,
354    ) -> tenferro_tensor::Result<Tensor> {
355        ensure_host_tensor("full_piv_lu_solve", a)?;
356        ensure_host_tensor("full_piv_lu_solve", b)?;
357        ensure_supported_linalg_pair("full_piv_lu_solve", a, b)?;
358        if has_zero_dim(a.shape()) || has_zero_dim(b.shape()) {
359            return zeros_like_tensor(b);
360        }
361
362        let (rhs, restore_shape) = if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
363            (
364                self.reshape(b, &matrix_rhs_shape)?,
365                Some(b.shape().to_vec()),
366            )
367        } else {
368            (b.clone(), None)
369        };
370
371        let result = match self.kind() {
372            CpuBackendKind::Faer => {
373                #[cfg(feature = "cpu-faer")]
374                {
375                    let ctx = self.linalg_context();
376                    self.with_linalg_pool(|buffers| match (a, &rhs) {
377                        (Tensor::F32(a), Tensor::F32(b)) => linalg::faer::full_piv_lu_solve(
378                            ctx.as_ref(),
379                            buffers,
380                            a,
381                            b,
382                            transpose_a,
383                        )
384                        .map(Tensor::F32),
385                        (Tensor::F64(a), Tensor::F64(b)) => linalg::faer::full_piv_lu_solve(
386                            ctx.as_ref(),
387                            buffers,
388                            a,
389                            b,
390                            transpose_a,
391                        )
392                        .map(Tensor::F64),
393                        (Tensor::C32(a), Tensor::C32(b)) => linalg::faer::full_piv_lu_solve(
394                            ctx.as_ref(),
395                            buffers,
396                            a,
397                            b,
398                            transpose_a,
399                        )
400                        .map(Tensor::C32),
401                        (Tensor::C64(a), Tensor::C64(b)) => linalg::faer::full_piv_lu_solve(
402                            ctx.as_ref(),
403                            buffers,
404                            a,
405                            b,
406                            transpose_a,
407                        )
408                        .map(Tensor::C64),
409                        _ => unsupported_pair("full_piv_lu_solve", a, &rhs),
410                    })
411                }
412                #[cfg(not(feature = "cpu-faer"))]
413                {
414                    Err(unsupported_provider("full_piv_lu_solve", self.kind()))
415                }
416            }
417            CpuBackendKind::Blas => {
418                #[cfg(feature = "cpu-blas")]
419                {
420                    self.with_linalg_pool(|buffers| match (a, &rhs) {
421                        (Tensor::F32(a), Tensor::F32(b)) => {
422                            linalg::blas::full_piv_lu_solve(buffers, a, b, transpose_a)
423                                .map(Tensor::F32)
424                        }
425                        (Tensor::F64(a), Tensor::F64(b)) => {
426                            linalg::blas::full_piv_lu_solve(buffers, a, b, transpose_a)
427                                .map(Tensor::F64)
428                        }
429                        (Tensor::C32(a), Tensor::C32(b)) => {
430                            linalg::blas::full_piv_lu_solve(buffers, a, b, transpose_a)
431                                .map(Tensor::C32)
432                        }
433                        (Tensor::C64(a), Tensor::C64(b)) => {
434                            linalg::blas::full_piv_lu_solve(buffers, a, b, transpose_a)
435                                .map(Tensor::C64)
436                        }
437                        _ => unsupported_pair("full_piv_lu_solve", a, &rhs),
438                    })
439                }
440                #[cfg(not(feature = "cpu-blas"))]
441                {
442                    Err(unsupported_provider("full_piv_lu_solve", self.kind()))
443                }
444            }
445        }?;
446
447        if let Some(shape) = restore_shape {
448            self.reshape(&result, &shape)
449        } else {
450            Ok(result)
451        }
452    }
453
454    fn svd(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
455        ensure_host_tensor("svd", input)?;
456        match self.kind() {
457            CpuBackendKind::Faer => {
458                #[cfg(feature = "cpu-faer")]
459                {
460                    let ctx = self.linalg_context();
461                    self.with_linalg_pool(|buffers| match input {
462                        Tensor::F32(t) => linalg::faer::svd(ctx.as_ref(), buffers, t)
463                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
464                        Tensor::F64(t) => linalg::faer::svd(ctx.as_ref(), buffers, t)
465                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
466                        Tensor::C32(t) => linalg::faer::svd(ctx.as_ref(), buffers, t)
467                            .and_then(svd_c32_outputs_to_public_tensors),
468                        Tensor::C64(t) => linalg::faer::svd(ctx.as_ref(), buffers, t)
469                            .and_then(svd_c64_outputs_to_public_tensors),
470                        _ => Err(unsupported_dtype("svd", input.dtype())),
471                    })
472                }
473                #[cfg(not(feature = "cpu-faer"))]
474                {
475                    Err(unsupported_provider("svd", self.kind()))
476                }
477            }
478            CpuBackendKind::Blas => {
479                #[cfg(feature = "cpu-blas")]
480                {
481                    self.with_linalg_pool(|buffers| match input {
482                        Tensor::F32(t) => linalg::blas::svd(buffers, t)
483                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
484                        Tensor::F64(t) => linalg::blas::svd(buffers, t)
485                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
486                        Tensor::C32(t) => linalg::blas::svd(buffers, t)
487                            .and_then(svd_c32_outputs_to_public_tensors),
488                        Tensor::C64(t) => linalg::blas::svd(buffers, t)
489                            .and_then(svd_c64_outputs_to_public_tensors),
490                        _ => Err(unsupported_dtype("svd", input.dtype())),
491                    })
492                }
493                #[cfg(not(feature = "cpu-blas"))]
494                {
495                    Err(unsupported_provider("svd", self.kind()))
496                }
497            }
498        }
499    }
500
501    fn svd_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
502        ensure_host_tensor("svd_values", input)?;
503        match self.kind() {
504            CpuBackendKind::Faer => {
505                #[cfg(feature = "cpu-faer")]
506                {
507                    let ctx = self.linalg_context();
508                    self.with_linalg_pool(|buffers| match input {
509                        Tensor::F32(t) => {
510                            linalg::faer::svd_values(ctx.as_ref(), buffers, t).map(Tensor::F32)
511                        }
512                        Tensor::F64(t) => {
513                            linalg::faer::svd_values(ctx.as_ref(), buffers, t).map(Tensor::F64)
514                        }
515                        Tensor::C32(t) => {
516                            linalg::faer::svd_values(ctx.as_ref(), buffers, t).map(Tensor::F32)
517                        }
518                        Tensor::C64(t) => {
519                            linalg::faer::svd_values(ctx.as_ref(), buffers, t).map(Tensor::F64)
520                        }
521                        _ => Err(unsupported_dtype("svd_values", input.dtype())),
522                    })
523                }
524                #[cfg(not(feature = "cpu-faer"))]
525                {
526                    Err(unsupported_provider("svd_values", self.kind()))
527                }
528            }
529            CpuBackendKind::Blas => {
530                #[cfg(feature = "cpu-blas")]
531                {
532                    self.with_linalg_pool(|buffers| match input {
533                        Tensor::F32(t) => linalg::blas::svd_values(buffers, t).map(Tensor::F32),
534                        Tensor::F64(t) => linalg::blas::svd_values(buffers, t).map(Tensor::F64),
535                        Tensor::C32(t) => linalg::blas::svd_values(buffers, t).map(Tensor::F32),
536                        Tensor::C64(t) => linalg::blas::svd_values(buffers, t).map(Tensor::F64),
537                        _ => Err(unsupported_dtype("svd_values", input.dtype())),
538                    })
539                }
540                #[cfg(not(feature = "cpu-blas"))]
541                {
542                    Err(unsupported_provider("svd_values", self.kind()))
543                }
544            }
545        }
546    }
547
548    fn svd_read(&mut self, input: TensorView<'_>) -> tenferro_tensor::Result<Vec<Tensor>> {
549        match input {
550            TensorView::F32(view) => {
551                let compact = self.to_contiguous(&view)?;
552                let input = Tensor::F32(compact);
553                self.svd(&input)
554            }
555            TensorView::F64(view) => {
556                let compact = self.to_contiguous(&view)?;
557                let input = Tensor::F64(compact);
558                self.svd(&input)
559            }
560            TensorView::C32(view) => {
561                let compact = self.to_contiguous(&view)?;
562                let input = Tensor::C32(compact);
563                self.svd(&input)
564            }
565            TensorView::C64(view) => {
566                let compact = self.to_contiguous(&view)?;
567                let input = Tensor::C64(compact);
568                self.svd(&input)
569            }
570            TensorView::I32(_) | TensorView::I64(_) | TensorView::Bool(_) => {
571                Err(unsupported_dtype("svd", input.dtype()))
572            }
573        }
574    }
575
576    fn qr(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
577        ensure_host_tensor("qr", input)?;
578        match self.kind() {
579            CpuBackendKind::Faer => {
580                #[cfg(feature = "cpu-faer")]
581                {
582                    let ctx = self.linalg_context();
583                    self.with_linalg_pool(|buffers| match input {
584                        Tensor::F32(t) => linalg::faer::qr(ctx.as_ref(), buffers, t)
585                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
586                        Tensor::F64(t) => linalg::faer::qr(ctx.as_ref(), buffers, t)
587                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
588                        Tensor::C32(t) => linalg::faer::qr(ctx.as_ref(), buffers, t)
589                            .map(|outputs| outputs.into_iter().map(Tensor::C32).collect()),
590                        Tensor::C64(t) => linalg::faer::qr(ctx.as_ref(), buffers, t)
591                            .map(|outputs| outputs.into_iter().map(Tensor::C64).collect()),
592                        _ => Err(unsupported_dtype("qr", input.dtype())),
593                    })
594                }
595                #[cfg(not(feature = "cpu-faer"))]
596                {
597                    Err(unsupported_provider("qr", self.kind()))
598                }
599            }
600            CpuBackendKind::Blas => {
601                #[cfg(feature = "cpu-blas")]
602                {
603                    self.with_linalg_pool(|buffers| match input {
604                        Tensor::F32(t) => linalg::blas::qr(buffers, t)
605                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
606                        Tensor::F64(t) => linalg::blas::qr(buffers, t)
607                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
608                        Tensor::C32(t) => linalg::blas::qr(buffers, t)
609                            .map(|outputs| outputs.into_iter().map(Tensor::C32).collect()),
610                        Tensor::C64(t) => linalg::blas::qr(buffers, t)
611                            .map(|outputs| outputs.into_iter().map(Tensor::C64).collect()),
612                        _ => Err(unsupported_dtype("qr", input.dtype())),
613                    })
614                }
615                #[cfg(not(feature = "cpu-blas"))]
616                {
617                    Err(unsupported_provider("qr", self.kind()))
618                }
619            }
620        }
621    }
622
623    fn eigh(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
624        ensure_host_tensor("eigh", input)?;
625        match self.kind() {
626            CpuBackendKind::Faer => {
627                #[cfg(feature = "cpu-faer")]
628                {
629                    let ctx = self.linalg_context();
630                    self.with_linalg_pool(|buffers| match input {
631                        Tensor::F32(t) => linalg::faer::eigh(ctx.as_ref(), buffers, t)
632                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
633                        Tensor::F64(t) => linalg::faer::eigh(ctx.as_ref(), buffers, t)
634                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
635                        Tensor::C32(t) => linalg::faer::eigh(ctx.as_ref(), buffers, t)
636                            .and_then(eigh_c32_outputs_to_public_tensors),
637                        Tensor::C64(t) => linalg::faer::eigh(ctx.as_ref(), buffers, t)
638                            .and_then(eigh_c64_outputs_to_public_tensors),
639                        _ => Err(unsupported_dtype("eigh", input.dtype())),
640                    })
641                }
642                #[cfg(not(feature = "cpu-faer"))]
643                {
644                    Err(unsupported_provider("eigh", self.kind()))
645                }
646            }
647            CpuBackendKind::Blas => {
648                #[cfg(feature = "cpu-blas")]
649                {
650                    self.with_linalg_pool(|buffers| match input {
651                        Tensor::F32(t) => linalg::blas::eigh(buffers, t)
652                            .map(|outputs| outputs.into_iter().map(Tensor::F32).collect()),
653                        Tensor::F64(t) => linalg::blas::eigh(buffers, t)
654                            .map(|outputs| outputs.into_iter().map(Tensor::F64).collect()),
655                        Tensor::C32(t) => linalg::blas::eigh(buffers, t)
656                            .and_then(eigh_c32_outputs_to_public_tensors),
657                        Tensor::C64(t) => linalg::blas::eigh(buffers, t)
658                            .and_then(eigh_c64_outputs_to_public_tensors),
659                        _ => Err(unsupported_dtype("eigh", input.dtype())),
660                    })
661                }
662                #[cfg(not(feature = "cpu-blas"))]
663                {
664                    Err(unsupported_provider("eigh", self.kind()))
665                }
666            }
667        }
668    }
669
670    fn eigh_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
671        ensure_host_tensor("eigh_values", input)?;
672        match self.kind() {
673            CpuBackendKind::Faer => {
674                #[cfg(feature = "cpu-faer")]
675                {
676                    let ctx = self.linalg_context();
677                    self.with_linalg_pool(|buffers| match input {
678                        Tensor::F32(t) => {
679                            linalg::faer::eigh_values(ctx.as_ref(), buffers, t).map(Tensor::F32)
680                        }
681                        Tensor::F64(t) => {
682                            linalg::faer::eigh_values(ctx.as_ref(), buffers, t).map(Tensor::F64)
683                        }
684                        Tensor::C32(t) => {
685                            linalg::faer::eigh_values(ctx.as_ref(), buffers, t).map(Tensor::F32)
686                        }
687                        Tensor::C64(t) => {
688                            linalg::faer::eigh_values(ctx.as_ref(), buffers, t).map(Tensor::F64)
689                        }
690                        _ => Err(unsupported_dtype("eigh_values", input.dtype())),
691                    })
692                }
693                #[cfg(not(feature = "cpu-faer"))]
694                {
695                    Err(unsupported_provider("eigh_values", self.kind()))
696                }
697            }
698            CpuBackendKind::Blas => {
699                #[cfg(feature = "cpu-blas")]
700                {
701                    self.with_linalg_pool(|buffers| match input {
702                        Tensor::F32(t) => linalg::blas::eigh_values(buffers, t).map(Tensor::F32),
703                        Tensor::F64(t) => linalg::blas::eigh_values(buffers, t).map(Tensor::F64),
704                        Tensor::C32(t) => linalg::blas::eigh_values(buffers, t).map(Tensor::F32),
705                        Tensor::C64(t) => linalg::blas::eigh_values(buffers, t).map(Tensor::F64),
706                        _ => Err(unsupported_dtype("eigh_values", input.dtype())),
707                    })
708                }
709                #[cfg(not(feature = "cpu-blas"))]
710                {
711                    Err(unsupported_provider("eigh_values", self.kind()))
712                }
713            }
714        }
715    }
716
717    fn eig(&mut self, input: &Tensor) -> tenferro_tensor::Result<Vec<Tensor>> {
718        ensure_host_tensor("eig", input)?;
719        if !matches!(
720            input,
721            Tensor::F32(_) | Tensor::F64(_) | Tensor::C32(_) | Tensor::C64(_)
722        ) {
723            return Err(unsupported_dtype("eig", input.dtype()));
724        }
725        match self.kind() {
726            CpuBackendKind::Faer => {
727                #[cfg(feature = "cpu-faer")]
728                {
729                    let ctx = self.linalg_context();
730                    self.with_linalg_pool(|buffers| linalg::faer::eig(ctx.as_ref(), buffers, input))
731                }
732                #[cfg(not(feature = "cpu-faer"))]
733                {
734                    Err(unsupported_provider("eig", self.kind()))
735                }
736            }
737            CpuBackendKind::Blas => {
738                #[cfg(feature = "cpu-blas")]
739                {
740                    self.with_linalg_pool(|buffers| linalg::blas::eig(buffers, input))
741                }
742                #[cfg(not(feature = "cpu-blas"))]
743                {
744                    Err(unsupported_provider("eig", self.kind()))
745                }
746            }
747        }
748    }
749
750    fn eig_values(&mut self, input: &Tensor) -> tenferro_tensor::Result<Tensor> {
751        ensure_host_tensor("eig_values", input)?;
752        if !matches!(
753            input,
754            Tensor::F32(_) | Tensor::F64(_) | Tensor::C32(_) | Tensor::C64(_)
755        ) {
756            return Err(unsupported_dtype("eig_values", input.dtype()));
757        }
758        match self.kind() {
759            CpuBackendKind::Faer => {
760                #[cfg(feature = "cpu-faer")]
761                {
762                    let ctx = self.linalg_context();
763                    self.with_linalg_pool(|buffers| {
764                        linalg::faer::eig_values(ctx.as_ref(), buffers, input)
765                    })
766                }
767                #[cfg(not(feature = "cpu-faer"))]
768                {
769                    Err(unsupported_provider("eig_values", self.kind()))
770                }
771            }
772            CpuBackendKind::Blas => {
773                #[cfg(feature = "cpu-blas")]
774                {
775                    self.with_linalg_pool(|buffers| linalg::blas::eig_values(buffers, input))
776                }
777                #[cfg(not(feature = "cpu-blas"))]
778                {
779                    Err(unsupported_provider("eig_values", self.kind()))
780                }
781            }
782        }
783    }
784
785    fn lu_solve_prepared(
786        &mut self,
787        a: &Tensor,
788        packed_lu: &Tensor,
789        pivots: &Tensor,
790        b: &Tensor,
791        transpose_a: bool,
792        conjugate_a: bool,
793    ) -> tenferro_tensor::Result<Tensor> {
794        const OP: &str = "lu_solve_prepared";
795
796        ensure_host_tensor(OP, a)?;
797        ensure_host_tensor(OP, packed_lu)?;
798        ensure_host_tensor(OP, pivots)?;
799        ensure_host_tensor(OP, b)?;
800        ensure_supported_linalg_pair(OP, a, b)?;
801        ensure_supported_linalg_pair(OP, a, packed_lu)?;
802        if !matches!(pivots, Tensor::I32(_)) {
803            return Err(Error::DTypeMismatch {
804                op: OP,
805                lhs: DType::I32,
806                rhs: pivots.dtype(),
807            });
808        }
809        if has_zero_dim(a.shape()) || has_zero_dim(b.shape()) {
810            return zeros_like_tensor(b);
811        }
812
813        let (rhs, restore_shape) = if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
814            (
815                self.reshape(b, &matrix_rhs_shape)?,
816                Some(b.shape().to_vec()),
817            )
818        } else {
819            (b.clone(), None)
820        };
821
822        validate_lu_solve_prepared_shapes(packed_lu.shape(), pivots.shape(), rhs.shape())?;
823        validate_nonsingular_u(packed_lu)?;
824        let lu_op = if conjugate_a {
825            self.conj(packed_lu)?
826        } else {
827            packed_lu.clone()
828        };
829        let result = if transpose_a {
830            let z = self.triangular_solve(&lu_op, &rhs, true, false, true, false)?;
831            let y = self.triangular_solve(&lu_op, &z, true, true, true, true)?;
832            apply_lu_pivots_cpu(&y, pivots, true)?
833        } else {
834            let pb = apply_lu_pivots_cpu(&rhs, pivots, false)?;
835            let y = self.triangular_solve(&lu_op, &pb, true, true, false, true)?;
836            self.triangular_solve(&lu_op, &y, true, false, false, false)?
837        };
838
839        if let Some(shape) = restore_shape {
840            self.reshape(&result, &shape)
841        } else {
842            Ok(result)
843        }
844    }
845
846    fn solve(&mut self, a: &Tensor, b: &Tensor) -> tenferro_tensor::Result<Tensor> {
847        ensure_host_tensor("solve", a)?;
848        ensure_host_tensor("solve", b)?;
849        ensure_supported_linalg_pair("solve", a, b)?;
850        if has_zero_dim(a.shape()) || has_zero_dim(b.shape()) {
851            return zeros_like_tensor(b);
852        }
853
854        let (rhs, restore_shape) = if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
855            (
856                self.reshape(b, &matrix_rhs_shape)?,
857                Some(b.shape().to_vec()),
858            )
859        } else {
860            (b.clone(), None)
861        };
862
863        let result = match self.kind() {
864            CpuBackendKind::Faer => {
865                #[cfg(feature = "cpu-faer")]
866                {
867                    let ctx = self.linalg_context();
868                    self.with_linalg_pool(|buffers| match (a, &rhs) {
869                        (Tensor::F32(a), Tensor::F32(b)) => {
870                            linalg::faer::solve(ctx.as_ref(), buffers, a, b, false).map(Tensor::F32)
871                        }
872                        (Tensor::F64(a), Tensor::F64(b)) => {
873                            linalg::faer::solve(ctx.as_ref(), buffers, a, b, false).map(Tensor::F64)
874                        }
875                        (Tensor::C32(a), Tensor::C32(b)) => {
876                            linalg::faer::solve(ctx.as_ref(), buffers, a, b, false).map(Tensor::C32)
877                        }
878                        (Tensor::C64(a), Tensor::C64(b)) => {
879                            linalg::faer::solve(ctx.as_ref(), buffers, a, b, false).map(Tensor::C64)
880                        }
881                        _ => unsupported_pair("solve", a, &rhs),
882                    })
883                }
884                #[cfg(not(feature = "cpu-faer"))]
885                {
886                    Err(unsupported_provider("solve", self.kind()))
887                }
888            }
889            CpuBackendKind::Blas => {
890                #[cfg(feature = "cpu-blas")]
891                {
892                    self.with_linalg_pool(|buffers| match (a, &rhs) {
893                        (Tensor::F32(a), Tensor::F32(b)) => {
894                            linalg::blas::solve(buffers, a, b, false).map(Tensor::F32)
895                        }
896                        (Tensor::F64(a), Tensor::F64(b)) => {
897                            linalg::blas::solve(buffers, a, b, false).map(Tensor::F64)
898                        }
899                        (Tensor::C32(a), Tensor::C32(b)) => {
900                            linalg::blas::solve(buffers, a, b, false).map(Tensor::C32)
901                        }
902                        (Tensor::C64(a), Tensor::C64(b)) => {
903                            linalg::blas::solve(buffers, a, b, false).map(Tensor::C64)
904                        }
905                        _ => unsupported_pair("solve", a, &rhs),
906                    })
907                }
908                #[cfg(not(feature = "cpu-blas"))]
909                {
910                    Err(unsupported_provider("solve", self.kind()))
911                }
912            }
913        }?;
914
915        if let Some(shape) = restore_shape {
916            self.reshape(&result, &shape)
917        } else {
918            Ok(result)
919        }
920    }
921}
922
923fn ensure_host_tensor(op: &'static str, input: &Tensor) -> tenferro_tensor::Result<()> {
924    match input {
925        Tensor::F32(t) => ensure_host_typed_tensor(op, t),
926        Tensor::F64(t) => ensure_host_typed_tensor(op, t),
927        Tensor::I32(t) => ensure_host_typed_tensor(op, t),
928        Tensor::I64(t) => ensure_host_typed_tensor(op, t),
929        Tensor::Bool(t) => ensure_host_typed_tensor(op, t),
930        Tensor::C32(t) => ensure_host_typed_tensor(op, t),
931        Tensor::C64(t) => ensure_host_typed_tensor(op, t),
932    }
933}
934
935fn ensure_host_typed_tensor<T: 'static>(
936    op: &'static str,
937    input: &TypedTensor<T>,
938) -> tenferro_tensor::Result<()> {
939    if input.as_view().backend_buffer().is_some() {
940        return Err(Error::backend_failure(
941            op,
942            "CPU linalg backend received a backend buffer; download the tensor to host before CPU execution",
943        ));
944    }
945    Ok(())
946}
947
948fn ensure_supported_linalg_pair(
949    op: &'static str,
950    lhs: &Tensor,
951    rhs: &Tensor,
952) -> tenferro_tensor::Result<()> {
953    if lhs.dtype() != rhs.dtype() {
954        return Err(Error::DTypeMismatch {
955            op,
956            lhs: lhs.dtype(),
957            rhs: rhs.dtype(),
958        });
959    }
960    match lhs {
961        Tensor::F32(_) | Tensor::F64(_) | Tensor::C32(_) | Tensor::C64(_) => Ok(()),
962        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => {
963            Err(unsupported_dtype(op, lhs.dtype()))
964        }
965    }
966}
967
968fn has_zero_dim(shape: &[usize]) -> bool {
969    shape.contains(&0)
970}
971
972fn batch_count(batch_shape: &[usize]) -> usize {
973    batch_shape.iter().product::<usize>().max(1)
974}
975
976fn batched_vector_rhs_shape(a: &Tensor, b: &Tensor) -> Option<Vec<usize>> {
977    if b.shape().len() == 1 {
978        return Some(vec![b.shape()[0], 1]);
979    }
980
981    let is_batched_vector_rhs = a.shape().len() == b.shape().len() + 1
982        && !b.shape().is_empty()
983        && b.shape()[0] == a.shape()[0]
984        && b.shape()[1..] == a.shape()[2..];
985    if !is_batched_vector_rhs {
986        return None;
987    }
988
989    let mut rhs_shape = vec![b.shape()[0], 1];
990    rhs_shape.extend_from_slice(&b.shape()[1..]);
991    Some(rhs_shape)
992}
993
994fn zeros_like_tensor(input: &Tensor) -> tenferro_tensor::Result<Tensor> {
995    Ok(match input {
996        Tensor::F32(t) => Tensor::F32(TypedTensor::zeros(t.shape().to_vec())?),
997        Tensor::F64(t) => Tensor::F64(TypedTensor::zeros(t.shape().to_vec())?),
998        Tensor::I32(t) => Tensor::I32(TypedTensor::zeros(t.shape().to_vec())?),
999        Tensor::I64(t) => Tensor::I64(TypedTensor::zeros(t.shape().to_vec())?),
1000        Tensor::Bool(t) => Tensor::Bool(TypedTensor::from_vec_col_major(
1001            t.shape().to_vec(),
1002            vec![false; t.n_elements()],
1003        )?),
1004        Tensor::C32(t) => Tensor::C32(TypedTensor::zeros(t.shape().to_vec())?),
1005        Tensor::C64(t) => Tensor::C64(TypedTensor::zeros(t.shape().to_vec())?),
1006    })
1007}
1008
1009fn complex32_real_part_tensor(
1010    values: TypedTensor<Complex32>,
1011) -> tenferro_tensor::Result<TypedTensor<f32>> {
1012    let mut out = TypedTensor::from_vec_col_major(
1013        values.shape().to_vec(),
1014        values.host_data()?.iter().map(|value| value.re).collect(),
1015    )?;
1016    out.set_placement(values.placement().clone());
1017    Ok(out)
1018}
1019
1020fn complex64_real_part_tensor(
1021    values: TypedTensor<Complex64>,
1022) -> tenferro_tensor::Result<TypedTensor<f64>> {
1023    let mut out = TypedTensor::from_vec_col_major(
1024        values.shape().to_vec(),
1025        values.host_data()?.iter().map(|value| value.re).collect(),
1026    )?;
1027    out.set_placement(values.placement().clone());
1028    Ok(out)
1029}
1030
1031fn svd_output_count_error(count: usize) -> Error {
1032    Error::backend_failure("svd", format!("expected 3 outputs, got {count}"))
1033}
1034
1035fn full_piv_lu_output_count_error(count: usize) -> Error {
1036    Error::backend_failure("full_piv_lu", format!("expected 5 outputs, got {count}"))
1037}
1038
1039fn eigh_output_count_error(count: usize) -> Error {
1040    Error::backend_failure("eigh", format!("expected 2 outputs, got {count}"))
1041}
1042
1043fn full_piv_lu_c32_outputs_to_public_tensors(
1044    outputs: Vec<TypedTensor<Complex32>>,
1045) -> tenferro_tensor::Result<Vec<Tensor>> {
1046    let count = outputs.len();
1047    let mut outputs = outputs.into_iter();
1048    match (
1049        outputs.next(),
1050        outputs.next(),
1051        outputs.next(),
1052        outputs.next(),
1053        outputs.next(),
1054        outputs.next(),
1055    ) {
1056        (Some(p), Some(l), Some(u), Some(q), Some(parity), None) => Ok(vec![
1057            Tensor::C32(p),
1058            Tensor::C32(l),
1059            Tensor::C32(u),
1060            Tensor::C32(q),
1061            Tensor::F32(complex32_real_part_tensor(parity)?),
1062        ]),
1063        _ => Err(full_piv_lu_output_count_error(count)),
1064    }
1065}
1066
1067fn full_piv_lu_c64_outputs_to_public_tensors(
1068    outputs: Vec<TypedTensor<Complex64>>,
1069) -> tenferro_tensor::Result<Vec<Tensor>> {
1070    let count = outputs.len();
1071    let mut outputs = outputs.into_iter();
1072    match (
1073        outputs.next(),
1074        outputs.next(),
1075        outputs.next(),
1076        outputs.next(),
1077        outputs.next(),
1078        outputs.next(),
1079    ) {
1080        (Some(p), Some(l), Some(u), Some(q), Some(parity), None) => Ok(vec![
1081            Tensor::C64(p),
1082            Tensor::C64(l),
1083            Tensor::C64(u),
1084            Tensor::C64(q),
1085            Tensor::F64(complex64_real_part_tensor(parity)?),
1086        ]),
1087        _ => Err(full_piv_lu_output_count_error(count)),
1088    }
1089}
1090
1091fn svd_c32_outputs_to_public_tensors(
1092    outputs: Vec<TypedTensor<Complex32>>,
1093) -> tenferro_tensor::Result<Vec<Tensor>> {
1094    let count = outputs.len();
1095    let mut outputs = outputs.into_iter();
1096    match (
1097        outputs.next(),
1098        outputs.next(),
1099        outputs.next(),
1100        outputs.next(),
1101    ) {
1102        (Some(u), Some(values), Some(vt), None) => Ok(vec![
1103            Tensor::C32(u),
1104            Tensor::F32(complex32_real_part_tensor(values)?),
1105            Tensor::C32(vt),
1106        ]),
1107        _ => Err(svd_output_count_error(count)),
1108    }
1109}
1110
1111fn svd_c64_outputs_to_public_tensors(
1112    outputs: Vec<TypedTensor<Complex64>>,
1113) -> tenferro_tensor::Result<Vec<Tensor>> {
1114    let count = outputs.len();
1115    let mut outputs = outputs.into_iter();
1116    match (
1117        outputs.next(),
1118        outputs.next(),
1119        outputs.next(),
1120        outputs.next(),
1121    ) {
1122        (Some(u), Some(values), Some(vt), None) => Ok(vec![
1123            Tensor::C64(u),
1124            Tensor::F64(complex64_real_part_tensor(values)?),
1125            Tensor::C64(vt),
1126        ]),
1127        _ => Err(svd_output_count_error(count)),
1128    }
1129}
1130
1131fn eigh_c32_outputs_to_public_tensors(
1132    outputs: Vec<TypedTensor<Complex32>>,
1133) -> tenferro_tensor::Result<Vec<Tensor>> {
1134    let count = outputs.len();
1135    let mut outputs = outputs.into_iter();
1136    match (outputs.next(), outputs.next(), outputs.next()) {
1137        (Some(values), Some(vectors), None) => Ok(vec![
1138            Tensor::F32(complex32_real_part_tensor(values)?),
1139            Tensor::C32(vectors),
1140        ]),
1141        _ => Err(eigh_output_count_error(count)),
1142    }
1143}
1144
1145fn eigh_c64_outputs_to_public_tensors(
1146    outputs: Vec<TypedTensor<Complex64>>,
1147) -> tenferro_tensor::Result<Vec<Tensor>> {
1148    let count = outputs.len();
1149    let mut outputs = outputs.into_iter();
1150    match (outputs.next(), outputs.next(), outputs.next()) {
1151        (Some(values), Some(vectors), None) => Ok(vec![
1152            Tensor::F64(complex64_real_part_tensor(values)?),
1153            Tensor::C64(vectors),
1154        ]),
1155        _ => Err(eigh_output_count_error(count)),
1156    }
1157}
1158
1159fn apply_lu_pivots_cpu(
1160    input: &Tensor,
1161    pivots: &Tensor,
1162    inverse: bool,
1163) -> tenferro_tensor::Result<Tensor> {
1164    let Tensor::I32(pivots) = pivots else {
1165        return Err(Error::DTypeMismatch {
1166            op: "lu_solve_prepared",
1167            lhs: DType::I32,
1168            rhs: pivots.dtype(),
1169        });
1170    };
1171    match input {
1172        Tensor::F32(t) => apply_lu_pivots_typed(t, pivots, inverse).map(Tensor::F32),
1173        Tensor::F64(t) => apply_lu_pivots_typed(t, pivots, inverse).map(Tensor::F64),
1174        Tensor::C32(t) => apply_lu_pivots_typed(t, pivots, inverse).map(Tensor::C32),
1175        Tensor::C64(t) => apply_lu_pivots_typed(t, pivots, inverse).map(Tensor::C64),
1176        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => {
1177            Err(unsupported_dtype("lu_solve_prepared", input.dtype()))
1178        }
1179    }
1180}
1181
1182fn apply_lu_pivots_typed<T: Clone>(
1183    input: &TypedTensor<T>,
1184    pivots: &TypedTensor<i32>,
1185    inverse: bool,
1186) -> tenferro_tensor::Result<TypedTensor<T>> {
1187    let shape = input.shape();
1188    if shape.len() < 2 {
1189        return Err(Error::RankMismatch {
1190            op: "lu_solve_prepared",
1191            expected: 2,
1192            actual: shape.len(),
1193        });
1194    }
1195    let rows = shape[0];
1196    let cols = shape[1];
1197    let k = pivots.shape()[0];
1198    if k > rows || pivots.shape()[1..] != shape[2..] {
1199        return Err(Error::ShapeMismatch {
1200            op: "lu_solve_prepared",
1201            lhs: pivots.shape().to_vec(),
1202            rhs: shape.to_vec(),
1203        });
1204    }
1205    let batch_total = batch_count(&shape[2..]);
1206    let matrix_stride = rows * cols;
1207    let pivot_stride = k;
1208    let input_data = input.host_data()?;
1209    let pivot_data = pivots.host_data()?;
1210    let mut data = Vec::with_capacity(input_data.len());
1211
1212    for batch in 0..batch_total {
1213        let mut perm: Vec<usize> = (0..rows).collect();
1214        let pivot_offset = batch * pivot_stride;
1215        for step in 0..k {
1216            let pivot_one_based = pivot_data[pivot_offset + step];
1217            if pivot_one_based <= 0 {
1218                return Err(Error::backend_failure(
1219                    "lu_solve_prepared",
1220                    "LU pivot index must be 1-based and positive",
1221                ));
1222            }
1223            let pivot = usize::try_from(pivot_one_based - 1).map_err(|_| {
1224                Error::backend_failure("lu_solve_prepared", "LU pivot index is invalid")
1225            })?;
1226            if pivot >= rows {
1227                return Err(Error::backend_failure(
1228                    "lu_solve_prepared",
1229                    "LU pivot index is out of bounds",
1230                ));
1231            }
1232            perm.swap(step, pivot);
1233        }
1234        let row_map = if inverse {
1235            let mut inv = vec![0usize; rows];
1236            for (row, &source) in perm.iter().enumerate() {
1237                inv[source] = row;
1238            }
1239            inv
1240        } else {
1241            perm
1242        };
1243        let batch_offset = batch * matrix_stride;
1244        for col in 0..cols {
1245            for &source_row in &row_map {
1246                data.push(input_data[batch_offset + source_row + col * rows].clone());
1247            }
1248        }
1249    }
1250
1251    TypedTensor::from_vec_col_major(shape.to_vec(), data)
1252}
1253
1254fn validate_lu_solve_prepared_shapes(
1255    lu_shape: &[usize],
1256    pivots_shape: &[usize],
1257    b_shape: &[usize],
1258) -> tenferro_tensor::Result<()> {
1259    let n = square_matrix_dim("lu_solve_prepared", lu_shape)?;
1260    let (b_rows, _) = matrix_dims("lu_solve_prepared", b_shape)?;
1261    if b_rows != n {
1262        return Err(Error::InvalidConfig {
1263            op: "lu_solve_prepared",
1264            message: format!("rhs row count mismatch: expected {n}, got {b_rows}"),
1265        });
1266    }
1267    if lu_shape[2..] != b_shape[2..] {
1268        return Err(Error::ShapeMismatch {
1269            op: "lu_solve_prepared",
1270            lhs: lu_shape.to_vec(),
1271            rhs: b_shape.to_vec(),
1272        });
1273    }
1274    let mut expected_pivots = vec![n];
1275    expected_pivots.extend_from_slice(&lu_shape[2..]);
1276    if pivots_shape != expected_pivots {
1277        return Err(Error::ShapeMismatch {
1278            op: "lu_solve_prepared",
1279            lhs: expected_pivots,
1280            rhs: pivots_shape.to_vec(),
1281        });
1282    }
1283    Ok(())
1284}
1285
1286fn matrix_dims(op: &'static str, shape: &[usize]) -> tenferro_tensor::Result<(usize, usize)> {
1287    if shape.len() < 2 {
1288        return Err(Error::RankMismatch {
1289            op,
1290            expected: 2,
1291            actual: shape.len(),
1292        });
1293    }
1294    Ok((shape[0], shape[1]))
1295}
1296
1297fn square_matrix_dim(op: &'static str, shape: &[usize]) -> tenferro_tensor::Result<usize> {
1298    let (rows, cols) = matrix_dims(op, shape)?;
1299    if rows != cols {
1300        return Err(Error::ShapeMismatch {
1301            op,
1302            lhs: vec![rows],
1303            rhs: vec![cols],
1304        });
1305    }
1306    Ok(rows)
1307}
1308
1309// Used only by feature-disabled provider branches, so default feature builds
1310// may not compile a direct call site.
1311#[allow(dead_code)]
1312fn unsupported_provider(op: &'static str, kind: CpuBackendKind) -> Error {
1313    Error::InvalidConfig {
1314        op,
1315        message: format!("CPU linalg provider {kind:?} is not compiled in"),
1316    }
1317}
1318
1319fn unsupported_pair(
1320    op: &'static str,
1321    lhs: &Tensor,
1322    rhs: &Tensor,
1323) -> tenferro_tensor::Result<Tensor> {
1324    if lhs.dtype() != rhs.dtype() {
1325        Err(Error::DTypeMismatch {
1326            op,
1327            lhs: lhs.dtype(),
1328            rhs: rhs.dtype(),
1329        })
1330    } else {
1331        Err(unsupported_dtype(op, lhs.dtype()))
1332    }
1333}
1334
1335fn unsupported_dtype(op: &'static str, dtype: DType) -> Error {
1336    Error::backend_failure(op, format!("unsupported dtype {dtype:?}"))
1337}