Skip to main content

tensor4all_tcicore/matrixluci/
block_rook.rs

1//! Lazy pivot-kernel implementations.
2
3use crate::matrixluci::factors::{invert_square, load_block, matmul, subtract_inplace};
4use crate::matrixluci::kernel::PivotKernel;
5use crate::matrixluci::scalar::Scalar;
6use crate::matrixluci::source::CandidateMatrixSource;
7use crate::matrixluci::types::{DenseOwnedMatrix, PivotKernelOptions, PivotSelectionCore};
8use crate::matrixluci::Result;
9use num_complex::{Complex32, Complex64};
10
11/// Lazy pivot kernel based on residual row/column rook search.
12///
13/// Selects pivots by computing residual blocks on demand, avoiding full
14/// matrix materialization. Suitable for large matrices accessed via
15/// [`LazyMatrixSource`](super::LazyMatrixSource).
16#[derive(Default)]
17pub struct LazyBlockRookKernel;
18
19fn residual_block<T: Scalar, S: CandidateMatrixSource<T>>(
20    source: &S,
21    rows: &[usize],
22    cols: &[usize],
23    selected_rows: &[usize],
24    selected_cols: &[usize],
25    pivot_inv: Option<&DenseOwnedMatrix<T>>,
26) -> DenseOwnedMatrix<T> {
27    let mut residual = load_block(source, rows, cols);
28    if selected_rows.is_empty() {
29        return residual;
30    }
31
32    let a_rj = load_block(source, rows, selected_cols);
33    let a_ic = load_block(source, selected_rows, cols);
34    let temp = matmul(&a_rj, pivot_inv.unwrap());
35    let approx = matmul(&temp, &a_ic);
36    subtract_inplace(&mut residual, &approx);
37    residual
38}
39
40fn argmax_abs<T: Scalar>(matrix: &DenseOwnedMatrix<T>) -> (usize, usize, f64) {
41    let mut best_row = 0usize;
42    let mut best_col = 0usize;
43    let mut best_abs = -1.0f64;
44    for col in 0..matrix.ncols() {
45        for row in 0..matrix.nrows() {
46            let value = matrix[[row, col]].abs_val();
47            if value > best_abs {
48                best_row = row;
49                best_col = col;
50                best_abs = value;
51            }
52        }
53    }
54    (best_row, best_col, best_abs.max(0.0))
55}
56
57fn remaining_indices(total: usize, selected: &[usize]) -> Vec<usize> {
58    let mut used = vec![false; total];
59    for &idx in selected {
60        used[idx] = true;
61    }
62    (0..total).filter(|&idx| !used[idx]).collect()
63}
64
65fn rook_pivot<T: Scalar, S: CandidateMatrixSource<T>>(
66    source: &S,
67    remaining_rows: &[usize],
68    remaining_cols: &[usize],
69    selected_rows: &[usize],
70    selected_cols: &[usize],
71    pivot_inv: Option<&DenseOwnedMatrix<T>>,
72) -> (usize, usize, f64) {
73    let mut current_col = remaining_cols[0];
74    let mut current_row = remaining_rows[0];
75    let max_steps = remaining_rows.len() + remaining_cols.len() + 1;
76
77    for _ in 0..max_steps {
78        let col_residual = residual_block(
79            source,
80            remaining_rows,
81            &[current_col],
82            selected_rows,
83            selected_cols,
84            pivot_inv,
85        );
86        let (best_row_pos, _, _) = argmax_abs(&col_residual);
87        current_row = remaining_rows[best_row_pos];
88
89        let row_residual = residual_block(
90            source,
91            &[current_row],
92            remaining_cols,
93            selected_rows,
94            selected_cols,
95            pivot_inv,
96        );
97        let (_, best_col_pos, best_abs) = argmax_abs(&row_residual);
98        let next_col = remaining_cols[best_col_pos];
99
100        if next_col == current_col {
101            return (current_row, current_col, best_abs);
102        }
103        current_col = next_col;
104    }
105
106    let row_residual = residual_block(
107        source,
108        &[current_row],
109        remaining_cols,
110        selected_rows,
111        selected_cols,
112        pivot_inv,
113    );
114    let (_, best_col_pos, best_abs) = argmax_abs(&row_residual);
115    (current_row, remaining_cols[best_col_pos], best_abs)
116}
117
118fn factorize_lazy<T: Scalar, S: CandidateMatrixSource<T>>(
119    source: &S,
120    options: &PivotKernelOptions,
121) -> Result<PivotSelectionCore> {
122    let nrows = source.nrows();
123    let ncols = source.ncols();
124    let full_rank = nrows.min(ncols);
125    if full_rank == 0 {
126        return Ok(PivotSelectionCore {
127            row_indices: Vec::new(),
128            col_indices: Vec::new(),
129            pivot_errors: vec![0.0],
130            rank: 0,
131        });
132    }
133
134    let max_rank = options.max_rank.min(full_rank);
135    let mut selected_rows = Vec::with_capacity(max_rank);
136    let mut selected_cols = Vec::with_capacity(max_rank);
137    let mut accepted = Vec::with_capacity(max_rank + 1);
138    let mut max_error = 0.0f64;
139    let mut last_error = f64::NAN;
140
141    while selected_rows.len() < max_rank {
142        let remaining_rows = remaining_indices(nrows, &selected_rows);
143        let remaining_cols = remaining_indices(ncols, &selected_cols);
144        if remaining_rows.is_empty() || remaining_cols.is_empty() {
145            break;
146        }
147
148        let pivot_inv = if selected_rows.is_empty() {
149            None
150        } else {
151            let pivot = load_block(source, &selected_rows, &selected_cols);
152            Some(invert_square(&pivot)?)
153        };
154
155        let (pivot_row, pivot_col, pivot_abs) = rook_pivot(
156            source,
157            &remaining_rows,
158            &remaining_cols,
159            &selected_rows,
160            &selected_cols,
161            pivot_inv.as_ref(),
162        );
163        last_error = pivot_abs;
164
165        if !selected_rows.is_empty()
166            && (pivot_abs < options.rel_tol * max_error || pivot_abs < options.abs_tol)
167        {
168            break;
169        }
170
171        if pivot_abs < T::epsilon() {
172            if selected_rows.is_empty() {
173                last_error = pivot_abs;
174            }
175            break;
176        }
177
178        max_error = max_error.max(pivot_abs);
179        selected_rows.push(pivot_row);
180        selected_cols.push(pivot_col);
181        accepted.push(pivot_abs);
182    }
183
184    let rank = selected_rows.len();
185    if rank >= full_rank {
186        last_error = 0.0;
187    } else if rank == max_rank && rank > 0 {
188        last_error = accepted[rank - 1];
189    }
190    accepted.push(last_error);
191
192    Ok(PivotSelectionCore {
193        row_indices: selected_rows,
194        col_indices: selected_cols,
195        pivot_errors: accepted,
196        rank,
197    })
198}
199
200macro_rules! impl_lazy_block_rook_kernel {
201    ($t:ty) => {
202        impl PivotKernel<$t> for LazyBlockRookKernel {
203            fn factorize<S: CandidateMatrixSource<$t>>(
204                &self,
205                source: &S,
206                options: &PivotKernelOptions,
207            ) -> Result<PivotSelectionCore> {
208                factorize_lazy(source, options)
209            }
210        }
211    };
212}
213
214impl_lazy_block_rook_kernel!(f32);
215impl_lazy_block_rook_kernel!(f64);
216impl_lazy_block_rook_kernel!(Complex32);
217impl_lazy_block_rook_kernel!(Complex64);
218
219#[cfg(test)]
220mod tests;