tensor4all_core/defaults/
factorize.rs

1//! Unified tensor factorization module.
2//!
3//! This module provides a unified `factorize()` function that dispatches to
4//! SVD, QR, LU, or CI (Cross Interpolation) algorithms based on options.
5//!
6//! # Note
7//!
8//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
9//! Generic tensor types are not supported.
10//!
11//! # Example
12//!
13//! ```ignore
14//! use tensor4all_core::{factorize, FactorizeOptions, FactorizeAlg, Canonical};
15//!
16//! let result = factorize(&tensor, &left_inds, &FactorizeOptions::default())?;
17//! // result.left * result.right ≈ tensor
18//! ```
19
20use crate::defaults::DynIndex;
21use crate::{unfold_split, TensorDynLen};
22use num_complex::{Complex64, ComplexFloat};
23use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions, Scalar as MatrixScalar};
24use tensor4all_tensorbackend::TensorElement;
25
26use crate::qr::{qr_with, QrOptions};
27use crate::svd::{svd_for_factorize, SvdOptions};
28
29// Re-export types from tensor_like for backwards compatibility
30pub use crate::tensor_like::{
31    Canonical, FactorizeAlg, FactorizeError, FactorizeOptions, FactorizeResult,
32};
33
34/// Factorize a tensor into left and right factors.
35///
36/// This function dispatches to the appropriate algorithm based on `options.alg`:
37/// - `SVD`: Singular Value Decomposition
38/// - `QR`: QR decomposition
39/// - `LU`: Rank-revealing LU decomposition
40/// - `CI`: Cross Interpolation
41///
42/// The `canonical` option controls which factor is "canonical":
43/// - `Canonical::Left`: Left factor is orthogonal (SVD/QR) or unit-diagonal (LU/CI)
44/// - `Canonical::Right`: Right factor is orthogonal (SVD) or unit-diagonal (LU/CI)
45///
46/// # Arguments
47/// * `t` - Input tensor
48/// * `left_inds` - Indices to place on the left side
49/// * `options` - Factorization options
50///
51/// # Returns
52/// A `FactorizeResult` containing the left and right factors, bond index,
53/// singular values (for SVD), and rank.
54///
55/// # Errors
56/// Returns `FactorizeError` if:
57/// - The storage type is not supported (only DenseF64 and DenseC64)
58/// - QR is used with `Canonical::Right`
59/// - The underlying algorithm fails
60pub fn factorize(
61    t: &TensorDynLen,
62    left_inds: &[DynIndex],
63    options: &FactorizeOptions,
64) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
65    if t.is_diag() {
66        return Err(FactorizeError::UnsupportedStorage(
67            "Diagonal storage not supported for factorize",
68        ));
69    }
70
71    if t.is_f64() {
72        factorize_impl_f64(t, left_inds, options)
73    } else if t.is_complex() {
74        factorize_impl_c64(t, left_inds, options)
75    } else {
76        Err(FactorizeError::UnsupportedStorage(
77            "factorize currently supports only f64 and Complex64 tensors",
78        ))
79    }
80}
81
82fn factorize_impl_f64(
83    t: &TensorDynLen,
84    left_inds: &[DynIndex],
85    options: &FactorizeOptions,
86) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
87    match options.alg {
88        FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
89        FactorizeAlg::QR => factorize_qr(t, left_inds, options),
90        FactorizeAlg::LU => factorize_lu::<f64>(t, left_inds, options),
91        FactorizeAlg::CI => factorize_ci::<f64>(t, left_inds, options),
92    }
93}
94
95fn factorize_impl_c64(
96    t: &TensorDynLen,
97    left_inds: &[DynIndex],
98    options: &FactorizeOptions,
99) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
100    match options.alg {
101        FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
102        FactorizeAlg::QR => factorize_qr(t, left_inds, options),
103        FactorizeAlg::LU => factorize_lu::<Complex64>(t, left_inds, options),
104        FactorizeAlg::CI => factorize_ci::<Complex64>(t, left_inds, options),
105    }
106}
107
108/// SVD factorization implementation.
109fn factorize_svd(
110    t: &TensorDynLen,
111    left_inds: &[DynIndex],
112    options: &FactorizeOptions,
113) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
114    let mut svd_options = SvdOptions::default();
115    if let Some(rtol) = options.rtol {
116        svd_options.truncation.rtol = Some(rtol);
117    }
118    if let Some(max_rank) = options.max_rank {
119        svd_options.truncation.max_rank = Some(max_rank);
120    }
121
122    let result = svd_for_factorize(t, left_inds, &svd_options)?;
123    let u = result.u;
124    let s = result.s;
125    let vh = result.vh;
126    let bond_index = result.bond_index;
127    let singular_values = result.singular_values;
128    let rank = result.rank;
129    let sim_bond_index = s.indices[1].clone();
130
131    match options.canonical {
132        Canonical::Left => {
133            // L = U (orthogonal), R = S * V^H
134            let right_contracted = s.contract(&vh);
135            let right = right_contracted.replaceind(&sim_bond_index, &bond_index);
136            Ok(FactorizeResult {
137                left: u,
138                right,
139                bond_index,
140                singular_values: Some(singular_values),
141                rank,
142            })
143        }
144        Canonical::Right => {
145            // L = U * S, R = V^H
146            let left_contracted = u.contract(&s);
147            let left = left_contracted.replaceind(&sim_bond_index, &bond_index);
148            Ok(FactorizeResult {
149                left,
150                right: vh,
151                bond_index,
152                singular_values: Some(singular_values),
153                rank,
154            })
155        }
156    }
157}
158
159/// QR factorization implementation.
160fn factorize_qr(
161    t: &TensorDynLen,
162    left_inds: &[DynIndex],
163    options: &FactorizeOptions,
164) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
165    if options.canonical == Canonical::Right {
166        return Err(FactorizeError::UnsupportedCanonical(
167            "QR only supports Canonical::Left (would need LQ for right)",
168        ));
169    }
170
171    let mut qr_options = QrOptions::default();
172    if let Some(rtol) = options.rtol {
173        qr_options.truncation.rtol = Some(rtol);
174    }
175
176    let (q, r) = qr_with::<f64>(t, left_inds, &qr_options)?;
177
178    // Get bond index from Q tensor (last index)
179    let bond_index = q.indices.last().unwrap().clone();
180    // Rank is the last dimension of Q
181    let q_dims = q.dims();
182    let rank = *q_dims.last().unwrap();
183
184    Ok(FactorizeResult {
185        left: q,
186        right: r,
187        bond_index,
188        singular_values: None,
189        rank,
190    })
191}
192
193/// LU factorization implementation.
194fn factorize_lu<T>(
195    t: &TensorDynLen,
196    left_inds: &[DynIndex],
197    options: &FactorizeOptions,
198) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
199where
200    T: TensorElement
201        + ComplexFloat
202        + Default
203        + From<<T as ComplexFloat>::Real>
204        + MatrixScalar
205        + tensor4all_tcicore::MatrixLuciScalar
206        + 'static,
207    <T as ComplexFloat>::Real: Into<f64> + 'static,
208    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
209{
210    // Unfold tensor into matrix
211    let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
212        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
213
214    // Convert to Matrix type for rrlu
215    let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
216
217    // Set up LU options
218    let left_orthogonal = options.canonical == Canonical::Left;
219    let lu_options = RrLUOptions {
220        max_rank: options.max_rank.unwrap_or(usize::MAX),
221        rel_tol: options.rtol.unwrap_or(1e-14),
222        abs_tol: 0.0,
223        left_orthogonal,
224    };
225
226    // Perform LU decomposition
227    let lu = rrlu(&a_matrix, Some(lu_options))?;
228    let rank = lu.npivots();
229
230    // Extract L and U matrices (permuted)
231    let l_matrix = lu.left(true);
232    let u_matrix = lu.right(true);
233
234    // Create bond index
235    let bond_index = DynIndex::new_bond(rank)
236        .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
237
238    // Convert L matrix back to tensor
239    let l_vec = matrix_to_vec(&l_matrix);
240    let mut l_indices = left_indices.clone();
241    l_indices.push(bond_index.clone());
242    let left =
243        TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
244
245    // Convert U matrix back to tensor
246    let u_vec = matrix_to_vec(&u_matrix);
247    let mut r_indices = vec![bond_index.clone()];
248    r_indices.extend_from_slice(&right_indices);
249    let right =
250        TensorDynLen::from_dense(r_indices, u_vec).map_err(FactorizeError::ComputationError)?;
251
252    Ok(FactorizeResult {
253        left,
254        right,
255        bond_index,
256        singular_values: None,
257        rank,
258    })
259}
260
261/// CI (Cross Interpolation) factorization implementation.
262fn factorize_ci<T>(
263    t: &TensorDynLen,
264    left_inds: &[DynIndex],
265    options: &FactorizeOptions,
266) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
267where
268    T: TensorElement
269        + ComplexFloat
270        + Default
271        + From<<T as ComplexFloat>::Real>
272        + MatrixScalar
273        + tensor4all_tcicore::MatrixLuciScalar
274        + 'static,
275    <T as ComplexFloat>::Real: Into<f64> + 'static,
276    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
277{
278    // Unfold tensor into matrix
279    let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
280        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
281
282    // Convert to Matrix type for MatrixLUCI
283    let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
284
285    // Set up LU options for CI
286    let left_orthogonal = options.canonical == Canonical::Left;
287    let lu_options = RrLUOptions {
288        max_rank: options.max_rank.unwrap_or(usize::MAX),
289        rel_tol: options.rtol.unwrap_or(1e-14),
290        abs_tol: 0.0,
291        left_orthogonal,
292    };
293
294    // Perform CI decomposition
295    let ci = MatrixLUCI::from_matrix(&a_matrix, Some(lu_options))?;
296    let rank = ci.rank();
297
298    // Get left and right matrices from CI
299    let l_matrix = ci.left();
300    let r_matrix = ci.right();
301
302    // Create bond index
303    let bond_index = DynIndex::new_bond(rank)
304        .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
305
306    // Convert L matrix back to tensor
307    let l_vec = matrix_to_vec(&l_matrix);
308    let mut l_indices = left_indices.clone();
309    l_indices.push(bond_index.clone());
310    let left =
311        TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
312
313    // Convert R matrix back to tensor
314    let r_vec = matrix_to_vec(&r_matrix);
315    let mut r_indices = vec![bond_index.clone()];
316    r_indices.extend_from_slice(&right_indices);
317    let right =
318        TensorDynLen::from_dense(r_indices, r_vec).map_err(FactorizeError::ComputationError)?;
319
320    Ok(FactorizeResult {
321        left,
322        right,
323        bond_index,
324        singular_values: None,
325        rank,
326    })
327}
328
329/// Convert a native rank-2 tensor into a `tensor4all_tcicore::Matrix`.
330fn native_tensor_to_matrix<T>(
331    tensor: &tenferro::Tensor,
332    m: usize,
333    n: usize,
334) -> Result<tensor4all_tcicore::Matrix<T>, FactorizeError>
335where
336    T: TensorElement + MatrixScalar + Copy,
337{
338    let data = T::dense_values_from_native_col_major(tensor).map_err(|e| {
339        FactorizeError::ComputationError(anyhow::anyhow!(
340            "failed to extract dense matrix entries from native tensor: {e}"
341        ))
342    })?;
343    if data.len() != m * n {
344        return Err(FactorizeError::ComputationError(anyhow::anyhow!(
345            "native matrix materialization produced {} entries for shape ({m}, {n})",
346            data.len()
347        )));
348    }
349
350    let mut matrix = tensor4all_tcicore::matrix::zeros(m, n);
351    for i in 0..m {
352        for j in 0..n {
353            matrix[[i, j]] = data[j * m + i];
354        }
355    }
356    Ok(matrix)
357}
358
359/// Convert Matrix to Vec for storage.
360fn matrix_to_vec<T>(matrix: &tensor4all_tcicore::Matrix<T>) -> Vec<T>
361where
362    T: Clone,
363{
364    let m = tensor4all_tcicore::matrix::nrows(matrix);
365    let n = tensor4all_tcicore::matrix::ncols(matrix);
366    let mut vec = Vec::with_capacity(m * n);
367    for j in 0..n {
368        for i in 0..m {
369            vec.push(matrix[[i, j]].clone());
370        }
371    }
372    vec
373}
374
375#[cfg(test)]
376mod tests;