1use std::sync::Arc;
2
3use num_complex::{Complex32, Complex64};
4use tenferro_runtime::extension::apply;
5use tenferro_runtime::{CompareDir, DType, DotGeneralConfig, Error, Result, TracedTensor};
6
7use crate::extension::{LinalgExtensionOp, LinalgOp};
8
9pub trait TracedTensorLinalgExt {
11 fn svd(&self) -> Result<(TracedTensor, TracedTensor, TracedTensor)>;
12 fn svd_with_eps(&self, eps: f64) -> Result<(TracedTensor, TracedTensor, TracedTensor)>;
13 fn qr(&self) -> Result<(TracedTensor, TracedTensor)>;
14 fn eigh(&self) -> Result<(TracedTensor, TracedTensor)>;
15 fn eigh_with_eps(&self, eps: f64) -> Result<(TracedTensor, TracedTensor)>;
16 fn cholesky(&self) -> Result<TracedTensor>;
17 fn lu(&self) -> Result<(TracedTensor, TracedTensor, TracedTensor, TracedTensor)>;
18 fn full_piv_lu(
19 &self,
20 ) -> Result<(
21 TracedTensor,
22 TracedTensor,
23 TracedTensor,
24 TracedTensor,
25 TracedTensor,
26 )>;
27 fn eig(&self) -> Result<(TracedTensor, TracedTensor)>;
28 fn solve(&self, b: &TracedTensor) -> Result<TracedTensor>;
29 fn full_piv_lu_solve(&self, b: &TracedTensor) -> Result<TracedTensor>;
30 fn triangular_solve(
31 &self,
32 b: &TracedTensor,
33 left_side: bool,
34 lower: bool,
35 transpose_a: bool,
36 unit_diagonal: bool,
37 ) -> Result<TracedTensor>;
38 fn slogdet(&self) -> Result<(TracedTensor, TracedTensor)>;
39 fn det(&self) -> Result<TracedTensor>;
40 fn inv(&self) -> Result<TracedTensor>;
41 fn eigvalsh(&self) -> Result<TracedTensor>;
42 fn eigvals(&self) -> Result<TracedTensor>;
43 fn pinv(&self) -> Result<TracedTensor>;
44 fn pinv_with_rtol(&self, rtol: f64) -> Result<TracedTensor>;
45 fn norm(&self, ord: Option<f64>, dim: Option<&[usize]>, keepdim: bool) -> Result<TracedTensor>;
46}
47
48impl TracedTensorLinalgExt for TracedTensor {
49 fn svd(&self) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
50 svd(self)
51 }
52
53 fn svd_with_eps(&self, eps: f64) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
54 svd_with_eps(self, eps)
55 }
56
57 fn qr(&self) -> Result<(TracedTensor, TracedTensor)> {
58 qr(self)
59 }
60
61 fn eigh(&self) -> Result<(TracedTensor, TracedTensor)> {
62 eigh(self)
63 }
64
65 fn eigh_with_eps(&self, eps: f64) -> Result<(TracedTensor, TracedTensor)> {
66 eigh_with_eps(self, eps)
67 }
68
69 fn cholesky(&self) -> Result<TracedTensor> {
70 cholesky(self)
71 }
72
73 fn lu(&self) -> Result<(TracedTensor, TracedTensor, TracedTensor, TracedTensor)> {
74 lu(self)
75 }
76
77 fn full_piv_lu(
78 &self,
79 ) -> Result<(
80 TracedTensor,
81 TracedTensor,
82 TracedTensor,
83 TracedTensor,
84 TracedTensor,
85 )> {
86 full_piv_lu(self)
87 }
88
89 fn eig(&self) -> Result<(TracedTensor, TracedTensor)> {
90 eig(self)
91 }
92
93 fn solve(&self, b: &TracedTensor) -> Result<TracedTensor> {
94 solve(self, b)
95 }
96
97 fn full_piv_lu_solve(&self, b: &TracedTensor) -> Result<TracedTensor> {
98 full_piv_lu_solve(self, b)
99 }
100
101 fn triangular_solve(
102 &self,
103 b: &TracedTensor,
104 left_side: bool,
105 lower: bool,
106 transpose_a: bool,
107 unit_diagonal: bool,
108 ) -> Result<TracedTensor> {
109 triangular_solve(self, b, left_side, lower, transpose_a, unit_diagonal)
110 }
111
112 fn slogdet(&self) -> Result<(TracedTensor, TracedTensor)> {
113 slogdet(self)
114 }
115
116 fn det(&self) -> Result<TracedTensor> {
117 det(self)
118 }
119
120 fn inv(&self) -> Result<TracedTensor> {
121 inv(self)
122 }
123
124 fn eigvalsh(&self) -> Result<TracedTensor> {
125 eigvalsh(self)
126 }
127
128 fn eigvals(&self) -> Result<TracedTensor> {
129 eigvals(self)
130 }
131
132 fn pinv(&self) -> Result<TracedTensor> {
133 pinv(self)
134 }
135
136 fn pinv_with_rtol(&self, rtol: f64) -> Result<TracedTensor> {
137 pinv_with_rtol(self, rtol)
138 }
139
140 fn norm(&self, ord: Option<f64>, dim: Option<&[usize]>, keepdim: bool) -> Result<TracedTensor> {
141 norm(self, ord, dim, keepdim)
142 }
143}
144
145pub fn svd(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
160 svd_with_eps(a, 1e-12)
161}
162
163pub fn svd_with_eps(
176 a: &TracedTensor,
177 eps: f64,
178) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
179 three_outputs(
180 apply(
181 Arc::new(LinalgExtensionOp::new(LinalgOp::Svd { eps })),
182 &[a],
183 )?,
184 "svd",
185 )
186}
187
188pub fn qr(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor)> {
202 two_outputs(
203 apply(Arc::new(LinalgExtensionOp::new(LinalgOp::Qr)), &[a])?,
204 "qr",
205 )
206}
207
208pub fn eigh(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor)> {
222 eigh_with_eps(a, 1e-12)
223}
224
225pub fn eigh_with_eps(a: &TracedTensor, eps: f64) -> Result<(TracedTensor, TracedTensor)> {
238 two_outputs(
239 apply(
240 Arc::new(LinalgExtensionOp::new(LinalgOp::Eigh { eps })),
241 &[a],
242 )?,
243 "eigh",
244 )
245}
246
247pub fn cholesky(a: &TracedTensor) -> Result<TracedTensor> {
260 one_output(
261 apply(Arc::new(LinalgExtensionOp::new(LinalgOp::Cholesky)), &[a])?,
262 "cholesky",
263 )
264}
265
266pub fn lu(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor, TracedTensor, TracedTensor)> {
282 four_outputs(
283 apply(Arc::new(LinalgExtensionOp::new(LinalgOp::Lu)), &[a])?,
284 "lu",
285 )
286}
287
288pub fn full_piv_lu(
310 a: &TracedTensor,
311) -> Result<(
312 TracedTensor,
313 TracedTensor,
314 TracedTensor,
315 TracedTensor,
316 TracedTensor,
317)> {
318 five_outputs(
319 apply(Arc::new(LinalgExtensionOp::new(LinalgOp::FullPivLu)), &[a])?,
320 "full_piv_lu",
321 )
322}
323
324pub fn eig(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor)> {
338 two_outputs(
339 apply(
340 Arc::new(LinalgExtensionOp::new(LinalgOp::Eig {
341 input_dtype: a.dtype,
342 })),
343 &[a],
344 )?,
345 "eig",
346 )
347}
348
349pub fn solve(a: &TracedTensor, b: &TracedTensor) -> Result<TracedTensor> {
363 let mut factor_outputs =
364 apply(Arc::new(LinalgExtensionOp::new(LinalgOp::LuFactor)), &[a])?.into_iter();
365 let (packed_lu, pivots) = match (
366 factor_outputs.next(),
367 factor_outputs.next(),
368 factor_outputs.next(),
369 factor_outputs.next(),
370 ) {
371 (Some(packed_lu), Some(pivots), Some(_parity), None) => (packed_lu, pivots),
372 _ => return Err(unexpected_output_count("lu_factor", 3)),
373 };
374 one_output(
375 apply(
376 Arc::new(LinalgExtensionOp::new(LinalgOp::LuSolvePrepared {
377 transpose_a: false,
378 conjugate_a: false,
379 })),
380 &[a, &packed_lu, &pivots, b],
381 )?,
382 "solve",
383 )
384}
385
386pub fn full_piv_lu_solve(a: &TracedTensor, b: &TracedTensor) -> Result<TracedTensor> {
400 one_output(
401 apply(
402 Arc::new(LinalgExtensionOp::new(LinalgOp::FullPivLuSolve {
403 transpose_a: false,
404 })),
405 &[a, b],
406 )?,
407 "full_piv_lu_solve",
408 )
409}
410
411pub fn triangular_solve(
425 a: &TracedTensor,
426 b: &TracedTensor,
427 left_side: bool,
428 lower: bool,
429 transpose_a: bool,
430 unit_diagonal: bool,
431) -> Result<TracedTensor> {
432 one_output(
433 apply(
434 Arc::new(LinalgExtensionOp::new(LinalgOp::TriangularSolve {
435 left_side,
436 lower,
437 transpose_a,
438 unit_diagonal,
439 })),
440 &[a, b],
441 )?,
442 "triangular_solve",
443 )
444}
445
446pub fn slogdet(a: &TracedTensor) -> Result<(TracedTensor, TracedTensor)> {
460 let mut factor_outputs =
461 apply(Arc::new(LinalgExtensionOp::new(LinalgOp::LuFactor)), &[a])?.into_iter();
462 let (packed_lu, parity) = match (
463 factor_outputs.next(),
464 factor_outputs.next(),
465 factor_outputs.next(),
466 factor_outputs.next(),
467 ) {
468 (Some(packed_lu), Some(_pivots), Some(parity), None) => (packed_lu, parity),
469 _ => return Err(unexpected_output_count("lu_factor", 3)),
470 };
471 let diag_u = packed_lu.extract_diag(0, 1)?;
472 let sign_u = diag_u.sign().reduce_prod(&[0])?;
473 let sign = (&parity * &sign_u)?;
474 let logabsdet = diag_u.abs().log().reduce_sum(&[0])?;
475 Ok((sign, logabsdet))
476}
477
478pub fn det(a: &TracedTensor) -> Result<TracedTensor> {
491 let (sign, logabsdet) = slogdet(a)?;
492 &sign * &logabsdet.exp()
493}
494
495pub fn inv(a: &TracedTensor) -> Result<TracedTensor> {
508 ensure_min_rank("inv", a.rank, 2)?;
509 let shape = require_concrete_shape("inv", a)?;
510 let eye = eye_like(a, shape[0])?;
511 solve(a, &eye)
512}
513
514pub fn eigvalsh(a: &TracedTensor) -> Result<TracedTensor> {
527 eigh_values(a)
528}
529
530pub fn eigvals(a: &TracedTensor) -> Result<TracedTensor> {
543 eig_values(a)
544}
545
546pub fn pinv(a: &TracedTensor) -> Result<TracedTensor> {
562 ensure_float_or_complex("pinv", a.dtype)?;
563 let shape = require_concrete_shape("pinv", a)?;
564 let max_dim = match (shape.first(), shape.get(1)) {
565 (Some(&m), Some(&n)) => m.max(n),
566 (Some(&m), None) => m,
567 _ => 0,
568 };
569 pinv_with_rtol(a, default_pinv_rtol(a.dtype, max_dim))
570}
571
572pub fn pinv_with_rtol(a: &TracedTensor, rtol: f64) -> Result<TracedTensor> {
588 ensure_float_or_complex("pinv_with_rtol", a.dtype)?;
589 require_concrete_shape("pinv_with_rtol", a)?;
590 let (u, s, vt) = svd(a)?;
591 let abs_s = s.abs();
592 let s_max = abs_s.reduce_max(&[0])?;
593 let s_max_shape = s_max.concrete_shape()?;
594 let threshold_scalar = broadcast_scalar(scalar_real(s.dtype, rtol.max(0.0))?, &s_max_shape)?;
595 let threshold = (&s_max * &threshold_scalar)?;
596 let s_shape = s.concrete_shape()?;
597 let threshold = broadcast_batch_scalar_to_leading_axis(&threshold, &s_shape)?;
598 let mask = abs_s.compare(&threshold, CompareDir::Gt)?;
599 let mask = mask.convert(s.dtype)?;
600 let ones = ones_like(&s)?;
601 let denom = (&s + &(&ones + &(-&mask))?)?;
602 let s_inv = (&mask / &denom)?;
603
604 let v = vt.conj().transpose(&matrix_transpose_perm(vt.rank))?;
605 let uh = u.conj().transpose(&matrix_transpose_perm(u.rank))?;
606 let vs = scale_matrix_columns(&v, &s_inv)?;
607 matmul_preserve_trailing_batch(&vs, &uh)
608}
609
610pub fn norm(
626 a: &TracedTensor,
627 ord: Option<f64>,
628 dim: Option<&[usize]>,
629 keepdim: bool,
630) -> Result<TracedTensor> {
631 ensure_float_or_complex("norm", a.dtype)?;
632 let shape = require_concrete_shape("norm", a)?;
633 let axes = dim.map_or_else(|| (0..a.rank).collect::<Vec<_>>(), |dims| dims.to_vec());
634 if axes.is_empty() {
635 return Ok(a.clone());
636 }
637 validate_axes("norm", a.rank, &axes)?;
638
639 let out = match axes.len() {
640 1 => vector_norm(a, axes[0], ord)?,
641 2 => matrix_norm(a, &axes, ord)?,
642 _ => {
643 let abs = a.abs();
644 match ord {
645 None => frobenius_norm(&abs, &axes)?,
646 Some(p) if p == f64::INFINITY => abs.reduce_max(&axes)?,
647 Some(p) if p == f64::NEG_INFINITY => abs.reduce_min(&axes)?,
648 Some(0.0) => count_nonzero(&abs, &axes)?,
649 Some(p) => p_norm(&abs, &axes, p)?,
650 }
651 }
652 };
653 Ok(restore_keepdim(out, &shape, &axes, keepdim))
654}
655
656fn unexpected_output_count(name: &str, expected: usize) -> Error {
657 Error::Internal(format!("{name} must produce exactly {expected} outputs"))
658}
659
660fn one_output(outputs: Vec<TracedTensor>, name: &str) -> Result<TracedTensor> {
661 let mut outputs = outputs.into_iter();
662 match (outputs.next(), outputs.next()) {
663 (Some(output), None) => Ok(output),
664 _ => Err(unexpected_output_count(name, 1)),
665 }
666}
667
668fn two_outputs(outputs: Vec<TracedTensor>, name: &str) -> Result<(TracedTensor, TracedTensor)> {
669 let mut outputs = outputs.into_iter();
670 match (outputs.next(), outputs.next(), outputs.next()) {
671 (Some(lhs), Some(rhs), None) => Ok((lhs, rhs)),
672 _ => Err(unexpected_output_count(name, 2)),
673 }
674}
675
676fn three_outputs(
677 outputs: Vec<TracedTensor>,
678 name: &str,
679) -> Result<(TracedTensor, TracedTensor, TracedTensor)> {
680 let mut outputs = outputs.into_iter();
681 match (
682 outputs.next(),
683 outputs.next(),
684 outputs.next(),
685 outputs.next(),
686 ) {
687 (Some(first), Some(second), Some(third), None) => Ok((first, second, third)),
688 _ => Err(unexpected_output_count(name, 3)),
689 }
690}
691
692fn four_outputs(
693 outputs: Vec<TracedTensor>,
694 name: &str,
695) -> Result<(TracedTensor, TracedTensor, TracedTensor, TracedTensor)> {
696 let mut outputs = outputs.into_iter();
697 match (
698 outputs.next(),
699 outputs.next(),
700 outputs.next(),
701 outputs.next(),
702 outputs.next(),
703 ) {
704 (Some(first), Some(second), Some(third), Some(fourth), None) => {
705 Ok((first, second, third, fourth))
706 }
707 _ => Err(unexpected_output_count(name, 4)),
708 }
709}
710
711fn five_outputs(
712 outputs: Vec<TracedTensor>,
713 name: &str,
714) -> Result<(
715 TracedTensor,
716 TracedTensor,
717 TracedTensor,
718 TracedTensor,
719 TracedTensor,
720)> {
721 let mut outputs = outputs.into_iter();
722 match (
723 outputs.next(),
724 outputs.next(),
725 outputs.next(),
726 outputs.next(),
727 outputs.next(),
728 outputs.next(),
729 ) {
730 (Some(first), Some(second), Some(third), Some(fourth), Some(fifth), None) => {
731 Ok((first, second, third, fourth, fifth))
732 }
733 _ => Err(unexpected_output_count(name, 5)),
734 }
735}
736
737fn scalar_real(dtype: DType, value: f64) -> Result<TracedTensor> {
738 match dtype {
739 DType::F64 => TracedTensor::from_vec_col_major(vec![], vec![value]),
740 DType::F32 => TracedTensor::from_vec_col_major(vec![], vec![value as f32]),
741 DType::I32 => TracedTensor::from_vec_col_major(vec![], vec![value.round() as i32]),
742 DType::I64 => TracedTensor::from_vec_col_major(vec![], vec![value.round() as i64]),
743 DType::Bool => TracedTensor::from_vec_col_major(vec![], vec![value != 0.0]),
744 DType::C64 => TracedTensor::from_vec_col_major(vec![], vec![Complex64::new(value, 0.0)]),
745 DType::C32 => {
746 TracedTensor::from_vec_col_major(vec![], vec![Complex32::new(value as f32, 0.0)])
747 }
748 }
749}
750
751fn ensure_float_or_complex(op: &'static str, dtype: DType) -> Result<()> {
752 match dtype {
753 DType::F32 | DType::F64 | DType::C32 | DType::C64 => Ok(()),
754 DType::I32 | DType::I64 | DType::Bool => Err(Error::TensorRuntime(
755 tenferro_tensor::Error::backend_failure(op, format!("unsupported dtype {dtype:?}")),
756 )),
757 }
758}
759
760fn ensure_min_rank(op: &'static str, actual: usize, expected: usize) -> Result<()> {
761 if actual < expected {
762 return Err(Error::TensorRuntime(tenferro_tensor::Error::RankMismatch {
763 op,
764 expected,
765 actual,
766 }));
767 }
768 Ok(())
769}
770
771fn validate_axes(op: &'static str, rank: usize, axes: &[usize]) -> Result<()> {
772 for &axis in axes {
773 if axis >= rank {
774 return Err(Error::TensorRuntime(
775 tenferro_tensor::Error::AxisOutOfBounds { op, axis, rank },
776 ));
777 }
778 }
779 Ok(())
780}
781
782fn require_concrete_shape(op: &'static str, input: &TracedTensor) -> Result<Vec<usize>> {
783 input.try_concrete_shape().ok_or_else(|| {
784 Error::TensorRuntime(tenferro_tensor::Error::backend_failure(
785 op,
786 "symbolic shape is not supported by this traced linalg helper",
787 ))
788 })
789}
790
791fn zero_scalar(dtype: DType) -> Result<TracedTensor> {
792 scalar_real(dtype, 0.0)
793}
794
795fn one_scalar(dtype: DType) -> Result<TracedTensor> {
796 scalar_real(dtype, 1.0)
797}
798
799fn ones_like(input: &TracedTensor) -> Result<TracedTensor> {
800 let shape = input.concrete_shape()?;
801 broadcast_scalar(one_scalar(input.dtype)?, &shape)
802}
803
804fn eye_like(anchor: &TracedTensor, size: usize) -> Result<TracedTensor> {
805 let mut vector_shape = vec![size];
806 let anchor_shape = anchor.concrete_shape()?;
807 vector_shape.extend_from_slice(&anchor_shape[2..]);
808 let diagonal = broadcast_scalar(one_scalar(anchor.dtype)?, &vector_shape)?;
809 diagonal.embed_diag(0, 1)
810}
811
812fn broadcast_scalar(input: TracedTensor, shape: &[usize]) -> Result<TracedTensor> {
813 let input_shape = input.concrete_shape()?;
814 if input_shape == shape {
815 return Ok(input);
816 }
817 input.broadcast_in_dim(shape, &[])
818}
819
820fn broadcast_batch_scalar_to_leading_axis(
821 input: &TracedTensor,
822 shape: &[usize],
823) -> Result<TracedTensor> {
824 let input_shape = input.concrete_shape()?;
825 if input_shape == shape {
826 return Ok(input.clone());
827 }
828 let dims: Vec<usize> = (1..shape.len()).collect();
829 input.broadcast_in_dim(shape, &dims)
830}
831
832fn matmul_preserve_trailing_batch(lhs: &TracedTensor, rhs: &TracedTensor) -> Result<TracedTensor> {
833 let rank = lhs.rank;
834 let batch_dims: Vec<usize> = (2..rank).collect();
835 lhs.dot_general(
836 rhs,
837 DotGeneralConfig {
838 lhs_contracting_dims: vec![1],
839 rhs_contracting_dims: vec![0],
840 lhs_batch_dims: batch_dims.clone(),
841 rhs_batch_dims: batch_dims,
842 },
843 )
844}
845
846fn matrix_transpose_perm(rank: usize) -> Vec<usize> {
847 let mut perm: Vec<usize> = (0..rank).collect();
848 perm.swap(0, 1);
849 perm
850}
851
852fn frobenius_norm(abs: &TracedTensor, axes: &[usize]) -> Result<TracedTensor> {
853 let squared = abs.pow(&scalar_real(abs.dtype, 2.0)?)?;
854 Ok(squared.reduce_sum(axes)?.sqrt())
855}
856
857fn p_norm(abs: &TracedTensor, axes: &[usize], p: f64) -> Result<TracedTensor> {
858 let power = abs.pow(&scalar_real(abs.dtype, p)?)?;
859 let inv_p = scalar_real(abs.dtype, 1.0 / p)?;
860 power.reduce_sum(axes)?.pow(&inv_p)
861}
862
863fn default_pinv_rtol(dtype: DType, max_dim: usize) -> f64 {
864 let eps = match dtype {
865 DType::F32 | DType::C32 => f32::EPSILON as f64,
866 DType::F64 | DType::C64 => f64::EPSILON,
867 DType::I32 | DType::I64 | DType::Bool => 0.0,
868 };
869 eps * max_dim as f64
870}
871
872fn vector_norm(a: &TracedTensor, axis: usize, ord: Option<f64>) -> Result<TracedTensor> {
873 let abs = a.abs();
874 match ord {
875 None => frobenius_norm(&abs, &[axis]),
876 Some(0.0) => count_nonzero(&abs, &[axis]),
877 Some(p) if p == f64::INFINITY => abs.reduce_max(&[axis]),
878 Some(p) if p == f64::NEG_INFINITY => abs.reduce_min(&[axis]),
879 Some(p) => p_norm(&abs, &[axis], p),
880 }
881}
882
883fn matrix_norm(a: &TracedTensor, axes: &[usize], ord: Option<f64>) -> Result<TracedTensor> {
884 let matrix = move_axes_to_front(a, axes)?;
885 let abs = matrix.abs();
886 match ord {
887 None => frobenius_norm(&abs, &[0, 1]),
888 Some(p) if p == f64::INFINITY => matrix_row_sum_norm(&abs, true),
889 Some(p) if p == f64::NEG_INFINITY => matrix_row_sum_norm(&abs, false),
890 Some(1.0) => matrix_col_sum_norm(&abs, true),
891 Some(-1.0) => matrix_col_sum_norm(&abs, false),
892 Some(2.0) => {
893 let singular_values = svd_values(&matrix)?.abs();
894 singular_values.reduce_max(&[0])
895 }
896 Some(-2.0) => {
897 let singular_values = svd_values(&matrix)?.abs();
898 singular_values.reduce_min(&[0])
899 }
900 Some(0.0) => count_nonzero(&abs, &[0, 1]),
901 Some(p) => p_norm(&abs, &[0, 1], p),
902 }
903}
904
905fn svd_values(a: &TracedTensor) -> Result<TracedTensor> {
906 one_output(
907 apply(
908 Arc::new(LinalgExtensionOp::new(LinalgOp::SvdVals { eps: 1e-12 })),
909 &[a],
910 )?,
911 "svd_values",
912 )
913}
914
915fn eigh_values(a: &TracedTensor) -> Result<TracedTensor> {
916 one_output(
917 apply(
918 Arc::new(LinalgExtensionOp::new(LinalgOp::EighVals { eps: 1e-12 })),
919 &[a],
920 )?,
921 "eigh_values",
922 )
923}
924
925fn eig_values(a: &TracedTensor) -> Result<TracedTensor> {
926 one_output(
927 apply(
928 Arc::new(LinalgExtensionOp::new(LinalgOp::EigVals {
929 input_dtype: a.dtype,
930 })),
931 &[a],
932 )?,
933 "eig_values",
934 )
935}
936
937fn scale_matrix_columns(matrix: &TracedTensor, scale: &TracedTensor) -> Result<TracedTensor> {
938 let matrix_shape = matrix.concrete_shape()?;
939 let scale_shape_input = scale.concrete_shape()?;
940 let mut scale_shape = vec![1, scale_shape_input[0]];
941 scale_shape.extend_from_slice(&matrix_shape[2..]);
942 let dims: Vec<usize> = (0..matrix_shape.len()).collect();
943 let scale = scale
944 .reshape(&scale_shape)
945 .broadcast_in_dim(&matrix_shape, &dims)?;
946 matrix * &scale
947}
948
949fn count_nonzero(abs: &TracedTensor, axes: &[usize]) -> Result<TracedTensor> {
950 let mask = abs.compare(&zero_scalar(abs.dtype)?, CompareDir::Gt)?;
951 mask.convert(abs.dtype)?.reduce_sum(axes)
952}
953
954fn matrix_row_sum_norm(abs: &TracedTensor, take_max: bool) -> Result<TracedTensor> {
955 let row_sums = abs.reduce_sum(&[1])?;
956 if take_max {
957 row_sums.reduce_max(&[0])
958 } else {
959 row_sums.reduce_min(&[0])
960 }
961}
962
963fn matrix_col_sum_norm(abs: &TracedTensor, take_max: bool) -> Result<TracedTensor> {
964 let col_sums = abs.reduce_sum(&[0])?;
965 if take_max {
966 col_sums.reduce_max(&[0])
967 } else {
968 col_sums.reduce_min(&[0])
969 }
970}
971
972fn move_axes_to_front(tensor: &TracedTensor, axes: &[usize]) -> Result<TracedTensor> {
973 if axes.iter().enumerate().all(|(index, &axis)| index == axis) {
974 return Ok(tensor.clone());
975 }
976
977 let mut selected = vec![false; tensor.rank];
978 for &axis in axes {
979 selected[axis] = true;
980 }
981
982 let mut perm = Vec::with_capacity(tensor.rank);
983 perm.extend_from_slice(axes);
984 for (axis, is_selected) in selected.iter().enumerate().take(tensor.rank) {
985 if !*is_selected {
986 perm.push(axis);
987 }
988 }
989 tensor.transpose(&perm)
990}
991
992fn restore_keepdim(
993 reduced: TracedTensor,
994 original_shape: &[usize],
995 axes: &[usize],
996 keepdim: bool,
997) -> TracedTensor {
998 if !keepdim {
999 return reduced;
1000 }
1001 let mut kept_shape = original_shape.to_vec();
1002 for &axis in axes {
1003 kept_shape[axis] = 1;
1004 }
1005 reduced.reshape(&kept_shape)
1006}