tensor4all_simplett/tensortrain.rs
1//! TensorTrain implementation
2
3use crate::error::{Result, TensorTrainError};
4use crate::traits::{AbstractTensorTrain, TTScalar};
5use crate::types::{tensor3_zeros, Tensor3, Tensor3Ops};
6
7/// Tensor Train (Matrix Product State) representation.
8///
9/// A tensor train decomposes a high-dimensional tensor `T[i0, i1, ..., i_{L-1}]`
10/// into a chain of rank-3 core tensors:
11///
12/// ```text
13/// T[i0, i1, ..., i_{L-1}] = A0[i0] * A1[i1] * ... * A_{L-1}[i_{L-1}]
14/// ```
15///
16/// where each core `Ak` has shape `(r_{k-1}, d_k, r_k)` with:
17/// - `r_k` = bond dimension (link between site `k` and `k+1`),
18/// - `d_k` = physical (site) dimension at site `k`,
19/// - `r_{-1} = r_{L-1} = 1` (boundary condition).
20///
21/// # Construction
22///
23/// - [`TensorTrain::constant`] -- all entries equal to a given value
24/// - [`TensorTrain::zeros`] -- all entries zero
25/// - [`TensorTrain::new`] -- from explicit rank-3 core tensors
26///
27/// # Related types
28///
29/// - [`CompressionOptions`](crate::CompressionOptions) -- configure compression
30/// - [`TTCache`](crate::TTCache) -- cached evaluation
31/// - [`SiteTensorTrain`](crate::SiteTensorTrain) -- center-canonical form
32/// - [`VidalTensorTrain`](crate::VidalTensorTrain) -- Vidal canonical form
33///
34/// # Examples
35///
36/// ```
37/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
38///
39/// // Create a constant tensor train: T[i,j,k] = 3.0 for all i,j,k
40/// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 3.0);
41///
42/// assert_eq!(tt.len(), 3);
43/// assert_eq!(tt.site_dims(), vec![2, 3, 4]);
44/// assert_eq!(tt.link_dims(), vec![1, 1]); // bond dim = 1 for constant
45///
46/// // Evaluate at a specific index
47/// let val = tt.evaluate(&[0, 1, 2]).unwrap();
48/// assert!((val - 3.0).abs() < 1e-12);
49///
50/// // Sum over all indices: 3.0 * 2 * 3 * 4 = 72.0
51/// let s = tt.sum();
52/// assert!((s - 72.0).abs() < 1e-10);
53/// ```
54#[derive(Debug, Clone)]
55pub struct TensorTrain<T: TTScalar> {
56 /// The tensors that make up the tensor train
57 /// Each tensor has shape (left_bond, site_dim, right_bond)
58 tensors: Vec<Tensor3<T>>,
59}
60
61impl<T: TTScalar> TensorTrain<T> {
62 /// Create a new tensor train from a list of rank-3 core tensors.
63 ///
64 /// Each tensor must have shape `(left_bond, site_dim, right_bond)` where
65 /// the `right_bond` of tensor `i` equals the `left_bond` of tensor `i+1`.
66 /// The first tensor must have `left_bond = 1` and the last must have
67 /// `right_bond = 1`.
68 ///
69 /// # Errors
70 ///
71 /// Returns [`TensorTrainError::DimensionMismatch`] if adjacent bond
72 /// dimensions do not match, or [`TensorTrainError::InvalidOperation`] if
73 /// boundary dimensions are not 1.
74 ///
75 /// # Examples
76 ///
77 /// ```
78 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, Tensor3Ops, tensor3_zeros};
79 ///
80 /// // Build a 2-site TT with bond dimension 1 and site dimensions [2, 3]
81 /// let mut t0 = tensor3_zeros::<f64>(1, 2, 1);
82 /// t0.set3(0, 0, 0, 1.0);
83 /// t0.set3(0, 1, 0, 2.0);
84 ///
85 /// let mut t1 = tensor3_zeros::<f64>(1, 3, 1);
86 /// t1.set3(0, 0, 0, 10.0);
87 /// t1.set3(0, 1, 0, 20.0);
88 /// t1.set3(0, 2, 0, 30.0);
89 ///
90 /// let tt = TensorTrain::new(vec![t0, t1]).unwrap();
91 /// assert_eq!(tt.len(), 2);
92 ///
93 /// // T[0, 2] = 1.0 * 30.0 = 30.0
94 /// let val = tt.evaluate(&[0, 2]).unwrap();
95 /// assert!((val - 30.0).abs() < 1e-12);
96 /// ```
97 pub fn new(tensors: Vec<Tensor3<T>>) -> Result<Self> {
98 // Validate dimensions
99 for i in 0..tensors.len().saturating_sub(1) {
100 if tensors[i].right_dim() != tensors[i + 1].left_dim() {
101 return Err(TensorTrainError::DimensionMismatch { site: i });
102 }
103 }
104
105 // First tensor should have left_dim = 1
106 if !tensors.is_empty() && tensors[0].left_dim() != 1 {
107 return Err(TensorTrainError::InvalidOperation {
108 message: "First tensor must have left dimension 1".to_string(),
109 });
110 }
111
112 // Last tensor should have right_dim = 1
113 if !tensors.is_empty() && tensors.last().unwrap().right_dim() != 1 {
114 return Err(TensorTrainError::InvalidOperation {
115 message: "Last tensor must have right dimension 1".to_string(),
116 });
117 }
118
119 Ok(Self { tensors })
120 }
121
122 /// Create a tensor train from tensors without dimension validation
123 /// (for internal use when dimensions are known to be correct)
124 pub(crate) fn from_tensors_unchecked(tensors: Vec<Tensor3<T>>) -> Self {
125 Self { tensors }
126 }
127
128 /// Create a tensor train where every entry is zero.
129 ///
130 /// The resulting TT has bond dimension 1 at every link.
131 ///
132 /// # Examples
133 ///
134 /// ```
135 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
136 ///
137 /// let tt = TensorTrain::<f64>::zeros(&[2, 3]);
138 /// assert!((tt.evaluate(&[1, 2]).unwrap()).abs() < 1e-14);
139 /// assert!((tt.sum()).abs() < 1e-14);
140 /// ```
141 pub fn zeros(site_dims: &[usize]) -> Self {
142 let tensors: Vec<Tensor3<T>> = site_dims.iter().map(|&d| tensor3_zeros(1, d, 1)).collect();
143 Self { tensors }
144 }
145
146 /// Create a tensor train where every entry equals `value`.
147 ///
148 /// The resulting TT has bond dimension 1 at every link.
149 ///
150 /// # Examples
151 ///
152 /// ```
153 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
154 ///
155 /// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 5.0);
156 ///
157 /// // Every entry is 5.0
158 /// assert!((tt.evaluate(&[0, 0, 0]).unwrap() - 5.0).abs() < 1e-12);
159 /// assert!((tt.evaluate(&[1, 2, 3]).unwrap() - 5.0).abs() < 1e-12);
160 ///
161 /// // Sum = 5.0 * 2 * 3 * 4 = 120.0
162 /// assert!((tt.sum() - 120.0).abs() < 1e-10);
163 /// ```
164 pub fn constant(site_dims: &[usize], value: T) -> Self {
165 if site_dims.is_empty() {
166 return Self {
167 tensors: Vec::new(),
168 };
169 }
170
171 let n = site_dims.len();
172 let mut tensors = Vec::with_capacity(n);
173
174 // First tensor: all ones
175 let mut first = tensor3_zeros(1, site_dims[0], 1);
176 for s in 0..site_dims[0] {
177 first.set3(0, s, 0, T::one());
178 }
179 tensors.push(first);
180
181 // Middle tensors: all ones (only if n > 2)
182 if n > 2 {
183 for &d in &site_dims[1..n - 1] {
184 let mut tensor = tensor3_zeros(1, d, 1);
185 for s in 0..d {
186 tensor.set3(0, s, 0, T::one());
187 }
188 tensors.push(tensor);
189 }
190 }
191
192 // Last tensor: multiply by value
193 if n > 1 {
194 let mut last = tensor3_zeros(1, site_dims[n - 1], 1);
195 for s in 0..site_dims[n - 1] {
196 last.set3(0, s, 0, value);
197 }
198 tensors.push(last);
199 } else {
200 // Single site: multiply the first (and only) tensor by value
201 for s in 0..site_dims[0] {
202 let current = *tensors[0].get3(0, s, 0);
203 tensors[0].set3(0, s, 0, current * value);
204 }
205 }
206
207 Self { tensors }
208 }
209
210 /// Get mutable access to the site tensors
211 pub fn site_tensors_mut(&mut self) -> &mut [Tensor3<T>] {
212 &mut self.tensors
213 }
214
215 /// Multiply every entry of the tensor train by `factor` in place.
216 ///
217 /// Only the last core tensor is rescaled, so this is an O(d * r^2) operation
218 /// where `d` is the site dimension and `r` the bond dimension of the last site.
219 ///
220 /// # Examples
221 ///
222 /// ```
223 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
224 ///
225 /// let mut tt = TensorTrain::<f64>::constant(&[2, 3], 1.0);
226 /// tt.scale(3.0);
227 /// assert!((tt.evaluate(&[0, 0]).unwrap() - 3.0).abs() < 1e-12);
228 /// assert!((tt.sum() - 18.0).abs() < 1e-10);
229 /// ```
230 pub fn scale(&mut self, factor: T) {
231 if !self.tensors.is_empty() {
232 let last = self.tensors.len() - 1;
233 let tensor = &mut self.tensors[last];
234 for l in 0..tensor.left_dim() {
235 for s in 0..tensor.site_dim() {
236 for r in 0..tensor.right_dim() {
237 let val = *tensor.get3(l, s, r);
238 tensor.set3(l, s, r, val * factor);
239 }
240 }
241 }
242 }
243 }
244
245 /// Return a new tensor train with every entry multiplied by `factor`.
246 ///
247 /// This is the non-mutating version of [`scale`](Self::scale).
248 ///
249 /// # Examples
250 ///
251 /// ```
252 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
253 ///
254 /// let tt = TensorTrain::<f64>::constant(&[2, 3], 1.0);
255 /// let tt2 = tt.scaled(4.0);
256 /// // Original is unchanged
257 /// assert!((tt.evaluate(&[0, 0]).unwrap() - 1.0).abs() < 1e-12);
258 /// // Scaled copy
259 /// assert!((tt2.evaluate(&[0, 0]).unwrap() - 4.0).abs() < 1e-12);
260 /// ```
261 pub fn scaled(&self, factor: T) -> Self {
262 let mut result = self.clone();
263 result.scale(factor);
264 result
265 }
266
267 /// Reverse the order of sites in the tensor train.
268 ///
269 /// The reversed TT satisfies `reversed.evaluate(&[i_{L-1}, ..., i_0]) ==
270 /// original.evaluate(&[i_0, ..., i_{L-1}])`.
271 ///
272 /// # Examples
273 ///
274 /// ```
275 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, Tensor3Ops, tensor3_zeros};
276 ///
277 /// let mut t0 = tensor3_zeros::<f64>(1, 2, 1);
278 /// t0.set3(0, 0, 0, 1.0);
279 /// t0.set3(0, 1, 0, 2.0);
280 /// let mut t1 = tensor3_zeros::<f64>(1, 3, 1);
281 /// t1.set3(0, 0, 0, 10.0);
282 /// t1.set3(0, 1, 0, 20.0);
283 /// t1.set3(0, 2, 0, 30.0);
284 /// let tt = TensorTrain::new(vec![t0, t1]).unwrap();
285 ///
286 /// let rev = tt.reverse();
287 /// assert_eq!(rev.site_dims(), vec![3, 2]);
288 /// // T[0, 1] = 1.0 * 10.0 = 10.0, reversed: T_rev[0, 1] should also be 10.0 (site 0->10, site 1->2)
289 /// // Original: T[1, 0] = 2.0 * 10.0 = 20.0
290 /// // Reversed: T_rev[0, 1] = 20.0
291 /// assert!((rev.evaluate(&[0, 1]).unwrap() - tt.evaluate(&[1, 0]).unwrap()).abs() < 1e-12);
292 /// ```
293 pub fn reverse(&self) -> Self {
294 let mut new_tensors = Vec::with_capacity(self.tensors.len());
295 for tensor in self.tensors.iter().rev() {
296 // Swap left and right dimensions
297 let mut new_tensor =
298 tensor3_zeros(tensor.right_dim(), tensor.site_dim(), tensor.left_dim());
299 for l in 0..tensor.left_dim() {
300 for s in 0..tensor.site_dim() {
301 for r in 0..tensor.right_dim() {
302 new_tensor.set3(r, s, l, *tensor.get3(l, s, r));
303 }
304 }
305 }
306 new_tensors.push(new_tensor);
307 }
308 Self {
309 tensors: new_tensors,
310 }
311 }
312}
313
314impl<T: TTScalar> TensorTrain<T> {
315 /// Materialize the tensor train as a full dense tensor.
316 ///
317 /// Returns `(data, shape)` where `data` is a flat vector in **column-major**
318 /// order and `shape` is the site dimensions. The total number of elements
319 /// is `prod(shape)`.
320 ///
321 /// **Warning:** The full tensor can be extremely large for high-dimensional
322 /// problems. Only use this for small tensors or debugging.
323 ///
324 /// # Examples
325 ///
326 /// ```
327 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
328 ///
329 /// let tt = TensorTrain::<f64>::constant(&[2, 3], 7.0);
330 /// let (data, shape) = tt.fulltensor();
331 /// assert_eq!(shape, vec![2, 3]);
332 /// assert_eq!(data.len(), 6);
333 /// // Every element should be 7.0
334 /// assert!(data.iter().all(|&v| (v - 7.0).abs() < 1e-12));
335 /// ```
336 pub fn fulltensor(&self) -> (Vec<T>, Vec<usize>) {
337 if self.is_empty() {
338 return (Vec::new(), Vec::new());
339 }
340
341 let site_dims: Vec<usize> = self.site_dims();
342 let total_size: usize = site_dims.iter().product();
343
344 if total_size == 0 {
345 return (Vec::new(), site_dims);
346 }
347
348 // Build full tensor by iterating over all indices
349 let mut result = Vec::with_capacity(total_size);
350 let mut indices = vec![0usize; site_dims.len()];
351
352 loop {
353 // Evaluate at current indices
354 if let Ok(val) = self.evaluate(&indices) {
355 result.push(val);
356 } else {
357 result.push(T::zero());
358 }
359
360 // Increment indices in column-major order, leftmost index fastest.
361 let mut carry = true;
362 for i in 0..site_dims.len() {
363 if carry {
364 indices[i] += 1;
365 if indices[i] >= site_dims[i] {
366 indices[i] = 0;
367 } else {
368 carry = false;
369 }
370 }
371 }
372
373 if carry {
374 break; // All indices wrapped around
375 }
376 }
377
378 (result, site_dims)
379 }
380}
381
382impl<T: TTScalar> TensorTrain<T> {
383 /// Sum (trace out) selected site dimensions, returning a lower-order TT.
384 ///
385 /// `dims` is a slice of 0-indexed site positions to sum over. The
386 /// remaining sites keep their original order. If *all* dimensions are
387 /// summed, the result is a 1-site TT wrapping the scalar total.
388 ///
389 /// # Errors
390 ///
391 /// Returns an error if any element of `dims` is out of range.
392 ///
393 /// # Examples
394 ///
395 /// ```
396 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
397 ///
398 /// // 3-site constant TT: T[i,j,k] = 1.0, dims = [2, 3, 4]
399 /// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 1.0);
400 ///
401 /// // Sum over the middle site (index 1): result has dims [2, 4]
402 /// let summed = tt.partial_sum(&[1]).unwrap();
403 /// assert_eq!(summed.site_dims(), vec![2, 4]);
404 ///
405 /// // Each remaining entry = 1.0 * 3 (summed over dim=3)
406 /// let val = summed.evaluate(&[0, 0]).unwrap();
407 /// assert!((val - 3.0).abs() < 1e-12);
408 /// ```
409 pub fn partial_sum(&self, dims: &[usize]) -> Result<TensorTrain<T>> {
410 use tensor4all_tcicore::matrix::{mat_mul, ncols, nrows, zeros as mat_zeros};
411
412 let n = self.len();
413 if n == 0 {
414 return Ok(Self {
415 tensors: Vec::new(),
416 });
417 }
418
419 // Validate dims
420 for &d in dims {
421 if d >= n {
422 return Err(TensorTrainError::InvalidOperation {
423 message: format!("Dimension {} out of range (0..{})", d, n),
424 });
425 }
426 }
427
428 let mut result_tensors: Vec<Tensor3<T>> = Vec::new();
429
430 // Tprod: accumulator matrix, starts as 1x1 identity
431 let mut tprod = mat_zeros(1, 1);
432 tprod[[0, 0]] = T::one();
433
434 for site in 0..n {
435 let t = self.site_tensor(site);
436 let left_dim = t.left_dim();
437 let site_dim = t.site_dim();
438 let right_dim = t.right_dim();
439
440 if dims.contains(&site) {
441 // Sum over site index: result is (left_dim, right_dim) matrix
442 // sum(T, dims=2)[:, 1, :] in Julia
443 let mut site_sum = mat_zeros(left_dim, right_dim);
444 for l in 0..left_dim {
445 for r in 0..right_dim {
446 let mut acc = T::zero();
447 for s in 0..site_dim {
448 acc = acc + *t.get3(l, s, r);
449 }
450 site_sum[[l, r]] = acc;
451 }
452 }
453 // Tprod = Tprod * site_sum
454 tprod = mat_mul(&tprod, &site_sum);
455 } else {
456 // Keep this dimension: multiply Tprod into the site tensor
457 // Tprod (tprod_rows, left_dim) * T reshaped to (left_dim, site_dim * right_dim)
458 let tprod_rows = nrows(&tprod);
459 let mut t_reshaped = mat_zeros(left_dim, site_dim * right_dim);
460 for l in 0..left_dim {
461 for s in 0..site_dim {
462 for r in 0..right_dim {
463 t_reshaped[[l, s * right_dim + r]] = *t.get3(l, s, r);
464 }
465 }
466 }
467 let product = mat_mul(&tprod, &t_reshaped);
468
469 // Reshape product (tprod_rows, site_dim * right_dim)
470 // into tensor (tprod_rows, site_dim, right_dim)
471 let mut new_tensor = tensor3_zeros(tprod_rows, site_dim, right_dim);
472 for l in 0..tprod_rows {
473 for s in 0..site_dim {
474 for r in 0..right_dim {
475 new_tensor.set3(l, s, r, product[[l, s * right_dim + r]]);
476 }
477 }
478 }
479 result_tensors.push(new_tensor);
480
481 // Reset Tprod to identity of size right_dim
482 tprod = mat_zeros(right_dim, right_dim);
483 for i in 0..right_dim {
484 tprod[[i, i]] = T::one();
485 }
486 }
487 }
488
489 if result_tensors.is_empty() {
490 // All dims summed → return 1-site TT wrapping scalar
491 // tprod should be 1×1
492 let scalar = tprod[[0, 0]];
493 let mut t = tensor3_zeros(1, 1, 1);
494 t.set3(0, 0, 0, scalar);
495 return TensorTrain::new(vec![t]);
496 }
497
498 // Contract final Tprod into last result tensor
499 let last = result_tensors.last().unwrap();
500 let last_left = last.left_dim();
501 let last_site = last.site_dim();
502 let last_right = last.right_dim();
503 let tprod_cols = ncols(&tprod);
504
505 // Reshape last tensor to (last_left * last_site, last_right)
506 let mut last_mat = mat_zeros(last_left * last_site, last_right);
507 for l in 0..last_left {
508 for s in 0..last_site {
509 for r in 0..last_right {
510 last_mat[[l * last_site + s, r]] = *last.get3(l, s, r);
511 }
512 }
513 }
514
515 // Multiply: last_mat * Tprod → (last_left * last_site, tprod_cols)
516 let contracted = mat_mul(&last_mat, &tprod);
517
518 // Reshape back to tensor (last_left, last_site, tprod_cols)
519 let mut new_last = tensor3_zeros(last_left, last_site, tprod_cols);
520 for l in 0..last_left {
521 for s in 0..last_site {
522 for r in 0..tprod_cols {
523 new_last.set3(l, s, r, contracted[[l * last_site + s, r]]);
524 }
525 }
526 }
527 *result_tensors.last_mut().unwrap() = new_last;
528
529 TensorTrain::new(result_tensors)
530 }
531}
532
533impl<T: TTScalar> AbstractTensorTrain<T> for TensorTrain<T> {
534 fn len(&self) -> usize {
535 self.tensors.len()
536 }
537
538 fn site_tensor(&self, i: usize) -> &Tensor3<T> {
539 &self.tensors[i]
540 }
541
542 fn site_tensors(&self) -> &[Tensor3<T>] {
543 &self.tensors
544 }
545}
546
547#[cfg(test)]
548mod tests;