Skip to main content

tenferro_tensor/cpu/linalg/
faer_linalg.rs

1use faer::dyn_stack::{MemBuffer, MemStack};
2use faer::{
3    diag::{Diag, DiagRef},
4    Conj, Mat, MatMut, MatRef,
5};
6use num_complex::Complex64;
7
8use crate::buffer_pool::{BufferPool, PoolScalar};
9use crate::cpu::CpuContext;
10use crate::{Buffer, Tensor, TypedTensor};
11
12pub(crate) trait FaerLinalg: Copy + Clone + PoolScalar {
13    fn parity_one() -> Self;
14    fn cholesky_2d(
15        ctx: &CpuContext,
16        buffers: &mut BufferPool,
17        input: &TypedTensor<Self>,
18    ) -> crate::Result<TypedTensor<Self>>;
19    fn lu_2d(
20        ctx: &CpuContext,
21        buffers: &mut BufferPool,
22        input: &TypedTensor<Self>,
23    ) -> Vec<TypedTensor<Self>>;
24    fn triangular_solve_2d(
25        ctx: &CpuContext,
26        buffers: &mut BufferPool,
27        a: &TypedTensor<Self>,
28        b: &TypedTensor<Self>,
29        left_side: bool,
30        lower: bool,
31        transpose_a: bool,
32        unit_diagonal: bool,
33    ) -> TypedTensor<Self>;
34    fn svd_2d(
35        ctx: &CpuContext,
36        buffers: &mut BufferPool,
37        input: &TypedTensor<Self>,
38    ) -> Vec<TypedTensor<Self>>;
39    fn qr_2d(
40        ctx: &CpuContext,
41        buffers: &mut BufferPool,
42        input: &TypedTensor<Self>,
43    ) -> Vec<TypedTensor<Self>>;
44    fn eigh_2d(
45        ctx: &CpuContext,
46        buffers: &mut BufferPool,
47        input: &TypedTensor<Self>,
48    ) -> Vec<TypedTensor<Self>>;
49}
50
51fn matrix_dims<T>(input: &TypedTensor<T>, op: &str) -> (usize, usize) {
52    assert_eq!(input.shape.len(), 2, "{op}: expected a 2D matrix");
53    (input.shape[0], input.shape[1])
54}
55
56fn square_matrix_dim<T>(input: &TypedTensor<T>, op: &str) -> usize {
57    let (rows, cols) = matrix_dims(input, op);
58    assert_eq!(rows, cols, "{op}: expected a square matrix");
59    rows
60}
61
62fn tensor_from_vec_with_template<T: Clone, U>(
63    shape: Vec<usize>,
64    data: Vec<T>,
65    template: &TypedTensor<U>,
66) -> TypedTensor<T> {
67    TypedTensor {
68        buffer: Buffer::Host(data),
69        shape,
70        placement: template.placement.clone(),
71    }
72}
73
74fn col_major_vec_from_mat<T: Copy + PoolScalar>(
75    buffers: &mut BufferPool,
76    mat: MatRef<'_, T>,
77) -> Vec<T> {
78    let (rows, cols) = mat.shape();
79    let mut data = buffers.acquire_with_capacity::<T>(rows * cols);
80    for j in 0..cols {
81        for i in 0..rows {
82            data.push(mat[(i, j)]);
83        }
84    }
85    data
86}
87
88fn vec_from_diag<T: Copy + PoolScalar>(buffers: &mut BufferPool, diag: DiagRef<'_, T>) -> Vec<T> {
89    let col = diag.column_vector();
90    let mut data = buffers.acquire_with_capacity::<T>(col.nrows());
91    for i in 0..col.nrows() {
92        data.push(col[i]);
93    }
94    data
95}
96
97fn complex64_to_faer_slice(data: &[Complex64]) -> &[faer::c64] {
98    debug_assert_eq!(
99        std::mem::size_of::<Complex64>(),
100        std::mem::size_of::<faer::c64>()
101    );
102    debug_assert_eq!(
103        std::mem::align_of::<Complex64>(),
104        std::mem::align_of::<faer::c64>()
105    );
106
107    unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<faer::c64>(), data.len()) }
108}
109
110fn complex64_to_faer_slice_mut(data: &mut [Complex64]) -> &mut [faer::c64] {
111    debug_assert_eq!(
112        std::mem::size_of::<Complex64>(),
113        std::mem::size_of::<faer::c64>()
114    );
115    debug_assert_eq!(
116        std::mem::align_of::<Complex64>(),
117        std::mem::align_of::<faer::c64>()
118    );
119
120    unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr().cast::<faer::c64>(), data.len()) }
121}
122
123fn complex_vec_from_real_diag(
124    buffers: &mut BufferPool,
125    diag: DiagRef<'_, faer::c64>,
126) -> Vec<Complex64> {
127    let col = diag.column_vector();
128    let mut data = buffers.acquire_with_capacity::<Complex64>(col.nrows());
129    for i in 0..col.nrows() {
130        data.push(Complex64::new(col[i].re, 0.0));
131    }
132    data
133}
134
135fn complex_vec_from_diag(buffers: &mut BufferPool, diag: DiagRef<'_, faer::c64>) -> Vec<Complex64> {
136    let col = diag.column_vector();
137    let mut data = buffers.acquire_with_capacity::<Complex64>(col.nrows());
138    for i in 0..col.nrows() {
139        data.push(Complex64::new(col[i].re, col[i].im));
140    }
141    data
142}
143
144fn complex_vec_from_mat(buffers: &mut BufferPool, mat: MatRef<'_, faer::c64>) -> Vec<Complex64> {
145    let (rows, cols) = mat.shape();
146    let mut data = buffers.acquire_with_capacity::<Complex64>(rows * cols);
147    for j in 0..cols {
148        for i in 0..rows {
149            let value = mat[(i, j)];
150            data.push(Complex64::new(value.re, value.im));
151        }
152    }
153    data
154}
155
156fn matrix_from_predicate<T: Copy + Default>(
157    mat: MatRef<'_, T>,
158    rows: usize,
159    cols: usize,
160    predicate: impl Fn(usize, usize) -> bool,
161) -> Vec<T> {
162    let mut data = vec![T::default(); rows * cols];
163    for j in 0..cols {
164        for i in 0..rows {
165            if predicate(i, j) {
166                data[i + j * rows] = mat[(i, j)];
167            }
168        }
169    }
170    data
171}
172
173fn lower_triangle_vec_from_mat<T: Copy + Default>(mat: MatRef<'_, T>) -> Vec<T> {
174    let (rows, cols) = mat.shape();
175    matrix_from_predicate(mat, rows, cols, |row, col| row >= col)
176}
177
178fn upper_triangle_vec_from_mat<T: Copy + Default>(mat: MatRef<'_, T>) -> Vec<T> {
179    let (rows, cols) = mat.shape();
180    matrix_from_predicate(mat, rows, cols, |row, col| row <= col)
181}
182
183fn complex_matrix_from_predicate(
184    mat: MatRef<'_, faer::c64>,
185    rows: usize,
186    cols: usize,
187    predicate: impl Fn(usize, usize) -> bool,
188) -> Vec<Complex64> {
189    let mut data = vec![Complex64::new(0.0, 0.0); rows * cols];
190    for j in 0..cols {
191        for i in 0..rows {
192            if predicate(i, j) {
193                let value = mat[(i, j)];
194                data[i + j * rows] = Complex64::new(value.re, value.im);
195            }
196        }
197    }
198    data
199}
200
201fn real_eig_to_complex_outputs(
202    buffers: &mut BufferPool,
203    u_real: MatRef<'_, f64>,
204    s_re: DiagRef<'_, f64>,
205    s_im: DiagRef<'_, f64>,
206) -> (Vec<Complex64>, Vec<Complex64>) {
207    let n = u_real.nrows();
208    // SAFETY: the loop below writes every element of `u` and `s` before any read.
209    let mut u = unsafe { <Complex64 as PoolScalar>::pool_acquire(buffers, n * n) };
210    // SAFETY: the loop below writes every element of `u` and `s` before any read.
211    let mut s = unsafe { <Complex64 as PoolScalar>::pool_acquire(buffers, n) };
212    let mut j = 0;
213    while j < n {
214        if s_im[j] == 0.0 {
215            s[j] = Complex64::new(s_re[j], 0.0);
216            for i in 0..n {
217                u[i + j * n] = Complex64::new(u_real[(i, j)], 0.0);
218            }
219            j += 1;
220        } else {
221            s[j] = Complex64::new(s_re[j], s_im[j]);
222            s[j + 1] = Complex64::new(s_re[j], -s_im[j]);
223            for i in 0..n {
224                u[i + j * n] = Complex64::new(u_real[(i, j)], u_real[(i, j + 1)]);
225                u[i + (j + 1) * n] = Complex64::new(u_real[(i, j)], -u_real[(i, j + 1)]);
226            }
227            j += 2;
228        }
229    }
230    (u, s)
231}
232
233fn split_core_and_batch<'a, T>(
234    input: &'a TypedTensor<T>,
235    core_rank: usize,
236    op: &str,
237) -> (&'a [usize], &'a [usize]) {
238    assert!(
239        input.shape.len() >= core_rank,
240        "{op}: expected rank >= {core_rank}"
241    );
242    input.shape.split_at(core_rank)
243}
244
245fn transpose_col_major_data<T: Copy + PoolScalar>(
246    buffers: &mut BufferPool,
247    data: &[T],
248    rows: usize,
249    cols: usize,
250) -> Vec<T> {
251    let mut transposed = buffers.acquire_with_capacity::<T>(data.len());
252    for j in 0..rows {
253        for i in 0..cols {
254            transposed.push(data[j + i * rows]);
255        }
256    }
257    transposed
258}
259
260fn batched_single<T, F>(
261    buffers: &mut BufferPool,
262    input: &TypedTensor<T>,
263    core_rank: usize,
264    op: F,
265) -> crate::Result<TypedTensor<T>>
266where
267    T: Clone + PoolScalar,
268    F: Fn(&mut BufferPool, &TypedTensor<T>) -> crate::Result<TypedTensor<T>>,
269{
270    let (core_shape, batch_shape) = split_core_and_batch(input, core_rank, "batched_single");
271    if batch_shape.is_empty() {
272        return op(buffers, input);
273    }
274
275    let slice_size: usize = core_shape.iter().product();
276    let batch_count: usize = batch_shape.iter().product();
277    assert!(
278        batch_count > 0,
279        "batched_single: zero-sized batch dims are unsupported"
280    );
281
282    let mut out_core_shape: Option<Vec<usize>> = None;
283    let mut out_data: Option<Vec<T>> = None;
284
285    for batch_idx in 0..batch_count {
286        let start = batch_idx * slice_size;
287        let end = start + slice_size;
288        let batch_input = tensor_from_vec_with_template(
289            core_shape.to_vec(),
290            input.host_data()[start..end].to_vec(),
291            input,
292        );
293        let batch_output = op(buffers, &batch_input)?;
294
295        if let Some(expected_shape) = &out_core_shape {
296            assert_eq!(
297                batch_output.shape.as_slice(),
298                expected_shape.as_slice(),
299                "batched_single: output core shape mismatch across batches"
300            );
301        } else {
302            out_data =
303                Some(buffers.acquire_with_capacity::<T>(batch_output.n_elements() * batch_count));
304            out_core_shape = Some(batch_output.shape.clone());
305        }
306
307        out_data
308            .as_mut()
309            .expect("batched_single: missing output buffer")
310            .extend_from_slice(batch_output.host_data());
311    }
312
313    let mut out_shape = out_core_shape.expect("batched_single: missing output shape");
314    out_shape.extend_from_slice(batch_shape);
315    Ok(tensor_from_vec_with_template(
316        out_shape,
317        out_data.expect("batched_single: missing output data"),
318        input,
319    ))
320}
321
322fn batched_multi<T, F>(
323    buffers: &mut BufferPool,
324    input: &TypedTensor<T>,
325    core_rank: usize,
326    op: F,
327) -> Vec<TypedTensor<T>>
328where
329    T: Clone + PoolScalar,
330    F: Fn(&mut BufferPool, &TypedTensor<T>) -> Vec<TypedTensor<T>>,
331{
332    let (core_shape, batch_shape) = split_core_and_batch(input, core_rank, "batched_multi");
333    if batch_shape.is_empty() {
334        return op(buffers, input);
335    }
336
337    let slice_size: usize = core_shape.iter().product();
338    let batch_count: usize = batch_shape.iter().product();
339    assert!(
340        batch_count > 0,
341        "batched_multi: zero-sized batch dims are unsupported"
342    );
343
344    let mut out_shapes: Vec<Vec<usize>> = Vec::new();
345    let mut out_data: Vec<Vec<T>> = Vec::new();
346
347    for batch_idx in 0..batch_count {
348        let start = batch_idx * slice_size;
349        let end = start + slice_size;
350        let batch_input = tensor_from_vec_with_template(
351            core_shape.to_vec(),
352            input.host_data()[start..end].to_vec(),
353            input,
354        );
355        let batch_outputs = op(buffers, &batch_input);
356
357        if out_shapes.is_empty() {
358            out_shapes = batch_outputs
359                .iter()
360                .map(|tensor| tensor.shape.clone())
361                .collect();
362            let mut pooled_outputs = Vec::with_capacity(batch_outputs.len());
363            for tensor in &batch_outputs {
364                pooled_outputs
365                    .push(buffers.acquire_with_capacity::<T>(tensor.n_elements() * batch_count));
366            }
367            out_data = pooled_outputs;
368        } else {
369            assert_eq!(
370                batch_outputs.len(),
371                out_shapes.len(),
372                "batched_multi: output count mismatch across batches"
373            );
374        }
375
376        for (idx, batch_output) in batch_outputs.iter().enumerate() {
377            assert_eq!(
378                batch_output.shape.as_slice(),
379                out_shapes[idx].as_slice(),
380                "batched_multi: output core shape mismatch across batches"
381            );
382            out_data[idx].extend_from_slice(batch_output.host_data());
383        }
384    }
385
386    out_shapes
387        .into_iter()
388        .zip(out_data)
389        .map(|(mut out_shape, out_data)| {
390            out_shape.extend_from_slice(batch_shape);
391            tensor_from_vec_with_template(out_shape, out_data, input)
392        })
393        .collect()
394}
395
396fn batched_multi_convert<InT, OutT, F>(
397    buffers: &mut BufferPool,
398    input: &TypedTensor<InT>,
399    core_rank: usize,
400    op: F,
401) -> Vec<TypedTensor<OutT>>
402where
403    InT: Clone,
404    OutT: Clone + PoolScalar,
405    F: Fn(&mut BufferPool, &TypedTensor<InT>) -> Vec<TypedTensor<OutT>>,
406{
407    let (core_shape, batch_shape) = split_core_and_batch(input, core_rank, "batched_multi");
408    if batch_shape.is_empty() {
409        return op(buffers, input);
410    }
411
412    let slice_size: usize = core_shape.iter().product();
413    let batch_count: usize = batch_shape.iter().product();
414    assert!(
415        batch_count > 0,
416        "batched_multi: zero-sized batch dims are unsupported"
417    );
418
419    let mut out_shapes: Vec<Vec<usize>> = Vec::new();
420    let mut out_data: Vec<Vec<OutT>> = Vec::new();
421
422    for batch_idx in 0..batch_count {
423        let start = batch_idx * slice_size;
424        let end = start + slice_size;
425        let batch_input = tensor_from_vec_with_template(
426            core_shape.to_vec(),
427            input.host_data()[start..end].to_vec(),
428            input,
429        );
430        let batch_outputs = op(buffers, &batch_input);
431
432        if out_shapes.is_empty() {
433            out_shapes = batch_outputs
434                .iter()
435                .map(|tensor| tensor.shape.clone())
436                .collect();
437            let mut pooled_outputs = Vec::with_capacity(batch_outputs.len());
438            for tensor in &batch_outputs {
439                pooled_outputs
440                    .push(buffers.acquire_with_capacity::<OutT>(tensor.n_elements() * batch_count));
441            }
442            out_data = pooled_outputs;
443        } else {
444            assert_eq!(
445                batch_outputs.len(),
446                out_shapes.len(),
447                "batched_multi: output count mismatch across batches"
448            );
449        }
450
451        for (idx, batch_output) in batch_outputs.iter().enumerate() {
452            assert_eq!(
453                batch_output.shape.as_slice(),
454                out_shapes[idx].as_slice(),
455                "batched_multi: output core shape mismatch across batches"
456            );
457            out_data[idx].extend_from_slice(batch_output.host_data());
458        }
459    }
460
461    out_shapes
462        .into_iter()
463        .zip(out_data)
464        .map(|(mut out_shape, out_data)| {
465            out_shape.extend_from_slice(batch_shape);
466            tensor_from_vec_with_template(out_shape, out_data, input)
467        })
468        .collect()
469}
470
471fn batched_binary<T, F>(
472    buffers: &mut BufferPool,
473    a: &TypedTensor<T>,
474    b: &TypedTensor<T>,
475    core_rank_a: usize,
476    core_rank_b: usize,
477    op: F,
478) -> TypedTensor<T>
479where
480    T: Clone + PoolScalar,
481    F: Fn(&mut BufferPool, &TypedTensor<T>, &TypedTensor<T>) -> TypedTensor<T>,
482{
483    let (a_core_shape, a_batch_shape) = split_core_and_batch(a, core_rank_a, "batched_binary");
484    let (b_core_shape, b_batch_shape) = split_core_and_batch(b, core_rank_b, "batched_binary");
485    assert_eq!(
486        a_batch_shape, b_batch_shape,
487        "batched_binary: batch shape mismatch"
488    );
489
490    if a_batch_shape.is_empty() {
491        return op(buffers, a, b);
492    }
493
494    let a_slice_size: usize = a_core_shape.iter().product();
495    let b_slice_size: usize = b_core_shape.iter().product();
496    let batch_count: usize = a_batch_shape.iter().product();
497    assert!(
498        batch_count > 0,
499        "batched_binary: zero-sized batch dims are unsupported"
500    );
501
502    let mut out_core_shape: Option<Vec<usize>> = None;
503    let mut out_data: Option<Vec<T>> = None;
504
505    for batch_idx in 0..batch_count {
506        let a_start = batch_idx * a_slice_size;
507        let a_end = a_start + a_slice_size;
508        let b_start = batch_idx * b_slice_size;
509        let b_end = b_start + b_slice_size;
510
511        let batch_a = tensor_from_vec_with_template(
512            a_core_shape.to_vec(),
513            a.host_data()[a_start..a_end].to_vec(),
514            a,
515        );
516        let batch_b = tensor_from_vec_with_template(
517            b_core_shape.to_vec(),
518            b.host_data()[b_start..b_end].to_vec(),
519            b,
520        );
521        let batch_output = op(buffers, &batch_a, &batch_b);
522
523        if let Some(expected_shape) = &out_core_shape {
524            assert_eq!(
525                batch_output.shape.as_slice(),
526                expected_shape.as_slice(),
527                "batched_binary: output core shape mismatch across batches"
528            );
529        } else {
530            out_data =
531                Some(buffers.acquire_with_capacity::<T>(batch_output.n_elements() * batch_count));
532            out_core_shape = Some(batch_output.shape.clone());
533        }
534
535        out_data
536            .as_mut()
537            .expect("batched_binary: missing output buffer")
538            .extend_from_slice(batch_output.host_data());
539    }
540
541    let mut out_shape = out_core_shape.expect("batched_binary: missing output shape");
542    out_shape.extend_from_slice(a_batch_shape);
543    tensor_from_vec_with_template(
544        out_shape,
545        out_data.expect("batched_binary: missing output data"),
546        b,
547    )
548}
549
550impl FaerLinalg for f64 {
551    fn parity_one() -> Self {
552        1.0
553    }
554
555    fn cholesky_2d(
556        ctx: &CpuContext,
557        _buffers: &mut BufferPool,
558        input: &TypedTensor<Self>,
559    ) -> crate::Result<TypedTensor<Self>> {
560        let n = square_matrix_dim(input, "cholesky");
561        let mut l = Mat::zeros(n, n);
562        l.copy_from(MatRef::from_column_major_slice(input.host_data(), n, n));
563        let mut mem = MemBuffer::new(
564            faer::linalg::cholesky::llt::factor::cholesky_in_place_scratch::<Self>(
565                n,
566                ctx.faer_par(),
567                Default::default(),
568            ),
569        );
570        let stack = MemStack::new(&mut mem);
571        faer::linalg::cholesky::llt::factor::cholesky_in_place(
572            l.as_mut(),
573            Default::default(),
574            ctx.faer_par(),
575            stack,
576            Default::default(),
577        )
578        .map_err(|_| crate::Error::BackendFailure {
579            op: "cholesky",
580            message: "matrix is not positive definite".into(),
581        })?;
582        Ok(tensor_from_vec_with_template(
583            vec![n, n],
584            lower_triangle_vec_from_mat(l.as_ref()),
585            input,
586        ))
587    }
588
589    fn lu_2d(
590        ctx: &CpuContext,
591        _buffers: &mut BufferPool,
592        input: &TypedTensor<Self>,
593    ) -> Vec<TypedTensor<Self>> {
594        let (m, n) = matrix_dims(input, "lu");
595        let k = m.min(n);
596        let mut lu = Mat::zeros(m, n);
597        lu.copy_from(MatRef::from_column_major_slice(input.host_data(), m, n));
598        let mut perm = vec![0usize; m];
599        let mut perm_inv = vec![0usize; m];
600        let mut mem = MemBuffer::new(
601            faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, Self>(
602                m,
603                n,
604                ctx.faer_par(),
605                Default::default(),
606            ),
607        );
608        let stack = MemStack::new(&mut mem);
609        let info = faer::linalg::lu::partial_pivoting::factor::lu_in_place(
610            lu.as_mut(),
611            &mut perm,
612            &mut perm_inv,
613            ctx.faer_par(),
614            stack,
615            Default::default(),
616        )
617        .0;
618
619        let mut p_data = vec![0.0; m * m];
620        for (row, &col) in perm.iter().enumerate() {
621            p_data[row + col * m] = 1.0;
622        }
623        let parity = if info.transposition_count % 2 == 0 {
624            1.0
625        } else {
626            -1.0
627        };
628
629        let mut l_data = matrix_from_predicate(lu.as_ref(), m, k, |row, col| row >= col);
630        for i in 0..k {
631            l_data[i + i * m] = 1.0;
632        }
633        let u_data = upper_triangle_vec_from_mat(lu.as_ref().get(..k, ..));
634
635        vec![
636            tensor_from_vec_with_template(vec![m, m], p_data, input),
637            tensor_from_vec_with_template(vec![m, k], l_data, input),
638            tensor_from_vec_with_template(vec![k, n], u_data, input),
639            tensor_from_vec_with_template(vec![], vec![parity], input),
640        ]
641    }
642
643    fn triangular_solve_2d(
644        ctx: &CpuContext,
645        buffers: &mut BufferPool,
646        a: &TypedTensor<Self>,
647        b: &TypedTensor<Self>,
648        left_side: bool,
649        lower: bool,
650        transpose_a: bool,
651        unit_diagonal: bool,
652    ) -> TypedTensor<Self> {
653        let n = square_matrix_dim(a, "triangular_solve");
654        let (b_rows, b_cols) = matrix_dims(b, "triangular_solve");
655        let a_mat = MatRef::from_column_major_slice(a.host_data(), n, n);
656
657        if left_side {
658            assert_eq!(b_rows, n, "triangular_solve: rhs row count mismatch");
659            let mut rhs_data = buffers.acquire_with_capacity::<Self>(b.host_data().len());
660            rhs_data.extend_from_slice(b.host_data());
661            let rhs = MatMut::from_column_major_slice_mut(&mut rhs_data, n, b_cols);
662            match (transpose_a, lower, unit_diagonal) {
663                (false, true, false) => {
664                    faer::linalg::triangular_solve::solve_lower_triangular_in_place(
665                        a_mat,
666                        rhs,
667                        ctx.faer_par(),
668                    );
669                }
670                (false, true, true) => {
671                    faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
672                        a_mat,
673                        rhs,
674                        ctx.faer_par(),
675                    );
676                }
677                (false, false, false) => {
678                    faer::linalg::triangular_solve::solve_upper_triangular_in_place(
679                        a_mat,
680                        rhs,
681                        ctx.faer_par(),
682                    );
683                }
684                (false, false, true) => {
685                    faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
686                        a_mat,
687                        rhs,
688                        ctx.faer_par(),
689                    );
690                }
691                (true, true, false) => {
692                    faer::linalg::triangular_solve::solve_upper_triangular_in_place(
693                        a_mat.transpose(),
694                        rhs,
695                        ctx.faer_par(),
696                    );
697                }
698                (true, true, true) => {
699                    faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
700                        a_mat.transpose(),
701                        rhs,
702                        ctx.faer_par(),
703                    );
704                }
705                (true, false, false) => {
706                    faer::linalg::triangular_solve::solve_lower_triangular_in_place(
707                        a_mat.transpose(),
708                        rhs,
709                        ctx.faer_par(),
710                    );
711                }
712                (true, false, true) => {
713                    faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
714                        a_mat.transpose(),
715                        rhs,
716                        ctx.faer_par(),
717                    );
718                }
719            }
720            tensor_from_vec_with_template(vec![n, b_cols], rhs_data, b)
721        } else {
722            assert_eq!(b_cols, n, "triangular_solve: rhs column count mismatch");
723            let nrhs = b_rows;
724            let mut rhs_transposed = transpose_col_major_data(buffers, b.host_data(), nrhs, n);
725            let rhs = MatMut::from_column_major_slice_mut(&mut rhs_transposed, n, nrhs);
726            match (transpose_a, lower, unit_diagonal) {
727                (false, true, false) => {
728                    faer::linalg::triangular_solve::solve_upper_triangular_in_place(
729                        a_mat.transpose(),
730                        rhs,
731                        ctx.faer_par(),
732                    );
733                }
734                (false, true, true) => {
735                    faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
736                        a_mat.transpose(),
737                        rhs,
738                        ctx.faer_par(),
739                    );
740                }
741                (false, false, false) => {
742                    faer::linalg::triangular_solve::solve_lower_triangular_in_place(
743                        a_mat.transpose(),
744                        rhs,
745                        ctx.faer_par(),
746                    );
747                }
748                (false, false, true) => {
749                    faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
750                        a_mat.transpose(),
751                        rhs,
752                        ctx.faer_par(),
753                    );
754                }
755                (true, true, false) => {
756                    faer::linalg::triangular_solve::solve_lower_triangular_in_place(
757                        a_mat,
758                        rhs,
759                        ctx.faer_par(),
760                    );
761                }
762                (true, true, true) => {
763                    faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
764                        a_mat,
765                        rhs,
766                        ctx.faer_par(),
767                    );
768                }
769                (true, false, false) => {
770                    faer::linalg::triangular_solve::solve_upper_triangular_in_place(
771                        a_mat,
772                        rhs,
773                        ctx.faer_par(),
774                    );
775                }
776                (true, false, true) => {
777                    faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
778                        a_mat,
779                        rhs,
780                        ctx.faer_par(),
781                    );
782                }
783            }
784            let result = transpose_col_major_data(buffers, &rhs_transposed, n, nrhs);
785            <Self as PoolScalar>::pool_release(buffers, rhs_transposed);
786            tensor_from_vec_with_template(vec![nrhs, n], result, b)
787        }
788    }
789
790    fn svd_2d(
791        ctx: &CpuContext,
792        buffers: &mut BufferPool,
793        input: &TypedTensor<Self>,
794    ) -> Vec<TypedTensor<Self>> {
795        let (m, n) = matrix_dims(input, "svd");
796        let k = m.min(n);
797        let mat = MatRef::from_column_major_slice(input.host_data(), m, n);
798        let mut u = Mat::zeros(m, k);
799        let mut v = Mat::zeros(n, k);
800        let mut s = Diag::zeros(k);
801        let mut mem = MemBuffer::new(faer::linalg::svd::svd_scratch::<Self>(
802            m,
803            n,
804            faer::linalg::svd::ComputeSvdVectors::Thin,
805            faer::linalg::svd::ComputeSvdVectors::Thin,
806            ctx.faer_par(),
807            Default::default(),
808        ));
809        let stack = MemStack::new(&mut mem);
810        faer::linalg::svd::svd(
811            mat,
812            s.as_mut(),
813            Some(u.as_mut()),
814            Some(v.as_mut()),
815            ctx.faer_par(),
816            stack,
817            Default::default(),
818        )
819        .unwrap_or_else(|_| panic!("svd: decomposition failed"));
820
821        let u = tensor_from_vec_with_template(
822            vec![m, k],
823            col_major_vec_from_mat(buffers, u.as_ref()),
824            input,
825        );
826        let s = tensor_from_vec_with_template(vec![k], vec_from_diag(buffers, s.as_ref()), input);
827        let mut vt_data = buffers.acquire_with_capacity::<Self>(k * n);
828        for j in 0..n {
829            for i in 0..k {
830                vt_data.push(v[(j, i)]);
831            }
832        }
833        let vt = tensor_from_vec_with_template(vec![k, n], vt_data, input);
834
835        vec![u, s, vt]
836    }
837
838    fn qr_2d(
839        ctx: &CpuContext,
840        buffers: &mut BufferPool,
841        input: &TypedTensor<Self>,
842    ) -> Vec<TypedTensor<Self>> {
843        let (m, n) = matrix_dims(input, "qr");
844        let k = m.min(n);
845        let mat = MatRef::from_column_major_slice(input.host_data(), m, n);
846        let block_size =
847            faer::linalg::qr::no_pivoting::factor::recommended_block_size::<Self>(m, n);
848        let mut qr = Mat::zeros(m, n);
849        qr.copy_from(mat);
850        let mut coeff = Mat::zeros(block_size, k);
851        let mut mem = MemBuffer::new(
852            faer::linalg::qr::no_pivoting::factor::qr_in_place_scratch::<Self>(
853                m,
854                n,
855                block_size,
856                ctx.faer_par(),
857                Default::default(),
858            ),
859        );
860        let stack = MemStack::new(&mut mem);
861        faer::linalg::qr::no_pivoting::factor::qr_in_place(
862            qr.as_mut(),
863            coeff.as_mut(),
864            ctx.faer_par(),
865            stack,
866            Default::default(),
867        );
868        let mut q = Mat::identity(m, k);
869        let mut mem = MemBuffer::new(
870            faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<Self>(
871                m,
872                block_size,
873                k,
874            ),
875        );
876        let stack = MemStack::new(&mut mem);
877        faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
878            qr.as_ref().subcols(0, k),
879            coeff.as_ref(),
880            Conj::No,
881            q.as_mut(),
882            ctx.faer_par(),
883            stack,
884        );
885        let q = tensor_from_vec_with_template(
886            vec![m, k],
887            col_major_vec_from_mat(buffers, q.as_ref()),
888            input,
889        );
890        let r = tensor_from_vec_with_template(
891            vec![k, n],
892            upper_triangle_vec_from_mat(qr.as_ref().get(..k, ..)),
893            input,
894        );
895
896        vec![q, r]
897    }
898
899    fn eigh_2d(
900        ctx: &CpuContext,
901        buffers: &mut BufferPool,
902        input: &TypedTensor<Self>,
903    ) -> Vec<TypedTensor<Self>> {
904        let n = square_matrix_dim(input, "eigh");
905        let mat = MatRef::from_column_major_slice(input.host_data(), n, n);
906        let mut values = Diag::zeros(n);
907        let mut vectors = Mat::zeros(n, n);
908        let mut mem = MemBuffer::new(faer::linalg::evd::self_adjoint_evd_scratch::<Self>(
909            n,
910            faer::linalg::evd::ComputeEigenvectors::Yes,
911            ctx.faer_par(),
912            Default::default(),
913        ));
914        let stack = MemStack::new(&mut mem);
915        faer::linalg::evd::self_adjoint_evd(
916            mat,
917            values.as_mut(),
918            Some(vectors.as_mut()),
919            ctx.faer_par(),
920            stack,
921            Default::default(),
922        )
923        .unwrap_or_else(|_| panic!("eigh: decomposition failed"));
924
925        let values =
926            tensor_from_vec_with_template(vec![n], vec_from_diag(buffers, values.as_ref()), input);
927        let vectors = tensor_from_vec_with_template(
928            vec![n, n],
929            col_major_vec_from_mat(buffers, vectors.as_ref()),
930            input,
931        );
932
933        vec![values, vectors]
934    }
935}
936
937impl FaerLinalg for Complex64 {
938    fn parity_one() -> Self {
939        Complex64::new(1.0, 0.0)
940    }
941
942    fn cholesky_2d(
943        ctx: &CpuContext,
944        _buffers: &mut BufferPool,
945        input: &TypedTensor<Self>,
946    ) -> crate::Result<TypedTensor<Self>> {
947        let n = square_matrix_dim(input, "cholesky");
948        let mut l = Mat::zeros(n, n);
949        l.copy_from(MatRef::from_column_major_slice(
950            complex64_to_faer_slice(input.host_data()),
951            n,
952            n,
953        ));
954        let mut mem = MemBuffer::new(
955            faer::linalg::cholesky::llt::factor::cholesky_in_place_scratch::<faer::c64>(
956                n,
957                ctx.faer_par(),
958                Default::default(),
959            ),
960        );
961        let stack = MemStack::new(&mut mem);
962        faer::linalg::cholesky::llt::factor::cholesky_in_place(
963            l.as_mut(),
964            Default::default(),
965            ctx.faer_par(),
966            stack,
967            Default::default(),
968        )
969        .map_err(|_| crate::Error::BackendFailure {
970            op: "cholesky",
971            message: "matrix is not positive definite".into(),
972        })?;
973        Ok(tensor_from_vec_with_template(
974            vec![n, n],
975            complex_matrix_from_predicate(l.as_ref(), n, n, |row, col| row >= col),
976            input,
977        ))
978    }
979
980    fn lu_2d(
981        ctx: &CpuContext,
982        _buffers: &mut BufferPool,
983        input: &TypedTensor<Self>,
984    ) -> Vec<TypedTensor<Self>> {
985        let (m, n) = matrix_dims(input, "lu");
986        let k = m.min(n);
987        let mut lu = Mat::zeros(m, n);
988        lu.copy_from(MatRef::from_column_major_slice(
989            complex64_to_faer_slice(input.host_data()),
990            m,
991            n,
992        ));
993        let mut perm = vec![0usize; m];
994        let mut perm_inv = vec![0usize; m];
995        let mut mem = MemBuffer::new(
996            faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, faer::c64>(
997                m,
998                n,
999                ctx.faer_par(),
1000                Default::default(),
1001            ),
1002        );
1003        let stack = MemStack::new(&mut mem);
1004        let info = faer::linalg::lu::partial_pivoting::factor::lu_in_place(
1005            lu.as_mut(),
1006            &mut perm,
1007            &mut perm_inv,
1008            ctx.faer_par(),
1009            stack,
1010            Default::default(),
1011        )
1012        .0;
1013
1014        let mut p_data = vec![Complex64::new(0.0, 0.0); m * m];
1015        for (row, &col) in perm.iter().enumerate() {
1016            p_data[row + col * m] = Complex64::new(1.0, 0.0);
1017        }
1018        let parity = if info.transposition_count % 2 == 0 {
1019            Complex64::new(1.0, 0.0)
1020        } else {
1021            Complex64::new(-1.0, 0.0)
1022        };
1023        let mut l_data = complex_matrix_from_predicate(lu.as_ref(), m, k, |row, col| row >= col);
1024        for i in 0..k {
1025            l_data[i + i * m] = Complex64::new(1.0, 0.0);
1026        }
1027        let u_data = complex_matrix_from_predicate(lu.as_ref(), k, n, |row, col| row <= col);
1028
1029        vec![
1030            tensor_from_vec_with_template(vec![m, m], p_data, input),
1031            tensor_from_vec_with_template(vec![m, k], l_data, input),
1032            tensor_from_vec_with_template(vec![k, n], u_data, input),
1033            tensor_from_vec_with_template(vec![], vec![parity], input),
1034        ]
1035    }
1036
1037    fn triangular_solve_2d(
1038        ctx: &CpuContext,
1039        buffers: &mut BufferPool,
1040        a: &TypedTensor<Self>,
1041        b: &TypedTensor<Self>,
1042        left_side: bool,
1043        lower: bool,
1044        transpose_a: bool,
1045        unit_diagonal: bool,
1046    ) -> TypedTensor<Self> {
1047        let n = square_matrix_dim(a, "triangular_solve");
1048        let (b_rows, b_cols) = matrix_dims(b, "triangular_solve");
1049        let a_mat = MatRef::from_column_major_slice(complex64_to_faer_slice(a.host_data()), n, n);
1050
1051        if left_side {
1052            assert_eq!(b_rows, n, "triangular_solve: rhs row count mismatch");
1053            let mut rhs_data = buffers.acquire_with_capacity::<Self>(b.host_data().len());
1054            rhs_data.extend_from_slice(b.host_data());
1055            let rhs = MatMut::from_column_major_slice_mut(
1056                complex64_to_faer_slice_mut(&mut rhs_data),
1057                n,
1058                b_cols,
1059            );
1060            match (transpose_a, lower, unit_diagonal) {
1061                (false, true, false) => {
1062                    faer::linalg::triangular_solve::solve_lower_triangular_in_place(
1063                        a_mat,
1064                        rhs,
1065                        ctx.faer_par(),
1066                    );
1067                }
1068                (false, true, true) => {
1069                    faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
1070                        a_mat,
1071                        rhs,
1072                        ctx.faer_par(),
1073                    );
1074                }
1075                (false, false, false) => {
1076                    faer::linalg::triangular_solve::solve_upper_triangular_in_place(
1077                        a_mat,
1078                        rhs,
1079                        ctx.faer_par(),
1080                    );
1081                }
1082                (false, false, true) => {
1083                    faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
1084                        a_mat,
1085                        rhs,
1086                        ctx.faer_par(),
1087                    );
1088                }
1089                (true, true, false) => {
1090                    faer::linalg::triangular_solve::solve_upper_triangular_in_place(
1091                        a_mat.transpose(),
1092                        rhs,
1093                        ctx.faer_par(),
1094                    );
1095                }
1096                (true, true, true) => {
1097                    faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
1098                        a_mat.transpose(),
1099                        rhs,
1100                        ctx.faer_par(),
1101                    );
1102                }
1103                (true, false, false) => {
1104                    faer::linalg::triangular_solve::solve_lower_triangular_in_place(
1105                        a_mat.transpose(),
1106                        rhs,
1107                        ctx.faer_par(),
1108                    );
1109                }
1110                (true, false, true) => {
1111                    faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
1112                        a_mat.transpose(),
1113                        rhs,
1114                        ctx.faer_par(),
1115                    );
1116                }
1117            }
1118            tensor_from_vec_with_template(vec![n, b_cols], rhs_data, b)
1119        } else {
1120            assert_eq!(b_cols, n, "triangular_solve: rhs column count mismatch");
1121            let nrhs = b_rows;
1122            let mut rhs_transposed = transpose_col_major_data(buffers, b.host_data(), nrhs, n);
1123            let rhs = MatMut::from_column_major_slice_mut(
1124                complex64_to_faer_slice_mut(&mut rhs_transposed),
1125                n,
1126                nrhs,
1127            );
1128            match (transpose_a, lower, unit_diagonal) {
1129                (false, true, false) => {
1130                    faer::linalg::triangular_solve::solve_upper_triangular_in_place(
1131                        a_mat.transpose(),
1132                        rhs,
1133                        ctx.faer_par(),
1134                    );
1135                }
1136                (false, true, true) => {
1137                    faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
1138                        a_mat.transpose(),
1139                        rhs,
1140                        ctx.faer_par(),
1141                    );
1142                }
1143                (false, false, false) => {
1144                    faer::linalg::triangular_solve::solve_lower_triangular_in_place(
1145                        a_mat.transpose(),
1146                        rhs,
1147                        ctx.faer_par(),
1148                    );
1149                }
1150                (false, false, true) => {
1151                    faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
1152                        a_mat.transpose(),
1153                        rhs,
1154                        ctx.faer_par(),
1155                    );
1156                }
1157                (true, true, false) => {
1158                    faer::linalg::triangular_solve::solve_lower_triangular_in_place(
1159                        a_mat,
1160                        rhs,
1161                        ctx.faer_par(),
1162                    );
1163                }
1164                (true, true, true) => {
1165                    faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
1166                        a_mat,
1167                        rhs,
1168                        ctx.faer_par(),
1169                    );
1170                }
1171                (true, false, false) => {
1172                    faer::linalg::triangular_solve::solve_upper_triangular_in_place(
1173                        a_mat,
1174                        rhs,
1175                        ctx.faer_par(),
1176                    );
1177                }
1178                (true, false, true) => {
1179                    faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
1180                        a_mat,
1181                        rhs,
1182                        ctx.faer_par(),
1183                    );
1184                }
1185            }
1186            let result = transpose_col_major_data(buffers, &rhs_transposed, n, nrhs);
1187            <Self as PoolScalar>::pool_release(buffers, rhs_transposed);
1188            tensor_from_vec_with_template(vec![nrhs, n], result, b)
1189        }
1190    }
1191
1192    fn svd_2d(
1193        ctx: &CpuContext,
1194        buffers: &mut BufferPool,
1195        input: &TypedTensor<Self>,
1196    ) -> Vec<TypedTensor<Self>> {
1197        let (m, n) = matrix_dims(input, "svd");
1198        let k = m.min(n);
1199        let mat = MatRef::from_column_major_slice(complex64_to_faer_slice(input.host_data()), m, n);
1200        let mut u = Mat::zeros(m, k);
1201        let mut v = Mat::zeros(n, k);
1202        let mut s = Diag::zeros(k);
1203        let mut mem = MemBuffer::new(faer::linalg::svd::svd_scratch::<faer::c64>(
1204            m,
1205            n,
1206            faer::linalg::svd::ComputeSvdVectors::Thin,
1207            faer::linalg::svd::ComputeSvdVectors::Thin,
1208            ctx.faer_par(),
1209            Default::default(),
1210        ));
1211        let stack = MemStack::new(&mut mem);
1212        faer::linalg::svd::svd(
1213            mat,
1214            s.as_mut(),
1215            Some(u.as_mut()),
1216            Some(v.as_mut()),
1217            ctx.faer_par(),
1218            stack,
1219            Default::default(),
1220        )
1221        .unwrap_or_else(|_| panic!("svd: decomposition failed"));
1222
1223        let u = tensor_from_vec_with_template(
1224            vec![m, k],
1225            complex_vec_from_mat(buffers, u.as_ref()),
1226            input,
1227        );
1228        let s = tensor_from_vec_with_template(
1229            vec![k],
1230            complex_vec_from_real_diag(buffers, s.as_ref()),
1231            input,
1232        );
1233        let mut vt_data = buffers.acquire_with_capacity::<Self>(k * n);
1234        for j in 0..n {
1235            for i in 0..k {
1236                vt_data.push(v[(j, i)].conj());
1237            }
1238        }
1239        let vt = tensor_from_vec_with_template(vec![k, n], vt_data, input);
1240
1241        vec![u, s, vt]
1242    }
1243
1244    fn qr_2d(
1245        ctx: &CpuContext,
1246        buffers: &mut BufferPool,
1247        input: &TypedTensor<Self>,
1248    ) -> Vec<TypedTensor<Self>> {
1249        let (m, n) = matrix_dims(input, "qr");
1250        let k = m.min(n);
1251        let mat = MatRef::from_column_major_slice(complex64_to_faer_slice(input.host_data()), m, n);
1252        let block_size =
1253            faer::linalg::qr::no_pivoting::factor::recommended_block_size::<faer::c64>(m, n);
1254        let mut qr = Mat::zeros(m, n);
1255        qr.copy_from(mat);
1256        let mut coeff = Mat::zeros(block_size, k);
1257        let mut mem = MemBuffer::new(
1258            faer::linalg::qr::no_pivoting::factor::qr_in_place_scratch::<faer::c64>(
1259                m,
1260                n,
1261                block_size,
1262                ctx.faer_par(),
1263                Default::default(),
1264            ),
1265        );
1266        let stack = MemStack::new(&mut mem);
1267        faer::linalg::qr::no_pivoting::factor::qr_in_place(
1268            qr.as_mut(),
1269            coeff.as_mut(),
1270            ctx.faer_par(),
1271            stack,
1272            Default::default(),
1273        );
1274        let mut q = Mat::identity(m, k);
1275        let mut mem = MemBuffer::new(
1276            faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<faer::c64>(
1277                m,
1278                block_size,
1279                k,
1280            ),
1281        );
1282        let stack = MemStack::new(&mut mem);
1283        faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1284            qr.as_ref().subcols(0, k),
1285            coeff.as_ref(),
1286            Conj::No,
1287            q.as_mut(),
1288            ctx.faer_par(),
1289            stack,
1290        );
1291        let q = tensor_from_vec_with_template(
1292            vec![m, k],
1293            complex_vec_from_mat(buffers, q.as_ref()),
1294            input,
1295        );
1296        let r = tensor_from_vec_with_template(
1297            vec![k, n],
1298            complex_matrix_from_predicate(qr.as_ref(), k, n, |row, col| row <= col),
1299            input,
1300        );
1301
1302        vec![q, r]
1303    }
1304
1305    fn eigh_2d(
1306        ctx: &CpuContext,
1307        buffers: &mut BufferPool,
1308        input: &TypedTensor<Self>,
1309    ) -> Vec<TypedTensor<Self>> {
1310        let n = square_matrix_dim(input, "eigh");
1311        let mat = MatRef::from_column_major_slice(complex64_to_faer_slice(input.host_data()), n, n);
1312        let mut values = Diag::zeros(n);
1313        let mut vectors = Mat::zeros(n, n);
1314        let mut mem = MemBuffer::new(faer::linalg::evd::self_adjoint_evd_scratch::<faer::c64>(
1315            n,
1316            faer::linalg::evd::ComputeEigenvectors::Yes,
1317            ctx.faer_par(),
1318            Default::default(),
1319        ));
1320        let stack = MemStack::new(&mut mem);
1321        faer::linalg::evd::self_adjoint_evd(
1322            mat,
1323            values.as_mut(),
1324            Some(vectors.as_mut()),
1325            ctx.faer_par(),
1326            stack,
1327            Default::default(),
1328        )
1329        .unwrap_or_else(|_| panic!("eigh: decomposition failed"));
1330
1331        let values = tensor_from_vec_with_template(
1332            vec![n],
1333            complex_vec_from_real_diag(buffers, values.as_ref()),
1334            input,
1335        );
1336        let vectors = tensor_from_vec_with_template(
1337            vec![n, n],
1338            complex_vec_from_mat(buffers, vectors.as_ref()),
1339            input,
1340        );
1341
1342        vec![values, vectors]
1343    }
1344}
1345
1346pub(crate) fn cholesky<T: FaerLinalg>(
1347    ctx: &CpuContext,
1348    buffers: &mut BufferPool,
1349    input: &TypedTensor<T>,
1350) -> crate::Result<TypedTensor<T>> {
1351    if has_zero_dim(&input.shape) {
1352        return Ok(tensor_from_vec_with_template(
1353            input.shape.clone(),
1354            Vec::new(),
1355            input,
1356        ));
1357    }
1358    batched_single(buffers, input, 2, |buffers, batch| {
1359        T::cholesky_2d(ctx, buffers, batch)
1360    })
1361}
1362
1363pub(crate) fn lu<T: FaerLinalg>(
1364    ctx: &CpuContext,
1365    buffers: &mut BufferPool,
1366    input: &TypedTensor<T>,
1367) -> Vec<TypedTensor<T>> {
1368    if has_zero_dim(&input.shape) {
1369        let m = input.shape[0];
1370        let n = input.shape[1];
1371        let k = m.min(n);
1372        let batch_shape = &input.shape[2..];
1373        let parity_elements: usize = batch_shape.iter().product::<usize>().max(1);
1374        return vec![
1375            tensor_from_vec_with_template(
1376                matrix_with_batch_shape(m, m, batch_shape),
1377                Vec::new(),
1378                input,
1379            ),
1380            tensor_from_vec_with_template(
1381                matrix_with_batch_shape(m, k, batch_shape),
1382                Vec::new(),
1383                input,
1384            ),
1385            tensor_from_vec_with_template(
1386                matrix_with_batch_shape(k, n, batch_shape),
1387                Vec::new(),
1388                input,
1389            ),
1390            tensor_from_vec_with_template(
1391                batch_shape.to_vec(),
1392                vec![T::parity_one(); parity_elements],
1393                input,
1394            ),
1395        ];
1396    }
1397    batched_multi(buffers, input, 2, |buffers, batch| {
1398        T::lu_2d(ctx, buffers, batch)
1399    })
1400}
1401
1402pub(crate) fn triangular_solve<T: FaerLinalg>(
1403    ctx: &CpuContext,
1404    buffers: &mut BufferPool,
1405    a: &TypedTensor<T>,
1406    b: &TypedTensor<T>,
1407    left_side: bool,
1408    lower: bool,
1409    transpose_a: bool,
1410    unit_diagonal: bool,
1411) -> TypedTensor<T> {
1412    if has_zero_dim(&a.shape) || has_zero_dim(&b.shape) {
1413        return tensor_from_vec_with_template(b.shape.clone(), Vec::new(), b);
1414    }
1415    batched_binary(buffers, a, b, 2, 2, |buffers, a, b| {
1416        T::triangular_solve_2d(
1417            ctx,
1418            buffers,
1419            a,
1420            b,
1421            left_side,
1422            lower,
1423            transpose_a,
1424            unit_diagonal,
1425        )
1426    })
1427}
1428
1429pub(crate) fn svd<T: FaerLinalg>(
1430    ctx: &CpuContext,
1431    buffers: &mut BufferPool,
1432    input: &TypedTensor<T>,
1433) -> Vec<TypedTensor<T>> {
1434    if has_zero_dim(&input.shape) {
1435        let (matrix_shape, batch_shape) = split_core_and_batch(input, 2, "svd");
1436        let m = matrix_shape[0];
1437        let n = matrix_shape[1];
1438        let k = m.min(n);
1439        return vec![
1440            tensor_from_vec_with_template(
1441                matrix_with_batch_shape(m, k, batch_shape),
1442                Vec::new(),
1443                input,
1444            ),
1445            tensor_from_vec_with_template(
1446                vector_with_batch_shape(k, batch_shape),
1447                Vec::new(),
1448                input,
1449            ),
1450            tensor_from_vec_with_template(
1451                matrix_with_batch_shape(k, n, batch_shape),
1452                Vec::new(),
1453                input,
1454            ),
1455        ];
1456    }
1457    batched_multi(buffers, input, 2, |buffers, batch| {
1458        T::svd_2d(ctx, buffers, batch)
1459    })
1460}
1461
1462pub(crate) fn qr<T: FaerLinalg>(
1463    ctx: &CpuContext,
1464    buffers: &mut BufferPool,
1465    input: &TypedTensor<T>,
1466) -> Vec<TypedTensor<T>> {
1467    if has_zero_dim(&input.shape) {
1468        let (matrix_shape, batch_shape) = split_core_and_batch(input, 2, "qr");
1469        let m = matrix_shape[0];
1470        let n = matrix_shape[1];
1471        let k = m.min(n);
1472        return vec![
1473            tensor_from_vec_with_template(
1474                matrix_with_batch_shape(m, k, batch_shape),
1475                Vec::new(),
1476                input,
1477            ),
1478            tensor_from_vec_with_template(
1479                matrix_with_batch_shape(k, n, batch_shape),
1480                Vec::new(),
1481                input,
1482            ),
1483        ];
1484    }
1485    batched_multi(buffers, input, 2, |buffers, batch| {
1486        T::qr_2d(ctx, buffers, batch)
1487    })
1488}
1489
1490pub(crate) fn eigh<T: FaerLinalg>(
1491    ctx: &CpuContext,
1492    buffers: &mut BufferPool,
1493    input: &TypedTensor<T>,
1494) -> Vec<TypedTensor<T>> {
1495    if has_zero_dim(&input.shape) {
1496        let n = input.shape[0];
1497        let batch_shape = &input.shape[2..];
1498        return vec![
1499            tensor_from_vec_with_template(
1500                vector_with_batch_shape(n, batch_shape),
1501                Vec::new(),
1502                input,
1503            ),
1504            tensor_from_vec_with_template(
1505                matrix_with_batch_shape(n, n, batch_shape),
1506                Vec::new(),
1507                input,
1508            ),
1509        ];
1510    }
1511    batched_multi(buffers, input, 2, |buffers, batch| {
1512        T::eigh_2d(ctx, buffers, batch)
1513    })
1514}
1515
1516fn eig_real_2d(
1517    ctx: &CpuContext,
1518    buffers: &mut BufferPool,
1519    input: &TypedTensor<f64>,
1520) -> Vec<TypedTensor<Complex64>> {
1521    let n = square_matrix_dim(input, "eig");
1522    let mat = MatRef::from_column_major_slice(input.host_data(), n, n);
1523    let mut u_real = Mat::zeros(n, n);
1524    let mut s_re = Diag::zeros(n);
1525    let mut s_im = Diag::zeros(n);
1526    let mut mem = MemBuffer::new(faer::linalg::evd::evd_scratch::<f64>(
1527        n,
1528        faer::linalg::evd::ComputeEigenvectors::No,
1529        faer::linalg::evd::ComputeEigenvectors::Yes,
1530        ctx.faer_par(),
1531        Default::default(),
1532    ));
1533    let stack = MemStack::new(&mut mem);
1534    faer::linalg::evd::evd_real(
1535        mat,
1536        s_re.as_mut(),
1537        s_im.as_mut(),
1538        None,
1539        Some(u_real.as_mut()),
1540        ctx.faer_par(),
1541        stack,
1542        Default::default(),
1543    )
1544    .unwrap_or_else(|_| panic!("eig: decomposition failed"));
1545    let (u, s) =
1546        real_eig_to_complex_outputs(buffers, u_real.as_ref(), s_re.as_ref(), s_im.as_ref());
1547
1548    vec![
1549        tensor_from_vec_with_template(vec![n], s, input),
1550        tensor_from_vec_with_template(vec![n, n], u, input),
1551    ]
1552}
1553
1554fn eig_complex_2d(
1555    ctx: &CpuContext,
1556    buffers: &mut BufferPool,
1557    input: &TypedTensor<Complex64>,
1558) -> Vec<TypedTensor<Complex64>> {
1559    let n = square_matrix_dim(input, "eig");
1560    let mat = MatRef::from_column_major_slice(complex64_to_faer_slice(input.host_data()), n, n);
1561    let mut u = Mat::zeros(n, n);
1562    let mut s = Diag::zeros(n);
1563    let mut mem = MemBuffer::new(faer::linalg::evd::evd_scratch::<faer::c64>(
1564        n,
1565        faer::linalg::evd::ComputeEigenvectors::No,
1566        faer::linalg::evd::ComputeEigenvectors::Yes,
1567        ctx.faer_par(),
1568        Default::default(),
1569    ));
1570    let stack = MemStack::new(&mut mem);
1571    faer::linalg::evd::evd_cplx(
1572        mat,
1573        s.as_mut(),
1574        None,
1575        Some(u.as_mut()),
1576        ctx.faer_par(),
1577        stack,
1578        Default::default(),
1579    )
1580    .unwrap_or_else(|_| panic!("eig: decomposition failed"));
1581
1582    vec![
1583        tensor_from_vec_with_template(vec![n], complex_vec_from_diag(buffers, s.as_ref()), input),
1584        tensor_from_vec_with_template(vec![n, n], complex_vec_from_mat(buffers, u.as_ref()), input),
1585    ]
1586}
1587
1588pub(crate) fn eig(ctx: &CpuContext, buffers: &mut BufferPool, input: &Tensor) -> Vec<Tensor> {
1589    if has_zero_dim(input.shape()) {
1590        let n = input.shape()[0];
1591        let batch_shape = &input.shape()[2..];
1592        return vec![
1593            Tensor::C64(TypedTensor::from_vec(
1594                vector_with_batch_shape(n, batch_shape),
1595                Vec::new(),
1596            )),
1597            Tensor::C64(TypedTensor::from_vec(
1598                matrix_with_batch_shape(n, n, batch_shape),
1599                Vec::new(),
1600            )),
1601        ];
1602    }
1603
1604    match input {
1605        Tensor::F64(t) => batched_multi_convert(buffers, t, 2, |buffers, batch| {
1606            eig_real_2d(ctx, buffers, batch)
1607        })
1608        .into_iter()
1609        .map(Tensor::C64)
1610        .collect(),
1611        Tensor::C64(t) => batched_multi_convert(buffers, t, 2, |buffers, batch| {
1612            eig_complex_2d(ctx, buffers, batch)
1613        })
1614        .into_iter()
1615        .map(Tensor::C64)
1616        .collect(),
1617        _ => todo!("eig: unsupported dtype"),
1618    }
1619}
1620
1621fn has_zero_dim(shape: &[usize]) -> bool {
1622    shape.contains(&0)
1623}
1624
1625fn matrix_with_batch_shape(rows: usize, cols: usize, batch_shape: &[usize]) -> Vec<usize> {
1626    let mut shape = vec![rows, cols];
1627    shape.extend_from_slice(batch_shape);
1628    shape
1629}
1630
1631fn vector_with_batch_shape(len: usize, batch_shape: &[usize]) -> Vec<usize> {
1632    let mut shape = vec![len];
1633    shape.extend_from_slice(batch_shape);
1634    shape
1635}