Skip to main content

tensor4all_simplett/
tensortrain.rs

1//! TensorTrain implementation
2
3use crate::error::{Result, TensorTrainError};
4use crate::traits::{AbstractTensorTrain, TTScalar};
5use crate::types::{tensor3_zeros, Tensor3, Tensor3Ops};
6
7/// Tensor Train (Matrix Product State) representation.
8///
9/// A tensor train decomposes a high-dimensional tensor `T[i0, i1, ..., i_{L-1}]`
10/// into a chain of rank-3 core tensors:
11///
12/// ```text
13/// T[i0, i1, ..., i_{L-1}] = A0[i0] * A1[i1] * ... * A_{L-1}[i_{L-1}]
14/// ```
15///
16/// where each core `Ak` has shape `(r_{k-1}, d_k, r_k)` with:
17/// - `r_k` = bond dimension (link between site `k` and `k+1`),
18/// - `d_k` = physical (site) dimension at site `k`,
19/// - `r_{-1} = r_{L-1} = 1` (boundary condition).
20///
21/// # Construction
22///
23/// - [`TensorTrain::constant`] -- all entries equal to a given value
24/// - [`TensorTrain::zeros`] -- all entries zero
25/// - [`TensorTrain::new`] -- from explicit rank-3 core tensors
26///
27/// # Related types
28///
29/// - [`CompressionOptions`](crate::CompressionOptions) -- configure compression
30/// - [`TTCache`](crate::TTCache) -- cached evaluation
31/// - [`SiteTensorTrain`](crate::SiteTensorTrain) -- center-canonical form
32/// - [`VidalTensorTrain`](crate::VidalTensorTrain) -- Vidal canonical form
33///
34/// # Examples
35///
36/// ```
37/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
38///
39/// // Create a constant tensor train: T[i,j,k] = 3.0 for all i,j,k
40/// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 3.0);
41///
42/// assert_eq!(tt.len(), 3);
43/// assert_eq!(tt.site_dims(), vec![2, 3, 4]);
44/// assert_eq!(tt.link_dims(), vec![1, 1]); // bond dim = 1 for constant
45///
46/// // Evaluate at a specific index
47/// let val = tt.evaluate(&[0, 1, 2]).unwrap();
48/// assert!((val - 3.0).abs() < 1e-12);
49///
50/// // Sum over all indices: 3.0 * 2 * 3 * 4 = 72.0
51/// let s = tt.sum();
52/// assert!((s - 72.0).abs() < 1e-10);
53/// ```
54#[derive(Debug, Clone)]
55pub struct TensorTrain<T: TTScalar> {
56    /// The tensors that make up the tensor train
57    /// Each tensor has shape (left_bond, site_dim, right_bond)
58    tensors: Vec<Tensor3<T>>,
59}
60
61impl<T: TTScalar> TensorTrain<T> {
62    /// Create a new tensor train from a list of rank-3 core tensors.
63    ///
64    /// Each tensor must have shape `(left_bond, site_dim, right_bond)` where
65    /// the `right_bond` of tensor `i` equals the `left_bond` of tensor `i+1`.
66    /// The first tensor must have `left_bond = 1` and the last must have
67    /// `right_bond = 1`.
68    ///
69    /// # Errors
70    ///
71    /// Returns [`TensorTrainError::DimensionMismatch`] if adjacent bond
72    /// dimensions do not match, or [`TensorTrainError::InvalidOperation`] if
73    /// boundary dimensions are not 1.
74    ///
75    /// # Examples
76    ///
77    /// ```
78    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, Tensor3Ops, tensor3_zeros};
79    ///
80    /// // Build a 2-site TT with bond dimension 1 and site dimensions [2, 3]
81    /// let mut t0 = tensor3_zeros::<f64>(1, 2, 1);
82    /// t0.set3(0, 0, 0, 1.0);
83    /// t0.set3(0, 1, 0, 2.0);
84    ///
85    /// let mut t1 = tensor3_zeros::<f64>(1, 3, 1);
86    /// t1.set3(0, 0, 0, 10.0);
87    /// t1.set3(0, 1, 0, 20.0);
88    /// t1.set3(0, 2, 0, 30.0);
89    ///
90    /// let tt = TensorTrain::new(vec![t0, t1]).unwrap();
91    /// assert_eq!(tt.len(), 2);
92    ///
93    /// // T[0, 2] = 1.0 * 30.0 = 30.0
94    /// let val = tt.evaluate(&[0, 2]).unwrap();
95    /// assert!((val - 30.0).abs() < 1e-12);
96    /// ```
97    pub fn new(tensors: Vec<Tensor3<T>>) -> Result<Self> {
98        // Validate dimensions
99        for i in 0..tensors.len().saturating_sub(1) {
100            if tensors[i].right_dim() != tensors[i + 1].left_dim() {
101                return Err(TensorTrainError::DimensionMismatch { site: i });
102            }
103        }
104
105        // First tensor should have left_dim = 1
106        if !tensors.is_empty() && tensors[0].left_dim() != 1 {
107            return Err(TensorTrainError::InvalidOperation {
108                message: "First tensor must have left dimension 1".to_string(),
109            });
110        }
111
112        // Last tensor should have right_dim = 1
113        if !tensors.is_empty() && tensors.last().unwrap().right_dim() != 1 {
114            return Err(TensorTrainError::InvalidOperation {
115                message: "Last tensor must have right dimension 1".to_string(),
116            });
117        }
118
119        Ok(Self { tensors })
120    }
121
122    /// Create a tensor train from tensors without dimension validation
123    /// (for internal use when dimensions are known to be correct)
124    pub(crate) fn from_tensors_unchecked(tensors: Vec<Tensor3<T>>) -> Self {
125        Self { tensors }
126    }
127
128    /// Create a tensor train where every entry is zero.
129    ///
130    /// The resulting TT has bond dimension 1 at every link.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
136    ///
137    /// let tt = TensorTrain::<f64>::zeros(&[2, 3]);
138    /// assert!((tt.evaluate(&[1, 2]).unwrap()).abs() < 1e-14);
139    /// assert!((tt.sum()).abs() < 1e-14);
140    /// ```
141    pub fn zeros(site_dims: &[usize]) -> Self {
142        let tensors: Vec<Tensor3<T>> = site_dims.iter().map(|&d| tensor3_zeros(1, d, 1)).collect();
143        Self { tensors }
144    }
145
146    /// Create a tensor train where every entry equals `value`.
147    ///
148    /// The resulting TT has bond dimension 1 at every link.
149    ///
150    /// # Examples
151    ///
152    /// ```
153    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
154    ///
155    /// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 5.0);
156    ///
157    /// // Every entry is 5.0
158    /// assert!((tt.evaluate(&[0, 0, 0]).unwrap() - 5.0).abs() < 1e-12);
159    /// assert!((tt.evaluate(&[1, 2, 3]).unwrap() - 5.0).abs() < 1e-12);
160    ///
161    /// // Sum = 5.0 * 2 * 3 * 4 = 120.0
162    /// assert!((tt.sum() - 120.0).abs() < 1e-10);
163    /// ```
164    pub fn constant(site_dims: &[usize], value: T) -> Self {
165        if site_dims.is_empty() {
166            return Self {
167                tensors: Vec::new(),
168            };
169        }
170
171        let n = site_dims.len();
172        let mut tensors = Vec::with_capacity(n);
173
174        // First tensor: all ones
175        let mut first = tensor3_zeros(1, site_dims[0], 1);
176        for s in 0..site_dims[0] {
177            first.set3(0, s, 0, T::one());
178        }
179        tensors.push(first);
180
181        // Middle tensors: all ones (only if n > 2)
182        if n > 2 {
183            for &d in &site_dims[1..n - 1] {
184                let mut tensor = tensor3_zeros(1, d, 1);
185                for s in 0..d {
186                    tensor.set3(0, s, 0, T::one());
187                }
188                tensors.push(tensor);
189            }
190        }
191
192        // Last tensor: multiply by value
193        if n > 1 {
194            let mut last = tensor3_zeros(1, site_dims[n - 1], 1);
195            for s in 0..site_dims[n - 1] {
196                last.set3(0, s, 0, value);
197            }
198            tensors.push(last);
199        } else {
200            // Single site: multiply the first (and only) tensor by value
201            for s in 0..site_dims[0] {
202                let current = *tensors[0].get3(0, s, 0);
203                tensors[0].set3(0, s, 0, current * value);
204            }
205        }
206
207        Self { tensors }
208    }
209
210    /// Get mutable access to the site tensors
211    pub fn site_tensors_mut(&mut self) -> &mut [Tensor3<T>] {
212        &mut self.tensors
213    }
214
215    /// Multiply every entry of the tensor train by `factor` in place.
216    ///
217    /// Only the last core tensor is rescaled, so this is an O(d * r^2) operation
218    /// where `d` is the site dimension and `r` the bond dimension of the last site.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
224    ///
225    /// let mut tt = TensorTrain::<f64>::constant(&[2, 3], 1.0);
226    /// tt.scale(3.0);
227    /// assert!((tt.evaluate(&[0, 0]).unwrap() - 3.0).abs() < 1e-12);
228    /// assert!((tt.sum() - 18.0).abs() < 1e-10);
229    /// ```
230    pub fn scale(&mut self, factor: T) {
231        if !self.tensors.is_empty() {
232            let last = self.tensors.len() - 1;
233            let tensor = &mut self.tensors[last];
234            for l in 0..tensor.left_dim() {
235                for s in 0..tensor.site_dim() {
236                    for r in 0..tensor.right_dim() {
237                        let val = *tensor.get3(l, s, r);
238                        tensor.set3(l, s, r, val * factor);
239                    }
240                }
241            }
242        }
243    }
244
245    /// Return a new tensor train with every entry multiplied by `factor`.
246    ///
247    /// This is the non-mutating version of [`scale`](Self::scale).
248    ///
249    /// # Examples
250    ///
251    /// ```
252    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
253    ///
254    /// let tt = TensorTrain::<f64>::constant(&[2, 3], 1.0);
255    /// let tt2 = tt.scaled(4.0);
256    /// // Original is unchanged
257    /// assert!((tt.evaluate(&[0, 0]).unwrap() - 1.0).abs() < 1e-12);
258    /// // Scaled copy
259    /// assert!((tt2.evaluate(&[0, 0]).unwrap() - 4.0).abs() < 1e-12);
260    /// ```
261    pub fn scaled(&self, factor: T) -> Self {
262        let mut result = self.clone();
263        result.scale(factor);
264        result
265    }
266
267    /// Reverse the order of sites in the tensor train.
268    ///
269    /// The reversed TT satisfies `reversed.evaluate(&[i_{L-1}, ..., i_0]) ==
270    /// original.evaluate(&[i_0, ..., i_{L-1}])`.
271    ///
272    /// # Examples
273    ///
274    /// ```
275    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, Tensor3Ops, tensor3_zeros};
276    ///
277    /// let mut t0 = tensor3_zeros::<f64>(1, 2, 1);
278    /// t0.set3(0, 0, 0, 1.0);
279    /// t0.set3(0, 1, 0, 2.0);
280    /// let mut t1 = tensor3_zeros::<f64>(1, 3, 1);
281    /// t1.set3(0, 0, 0, 10.0);
282    /// t1.set3(0, 1, 0, 20.0);
283    /// t1.set3(0, 2, 0, 30.0);
284    /// let tt = TensorTrain::new(vec![t0, t1]).unwrap();
285    ///
286    /// let rev = tt.reverse();
287    /// assert_eq!(rev.site_dims(), vec![3, 2]);
288    /// // T[0, 1] = 1.0 * 10.0 = 10.0, reversed: T_rev[0, 1] should also be 10.0 (site 0->10, site 1->2)
289    /// // Original: T[1, 0] = 2.0 * 10.0 = 20.0
290    /// // Reversed: T_rev[0, 1] = 20.0
291    /// assert!((rev.evaluate(&[0, 1]).unwrap() - tt.evaluate(&[1, 0]).unwrap()).abs() < 1e-12);
292    /// ```
293    pub fn reverse(&self) -> Self {
294        let mut new_tensors = Vec::with_capacity(self.tensors.len());
295        for tensor in self.tensors.iter().rev() {
296            // Swap left and right dimensions
297            let mut new_tensor =
298                tensor3_zeros(tensor.right_dim(), tensor.site_dim(), tensor.left_dim());
299            for l in 0..tensor.left_dim() {
300                for s in 0..tensor.site_dim() {
301                    for r in 0..tensor.right_dim() {
302                        new_tensor.set3(r, s, l, *tensor.get3(l, s, r));
303                    }
304                }
305            }
306            new_tensors.push(new_tensor);
307        }
308        Self {
309            tensors: new_tensors,
310        }
311    }
312}
313
314impl<T: TTScalar> TensorTrain<T> {
315    /// Materialize the tensor train as a full dense tensor.
316    ///
317    /// Returns `(data, shape)` where `data` is a flat vector in **column-major**
318    /// order and `shape` is the site dimensions. The total number of elements
319    /// is `prod(shape)`.
320    ///
321    /// **Warning:** The full tensor can be extremely large for high-dimensional
322    /// problems. Only use this for small tensors or debugging.
323    ///
324    /// # Examples
325    ///
326    /// ```
327    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
328    ///
329    /// let tt = TensorTrain::<f64>::constant(&[2, 3], 7.0);
330    /// let (data, shape) = tt.fulltensor();
331    /// assert_eq!(shape, vec![2, 3]);
332    /// assert_eq!(data.len(), 6);
333    /// // Every element should be 7.0
334    /// assert!(data.iter().all(|&v| (v - 7.0).abs() < 1e-12));
335    /// ```
336    pub fn fulltensor(&self) -> (Vec<T>, Vec<usize>) {
337        if self.is_empty() {
338            return (Vec::new(), Vec::new());
339        }
340
341        let site_dims: Vec<usize> = self.site_dims();
342        let total_size: usize = site_dims.iter().product();
343
344        if total_size == 0 {
345            return (Vec::new(), site_dims);
346        }
347
348        // Build full tensor by iterating over all indices
349        let mut result = Vec::with_capacity(total_size);
350        let mut indices = vec![0usize; site_dims.len()];
351
352        loop {
353            // Evaluate at current indices
354            if let Ok(val) = self.evaluate(&indices) {
355                result.push(val);
356            } else {
357                result.push(T::zero());
358            }
359
360            // Increment indices in column-major order, leftmost index fastest.
361            let mut carry = true;
362            for i in 0..site_dims.len() {
363                if carry {
364                    indices[i] += 1;
365                    if indices[i] >= site_dims[i] {
366                        indices[i] = 0;
367                    } else {
368                        carry = false;
369                    }
370                }
371            }
372
373            if carry {
374                break; // All indices wrapped around
375            }
376        }
377
378        (result, site_dims)
379    }
380}
381
382impl<T: TTScalar> TensorTrain<T> {
383    /// Sum (trace out) selected site dimensions, returning a lower-order TT.
384    ///
385    /// `dims` is a slice of 0-indexed site positions to sum over. The
386    /// remaining sites keep their original order. If *all* dimensions are
387    /// summed, the result is a 1-site TT wrapping the scalar total.
388    ///
389    /// # Errors
390    ///
391    /// Returns an error if any element of `dims` is out of range.
392    ///
393    /// # Examples
394    ///
395    /// ```
396    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
397    ///
398    /// // 3-site constant TT: T[i,j,k] = 1.0, dims = [2, 3, 4]
399    /// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 1.0);
400    ///
401    /// // Sum over the middle site (index 1): result has dims [2, 4]
402    /// let summed = tt.partial_sum(&[1]).unwrap();
403    /// assert_eq!(summed.site_dims(), vec![2, 4]);
404    ///
405    /// // Each remaining entry = 1.0 * 3 (summed over dim=3)
406    /// let val = summed.evaluate(&[0, 0]).unwrap();
407    /// assert!((val - 3.0).abs() < 1e-12);
408    /// ```
409    pub fn partial_sum(&self, dims: &[usize]) -> Result<TensorTrain<T>> {
410        use tensor4all_tcicore::matrix::{mat_mul, ncols, nrows, zeros as mat_zeros};
411
412        let n = self.len();
413        if n == 0 {
414            return Ok(Self {
415                tensors: Vec::new(),
416            });
417        }
418
419        // Validate dims
420        for &d in dims {
421            if d >= n {
422                return Err(TensorTrainError::InvalidOperation {
423                    message: format!("Dimension {} out of range (0..{})", d, n),
424                });
425            }
426        }
427
428        let mut result_tensors: Vec<Tensor3<T>> = Vec::new();
429
430        // Tprod: accumulator matrix, starts as 1x1 identity
431        let mut tprod = mat_zeros(1, 1);
432        tprod[[0, 0]] = T::one();
433
434        for site in 0..n {
435            let t = self.site_tensor(site);
436            let left_dim = t.left_dim();
437            let site_dim = t.site_dim();
438            let right_dim = t.right_dim();
439
440            if dims.contains(&site) {
441                // Sum over site index: result is (left_dim, right_dim) matrix
442                // sum(T, dims=2)[:, 1, :] in Julia
443                let mut site_sum = mat_zeros(left_dim, right_dim);
444                for l in 0..left_dim {
445                    for r in 0..right_dim {
446                        let mut acc = T::zero();
447                        for s in 0..site_dim {
448                            acc = acc + *t.get3(l, s, r);
449                        }
450                        site_sum[[l, r]] = acc;
451                    }
452                }
453                // Tprod = Tprod * site_sum
454                tprod = mat_mul(&tprod, &site_sum);
455            } else {
456                // Keep this dimension: multiply Tprod into the site tensor
457                // Tprod (tprod_rows, left_dim) * T reshaped to (left_dim, site_dim * right_dim)
458                let tprod_rows = nrows(&tprod);
459                let mut t_reshaped = mat_zeros(left_dim, site_dim * right_dim);
460                for l in 0..left_dim {
461                    for s in 0..site_dim {
462                        for r in 0..right_dim {
463                            t_reshaped[[l, s * right_dim + r]] = *t.get3(l, s, r);
464                        }
465                    }
466                }
467                let product = mat_mul(&tprod, &t_reshaped);
468
469                // Reshape product (tprod_rows, site_dim * right_dim)
470                // into tensor (tprod_rows, site_dim, right_dim)
471                let mut new_tensor = tensor3_zeros(tprod_rows, site_dim, right_dim);
472                for l in 0..tprod_rows {
473                    for s in 0..site_dim {
474                        for r in 0..right_dim {
475                            new_tensor.set3(l, s, r, product[[l, s * right_dim + r]]);
476                        }
477                    }
478                }
479                result_tensors.push(new_tensor);
480
481                // Reset Tprod to identity of size right_dim
482                tprod = mat_zeros(right_dim, right_dim);
483                for i in 0..right_dim {
484                    tprod[[i, i]] = T::one();
485                }
486            }
487        }
488
489        if result_tensors.is_empty() {
490            // All dims summed → return 1-site TT wrapping scalar
491            // tprod should be 1×1
492            let scalar = tprod[[0, 0]];
493            let mut t = tensor3_zeros(1, 1, 1);
494            t.set3(0, 0, 0, scalar);
495            return TensorTrain::new(vec![t]);
496        }
497
498        // Contract final Tprod into last result tensor
499        let last = result_tensors.last().unwrap();
500        let last_left = last.left_dim();
501        let last_site = last.site_dim();
502        let last_right = last.right_dim();
503        let tprod_cols = ncols(&tprod);
504
505        // Reshape last tensor to (last_left * last_site, last_right)
506        let mut last_mat = mat_zeros(last_left * last_site, last_right);
507        for l in 0..last_left {
508            for s in 0..last_site {
509                for r in 0..last_right {
510                    last_mat[[l * last_site + s, r]] = *last.get3(l, s, r);
511                }
512            }
513        }
514
515        // Multiply: last_mat * Tprod → (last_left * last_site, tprod_cols)
516        let contracted = mat_mul(&last_mat, &tprod);
517
518        // Reshape back to tensor (last_left, last_site, tprod_cols)
519        let mut new_last = tensor3_zeros(last_left, last_site, tprod_cols);
520        for l in 0..last_left {
521            for s in 0..last_site {
522                for r in 0..tprod_cols {
523                    new_last.set3(l, s, r, contracted[[l * last_site + s, r]]);
524                }
525            }
526        }
527        *result_tensors.last_mut().unwrap() = new_last;
528
529        TensorTrain::new(result_tensors)
530    }
531}
532
533impl<T: TTScalar> AbstractTensorTrain<T> for TensorTrain<T> {
534    fn len(&self) -> usize {
535        self.tensors.len()
536    }
537
538    fn site_tensor(&self, i: usize) -> &Tensor3<T> {
539        &self.tensors[i]
540    }
541
542    fn site_tensors(&self) -> &[Tensor3<T>] {
543        &self.tensors
544    }
545}
546
547#[cfg(test)]
548mod tests;