Skip to main content

tensor4all_simplett/
vidal.rs

1//! Vidal and Inverse tensor train representations
2//!
3//! This module provides tensor train representations with explicit singular values:
4//! - `VidalTensorTrain`: Stores tensors and singular values separately (Vidal canonical form)
5//! - `InverseTensorTrain`: Stores tensors and inverse singular values for efficient local updates
6
7use std::ops::Range;
8
9use crate::einsum_helper::{tensor_to_row_major_vec, typed_tensor_from_row_major_slice};
10use crate::error::{Result, TensorTrainError};
11use crate::tensortrain::TensorTrain;
12use crate::traits::{AbstractTensorTrain, TTScalar};
13use crate::types::{tensor3_zeros, Tensor3, Tensor3Ops};
14use num_complex::ComplexFloat;
15use num_traits::ToPrimitive;
16use tenferro_tensor::{TensorScalar, TypedTensor};
17use tensor4all_tcicore::matrix::{mat_mul, ncols, nrows, zeros, Matrix};
18use tensor4all_tcicore::Scalar;
19use tensor4all_tcicore::{rrlu, RrLUOptions};
20use tensor4all_tensorbackend::{svd_backend, BackendLinalgScalar};
21
22/// Compute QR decomposition
23fn qr_decomp<T: TTScalar + Scalar>(matrix: &Matrix<T>) -> (Matrix<T>, Matrix<T>) {
24    let options = RrLUOptions {
25        max_rank: ncols(matrix).min(nrows(matrix)),
26        rel_tol: 0.0,
27        abs_tol: 0.0,
28        left_orthogonal: true,
29    };
30    let lu = rrlu(matrix, Some(options)).expect("rrlu failed in QR decomposition");
31    (lu.left(true), lu.right(true))
32}
33
34fn typed_tensor_to_matrix<T>(tensor: &TypedTensor<T>, op: &'static str) -> Result<Matrix<T>>
35where
36    T: TTScalar + Scalar + Default,
37{
38    if tensor.shape.len() != 2 {
39        return Err(TensorTrainError::InvalidOperation {
40            message: format!(
41                "{op} returned rank-{} tensor, expected matrix",
42                tensor.shape.len()
43            ),
44        });
45    }
46
47    let rows = tensor.shape[0];
48    let cols = tensor.shape[1];
49    let data = tensor_to_row_major_vec(tensor);
50
51    let mut matrix = zeros(rows, cols);
52    for i in 0..rows {
53        for j in 0..cols {
54            matrix[[i, j]] = data[i * cols + j];
55        }
56    }
57    Ok(matrix)
58}
59
60fn typed_real_values_to_f64<R>(tensor: &TypedTensor<R>, op: &'static str) -> Result<Vec<f64>>
61where
62    R: TensorScalar + ToPrimitive,
63{
64    let data = tensor_to_row_major_vec(tensor);
65    data.iter()
66        .map(|value| {
67            value
68                .to_f64()
69                .ok_or_else(|| TensorTrainError::InvalidOperation {
70                    message: format!(
71                        "{op} returned a singular value that cannot be represented as f64"
72                    ),
73                })
74        })
75        .collect()
76}
77
78fn svd_factorize_right_matrix<T>(matrix: &Matrix<T>) -> Result<(Matrix<T>, DiagMatrix, Matrix<T>)>
79where
80    T: TTScalar + Scalar + Default + ComplexFloat + BackendLinalgScalar + Copy + 'static,
81    <T as TensorScalar>::Real: TensorScalar + ToPrimitive,
82{
83    let rows = nrows(matrix);
84    let cols = ncols(matrix);
85    if rows == 0 || cols == 0 {
86        return Err(TensorTrainError::InvalidOperation {
87            message: "Cannot compute Vidal singular values for an empty bond matrix".to_string(),
88        });
89    }
90
91    let mut data = Vec::with_capacity(rows * cols);
92    for i in 0..rows {
93        for j in 0..cols {
94            data.push(matrix[[i, j]]);
95        }
96    }
97
98    let typed = typed_tensor_from_row_major_slice(&data, &[rows, cols]);
99    let decomp = svd_backend(&typed).map_err(|e| TensorTrainError::InvalidOperation {
100        message: format!("Failed to compute Vidal bond SVD: {e}"),
101    })?;
102
103    let u = typed_tensor_to_matrix(&decomp.u, "svd.u")?;
104    let vt = typed_tensor_to_matrix(&decomp.vt, "svd.vt")?;
105    let singular_values = typed_real_values_to_f64(&decomp.s, "svd.s")?;
106    let rank = singular_values.len();
107
108    let mut left_scaled = zeros(rows, rank);
109    for i in 0..rows {
110        for j in 0..rank {
111            left_scaled[[i, j]] = u[[i, j]] * T::from_f64(singular_values[j]);
112        }
113    }
114
115    Ok((left_scaled, singular_values, vt))
116}
117
118/// Convert Tensor3 to Matrix with left dimensions flattened
119fn tensor3_to_left_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
120    let left_dim = tensor.left_dim();
121    let site_dim = tensor.site_dim();
122    let right_dim = tensor.right_dim();
123    let rows = left_dim * site_dim;
124    let cols = right_dim;
125
126    let mut mat = zeros(rows, cols);
127    for l in 0..left_dim {
128        for s in 0..site_dim {
129            for r in 0..right_dim {
130                mat[[l * site_dim + s, r]] = *tensor.get3(l, s, r);
131            }
132        }
133    }
134    mat
135}
136
137/// Convert Tensor3 to Matrix with right dimensions flattened
138fn tensor3_to_right_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
139    let left_dim = tensor.left_dim();
140    let site_dim = tensor.site_dim();
141    let right_dim = tensor.right_dim();
142    let rows = left_dim;
143    let cols = site_dim * right_dim;
144
145    let mut mat = zeros(rows, cols);
146    for l in 0..left_dim {
147        for s in 0..site_dim {
148            for r in 0..right_dim {
149                mat[[l, s * right_dim + r]] = *tensor.get3(l, s, r);
150            }
151        }
152    }
153    mat
154}
155
156/// Diagonal matrix type (stored as vector of diagonal elements).
157///
158/// Each entry represents one singular value on the diagonal. The length
159/// equals the bond dimension at that link.
160///
161/// # Examples
162///
163/// ```
164/// use tensor4all_simplett::DiagMatrix;
165///
166/// let diag: DiagMatrix = vec![1.0, 0.5, 0.25];
167/// assert_eq!(diag.len(), 3);
168/// assert!((diag[0] - 1.0).abs() < 1e-15);
169/// ```
170pub type DiagMatrix = Vec<f64>;
171
172/// Vidal Tensor Train representation
173///
174/// Stores the tensor train in Vidal canonical form where:
175/// - Site tensors are stored separately from singular values
176/// - Singular values are stored as diagonal matrices between sites
177///
178/// This form is useful for algorithms that need to apply local operations
179/// and maintain canonical form efficiently.
180///
181/// # Examples
182///
183/// ```
184/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, VidalTensorTrain};
185///
186/// // Build a simple tensor train and convert to Vidal form.
187/// let tt = TensorTrain::<f64>::constant(&[2, 3], 1.0);
188/// let vidal = VidalTensorTrain::from_tensor_train(&tt).unwrap();
189///
190/// assert_eq!(vidal.len(), 2);
191///
192/// // Converting back preserves tensor train values.
193/// let tt2 = vidal.to_tensor_train();
194/// let val = tt2.evaluate(&[0, 1]).unwrap();
195/// assert!((val - 1.0).abs() < 1e-12);
196///
197/// // The sum is also preserved: 1.0 * 2 * 3 = 6.0
198/// assert!((tt2.sum() - 6.0).abs() < 1e-10);
199/// ```
200#[derive(Debug, Clone)]
201pub struct VidalTensorTrain<T: TTScalar> {
202    /// Site tensors (unscaled)
203    tensors: Vec<Tensor3<T>>,
204    /// Singular values between sites (length = n-1)
205    singular_values: Vec<DiagMatrix>,
206    /// Active partition range
207    partition: Range<usize>,
208}
209
210impl<T: TTScalar + Scalar + Default> VidalTensorTrain<T> {
211    /// Create a VidalTensorTrain from a regular TensorTrain
212    pub fn from_tensor_train(tt: &TensorTrain<T>) -> Result<Self>
213    where
214        T: ComplexFloat + BackendLinalgScalar + Copy + 'static,
215        <T as TensorScalar>::Real: TensorScalar + ToPrimitive,
216    {
217        Self::from_tensor_train_with_partition(tt, 0..tt.len())
218    }
219
220    /// Create a VidalTensorTrain with a specific partition
221    pub fn from_tensor_train_with_partition(
222        tt: &TensorTrain<T>,
223        partition: Range<usize>,
224    ) -> Result<Self>
225    where
226        T: ComplexFloat + BackendLinalgScalar + Copy + 'static,
227        <T as TensorScalar>::Real: TensorScalar + ToPrimitive,
228    {
229        let n = tt.len();
230        if n == 0 {
231            return Ok(Self {
232                tensors: Vec::new(),
233                singular_values: Vec::new(),
234                partition: 0..0,
235            });
236        }
237
238        if partition.end > n {
239            return Err(TensorTrainError::InvalidOperation {
240                message: format!(
241                    "Partition end {} exceeds tensor train length {}",
242                    partition.end, n
243                ),
244            });
245        }
246
247        let mut tensors: Vec<Tensor3<T>> = tt.site_tensors().to_vec();
248        let mut singular_values: Vec<DiagMatrix> = vec![Vec::new(); n - 1];
249
250        // Left sweep: QR decomposition to make left-orthogonal
251        for i in partition.start..partition.end.saturating_sub(1) {
252            let left_dim = tensors[i].left_dim();
253            let site_dim = tensors[i].site_dim();
254
255            let mat = tensor3_to_left_matrix(&tensors[i]);
256            let (q, r) = qr_decomp(&mat);
257
258            let new_bond_dim = ncols(&q);
259
260            // Update current tensor with Q
261            let mut new_tensor = tensor3_zeros(left_dim, site_dim, new_bond_dim);
262            for l in 0..left_dim {
263                for s in 0..site_dim {
264                    for b in 0..new_bond_dim {
265                        let row = l * site_dim + s;
266                        if row < nrows(&q) && b < ncols(&q) {
267                            new_tensor.set3(l, s, b, q[[row, b]]);
268                        }
269                    }
270                }
271            }
272            tensors[i] = new_tensor;
273
274            // Contract R with next tensor
275            let next_site_dim = tensors[i + 1].site_dim();
276            let next_right_dim = tensors[i + 1].right_dim();
277            let next_mat = tensor3_to_right_matrix(&tensors[i + 1]);
278
279            let contracted = mat_mul(&r, &next_mat);
280
281            let mut new_next_tensor = tensor3_zeros(new_bond_dim, next_site_dim, next_right_dim);
282            for l in 0..new_bond_dim {
283                for s in 0..next_site_dim {
284                    for r_idx in 0..next_right_dim {
285                        new_next_tensor.set3(
286                            l,
287                            s,
288                            r_idx,
289                            contracted[[l, s * next_right_dim + r_idx]],
290                        );
291                    }
292                }
293            }
294            tensors[i + 1] = new_next_tensor;
295        }
296
297        // Right sweep: true bond SVD to recover Schmidt coefficients
298        for i in (partition.start + 1..partition.end).rev() {
299            let site_dim = tensors[i].site_dim();
300            let right_dim = tensors[i].right_dim();
301
302            let mat = tensor3_to_right_matrix(&tensors[i]);
303            let (us, sv, vt) = svd_factorize_right_matrix(&mat)?;
304
305            let new_bond_dim = nrows(&vt);
306
307            singular_values[i - 1] = sv;
308
309            // Update current tensor with V^H (right-orthogonal)
310            let mut new_tensor = tensor3_zeros(new_bond_dim, site_dim, right_dim);
311            for l in 0..new_bond_dim {
312                for s in 0..site_dim {
313                    for r in 0..right_dim {
314                        if l < nrows(&vt) {
315                            new_tensor.set3(l, s, r, vt[[l, s * right_dim + r]]);
316                        }
317                    }
318                }
319            }
320            tensors[i] = new_tensor;
321
322            // Contract the previous tensor with U * S to preserve the TT values.
323            if i > partition.start {
324                let prev_left_dim = tensors[i - 1].left_dim();
325                let prev_site_dim = tensors[i - 1].site_dim();
326                let prev_mat = tensor3_to_left_matrix(&tensors[i - 1]);
327                let contracted = mat_mul(&prev_mat, &us);
328
329                let mut new_prev_tensor = tensor3_zeros(prev_left_dim, prev_site_dim, new_bond_dim);
330                for l in 0..prev_left_dim {
331                    for s in 0..prev_site_dim {
332                        for r in 0..new_bond_dim {
333                            if l * prev_site_dim + s < nrows(&contracted) && r < ncols(&contracted)
334                            {
335                                new_prev_tensor.set3(
336                                    l,
337                                    s,
338                                    r,
339                                    contracted[[l * prev_site_dim + s, r]],
340                                );
341                            }
342                        }
343                    }
344                }
345                tensors[i - 1] = new_prev_tensor;
346            }
347        }
348
349        // Divide out singular values from tensors to get Vidal form
350        for i in partition.start..partition.end.saturating_sub(1) {
351            if singular_values[i].is_empty() {
352                continue;
353            }
354
355            let site_dim = tensors[i].site_dim();
356            let right_dim = tensors[i].right_dim();
357            let left_dim = tensors[i].left_dim();
358
359            let mut new_tensor = tensor3_zeros(left_dim, site_dim, right_dim);
360            for l in 0..left_dim {
361                for s in 0..site_dim {
362                    for r in 0..right_dim {
363                        let val = *tensors[i].get3(l, s, r);
364                        let sv = if r < singular_values[i].len() && singular_values[i][r] > 1e-15 {
365                            singular_values[i][r]
366                        } else {
367                            1.0
368                        };
369                        new_tensor.set3(l, s, r, val / T::from_f64(sv));
370                    }
371                }
372            }
373            tensors[i] = new_tensor;
374        }
375
376        Ok(Self {
377            tensors,
378            singular_values,
379            partition,
380        })
381    }
382
383    /// Create a VidalTensorTrain with given tensors and singular values
384    pub fn new(tensors: Vec<Tensor3<T>>, singular_values: Vec<DiagMatrix>) -> Result<Self> {
385        let n = tensors.len();
386        if n == 0 {
387            return Ok(Self {
388                tensors: Vec::new(),
389                singular_values: Vec::new(),
390                partition: 0..0,
391            });
392        }
393
394        if singular_values.len() != n - 1 {
395            return Err(TensorTrainError::InvalidOperation {
396                message: format!(
397                    "Expected {} singular value vectors, got {}",
398                    n - 1,
399                    singular_values.len()
400                ),
401            });
402        }
403
404        Ok(Self {
405            tensors,
406            singular_values,
407            partition: 0..n,
408        })
409    }
410
411    /// Get the singular values between sites i and i+1
412    pub fn singular_values(&self, i: usize) -> &DiagMatrix {
413        &self.singular_values[i]
414    }
415
416    /// Get all singular value matrices
417    pub fn all_singular_values(&self) -> &[DiagMatrix] {
418        &self.singular_values
419    }
420
421    /// Get the partition range
422    pub fn partition(&self) -> &Range<usize> {
423        &self.partition
424    }
425
426    /// Get mutable access to site tensors
427    pub fn site_tensors_mut(&mut self) -> &mut [Tensor3<T>] {
428        &mut self.tensors
429    }
430
431    /// Get mutable access to singular values
432    pub fn singular_values_mut(&mut self, i: usize) -> &mut DiagMatrix {
433        &mut self.singular_values[i]
434    }
435
436    /// Convert to a regular TensorTrain
437    pub fn to_tensor_train(&self) -> TensorTrain<T> {
438        let n = self.len();
439        if n == 0 {
440            return TensorTrain::from_tensors_unchecked(Vec::new());
441        }
442
443        let mut tensors = Vec::with_capacity(n);
444
445        for i in 0..n - 1 {
446            let tensor = &self.tensors[i];
447            let left_dim = tensor.left_dim();
448            let site_dim = tensor.site_dim();
449            let right_dim = tensor.right_dim();
450
451            // Multiply by singular values on the right
452            let mut new_tensor = tensor3_zeros(left_dim, site_dim, right_dim);
453            for l in 0..left_dim {
454                for s in 0..site_dim {
455                    for r in 0..right_dim {
456                        let val = *tensor.get3(l, s, r);
457                        let sv = if r < self.singular_values[i].len() {
458                            self.singular_values[i][r]
459                        } else {
460                            1.0
461                        };
462                        new_tensor.set3(l, s, r, val * T::from_f64(sv));
463                    }
464                }
465            }
466            tensors.push(new_tensor);
467        }
468
469        // Last tensor is unchanged
470        tensors.push(self.tensors[n - 1].clone());
471
472        TensorTrain::from_tensors_unchecked(tensors)
473    }
474}
475
476impl<T: TTScalar + Scalar + Default> AbstractTensorTrain<T> for VidalTensorTrain<T> {
477    fn len(&self) -> usize {
478        self.tensors.len()
479    }
480
481    fn site_tensor(&self, i: usize) -> &Tensor3<T> {
482        &self.tensors[i]
483    }
484
485    fn site_tensors(&self) -> &[Tensor3<T>] {
486        &self.tensors
487    }
488}
489
490/// Inverse Tensor Train representation
491///
492/// Similar to VidalTensorTrain but stores inverse singular values instead.
493/// This is useful for algorithms that need to efficiently apply inverse
494/// operations during local updates.
495///
496/// # Examples
497///
498/// ```
499/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, InverseTensorTrain};
500///
501/// // Build a simple tensor train and convert to inverse form.
502/// let tt = TensorTrain::<f64>::constant(&[2, 3], 1.0);
503/// let inv = InverseTensorTrain::from_tensor_train(&tt).unwrap();
504///
505/// assert_eq!(inv.len(), 2);
506///
507/// // Converting back preserves tensor train values.
508/// let tt2 = inv.to_tensor_train();
509/// let val = tt2.evaluate(&[0, 1]).unwrap();
510/// assert!((val - 1.0).abs() < 1e-12);
511///
512/// // The sum is also preserved: 1.0 * 2 * 3 = 6.0
513/// assert!((tt2.sum() - 6.0).abs() < 1e-10);
514/// ```
515#[derive(Debug, Clone)]
516pub struct InverseTensorTrain<T: TTScalar> {
517    /// Site tensors (scaled by singular values)
518    tensors: Vec<Tensor3<T>>,
519    /// Inverse singular values between sites (length = n-1)
520    inverse_singular_values: Vec<DiagMatrix>,
521    /// Active partition range
522    partition: Range<usize>,
523}
524
525impl<T: TTScalar + Scalar + Default> InverseTensorTrain<T> {
526    /// Create an InverseTensorTrain from a VidalTensorTrain
527    pub fn from_vidal(vidal: &VidalTensorTrain<T>) -> Result<Self> {
528        let n = vidal.len();
529        if n == 0 {
530            return Ok(Self {
531                tensors: Vec::new(),
532                inverse_singular_values: Vec::new(),
533                partition: 0..0,
534            });
535        }
536
537        let mut tensors = Vec::with_capacity(n);
538
539        // First tensor: multiply by S[0] on the right
540        if n > 1 && !vidal.singular_values[0].is_empty() {
541            let tensor = &vidal.tensors[0];
542            let left_dim = tensor.left_dim();
543            let site_dim = tensor.site_dim();
544            let right_dim = tensor.right_dim();
545
546            let mut new_tensor = tensor3_zeros(left_dim, site_dim, right_dim);
547            for l in 0..left_dim {
548                for s in 0..site_dim {
549                    for r in 0..right_dim {
550                        let val = *tensor.get3(l, s, r);
551                        let sv = if r < vidal.singular_values[0].len() {
552                            vidal.singular_values[0][r]
553                        } else {
554                            1.0
555                        };
556                        new_tensor.set3(l, s, r, val * T::from_f64(sv));
557                    }
558                }
559            }
560            tensors.push(new_tensor);
561        } else {
562            tensors.push(vidal.tensors[0].clone());
563        }
564
565        // Middle tensors: multiply by S[i-1] on left and S[i] on right
566        for i in 1..n - 1 {
567            let tensor = &vidal.tensors[i];
568            let left_dim = tensor.left_dim();
569            let site_dim = tensor.site_dim();
570            let right_dim = tensor.right_dim();
571
572            let mut new_tensor = tensor3_zeros(left_dim, site_dim, right_dim);
573            for l in 0..left_dim {
574                for s in 0..site_dim {
575                    for r in 0..right_dim {
576                        let val = *tensor.get3(l, s, r);
577                        let sv_left = if l < vidal.singular_values[i - 1].len() {
578                            vidal.singular_values[i - 1][l]
579                        } else {
580                            1.0
581                        };
582                        let sv_right = if r < vidal.singular_values[i].len() {
583                            vidal.singular_values[i][r]
584                        } else {
585                            1.0
586                        };
587                        new_tensor.set3(
588                            l,
589                            s,
590                            r,
591                            val * T::from_f64(sv_left) * T::from_f64(sv_right),
592                        );
593                    }
594                }
595            }
596            tensors.push(new_tensor);
597        }
598
599        // Last tensor: multiply by S[n-2] on left
600        if n > 1 {
601            let tensor = &vidal.tensors[n - 1];
602            let left_dim = tensor.left_dim();
603            let site_dim = tensor.site_dim();
604            let right_dim = tensor.right_dim();
605
606            let mut new_tensor = tensor3_zeros(left_dim, site_dim, right_dim);
607            for l in 0..left_dim {
608                for s in 0..site_dim {
609                    for r in 0..right_dim {
610                        let val = *tensor.get3(l, s, r);
611                        let sv = if l < vidal.singular_values[n - 2].len() {
612                            vidal.singular_values[n - 2][l]
613                        } else {
614                            1.0
615                        };
616                        new_tensor.set3(l, s, r, val * T::from_f64(sv));
617                    }
618                }
619            }
620            tensors.push(new_tensor);
621        }
622
623        // Compute inverse singular values
624        let inverse_singular_values: Vec<DiagMatrix> = vidal
625            .singular_values
626            .iter()
627            .map(|sv| {
628                sv.iter()
629                    .map(|&v| if v.abs() > 1e-15 { 1.0 / v } else { 0.0 })
630                    .collect()
631            })
632            .collect();
633
634        Ok(Self {
635            tensors,
636            inverse_singular_values,
637            partition: vidal.partition.clone(),
638        })
639    }
640
641    /// Create an InverseTensorTrain from a regular TensorTrain
642    pub fn from_tensor_train(tt: &TensorTrain<T>) -> Result<Self>
643    where
644        T: ComplexFloat + BackendLinalgScalar + Copy + 'static,
645        <T as TensorScalar>::Real: TensorScalar + ToPrimitive,
646    {
647        let vidal = VidalTensorTrain::from_tensor_train(tt)?;
648        Self::from_vidal(&vidal)
649    }
650
651    /// Get the inverse singular values between sites i and i+1
652    pub fn inverse_singular_values(&self, i: usize) -> &DiagMatrix {
653        &self.inverse_singular_values[i]
654    }
655
656    /// Get all inverse singular value matrices
657    pub fn all_inverse_singular_values(&self) -> &[DiagMatrix] {
658        &self.inverse_singular_values
659    }
660
661    /// Get the partition range
662    pub fn partition(&self) -> &Range<usize> {
663        &self.partition
664    }
665
666    /// Get mutable access to site tensors
667    pub fn site_tensors_mut(&mut self) -> &mut [Tensor3<T>] {
668        &mut self.tensors
669    }
670
671    /// Set two adjacent site tensors along with their inverse singular values
672    pub fn set_two_site_tensors(
673        &mut self,
674        i: usize,
675        tensor1: Tensor3<T>,
676        inv_sv: DiagMatrix,
677        tensor2: Tensor3<T>,
678    ) -> Result<()> {
679        if i >= self.len() - 1 {
680            return Err(TensorTrainError::InvalidOperation {
681                message: format!(
682                    "Cannot set two-site tensors at site {} (max {})",
683                    i,
684                    self.len() - 2
685                ),
686            });
687        }
688
689        self.tensors[i] = tensor1;
690        self.inverse_singular_values[i] = inv_sv;
691        self.tensors[i + 1] = tensor2;
692        Ok(())
693    }
694
695    /// Convert to a regular TensorTrain
696    pub fn to_tensor_train(&self) -> TensorTrain<T> {
697        let n = self.len();
698        if n == 0 {
699            return TensorTrain::from_tensors_unchecked(Vec::new());
700        }
701
702        let mut tensors = Vec::with_capacity(n);
703
704        // All tensors except last: multiply by inverse singular values on right
705        for i in 0..n - 1 {
706            let tensor = &self.tensors[i];
707            let left_dim = tensor.left_dim();
708            let site_dim = tensor.site_dim();
709            let right_dim = tensor.right_dim();
710
711            let mut new_tensor = tensor3_zeros(left_dim, site_dim, right_dim);
712            for l in 0..left_dim {
713                for s in 0..site_dim {
714                    for r in 0..right_dim {
715                        let val = *tensor.get3(l, s, r);
716                        let inv_sv = if r < self.inverse_singular_values[i].len() {
717                            self.inverse_singular_values[i][r]
718                        } else {
719                            1.0
720                        };
721                        new_tensor.set3(l, s, r, val * T::from_f64(inv_sv));
722                    }
723                }
724            }
725            tensors.push(new_tensor);
726        }
727
728        // Last tensor is unchanged
729        tensors.push(self.tensors[n - 1].clone());
730
731        TensorTrain::from_tensors_unchecked(tensors)
732    }
733}
734
735impl<T: TTScalar + Scalar + Default> AbstractTensorTrain<T> for InverseTensorTrain<T> {
736    fn len(&self) -> usize {
737        self.tensors.len()
738    }
739
740    fn site_tensor(&self, i: usize) -> &Tensor3<T> {
741        &self.tensors[i]
742    }
743
744    fn site_tensors(&self) -> &[Tensor3<T>] {
745        &self.tensors
746    }
747}
748
749#[cfg(test)]
750mod tests;