Skip to main content

tensor4all_core/defaults/
factorize.rs

1//! Unified tensor factorization module.
2//!
3//! This module provides a unified `factorize()` function that dispatches to
4//! SVD, QR, LU, or CI (Cross Interpolation) algorithms based on options.
5//!
6//! # Note
7//!
8//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
9//! Generic tensor types are not supported.
10//!
11//! # Example
12//!
13//! ```
14//! use tensor4all_core::{factorize, Canonical, DynIndex, FactorizeOptions, TensorDynLen, TensorLike};
15//!
16//! # fn main() -> anyhow::Result<()> {
17//! let i = DynIndex::new_dyn(2);
18//! let j = DynIndex::new_dyn(2);
19//! let tensor = TensorDynLen::from_dense(
20//!     vec![i.clone(), j.clone()],
21//!     vec![1.0, 0.0, 0.0, 1.0],
22//! )?;
23//! let result = factorize(
24//!     &tensor,
25//!     std::slice::from_ref(&i),
26//!     &FactorizeOptions::svd().with_canonical(Canonical::Left),
27//! )?;
28//!
29//! assert_eq!(result.rank, 2);
30//! assert_eq!(result.left.dims(), vec![2, 2]);
31//! # Ok(())
32//! # }
33//! ```
34
35use crate::defaults::tensordynlen::unfold_split_inner;
36use crate::defaults::DynIndex;
37use crate::{contract_pair, unfold_split, TensorDynLen};
38use num_complex::{Complex64, ComplexFloat};
39use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions, Scalar as MatrixScalar};
40use tensor4all_tensorbackend::{Matrix, TensorElement};
41
42use crate::defaults::svd::svd_for_factorize;
43use crate::qr::{qr_with, QrOptions};
44use crate::svd::SvdOptions;
45
46// Re-export types from tensor_like for backwards compatibility
47pub use crate::tensor_like::{
48    Canonical, FactorizeAlg, FactorizeError, FactorizeOptions, FactorizeResult,
49};
50
51/// Factorize a tensor into left and right factors.
52///
53/// This function dispatches to the appropriate algorithm based on `options.alg`:
54/// - `SVD`: Singular Value Decomposition
55/// - `QR`: QR decomposition
56/// - `LU`: Rank-revealing LU decomposition
57/// - `CI`: Cross Interpolation
58///
59/// The `canonical` option controls which factor is "canonical":
60/// - `Canonical::Left`: Left factor is orthogonal (SVD/QR) or unit-diagonal (LU/CI)
61/// - `Canonical::Right`: Right factor is orthogonal (SVD) or unit-diagonal (LU/CI)
62///
63/// # Arguments
64/// * `t` - Input tensor
65/// * `left_inds` - Indices to place on the left side
66/// * `options` - Factorization options
67///
68/// # Returns
69/// A `FactorizeResult` containing the left and right factors, bond index,
70/// singular values (for SVD), and rank.
71///
72/// # Errors
73/// Returns `FactorizeError` if:
74/// - The storage type is not supported (only DenseF64 and DenseC64)
75/// - QR is used with `Canonical::Right`
76/// - The underlying algorithm fails
77pub fn factorize(
78    t: &TensorDynLen,
79    left_inds: &[DynIndex],
80    options: &FactorizeOptions,
81) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
82    options.validate()?;
83
84    if t.is_diag() {
85        return Err(FactorizeError::UnsupportedStorage(
86            "Diagonal storage not supported for factorize",
87        ));
88    }
89
90    if t.is_f64() {
91        factorize_impl_f64(t, left_inds, options)
92    } else if t.is_complex() {
93        factorize_impl_c64(t, left_inds, options)
94    } else {
95        Err(FactorizeError::UnsupportedStorage(
96            "factorize currently supports only f64 and Complex64 tensors",
97        ))
98    }
99}
100
101/// Factorize a tensor without applying algorithm-specific truncation options.
102///
103/// This path is intended for canonicalization and other exact tensor-network
104/// rewrites where the decomposition must preserve the represented tensor rather
105/// than obey global SVD/QR/LU rank-dropping defaults.
106///
107/// # Arguments
108/// * `t` - Input tensor.
109/// * `left_inds` - Indices to place on the left side.
110/// * `alg` - Decomposition algorithm to use.
111/// * `canonical` - Which factor should carry the canonical form.
112///
113/// # Returns
114/// A factorization whose contracted factors reconstruct `t` up to numerical
115/// roundoff, with no tolerance-based or maximum-rank truncation applied.
116///
117/// # Errors
118/// Returns [`FactorizeError`] if the storage type is unsupported, the canonical
119/// direction is unsupported for the selected algorithm, or the underlying
120/// decomposition fails.
121///
122/// # Examples
123///
124/// ```
125/// use tensor4all_core::{
126///     factorize_full_rank, Canonical, DynIndex, FactorizeAlg, TensorContractionLike, TensorDynLen,
127/// };
128///
129/// let i = DynIndex::new_dyn(2);
130/// let j = DynIndex::new_dyn(2);
131/// let tensor = TensorDynLen::from_dense(
132///     vec![i.clone(), j.clone()],
133///     vec![1.0_f64, 0.0, 0.0, 1.0e-16],
134/// )?;
135///
136/// let result = factorize_full_rank(
137///     &tensor,
138///     std::slice::from_ref(&i),
139///     FactorizeAlg::QR,
140///     Canonical::Left,
141/// )?;
142/// let reconstructed = result.left.contract_pair(&result.right).unwrap();
143/// assert!(tensor.sub(&reconstructed)?.maxabs() < 1.0e-18);
144/// # Ok::<(), Box<dyn std::error::Error>>(())
145/// ```
146pub fn factorize_full_rank(
147    t: &TensorDynLen,
148    left_inds: &[DynIndex],
149    alg: FactorizeAlg,
150    canonical: Canonical,
151) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
152    if t.is_diag() {
153        return Err(FactorizeError::UnsupportedStorage(
154            "Diagonal storage not supported for factorize",
155        ));
156    }
157
158    if t.is_f64() {
159        factorize_impl_f64_full_rank(t, left_inds, alg, canonical)
160    } else if t.is_complex() {
161        factorize_impl_c64_full_rank(t, left_inds, alg, canonical)
162    } else {
163        Err(FactorizeError::UnsupportedStorage(
164            "factorize currently supports only f64 and Complex64 tensors",
165        ))
166    }
167}
168
169fn factorize_impl_f64(
170    t: &TensorDynLen,
171    left_inds: &[DynIndex],
172    options: &FactorizeOptions,
173) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
174    match options.alg {
175        FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
176        FactorizeAlg::QR => factorize_qr(t, left_inds, options),
177        FactorizeAlg::LU => factorize_lu::<f64>(t, left_inds, options),
178        FactorizeAlg::CI => factorize_ci::<f64>(t, left_inds, options),
179    }
180}
181
182fn factorize_impl_f64_full_rank(
183    t: &TensorDynLen,
184    left_inds: &[DynIndex],
185    alg: FactorizeAlg,
186    canonical: Canonical,
187) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
188    match alg {
189        FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
190        FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
191        FactorizeAlg::LU => factorize_lu_full_rank::<f64>(t, left_inds, canonical),
192        FactorizeAlg::CI => factorize_ci_full_rank::<f64>(t, left_inds, canonical),
193    }
194}
195
196fn factorize_impl_c64(
197    t: &TensorDynLen,
198    left_inds: &[DynIndex],
199    options: &FactorizeOptions,
200) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
201    match options.alg {
202        FactorizeAlg::SVD => factorize_svd(t, left_inds, options),
203        FactorizeAlg::QR => factorize_qr(t, left_inds, options),
204        FactorizeAlg::LU => factorize_lu::<Complex64>(t, left_inds, options),
205        FactorizeAlg::CI => factorize_ci::<Complex64>(t, left_inds, options),
206    }
207}
208
209fn factorize_impl_c64_full_rank(
210    t: &TensorDynLen,
211    left_inds: &[DynIndex],
212    alg: FactorizeAlg,
213    canonical: Canonical,
214) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
215    match alg {
216        FactorizeAlg::SVD => factorize_svd_full_rank(t, left_inds, canonical),
217        FactorizeAlg::QR => factorize_qr_full_rank(t, left_inds, canonical),
218        FactorizeAlg::LU => factorize_lu_full_rank::<Complex64>(t, left_inds, canonical),
219        FactorizeAlg::CI => factorize_ci_full_rank::<Complex64>(t, left_inds, canonical),
220    }
221}
222
223/// SVD factorization implementation.
224fn factorize_svd(
225    t: &TensorDynLen,
226    left_inds: &[DynIndex],
227    options: &FactorizeOptions,
228) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
229    let mut svd_options = SvdOptions::new();
230    if let Some(policy) = options.svd_policy {
231        svd_options = svd_options.with_policy(policy);
232    }
233    if let Some(max_rank) = options.max_rank {
234        svd_options = svd_options.with_max_rank(max_rank);
235    }
236
237    factorize_svd_with_options(t, left_inds, options.canonical, &svd_options)
238}
239
240fn factorize_svd_full_rank(
241    t: &TensorDynLen,
242    left_inds: &[DynIndex],
243    canonical: Canonical,
244) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
245    let svd_options = SvdOptions::full_rank();
246    factorize_svd_with_options(t, left_inds, canonical, &svd_options)
247}
248
249fn factorize_svd_with_options(
250    t: &TensorDynLen,
251    left_inds: &[DynIndex],
252    canonical: Canonical,
253    svd_options: &SvdOptions,
254) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
255    let result = svd_for_factorize(t, left_inds, svd_options)?;
256    let u = result.u;
257    let s = result.s;
258    let vh = result.vh;
259    let bond_index = result.bond_index;
260    let singular_values = result.singular_values;
261    let rank = result.rank;
262    let sim_bond_index = s.indices[1].clone();
263
264    match canonical {
265        Canonical::Left => {
266            // L = U (orthogonal), R = S * V^H
267            let right_contracted = contract_pair(&s, &vh)?;
268            let right = right_contracted.replaceind(&sim_bond_index, &bond_index)?;
269            Ok(FactorizeResult {
270                left: u,
271                right,
272                bond_index,
273                singular_values: Some(singular_values),
274                rank,
275            })
276        }
277        Canonical::Right => {
278            // L = U * S, R = V^H
279            let left_contracted = contract_pair(&u, &s)?;
280            let left = left_contracted.replaceind(&sim_bond_index, &bond_index)?;
281            Ok(FactorizeResult {
282                left,
283                right: vh,
284                bond_index,
285                singular_values: Some(singular_values),
286                rank,
287            })
288        }
289    }
290}
291
292/// QR factorization implementation.
293fn factorize_qr(
294    t: &TensorDynLen,
295    left_inds: &[DynIndex],
296    options: &FactorizeOptions,
297) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
298    if options.canonical == Canonical::Right {
299        return Err(FactorizeError::UnsupportedCanonical(
300            "QR only supports Canonical::Left (would need LQ for right)",
301        ));
302    }
303
304    let qr_options = if let Some(rtol) = options.qr_rtol {
305        QrOptions::new().with_rtol(rtol)
306    } else {
307        QrOptions::new()
308    };
309
310    factorize_qr_with_options(t, left_inds, &qr_options)
311}
312
313fn factorize_qr_full_rank(
314    t: &TensorDynLen,
315    left_inds: &[DynIndex],
316    canonical: Canonical,
317) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
318    if canonical == Canonical::Right {
319        return Err(FactorizeError::UnsupportedCanonical(
320            "QR only supports Canonical::Left (would need LQ for right)",
321        ));
322    }
323
324    factorize_qr_with_options(t, left_inds, &QrOptions::full_rank())
325}
326
327fn factorize_qr_with_options(
328    t: &TensorDynLen,
329    left_inds: &[DynIndex],
330    qr_options: &QrOptions,
331) -> Result<FactorizeResult<TensorDynLen>, FactorizeError> {
332    let (q, r) = qr_with::<f64>(t, left_inds, qr_options)?;
333
334    // Get bond index from Q tensor (last index)
335    let bond_index = q
336        .indices
337        .last()
338        .cloned()
339        .ok_or_else(|| anyhow::anyhow!("QR factorization returned a rank-0 Q tensor"))?;
340    // Rank is the last dimension of Q
341    let q_dims = q.dims();
342    let rank = *q_dims
343        .last()
344        .ok_or_else(|| anyhow::anyhow!("QR factorization returned Q with no dimensions"))?;
345
346    Ok(FactorizeResult {
347        left: q,
348        right: r,
349        bond_index,
350        singular_values: None,
351        rank,
352    })
353}
354
355/// LU factorization implementation.
356fn factorize_lu<T>(
357    t: &TensorDynLen,
358    left_inds: &[DynIndex],
359    options: &FactorizeOptions,
360) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
361where
362    T: TensorElement
363        + ComplexFloat
364        + Default
365        + From<<T as ComplexFloat>::Real>
366        + MatrixScalar
367        + tensor4all_tcicore::MatrixLuciScalar
368        + 'static,
369    <T as ComplexFloat>::Real: Into<f64> + 'static,
370{
371    factorize_lu_with_options::<T>(
372        t,
373        left_inds,
374        options.canonical,
375        options.max_rank.unwrap_or(usize::MAX),
376        1e-14,
377    )
378}
379
380fn factorize_lu_full_rank<T>(
381    t: &TensorDynLen,
382    left_inds: &[DynIndex],
383    canonical: Canonical,
384) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
385where
386    T: TensorElement
387        + ComplexFloat
388        + Default
389        + From<<T as ComplexFloat>::Real>
390        + MatrixScalar
391        + tensor4all_tcicore::MatrixLuciScalar
392        + 'static,
393    <T as ComplexFloat>::Real: Into<f64> + 'static,
394{
395    factorize_lu_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
396}
397
398fn factorize_lu_with_options<T>(
399    t: &TensorDynLen,
400    left_inds: &[DynIndex],
401    canonical: Canonical,
402    max_rank: usize,
403    rel_tol: f64,
404) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
405where
406    T: TensorElement
407        + ComplexFloat
408        + Default
409        + From<<T as ComplexFloat>::Real>
410        + MatrixScalar
411        + tensor4all_tcicore::MatrixLuciScalar
412        + 'static,
413    <T as ComplexFloat>::Real: Into<f64> + 'static,
414{
415    // Unfold tensor into matrix
416    let (a_tensor, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
417        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
418
419    // Convert to Matrix type for rrlu
420    let a_matrix = native_tensor_to_matrix::<T>(&a_tensor, m, n)?;
421
422    // Set up LU options
423    let left_orthogonal = canonical == Canonical::Left;
424    let lu_options = RrLUOptions {
425        max_rank,
426        rel_tol,
427        abs_tol: 0.0,
428        left_orthogonal,
429    };
430
431    // Perform LU decomposition
432    let lu = rrlu(&a_matrix, Some(lu_options))?;
433    let rank = lu.npivots();
434
435    // Extract L and U matrices (permuted)
436    let l_matrix = lu.left(true);
437    let u_matrix = lu.right(true);
438
439    // Create bond index
440    let bond_index = DynIndex::new_bond(rank)
441        .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
442
443    // Convert L matrix back to tensor
444    let l_vec = matrix_to_vec(&l_matrix);
445    let mut l_indices = left_indices.clone();
446    l_indices.push(bond_index.clone());
447    let left =
448        TensorDynLen::from_dense(l_indices, l_vec).map_err(FactorizeError::ComputationError)?;
449
450    // Convert U matrix back to tensor
451    let u_vec = matrix_to_vec(&u_matrix);
452    let mut r_indices = vec![bond_index.clone()];
453    r_indices.extend_from_slice(&right_indices);
454    let right =
455        TensorDynLen::from_dense(r_indices, u_vec).map_err(FactorizeError::ComputationError)?;
456
457    Ok(FactorizeResult {
458        left,
459        right,
460        bond_index,
461        singular_values: None,
462        rank,
463    })
464}
465
466/// CI (Cross Interpolation) factorization implementation.
467fn factorize_ci<T>(
468    t: &TensorDynLen,
469    left_inds: &[DynIndex],
470    options: &FactorizeOptions,
471) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
472where
473    T: TensorElement
474        + ComplexFloat
475        + Default
476        + From<<T as ComplexFloat>::Real>
477        + MatrixScalar
478        + tensor4all_tcicore::MatrixLuciScalar
479        + 'static,
480    <T as ComplexFloat>::Real: Into<f64> + 'static,
481{
482    factorize_ci_with_options::<T>(
483        t,
484        left_inds,
485        options.canonical,
486        options.max_rank.unwrap_or(usize::MAX),
487        1e-14,
488    )
489}
490
491fn factorize_ci_full_rank<T>(
492    t: &TensorDynLen,
493    left_inds: &[DynIndex],
494    canonical: Canonical,
495) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
496where
497    T: TensorElement
498        + ComplexFloat
499        + Default
500        + From<<T as ComplexFloat>::Real>
501        + MatrixScalar
502        + tensor4all_tcicore::MatrixLuciScalar
503        + 'static,
504    <T as ComplexFloat>::Real: Into<f64> + 'static,
505{
506    factorize_ci_with_options::<T>(t, left_inds, canonical, usize::MAX, 0.0)
507}
508
509fn factorize_ci_with_options<T>(
510    t: &TensorDynLen,
511    left_inds: &[DynIndex],
512    canonical: Canonical,
513    max_rank: usize,
514    rel_tol: f64,
515) -> Result<FactorizeResult<TensorDynLen>, FactorizeError>
516where
517    T: TensorElement
518        + ComplexFloat
519        + Default
520        + From<<T as ComplexFloat>::Real>
521        + MatrixScalar
522        + tensor4all_tcicore::MatrixLuciScalar
523        + 'static,
524    <T as ComplexFloat>::Real: Into<f64> + 'static,
525{
526    // Unfold tensor into an eager matrix. Pivot selection is primal-only, but
527    // fixed-pivot CI factors are rebuilt from this eager value below.
528    let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
529        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))?;
530
531    // Convert to Matrix type for MatrixLUCI
532    let a_matrix = native_tensor_to_matrix::<T>(matrix_inner.data(), m, n)?;
533
534    // Set up LU options for CI
535    let left_orthogonal = canonical == Canonical::Left;
536    let lu_options = RrLUOptions {
537        max_rank,
538        rel_tol,
539        abs_tol: 0.0,
540        left_orthogonal,
541    };
542
543    // Perform CI decomposition
544    let ci = MatrixLUCI::from_matrix(&a_matrix, Some(lu_options))?;
545    let rank = ci.rank();
546
547    let row_indices = ci.row_indices().to_vec();
548    let col_indices = ci.col_indices().to_vec();
549    let cols_inner = matrix_inner.take_cols(&col_indices).map_err(|e| {
550        FactorizeError::ComputationError(anyhow::anyhow!(
551            "fixed-pivot CI column gather failed: {e}"
552        ))
553    })?;
554    let rows_inner = matrix_inner.take_rows(&row_indices).map_err(|e| {
555        FactorizeError::ComputationError(anyhow::anyhow!("fixed-pivot CI row gather failed: {e}"))
556    })?;
557    let pivot_inner = matrix_inner
558        .take_block(&row_indices, &col_indices)
559        .map_err(|e| {
560            FactorizeError::ComputationError(anyhow::anyhow!(
561                "fixed-pivot CI pivot block gather failed: {e}"
562            ))
563        })?;
564    let (left_inner, right_inner) = match canonical {
565        Canonical::Left => {
566            let left = pivot_inner.right_solve(&cols_inner).map_err(|e| {
567                FactorizeError::ComputationError(anyhow::anyhow!(
568                    "fixed-pivot CI right solve failed: {e}"
569                ))
570            })?;
571            (left, rows_inner)
572        }
573        Canonical::Right => {
574            let right = pivot_inner.solve(&rows_inner).map_err(|e| {
575                FactorizeError::ComputationError(anyhow::anyhow!(
576                    "fixed-pivot CI solve failed: {e}"
577                ))
578            })?;
579            (cols_inner, right)
580        }
581    };
582
583    // Create bond index
584    let bond_index = DynIndex::new_bond(rank)
585        .map_err(|e| anyhow::anyhow!("Failed to create bond index: {:?}", e))?;
586
587    let mut l_indices = left_indices.clone();
588    l_indices.push(bond_index.clone());
589    let l_dims: Vec<usize> = l_indices.iter().map(|idx| idx.dim).collect();
590    let left_inner = left_inner.reshape(&l_dims).map_err(|e| {
591        FactorizeError::ComputationError(anyhow::anyhow!("fixed-pivot CI left reshape failed: {e}"))
592    })?;
593    let left = TensorDynLen::from_inner(l_indices, left_inner)
594        .map_err(FactorizeError::ComputationError)?;
595
596    let mut r_indices = vec![bond_index.clone()];
597    r_indices.extend_from_slice(&right_indices);
598    let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
599    let right_inner = right_inner.reshape(&r_dims).map_err(|e| {
600        FactorizeError::ComputationError(anyhow::anyhow!(
601            "fixed-pivot CI right reshape failed: {e}"
602        ))
603    })?;
604    let right = TensorDynLen::from_inner(r_indices, right_inner)
605        .map_err(FactorizeError::ComputationError)?;
606
607    Ok(FactorizeResult {
608        left,
609        right,
610        bond_index,
611        singular_values: None,
612        rank,
613    })
614}
615
616/// Convert a native rank-2 tensor into a backend [`Matrix`].
617fn native_tensor_to_matrix<T>(
618    tensor: &tenferro::Tensor,
619    m: usize,
620    n: usize,
621) -> Result<Matrix<T>, FactorizeError>
622where
623    T: TensorElement + MatrixScalar + Copy,
624{
625    let data = T::dense_values_from_native_col_major(tensor).map_err(|e| {
626        FactorizeError::ComputationError(anyhow::anyhow!(
627            "failed to extract dense matrix entries from native tensor: {e}"
628        ))
629    })?;
630    if data.len() != m * n {
631        return Err(FactorizeError::ComputationError(anyhow::anyhow!(
632            "native matrix materialization produced {} entries for shape ({m}, {n})",
633            data.len()
634        )));
635    }
636
637    let mut matrix = Matrix::zeros(m, n);
638    for i in 0..m {
639        for j in 0..n {
640            matrix[[i, j]] = data[j * m + i];
641        }
642    }
643    Ok(matrix)
644}
645
646/// Convert Matrix to Vec for storage.
647fn matrix_to_vec<T>(matrix: &Matrix<T>) -> Vec<T>
648where
649    T: Clone,
650{
651    let m = matrix.nrows();
652    let n = matrix.ncols();
653    let mut vec = Vec::with_capacity(m * n);
654    for j in 0..n {
655        for i in 0..m {
656            vec.push(matrix[[i, j]].clone());
657        }
658    }
659    vec
660}
661
662#[cfg(test)]
663mod tests;