Skip to main content

tensor4all_simplett/
canonical.rs

1//! Canonical forms for tensor trains
2//!
3//! This module provides tensor train representations with canonical forms:
4//! - `SiteTensorTrain`: Center-canonical form where tensors left of center are
5//!   left-orthogonal and tensors right of center are right-orthogonal.
6
7use std::ops::Range;
8
9use crate::error::{Result, TensorTrainError};
10use crate::tensortrain::TensorTrain;
11use crate::traits::{AbstractTensorTrain, TTScalar};
12use crate::types::{tensor3_zeros, Tensor3, Tensor3Ops};
13use tensor4all_tcicore::matrix::{mat_mul, ncols, nrows, transpose, zeros, Matrix};
14use tensor4all_tcicore::Scalar;
15use tensor4all_tcicore::{rrlu, RrLUOptions};
16
17/// Compute QR decomposition using rank-revealing LU with left-orthogonal output
18fn qr_decomp<T: TTScalar + Scalar>(matrix: &Matrix<T>) -> (Matrix<T>, Matrix<T>) {
19    let options = RrLUOptions {
20        max_rank: ncols(matrix).min(nrows(matrix)),
21        rel_tol: 0.0, // No truncation
22        abs_tol: 0.0,
23        left_orthogonal: true,
24    };
25    let lu = rrlu(matrix, Some(options)).expect("rrlu failed in QR decomposition");
26    (lu.left(true), lu.right(true))
27}
28
29/// Compute LQ decomposition (transpose, QR, transpose)
30fn lq_decomp<T: TTScalar + Scalar>(matrix: &Matrix<T>) -> (Matrix<T>, Matrix<T>) {
31    let at = transpose(matrix);
32    let (qt, lt) = qr_decomp(&at);
33    (transpose(&lt), transpose(&qt))
34}
35
36/// Convert Tensor3 to Matrix with left dimensions flattened
37fn tensor3_to_left_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
38    let left_dim = tensor.left_dim();
39    let site_dim = tensor.site_dim();
40    let right_dim = tensor.right_dim();
41    let rows = left_dim * site_dim;
42    let cols = right_dim;
43
44    let mut mat = zeros(rows, cols);
45    for l in 0..left_dim {
46        for s in 0..site_dim {
47            for r in 0..right_dim {
48                mat[[l * site_dim + s, r]] = *tensor.get3(l, s, r);
49            }
50        }
51    }
52    mat
53}
54
55/// Convert Tensor3 to Matrix with right dimensions flattened
56fn tensor3_to_right_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
57    let left_dim = tensor.left_dim();
58    let site_dim = tensor.site_dim();
59    let right_dim = tensor.right_dim();
60    let rows = left_dim;
61    let cols = site_dim * right_dim;
62
63    let mut mat = zeros(rows, cols);
64    for l in 0..left_dim {
65        for s in 0..site_dim {
66            for r in 0..right_dim {
67                mat[[l, s * right_dim + r]] = *tensor.get3(l, s, r);
68            }
69        }
70    }
71    mat
72}
73
74/// Site Tensor Train with center canonical form
75///
76/// A tensor train where:
77/// - Tensors at indices < center are left-orthogonal
78/// - Tensors at indices > center are right-orthogonal
79/// - The tensor at the center index is the "center" tensor
80///
81/// # Examples
82///
83/// ```
84/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, SiteTensorTrain};
85///
86/// // Build a constant tensor train and convert to center-canonical form.
87/// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 1.0);
88/// let stt = SiteTensorTrain::from_tensor_train(&tt, 1).unwrap();
89///
90/// // The center is at site 1.
91/// assert_eq!(stt.center(), 1);
92/// assert_eq!(stt.len(), 3);
93///
94/// // Converting back preserves tensor train values.
95/// let tt2 = stt.to_tensor_train();
96/// let val = tt2.evaluate(&[0, 1, 2]).unwrap();
97/// assert!((val - 1.0).abs() < 1e-12);
98/// ```
99#[derive(Debug, Clone)]
100pub struct SiteTensorTrain<T: TTScalar> {
101    /// Site tensors
102    tensors: Vec<Tensor3<T>>,
103    /// Center index (0-based)
104    center: usize,
105    /// Active partition range
106    partition: Range<usize>,
107}
108
109impl<T: TTScalar + Scalar + Default> SiteTensorTrain<T> {
110    /// Create a new SiteTensorTrain from tensors with specified center
111    pub fn new(tensors: Vec<Tensor3<T>>, center: usize) -> Result<Self> {
112        let n = tensors.len();
113        if n == 0 {
114            return Err(TensorTrainError::Empty);
115        }
116        if center >= n {
117            return Err(TensorTrainError::InvalidOperation {
118                message: format!("Center {} is out of range for {} tensors", center, n),
119            });
120        }
121
122        // Validate dimensions
123        for i in 0..n.saturating_sub(1) {
124            if tensors[i].right_dim() != tensors[i + 1].left_dim() {
125                return Err(TensorTrainError::DimensionMismatch { site: i });
126            }
127        }
128
129        let mut result = Self {
130            tensors,
131            center,
132            partition: 0..n,
133        };
134        result.canonicalize();
135        Ok(result)
136    }
137
138    /// Create from TensorTrain with specified center
139    pub fn from_tensor_train(tt: &TensorTrain<T>, center: usize) -> Result<Self> {
140        let tensors = tt.site_tensors().to_vec();
141        Self::new(tensors, center)
142    }
143
144    /// Get the center index
145    pub fn center(&self) -> usize {
146        self.center
147    }
148
149    /// Get the partition range
150    pub fn partition(&self) -> &Range<usize> {
151        &self.partition
152    }
153
154    /// Get mutable access to site tensors
155    pub fn site_tensors_mut(&mut self) -> &mut [Tensor3<T>] {
156        &mut self.tensors
157    }
158
159    /// Canonicalize the tensor train around the center
160    fn canonicalize(&mut self) {
161        let n = self.len();
162        if n <= 1 {
163            return;
164        }
165
166        // Left sweep: make tensors [0..center) left-orthogonal
167        for i in 0..self.center {
168            self.make_left_orthogonal(i);
169        }
170
171        // Right sweep: make tensors (center..n] right-orthogonal
172        for i in (self.center + 1..n).rev() {
173            self.make_right_orthogonal(i);
174        }
175    }
176
177    /// Make tensor at site i left-orthogonal, pushing R to site i+1
178    fn make_left_orthogonal(&mut self, i: usize) {
179        if i >= self.len() - 1 {
180            return;
181        }
182
183        let left_dim = self.tensors[i].left_dim();
184        let site_dim = self.tensors[i].site_dim();
185
186        // Reshape to (left_dim * site_dim, right_dim)
187        let mat = tensor3_to_left_matrix(&self.tensors[i]);
188        let (q, r) = qr_decomp(&mat);
189
190        let new_bond_dim = ncols(&q);
191
192        // Update current tensor with Q
193        let mut new_tensor = tensor3_zeros(left_dim, site_dim, new_bond_dim);
194        for l in 0..left_dim {
195            for s in 0..site_dim {
196                for b in 0..new_bond_dim {
197                    let row = l * site_dim + s;
198                    if row < nrows(&q) && b < ncols(&q) {
199                        new_tensor.set3(l, s, b, q[[row, b]]);
200                    }
201                }
202            }
203        }
204        self.tensors[i] = new_tensor;
205
206        // Contract R with next tensor
207        let next_site_dim = self.tensors[i + 1].site_dim();
208        let next_right_dim = self.tensors[i + 1].right_dim();
209        let next_mat = tensor3_to_right_matrix(&self.tensors[i + 1]);
210
211        // R * next_mat
212        let contracted = mat_mul(&r, &next_mat);
213
214        // Update next tensor
215        let mut new_next_tensor = tensor3_zeros(new_bond_dim, next_site_dim, next_right_dim);
216        for l in 0..new_bond_dim {
217            for s in 0..next_site_dim {
218                for r_idx in 0..next_right_dim {
219                    new_next_tensor.set3(l, s, r_idx, contracted[[l, s * next_right_dim + r_idx]]);
220                }
221            }
222        }
223        self.tensors[i + 1] = new_next_tensor;
224    }
225
226    /// Make tensor at site i right-orthogonal, pushing L to site i-1
227    fn make_right_orthogonal(&mut self, i: usize) {
228        if i == 0 {
229            return;
230        }
231
232        let site_dim = self.tensors[i].site_dim();
233        let right_dim = self.tensors[i].right_dim();
234
235        // Reshape to (left_dim, site_dim * right_dim)
236        let mat = tensor3_to_right_matrix(&self.tensors[i]);
237        let (l_mat, q) = lq_decomp(&mat);
238
239        let new_bond_dim = nrows(&q);
240
241        // Update current tensor with Q
242        let mut new_tensor = tensor3_zeros(new_bond_dim, site_dim, right_dim);
243        for l in 0..new_bond_dim {
244            for s in 0..site_dim {
245                for r in 0..right_dim {
246                    new_tensor.set3(l, s, r, q[[l, s * right_dim + r]]);
247                }
248            }
249        }
250        self.tensors[i] = new_tensor;
251
252        // Contract previous tensor with L
253        let prev_left_dim = self.tensors[i - 1].left_dim();
254        let prev_site_dim = self.tensors[i - 1].site_dim();
255        let prev_mat = tensor3_to_left_matrix(&self.tensors[i - 1]);
256
257        // prev_mat * L
258        let contracted = mat_mul(&prev_mat, &l_mat);
259
260        // Update previous tensor
261        let mut new_prev_tensor = tensor3_zeros(prev_left_dim, prev_site_dim, new_bond_dim);
262        for l in 0..prev_left_dim {
263            for s in 0..prev_site_dim {
264                for r in 0..new_bond_dim {
265                    new_prev_tensor.set3(l, s, r, contracted[[l * prev_site_dim + s, r]]);
266                }
267            }
268        }
269        self.tensors[i - 1] = new_prev_tensor;
270    }
271
272    /// Move the center one position to the right
273    pub fn move_center_right(&mut self) -> Result<()> {
274        if self.center >= self.len() - 1 {
275            return Err(TensorTrainError::InvalidOperation {
276                message: "Cannot move center right: already at rightmost position".to_string(),
277            });
278        }
279
280        self.make_left_orthogonal(self.center);
281        self.center += 1;
282        Ok(())
283    }
284
285    /// Move the center one position to the left
286    pub fn move_center_left(&mut self) -> Result<()> {
287        if self.center == 0 {
288            return Err(TensorTrainError::InvalidOperation {
289                message: "Cannot move center left: already at leftmost position".to_string(),
290            });
291        }
292
293        self.make_right_orthogonal(self.center);
294        self.center -= 1;
295        Ok(())
296    }
297
298    /// Move the center to a specific position
299    pub fn set_center(&mut self, new_center: usize) -> Result<()> {
300        if new_center >= self.len() {
301            return Err(TensorTrainError::InvalidOperation {
302                message: format!(
303                    "New center {} is out of range for {} tensors",
304                    new_center,
305                    self.len()
306                ),
307            });
308        }
309
310        while self.center < new_center {
311            self.move_center_right()?;
312        }
313        while self.center > new_center {
314            self.move_center_left()?;
315        }
316        Ok(())
317    }
318
319    /// Convert to a regular TensorTrain
320    pub fn to_tensor_train(&self) -> TensorTrain<T> {
321        TensorTrain::from_tensors_unchecked(self.tensors.clone())
322    }
323
324    /// Set the tensor at a specific site
325    ///
326    /// Note: This may invalidate the canonical form. Use with caution.
327    pub fn set_site_tensor(&mut self, i: usize, tensor: Tensor3<T>) {
328        self.tensors[i] = tensor;
329    }
330
331    /// Set two adjacent site tensors (useful for TEBD-like algorithms)
332    pub fn set_two_site_tensors(
333        &mut self,
334        i: usize,
335        tensor1: Tensor3<T>,
336        tensor2: Tensor3<T>,
337    ) -> Result<()> {
338        if i >= self.len() - 1 {
339            return Err(TensorTrainError::InvalidOperation {
340                message: format!(
341                    "Cannot set two-site tensors at site {} (max {})",
342                    i,
343                    self.len() - 2
344                ),
345            });
346        }
347
348        self.tensors[i] = tensor1;
349        self.tensors[i + 1] = tensor2;
350        Ok(())
351    }
352}
353
354impl<T: TTScalar + Scalar + Default> AbstractTensorTrain<T> for SiteTensorTrain<T> {
355    fn len(&self) -> usize {
356        self.tensors.len()
357    }
358
359    fn site_tensor(&self, i: usize) -> &Tensor3<T> {
360        &self.tensors[i]
361    }
362
363    fn site_tensors(&self) -> &[Tensor3<T>] {
364        &self.tensors
365    }
366}
367
368/// Center canonicalize a vector of tensors in place
369///
370/// After this call, tensors at indices `< center` are left-orthogonal and
371/// tensors at indices `> center` are right-orthogonal. The overall tensor
372/// train represented by the slice is unchanged.
373///
374/// # Examples
375///
376/// ```
377/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, tensor3_zeros, Tensor3Ops};
378/// use tensor4all_simplett::canonical::center_canonicalize;
379///
380/// // Start with a simple 3-site constant TT.
381/// let tt = TensorTrain::<f64>::constant(&[2, 3, 2], 1.0);
382/// let mut tensors: Vec<_> = tt.site_tensors().to_vec();
383///
384/// // Canonicalize around site 1.
385/// center_canonicalize(&mut tensors, 1);
386///
387/// // Rebuild TT from the canonicalized tensors; values are preserved.
388/// let tt2 = TensorTrain::new(tensors).unwrap();
389/// let val = tt2.evaluate(&[0, 1, 0]).unwrap();
390/// assert!((val - 1.0).abs() < 1e-12);
391/// ```
392pub fn center_canonicalize<T: TTScalar + Scalar + Default>(
393    tensors: &mut [Tensor3<T>],
394    center: usize,
395) {
396    let n = tensors.len();
397    if n <= 1 || center >= n {
398        return;
399    }
400
401    // Left sweep: make tensors [0..center) left-orthogonal
402    for i in 0..center {
403        let left_dim = tensors[i].left_dim();
404        let site_dim = tensors[i].site_dim();
405
406        let mat = tensor3_to_left_matrix(&tensors[i]);
407        let (q, r) = qr_decomp(&mat);
408
409        let new_bond_dim = ncols(&q);
410
411        // Update current tensor
412        let mut new_tensor = tensor3_zeros(left_dim, site_dim, new_bond_dim);
413        for l in 0..left_dim {
414            for s in 0..site_dim {
415                for b in 0..new_bond_dim {
416                    let row = l * site_dim + s;
417                    if row < nrows(&q) && b < ncols(&q) {
418                        new_tensor.set3(l, s, b, q[[row, b]]);
419                    }
420                }
421            }
422        }
423        tensors[i] = new_tensor;
424
425        // Contract R with next tensor
426        if i + 1 < n {
427            let next_site_dim = tensors[i + 1].site_dim();
428            let next_right_dim = tensors[i + 1].right_dim();
429            let next_mat = tensor3_to_right_matrix(&tensors[i + 1]);
430
431            let contracted = mat_mul(&r, &next_mat);
432
433            let mut new_next_tensor = tensor3_zeros(new_bond_dim, next_site_dim, next_right_dim);
434            for l in 0..new_bond_dim {
435                for s in 0..next_site_dim {
436                    for r_idx in 0..next_right_dim {
437                        new_next_tensor.set3(
438                            l,
439                            s,
440                            r_idx,
441                            contracted[[l, s * next_right_dim + r_idx]],
442                        );
443                    }
444                }
445            }
446            tensors[i + 1] = new_next_tensor;
447        }
448    }
449
450    // Right sweep: make tensors (center..n] right-orthogonal
451    for i in (center + 1..n).rev() {
452        let site_dim = tensors[i].site_dim();
453        let right_dim = tensors[i].right_dim();
454
455        let mat = tensor3_to_right_matrix(&tensors[i]);
456        let (l_mat, q) = lq_decomp(&mat);
457
458        let new_bond_dim = nrows(&q);
459
460        // Update current tensor
461        let mut new_tensor = tensor3_zeros(new_bond_dim, site_dim, right_dim);
462        for l in 0..new_bond_dim {
463            for s in 0..site_dim {
464                for r in 0..right_dim {
465                    new_tensor.set3(l, s, r, q[[l, s * right_dim + r]]);
466                }
467            }
468        }
469        tensors[i] = new_tensor;
470
471        // Contract L with previous tensor
472        if i > 0 {
473            let prev_left_dim = tensors[i - 1].left_dim();
474            let prev_site_dim = tensors[i - 1].site_dim();
475            let prev_mat = tensor3_to_left_matrix(&tensors[i - 1]);
476
477            let contracted = mat_mul(&prev_mat, &l_mat);
478
479            let mut new_prev_tensor = tensor3_zeros(prev_left_dim, prev_site_dim, new_bond_dim);
480            for l in 0..prev_left_dim {
481                for s in 0..prev_site_dim {
482                    for r in 0..new_bond_dim {
483                        new_prev_tensor.set3(l, s, r, contracted[[l * prev_site_dim + s, r]]);
484                    }
485                }
486            }
487            tensors[i - 1] = new_prev_tensor;
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests;