Skip to main content

tensor4all_simplett/
traits.rs

1//! Abstract traits for tensor train objects
2
3use crate::error::Result;
4use crate::types::{LocalIndex, Tensor3, Tensor3Ops};
5use tenferro_algebra::Scalar as TfScalar;
6use tenferro_tensor::TensorScalar;
7
8/// Scalar trait bound shared by all simplett tensor types.
9///
10/// Combines [`tensor4all_core::CommonScalar`] (arithmetic, conversion) with
11/// [`tenferro_algebra::Scalar`] (backend compatibility). Both `f64` and
12/// `Complex64` implement this trait.
13///
14/// # Examples
15///
16/// ```
17/// use tensor4all_simplett::TTScalar;
18///
19/// // f64 satisfies TTScalar
20/// fn uses_ttscalar<T: TTScalar>(x: T, y: T) -> T { x + y }
21///
22/// let result = uses_ttscalar(1.0_f64, 2.0_f64);
23/// assert!((result - 3.0).abs() < 1e-15);
24///
25/// // Complex64 also satisfies TTScalar
26/// use num_complex::Complex64;
27/// let c = uses_ttscalar(Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0));
28/// assert!((c.re - 1.0).abs() < 1e-15);
29/// assert!((c.im - 1.0).abs() < 1e-15);
30/// ```
31pub trait TTScalar: tensor4all_core::CommonScalar + TfScalar + TensorScalar {}
32
33impl<T> TTScalar for T where T: tensor4all_core::CommonScalar + TfScalar + TensorScalar {}
34
35/// Common interface implemented by all tensor train representations.
36///
37/// Provides read-only access to site tensors plus derived operations:
38/// [`evaluate`](Self::evaluate), [`sum`](Self::sum),
39/// [`norm`](Self::norm), and [`log_norm`](Self::log_norm).
40///
41/// # Implementors
42///
43/// - [`TensorTrain`](crate::TensorTrain) -- primary container
44/// - [`SiteTensorTrain`](crate::SiteTensorTrain) -- center-canonical form
45/// - [`VidalTensorTrain`](crate::VidalTensorTrain) -- Vidal form (after
46///   conversion to `TensorTrain`)
47///
48/// # Examples
49///
50/// ```
51/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
52///
53/// // TensorTrain implements AbstractTensorTrain.
54/// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 1.0);
55///
56/// // Query structure
57/// assert_eq!(tt.len(), 3);
58/// assert!(!tt.is_empty());
59/// assert_eq!(tt.site_dims(), vec![2, 3, 4]);
60/// assert_eq!(tt.site_dim(1), 3);
61/// assert_eq!(tt.link_dims(), vec![1, 1]);
62///
63/// // Evaluate, sum, and norm
64/// let val = tt.evaluate(&[0, 0, 0]).unwrap();
65/// assert!((val - 1.0).abs() < 1e-12);
66///
67/// let s = tt.sum();
68/// assert!((s - 24.0).abs() < 1e-10);
69///
70/// let n = tt.norm();
71/// assert!((n - 24.0_f64.sqrt()).abs() < 1e-10);
72/// ```
73pub trait AbstractTensorTrain<T: TTScalar>: Sized {
74    /// Number of sites (core tensors) in the tensor train.
75    fn len(&self) -> usize;
76
77    /// Returns `true` if the tensor train has zero sites.
78    fn is_empty(&self) -> bool {
79        self.len() == 0
80    }
81
82    /// Borrow the rank-3 core tensor at site `i`.
83    fn site_tensor(&self, i: usize) -> &Tensor3<T>;
84
85    /// Borrow all core tensors as a slice.
86    fn site_tensors(&self) -> &[Tensor3<T>];
87
88    /// Bond dimensions at every link (length = `len() - 1`).
89    fn link_dims(&self) -> Vec<usize> {
90        if self.len() <= 1 {
91            return Vec::new();
92        }
93        (1..self.len())
94            .map(|i| self.site_tensor(i).left_dim())
95            .collect()
96    }
97
98    /// Bond dimension at the link between site `i` and site `i+1`.
99    fn link_dim(&self, i: usize) -> usize {
100        self.site_tensor(i + 1).left_dim()
101    }
102
103    /// Physical (site) dimensions for every site.
104    fn site_dims(&self) -> Vec<usize> {
105        (0..self.len())
106            .map(|i| self.site_tensor(i).site_dim())
107            .collect()
108    }
109
110    /// Physical (site) dimension at site `i`.
111    fn site_dim(&self, i: usize) -> usize {
112        self.site_tensor(i).site_dim()
113    }
114
115    /// Maximum bond dimension across all links.
116    fn rank(&self) -> usize {
117        let lds = self.link_dims();
118        if lds.is_empty() {
119            1
120        } else {
121            *lds.iter().max().unwrap_or(&1)
122        }
123    }
124
125    /// Evaluate the tensor train at a given index set
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
131    ///
132    /// // Constant TT: all values are 5.0
133    /// let tt = TensorTrain::<f64>::constant(&[3, 4], 5.0);
134    ///
135    /// let val = tt.evaluate(&[1, 2]).unwrap();
136    /// assert!((val - 5.0).abs() < 1e-12);
137    ///
138    /// // Wrong number of indices returns an error
139    /// assert!(tt.evaluate(&[0]).is_err());
140    /// ```
141    fn evaluate(&self, indices: &[LocalIndex]) -> Result<T> {
142        use crate::error::TensorTrainError;
143
144        if indices.len() != self.len() {
145            return Err(TensorTrainError::IndexLengthMismatch {
146                expected: self.len(),
147                got: indices.len(),
148            });
149        }
150
151        if self.is_empty() {
152            return Err(TensorTrainError::Empty);
153        }
154
155        // Start with the first tensor slice
156        let first = self.site_tensor(0);
157        let idx0 = indices[0];
158        if idx0 >= first.site_dim() {
159            return Err(TensorTrainError::IndexOutOfBounds {
160                site: 0,
161                index: idx0,
162                max: first.site_dim(),
163            });
164        }
165
166        // Vector of size right_dim
167        let mut current: Vec<T> = first.slice_site(idx0);
168
169        // Contract with remaining tensors
170        for (site, &idx) in indices.iter().enumerate().skip(1) {
171            let tensor = self.site_tensor(site);
172            if idx >= tensor.site_dim() {
173                return Err(TensorTrainError::IndexOutOfBounds {
174                    site,
175                    index: idx,
176                    max: tensor.site_dim(),
177                });
178            }
179
180            let slice = tensor.slice_site(idx);
181            let left_dim = tensor.left_dim();
182            let right_dim = tensor.right_dim();
183
184            // Contract: current (of size left_dim) with slice (left_dim x right_dim)
185            let mut next = vec![T::zero(); right_dim];
186            for r in 0..right_dim {
187                let mut sum = T::zero();
188                for l in 0..left_dim {
189                    sum = sum + current[l] * slice[l * right_dim + r];
190                }
191                next[r] = sum;
192            }
193            current = next;
194        }
195
196        // Should have a single element
197        if current.len() != 1 {
198            return Err(TensorTrainError::InvalidOperation {
199                message: format!(
200                    "Final contraction resulted in {} elements, expected 1",
201                    current.len()
202                ),
203            });
204        }
205
206        Ok(current[0])
207    }
208
209    /// Sum over all indices of the tensor train
210    ///
211    /// # Examples
212    ///
213    /// ```
214    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
215    ///
216    /// // Constant TT with value 2.0 over 3×4 grid: sum = 2.0 * 3 * 4 = 24.0
217    /// let tt = TensorTrain::<f64>::constant(&[3, 4], 2.0);
218    /// let s = tt.sum();
219    /// assert!((s - 24.0).abs() < 1e-10);
220    ///
221    /// // Zero TT sums to 0.0
222    /// let zero_tt = TensorTrain::<f64>::zeros(&[2, 3]);
223    /// assert!((zero_tt.sum() - 0.0).abs() < 1e-12);
224    /// ```
225    #[allow(clippy::needless_range_loop)]
226    fn sum(&self) -> T {
227        if self.is_empty() {
228            return T::zero();
229        }
230
231        // Start with sum over first tensor
232        let first = self.site_tensor(0);
233        let mut current = vec![T::zero(); first.right_dim()];
234        for s in 0..first.site_dim() {
235            for r in 0..first.right_dim() {
236                current[r] = current[r] + *first.get3(0, s, r);
237            }
238        }
239
240        // Contract with sums of remaining tensors
241        for site in 1..self.len() {
242            let tensor = self.site_tensor(site);
243            let left_dim = tensor.left_dim();
244            let right_dim = tensor.right_dim();
245
246            // Sum over site index
247            let mut site_sum = vec![T::zero(); left_dim * right_dim];
248            for l in 0..left_dim {
249                for s in 0..tensor.site_dim() {
250                    for r in 0..right_dim {
251                        site_sum[l * right_dim + r] =
252                            site_sum[l * right_dim + r] + *tensor.get3(l, s, r);
253                    }
254                }
255            }
256
257            // Contract with current
258            let mut next = vec![T::zero(); right_dim];
259            for r in 0..right_dim {
260                let mut sum = T::zero();
261                for l in 0..left_dim {
262                    sum = sum + current[l] * site_sum[l * right_dim + r];
263                }
264                next[r] = sum;
265            }
266            current = next;
267        }
268
269        current[0]
270    }
271
272    /// Squared Frobenius norm: `sum_i |T[i]|^2`.
273    ///
274    /// # Examples
275    ///
276    /// ```
277    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
278    ///
279    /// // Constant TT: T[i,j] = 2.0 on a 3x4 grid
280    /// let tt = TensorTrain::<f64>::constant(&[3, 4], 2.0);
281    /// // norm^2 = 2^2 * 3 * 4 = 48
282    /// assert!((tt.norm2() - 48.0).abs() < 1e-10);
283    /// ```
284    fn norm2(&self) -> f64 {
285        if self.is_empty() {
286            return 0.0;
287        }
288
289        // Contract tt with its conjugate
290        // result[la, la_conj, ra, ra_conj] at each step
291        let first = self.site_tensor(0);
292        let right_dim = first.right_dim();
293
294        // current[ra, ra_conj] = sum_s first[0, s, ra] * conj(first[0, s, ra_conj])
295        let mut current = vec![T::zero(); right_dim * right_dim];
296        for s in 0..first.site_dim() {
297            for ra in 0..right_dim {
298                for ra_c in 0..right_dim {
299                    let idx = ra * right_dim + ra_c;
300                    current[idx] =
301                        current[idx] + *first.get3(0, s, ra) * first.get3(0, s, ra_c).conj();
302                }
303            }
304        }
305
306        // Contract through remaining tensors
307        for site in 1..self.len() {
308            let tensor = self.site_tensor(site);
309            let left_dim = tensor.left_dim();
310            let right_dim = tensor.right_dim();
311
312            // new_current[ra, ra_conj] = sum_{la, la_conj, s}
313            //     current[la, la_conj] * tensor[la, s, ra] * conj(tensor[la_conj, s, ra_conj])
314            let mut new_current = vec![T::zero(); right_dim * right_dim];
315
316            for la in 0..left_dim {
317                for la_c in 0..left_dim {
318                    let c_val = current[la * left_dim + la_c];
319                    for s in 0..tensor.site_dim() {
320                        for ra in 0..right_dim {
321                            for ra_c in 0..right_dim {
322                                let idx = ra * right_dim + ra_c;
323                                new_current[idx] = new_current[idx]
324                                    + c_val
325                                        * *tensor.get3(la, s, ra)
326                                        * tensor.get3(la_c, s, ra_c).conj();
327                            }
328                        }
329                    }
330                }
331            }
332
333            current = new_current;
334        }
335
336        // Final result should be a single element (1x1)
337        current[0].abs_sq().sqrt()
338    }
339
340    /// Frobenius norm: `sqrt(sum_i |T[i]|^2)`.
341    ///
342    /// # Examples
343    ///
344    /// ```
345    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
346    ///
347    /// let tt = TensorTrain::<f64>::constant(&[3, 4], 2.0);
348    /// // norm = sqrt(48) ~ 6.928
349    /// assert!((tt.norm() - 48.0_f64.sqrt()).abs() < 1e-10);
350    /// ```
351    fn norm(&self) -> f64 {
352        self.norm2().sqrt()
353    }
354
355    /// Logarithm of the Frobenius norm: `ln(norm())`.
356    ///
357    /// This is more numerically stable than `norm().ln()` for tensor trains
358    /// with very large or very small norms, because it normalizes at each
359    /// contraction step to avoid overflow/underflow.
360    ///
361    /// Returns `f64::NEG_INFINITY` for zero tensor trains.
362    ///
363    /// # Examples
364    ///
365    /// ```
366    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
367    ///
368    /// let tt = TensorTrain::<f64>::constant(&[3, 4], 2.0);
369    /// let log_n = tt.log_norm();
370    /// assert!((log_n - tt.norm().ln()).abs() < 1e-10);
371    ///
372    /// // Zero TT returns negative infinity
373    /// let zero_tt = TensorTrain::<f64>::zeros(&[2, 3]);
374    /// assert_eq!(zero_tt.log_norm(), f64::NEG_INFINITY);
375    /// ```
376    fn log_norm(&self) -> f64 {
377        if self.is_empty() {
378            return f64::NEG_INFINITY;
379        }
380
381        // Contract tt with its conjugate, normalizing at each step
382        // to avoid overflow/underflow
383        let first = self.site_tensor(0);
384        let right_dim = first.right_dim();
385
386        // current[ra, ra_conj] = sum_s first[0, s, ra] * conj(first[0, s, ra_conj])
387        let mut current = vec![T::zero(); right_dim * right_dim];
388        for s in 0..first.site_dim() {
389            for ra in 0..right_dim {
390                for ra_c in 0..right_dim {
391                    let idx = ra * right_dim + ra_c;
392                    current[idx] =
393                        current[idx] + *first.get3(0, s, ra) * first.get3(0, s, ra_c).conj();
394                }
395            }
396        }
397
398        // Normalize and accumulate log scale
399        let mut log_scale = 0.0;
400        let scale = current
401            .iter()
402            .map(|x| x.abs_sq())
403            .fold(0.0, f64::max)
404            .sqrt();
405        if scale > 0.0 {
406            log_scale += scale.ln();
407            let inv_scale = T::one() / T::from_f64(scale);
408            for val in &mut current {
409                *val = *val * inv_scale;
410            }
411        } else if self.len() == 1 {
412            // Single site with zero norm
413            return f64::NEG_INFINITY;
414        }
415
416        // Contract through remaining tensors
417        for site in 1..self.len() {
418            let tensor = self.site_tensor(site);
419            let left_dim = tensor.left_dim();
420            let right_dim = tensor.right_dim();
421
422            // new_current[ra, ra_conj] = sum_{la, la_conj, s}
423            //     current[la, la_conj] * tensor[la, s, ra] * conj(tensor[la_conj, s, ra_conj])
424            let mut new_current = vec![T::zero(); right_dim * right_dim];
425
426            for la in 0..left_dim {
427                for la_c in 0..left_dim {
428                    let c_val = current[la * left_dim + la_c];
429                    for s in 0..tensor.site_dim() {
430                        for ra in 0..right_dim {
431                            for ra_c in 0..right_dim {
432                                let idx = ra * right_dim + ra_c;
433                                new_current[idx] = new_current[idx]
434                                    + c_val
435                                        * *tensor.get3(la, s, ra)
436                                        * tensor.get3(la_c, s, ra_c).conj();
437                            }
438                        }
439                    }
440                }
441            }
442
443            // Normalize and accumulate log scale
444            let scale = new_current
445                .iter()
446                .map(|x| x.abs_sq())
447                .fold(0.0, f64::max)
448                .sqrt();
449            if scale > 0.0 {
450                log_scale += scale.ln();
451                let inv_scale = T::one() / T::from_f64(scale);
452                for val in &mut new_current {
453                    *val = *val * inv_scale;
454                }
455            }
456
457            current = new_current;
458        }
459
460        // Final result: norm = sqrt(norm2), where norm2 = final_val * cumulative_scale
461        // log(norm) = 0.5 * log(norm2) = 0.5 * (log(final_val) + log_scale)
462        let final_val = current[0].abs_sq().sqrt();
463        if final_val > 0.0 {
464            0.5 * (final_val.ln() + log_scale)
465        } else {
466            f64::NEG_INFINITY
467        }
468    }
469}