Skip to main content

tensor4all_tcicore/
matrixlu.rs

1//! Rank-Revealing LU decomposition (rrLU) implementation.
2//!
3//! Provides [`RrLU`], a full-pivoting LU decomposition that reveals the
4//! numerical rank of a matrix. The decomposition is:
5//!
6//! ```text
7//! P_row * A * P_col = L * U
8//! ```
9//!
10//! where `P_row`, `P_col` are permutation matrices. The rank is determined
11//! by the number of pivots exceeding the tolerance thresholds in
12//! [`RrLUOptions`].
13//!
14//! # Examples
15//!
16//! ```
17//! use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
18//!
19//! let m = from_vec2d(vec![
20//!     vec![1.0_f64, 2.0],
21//!     vec![3.0, 4.0],
22//! ]);
23//! let lu = rrlu(&m, None).unwrap();
24//! assert_eq!(lu.npivots(), 2);
25//! ```
26
27use crate::error::{MatrixCIError, Result};
28use crate::matrix::{
29    ncols, nrows, submatrix_argmax, swap_cols, swap_rows, transpose, zeros, Matrix,
30};
31use crate::scalar::Scalar;
32
33/// Rank-Revealing LU decomposition.
34///
35/// Represents a matrix `A` as `P_row * A * P_col = L * U`, where `P_row`
36/// and `P_col` are permutation matrices, `L` is lower-triangular, and `U`
37/// is upper-triangular. One of `L` or `U` has unit diagonal, controlled by
38/// the `left_orthogonal` option.
39///
40/// # Examples
41///
42/// ```
43/// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu, matrix::mat_mul};
44///
45/// let m = from_vec2d(vec![
46///     vec![1.0_f64, 2.0, 3.0],
47///     vec![4.0, 5.0, 6.0],
48///     vec![7.0, 8.0, 10.0],
49/// ]);
50///
51/// let lu = rrlu(&m, None).unwrap();
52/// assert_eq!(lu.npivots(), 3);
53///
54/// // Verify L * U reconstructs the permuted matrix
55/// let l = lu.left(false);
56/// let u = lu.right(false);
57/// let reconstructed = mat_mul(&l, &u);
58///
59/// // Check reconstruction matches the permuted matrix
60/// for i in 0..3 {
61///     for j in 0..3 {
62///         let orig_row = lu.row_permutation()[i];
63///         let orig_col = lu.col_permutation()[j];
64///         assert!((reconstructed[[i, j]] - m[[orig_row, orig_col]]).abs() < 1e-10);
65///     }
66/// }
67/// ```
68#[derive(Debug, Clone)]
69pub struct RrLU<T: Scalar> {
70    /// Row permutation
71    row_permutation: Vec<usize>,
72    /// Column permutation
73    col_permutation: Vec<usize>,
74    /// Lower triangular matrix L
75    l: Matrix<T>,
76    /// Upper triangular matrix U
77    u: Matrix<T>,
78    /// Whether L is left-orthogonal (L has 1s on diagonal) or U is (U has 1s on diagonal)
79    left_orthogonal: bool,
80    /// Number of pivots
81    n_pivot: usize,
82    /// Last pivot error
83    error: f64,
84}
85
86impl<T: Scalar> RrLU<T> {
87    /// Create an empty rrLU for a matrix of given size.
88    ///
89    /// Used internally. Most users should call [`rrlu`] or [`rrlu_inplace`]
90    /// instead.
91    ///
92    /// # Examples
93    ///
94    /// ```
95    /// use tensor4all_tcicore::RrLU;
96    ///
97    /// let lu = RrLU::<f64>::new(3, 4, true);
98    /// assert_eq!(lu.nrows(), 3);
99    /// assert_eq!(lu.ncols(), 4);
100    /// assert_eq!(lu.npivots(), 0);
101    /// assert!(lu.is_left_orthogonal());
102    /// ```
103    pub fn new(nr: usize, nc: usize, left_orthogonal: bool) -> Self {
104        Self {
105            row_permutation: (0..nr).collect(),
106            col_permutation: (0..nc).collect(),
107            l: zeros(nr, 0),
108            u: zeros(0, nc),
109            left_orthogonal,
110            n_pivot: 0,
111            error: f64::NAN,
112        }
113    }
114
115    /// Number of rows
116    ///
117    /// # Examples
118    ///
119    /// ```
120    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
121    ///
122    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
123    /// let lu = rrlu(&m, None).unwrap();
124    /// assert_eq!(lu.nrows(), 3);
125    /// ```
126    pub fn nrows(&self) -> usize {
127        nrows(&self.l)
128    }
129
130    /// Number of columns
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
136    ///
137    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
138    /// let lu = rrlu(&m, None).unwrap();
139    /// assert_eq!(lu.ncols(), 3);
140    /// ```
141    pub fn ncols(&self) -> usize {
142        ncols(&self.u)
143    }
144
145    /// Number of pivots
146    ///
147    /// # Examples
148    ///
149    /// ```
150    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
151    ///
152    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
153    /// let lu = rrlu(&m, None).unwrap();
154    /// assert_eq!(lu.npivots(), 2);
155    /// ```
156    pub fn npivots(&self) -> usize {
157        self.n_pivot
158    }
159
160    /// Row permutation
161    ///
162    /// # Examples
163    ///
164    /// ```
165    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
166    ///
167    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
168    /// let lu = rrlu(&m, None).unwrap();
169    /// let perm = lu.row_permutation();
170    /// assert_eq!(perm.len(), 2);
171    /// // Permutation is a rearrangement of 0..nrows
172    /// let mut sorted = perm.to_vec();
173    /// sorted.sort();
174    /// assert_eq!(sorted, vec![0, 1]);
175    /// ```
176    pub fn row_permutation(&self) -> &[usize] {
177        &self.row_permutation
178    }
179
180    /// Column permutation
181    ///
182    /// # Examples
183    ///
184    /// ```
185    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
186    ///
187    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
188    /// let lu = rrlu(&m, None).unwrap();
189    /// let perm = lu.col_permutation();
190    /// assert_eq!(perm.len(), 2);
191    /// let mut sorted = perm.to_vec();
192    /// sorted.sort();
193    /// assert_eq!(sorted, vec![0, 1]);
194    /// ```
195    pub fn col_permutation(&self) -> &[usize] {
196        &self.col_permutation
197    }
198
199    /// Get row indices (selected pivots)
200    ///
201    /// # Examples
202    ///
203    /// ```
204    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu, RrLUOptions};
205    ///
206    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
207    /// let lu = rrlu(&m, Some(RrLUOptions { max_rank: 1, ..Default::default() })).unwrap();
208    /// let rows = lu.row_indices();
209    /// assert_eq!(rows.len(), 1);
210    /// assert!(rows[0] < 2);
211    /// ```
212    pub fn row_indices(&self) -> Vec<usize> {
213        self.row_permutation[0..self.n_pivot].to_vec()
214    }
215
216    /// Get column indices (selected pivots)
217    ///
218    /// # Examples
219    ///
220    /// ```
221    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu, RrLUOptions};
222    ///
223    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
224    /// let lu = rrlu(&m, Some(RrLUOptions { max_rank: 1, ..Default::default() })).unwrap();
225    /// let cols = lu.col_indices();
226    /// assert_eq!(cols.len(), 1);
227    /// assert!(cols[0] < 2);
228    /// ```
229    pub fn col_indices(&self) -> Vec<usize> {
230        self.col_permutation[0..self.n_pivot].to_vec()
231    }
232
233    /// Get left matrix (optionally permuted)
234    ///
235    /// # Examples
236    ///
237    /// ```
238    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu, matrix::mat_mul};
239    ///
240    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
241    /// let lu = rrlu(&m, None).unwrap();
242    ///
243    /// // Unpermuted: L * U reconstructs the row/col-permuted matrix
244    /// let l = lu.left(false);
245    /// let u = lu.right(false);
246    /// let prod = mat_mul(&l, &u);
247    /// for i in 0..2 {
248    ///     for j in 0..2 {
249    ///         let ri = lu.row_permutation()[i];
250    ///         let cj = lu.col_permutation()[j];
251    ///         assert!((prod[[i, j]] - m[[ri, cj]]).abs() < 1e-10);
252    ///     }
253    /// }
254    /// ```
255    pub fn left(&self, permute: bool) -> Matrix<T> {
256        if permute {
257            let mut result = zeros(nrows(&self.l), ncols(&self.l));
258            for (new_i, &old_i) in self.row_permutation.iter().enumerate() {
259                for j in 0..ncols(&self.l) {
260                    result[[old_i, j]] = self.l[[new_i, j]];
261                }
262            }
263            result
264        } else {
265            self.l.clone()
266        }
267    }
268
269    /// Get right matrix (optionally permuted)
270    ///
271    /// # Examples
272    ///
273    /// ```
274    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu, matrix::mat_mul};
275    ///
276    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
277    /// let lu = rrlu(&m, None).unwrap();
278    /// let l = lu.left(false);
279    /// let u = lu.right(false);
280    /// assert_eq!(u.nrows(), lu.npivots());
281    /// assert_eq!(u.ncols(), lu.ncols());
282    /// // L * U reconstructs the permuted matrix
283    /// let prod = mat_mul(&l, &u);
284    /// assert!((prod[[0, 0]] - m[[lu.row_permutation()[0], lu.col_permutation()[0]]]).abs() < 1e-10);
285    /// ```
286    pub fn right(&self, permute: bool) -> Matrix<T> {
287        if permute {
288            let mut result = zeros(nrows(&self.u), ncols(&self.u));
289            for i in 0..nrows(&self.u) {
290                for (new_j, &old_j) in self.col_permutation.iter().enumerate() {
291                    result[[i, old_j]] = self.u[[i, new_j]];
292                }
293            }
294            result
295        } else {
296            self.u.clone()
297        }
298    }
299
300    /// Get diagonal elements
301    ///
302    /// # Examples
303    ///
304    /// ```
305    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
306    ///
307    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
308    /// let lu = rrlu(&m, None).unwrap();
309    /// let d = lu.diag();
310    /// assert_eq!(d.len(), lu.npivots());
311    /// // Diagonal elements are non-zero for full-rank matrices
312    /// for &val in &d {
313    ///     assert!(val.abs() > 1e-14);
314    /// }
315    /// ```
316    pub fn diag(&self) -> Vec<T> {
317        let n = self.n_pivot;
318        if self.left_orthogonal {
319            (0..n).map(|i| self.u[[i, i]]).collect()
320        } else {
321            (0..n).map(|i| self.l[[i, i]]).collect()
322        }
323    }
324
325    /// Get pivot errors
326    ///
327    /// # Examples
328    ///
329    /// ```
330    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
331    ///
332    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
333    /// let lu = rrlu(&m, None).unwrap();
334    /// let errs = lu.pivot_errors();
335    /// // One entry per pivot plus the final residual
336    /// assert_eq!(errs.len(), lu.npivots() + 1);
337    /// // Errors are non-negative
338    /// for &e in &errs {
339    ///     assert!(e >= 0.0);
340    /// }
341    /// ```
342    pub fn pivot_errors(&self) -> Vec<f64> {
343        let mut errors: Vec<f64> = self.diag().iter().map(|d| f64::sqrt(d.abs_sq())).collect();
344        errors.push(self.error);
345        errors
346    }
347
348    /// Get last pivot error
349    ///
350    /// # Examples
351    ///
352    /// ```
353    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
354    ///
355    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
356    /// let lu = rrlu(&m, None).unwrap();
357    /// // Full-rank decomposition has zero residual
358    /// assert_eq!(lu.last_pivot_error(), 0.0);
359    /// ```
360    pub fn last_pivot_error(&self) -> f64 {
361        self.error
362    }
363
364    /// Transpose the decomposition
365    ///
366    /// # Examples
367    ///
368    /// ```
369    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
370    ///
371    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
372    /// let lu = rrlu(&m, None).unwrap();
373    /// let lu_t = lu.transpose();
374    /// assert_eq!(lu_t.nrows(), lu.ncols());
375    /// assert_eq!(lu_t.ncols(), lu.nrows());
376    /// assert_eq!(lu_t.npivots(), lu.npivots());
377    /// assert_eq!(lu_t.is_left_orthogonal(), !lu.is_left_orthogonal());
378    /// ```
379    pub fn transpose(&self) -> RrLU<T> {
380        RrLU {
381            row_permutation: self.col_permutation.clone(),
382            col_permutation: self.row_permutation.clone(),
383            l: transpose(&self.u),
384            u: transpose(&self.l),
385            left_orthogonal: !self.left_orthogonal,
386            n_pivot: self.n_pivot,
387            error: self.error,
388        }
389    }
390
391    /// Check if left-orthogonal (L has 1s on diagonal)
392    ///
393    /// # Examples
394    ///
395    /// ```
396    /// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu, RrLUOptions};
397    ///
398    /// let m = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
399    ///
400    /// let lu = rrlu(&m, None).unwrap();
401    /// assert!(lu.is_left_orthogonal()); // default
402    ///
403    /// let lu2 = rrlu(&m, Some(RrLUOptions {
404    ///     left_orthogonal: false, ..Default::default()
405    /// })).unwrap();
406    /// assert!(!lu2.is_left_orthogonal());
407    /// ```
408    pub fn is_left_orthogonal(&self) -> bool {
409        self.left_orthogonal
410    }
411}
412
413/// Options for rank-revealing LU decomposition.
414///
415/// # Examples
416///
417/// ```
418/// use tensor4all_tcicore::RrLUOptions;
419///
420/// // Default: rel_tol = 1e-14, no absolute tolerance, no rank limit
421/// let opts = RrLUOptions::default();
422/// assert_eq!(opts.rel_tol, 1e-14);
423/// assert_eq!(opts.abs_tol, 0.0);
424/// assert!(opts.left_orthogonal);
425///
426/// // Limit rank to 5
427/// let opts = RrLUOptions { max_rank: 5, ..Default::default() };
428/// assert_eq!(opts.max_rank, 5);
429/// ```
430#[derive(Debug, Clone)]
431pub struct RrLUOptions {
432    /// Maximum rank
433    pub max_rank: usize,
434    /// Relative tolerance
435    pub rel_tol: f64,
436    /// Absolute tolerance
437    pub abs_tol: f64,
438    /// Left orthogonal (L has 1s on diagonal) or right orthogonal (U has 1s)
439    pub left_orthogonal: bool,
440}
441
442impl Default for RrLUOptions {
443    fn default() -> Self {
444        Self {
445            max_rank: usize::MAX,
446            rel_tol: 1e-14,
447            abs_tol: 0.0,
448            left_orthogonal: true,
449        }
450    }
451}
452
453/// Perform in-place rank-revealing LU decomposition.
454///
455/// The input matrix `a` is modified in place. Use [`rrlu`] for a
456/// non-destructive version.
457///
458/// # Errors
459///
460/// Returns [`MatrixCIError::NaNEncountered`]
461/// if NaN values appear in the L or U factors.
462///
463/// # Examples
464///
465/// ```
466/// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu_inplace, RrLUOptions};
467///
468/// let mut m = from_vec2d(vec![
469///     vec![1.0_f64, 2.0],
470///     vec![3.0, 4.0],
471/// ]);
472/// let lu = rrlu_inplace(&mut m, Some(RrLUOptions { max_rank: 1, ..Default::default() })).unwrap();
473/// assert_eq!(lu.npivots(), 1);
474/// ```
475pub fn rrlu_inplace<T: Scalar>(a: &mut Matrix<T>, options: Option<RrLUOptions>) -> Result<RrLU<T>> {
476    let opts = options.unwrap_or_default();
477    let nr = nrows(a);
478    let nc = ncols(a);
479
480    let mut lu = RrLU::new(nr, nc, opts.left_orthogonal);
481    let max_rank = opts.max_rank.min(nr).min(nc);
482    let mut max_error = 0.0f64;
483
484    while lu.n_pivot < max_rank {
485        let k = lu.n_pivot;
486
487        if k >= nr || k >= nc {
488            break;
489        }
490
491        // Find pivot with maximum absolute value in submatrix
492        let (pivot_row, pivot_col, pivot_val) = submatrix_argmax(a, k..nr, k..nc);
493
494        let pivot_abs = f64::sqrt(pivot_val.abs_sq());
495        lu.error = pivot_abs;
496
497        // Check stopping criteria (but add at least 1 pivot)
498        if lu.n_pivot > 0 && (pivot_abs < opts.rel_tol * max_error || pivot_abs < opts.abs_tol) {
499            break;
500        }
501
502        // Guard against near-zero pivot to prevent NaN from division
503        if pivot_abs < f64::EPSILON {
504            if lu.n_pivot == 0 {
505                // First pivot is near-zero: the matrix is effectively zero
506                lu.error = pivot_abs;
507            }
508            break;
509        }
510
511        max_error = max_error.max(pivot_abs);
512
513        // Swap rows and columns
514        if pivot_row != k {
515            swap_rows(a, k, pivot_row);
516            lu.row_permutation.swap(k, pivot_row);
517        }
518        if pivot_col != k {
519            swap_cols(a, k, pivot_col);
520            lu.col_permutation.swap(k, pivot_col);
521        }
522
523        let pivot = a[[k, k]];
524
525        // Eliminate
526        if opts.left_orthogonal {
527            // Scale column below pivot
528            for i in (k + 1)..nr {
529                let val = a[[i, k]] / pivot;
530                a[[i, k]] = val;
531            }
532        } else {
533            // Scale row to the right of pivot
534            for j in (k + 1)..nc {
535                let val = a[[k, j]] / pivot;
536                a[[k, j]] = val;
537            }
538        }
539
540        // Update submatrix: A[k+1:, k+1:] -= A[k+1:, k] * A[k, k+1:]
541        for i in (k + 1)..nr {
542            for j in (k + 1)..nc {
543                let x = a[[i, k]];
544                let y = a[[k, j]];
545                let old = a[[i, j]];
546                a[[i, j]] = old - x * y;
547            }
548        }
549
550        lu.n_pivot += 1;
551    }
552
553    // Extract L and U
554    let n = lu.n_pivot;
555
556    // L is lower triangular part
557    let mut l = zeros(nr, n);
558    for i in 0..nr {
559        for j in 0..n.min(i + 1) {
560            l[[i, j]] = a[[i, j]];
561        }
562    }
563
564    // U is upper triangular part
565    let mut u = zeros(n, nc);
566    for i in 0..n {
567        for j in i..nc {
568            u[[i, j]] = a[[i, j]];
569        }
570    }
571
572    // Set diagonal to 1 for the orthogonal factor
573    if opts.left_orthogonal {
574        for i in 0..n {
575            l[[i, i]] = T::one();
576        }
577    } else {
578        for i in 0..n {
579            u[[i, i]] = T::one();
580        }
581    }
582
583    // Check for NaNs (return error instead of panicking)
584    for i in 0..nrows(&l) {
585        for j in 0..ncols(&l) {
586            if l[[i, j]].is_nan() {
587                return Err(MatrixCIError::NaNEncountered {
588                    matrix: "L".to_string(),
589                });
590            }
591        }
592    }
593    for i in 0..nrows(&u) {
594        for j in 0..ncols(&u) {
595            if u[[i, j]].is_nan() {
596                return Err(MatrixCIError::NaNEncountered {
597                    matrix: "U".to_string(),
598                });
599            }
600        }
601    }
602
603    // Set error to 0 if full rank
604    if n >= nr.min(nc) {
605        lu.error = 0.0;
606    }
607
608    lu.l = l;
609    lu.u = u;
610
611    Ok(lu)
612}
613
614/// Perform rank-revealing LU decomposition (non-destructive).
615///
616/// Clones the input matrix and calls [`rrlu_inplace`].
617///
618/// # Errors
619///
620/// Returns [`MatrixCIError::NaNEncountered`]
621/// if NaN values appear in the L or U factors.
622///
623/// # Examples
624///
625/// ```
626/// use tensor4all_tcicore::{from_vec2d, matrixlu::rrlu};
627///
628/// let m = from_vec2d(vec![
629///     vec![1.0_f64, 0.0],
630///     vec![0.0, 2.0],
631/// ]);
632/// let lu = rrlu(&m, None).unwrap();
633/// assert_eq!(lu.npivots(), 2);
634/// assert_eq!(lu.nrows(), 2);
635/// assert_eq!(lu.ncols(), 2);
636/// ```
637pub fn rrlu<T: Scalar>(a: &Matrix<T>, options: Option<RrLUOptions>) -> Result<RrLU<T>> {
638    let mut a_copy = a.clone();
639    rrlu_inplace(&mut a_copy, options)
640}
641
642/// Convert L matrix to solve L * X = B given pivot matrix P
643///
644/// Modifies `c` in place so that the columns satisfy the triangular
645/// system defined by `p`. The matrix `c` must have at least `nrows(p)`
646/// columns.
647///
648/// # Examples
649///
650/// ```
651/// use tensor4all_tcicore::{from_vec2d, matrixlu::cols_to_l_matrix};
652///
653/// // Upper-triangular P (2x2)
654/// let p = from_vec2d(vec![
655///     vec![2.0_f64, 1.0],
656///     vec![0.0, 3.0],
657/// ]);
658/// // c has 3 rows, 2 columns (ncols >= nrows(p))
659/// let mut c = from_vec2d(vec![
660///     vec![4.0_f64, 5.0],
661///     vec![6.0, 9.0],
662///     vec![8.0, 7.0],
663/// ]);
664/// cols_to_l_matrix(&mut c, &p, true);
665/// // After processing: c[:,0] was divided by p[0,0]=2
666/// assert!((c[[0, 0]] - 2.0).abs() < 1e-10);
667/// assert!((c[[1, 0]] - 3.0).abs() < 1e-10);
668/// assert!((c[[2, 0]] - 4.0).abs() < 1e-10);
669/// ```
670pub fn cols_to_l_matrix<T: Scalar>(c: &mut Matrix<T>, p: &Matrix<T>, _left_orthogonal: bool) {
671    let n = nrows(p);
672
673    for k in 0..n {
674        let pivot = p[[k, k]];
675        // c[:, k] /= pivot
676        for i in 0..nrows(c) {
677            let val = c[[i, k]] / pivot;
678            c[[i, k]] = val;
679        }
680
681        // c[:, k+1:] -= c[:, k] * p[k, k+1:]
682        for j in (k + 1)..ncols(c) {
683            let p_kj = p[[k, j]];
684            for i in 0..nrows(c) {
685                let c_ik = c[[i, k]];
686                let old = c[[i, j]];
687                c[[i, j]] = old - c_ik * p_kj;
688            }
689        }
690    }
691}
692
693/// Convert R matrix to solve X * U = B given pivot matrix P
694///
695/// Modifies `r` in place so that the rows satisfy the triangular
696/// system defined by `p`. The matrix `r` must have at least `nrows(p)`
697/// rows.
698///
699/// # Examples
700///
701/// ```
702/// use tensor4all_tcicore::{from_vec2d, matrixlu::rows_to_u_matrix};
703///
704/// // Lower-triangular P (2x2)
705/// let p = from_vec2d(vec![
706///     vec![2.0_f64, 0.0],
707///     vec![1.0, 3.0],
708/// ]);
709/// // r has 2 rows (nrows >= nrows(p)), 3 columns
710/// let mut r = from_vec2d(vec![
711///     vec![4.0_f64, 6.0, 8.0],
712///     vec![5.0, 9.0, 7.0],
713/// ]);
714/// rows_to_u_matrix(&mut r, &p, true);
715/// // After processing: r[0,:] was divided by p[0,0]=2
716/// assert!((r[[0, 0]] - 2.0).abs() < 1e-10);
717/// assert!((r[[0, 1]] - 3.0).abs() < 1e-10);
718/// assert!((r[[0, 2]] - 4.0).abs() < 1e-10);
719/// ```
720pub fn rows_to_u_matrix<T: Scalar>(r: &mut Matrix<T>, p: &Matrix<T>, _left_orthogonal: bool) {
721    let n = nrows(p);
722
723    for k in 0..n {
724        let pivot = p[[k, k]];
725        // r[k, :] /= pivot
726        for j in 0..ncols(r) {
727            let val = r[[k, j]] / pivot;
728            r[[k, j]] = val;
729        }
730
731        // r[k+1:, :] -= p[k+1:, k] * r[k, :]
732        for i in (k + 1)..nrows(r) {
733            let p_ik = p[[i, k]];
734            for j in 0..ncols(r) {
735                let r_kj = r[[k, j]];
736                let old = r[[i, j]];
737                r[[i, j]] = old - p_ik * r_kj;
738            }
739        }
740    }
741}
742
743/// Solve LU * x = b
744///
745/// Given factors `L` and `U` such that `A = L * U`, solves `A * x = b`
746/// via forward and back substitution.
747///
748/// # Examples
749///
750/// ```
751/// use tensor4all_tcicore::{from_vec2d, matrixlu::{rrlu, solve_lu}};
752///
753/// let m = from_vec2d(vec![vec![2.0_f64, 1.0], vec![1.0, 3.0]]);
754/// let lu = rrlu(&m, None).unwrap();
755/// let l = lu.left(false);
756/// let u = lu.right(false);
757///
758/// // b = [5, 10] => solve Ax = b in permuted coordinates
759/// let b = from_vec2d(vec![
760///     vec![m[[lu.row_permutation()[0], 0]] * 1.0 + m[[lu.row_permutation()[0], 1]] * 2.0],
761///     vec![m[[lu.row_permutation()[1], 0]] * 1.0 + m[[lu.row_permutation()[1], 1]] * 2.0],
762/// ]);
763/// let x = solve_lu(&l, &u, &b).unwrap();
764/// // Solution in permuted col order
765/// assert!((x[[lu.col_permutation().iter().position(|&c| c == 0).unwrap(), 0]] - 1.0).abs() < 1e-10);
766/// assert!((x[[lu.col_permutation().iter().position(|&c| c == 1).unwrap(), 0]] - 2.0).abs() < 1e-10);
767/// ```
768pub fn solve_lu<T: Scalar>(l: &Matrix<T>, u: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>> {
769    let _n1 = nrows(l);
770    let n2 = ncols(l);
771    let n3 = ncols(u);
772    let m = ncols(b);
773
774    // Solve L * y = b (forward substitution)
775    let mut y: Matrix<T> = zeros(n2, m);
776    for i in 0..n2 {
777        for k in 0..m {
778            let mut sum = b[[i, k]];
779            for j in 0..i {
780                sum = sum - l[[i, j]] * y[[j, k]];
781            }
782            let diag = l[[i, i]];
783            y[[i, k]] = sum / diag;
784        }
785    }
786
787    // Solve U * x = y (back substitution)
788    let mut x: Matrix<T> = zeros(n3, m);
789    for i in (0..n3).rev() {
790        for k in 0..m {
791            let mut sum = y[[i, k]];
792            for j in (i + 1)..n3 {
793                sum = sum - u[[i, j]] * x[[j, k]];
794            }
795            let diag = u[[i, i]];
796            x[[i, k]] = sum / diag;
797        }
798    }
799
800    Ok(x)
801}
802
803#[cfg(test)]
804mod tests;