Skip to main content

tensor4all_simplett/mpo/
factorize.rs

1//! Factorization methods for MPO tensors
2//!
3//! This module provides various factorization methods (SVD, RSVD, LU, CI)
4//! for compressing and reshaping MPO tensors.
5
6use super::error::{MPOError, Result};
7use super::Matrix2;
8use crate::einsum_helper::{tensor_to_row_major_vec, typed_tensor_from_row_major_slice};
9use num_complex::{Complex64, ComplexFloat};
10use tenferro_tensor::{TensorScalar, TypedTensor};
11use tensor4all_tensorbackend::{svd_backend, BackendLinalgScalar};
12
13/// Factorization method to use
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum FactorizeMethod {
16    /// Singular Value Decomposition
17    #[default]
18    SVD,
19    /// Randomized SVD (faster for large matrices)
20    RSVD,
21    /// LU decomposition with rank-revealing pivoting
22    LU,
23    /// Cross Interpolation
24    CI,
25}
26
27/// Options for factorization
28#[derive(Debug, Clone)]
29pub struct FactorizeOptions {
30    /// Factorization method to use
31    pub method: FactorizeMethod,
32    /// Tolerance for truncation
33    pub tolerance: f64,
34    /// Maximum rank (bond dimension) after factorization
35    pub max_rank: usize,
36    /// Whether to return the left factor as left-orthogonal
37    pub left_orthogonal: bool,
38    /// Number of random projections for RSVD (parameter q)
39    pub rsvd_q: usize,
40    /// Oversampling parameter for RSVD (parameter p)
41    pub rsvd_p: usize,
42}
43
44impl Default for FactorizeOptions {
45    fn default() -> Self {
46        Self {
47            method: FactorizeMethod::SVD,
48            tolerance: 1e-12,
49            max_rank: usize::MAX,
50            left_orthogonal: true,
51            rsvd_q: 2,
52            rsvd_p: 10,
53        }
54    }
55}
56
57/// Result of factorization
58#[derive(Debug, Clone)]
59pub struct FactorizeResult<T: TensorScalar> {
60    /// Left factor matrix (m x rank)
61    pub left: Matrix2<T>,
62    /// Right factor matrix (rank x n)
63    pub right: Matrix2<T>,
64    /// New rank (number of columns in left / rows in right)
65    pub rank: usize,
66    /// Discarded weight (for error estimation)
67    pub discarded: f64,
68}
69
70/// Trait bounds for SVD-compatible scalars
71pub trait SVDScalar:
72    crate::traits::TTScalar + ComplexFloat + Default + Copy + BackendLinalgScalar + 'static
73{
74    /// Convert a backend singular value into `f64` for truncation logic.
75    fn linalg_real_to_f64(real: <Self as TensorScalar>::Real) -> f64;
76    /// Promote a backend singular value into the matrix scalar type.
77    fn from_linalg_real(real: <Self as TensorScalar>::Real) -> Self;
78}
79
80impl SVDScalar for f64 {
81    fn linalg_real_to_f64(real: <Self as TensorScalar>::Real) -> f64 {
82        real
83    }
84
85    fn from_linalg_real(real: <Self as TensorScalar>::Real) -> Self {
86        real
87    }
88}
89
90impl SVDScalar for Complex64 {
91    fn linalg_real_to_f64(real: <Self as TensorScalar>::Real) -> f64 {
92        real
93    }
94
95    fn from_linalg_real(real: <Self as TensorScalar>::Real) -> Self {
96        Complex64::new(real, 0.0)
97    }
98}
99
100/// Factorize a matrix into left and right factors
101///
102/// Returns (L, R, rank, discarded) where:
103/// - L: left factor matrix (rows x rank)
104/// - R: right factor matrix (rank x cols)
105/// - rank: the resulting rank after truncation
106/// - discarded: the discarded weight (for error estimation)
107///
108/// The original matrix M ≈ L @ R
109///
110/// Note: Only SVD method is fully supported. LU and CI require additional
111/// traits and should use `factorize_lu` directly.
112pub fn factorize<T: SVDScalar>(
113    matrix: &Matrix2<T>,
114    options: &FactorizeOptions,
115) -> Result<FactorizeResult<T>> {
116    match options.method {
117        FactorizeMethod::SVD => factorize_svd(matrix, options),
118        FactorizeMethod::RSVD => factorize_rsvd(matrix, options),
119        FactorizeMethod::LU | FactorizeMethod::CI => {
120            // For LU/CI, fall back to SVD for now
121            // Full LU/CI support requires tensor4all_tcicore::Scalar trait
122            factorize_svd(matrix, options)
123        }
124    }
125}
126
127// Use the shared matrix2_zeros from the parent module
128use super::matrix2_zeros;
129
130fn matrix2_to_typed_tensor<T>(matrix: &Matrix2<T>) -> Result<TypedTensor<T>>
131where
132    T: TensorScalar,
133{
134    let dims = [matrix.dim(0), matrix.dim(1)];
135    let data: Vec<T> = matrix.iter().copied().collect();
136    Ok(typed_tensor_from_row_major_slice(&data, &dims))
137}
138
139fn typed_tensor_to_matrix2<T>(tensor: &TypedTensor<T>, op: &'static str) -> Result<Matrix2<T>>
140where
141    T: crate::traits::TTScalar + Default,
142{
143    if tensor.shape.len() != 2 {
144        return Err(MPOError::FactorizationError {
145            message: format!(
146                "{op} returned rank-{} tensor, expected matrix",
147                tensor.shape.len()
148            ),
149        });
150    }
151
152    let rows = tensor.shape[0];
153    let cols = tensor.shape[1];
154    let data = tensor_to_row_major_vec(tensor);
155
156    let mut matrix = matrix2_zeros(rows, cols);
157    for i in 0..rows {
158        for j in 0..cols {
159            matrix[[i, j]] = data[i * cols + j];
160        }
161    }
162    Ok(matrix)
163}
164
165fn typed_row_major_values<T>(tensor: &TypedTensor<T>, op: &'static str) -> Result<Vec<T>>
166where
167    T: TensorScalar,
168{
169    let _ = op;
170    Ok(tensor_to_row_major_vec(tensor))
171}
172
173/// Factorize using SVD
174fn factorize_svd<T: SVDScalar>(
175    matrix: &Matrix2<T>,
176    options: &FactorizeOptions,
177) -> Result<FactorizeResult<T>> {
178    let m = matrix.dim(0);
179    let n = matrix.dim(1);
180
181    if m == 0 || n == 0 {
182        return Err(MPOError::FactorizationError {
183            message: "Cannot factorize empty matrix".to_string(),
184        });
185    }
186
187    // Compute SVD using tensorbackend (tenferro-backed implementation)
188    let a_tensor = matrix2_to_typed_tensor(matrix)?;
189    let svd_result = svd_backend(&a_tensor).map_err(|e| MPOError::FactorizationError {
190        message: format!("SVD computation failed: {:?}", e),
191    })?;
192
193    let u = typed_tensor_to_matrix2(&svd_result.u, "svd.u")?;
194    let vt = typed_tensor_to_matrix2(&svd_result.vt, "svd.vt")?;
195    let singular_values = typed_row_major_values(&svd_result.s, "svd.s")?;
196
197    // Determine rank based on tolerance and max_rank
198    let min_dim = m.min(n);
199    let mut rank = 0;
200    let mut total_weight: f64 = 0.0;
201
202    // Sum all squared singular values for total weight
203    // Singular values are stored in first row: s[[0, i]] (LAPACK-style convention)
204    for &singular_value in singular_values.iter().take(min_dim) {
205        let sv = T::linalg_real_to_f64(singular_value);
206        total_weight += sv * sv;
207    }
208
209    // Find rank by keeping singular values above tolerance
210    let mut kept_weight: f64 = 0.0;
211    for &singular_value in singular_values.iter().take(min_dim) {
212        if rank >= options.max_rank {
213            break;
214        }
215        let sv = T::linalg_real_to_f64(singular_value);
216        if sv < options.tolerance {
217            break;
218        }
219        kept_weight += sv * sv;
220        rank += 1;
221    }
222
223    // Ensure at least rank 1
224    rank = rank.max(1);
225
226    // Calculate discarded weight
227    let discarded: f64 = if total_weight > 0.0 {
228        1.0 - kept_weight / total_weight
229    } else {
230        0.0
231    };
232
233    // Build result matrices
234    let mut left: Matrix2<T> = matrix2_zeros(m, rank);
235    let mut right: Matrix2<T> = matrix2_zeros(rank, n);
236
237    if options.left_orthogonal {
238        // Left = U[:, :rank], Right = diag(S[:rank]) * Vt[:rank, :]
239        //
240        // `svd_backend` returns `vt` in backend convention
241        // (V^T for real and V^H for complex), which is used directly here.
242        for i in 0..m {
243            for j in 0..rank {
244                left[[i, j]] = u[[i, j]];
245            }
246        }
247        for i in 0..rank {
248            // Singular values are stored in first row: s[[0, i]] (LAPACK-style convention)
249            let sv = T::from_linalg_real(singular_values[i]);
250            for j in 0..n {
251                right[[i, j]] = sv * vt[[i, j]];
252            }
253        }
254    } else {
255        // Left = U[:, :rank] * diag(S[:rank]), Right = Vt[:rank, :]
256        for i in 0..m {
257            for j in 0..rank {
258                // Singular values are stored in first row: s[[0, j]] (LAPACK-style convention)
259                let sv = T::from_linalg_real(singular_values[j]);
260                left[[i, j]] = u[[i, j]] * sv;
261            }
262        }
263        for i in 0..rank {
264            for j in 0..n {
265                right[[i, j]] = vt[[i, j]];
266            }
267        }
268    }
269
270    Ok(FactorizeResult {
271        left,
272        right,
273        rank,
274        discarded,
275    })
276}
277
278/// Factorize using randomized SVD
279fn factorize_rsvd<T: SVDScalar>(
280    _matrix: &Matrix2<T>,
281    _options: &FactorizeOptions,
282) -> Result<FactorizeResult<T>> {
283    // TODO: Implement RSVD-based factorization
284    Err(MPOError::FactorizationError {
285        message: "RSVD factorization not yet implemented".to_string(),
286    })
287}
288
289/// Factorize using LU decomposition
290///
291/// This function requires the tensor4all_tcicore::Scalar trait.
292/// Use this directly when you need LU-based factorization.
293pub fn factorize_lu<T>(
294    matrix: &Matrix2<T>,
295    options: &FactorizeOptions,
296) -> Result<FactorizeResult<T>>
297where
298    T: SVDScalar + tensor4all_tcicore::Scalar + tensor4all_tcicore::MatrixLuciScalar,
299    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
300{
301    use tensor4all_tcicore::{AbstractMatrixCI, MatrixLUCI, RrLUOptions};
302
303    let m = matrix.dim(0);
304    let n = matrix.dim(1);
305
306    // Convert Matrix2 to tensor4all_tcicore::Matrix for LU/CI factorization.
307    let mut mat_ci: tensor4all_tcicore::Matrix<T> = tensor4all_tcicore::matrix::zeros(m, n);
308    for i in 0..m {
309        for j in 0..n {
310            mat_ci[[i, j]] = matrix[[i, j]];
311        }
312    }
313
314    let lu_options = RrLUOptions {
315        max_rank: options.max_rank,
316        rel_tol: options.tolerance,
317        abs_tol: 0.0,
318        left_orthogonal: options.left_orthogonal,
319    };
320
321    let luci = MatrixLUCI::from_matrix(&mat_ci, Some(lu_options))?;
322    let left_ci = luci.left();
323    let right_ci = luci.right();
324    let rank = luci.rank().max(1);
325
326    // Convert back to Matrix2
327    let left_m = tensor4all_tcicore::matrix::nrows(&left_ci);
328    let left_n = tensor4all_tcicore::matrix::ncols(&left_ci);
329    let mut left: Matrix2<T> = matrix2_zeros(left_m, left_n);
330    for i in 0..left_m {
331        for j in 0..left_n {
332            left[[i, j]] = left_ci[[i, j]];
333        }
334    }
335
336    let right_m = tensor4all_tcicore::matrix::nrows(&right_ci);
337    let right_n = tensor4all_tcicore::matrix::ncols(&right_ci);
338    let mut right: Matrix2<T> = matrix2_zeros(right_m, right_n);
339    for i in 0..right_m {
340        for j in 0..right_n {
341            right[[i, j]] = right_ci[[i, j]];
342        }
343    }
344
345    Ok(FactorizeResult {
346        left,
347        right,
348        rank,
349        discarded: 0.0,
350    })
351}
352
353/// Factorize using Cross Interpolation
354///
355/// This function requires the tensor4all_tcicore::Scalar trait.
356/// Use this directly when you need CI-based factorization.
357pub fn factorize_ci<T>(
358    matrix: &Matrix2<T>,
359    options: &FactorizeOptions,
360) -> Result<FactorizeResult<T>>
361where
362    T: SVDScalar + tensor4all_tcicore::Scalar + tensor4all_tcicore::MatrixLuciScalar,
363    <T as ComplexFloat>::Real: Into<f64>,
364    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
365{
366    // CI uses the same LUCI implementation as LU
367    factorize_lu(matrix, options)
368}
369
370#[cfg(test)]
371mod tests;