Skip to main content

tensor4all_tcicore/matrixluci/
factors.rs

1//! Cross-factor reconstruction helpers.
2//!
3//! [`CrossFactors`] gathers the pivot block, pivot columns, and pivot rows
4//! from a [`CandidateMatrixSource`], and
5//! provides methods to compute the left and right CI factors.
6
7use crate::matrixluci::error::MatrixLuciError;
8use crate::matrixluci::scalar::Scalar;
9use crate::matrixluci::source::CandidateMatrixSource;
10use crate::matrixluci::types::{DenseOwnedMatrix, PivotSelectionCore};
11use crate::matrixluci::Result;
12
13/// Gather a dense column-major block from a source.
14pub(crate) fn load_block<T: Scalar, S: CandidateMatrixSource<T>>(
15    source: &S,
16    rows: &[usize],
17    cols: &[usize],
18) -> DenseOwnedMatrix<T> {
19    let mut data = vec![T::zero(); rows.len() * cols.len()];
20    source.get_block(rows, cols, &mut data);
21    DenseOwnedMatrix::from_column_major(data, rows.len(), cols.len())
22}
23
24/// Dense matrix product in column-major layout.
25pub(crate) fn matmul<T: Scalar>(
26    lhs: &DenseOwnedMatrix<T>,
27    rhs: &DenseOwnedMatrix<T>,
28) -> DenseOwnedMatrix<T> {
29    assert_eq!(lhs.ncols(), rhs.nrows());
30    let mut out = DenseOwnedMatrix::zeros(lhs.nrows(), rhs.ncols());
31    for j in 0..rhs.ncols() {
32        for k in 0..lhs.ncols() {
33            let rhs_kj = rhs[[k, j]];
34            for i in 0..lhs.nrows() {
35                out[[i, j]] = out[[i, j]] + lhs[[i, k]] * rhs_kj;
36            }
37        }
38    }
39    out
40}
41
42/// Subtract one dense matrix from another in place.
43pub(crate) fn subtract_inplace<T: Scalar>(
44    lhs: &mut DenseOwnedMatrix<T>,
45    rhs: &DenseOwnedMatrix<T>,
46) {
47    assert_eq!(lhs.nrows(), rhs.nrows());
48    assert_eq!(lhs.ncols(), rhs.ncols());
49    for j in 0..lhs.ncols() {
50        for i in 0..lhs.nrows() {
51            lhs[[i, j]] = lhs[[i, j]] - rhs[[i, j]];
52        }
53    }
54}
55
56fn swap_rows<T: Scalar>(matrix: &mut DenseOwnedMatrix<T>, a: usize, b: usize) {
57    if a == b {
58        return;
59    }
60    for col in 0..matrix.ncols() {
61        let tmp = matrix[[a, col]];
62        matrix[[a, col]] = matrix[[b, col]];
63        matrix[[b, col]] = tmp;
64    }
65}
66
67/// Invert a small square dense matrix with Gauss-Jordan elimination.
68pub(crate) fn invert_square<T: Scalar>(
69    matrix: &DenseOwnedMatrix<T>,
70) -> Result<DenseOwnedMatrix<T>> {
71    if matrix.nrows() != matrix.ncols() {
72        return Err(MatrixLuciError::InvalidArgument {
73            message: "pivot block must be square".to_string(),
74        });
75    }
76
77    let n = matrix.nrows();
78    let mut aug = DenseOwnedMatrix::zeros(n, 2 * n);
79    for j in 0..n {
80        for i in 0..n {
81            aug[[i, j]] = matrix[[i, j]];
82        }
83        aug[[j, n + j]] = T::one();
84    }
85
86    for k in 0..n {
87        let mut pivot_row = k;
88        let mut pivot_abs = 0.0f64;
89        for row in k..n {
90            let candidate = aug[[row, k]].abs_val();
91            if candidate > pivot_abs {
92                pivot_abs = candidate;
93                pivot_row = row;
94            }
95        }
96
97        if pivot_abs < T::epsilon() {
98            return Err(MatrixLuciError::SingularPivotBlock);
99        }
100
101        swap_rows(&mut aug, k, pivot_row);
102
103        let pivot = aug[[k, k]];
104        for col in 0..(2 * n) {
105            aug[[k, col]] = aug[[k, col]] / pivot;
106        }
107
108        for row in 0..n {
109            if row == k {
110                continue;
111            }
112            let factor = aug[[row, k]];
113            if factor.abs_val() < T::epsilon() {
114                continue;
115            }
116            for col in 0..(2 * n) {
117                aug[[row, col]] = aug[[row, col]] - factor * aug[[k, col]];
118            }
119        }
120    }
121
122    let mut inv = DenseOwnedMatrix::zeros(n, n);
123    for j in 0..n {
124        for i in 0..n {
125            inv[[i, j]] = aug[[i, n + j]];
126        }
127    }
128    Ok(inv)
129}
130
131/// Dense factors derived from a pivot selection.
132///
133/// Contains the pivot block `A[I, J]`, full pivot columns `A[:, J]`, and
134/// full pivot rows `A[I, :]`. These are used to reconstruct the left and
135/// right CI factors for the cross interpolation approximation
136/// `A ~ A[:, J] * A[I, J]^{-1} * A[I, :]`.
137#[derive(Debug, Clone)]
138pub struct CrossFactors<T: Scalar> {
139    /// Pivot block `A[I, J]`.
140    pub pivot: DenseOwnedMatrix<T>,
141    /// Columns through selected pivot columns `A[:, J]`.
142    pub pivot_cols: DenseOwnedMatrix<T>,
143    /// Rows through selected pivot rows `A[I, :]`.
144    pub pivot_rows: DenseOwnedMatrix<T>,
145}
146
147impl<T: Scalar> CrossFactors<T> {
148    /// Gather a dense block from a source.
149    pub fn gather<S: CandidateMatrixSource<T>>(
150        source: &S,
151        rows: &[usize],
152        cols: &[usize],
153    ) -> DenseOwnedMatrix<T> {
154        load_block(source, rows, cols)
155    }
156
157    /// Reconstruct factors from a source and pivot-only selection.
158    pub fn from_source<S: CandidateMatrixSource<T>>(
159        source: &S,
160        selection: &PivotSelectionCore,
161    ) -> Result<Self> {
162        let all_rows: Vec<usize> = (0..source.nrows()).collect();
163        let all_cols: Vec<usize> = (0..source.ncols()).collect();
164        Ok(Self {
165            pivot: load_block(source, &selection.row_indices, &selection.col_indices),
166            pivot_cols: load_block(source, &all_rows, &selection.col_indices),
167            pivot_rows: load_block(source, &selection.row_indices, &all_cols),
168        })
169    }
170
171    /// Invert the pivot block.
172    pub fn pivot_inverse(&self) -> Result<DenseOwnedMatrix<T>> {
173        invert_square(&self.pivot)
174    }
175
176    /// Form `A[:, J] * A[I, J]^{-1}`.
177    pub fn cols_times_pivot_inv(&self) -> Result<DenseOwnedMatrix<T>> {
178        let pivot_inv = self.pivot_inverse()?;
179        Ok(matmul(&self.pivot_cols, &pivot_inv))
180    }
181
182    /// Form `A[I, J]^{-1} * A[I, :]`.
183    pub fn pivot_inv_times_rows(&self) -> Result<DenseOwnedMatrix<T>> {
184        let pivot_inv = self.pivot_inverse()?;
185        Ok(matmul(&pivot_inv, &self.pivot_rows))
186    }
187}
188
189#[cfg(test)]
190mod tests;