Skip to main content

tensor4all_simplett/
compression.rs

1//! Compression algorithms for tensor trains
2
3use crate::einsum_helper::{tensor_to_row_major_vec, typed_tensor_from_row_major_slice};
4use crate::error::Result;
5use crate::tensortrain::TensorTrain;
6use crate::traits::{AbstractTensorTrain, TTScalar};
7use crate::types::{tensor3_zeros, Tensor3, Tensor3Ops};
8use tenferro_tensor::{TensorScalar, TypedTensor};
9use tensor4all_tcicore::matrix::{mat_mul, ncols, nrows, zeros, Matrix};
10use tensor4all_tcicore::Scalar;
11use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions};
12use tensor4all_tensorbackend::BackendLinalgScalar;
13
14/// Matrix decomposition method used during TT compression.
15///
16/// The method controls how each bond is factored during the right-to-left
17/// truncation sweep.
18///
19/// # Examples
20///
21/// ```
22/// use tensor4all_simplett::CompressionMethod;
23///
24/// let method = CompressionMethod::default();
25/// assert_eq!(method, CompressionMethod::LU);
26/// ```
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum CompressionMethod {
29    /// Rank-revealing LU decomposition (default).
30    ///
31    /// Fast and robust for most use cases.
32    #[default]
33    LU,
34    /// Cross-interpolation (CI) decomposition.
35    ///
36    /// Can be faster than LU for very large bond dimensions because it
37    /// avoids dense matrix operations.
38    CI,
39    /// Singular value decomposition (SVD).
40    ///
41    /// Gives the optimal low-rank approximation (Eckart--Young theorem)
42    /// but is more expensive than LU or CI.
43    SVD,
44}
45
46/// Configuration for tensor train compression.
47///
48/// Controls the accuracy-vs-cost trade-off when reducing bond dimensions
49/// via [`TensorTrain::compress`] or [`TensorTrain::compressed`].
50///
51/// # Fields
52///
53/// | Field | Default | Meaning |
54/// |-------|---------|---------|
55/// | `method` | `LU` | Decomposition algorithm (see [`CompressionMethod`]) |
56/// | `tolerance` | `1e-12` | Relative truncation threshold per bond |
57/// | `max_bond_dim` | `usize::MAX` | Hard upper bound on any bond dimension |
58/// | `normalize_error` | `true` | Whether error is measured relative to the norm |
59///
60/// # Choosing `tolerance`
61///
62/// - `1e-12` (default): near machine precision, almost lossless.
63/// - `1e-8` to `1e-6`: good for most scientific applications.
64/// - `1e-4` to `1e-2`: aggressive compression, useful for exploratory work.
65///
66/// Tighter tolerances produce larger bond dimensions and slower evaluation.
67///
68/// # Examples
69///
70/// ```
71/// use tensor4all_simplett::{CompressionOptions, CompressionMethod};
72///
73/// // Default: LU with tolerance 1e-12
74/// let opts = CompressionOptions::default();
75/// assert_eq!(opts.method, CompressionMethod::LU);
76/// assert!((opts.tolerance - 1e-12).abs() < 1e-15);
77/// assert_eq!(opts.max_bond_dim, usize::MAX);
78///
79/// // Custom: SVD with moderate tolerance and bond dim cap
80/// let opts = CompressionOptions {
81///     method: CompressionMethod::SVD,
82///     tolerance: 1e-6,
83///     max_bond_dim: 50,
84///     ..Default::default()
85/// };
86/// assert_eq!(opts.method, CompressionMethod::SVD);
87/// ```
88#[derive(Debug, Clone)]
89pub struct CompressionOptions {
90    /// Decomposition method (LU, CI, or SVD).
91    pub method: CompressionMethod,
92    /// Relative truncation tolerance per bond.
93    ///
94    /// Singular values (or pivots) smaller than `tolerance * sigma_max` are
95    /// discarded. Smaller values preserve more accuracy but produce larger
96    /// bond dimensions.
97    pub tolerance: f64,
98    /// Hard upper bound on any bond dimension.
99    ///
100    /// Even if the tolerance would allow a larger rank, the bond dimension
101    /// is capped at this value. Set to `usize::MAX` (default) for no limit.
102    pub max_bond_dim: usize,
103    /// Whether to normalize the truncation error by the tensor norm.
104    pub normalize_error: bool,
105}
106
107impl Default for CompressionOptions {
108    fn default() -> Self {
109        Self {
110            method: CompressionMethod::LU,
111            tolerance: 1e-12,
112            max_bond_dim: usize::MAX,
113            normalize_error: true,
114        }
115    }
116}
117
118/// Convert Tensor3 to Matrix for factorization (left matrix view)
119fn tensor3_to_left_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
120    let left_dim = tensor.left_dim();
121    let site_dim = tensor.site_dim();
122    let right_dim = tensor.right_dim();
123    let rows = left_dim * site_dim;
124    let cols = right_dim;
125
126    let mut mat = zeros(rows, cols);
127    for l in 0..left_dim {
128        for s in 0..site_dim {
129            for r in 0..right_dim {
130                mat[[l * site_dim + s, r]] = *tensor.get3(l, s, r);
131            }
132        }
133    }
134    mat
135}
136
137/// Convert Tensor3 to Matrix for factorization (right matrix view)
138fn tensor3_to_right_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
139    let left_dim = tensor.left_dim();
140    let site_dim = tensor.site_dim();
141    let right_dim = tensor.right_dim();
142    let rows = left_dim;
143    let cols = site_dim * right_dim;
144
145    let mut mat = zeros(rows, cols);
146    for l in 0..left_dim {
147        for s in 0..site_dim {
148            for r in 0..right_dim {
149                mat[[l, s * right_dim + r]] = *tensor.get3(l, s, r);
150            }
151        }
152    }
153    mat
154}
155
156/// Factorize a matrix into left and right factors
157fn factorize<T>(
158    matrix: &Matrix<T>,
159    method: CompressionMethod,
160    tolerance: f64,
161    max_bond_dim: usize,
162    left_orthogonal: bool,
163) -> crate::error::Result<(Matrix<T>, Matrix<T>, usize)>
164where
165    T: TTScalar + Scalar + tensor4all_tcicore::MatrixLuciScalar,
166    tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
167{
168    let reltol = if tolerance > 0.0 { tolerance } else { 1e-14 };
169    let abstol = 0.0;
170
171    let options = RrLUOptions {
172        max_rank: max_bond_dim,
173        rel_tol: reltol,
174        abs_tol: abstol,
175        left_orthogonal,
176    };
177
178    match method {
179        CompressionMethod::LU => {
180            let lu = rrlu(matrix, Some(options))?;
181            let left = lu.left(true); // permuted
182            let right = lu.right(true); // permuted
183            let npivots = lu.npivots();
184            Ok((left, right, npivots))
185        }
186        CompressionMethod::CI => {
187            let luci = MatrixLUCI::from_matrix(matrix, Some(options))?;
188            let left = luci.left();
189            let right = luci.right();
190            let npivots = luci.rank();
191            Ok((left, right, npivots))
192        }
193        CompressionMethod::SVD => {
194            // SVD compression requires additional trait bounds. We use type-erased dispatch
195            // for the supported scalar types (f64, Complex64).
196            svd_dispatch(matrix, tolerance, max_bond_dim, left_orthogonal)
197        }
198    }
199}
200
201/// Trait bounds for SVD-compatible scalars in TT compression
202trait SVDCompressScalar:
203    TTScalar + Scalar + Default + Copy + BackendLinalgScalar + num_complex::ComplexFloat + 'static
204{
205    fn sv_to_f64(real: <Self as TensorScalar>::Real) -> f64;
206    fn from_sv(real: <Self as TensorScalar>::Real) -> Self;
207}
208
209impl SVDCompressScalar for f64 {
210    fn sv_to_f64(real: <Self as TensorScalar>::Real) -> f64 {
211        real
212    }
213    fn from_sv(real: <Self as TensorScalar>::Real) -> Self {
214        real
215    }
216}
217
218impl SVDCompressScalar for num_complex::Complex64 {
219    fn sv_to_f64(real: <Self as TensorScalar>::Real) -> f64 {
220        real
221    }
222    fn from_sv(real: <Self as TensorScalar>::Real) -> Self {
223        num_complex::Complex64::new(real, 0.0)
224    }
225}
226
227/// SVD dispatch: convert generic Matrix<T> to concrete type and call factorize_svd.
228///
229/// This is necessary because SVD requires additional trait bounds (LinalgScalar, etc.)
230/// that aren't available on the generic `T: TTScalar + Scalar`.
231fn svd_dispatch<T: TTScalar + Scalar>(
232    matrix: &Matrix<T>,
233    tolerance: f64,
234    max_bond_dim: usize,
235    left_orthogonal: bool,
236) -> crate::error::Result<(Matrix<T>, Matrix<T>, usize)> {
237    use std::any::Any;
238
239    let m = nrows(matrix);
240    let n = ncols(matrix);
241
242    // Try f64
243    if let Some(mat_f64) = (matrix as &dyn Any).downcast_ref::<Matrix<f64>>() {
244        let (l, r, rank) = factorize_svd(mat_f64, tolerance, max_bond_dim, left_orthogonal)?;
245        // Safety: T is f64 in this branch
246        let left = unsafe { std::mem::transmute::<Matrix<f64>, Matrix<T>>(l) };
247        let right = unsafe { std::mem::transmute::<Matrix<f64>, Matrix<T>>(r) };
248        return Ok((left, right, rank));
249    }
250
251    // Try Complex64
252    if let Some(mat_c64) = (matrix as &dyn Any).downcast_ref::<Matrix<num_complex::Complex64>>() {
253        let (l, r, rank) = factorize_svd(mat_c64, tolerance, max_bond_dim, left_orthogonal)?;
254        let left = unsafe { std::mem::transmute::<Matrix<num_complex::Complex64>, Matrix<T>>(l) };
255        let right = unsafe { std::mem::transmute::<Matrix<num_complex::Complex64>, Matrix<T>>(r) };
256        return Ok((left, right, rank));
257    }
258
259    Err(crate::error::TensorTrainError::InvalidOperation {
260        message: format!(
261            "SVD compression not supported for this scalar type (matrix {}x{})",
262            m, n
263        ),
264    })
265}
266
267/// Helper: extract row-major data from a TypedTensor.
268fn typed_tensor_row_major<T: TensorScalar>(
269    tensor: &TypedTensor<T>,
270) -> crate::error::Result<Vec<T>> {
271    Ok(tensor_to_row_major_vec(tensor))
272}
273
274/// SVD-based factorization for TT compression
275fn factorize_svd<T: SVDCompressScalar>(
276    matrix: &Matrix<T>,
277    tolerance: f64,
278    max_bond_dim: usize,
279    left_orthogonal: bool,
280) -> crate::error::Result<(Matrix<T>, Matrix<T>, usize)> {
281    let m = nrows(matrix);
282    let n = ncols(matrix);
283
284    if m == 0 || n == 0 {
285        return Err(crate::error::TensorTrainError::InvalidOperation {
286            message: "Cannot factorize empty matrix".to_string(),
287        });
288    }
289
290    // Convert matrixci::Matrix to TypedTensor (row-major) for SVD
291    let mut data = vec![T::zero(); m * n];
292    for i in 0..m {
293        for j in 0..n {
294            data[i * n + j] = matrix[[i, j]];
295        }
296    }
297    let a_tensor = typed_tensor_from_row_major_slice(&data, &[m, n]);
298
299    let svd_result = tensor4all_tensorbackend::svd_backend(&a_tensor).map_err(|e| {
300        crate::error::TensorTrainError::InvalidOperation {
301            message: format!("SVD computation failed: {e:?}"),
302        }
303    })?;
304
305    // Extract U, S, Vt as row-major vectors
306    let u_data = typed_tensor_row_major(&svd_result.u)?;
307    let u_cols = svd_result.u.shape[1];
308    let s_data: Vec<<T as TensorScalar>::Real> = typed_tensor_row_major(&svd_result.s)?;
309    let vt_data = typed_tensor_row_major(&svd_result.vt)?;
310    let vt_cols = svd_result.vt.shape[1];
311
312    // Determine rank based on tolerance and max_bond_dim
313    let min_dim = m.min(n);
314    let s_max = if !s_data.is_empty() {
315        T::sv_to_f64(s_data[0])
316    } else {
317        0.0
318    };
319
320    let mut rank = 0;
321    for &singular_value in s_data.iter().take(min_dim) {
322        if rank >= max_bond_dim {
323            break;
324        }
325        let sv = T::sv_to_f64(singular_value);
326        if sv < tolerance * s_max {
327            break;
328        }
329        rank += 1;
330    }
331    rank = rank.max(1);
332
333    // Build result matrices
334    let mut left = zeros(m, rank);
335    let mut right = zeros(rank, n);
336
337    if left_orthogonal {
338        // Left = U[:, :rank], Right = diag(S[:rank]) * Vt[:rank, :]
339        for i in 0..m {
340            for j in 0..rank {
341                left[[i, j]] = u_data[i * u_cols + j];
342            }
343        }
344        for i in 0..rank {
345            let sv = T::from_sv(s_data[i]);
346            for j in 0..n {
347                right[[i, j]] = sv * vt_data[i * vt_cols + j];
348            }
349        }
350    } else {
351        // Left = U[:, :rank] * diag(S[:rank]), Right = Vt[:rank, :]
352        for i in 0..m {
353            for j in 0..rank {
354                let sv = T::from_sv(s_data[j]);
355                left[[i, j]] = u_data[i * u_cols + j] * sv;
356            }
357        }
358        for i in 0..rank {
359            for j in 0..n {
360                right[[i, j]] = vt_data[i * vt_cols + j];
361            }
362        }
363    }
364
365    Ok((left, right, rank))
366}
367
368impl<T: TTScalar + Scalar + Default> TensorTrain<T> {
369    /// Compress the tensor train **in place**, reducing bond dimensions.
370    ///
371    /// The algorithm performs two sweeps:
372    /// 1. **Left-to-right**: orthogonalize each bond without truncation.
373    /// 2. **Right-to-left**: truncate each bond according to `options`.
374    ///
375    /// After compression, the tensor train approximates the original within
376    /// the specified tolerance while using smaller bond dimensions.
377    ///
378    /// # Errors
379    ///
380    /// Returns an error if the internal factorization fails.
381    ///
382    /// # Examples
383    ///
384    /// ```
385    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, CompressionOptions};
386    ///
387    /// // Add two constant TTs to get bond dim 2, then compress back to 1
388    /// let a = TensorTrain::<f64>::constant(&[2, 3, 4], 1.0);
389    /// let b = TensorTrain::<f64>::constant(&[2, 3, 4], 2.0);
390    /// let mut sum = a.add(&b).unwrap(); // bond dim = 2
391    /// assert_eq!(sum.rank(), 2);
392    ///
393    /// sum.compress(&CompressionOptions::default()).unwrap();
394    /// assert_eq!(sum.rank(), 1); // compressed back to optimal
395    ///
396    /// // Values are preserved: 1.0 + 2.0 = 3.0
397    /// assert!((sum.evaluate(&[0, 0, 0]).unwrap() - 3.0).abs() < 1e-10);
398    /// ```
399    pub fn compress(&mut self, options: &CompressionOptions) -> Result<()>
400    where
401        T: tensor4all_tcicore::MatrixLuciScalar,
402        tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
403    {
404        let n = self.len();
405        if n <= 1 {
406            return Ok(());
407        }
408
409        let tensors = self.site_tensors_mut();
410
411        // Left-to-right sweep: make left-orthogonal without truncation
412        for ell in 0..n - 1 {
413            let left_dim = tensors[ell].left_dim();
414            let site_dim = tensors[ell].site_dim();
415
416            // Reshape to matrix: (left_dim * site_dim, right_dim)
417            let mat = tensor3_to_left_matrix(&tensors[ell]);
418
419            // Factorize without truncation
420            let (left_factor, right_factor, new_bond_dim) = factorize(
421                &mat,
422                options.method,
423                0.0,        // No truncation in left sweep
424                usize::MAX, // No max bond dim in left sweep
425                true,       // left orthogonal
426            )?;
427
428            // Update current tensor
429            let mut new_tensor = tensor3_zeros(left_dim, site_dim, new_bond_dim);
430            for l in 0..left_dim {
431                for s in 0..site_dim {
432                    for r in 0..new_bond_dim {
433                        let row = l * site_dim + s;
434                        if row < nrows(&left_factor) && r < ncols(&left_factor) {
435                            new_tensor.set3(l, s, r, left_factor[[row, r]]);
436                        }
437                    }
438                }
439            }
440            tensors[ell] = new_tensor;
441
442            // Contract right_factor with next tensor
443            let next_site_dim = tensors[ell + 1].site_dim();
444            let next_right_dim = tensors[ell + 1].right_dim();
445
446            // Build next tensor as matrix (old_left_dim, site_dim * right_dim)
447            let next_mat = tensor3_to_right_matrix(&tensors[ell + 1]);
448
449            // Multiply: right_factor * next_mat
450            let contracted = mat_mul(&right_factor, &next_mat);
451
452            // Update next tensor
453            let mut new_next_tensor = tensor3_zeros(new_bond_dim, next_site_dim, next_right_dim);
454            for l in 0..new_bond_dim {
455                for s in 0..next_site_dim {
456                    for r in 0..next_right_dim {
457                        new_next_tensor.set3(l, s, r, contracted[[l, s * next_right_dim + r]]);
458                    }
459                }
460            }
461            tensors[ell + 1] = new_next_tensor;
462        }
463
464        // Right-to-left sweep: truncate
465        for ell in (1..n).rev() {
466            let site_dim = tensors[ell].site_dim();
467            let right_dim = tensors[ell].right_dim();
468
469            // Reshape to matrix: (left_dim, site_dim * right_dim)
470            let mat = tensor3_to_right_matrix(&tensors[ell]);
471
472            // Factorize with truncation
473            let (left_factor, right_factor, new_bond_dim) = factorize(
474                &mat,
475                options.method,
476                options.tolerance,
477                options.max_bond_dim,
478                false, // right orthogonal
479            )?;
480
481            // Update current tensor from right_factor
482            let mut new_tensor = tensor3_zeros(new_bond_dim, site_dim, right_dim);
483            for l in 0..new_bond_dim {
484                for s in 0..site_dim {
485                    for r in 0..right_dim {
486                        new_tensor.set3(l, s, r, right_factor[[l, s * right_dim + r]]);
487                    }
488                }
489            }
490            tensors[ell] = new_tensor;
491
492            // Contract previous tensor with left_factor
493            let prev_left_dim = tensors[ell - 1].left_dim();
494            let prev_site_dim = tensors[ell - 1].site_dim();
495
496            // Build prev tensor as matrix (left_dim * site_dim, old_right_dim)
497            let prev_mat = tensor3_to_left_matrix(&tensors[ell - 1]);
498
499            // Multiply: prev_mat * left_factor
500            let contracted = mat_mul(&prev_mat, &left_factor);
501
502            // Update prev tensor
503            let mut new_prev_tensor = tensor3_zeros(prev_left_dim, prev_site_dim, new_bond_dim);
504            for l in 0..prev_left_dim {
505                for s in 0..prev_site_dim {
506                    for r in 0..new_bond_dim {
507                        new_prev_tensor.set3(l, s, r, contracted[[l * prev_site_dim + s, r]]);
508                    }
509                }
510            }
511            tensors[ell - 1] = new_prev_tensor;
512        }
513
514        Ok(())
515    }
516
517    /// Return a compressed copy of the tensor train (non-mutating).
518    ///
519    /// Equivalent to cloning and calling [`compress`](Self::compress).
520    ///
521    /// # Examples
522    ///
523    /// ```
524    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, CompressionOptions};
525    ///
526    /// // A tensor train with redundant bond dimension can be compressed
527    /// let tt = TensorTrain::<f64>::constant(&[2, 3, 2], 1.0);
528    ///
529    /// let opts = CompressionOptions::default();
530    /// let compressed = tt.compressed(&opts).unwrap();
531    ///
532    /// // The compressed TT has the same number of sites
533    /// assert_eq!(compressed.len(), tt.len());
534    ///
535    /// // Evaluations agree
536    /// let val_orig = tt.evaluate(&[0, 1, 0]).unwrap();
537    /// let val_comp = compressed.evaluate(&[0, 1, 0]).unwrap();
538    /// assert!((val_orig - val_comp).abs() < 1e-10);
539    /// ```
540    pub fn compressed(&self, options: &CompressionOptions) -> Result<Self>
541    where
542        T: tensor4all_tcicore::MatrixLuciScalar,
543        tensor4all_tcicore::DenseFaerLuKernel: tensor4all_tcicore::PivotKernel<T>,
544    {
545        let mut result = self.clone();
546        result.compress(options)?;
547        Ok(result)
548    }
549}
550
551#[cfg(test)]
552mod tests;