tensor4all_tcicore/traits.rs
1//! Abstract traits for matrix cross interpolation.
2//!
3//! The [`AbstractMatrixCI`] trait provides a common interface for all matrix
4//! cross interpolation implementations ([`MatrixLUCI`](crate::MatrixLUCI),
5//! [`MatrixACA`](crate::MatrixACA)).
6
7use crate::error::Result;
8use crate::matrix::{submatrix, zeros, Matrix};
9use crate::scalar::Scalar;
10
11/// Common interface for matrix cross interpolation objects.
12///
13/// Implementors provide low-rank approximations of matrices via
14/// selected pivot rows and columns. This trait unifies the API for
15/// [`MatrixLUCI`](crate::MatrixLUCI) and [`MatrixACA`](crate::MatrixACA).
16///
17/// # Examples
18///
19/// ```
20/// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
21///
22/// let m = from_vec2d(vec![
23/// vec![1.0_f64, 2.0],
24/// vec![3.0, 4.0],
25/// ]);
26/// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
27///
28/// // All AbstractMatrixCI methods are available:
29/// assert_eq!(ci.nrows(), 2);
30/// assert_eq!(ci.ncols(), 2);
31/// assert!(ci.rank() >= 1);
32/// assert!(!ci.is_empty());
33///
34/// // Full reconstruction
35/// let full = ci.to_matrix();
36/// for i in 0..2 {
37/// for j in 0..2 {
38/// assert!((full[[i, j]] - m[[i, j]]).abs() < 1e-10);
39/// }
40/// }
41/// ```
42pub trait AbstractMatrixCI<T: Scalar>: Sized {
43 /// Number of rows in the approximated matrix
44 ///
45 /// # Examples
46 ///
47 /// ```
48 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
49 ///
50 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
51 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
52 /// assert_eq!(ci.nrows(), 3);
53 /// ```
54 fn nrows(&self) -> usize;
55
56 /// Number of columns in the approximated matrix
57 ///
58 /// # Examples
59 ///
60 /// ```
61 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
62 ///
63 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0, 3.0]]);
64 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
65 /// assert_eq!(ci.ncols(), 3);
66 /// ```
67 fn ncols(&self) -> usize;
68
69 /// Current rank of the approximation (number of pivots)
70 ///
71 /// # Examples
72 ///
73 /// ```
74 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
75 ///
76 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
77 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
78 /// assert!(ci.rank() >= 1);
79 /// assert!(ci.rank() <= 2);
80 /// ```
81 fn rank(&self) -> usize;
82
83 /// Row indices selected as pivots (I set)
84 ///
85 /// # Examples
86 ///
87 /// ```
88 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
89 ///
90 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
91 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
92 /// let rows = ci.row_indices();
93 /// assert_eq!(rows.len(), ci.rank());
94 /// for &r in rows {
95 /// assert!(r < ci.nrows());
96 /// }
97 /// ```
98 fn row_indices(&self) -> &[usize];
99
100 /// Column indices selected as pivots (J set)
101 ///
102 /// # Examples
103 ///
104 /// ```
105 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
106 ///
107 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
108 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
109 /// let cols = ci.col_indices();
110 /// assert_eq!(cols.len(), ci.rank());
111 /// for &c in cols {
112 /// assert!(c < ci.ncols());
113 /// }
114 /// ```
115 fn col_indices(&self) -> &[usize];
116
117 /// Check if the approximation is empty (no pivots)
118 ///
119 /// # Examples
120 ///
121 /// ```
122 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixACA};
123 ///
124 /// let aca = MatrixACA::<f64>::new(2, 2);
125 /// assert!(aca.is_empty());
126 /// ```
127 fn is_empty(&self) -> bool {
128 self.rank() == 0
129 }
130
131 /// Evaluate the approximation at position (i, j)
132 ///
133 /// # Examples
134 ///
135 /// ```
136 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
137 ///
138 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
139 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
140 /// assert!((ci.evaluate(0, 0) - 1.0).abs() < 1e-10);
141 /// assert!((ci.evaluate(1, 1) - 4.0).abs() < 1e-10);
142 /// ```
143 fn evaluate(&self, i: usize, j: usize) -> T;
144
145 /// Get a submatrix of the approximation
146 ///
147 /// # Examples
148 ///
149 /// ```
150 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
151 ///
152 /// let m = from_vec2d(vec![
153 /// vec![1.0_f64, 2.0, 3.0],
154 /// vec![4.0, 5.0, 6.0],
155 /// vec![7.0, 8.0, 9.0],
156 /// ]);
157 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
158 /// let sub = ci.submatrix(&[0, 2], &[1]);
159 /// assert_eq!(sub.nrows(), 2);
160 /// assert_eq!(sub.ncols(), 1);
161 /// assert!((sub[[0, 0]] - 2.0).abs() < 1e-10);
162 /// assert!((sub[[1, 0]] - 8.0).abs() < 1e-10);
163 /// ```
164 fn submatrix(&self, rows: &[usize], cols: &[usize]) -> Matrix<T>;
165
166 /// Get a row of the approximation
167 ///
168 /// # Examples
169 ///
170 /// ```
171 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
172 ///
173 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
174 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
175 /// let row0 = ci.row(0);
176 /// assert_eq!(row0.len(), 2);
177 /// assert!((row0[0] - 1.0).abs() < 1e-10);
178 /// assert!((row0[1] - 2.0).abs() < 1e-10);
179 /// ```
180 fn row(&self, i: usize) -> Vec<T> {
181 let cols: Vec<usize> = (0..self.ncols()).collect();
182 let sub = self.submatrix(&[i], &cols);
183 (0..self.ncols()).map(|j| sub[[0, j]]).collect()
184 }
185
186 /// Get a column of the approximation
187 ///
188 /// # Examples
189 ///
190 /// ```
191 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
192 ///
193 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
194 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
195 /// let col1 = ci.col(1);
196 /// assert_eq!(col1.len(), 2);
197 /// assert!((col1[0] - 2.0).abs() < 1e-10);
198 /// assert!((col1[1] - 4.0).abs() < 1e-10);
199 /// ```
200 fn col(&self, j: usize) -> Vec<T> {
201 let rows: Vec<usize> = (0..self.nrows()).collect();
202 let sub = self.submatrix(&rows, &[j]);
203 (0..self.nrows()).map(|i| sub[[i, 0]]).collect()
204 }
205
206 /// Get the full approximated matrix
207 ///
208 /// # Examples
209 ///
210 /// ```
211 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, from_vec2d};
212 ///
213 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
214 /// let ci = MatrixLUCI::from_matrix(&m, None).unwrap();
215 /// let full = ci.to_matrix();
216 /// assert_eq!(full.nrows(), 2);
217 /// assert_eq!(full.ncols(), 2);
218 /// for i in 0..2 {
219 /// for j in 0..2 {
220 /// assert!((full[[i, j]] - m[[i, j]]).abs() < 1e-10);
221 /// }
222 /// }
223 /// ```
224 fn to_matrix(&self) -> Matrix<T> {
225 let rows: Vec<usize> = (0..self.nrows()).collect();
226 let cols: Vec<usize> = (0..self.ncols()).collect();
227 self.submatrix(&rows, &cols)
228 }
229
230 /// Get available row indices (rows without pivots)
231 ///
232 /// # Examples
233 ///
234 /// ```
235 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixACA, from_vec2d};
236 ///
237 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
238 /// let aca = MatrixACA::from_matrix_with_pivot(&m, (1, 0)).unwrap();
239 /// let avail = aca.available_rows();
240 /// // Row 1 was used as pivot, so 0 and 2 remain
241 /// assert_eq!(avail, vec![0, 2]);
242 /// ```
243 fn available_rows(&self) -> Vec<usize> {
244 let pivot_rows: std::collections::HashSet<usize> =
245 self.row_indices().iter().copied().collect();
246 (0..self.nrows())
247 .filter(|i| !pivot_rows.contains(i))
248 .collect()
249 }
250
251 /// Get available column indices (columns without pivots)
252 ///
253 /// # Examples
254 ///
255 /// ```
256 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixACA, from_vec2d};
257 ///
258 /// let m = from_vec2d(vec![vec![1.0_f64, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
259 /// let aca = MatrixACA::from_matrix_with_pivot(&m, (0, 1)).unwrap();
260 /// let avail = aca.available_cols();
261 /// // Column 1 was used as pivot, so 0 and 2 remain
262 /// assert_eq!(avail, vec![0, 2]);
263 /// ```
264 fn available_cols(&self) -> Vec<usize> {
265 let pivot_cols: std::collections::HashSet<usize> =
266 self.col_indices().iter().copied().collect();
267 (0..self.ncols())
268 .filter(|j| !pivot_cols.contains(j))
269 .collect()
270 }
271
272 /// Compute local error |A - CI| for given indices
273 ///
274 /// # Examples
275 ///
276 /// ```
277 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixACA, from_vec2d};
278 ///
279 /// let m = from_vec2d(vec![
280 /// vec![1.0_f64, 2.0, 3.0],
281 /// vec![4.0, 5.0, 6.0],
282 /// vec![7.0, 8.0, 10.0],
283 /// ]);
284 /// let aca = MatrixACA::from_matrix_with_pivot(&m, (0, 0)).unwrap();
285 /// let err = aca.local_error(&m, &[1, 2], &[1, 2]);
286 /// // Error at pivot position (0,0) would be zero; off-pivot may be non-zero
287 /// assert_eq!(err.nrows(), 2);
288 /// assert_eq!(err.ncols(), 2);
289 /// ```
290 fn local_error(&self, a: &Matrix<T>, rows: &[usize], cols: &[usize]) -> Matrix<T>
291 where
292 T: std::ops::Sub<Output = T>,
293 {
294 let sub_a = submatrix(a, rows, cols);
295 let sub_ci = self.submatrix(rows, cols);
296
297 let mut result = zeros(rows.len(), cols.len());
298 for i in 0..rows.len() {
299 for j in 0..cols.len() {
300 let diff = sub_a[[i, j]] - sub_ci[[i, j]];
301 result[[i, j]] = diff.abs();
302 }
303 }
304 result
305 }
306
307 /// Find a new pivot that maximizes the local error
308 ///
309 /// # Examples
310 ///
311 /// ```
312 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixACA, from_vec2d};
313 ///
314 /// let m = from_vec2d(vec![
315 /// vec![1.0_f64, 2.0, 3.0],
316 /// vec![4.0, 5.0, 6.0],
317 /// vec![7.0, 8.0, 10.0],
318 /// ]);
319 /// let aca = MatrixACA::from_matrix_with_pivot(&m, (0, 0)).unwrap();
320 /// let ((r, c), err_val) = aca.find_new_pivot(&m).unwrap();
321 /// // New pivot must be in available rows/cols (not row 0 or col 0)
322 /// assert_ne!(r, 0);
323 /// assert_ne!(c, 0);
324 /// ```
325 fn find_new_pivot(&self, a: &Matrix<T>) -> Result<((usize, usize), T)>
326 where
327 T: std::ops::Sub<Output = T>,
328 {
329 let avail_rows = self.available_rows();
330 let avail_cols = self.available_cols();
331
332 self.find_new_pivot_in(a, &avail_rows, &avail_cols)
333 }
334
335 /// Find a new pivot in the given row/column subsets
336 ///
337 /// # Examples
338 ///
339 /// ```
340 /// use tensor4all_tcicore::{AbstractMatrixCI, MatrixACA, from_vec2d};
341 ///
342 /// let m = from_vec2d(vec![
343 /// vec![1.0_f64, 2.0, 3.0],
344 /// vec![4.0, 5.0, 6.0],
345 /// vec![7.0, 8.0, 10.0],
346 /// ]);
347 /// let aca = MatrixACA::from_matrix_with_pivot(&m, (0, 0)).unwrap();
348 /// // Search only in rows [1,2] and cols [1,2]
349 /// let ((r, c), _) = aca.find_new_pivot_in(&m, &[1, 2], &[1, 2]).unwrap();
350 /// assert!(r == 1 || r == 2);
351 /// assert!(c == 1 || c == 2);
352 /// ```
353 fn find_new_pivot_in(
354 &self,
355 a: &Matrix<T>,
356 rows: &[usize],
357 cols: &[usize],
358 ) -> Result<((usize, usize), T)>
359 where
360 T: std::ops::Sub<Output = T>,
361 {
362 use crate::error::MatrixCIError;
363
364 if self.rank() == self.nrows().min(self.ncols()) {
365 return Err(MatrixCIError::FullRank);
366 }
367
368 if rows.is_empty() {
369 return Err(MatrixCIError::EmptyIndexSet {
370 dimension: "rows".to_string(),
371 });
372 }
373
374 if cols.is_empty() {
375 return Err(MatrixCIError::EmptyIndexSet {
376 dimension: "cols".to_string(),
377 });
378 }
379
380 let errors = self.local_error(a, rows, cols);
381
382 // Find maximum error position (comparing by abs_sq which returns f64)
383 let mut max_val_sq: f64 = errors[[0, 0]].abs_sq();
384 let mut max_i = 0;
385 let mut max_j = 0;
386
387 for i in 0..rows.len() {
388 for j in 0..cols.len() {
389 let val_sq: f64 = errors[[i, j]].abs_sq();
390 if val_sq > max_val_sq {
391 max_val_sq = val_sq;
392 max_i = i;
393 max_j = j;
394 }
395 }
396 }
397
398 Ok(((rows[max_i], cols[max_j]), errors[[max_i, max_j]]))
399 }
400}