1use faer::dyn_stack::{MemBuffer, MemStack};
2use faer::{
3 diag::{Diag, DiagRef},
4 Conj, Mat, MatMut, MatRef,
5};
6use num_complex::Complex64;
7
8use crate::buffer_pool::{BufferPool, PoolScalar};
9use crate::cpu::CpuContext;
10use crate::{Buffer, Tensor, TypedTensor};
11
12pub(crate) trait FaerLinalg: Copy + Clone + PoolScalar {
13 fn parity_one() -> Self;
14 fn cholesky_2d(
15 ctx: &CpuContext,
16 buffers: &mut BufferPool,
17 input: &TypedTensor<Self>,
18 ) -> crate::Result<TypedTensor<Self>>;
19 fn lu_2d(
20 ctx: &CpuContext,
21 buffers: &mut BufferPool,
22 input: &TypedTensor<Self>,
23 ) -> Vec<TypedTensor<Self>>;
24 fn triangular_solve_2d(
25 ctx: &CpuContext,
26 buffers: &mut BufferPool,
27 a: &TypedTensor<Self>,
28 b: &TypedTensor<Self>,
29 left_side: bool,
30 lower: bool,
31 transpose_a: bool,
32 unit_diagonal: bool,
33 ) -> TypedTensor<Self>;
34 fn svd_2d(
35 ctx: &CpuContext,
36 buffers: &mut BufferPool,
37 input: &TypedTensor<Self>,
38 ) -> Vec<TypedTensor<Self>>;
39 fn qr_2d(
40 ctx: &CpuContext,
41 buffers: &mut BufferPool,
42 input: &TypedTensor<Self>,
43 ) -> Vec<TypedTensor<Self>>;
44 fn eigh_2d(
45 ctx: &CpuContext,
46 buffers: &mut BufferPool,
47 input: &TypedTensor<Self>,
48 ) -> Vec<TypedTensor<Self>>;
49}
50
51fn matrix_dims<T>(input: &TypedTensor<T>, op: &str) -> (usize, usize) {
52 assert_eq!(input.shape.len(), 2, "{op}: expected a 2D matrix");
53 (input.shape[0], input.shape[1])
54}
55
56fn square_matrix_dim<T>(input: &TypedTensor<T>, op: &str) -> usize {
57 let (rows, cols) = matrix_dims(input, op);
58 assert_eq!(rows, cols, "{op}: expected a square matrix");
59 rows
60}
61
62fn tensor_from_vec_with_template<T: Clone, U>(
63 shape: Vec<usize>,
64 data: Vec<T>,
65 template: &TypedTensor<U>,
66) -> TypedTensor<T> {
67 TypedTensor {
68 buffer: Buffer::Host(data),
69 shape,
70 placement: template.placement.clone(),
71 }
72}
73
74fn col_major_vec_from_mat<T: Copy + PoolScalar>(
75 buffers: &mut BufferPool,
76 mat: MatRef<'_, T>,
77) -> Vec<T> {
78 let (rows, cols) = mat.shape();
79 let mut data = buffers.acquire_with_capacity::<T>(rows * cols);
80 for j in 0..cols {
81 for i in 0..rows {
82 data.push(mat[(i, j)]);
83 }
84 }
85 data
86}
87
88fn vec_from_diag<T: Copy + PoolScalar>(buffers: &mut BufferPool, diag: DiagRef<'_, T>) -> Vec<T> {
89 let col = diag.column_vector();
90 let mut data = buffers.acquire_with_capacity::<T>(col.nrows());
91 for i in 0..col.nrows() {
92 data.push(col[i]);
93 }
94 data
95}
96
97fn complex64_to_faer_slice(data: &[Complex64]) -> &[faer::c64] {
98 debug_assert_eq!(
99 std::mem::size_of::<Complex64>(),
100 std::mem::size_of::<faer::c64>()
101 );
102 debug_assert_eq!(
103 std::mem::align_of::<Complex64>(),
104 std::mem::align_of::<faer::c64>()
105 );
106
107 unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<faer::c64>(), data.len()) }
108}
109
110fn complex64_to_faer_slice_mut(data: &mut [Complex64]) -> &mut [faer::c64] {
111 debug_assert_eq!(
112 std::mem::size_of::<Complex64>(),
113 std::mem::size_of::<faer::c64>()
114 );
115 debug_assert_eq!(
116 std::mem::align_of::<Complex64>(),
117 std::mem::align_of::<faer::c64>()
118 );
119
120 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr().cast::<faer::c64>(), data.len()) }
121}
122
123fn complex_vec_from_real_diag(
124 buffers: &mut BufferPool,
125 diag: DiagRef<'_, faer::c64>,
126) -> Vec<Complex64> {
127 let col = diag.column_vector();
128 let mut data = buffers.acquire_with_capacity::<Complex64>(col.nrows());
129 for i in 0..col.nrows() {
130 data.push(Complex64::new(col[i].re, 0.0));
131 }
132 data
133}
134
135fn complex_vec_from_diag(buffers: &mut BufferPool, diag: DiagRef<'_, faer::c64>) -> Vec<Complex64> {
136 let col = diag.column_vector();
137 let mut data = buffers.acquire_with_capacity::<Complex64>(col.nrows());
138 for i in 0..col.nrows() {
139 data.push(Complex64::new(col[i].re, col[i].im));
140 }
141 data
142}
143
144fn complex_vec_from_mat(buffers: &mut BufferPool, mat: MatRef<'_, faer::c64>) -> Vec<Complex64> {
145 let (rows, cols) = mat.shape();
146 let mut data = buffers.acquire_with_capacity::<Complex64>(rows * cols);
147 for j in 0..cols {
148 for i in 0..rows {
149 let value = mat[(i, j)];
150 data.push(Complex64::new(value.re, value.im));
151 }
152 }
153 data
154}
155
156fn matrix_from_predicate<T: Copy + Default>(
157 mat: MatRef<'_, T>,
158 rows: usize,
159 cols: usize,
160 predicate: impl Fn(usize, usize) -> bool,
161) -> Vec<T> {
162 let mut data = vec![T::default(); rows * cols];
163 for j in 0..cols {
164 for i in 0..rows {
165 if predicate(i, j) {
166 data[i + j * rows] = mat[(i, j)];
167 }
168 }
169 }
170 data
171}
172
173fn lower_triangle_vec_from_mat<T: Copy + Default>(mat: MatRef<'_, T>) -> Vec<T> {
174 let (rows, cols) = mat.shape();
175 matrix_from_predicate(mat, rows, cols, |row, col| row >= col)
176}
177
178fn upper_triangle_vec_from_mat<T: Copy + Default>(mat: MatRef<'_, T>) -> Vec<T> {
179 let (rows, cols) = mat.shape();
180 matrix_from_predicate(mat, rows, cols, |row, col| row <= col)
181}
182
183fn complex_matrix_from_predicate(
184 mat: MatRef<'_, faer::c64>,
185 rows: usize,
186 cols: usize,
187 predicate: impl Fn(usize, usize) -> bool,
188) -> Vec<Complex64> {
189 let mut data = vec![Complex64::new(0.0, 0.0); rows * cols];
190 for j in 0..cols {
191 for i in 0..rows {
192 if predicate(i, j) {
193 let value = mat[(i, j)];
194 data[i + j * rows] = Complex64::new(value.re, value.im);
195 }
196 }
197 }
198 data
199}
200
201fn real_eig_to_complex_outputs(
202 buffers: &mut BufferPool,
203 u_real: MatRef<'_, f64>,
204 s_re: DiagRef<'_, f64>,
205 s_im: DiagRef<'_, f64>,
206) -> (Vec<Complex64>, Vec<Complex64>) {
207 let n = u_real.nrows();
208 let mut u = unsafe { <Complex64 as PoolScalar>::pool_acquire(buffers, n * n) };
210 let mut s = unsafe { <Complex64 as PoolScalar>::pool_acquire(buffers, n) };
212 let mut j = 0;
213 while j < n {
214 if s_im[j] == 0.0 {
215 s[j] = Complex64::new(s_re[j], 0.0);
216 for i in 0..n {
217 u[i + j * n] = Complex64::new(u_real[(i, j)], 0.0);
218 }
219 j += 1;
220 } else {
221 s[j] = Complex64::new(s_re[j], s_im[j]);
222 s[j + 1] = Complex64::new(s_re[j], -s_im[j]);
223 for i in 0..n {
224 u[i + j * n] = Complex64::new(u_real[(i, j)], u_real[(i, j + 1)]);
225 u[i + (j + 1) * n] = Complex64::new(u_real[(i, j)], -u_real[(i, j + 1)]);
226 }
227 j += 2;
228 }
229 }
230 (u, s)
231}
232
233fn split_core_and_batch<'a, T>(
234 input: &'a TypedTensor<T>,
235 core_rank: usize,
236 op: &str,
237) -> (&'a [usize], &'a [usize]) {
238 assert!(
239 input.shape.len() >= core_rank,
240 "{op}: expected rank >= {core_rank}"
241 );
242 input.shape.split_at(core_rank)
243}
244
245fn transpose_col_major_data<T: Copy + PoolScalar>(
246 buffers: &mut BufferPool,
247 data: &[T],
248 rows: usize,
249 cols: usize,
250) -> Vec<T> {
251 let mut transposed = buffers.acquire_with_capacity::<T>(data.len());
252 for j in 0..rows {
253 for i in 0..cols {
254 transposed.push(data[j + i * rows]);
255 }
256 }
257 transposed
258}
259
260fn batched_single<T, F>(
261 buffers: &mut BufferPool,
262 input: &TypedTensor<T>,
263 core_rank: usize,
264 op: F,
265) -> crate::Result<TypedTensor<T>>
266where
267 T: Clone + PoolScalar,
268 F: Fn(&mut BufferPool, &TypedTensor<T>) -> crate::Result<TypedTensor<T>>,
269{
270 let (core_shape, batch_shape) = split_core_and_batch(input, core_rank, "batched_single");
271 if batch_shape.is_empty() {
272 return op(buffers, input);
273 }
274
275 let slice_size: usize = core_shape.iter().product();
276 let batch_count: usize = batch_shape.iter().product();
277 assert!(
278 batch_count > 0,
279 "batched_single: zero-sized batch dims are unsupported"
280 );
281
282 let mut out_core_shape: Option<Vec<usize>> = None;
283 let mut out_data: Option<Vec<T>> = None;
284
285 for batch_idx in 0..batch_count {
286 let start = batch_idx * slice_size;
287 let end = start + slice_size;
288 let batch_input = tensor_from_vec_with_template(
289 core_shape.to_vec(),
290 input.host_data()[start..end].to_vec(),
291 input,
292 );
293 let batch_output = op(buffers, &batch_input)?;
294
295 if let Some(expected_shape) = &out_core_shape {
296 assert_eq!(
297 batch_output.shape.as_slice(),
298 expected_shape.as_slice(),
299 "batched_single: output core shape mismatch across batches"
300 );
301 } else {
302 out_data =
303 Some(buffers.acquire_with_capacity::<T>(batch_output.n_elements() * batch_count));
304 out_core_shape = Some(batch_output.shape.clone());
305 }
306
307 out_data
308 .as_mut()
309 .expect("batched_single: missing output buffer")
310 .extend_from_slice(batch_output.host_data());
311 }
312
313 let mut out_shape = out_core_shape.expect("batched_single: missing output shape");
314 out_shape.extend_from_slice(batch_shape);
315 Ok(tensor_from_vec_with_template(
316 out_shape,
317 out_data.expect("batched_single: missing output data"),
318 input,
319 ))
320}
321
322fn batched_multi<T, F>(
323 buffers: &mut BufferPool,
324 input: &TypedTensor<T>,
325 core_rank: usize,
326 op: F,
327) -> Vec<TypedTensor<T>>
328where
329 T: Clone + PoolScalar,
330 F: Fn(&mut BufferPool, &TypedTensor<T>) -> Vec<TypedTensor<T>>,
331{
332 let (core_shape, batch_shape) = split_core_and_batch(input, core_rank, "batched_multi");
333 if batch_shape.is_empty() {
334 return op(buffers, input);
335 }
336
337 let slice_size: usize = core_shape.iter().product();
338 let batch_count: usize = batch_shape.iter().product();
339 assert!(
340 batch_count > 0,
341 "batched_multi: zero-sized batch dims are unsupported"
342 );
343
344 let mut out_shapes: Vec<Vec<usize>> = Vec::new();
345 let mut out_data: Vec<Vec<T>> = Vec::new();
346
347 for batch_idx in 0..batch_count {
348 let start = batch_idx * slice_size;
349 let end = start + slice_size;
350 let batch_input = tensor_from_vec_with_template(
351 core_shape.to_vec(),
352 input.host_data()[start..end].to_vec(),
353 input,
354 );
355 let batch_outputs = op(buffers, &batch_input);
356
357 if out_shapes.is_empty() {
358 out_shapes = batch_outputs
359 .iter()
360 .map(|tensor| tensor.shape.clone())
361 .collect();
362 let mut pooled_outputs = Vec::with_capacity(batch_outputs.len());
363 for tensor in &batch_outputs {
364 pooled_outputs
365 .push(buffers.acquire_with_capacity::<T>(tensor.n_elements() * batch_count));
366 }
367 out_data = pooled_outputs;
368 } else {
369 assert_eq!(
370 batch_outputs.len(),
371 out_shapes.len(),
372 "batched_multi: output count mismatch across batches"
373 );
374 }
375
376 for (idx, batch_output) in batch_outputs.iter().enumerate() {
377 assert_eq!(
378 batch_output.shape.as_slice(),
379 out_shapes[idx].as_slice(),
380 "batched_multi: output core shape mismatch across batches"
381 );
382 out_data[idx].extend_from_slice(batch_output.host_data());
383 }
384 }
385
386 out_shapes
387 .into_iter()
388 .zip(out_data)
389 .map(|(mut out_shape, out_data)| {
390 out_shape.extend_from_slice(batch_shape);
391 tensor_from_vec_with_template(out_shape, out_data, input)
392 })
393 .collect()
394}
395
396fn batched_multi_convert<InT, OutT, F>(
397 buffers: &mut BufferPool,
398 input: &TypedTensor<InT>,
399 core_rank: usize,
400 op: F,
401) -> Vec<TypedTensor<OutT>>
402where
403 InT: Clone,
404 OutT: Clone + PoolScalar,
405 F: Fn(&mut BufferPool, &TypedTensor<InT>) -> Vec<TypedTensor<OutT>>,
406{
407 let (core_shape, batch_shape) = split_core_and_batch(input, core_rank, "batched_multi");
408 if batch_shape.is_empty() {
409 return op(buffers, input);
410 }
411
412 let slice_size: usize = core_shape.iter().product();
413 let batch_count: usize = batch_shape.iter().product();
414 assert!(
415 batch_count > 0,
416 "batched_multi: zero-sized batch dims are unsupported"
417 );
418
419 let mut out_shapes: Vec<Vec<usize>> = Vec::new();
420 let mut out_data: Vec<Vec<OutT>> = Vec::new();
421
422 for batch_idx in 0..batch_count {
423 let start = batch_idx * slice_size;
424 let end = start + slice_size;
425 let batch_input = tensor_from_vec_with_template(
426 core_shape.to_vec(),
427 input.host_data()[start..end].to_vec(),
428 input,
429 );
430 let batch_outputs = op(buffers, &batch_input);
431
432 if out_shapes.is_empty() {
433 out_shapes = batch_outputs
434 .iter()
435 .map(|tensor| tensor.shape.clone())
436 .collect();
437 let mut pooled_outputs = Vec::with_capacity(batch_outputs.len());
438 for tensor in &batch_outputs {
439 pooled_outputs
440 .push(buffers.acquire_with_capacity::<OutT>(tensor.n_elements() * batch_count));
441 }
442 out_data = pooled_outputs;
443 } else {
444 assert_eq!(
445 batch_outputs.len(),
446 out_shapes.len(),
447 "batched_multi: output count mismatch across batches"
448 );
449 }
450
451 for (idx, batch_output) in batch_outputs.iter().enumerate() {
452 assert_eq!(
453 batch_output.shape.as_slice(),
454 out_shapes[idx].as_slice(),
455 "batched_multi: output core shape mismatch across batches"
456 );
457 out_data[idx].extend_from_slice(batch_output.host_data());
458 }
459 }
460
461 out_shapes
462 .into_iter()
463 .zip(out_data)
464 .map(|(mut out_shape, out_data)| {
465 out_shape.extend_from_slice(batch_shape);
466 tensor_from_vec_with_template(out_shape, out_data, input)
467 })
468 .collect()
469}
470
471fn batched_binary<T, F>(
472 buffers: &mut BufferPool,
473 a: &TypedTensor<T>,
474 b: &TypedTensor<T>,
475 core_rank_a: usize,
476 core_rank_b: usize,
477 op: F,
478) -> TypedTensor<T>
479where
480 T: Clone + PoolScalar,
481 F: Fn(&mut BufferPool, &TypedTensor<T>, &TypedTensor<T>) -> TypedTensor<T>,
482{
483 let (a_core_shape, a_batch_shape) = split_core_and_batch(a, core_rank_a, "batched_binary");
484 let (b_core_shape, b_batch_shape) = split_core_and_batch(b, core_rank_b, "batched_binary");
485 assert_eq!(
486 a_batch_shape, b_batch_shape,
487 "batched_binary: batch shape mismatch"
488 );
489
490 if a_batch_shape.is_empty() {
491 return op(buffers, a, b);
492 }
493
494 let a_slice_size: usize = a_core_shape.iter().product();
495 let b_slice_size: usize = b_core_shape.iter().product();
496 let batch_count: usize = a_batch_shape.iter().product();
497 assert!(
498 batch_count > 0,
499 "batched_binary: zero-sized batch dims are unsupported"
500 );
501
502 let mut out_core_shape: Option<Vec<usize>> = None;
503 let mut out_data: Option<Vec<T>> = None;
504
505 for batch_idx in 0..batch_count {
506 let a_start = batch_idx * a_slice_size;
507 let a_end = a_start + a_slice_size;
508 let b_start = batch_idx * b_slice_size;
509 let b_end = b_start + b_slice_size;
510
511 let batch_a = tensor_from_vec_with_template(
512 a_core_shape.to_vec(),
513 a.host_data()[a_start..a_end].to_vec(),
514 a,
515 );
516 let batch_b = tensor_from_vec_with_template(
517 b_core_shape.to_vec(),
518 b.host_data()[b_start..b_end].to_vec(),
519 b,
520 );
521 let batch_output = op(buffers, &batch_a, &batch_b);
522
523 if let Some(expected_shape) = &out_core_shape {
524 assert_eq!(
525 batch_output.shape.as_slice(),
526 expected_shape.as_slice(),
527 "batched_binary: output core shape mismatch across batches"
528 );
529 } else {
530 out_data =
531 Some(buffers.acquire_with_capacity::<T>(batch_output.n_elements() * batch_count));
532 out_core_shape = Some(batch_output.shape.clone());
533 }
534
535 out_data
536 .as_mut()
537 .expect("batched_binary: missing output buffer")
538 .extend_from_slice(batch_output.host_data());
539 }
540
541 let mut out_shape = out_core_shape.expect("batched_binary: missing output shape");
542 out_shape.extend_from_slice(a_batch_shape);
543 tensor_from_vec_with_template(
544 out_shape,
545 out_data.expect("batched_binary: missing output data"),
546 b,
547 )
548}
549
550impl FaerLinalg for f64 {
551 fn parity_one() -> Self {
552 1.0
553 }
554
555 fn cholesky_2d(
556 ctx: &CpuContext,
557 _buffers: &mut BufferPool,
558 input: &TypedTensor<Self>,
559 ) -> crate::Result<TypedTensor<Self>> {
560 let n = square_matrix_dim(input, "cholesky");
561 let mut l = Mat::zeros(n, n);
562 l.copy_from(MatRef::from_column_major_slice(input.host_data(), n, n));
563 let mut mem = MemBuffer::new(
564 faer::linalg::cholesky::llt::factor::cholesky_in_place_scratch::<Self>(
565 n,
566 ctx.faer_par(),
567 Default::default(),
568 ),
569 );
570 let stack = MemStack::new(&mut mem);
571 faer::linalg::cholesky::llt::factor::cholesky_in_place(
572 l.as_mut(),
573 Default::default(),
574 ctx.faer_par(),
575 stack,
576 Default::default(),
577 )
578 .map_err(|_| crate::Error::BackendFailure {
579 op: "cholesky",
580 message: "matrix is not positive definite".into(),
581 })?;
582 Ok(tensor_from_vec_with_template(
583 vec![n, n],
584 lower_triangle_vec_from_mat(l.as_ref()),
585 input,
586 ))
587 }
588
589 fn lu_2d(
590 ctx: &CpuContext,
591 _buffers: &mut BufferPool,
592 input: &TypedTensor<Self>,
593 ) -> Vec<TypedTensor<Self>> {
594 let (m, n) = matrix_dims(input, "lu");
595 let k = m.min(n);
596 let mut lu = Mat::zeros(m, n);
597 lu.copy_from(MatRef::from_column_major_slice(input.host_data(), m, n));
598 let mut perm = vec![0usize; m];
599 let mut perm_inv = vec![0usize; m];
600 let mut mem = MemBuffer::new(
601 faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, Self>(
602 m,
603 n,
604 ctx.faer_par(),
605 Default::default(),
606 ),
607 );
608 let stack = MemStack::new(&mut mem);
609 let info = faer::linalg::lu::partial_pivoting::factor::lu_in_place(
610 lu.as_mut(),
611 &mut perm,
612 &mut perm_inv,
613 ctx.faer_par(),
614 stack,
615 Default::default(),
616 )
617 .0;
618
619 let mut p_data = vec![0.0; m * m];
620 for (row, &col) in perm.iter().enumerate() {
621 p_data[row + col * m] = 1.0;
622 }
623 let parity = if info.transposition_count % 2 == 0 {
624 1.0
625 } else {
626 -1.0
627 };
628
629 let mut l_data = matrix_from_predicate(lu.as_ref(), m, k, |row, col| row >= col);
630 for i in 0..k {
631 l_data[i + i * m] = 1.0;
632 }
633 let u_data = upper_triangle_vec_from_mat(lu.as_ref().get(..k, ..));
634
635 vec![
636 tensor_from_vec_with_template(vec![m, m], p_data, input),
637 tensor_from_vec_with_template(vec![m, k], l_data, input),
638 tensor_from_vec_with_template(vec![k, n], u_data, input),
639 tensor_from_vec_with_template(vec![], vec![parity], input),
640 ]
641 }
642
643 fn triangular_solve_2d(
644 ctx: &CpuContext,
645 buffers: &mut BufferPool,
646 a: &TypedTensor<Self>,
647 b: &TypedTensor<Self>,
648 left_side: bool,
649 lower: bool,
650 transpose_a: bool,
651 unit_diagonal: bool,
652 ) -> TypedTensor<Self> {
653 let n = square_matrix_dim(a, "triangular_solve");
654 let (b_rows, b_cols) = matrix_dims(b, "triangular_solve");
655 let a_mat = MatRef::from_column_major_slice(a.host_data(), n, n);
656
657 if left_side {
658 assert_eq!(b_rows, n, "triangular_solve: rhs row count mismatch");
659 let mut rhs_data = buffers.acquire_with_capacity::<Self>(b.host_data().len());
660 rhs_data.extend_from_slice(b.host_data());
661 let rhs = MatMut::from_column_major_slice_mut(&mut rhs_data, n, b_cols);
662 match (transpose_a, lower, unit_diagonal) {
663 (false, true, false) => {
664 faer::linalg::triangular_solve::solve_lower_triangular_in_place(
665 a_mat,
666 rhs,
667 ctx.faer_par(),
668 );
669 }
670 (false, true, true) => {
671 faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
672 a_mat,
673 rhs,
674 ctx.faer_par(),
675 );
676 }
677 (false, false, false) => {
678 faer::linalg::triangular_solve::solve_upper_triangular_in_place(
679 a_mat,
680 rhs,
681 ctx.faer_par(),
682 );
683 }
684 (false, false, true) => {
685 faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
686 a_mat,
687 rhs,
688 ctx.faer_par(),
689 );
690 }
691 (true, true, false) => {
692 faer::linalg::triangular_solve::solve_upper_triangular_in_place(
693 a_mat.transpose(),
694 rhs,
695 ctx.faer_par(),
696 );
697 }
698 (true, true, true) => {
699 faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
700 a_mat.transpose(),
701 rhs,
702 ctx.faer_par(),
703 );
704 }
705 (true, false, false) => {
706 faer::linalg::triangular_solve::solve_lower_triangular_in_place(
707 a_mat.transpose(),
708 rhs,
709 ctx.faer_par(),
710 );
711 }
712 (true, false, true) => {
713 faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
714 a_mat.transpose(),
715 rhs,
716 ctx.faer_par(),
717 );
718 }
719 }
720 tensor_from_vec_with_template(vec![n, b_cols], rhs_data, b)
721 } else {
722 assert_eq!(b_cols, n, "triangular_solve: rhs column count mismatch");
723 let nrhs = b_rows;
724 let mut rhs_transposed = transpose_col_major_data(buffers, b.host_data(), nrhs, n);
725 let rhs = MatMut::from_column_major_slice_mut(&mut rhs_transposed, n, nrhs);
726 match (transpose_a, lower, unit_diagonal) {
727 (false, true, false) => {
728 faer::linalg::triangular_solve::solve_upper_triangular_in_place(
729 a_mat.transpose(),
730 rhs,
731 ctx.faer_par(),
732 );
733 }
734 (false, true, true) => {
735 faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
736 a_mat.transpose(),
737 rhs,
738 ctx.faer_par(),
739 );
740 }
741 (false, false, false) => {
742 faer::linalg::triangular_solve::solve_lower_triangular_in_place(
743 a_mat.transpose(),
744 rhs,
745 ctx.faer_par(),
746 );
747 }
748 (false, false, true) => {
749 faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
750 a_mat.transpose(),
751 rhs,
752 ctx.faer_par(),
753 );
754 }
755 (true, true, false) => {
756 faer::linalg::triangular_solve::solve_lower_triangular_in_place(
757 a_mat,
758 rhs,
759 ctx.faer_par(),
760 );
761 }
762 (true, true, true) => {
763 faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
764 a_mat,
765 rhs,
766 ctx.faer_par(),
767 );
768 }
769 (true, false, false) => {
770 faer::linalg::triangular_solve::solve_upper_triangular_in_place(
771 a_mat,
772 rhs,
773 ctx.faer_par(),
774 );
775 }
776 (true, false, true) => {
777 faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
778 a_mat,
779 rhs,
780 ctx.faer_par(),
781 );
782 }
783 }
784 let result = transpose_col_major_data(buffers, &rhs_transposed, n, nrhs);
785 <Self as PoolScalar>::pool_release(buffers, rhs_transposed);
786 tensor_from_vec_with_template(vec![nrhs, n], result, b)
787 }
788 }
789
790 fn svd_2d(
791 ctx: &CpuContext,
792 buffers: &mut BufferPool,
793 input: &TypedTensor<Self>,
794 ) -> Vec<TypedTensor<Self>> {
795 let (m, n) = matrix_dims(input, "svd");
796 let k = m.min(n);
797 let mat = MatRef::from_column_major_slice(input.host_data(), m, n);
798 let mut u = Mat::zeros(m, k);
799 let mut v = Mat::zeros(n, k);
800 let mut s = Diag::zeros(k);
801 let mut mem = MemBuffer::new(faer::linalg::svd::svd_scratch::<Self>(
802 m,
803 n,
804 faer::linalg::svd::ComputeSvdVectors::Thin,
805 faer::linalg::svd::ComputeSvdVectors::Thin,
806 ctx.faer_par(),
807 Default::default(),
808 ));
809 let stack = MemStack::new(&mut mem);
810 faer::linalg::svd::svd(
811 mat,
812 s.as_mut(),
813 Some(u.as_mut()),
814 Some(v.as_mut()),
815 ctx.faer_par(),
816 stack,
817 Default::default(),
818 )
819 .unwrap_or_else(|_| panic!("svd: decomposition failed"));
820
821 let u = tensor_from_vec_with_template(
822 vec![m, k],
823 col_major_vec_from_mat(buffers, u.as_ref()),
824 input,
825 );
826 let s = tensor_from_vec_with_template(vec![k], vec_from_diag(buffers, s.as_ref()), input);
827 let mut vt_data = buffers.acquire_with_capacity::<Self>(k * n);
828 for j in 0..n {
829 for i in 0..k {
830 vt_data.push(v[(j, i)]);
831 }
832 }
833 let vt = tensor_from_vec_with_template(vec![k, n], vt_data, input);
834
835 vec![u, s, vt]
836 }
837
838 fn qr_2d(
839 ctx: &CpuContext,
840 buffers: &mut BufferPool,
841 input: &TypedTensor<Self>,
842 ) -> Vec<TypedTensor<Self>> {
843 let (m, n) = matrix_dims(input, "qr");
844 let k = m.min(n);
845 let mat = MatRef::from_column_major_slice(input.host_data(), m, n);
846 let block_size =
847 faer::linalg::qr::no_pivoting::factor::recommended_block_size::<Self>(m, n);
848 let mut qr = Mat::zeros(m, n);
849 qr.copy_from(mat);
850 let mut coeff = Mat::zeros(block_size, k);
851 let mut mem = MemBuffer::new(
852 faer::linalg::qr::no_pivoting::factor::qr_in_place_scratch::<Self>(
853 m,
854 n,
855 block_size,
856 ctx.faer_par(),
857 Default::default(),
858 ),
859 );
860 let stack = MemStack::new(&mut mem);
861 faer::linalg::qr::no_pivoting::factor::qr_in_place(
862 qr.as_mut(),
863 coeff.as_mut(),
864 ctx.faer_par(),
865 stack,
866 Default::default(),
867 );
868 let mut q = Mat::identity(m, k);
869 let mut mem = MemBuffer::new(
870 faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<Self>(
871 m,
872 block_size,
873 k,
874 ),
875 );
876 let stack = MemStack::new(&mut mem);
877 faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
878 qr.as_ref().subcols(0, k),
879 coeff.as_ref(),
880 Conj::No,
881 q.as_mut(),
882 ctx.faer_par(),
883 stack,
884 );
885 let q = tensor_from_vec_with_template(
886 vec![m, k],
887 col_major_vec_from_mat(buffers, q.as_ref()),
888 input,
889 );
890 let r = tensor_from_vec_with_template(
891 vec![k, n],
892 upper_triangle_vec_from_mat(qr.as_ref().get(..k, ..)),
893 input,
894 );
895
896 vec![q, r]
897 }
898
899 fn eigh_2d(
900 ctx: &CpuContext,
901 buffers: &mut BufferPool,
902 input: &TypedTensor<Self>,
903 ) -> Vec<TypedTensor<Self>> {
904 let n = square_matrix_dim(input, "eigh");
905 let mat = MatRef::from_column_major_slice(input.host_data(), n, n);
906 let mut values = Diag::zeros(n);
907 let mut vectors = Mat::zeros(n, n);
908 let mut mem = MemBuffer::new(faer::linalg::evd::self_adjoint_evd_scratch::<Self>(
909 n,
910 faer::linalg::evd::ComputeEigenvectors::Yes,
911 ctx.faer_par(),
912 Default::default(),
913 ));
914 let stack = MemStack::new(&mut mem);
915 faer::linalg::evd::self_adjoint_evd(
916 mat,
917 values.as_mut(),
918 Some(vectors.as_mut()),
919 ctx.faer_par(),
920 stack,
921 Default::default(),
922 )
923 .unwrap_or_else(|_| panic!("eigh: decomposition failed"));
924
925 let values =
926 tensor_from_vec_with_template(vec![n], vec_from_diag(buffers, values.as_ref()), input);
927 let vectors = tensor_from_vec_with_template(
928 vec![n, n],
929 col_major_vec_from_mat(buffers, vectors.as_ref()),
930 input,
931 );
932
933 vec![values, vectors]
934 }
935}
936
937impl FaerLinalg for Complex64 {
938 fn parity_one() -> Self {
939 Complex64::new(1.0, 0.0)
940 }
941
942 fn cholesky_2d(
943 ctx: &CpuContext,
944 _buffers: &mut BufferPool,
945 input: &TypedTensor<Self>,
946 ) -> crate::Result<TypedTensor<Self>> {
947 let n = square_matrix_dim(input, "cholesky");
948 let mut l = Mat::zeros(n, n);
949 l.copy_from(MatRef::from_column_major_slice(
950 complex64_to_faer_slice(input.host_data()),
951 n,
952 n,
953 ));
954 let mut mem = MemBuffer::new(
955 faer::linalg::cholesky::llt::factor::cholesky_in_place_scratch::<faer::c64>(
956 n,
957 ctx.faer_par(),
958 Default::default(),
959 ),
960 );
961 let stack = MemStack::new(&mut mem);
962 faer::linalg::cholesky::llt::factor::cholesky_in_place(
963 l.as_mut(),
964 Default::default(),
965 ctx.faer_par(),
966 stack,
967 Default::default(),
968 )
969 .map_err(|_| crate::Error::BackendFailure {
970 op: "cholesky",
971 message: "matrix is not positive definite".into(),
972 })?;
973 Ok(tensor_from_vec_with_template(
974 vec![n, n],
975 complex_matrix_from_predicate(l.as_ref(), n, n, |row, col| row >= col),
976 input,
977 ))
978 }
979
980 fn lu_2d(
981 ctx: &CpuContext,
982 _buffers: &mut BufferPool,
983 input: &TypedTensor<Self>,
984 ) -> Vec<TypedTensor<Self>> {
985 let (m, n) = matrix_dims(input, "lu");
986 let k = m.min(n);
987 let mut lu = Mat::zeros(m, n);
988 lu.copy_from(MatRef::from_column_major_slice(
989 complex64_to_faer_slice(input.host_data()),
990 m,
991 n,
992 ));
993 let mut perm = vec![0usize; m];
994 let mut perm_inv = vec![0usize; m];
995 let mut mem = MemBuffer::new(
996 faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, faer::c64>(
997 m,
998 n,
999 ctx.faer_par(),
1000 Default::default(),
1001 ),
1002 );
1003 let stack = MemStack::new(&mut mem);
1004 let info = faer::linalg::lu::partial_pivoting::factor::lu_in_place(
1005 lu.as_mut(),
1006 &mut perm,
1007 &mut perm_inv,
1008 ctx.faer_par(),
1009 stack,
1010 Default::default(),
1011 )
1012 .0;
1013
1014 let mut p_data = vec![Complex64::new(0.0, 0.0); m * m];
1015 for (row, &col) in perm.iter().enumerate() {
1016 p_data[row + col * m] = Complex64::new(1.0, 0.0);
1017 }
1018 let parity = if info.transposition_count % 2 == 0 {
1019 Complex64::new(1.0, 0.0)
1020 } else {
1021 Complex64::new(-1.0, 0.0)
1022 };
1023 let mut l_data = complex_matrix_from_predicate(lu.as_ref(), m, k, |row, col| row >= col);
1024 for i in 0..k {
1025 l_data[i + i * m] = Complex64::new(1.0, 0.0);
1026 }
1027 let u_data = complex_matrix_from_predicate(lu.as_ref(), k, n, |row, col| row <= col);
1028
1029 vec![
1030 tensor_from_vec_with_template(vec![m, m], p_data, input),
1031 tensor_from_vec_with_template(vec![m, k], l_data, input),
1032 tensor_from_vec_with_template(vec![k, n], u_data, input),
1033 tensor_from_vec_with_template(vec![], vec![parity], input),
1034 ]
1035 }
1036
1037 fn triangular_solve_2d(
1038 ctx: &CpuContext,
1039 buffers: &mut BufferPool,
1040 a: &TypedTensor<Self>,
1041 b: &TypedTensor<Self>,
1042 left_side: bool,
1043 lower: bool,
1044 transpose_a: bool,
1045 unit_diagonal: bool,
1046 ) -> TypedTensor<Self> {
1047 let n = square_matrix_dim(a, "triangular_solve");
1048 let (b_rows, b_cols) = matrix_dims(b, "triangular_solve");
1049 let a_mat = MatRef::from_column_major_slice(complex64_to_faer_slice(a.host_data()), n, n);
1050
1051 if left_side {
1052 assert_eq!(b_rows, n, "triangular_solve: rhs row count mismatch");
1053 let mut rhs_data = buffers.acquire_with_capacity::<Self>(b.host_data().len());
1054 rhs_data.extend_from_slice(b.host_data());
1055 let rhs = MatMut::from_column_major_slice_mut(
1056 complex64_to_faer_slice_mut(&mut rhs_data),
1057 n,
1058 b_cols,
1059 );
1060 match (transpose_a, lower, unit_diagonal) {
1061 (false, true, false) => {
1062 faer::linalg::triangular_solve::solve_lower_triangular_in_place(
1063 a_mat,
1064 rhs,
1065 ctx.faer_par(),
1066 );
1067 }
1068 (false, true, true) => {
1069 faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
1070 a_mat,
1071 rhs,
1072 ctx.faer_par(),
1073 );
1074 }
1075 (false, false, false) => {
1076 faer::linalg::triangular_solve::solve_upper_triangular_in_place(
1077 a_mat,
1078 rhs,
1079 ctx.faer_par(),
1080 );
1081 }
1082 (false, false, true) => {
1083 faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
1084 a_mat,
1085 rhs,
1086 ctx.faer_par(),
1087 );
1088 }
1089 (true, true, false) => {
1090 faer::linalg::triangular_solve::solve_upper_triangular_in_place(
1091 a_mat.transpose(),
1092 rhs,
1093 ctx.faer_par(),
1094 );
1095 }
1096 (true, true, true) => {
1097 faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
1098 a_mat.transpose(),
1099 rhs,
1100 ctx.faer_par(),
1101 );
1102 }
1103 (true, false, false) => {
1104 faer::linalg::triangular_solve::solve_lower_triangular_in_place(
1105 a_mat.transpose(),
1106 rhs,
1107 ctx.faer_par(),
1108 );
1109 }
1110 (true, false, true) => {
1111 faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
1112 a_mat.transpose(),
1113 rhs,
1114 ctx.faer_par(),
1115 );
1116 }
1117 }
1118 tensor_from_vec_with_template(vec![n, b_cols], rhs_data, b)
1119 } else {
1120 assert_eq!(b_cols, n, "triangular_solve: rhs column count mismatch");
1121 let nrhs = b_rows;
1122 let mut rhs_transposed = transpose_col_major_data(buffers, b.host_data(), nrhs, n);
1123 let rhs = MatMut::from_column_major_slice_mut(
1124 complex64_to_faer_slice_mut(&mut rhs_transposed),
1125 n,
1126 nrhs,
1127 );
1128 match (transpose_a, lower, unit_diagonal) {
1129 (false, true, false) => {
1130 faer::linalg::triangular_solve::solve_upper_triangular_in_place(
1131 a_mat.transpose(),
1132 rhs,
1133 ctx.faer_par(),
1134 );
1135 }
1136 (false, true, true) => {
1137 faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
1138 a_mat.transpose(),
1139 rhs,
1140 ctx.faer_par(),
1141 );
1142 }
1143 (false, false, false) => {
1144 faer::linalg::triangular_solve::solve_lower_triangular_in_place(
1145 a_mat.transpose(),
1146 rhs,
1147 ctx.faer_par(),
1148 );
1149 }
1150 (false, false, true) => {
1151 faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
1152 a_mat.transpose(),
1153 rhs,
1154 ctx.faer_par(),
1155 );
1156 }
1157 (true, true, false) => {
1158 faer::linalg::triangular_solve::solve_lower_triangular_in_place(
1159 a_mat,
1160 rhs,
1161 ctx.faer_par(),
1162 );
1163 }
1164 (true, true, true) => {
1165 faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
1166 a_mat,
1167 rhs,
1168 ctx.faer_par(),
1169 );
1170 }
1171 (true, false, false) => {
1172 faer::linalg::triangular_solve::solve_upper_triangular_in_place(
1173 a_mat,
1174 rhs,
1175 ctx.faer_par(),
1176 );
1177 }
1178 (true, false, true) => {
1179 faer::linalg::triangular_solve::solve_unit_upper_triangular_in_place(
1180 a_mat,
1181 rhs,
1182 ctx.faer_par(),
1183 );
1184 }
1185 }
1186 let result = transpose_col_major_data(buffers, &rhs_transposed, n, nrhs);
1187 <Self as PoolScalar>::pool_release(buffers, rhs_transposed);
1188 tensor_from_vec_with_template(vec![nrhs, n], result, b)
1189 }
1190 }
1191
1192 fn svd_2d(
1193 ctx: &CpuContext,
1194 buffers: &mut BufferPool,
1195 input: &TypedTensor<Self>,
1196 ) -> Vec<TypedTensor<Self>> {
1197 let (m, n) = matrix_dims(input, "svd");
1198 let k = m.min(n);
1199 let mat = MatRef::from_column_major_slice(complex64_to_faer_slice(input.host_data()), m, n);
1200 let mut u = Mat::zeros(m, k);
1201 let mut v = Mat::zeros(n, k);
1202 let mut s = Diag::zeros(k);
1203 let mut mem = MemBuffer::new(faer::linalg::svd::svd_scratch::<faer::c64>(
1204 m,
1205 n,
1206 faer::linalg::svd::ComputeSvdVectors::Thin,
1207 faer::linalg::svd::ComputeSvdVectors::Thin,
1208 ctx.faer_par(),
1209 Default::default(),
1210 ));
1211 let stack = MemStack::new(&mut mem);
1212 faer::linalg::svd::svd(
1213 mat,
1214 s.as_mut(),
1215 Some(u.as_mut()),
1216 Some(v.as_mut()),
1217 ctx.faer_par(),
1218 stack,
1219 Default::default(),
1220 )
1221 .unwrap_or_else(|_| panic!("svd: decomposition failed"));
1222
1223 let u = tensor_from_vec_with_template(
1224 vec![m, k],
1225 complex_vec_from_mat(buffers, u.as_ref()),
1226 input,
1227 );
1228 let s = tensor_from_vec_with_template(
1229 vec![k],
1230 complex_vec_from_real_diag(buffers, s.as_ref()),
1231 input,
1232 );
1233 let mut vt_data = buffers.acquire_with_capacity::<Self>(k * n);
1234 for j in 0..n {
1235 for i in 0..k {
1236 vt_data.push(v[(j, i)].conj());
1237 }
1238 }
1239 let vt = tensor_from_vec_with_template(vec![k, n], vt_data, input);
1240
1241 vec![u, s, vt]
1242 }
1243
1244 fn qr_2d(
1245 ctx: &CpuContext,
1246 buffers: &mut BufferPool,
1247 input: &TypedTensor<Self>,
1248 ) -> Vec<TypedTensor<Self>> {
1249 let (m, n) = matrix_dims(input, "qr");
1250 let k = m.min(n);
1251 let mat = MatRef::from_column_major_slice(complex64_to_faer_slice(input.host_data()), m, n);
1252 let block_size =
1253 faer::linalg::qr::no_pivoting::factor::recommended_block_size::<faer::c64>(m, n);
1254 let mut qr = Mat::zeros(m, n);
1255 qr.copy_from(mat);
1256 let mut coeff = Mat::zeros(block_size, k);
1257 let mut mem = MemBuffer::new(
1258 faer::linalg::qr::no_pivoting::factor::qr_in_place_scratch::<faer::c64>(
1259 m,
1260 n,
1261 block_size,
1262 ctx.faer_par(),
1263 Default::default(),
1264 ),
1265 );
1266 let stack = MemStack::new(&mut mem);
1267 faer::linalg::qr::no_pivoting::factor::qr_in_place(
1268 qr.as_mut(),
1269 coeff.as_mut(),
1270 ctx.faer_par(),
1271 stack,
1272 Default::default(),
1273 );
1274 let mut q = Mat::identity(m, k);
1275 let mut mem = MemBuffer::new(
1276 faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<faer::c64>(
1277 m,
1278 block_size,
1279 k,
1280 ),
1281 );
1282 let stack = MemStack::new(&mut mem);
1283 faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1284 qr.as_ref().subcols(0, k),
1285 coeff.as_ref(),
1286 Conj::No,
1287 q.as_mut(),
1288 ctx.faer_par(),
1289 stack,
1290 );
1291 let q = tensor_from_vec_with_template(
1292 vec![m, k],
1293 complex_vec_from_mat(buffers, q.as_ref()),
1294 input,
1295 );
1296 let r = tensor_from_vec_with_template(
1297 vec![k, n],
1298 complex_matrix_from_predicate(qr.as_ref(), k, n, |row, col| row <= col),
1299 input,
1300 );
1301
1302 vec![q, r]
1303 }
1304
1305 fn eigh_2d(
1306 ctx: &CpuContext,
1307 buffers: &mut BufferPool,
1308 input: &TypedTensor<Self>,
1309 ) -> Vec<TypedTensor<Self>> {
1310 let n = square_matrix_dim(input, "eigh");
1311 let mat = MatRef::from_column_major_slice(complex64_to_faer_slice(input.host_data()), n, n);
1312 let mut values = Diag::zeros(n);
1313 let mut vectors = Mat::zeros(n, n);
1314 let mut mem = MemBuffer::new(faer::linalg::evd::self_adjoint_evd_scratch::<faer::c64>(
1315 n,
1316 faer::linalg::evd::ComputeEigenvectors::Yes,
1317 ctx.faer_par(),
1318 Default::default(),
1319 ));
1320 let stack = MemStack::new(&mut mem);
1321 faer::linalg::evd::self_adjoint_evd(
1322 mat,
1323 values.as_mut(),
1324 Some(vectors.as_mut()),
1325 ctx.faer_par(),
1326 stack,
1327 Default::default(),
1328 )
1329 .unwrap_or_else(|_| panic!("eigh: decomposition failed"));
1330
1331 let values = tensor_from_vec_with_template(
1332 vec![n],
1333 complex_vec_from_real_diag(buffers, values.as_ref()),
1334 input,
1335 );
1336 let vectors = tensor_from_vec_with_template(
1337 vec![n, n],
1338 complex_vec_from_mat(buffers, vectors.as_ref()),
1339 input,
1340 );
1341
1342 vec![values, vectors]
1343 }
1344}
1345
1346pub(crate) fn cholesky<T: FaerLinalg>(
1347 ctx: &CpuContext,
1348 buffers: &mut BufferPool,
1349 input: &TypedTensor<T>,
1350) -> crate::Result<TypedTensor<T>> {
1351 if has_zero_dim(&input.shape) {
1352 return Ok(tensor_from_vec_with_template(
1353 input.shape.clone(),
1354 Vec::new(),
1355 input,
1356 ));
1357 }
1358 batched_single(buffers, input, 2, |buffers, batch| {
1359 T::cholesky_2d(ctx, buffers, batch)
1360 })
1361}
1362
1363pub(crate) fn lu<T: FaerLinalg>(
1364 ctx: &CpuContext,
1365 buffers: &mut BufferPool,
1366 input: &TypedTensor<T>,
1367) -> Vec<TypedTensor<T>> {
1368 if has_zero_dim(&input.shape) {
1369 let m = input.shape[0];
1370 let n = input.shape[1];
1371 let k = m.min(n);
1372 let batch_shape = &input.shape[2..];
1373 let parity_elements: usize = batch_shape.iter().product::<usize>().max(1);
1374 return vec![
1375 tensor_from_vec_with_template(
1376 matrix_with_batch_shape(m, m, batch_shape),
1377 Vec::new(),
1378 input,
1379 ),
1380 tensor_from_vec_with_template(
1381 matrix_with_batch_shape(m, k, batch_shape),
1382 Vec::new(),
1383 input,
1384 ),
1385 tensor_from_vec_with_template(
1386 matrix_with_batch_shape(k, n, batch_shape),
1387 Vec::new(),
1388 input,
1389 ),
1390 tensor_from_vec_with_template(
1391 batch_shape.to_vec(),
1392 vec![T::parity_one(); parity_elements],
1393 input,
1394 ),
1395 ];
1396 }
1397 batched_multi(buffers, input, 2, |buffers, batch| {
1398 T::lu_2d(ctx, buffers, batch)
1399 })
1400}
1401
1402pub(crate) fn triangular_solve<T: FaerLinalg>(
1403 ctx: &CpuContext,
1404 buffers: &mut BufferPool,
1405 a: &TypedTensor<T>,
1406 b: &TypedTensor<T>,
1407 left_side: bool,
1408 lower: bool,
1409 transpose_a: bool,
1410 unit_diagonal: bool,
1411) -> TypedTensor<T> {
1412 if has_zero_dim(&a.shape) || has_zero_dim(&b.shape) {
1413 return tensor_from_vec_with_template(b.shape.clone(), Vec::new(), b);
1414 }
1415 batched_binary(buffers, a, b, 2, 2, |buffers, a, b| {
1416 T::triangular_solve_2d(
1417 ctx,
1418 buffers,
1419 a,
1420 b,
1421 left_side,
1422 lower,
1423 transpose_a,
1424 unit_diagonal,
1425 )
1426 })
1427}
1428
1429pub(crate) fn svd<T: FaerLinalg>(
1430 ctx: &CpuContext,
1431 buffers: &mut BufferPool,
1432 input: &TypedTensor<T>,
1433) -> Vec<TypedTensor<T>> {
1434 if has_zero_dim(&input.shape) {
1435 let (matrix_shape, batch_shape) = split_core_and_batch(input, 2, "svd");
1436 let m = matrix_shape[0];
1437 let n = matrix_shape[1];
1438 let k = m.min(n);
1439 return vec![
1440 tensor_from_vec_with_template(
1441 matrix_with_batch_shape(m, k, batch_shape),
1442 Vec::new(),
1443 input,
1444 ),
1445 tensor_from_vec_with_template(
1446 vector_with_batch_shape(k, batch_shape),
1447 Vec::new(),
1448 input,
1449 ),
1450 tensor_from_vec_with_template(
1451 matrix_with_batch_shape(k, n, batch_shape),
1452 Vec::new(),
1453 input,
1454 ),
1455 ];
1456 }
1457 batched_multi(buffers, input, 2, |buffers, batch| {
1458 T::svd_2d(ctx, buffers, batch)
1459 })
1460}
1461
1462pub(crate) fn qr<T: FaerLinalg>(
1463 ctx: &CpuContext,
1464 buffers: &mut BufferPool,
1465 input: &TypedTensor<T>,
1466) -> Vec<TypedTensor<T>> {
1467 if has_zero_dim(&input.shape) {
1468 let (matrix_shape, batch_shape) = split_core_and_batch(input, 2, "qr");
1469 let m = matrix_shape[0];
1470 let n = matrix_shape[1];
1471 let k = m.min(n);
1472 return vec![
1473 tensor_from_vec_with_template(
1474 matrix_with_batch_shape(m, k, batch_shape),
1475 Vec::new(),
1476 input,
1477 ),
1478 tensor_from_vec_with_template(
1479 matrix_with_batch_shape(k, n, batch_shape),
1480 Vec::new(),
1481 input,
1482 ),
1483 ];
1484 }
1485 batched_multi(buffers, input, 2, |buffers, batch| {
1486 T::qr_2d(ctx, buffers, batch)
1487 })
1488}
1489
1490pub(crate) fn eigh<T: FaerLinalg>(
1491 ctx: &CpuContext,
1492 buffers: &mut BufferPool,
1493 input: &TypedTensor<T>,
1494) -> Vec<TypedTensor<T>> {
1495 if has_zero_dim(&input.shape) {
1496 let n = input.shape[0];
1497 let batch_shape = &input.shape[2..];
1498 return vec![
1499 tensor_from_vec_with_template(
1500 vector_with_batch_shape(n, batch_shape),
1501 Vec::new(),
1502 input,
1503 ),
1504 tensor_from_vec_with_template(
1505 matrix_with_batch_shape(n, n, batch_shape),
1506 Vec::new(),
1507 input,
1508 ),
1509 ];
1510 }
1511 batched_multi(buffers, input, 2, |buffers, batch| {
1512 T::eigh_2d(ctx, buffers, batch)
1513 })
1514}
1515
1516fn eig_real_2d(
1517 ctx: &CpuContext,
1518 buffers: &mut BufferPool,
1519 input: &TypedTensor<f64>,
1520) -> Vec<TypedTensor<Complex64>> {
1521 let n = square_matrix_dim(input, "eig");
1522 let mat = MatRef::from_column_major_slice(input.host_data(), n, n);
1523 let mut u_real = Mat::zeros(n, n);
1524 let mut s_re = Diag::zeros(n);
1525 let mut s_im = Diag::zeros(n);
1526 let mut mem = MemBuffer::new(faer::linalg::evd::evd_scratch::<f64>(
1527 n,
1528 faer::linalg::evd::ComputeEigenvectors::No,
1529 faer::linalg::evd::ComputeEigenvectors::Yes,
1530 ctx.faer_par(),
1531 Default::default(),
1532 ));
1533 let stack = MemStack::new(&mut mem);
1534 faer::linalg::evd::evd_real(
1535 mat,
1536 s_re.as_mut(),
1537 s_im.as_mut(),
1538 None,
1539 Some(u_real.as_mut()),
1540 ctx.faer_par(),
1541 stack,
1542 Default::default(),
1543 )
1544 .unwrap_or_else(|_| panic!("eig: decomposition failed"));
1545 let (u, s) =
1546 real_eig_to_complex_outputs(buffers, u_real.as_ref(), s_re.as_ref(), s_im.as_ref());
1547
1548 vec![
1549 tensor_from_vec_with_template(vec![n], s, input),
1550 tensor_from_vec_with_template(vec![n, n], u, input),
1551 ]
1552}
1553
1554fn eig_complex_2d(
1555 ctx: &CpuContext,
1556 buffers: &mut BufferPool,
1557 input: &TypedTensor<Complex64>,
1558) -> Vec<TypedTensor<Complex64>> {
1559 let n = square_matrix_dim(input, "eig");
1560 let mat = MatRef::from_column_major_slice(complex64_to_faer_slice(input.host_data()), n, n);
1561 let mut u = Mat::zeros(n, n);
1562 let mut s = Diag::zeros(n);
1563 let mut mem = MemBuffer::new(faer::linalg::evd::evd_scratch::<faer::c64>(
1564 n,
1565 faer::linalg::evd::ComputeEigenvectors::No,
1566 faer::linalg::evd::ComputeEigenvectors::Yes,
1567 ctx.faer_par(),
1568 Default::default(),
1569 ));
1570 let stack = MemStack::new(&mut mem);
1571 faer::linalg::evd::evd_cplx(
1572 mat,
1573 s.as_mut(),
1574 None,
1575 Some(u.as_mut()),
1576 ctx.faer_par(),
1577 stack,
1578 Default::default(),
1579 )
1580 .unwrap_or_else(|_| panic!("eig: decomposition failed"));
1581
1582 vec![
1583 tensor_from_vec_with_template(vec![n], complex_vec_from_diag(buffers, s.as_ref()), input),
1584 tensor_from_vec_with_template(vec![n, n], complex_vec_from_mat(buffers, u.as_ref()), input),
1585 ]
1586}
1587
1588pub(crate) fn eig(ctx: &CpuContext, buffers: &mut BufferPool, input: &Tensor) -> Vec<Tensor> {
1589 if has_zero_dim(input.shape()) {
1590 let n = input.shape()[0];
1591 let batch_shape = &input.shape()[2..];
1592 return vec![
1593 Tensor::C64(TypedTensor::from_vec(
1594 vector_with_batch_shape(n, batch_shape),
1595 Vec::new(),
1596 )),
1597 Tensor::C64(TypedTensor::from_vec(
1598 matrix_with_batch_shape(n, n, batch_shape),
1599 Vec::new(),
1600 )),
1601 ];
1602 }
1603
1604 match input {
1605 Tensor::F64(t) => batched_multi_convert(buffers, t, 2, |buffers, batch| {
1606 eig_real_2d(ctx, buffers, batch)
1607 })
1608 .into_iter()
1609 .map(Tensor::C64)
1610 .collect(),
1611 Tensor::C64(t) => batched_multi_convert(buffers, t, 2, |buffers, batch| {
1612 eig_complex_2d(ctx, buffers, batch)
1613 })
1614 .into_iter()
1615 .map(Tensor::C64)
1616 .collect(),
1617 _ => todo!("eig: unsupported dtype"),
1618 }
1619}
1620
1621fn has_zero_dim(shape: &[usize]) -> bool {
1622 shape.contains(&0)
1623}
1624
1625fn matrix_with_batch_shape(rows: usize, cols: usize, batch_shape: &[usize]) -> Vec<usize> {
1626 let mut shape = vec![rows, cols];
1627 shape.extend_from_slice(batch_shape);
1628 shape
1629}
1630
1631fn vector_with_batch_shape(len: usize, batch_shape: &[usize]) -> Vec<usize> {
1632 let mut shape = vec![len];
1633 shape.extend_from_slice(batch_shape);
1634 shape
1635}