Skip to main content

tensor4all_tcicore/matrixluci/
dense.rs

1//! Dense pivot-kernel implementations.
2
3use crate::matrixluci::kernel::PivotKernel;
4use crate::matrixluci::scalar::Scalar;
5use crate::matrixluci::source::{materialize_source, CandidateMatrixSource};
6use crate::matrixluci::{PivotKernelOptions, PivotSelectionCore, Result};
7use faer::MatRef;
8use num_complex::{Complex32, Complex64};
9
10/// Dense full-pivoting LU kernel backed by faer.
11///
12/// Materializes the source matrix and performs full-pivoting LU
13/// decomposition via the faer library. Suitable for small to moderate
14/// matrices where full materialization is acceptable.
15#[derive(Default)]
16pub struct DenseFaerLuKernel;
17
18impl DenseFaerLuKernel {
19    fn is_no_truncation(options: &PivotKernelOptions, full_rank: usize) -> bool {
20        options.max_rank >= full_rank && options.rel_tol == 0.0 && options.abs_tol == 0.0
21    }
22
23    fn compute_no_truncation_pivot_errors<T: Scalar>(
24        u: &faer::MatRef<'_, T>,
25        full_rank: usize,
26    ) -> Vec<f64> {
27        let mut pivot_errors = Vec::with_capacity(full_rank + 1);
28        for i in 0..full_rank {
29            let pivot_abs = u[(i, i)].abs_val();
30            if pivot_abs < f64::EPSILON {
31                if pivot_errors.is_empty() {
32                    pivot_errors.push(pivot_abs);
33                } else {
34                    pivot_errors.push(0.0);
35                }
36                return pivot_errors;
37            }
38            pivot_errors.push(pivot_abs);
39        }
40        pivot_errors.push(0.0);
41        pivot_errors
42    }
43
44    fn compute_pivot_errors(
45        diag_abs: &[f64],
46        nrows: usize,
47        ncols: usize,
48        options: &PivotKernelOptions,
49    ) -> Vec<f64> {
50        let full_rank = nrows.min(ncols);
51        if full_rank == 0 {
52            return vec![0.0];
53        }
54
55        if Self::is_no_truncation(options, full_rank) {
56            let mut pivot_errors = Vec::with_capacity(full_rank + 1);
57            for &pivot_abs in diag_abs.iter().take(full_rank) {
58                if pivot_abs < f64::EPSILON {
59                    if pivot_errors.is_empty() {
60                        pivot_errors.push(pivot_abs);
61                    } else {
62                        pivot_errors.push(0.0);
63                    }
64                    return pivot_errors;
65                }
66                pivot_errors.push(pivot_abs);
67            }
68            pivot_errors.push(0.0);
69            return pivot_errors;
70        }
71
72        let max_rank = options.max_rank.min(full_rank);
73        let mut accepted = Vec::new();
74        let mut max_error = 0.0f64;
75        let mut last_error = f64::NAN;
76        let mut rank = 0usize;
77
78        while rank < max_rank {
79            let pivot_abs = diag_abs.get(rank).copied().unwrap_or(0.0);
80            last_error = pivot_abs;
81
82            if rank > 0 && (pivot_abs < options.rel_tol * max_error || pivot_abs < options.abs_tol)
83            {
84                break;
85            }
86
87            if pivot_abs < f64::EPSILON {
88                if rank == 0 {
89                    last_error = pivot_abs;
90                }
91                break;
92            }
93
94            max_error = max_error.max(pivot_abs);
95            accepted.push(pivot_abs);
96            rank += 1;
97        }
98
99        if rank >= full_rank {
100            last_error = 0.0;
101        } else if rank == max_rank && rank > 0 {
102            // Preserve the legacy tcicore-compatible semantics for max_rank stopping.
103            last_error = accepted[rank - 1];
104        }
105
106        accepted.push(last_error);
107        accepted
108    }
109}
110
111macro_rules! impl_dense_kernel {
112    ($t:ty) => {
113        impl PivotKernel<$t> for DenseFaerLuKernel {
114            fn factorize<S: CandidateMatrixSource<$t>>(
115                &self,
116                source: &S,
117                options: &PivotKernelOptions,
118            ) -> Result<PivotSelectionCore> {
119                let run = |data: &[$t]| -> Result<PivotSelectionCore> {
120                    let nrows = source.nrows();
121                    let ncols = source.ncols();
122                    let mat = MatRef::from_column_major_slice(data, nrows, ncols);
123                    let lu = mat.full_piv_lu();
124
125                    let rank_cap = nrows.min(ncols);
126                    let u = lu.U();
127                    let pivot_errors = if DenseFaerLuKernel::is_no_truncation(options, rank_cap) {
128                        DenseFaerLuKernel::compute_no_truncation_pivot_errors(&u, rank_cap)
129                    } else {
130                        let mut diag_abs = Vec::with_capacity(rank_cap);
131                        for i in 0..rank_cap {
132                            diag_abs.push(u[(i, i)].abs_val());
133                        }
134                        DenseFaerLuKernel::compute_pivot_errors(&diag_abs, nrows, ncols, options)
135                    };
136                    let rank = pivot_errors.len().saturating_sub(1);
137
138                    let (row_fwd, _) = lu.P().arrays();
139                    let (col_fwd, _) = lu.Q().arrays();
140
141                    Ok(PivotSelectionCore {
142                        row_indices: row_fwd[..rank].to_vec(),
143                        col_indices: col_fwd[..rank].to_vec(),
144                        pivot_errors,
145                        rank,
146                    })
147                };
148
149                if let Some(data) = source.dense_column_major_slice() {
150                    run(data)
151                } else {
152                    let materialized = materialize_source(source);
153                    run(materialized.as_slice())
154                }
155            }
156        }
157    };
158}
159
160impl_dense_kernel!(f32);
161impl_dense_kernel!(f64);
162impl_dense_kernel!(Complex32);
163impl_dense_kernel!(Complex64);
164
165#[cfg(test)]
166mod tests;