Skip to main content

tensor4all_core/defaults/
factorize.rs

1//! Unified tensor factorization module.
2//!
3//! This module provides a unified `factorize()` function that dispatches to
4//! SVD, QR, LU, or CI (Cross Interpolation) algorithms based on options.
5//!
6//! # Note
7//!
8//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
9//! Generic tensor types are not supported.
10//!
11//! # Example
12//!
13//! ```
14//! use tensor4all_core::{factorize, Canonical, DynIndex, FactorizeOptions, TensorDynLen};
15//!
16//! # fn main() -> anyhow::Result<()> {
17//! let i = DynIndex::new_dyn(2);
18//! let j = DynIndex::new_dyn(2);
19//! let tensor = TensorDynLen::from_dense(
20//!     vec![i.clone(), j.clone()],
21//!     vec![1.0, 0.0, 0.0, 1.0],
22//! )?;
23//! let result = factorize(
24//!     &tensor,
25//!     std::slice::from_ref(&i),
26//!     &FactorizeOptions::svd().with_canonical(Canonical::Left),
27//! )?;
28//!
29//! assert_eq!(result.rank, 2);
30//! assert_eq!(result.left.dims(), vec![2, 2]);
31//! # Ok(())
32//! # }
33//! ```
34
35use crate::defaults::DynIndex;
36use crate::{unfold_split, TensorDynLen};
37use num_complex::{Complex64, ComplexFloat};
38use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions, Scalar as MatrixScalar};
39use tensor4all_tensorbackend::TensorElement;
40
41use crate::defaults::svd::svd_for_factorize;
42use crate::qr::{qr_with, QrOptions};
43use crate::svd::SvdOptions;
44
45// Re-export types from tensor_like for backwards compatibility
46pub use crate::tensor_like::{
47    Canonical, FactorizeAlg, FactorizeError, FactorizeOptions, FactorizeResult,
48};
49
50/// Factorize a tensor into left and right factors.
51///
52/// This function dispatches to the appropriate algorithm based on `options.alg`:
53/// - `SVD`: Singular Value Decomposition
54/// - `QR`: QR decomposition
55/// - `LU`: Rank-revealing LU decomposition
56/// - `CI`: Cross Interpolation
57///
58/// The `canonical` option controls which factor is "canonical":
59/// - `Canonical::Left`: Left factor is orthogonal (SVD/QR) or unit-diagonal (LU/CI)
60/// - `Canonical::Right`: Right factor is orthogonal (SVD) or unit-diagonal (LU/CI)
61///
62/// # Arguments
63/// * `t` - Input tensor
64/// * `left_inds` - Indices to place on the left side
65/// * `options` - Factorization options
66///
67/// # Returns
68/// A `FactorizeResult` containing the left and right factors, bond index,
69/// singular values (for SVD), and rank.
70///
71/// # Errors
72/// Returns `FactorizeError` if:
73/// - The storage type is not supported (only DenseF64 and DenseC64)
74/// - QR is used with `Canonical::Right`
75/// - The underlying algorithm fails
76pub fn factorize(
77    t: &TensorDynLen,
78    left_inds: &[DynIndex],
79    options: &FactorizeOptions,
80) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
81    options.validate()?;
82
83    if t.is_diag() {
84        return Err(FactorizeError::UnsupportedStorage(
85            "Diagonal storage not supported for factorize",
86        ));
87    }
88
89    if t.is_f64() {
90        factorize_impl_f64(t, left_inds, options)
91    } else if t.is_complex() {
92        factorize_impl_c64(t, left_inds, options)
93    } else {
94        Err(FactorizeError::UnsupportedStorage(
95            "factorize currently supports only f64 and Complex64 tensors",
96        ))
97    }
98}
99
100/// Factorize a tensor without applying algorithm-specific truncation options.
101///
102/// This path is intended for canonicalization and other exact tensor-network
103/// rewrites where the decomposition must preserve the represented tensor rather
104/// than obey global SVD/QR/LU rank-dropping defaults.
105///
106/// # Arguments
107/// * `t` - Input tensor.
108/// * `left_inds` - Indices to place on the left side.
109/// * `alg` - Decomposition algorithm to use.
110/// * `canonical` - Which factor should carry the canonical form.
111///
112/// # Returns
113/// A factorization whose contracted factors reconstruct `t` up to numerical
114/// roundoff, with no tolerance-based or maximum-rank truncation applied.
115///
116/// # Errors
117/// Returns [`FactorizeError`] if the storage type is unsupported, the canonical
118/// direction is unsupported for the selected algorithm, or the underlying
119/// decomposition fails.
120///
121/// # Examples
122///
123/// ```
124/// use tensor4all_core::{
125///     factorize_full_rank, Canonical, DynIndex, FactorizeAlg, TensorDynLen, TensorLike,
126/// };
127///
128/// let i = DynIndex::new_dyn(2);
129/// let j = DynIndex::new_dyn(2);
130/// let tensor = TensorDynLen::from_dense(
131///     vec![i.clone(), j.clone()],
132///     vec![1.0_f64, 0.0, 0.0, 1.0e-16],
133/// )?;
134///
135/// let result = factorize_full_rank(
136///     &tensor,
137///     std::slice::from_ref(&i),
138///     FactorizeAlg::QR,
139///     Canonical::Left,
140/// )?;
141/// let reconstructed = result.left.contract(&result.right);
142/// assert!((tensor - reconstructed).maxabs() < 1.0e-18);
143/// # Ok::<(), Box<dyn std::error::Error>>(())
144/// ```
145pub fn factorize_full_rank(
146    t: &TensorDynLen,
147    left_inds: &[DynIndex],
148    alg: FactorizeAlg,
149    canonical: Canonical,
150) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
151    if t.is_diag() {
152        return Err(FactorizeError::UnsupportedStorage(
153            "Diagonal storage not supported for factorize",
154        ));
155    }
156
157    if t.is_f64() {
158        factorize_impl_f64_full_rank(t, left_inds, alg, canonical)
159    } else if t.is_complex() {
160        factorize_impl_c64_full_rank(t, left_inds, alg, canonical)
161    } else {
162        Err(FactorizeError::UnsupportedStorage(
163            "factorize currently supports only f64 and Complex64 tensors",
164        ))
165    }
166}
167
168fn factorize_impl_f64(
169    t: &TensorDynLen,
170    left_inds: &[DynIndex],
171    options: &FactorizeOptions,
172) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
173    match options.alg {
174        FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
175        FactorizeAlg::QR => factorize_qr(t, left_inds, options),
176        FactorizeAlg::LU => factorize_lu::<f64>(t, left_inds, options),
177        FactorizeAlg::CI => factorize_ci::<f64>(t, left_inds, options),
178    }
179}
180
181fn factorize_impl_f64_full_rank(
182    t: &TensorDynLen,
183    left_inds: &[DynIndex],
184    alg: FactorizeAlg,
185    canonical: Canonical,
186) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
187    match alg {
188        FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
189        FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
190        FactorizeAlg::LU => factorize_lu_full_rank::<f64>(t, left_inds, canonical),
191        FactorizeAlg::CI => factorize_ci_full_rank::<f64>(t, left_inds, canonical),
192    }
193}
194
195fn factorize_impl_c64(
196    t: &TensorDynLen,
197    left_inds: &[DynIndex],
198    options: &FactorizeOptions,
199) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
200    match options.alg {
201        FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
202        FactorizeAlg::QR => factorize_qr(t, left_inds, options),
203        FactorizeAlg::LU => factorize_lu::<Complex64>(t, left_inds, options),
204        FactorizeAlg::CI => factorize_ci::<Complex64>(t, left_inds, options),
205    }
206}
207
208fn factorize_impl_c64_full_rank(
209    t: &TensorDynLen,
210    left_inds: &[DynIndex],
211    alg: FactorizeAlg,
212    canonical: Canonical,
213) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
214    match alg {
215        FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
216        FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
217        FactorizeAlg::LU => factorize_lu_full_rank::<Complex64>(t, left_inds, canonical),
218        FactorizeAlg::CI => factorize_ci_full_rank::<Complex64>(t, left_inds, canonical),
219    }
220}
221
222/// SVD factorization implementation.
223fn factorize_svd(
224    t: &TensorDynLen,
225    left_inds: &[DynIndex],
226    options: &FactorizeOptions,
227) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
228    let mut svd_options = SvdOptions::new();
229    if let Some(policy) = options.svd_policy {
230        svd_options = svd_options.with_policy(policy);
231    }
232    if let Some(max_rank) = options.max_rank {
233        svd_options = svd_options.with_max_rank(max_rank);
234    }
235
236    factorize_svd_with_options(t, left_inds, options.canonical, &svd_options)
237}
238
239fn factorize_svd_full_rank(
240    t: &TensorDynLen,
241    left_inds: &[DynIndex],
242    canonical: Canonical,
243) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
244    let svd_options = SvdOptions::full_rank();
245    factorize_svd_with_options(t, left_inds, canonical, &svd_options)
246}
247
248fn factorize_svd_with_options(
249    t: &TensorDynLen,
250    left_inds: &[DynIndex],
251    canonical: Canonical,
252    svd_options: &SvdOptions,
253) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
254    let result = svd_for_factorize(t, left_inds, svd_options)?;
255    let u = result.u;
256    let s = result.s;
257    let vh = result.vh;
258    let bond_index = result.bond_index;
259    let singular_values = result.singular_values;
260    let rank = result.rank;
261    let sim_bond_index = s.indices[1].clone();
262
263    match canonical {
264        Canonical::Left => {
265            // L = U (orthogonal), R = S * V^H
266            let right_contracted = s.contract(&vh);
267            let right = right_contracted.replaceind(&sim_bond_index, &bond_index);
268            Ok(FactorizeResult {
269                left: u,
270                right,
271                bond_index,
272                singular_values: Some(singular_values),
273                rank,
274            })
275        }
276        Canonical::Right => {
277            // L = U * S, R = V^H
278            let left_contracted = u.contract(&s);
279            let left = left_contracted.replaceind(&sim_bond_index, &bond_index);
280            Ok(FactorizeResult {
281                left,
282                right: vh,
283                bond_index,
284                singular_values: Some(singular_values),
285                rank,
286            })
287        }
288    }
289}
290
291/// QR factorization implementation.
292fn factorize_qr(
293    t: &TensorDynLen,
294    left_inds: &[DynIndex],
295    options: &FactorizeOptions,
296) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
297    if options.canonical == Canonical::Right {
298        return Err(FactorizeError::UnsupportedCanonical(
299            "QR only supports Canonical::Left (would need LQ for right)",
300        ));
301    }
302
303    let qr_options = if let Some(rtol) = options.qr_rtol {
304        QrOptions::new().with_rtol(rtol)
305    } else {
306        QrOptions::new()
307    };
308
309    factorize_qr_with_options(t, left_inds, &qr_options)
310}
311
312fn factorize_qr_full_rank(
313    t: &TensorDynLen,
314    left_inds: &[DynIndex],
315    canonical: Canonical,
316) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
317    if canonical == Canonical::Right {
318        return Err(FactorizeError::UnsupportedCanonical(
319            "QR only supports Canonical::Left (would need LQ for right)",
320        ));
321    }
322
323    factorize_qr_with_options(t, left_inds, &QrOptions::full_rank())
324}
325
326fn factorize_qr_with_options(
327    t: &TensorDynLen,
328    left_inds: &[DynIndex],
329    qr_options: &QrOptions,
330) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
331    let (q, r) = qr_with::<f64>(t, left_inds, qr_options)?;
332
333    // Get bond index from Q tensor (last index)
334    let bond_index = q.indices.last().unwrap().clone();
335    // Rank is the last dimension of Q
336    let q_dims = q.dims();
337    let rank = *q_dims.last().unwrap();
338
339    Ok(FactorizeResult {
340        left: q,
341        right: r,
342        bond_index,
343        singular_values: None,
344        rank,
345    })
346}
347
348/// LU factorization implementation.
349fn factorize_lu<T>(
350    t: &TensorDynLen,
351    left_inds: &[DynIndex],
352    options: &FactorizeOptions,
353) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
354where
355    T: TensorElement
356        + ComplexFloat
357        + Default
358        + From<<T as ComplexFloat>::Real>
359        + MatrixScalar
360        + tensor4all_tcicore::MatrixLuciScalar
361        + 'static,
362    <T as ComplexFloat>::Real: Into<f64> + 'static,
363    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
364{
365    factorize_lu_with_options::<T>(
366        t,
367        left_inds,
368        options.canonical,
369        options.max_rank.unwrap_or(usize::MAX),
370        1e-14,
371    )
372}
373
374fn factorize_lu_full_rank<T>(
375    t: &TensorDynLen,
376    left_inds: &[DynIndex],
377    canonical: Canonical,
378) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
379where
380    T: TensorElement
381        + ComplexFloat
382        + Default
383        + From<<T as ComplexFloat>::Real>
384        + MatrixScalar
385        + tensor4all_tcicore::MatrixLuciScalar
386        + 'static,
387    <T as ComplexFloat>::Real: Into<f64> + 'static,
388    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
389{
390    factorize_lu_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
391}
392
393fn factorize_lu_with_options<T>(
394    t: &TensorDynLen,
395    left_inds: &[DynIndex],
396    canonical: Canonical,
397    max_rank: usize,
398    rel_tol: f64,
399) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
400where
401    T: TensorElement
402        + ComplexFloat
403        + Default
404        + From<<T as ComplexFloat>::Real>
405        + MatrixScalar
406        + tensor4all_tcicore::MatrixLuciScalar
407        + 'static,
408    <T as ComplexFloat>::Real: Into<f64> + 'static,
409    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
410{
411    // Unfold tensor into matrix
412    let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
413        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
414
415    // Convert to Matrix type for rrlu
416    let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
417
418    // Set up LU options
419    let left_orthogonal = canonical == Canonical::Left;
420    let lu_options = RrLUOptions {
421        max_rank,
422        rel_tol,
423        abs_tol: 0.0,
424        left_orthogonal,
425    };
426
427    // Perform LU decomposition
428    let lu = rrlu(&a_matrix, Some(lu_options))?;
429    let rank = lu.npivots();
430
431    // Extract L and U matrices (permuted)
432    let l_matrix = lu.left(true);
433    let u_matrix = lu.right(true);
434
435    // Create bond index
436    let bond_index = DynIndex::new_bond(rank)
437        .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
438
439    // Convert L matrix back to tensor
440    let l_vec = matrix_to_vec(&l_matrix);
441    let mut l_indices = left_indices.clone();
442    l_indices.push(bond_index.clone());
443    let left =
444        TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
445
446    // Convert U matrix back to tensor
447    let u_vec = matrix_to_vec(&u_matrix);
448    let mut r_indices = vec![bond_index.clone()];
449    r_indices.extend_from_slice(&right_indices);
450    let right =
451        TensorDynLen::from_dense(r_indices, u_vec).map_err(FactorizeError::ComputationError)?;
452
453    Ok(FactorizeResult {
454        left,
455        right,
456        bond_index,
457        singular_values: None,
458        rank,
459    })
460}
461
462/// CI (Cross Interpolation) factorization implementation.
463fn factorize_ci<T>(
464    t: &TensorDynLen,
465    left_inds: &[DynIndex],
466    options: &FactorizeOptions,
467) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
468where
469    T: TensorElement
470        + ComplexFloat
471        + Default
472        + From<<T as ComplexFloat>::Real>
473        + MatrixScalar
474        + tensor4all_tcicore::MatrixLuciScalar
475        + 'static,
476    <T as ComplexFloat>::Real: Into<f64> + 'static,
477    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
478{
479    factorize_ci_with_options::<T>(
480        t,
481        left_inds,
482        options.canonical,
483        options.max_rank.unwrap_or(usize::MAX),
484        1e-14,
485    )
486}
487
488fn factorize_ci_full_rank<T>(
489    t: &TensorDynLen,
490    left_inds: &[DynIndex],
491    canonical: Canonical,
492) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
493where
494    T: TensorElement
495        + ComplexFloat
496        + Default
497        + From<<T as ComplexFloat>::Real>
498        + MatrixScalar
499        + tensor4all_tcicore::MatrixLuciScalar
500        + 'static,
501    <T as ComplexFloat>::Real: Into<f64> + 'static,
502    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
503{
504    factorize_ci_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
505}
506
507fn factorize_ci_with_options<T>(
508    t: &TensorDynLen,
509    left_inds: &[DynIndex],
510    canonical: Canonical,
511    max_rank: usize,
512    rel_tol: f64,
513) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
514where
515    T: TensorElement
516        + ComplexFloat
517        + Default
518        + From<<T as ComplexFloat>::Real>
519        + MatrixScalar
520        + tensor4all_tcicore::MatrixLuciScalar
521        + 'static,
522    <T as ComplexFloat>::Real: Into<f64> + 'static,
523    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
524{
525    // Unfold tensor into matrix
526    let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
527        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
528
529    // Convert to Matrix type for MatrixLUCI
530    let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
531
532    // Set up LU options for CI
533    let left_orthogonal = canonical == Canonical::Left;
534    let lu_options = RrLUOptions {
535        max_rank,
536        rel_tol,
537        abs_tol: 0.0,
538        left_orthogonal,
539    };
540
541    // Perform CI decomposition
542    let ci = MatrixLUCI::from_matrix(&a_matrix, Some(lu_options))?;
543    let rank = ci.rank();
544
545    // Get left and right matrices from CI
546    let l_matrix = ci.left();
547    let r_matrix = ci.right();
548
549    // Create bond index
550    let bond_index = DynIndex::new_bond(rank)
551        .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
552
553    // Convert L matrix back to tensor
554    let l_vec = matrix_to_vec(&l_matrix);
555    let mut l_indices = left_indices.clone();
556    l_indices.push(bond_index.clone());
557    let left =
558        TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
559
560    // Convert R matrix back to tensor
561    let r_vec = matrix_to_vec(&r_matrix);
562    let mut r_indices = vec![bond_index.clone()];
563    r_indices.extend_from_slice(&right_indices);
564    let right =
565        TensorDynLen::from_dense(r_indices, r_vec).map_err(FactorizeError::ComputationError)?;
566
567    Ok(FactorizeResult {
568        left,
569        right,
570        bond_index,
571        singular_values: None,
572        rank,
573    })
574}
575
576/// Convert a native rank-2 tensor into a `tensor4all_tcicore::Matrix`.
577fn native_tensor_to_matrix<T>(
578    tensor: &tenferro::Tensor,
579    m: usize,
580    n: usize,
581) -> Result<tensor4all_tcicore::Matrix<T>, FactorizeError>
582where
583    T: TensorElement + MatrixScalar + Copy,
584{
585    let data = T::dense_values_from_native_col_major(tensor).map_err(|e| {
586        FactorizeError::ComputationError(anyhow::anyhow!(
587            "failed to extract dense matrix entries from native tensor: {e}"
588        ))
589    })?;
590    if data.len() != m * n {
591        return Err(FactorizeError::ComputationError(anyhow::anyhow!(
592            "native matrix materialization produced {} entries for shape ({m}, {n})",
593            data.len()
594        )));
595    }
596
597    let mut matrix = tensor4all_tcicore::matrix::zeros(m, n);
598    for i in 0..m {
599        for j in 0..n {
600            matrix[[i, j]] = data[j * m + i];
601        }
602    }
603    Ok(matrix)
604}
605
606/// Convert Matrix to Vec for storage.
607fn matrix_to_vec<T>(matrix: &tensor4all_tcicore::Matrix<T>) -> Vec<T>
608where
609    T: Clone,
610{
611    let m = tensor4all_tcicore::matrix::nrows(matrix);
612    let n = tensor4all_tcicore::matrix::ncols(matrix);
613    let mut vec = Vec::with_capacity(m * n);
614    for j in 0..n {
615        for i in 0..m {
616            vec.push(matrix[[i, j]].clone());
617        }
618    }
619    vec
620}
621
622#[cfg(test)]
623mod tests;