tensor4all_tcicore/matrixluci/
dense.rs1use crate::matrixluci::kernel::PivotKernel;
4use crate::matrixluci::scalar::Scalar;
5use crate::matrixluci::source::{materialize_source, CandidateMatrixSource};
6use crate::matrixluci::{PivotKernelOptions, PivotSelectionCore, Result};
7use faer::MatRef;
8use num_complex::{Complex32, Complex64};
9
10#[derive(Default)]
16pub struct DenseFaerLuKernel;
17
18impl DenseFaerLuKernel {
19 fn is_no_truncation(options: &PivotKernelOptions, full_rank: usize) -> bool {
20 options.max_rank >= full_rank && options.rel_tol == 0.0 && options.abs_tol == 0.0
21 }
22
23 fn compute_no_truncation_pivot_errors<T: Scalar>(
24 u: &faer::MatRef<'_, T>,
25 full_rank: usize,
26 ) -> Vec<f64> {
27 let mut pivot_errors = Vec::with_capacity(full_rank + 1);
28 for i in 0..full_rank {
29 let pivot_abs = u[(i, i)].abs_val();
30 if pivot_abs < f64::EPSILON {
31 if pivot_errors.is_empty() {
32 pivot_errors.push(pivot_abs);
33 } else {
34 pivot_errors.push(0.0);
35 }
36 return pivot_errors;
37 }
38 pivot_errors.push(pivot_abs);
39 }
40 pivot_errors.push(0.0);
41 pivot_errors
42 }
43
44 fn compute_pivot_errors(
45 diag_abs: &[f64],
46 nrows: usize,
47 ncols: usize,
48 options: &PivotKernelOptions,
49 ) -> Vec<f64> {
50 let full_rank = nrows.min(ncols);
51 if full_rank == 0 {
52 return vec![0.0];
53 }
54
55 if Self::is_no_truncation(options, full_rank) {
56 let mut pivot_errors = Vec::with_capacity(full_rank + 1);
57 for &pivot_abs in diag_abs.iter().take(full_rank) {
58 if pivot_abs < f64::EPSILON {
59 if pivot_errors.is_empty() {
60 pivot_errors.push(pivot_abs);
61 } else {
62 pivot_errors.push(0.0);
63 }
64 return pivot_errors;
65 }
66 pivot_errors.push(pivot_abs);
67 }
68 pivot_errors.push(0.0);
69 return pivot_errors;
70 }
71
72 let max_rank = options.max_rank.min(full_rank);
73 let mut accepted = Vec::new();
74 let mut max_error = 0.0f64;
75 let mut last_error = f64::NAN;
76 let mut rank = 0usize;
77
78 while rank < max_rank {
79 let pivot_abs = diag_abs.get(rank).copied().unwrap_or(0.0);
80 last_error = pivot_abs;
81
82 if rank > 0 && (pivot_abs < options.rel_tol * max_error || pivot_abs < options.abs_tol)
83 {
84 break;
85 }
86
87 if pivot_abs < f64::EPSILON {
88 if rank == 0 {
89 last_error = pivot_abs;
90 }
91 break;
92 }
93
94 max_error = max_error.max(pivot_abs);
95 accepted.push(pivot_abs);
96 rank += 1;
97 }
98
99 if rank >= full_rank {
100 last_error = 0.0;
101 } else if rank == max_rank && rank > 0 {
102 last_error = accepted[rank - 1];
104 }
105
106 accepted.push(last_error);
107 accepted
108 }
109}
110
111macro_rules! impl_dense_kernel {
112 ($t:ty) => {
113 impl PivotKernel<$t> for DenseFaerLuKernel {
114 fn factorize<S: CandidateMatrixSource<$t>>(
115 &self,
116 source: &S,
117 options: &PivotKernelOptions,
118 ) -> Result<PivotSelectionCore> {
119 let run = |data: &[$t]| -> Result<PivotSelectionCore> {
120 let nrows = source.nrows();
121 let ncols = source.ncols();
122 let mat = MatRef::from_column_major_slice(data, nrows, ncols);
123 let lu = mat.full_piv_lu();
124
125 let rank_cap = nrows.min(ncols);
126 let u = lu.U();
127 let pivot_errors = if DenseFaerLuKernel::is_no_truncation(options, rank_cap) {
128 DenseFaerLuKernel::compute_no_truncation_pivot_errors(&u, rank_cap)
129 } else {
130 let mut diag_abs = Vec::with_capacity(rank_cap);
131 for i in 0..rank_cap {
132 diag_abs.push(u[(i, i)].abs_val());
133 }
134 DenseFaerLuKernel::compute_pivot_errors(&diag_abs, nrows, ncols, options)
135 };
136 let rank = pivot_errors.len().saturating_sub(1);
137
138 let (row_fwd, _) = lu.P().arrays();
139 let (col_fwd, _) = lu.Q().arrays();
140
141 Ok(PivotSelectionCore {
142 row_indices: row_fwd[..rank].to_vec(),
143 col_indices: col_fwd[..rank].to_vec(),
144 pivot_errors,
145 rank,
146 })
147 };
148
149 if let Some(data) = source.dense_column_major_slice() {
150 run(data)
151 } else {
152 let materialized = materialize_source(source);
153 run(materialized.as_slice())
154 }
155 }
156 }
157 };
158}
159
160impl_dense_kernel!(f32);
161impl_dense_kernel!(f64);
162impl_dense_kernel!(Complex32);
163impl_dense_kernel!(Complex64);
164
165#[cfg(test)]
166mod tests;