tensor4all_simplett/traits.rs
1//! Abstract traits for tensor train objects
2
3use crate::error::Result;
4use crate::types::{LocalIndex, Tensor3, Tensor3Ops};
5use tenferro_algebra::Scalar as TfScalar;
6use tenferro_tensor::TensorScalar;
7
8/// Scalar trait bound shared by all simplett tensor types.
9///
10/// Combines [`tensor4all_core::CommonScalar`] (arithmetic, conversion) with
11/// [`tenferro_algebra::Scalar`] (backend compatibility). Both `f64` and
12/// `Complex64` implement this trait.
13///
14/// # Examples
15///
16/// ```
17/// use tensor4all_simplett::TTScalar;
18///
19/// // f64 satisfies TTScalar
20/// fn uses_ttscalar<T: TTScalar>(x: T, y: T) -> T { x + y }
21///
22/// let result = uses_ttscalar(1.0_f64, 2.0_f64);
23/// assert!((result - 3.0).abs() < 1e-15);
24///
25/// // Complex64 also satisfies TTScalar
26/// use num_complex::Complex64;
27/// let c = uses_ttscalar(Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0));
28/// assert!((c.re - 1.0).abs() < 1e-15);
29/// assert!((c.im - 1.0).abs() < 1e-15);
30/// ```
31pub trait TTScalar: tensor4all_core::CommonScalar + TfScalar + TensorScalar {}
32
33impl<T> TTScalar for T where T: tensor4all_core::CommonScalar + TfScalar + TensorScalar {}
34
35/// Common interface implemented by all tensor train representations.
36///
37/// Provides read-only access to site tensors plus derived operations:
38/// [`evaluate`](Self::evaluate), [`sum`](Self::sum),
39/// [`norm`](Self::norm), and [`log_norm`](Self::log_norm).
40///
41/// # Implementors
42///
43/// - [`TensorTrain`](crate::TensorTrain) -- primary container
44/// - [`SiteTensorTrain`](crate::SiteTensorTrain) -- center-canonical form
45/// - [`VidalTensorTrain`](crate::VidalTensorTrain) -- Vidal form (after
46/// conversion to `TensorTrain`)
47///
48/// # Examples
49///
50/// ```
51/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
52///
53/// // TensorTrain implements AbstractTensorTrain.
54/// let tt = TensorTrain::<f64>::constant(&[2, 3, 4], 1.0);
55///
56/// // Query structure
57/// assert_eq!(tt.len(), 3);
58/// assert!(!tt.is_empty());
59/// assert_eq!(tt.site_dims(), vec![2, 3, 4]);
60/// assert_eq!(tt.site_dim(1), 3);
61/// assert_eq!(tt.link_dims(), vec![1, 1]);
62///
63/// // Evaluate, sum, and norm
64/// let val = tt.evaluate(&[0, 0, 0]).unwrap();
65/// assert!((val - 1.0).abs() < 1e-12);
66///
67/// let s = tt.sum();
68/// assert!((s - 24.0).abs() < 1e-10);
69///
70/// let n = tt.norm();
71/// assert!((n - 24.0_f64.sqrt()).abs() < 1e-10);
72/// ```
73pub trait AbstractTensorTrain<T: TTScalar>: Sized {
74 /// Number of sites (core tensors) in the tensor train.
75 fn len(&self) -> usize;
76
77 /// Returns `true` if the tensor train has zero sites.
78 fn is_empty(&self) -> bool {
79 self.len() == 0
80 }
81
82 /// Borrow the rank-3 core tensor at site `i`.
83 fn site_tensor(&self, i: usize) -> &Tensor3<T>;
84
85 /// Borrow all core tensors as a slice.
86 fn site_tensors(&self) -> &[Tensor3<T>];
87
88 /// Bond dimensions at every link (length = `len() - 1`).
89 fn link_dims(&self) -> Vec<usize> {
90 if self.len() <= 1 {
91 return Vec::new();
92 }
93 (1..self.len())
94 .map(|i| self.site_tensor(i).left_dim())
95 .collect()
96 }
97
98 /// Bond dimension at the link between site `i` and site `i+1`.
99 fn link_dim(&self, i: usize) -> usize {
100 self.site_tensor(i + 1).left_dim()
101 }
102
103 /// Physical (site) dimensions for every site.
104 fn site_dims(&self) -> Vec<usize> {
105 (0..self.len())
106 .map(|i| self.site_tensor(i).site_dim())
107 .collect()
108 }
109
110 /// Physical (site) dimension at site `i`.
111 fn site_dim(&self, i: usize) -> usize {
112 self.site_tensor(i).site_dim()
113 }
114
115 /// Maximum bond dimension across all links.
116 fn rank(&self) -> usize {
117 let lds = self.link_dims();
118 if lds.is_empty() {
119 1
120 } else {
121 *lds.iter().max().unwrap_or(&1)
122 }
123 }
124
125 /// Evaluate the tensor train at a given index set
126 ///
127 /// # Examples
128 ///
129 /// ```
130 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
131 ///
132 /// // Constant TT: all values are 5.0
133 /// let tt = TensorTrain::<f64>::constant(&[3, 4], 5.0);
134 ///
135 /// let val = tt.evaluate(&[1, 2]).unwrap();
136 /// assert!((val - 5.0).abs() < 1e-12);
137 ///
138 /// // Wrong number of indices returns an error
139 /// assert!(tt.evaluate(&[0]).is_err());
140 /// ```
141 fn evaluate(&self, indices: &[LocalIndex]) -> Result<T> {
142 use crate::error::TensorTrainError;
143
144 if indices.len() != self.len() {
145 return Err(TensorTrainError::IndexLengthMismatch {
146 expected: self.len(),
147 got: indices.len(),
148 });
149 }
150
151 if self.is_empty() {
152 return Err(TensorTrainError::Empty);
153 }
154
155 // Start with the first tensor slice
156 let first = self.site_tensor(0);
157 let idx0 = indices[0];
158 if idx0 >= first.site_dim() {
159 return Err(TensorTrainError::IndexOutOfBounds {
160 site: 0,
161 index: idx0,
162 max: first.site_dim(),
163 });
164 }
165
166 // Vector of size right_dim
167 let mut current: Vec<T> = first.slice_site(idx0);
168
169 // Contract with remaining tensors
170 for (site, &idx) in indices.iter().enumerate().skip(1) {
171 let tensor = self.site_tensor(site);
172 if idx >= tensor.site_dim() {
173 return Err(TensorTrainError::IndexOutOfBounds {
174 site,
175 index: idx,
176 max: tensor.site_dim(),
177 });
178 }
179
180 let slice = tensor.slice_site(idx);
181 let left_dim = tensor.left_dim();
182 let right_dim = tensor.right_dim();
183
184 // Contract: current (of size left_dim) with slice (left_dim x right_dim)
185 let mut next = vec![T::zero(); right_dim];
186 for r in 0..right_dim {
187 let mut sum = T::zero();
188 for l in 0..left_dim {
189 sum = sum + current[l] * slice[l * right_dim + r];
190 }
191 next[r] = sum;
192 }
193 current = next;
194 }
195
196 // Should have a single element
197 if current.len() != 1 {
198 return Err(TensorTrainError::InvalidOperation {
199 message: format!(
200 "Final contraction resulted in {} elements, expected 1",
201 current.len()
202 ),
203 });
204 }
205
206 Ok(current[0])
207 }
208
209 /// Sum over all indices of the tensor train
210 ///
211 /// # Examples
212 ///
213 /// ```
214 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
215 ///
216 /// // Constant TT with value 2.0 over 3×4 grid: sum = 2.0 * 3 * 4 = 24.0
217 /// let tt = TensorTrain::<f64>::constant(&[3, 4], 2.0);
218 /// let s = tt.sum();
219 /// assert!((s - 24.0).abs() < 1e-10);
220 ///
221 /// // Zero TT sums to 0.0
222 /// let zero_tt = TensorTrain::<f64>::zeros(&[2, 3]);
223 /// assert!((zero_tt.sum() - 0.0).abs() < 1e-12);
224 /// ```
225 #[allow(clippy::needless_range_loop)]
226 fn sum(&self) -> T {
227 if self.is_empty() {
228 return T::zero();
229 }
230
231 // Start with sum over first tensor
232 let first = self.site_tensor(0);
233 let mut current = vec![T::zero(); first.right_dim()];
234 for s in 0..first.site_dim() {
235 for r in 0..first.right_dim() {
236 current[r] = current[r] + *first.get3(0, s, r);
237 }
238 }
239
240 // Contract with sums of remaining tensors
241 for site in 1..self.len() {
242 let tensor = self.site_tensor(site);
243 let left_dim = tensor.left_dim();
244 let right_dim = tensor.right_dim();
245
246 // Sum over site index
247 let mut site_sum = vec![T::zero(); left_dim * right_dim];
248 for l in 0..left_dim {
249 for s in 0..tensor.site_dim() {
250 for r in 0..right_dim {
251 site_sum[l * right_dim + r] =
252 site_sum[l * right_dim + r] + *tensor.get3(l, s, r);
253 }
254 }
255 }
256
257 // Contract with current
258 let mut next = vec![T::zero(); right_dim];
259 for r in 0..right_dim {
260 let mut sum = T::zero();
261 for l in 0..left_dim {
262 sum = sum + current[l] * site_sum[l * right_dim + r];
263 }
264 next[r] = sum;
265 }
266 current = next;
267 }
268
269 current[0]
270 }
271
272 /// Squared Frobenius norm: `sum_i |T[i]|^2`.
273 ///
274 /// # Examples
275 ///
276 /// ```
277 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
278 ///
279 /// // Constant TT: T[i,j] = 2.0 on a 3x4 grid
280 /// let tt = TensorTrain::<f64>::constant(&[3, 4], 2.0);
281 /// // norm^2 = 2^2 * 3 * 4 = 48
282 /// assert!((tt.norm2() - 48.0).abs() < 1e-10);
283 /// ```
284 fn norm2(&self) -> f64 {
285 if self.is_empty() {
286 return 0.0;
287 }
288
289 // Contract tt with its conjugate
290 // result[la, la_conj, ra, ra_conj] at each step
291 let first = self.site_tensor(0);
292 let right_dim = first.right_dim();
293
294 // current[ra, ra_conj] = sum_s first[0, s, ra] * conj(first[0, s, ra_conj])
295 let mut current = vec![T::zero(); right_dim * right_dim];
296 for s in 0..first.site_dim() {
297 for ra in 0..right_dim {
298 for ra_c in 0..right_dim {
299 let idx = ra * right_dim + ra_c;
300 current[idx] =
301 current[idx] + *first.get3(0, s, ra) * first.get3(0, s, ra_c).conj();
302 }
303 }
304 }
305
306 // Contract through remaining tensors
307 for site in 1..self.len() {
308 let tensor = self.site_tensor(site);
309 let left_dim = tensor.left_dim();
310 let right_dim = tensor.right_dim();
311
312 // new_current[ra, ra_conj] = sum_{la, la_conj, s}
313 // current[la, la_conj] * tensor[la, s, ra] * conj(tensor[la_conj, s, ra_conj])
314 let mut new_current = vec![T::zero(); right_dim * right_dim];
315
316 for la in 0..left_dim {
317 for la_c in 0..left_dim {
318 let c_val = current[la * left_dim + la_c];
319 for s in 0..tensor.site_dim() {
320 for ra in 0..right_dim {
321 for ra_c in 0..right_dim {
322 let idx = ra * right_dim + ra_c;
323 new_current[idx] = new_current[idx]
324 + c_val
325 * *tensor.get3(la, s, ra)
326 * tensor.get3(la_c, s, ra_c).conj();
327 }
328 }
329 }
330 }
331 }
332
333 current = new_current;
334 }
335
336 // Final result should be a single element (1x1)
337 current[0].abs_sq().sqrt()
338 }
339
340 /// Frobenius norm: `sqrt(sum_i |T[i]|^2)`.
341 ///
342 /// # Examples
343 ///
344 /// ```
345 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
346 ///
347 /// let tt = TensorTrain::<f64>::constant(&[3, 4], 2.0);
348 /// // norm = sqrt(48) ~ 6.928
349 /// assert!((tt.norm() - 48.0_f64.sqrt()).abs() < 1e-10);
350 /// ```
351 fn norm(&self) -> f64 {
352 self.norm2().sqrt()
353 }
354
355 /// Logarithm of the Frobenius norm: `ln(norm())`.
356 ///
357 /// This is more numerically stable than `norm().ln()` for tensor trains
358 /// with very large or very small norms, because it normalizes at each
359 /// contraction step to avoid overflow/underflow.
360 ///
361 /// Returns `f64::NEG_INFINITY` for zero tensor trains.
362 ///
363 /// # Examples
364 ///
365 /// ```
366 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
367 ///
368 /// let tt = TensorTrain::<f64>::constant(&[3, 4], 2.0);
369 /// let log_n = tt.log_norm();
370 /// assert!((log_n - tt.norm().ln()).abs() < 1e-10);
371 ///
372 /// // Zero TT returns negative infinity
373 /// let zero_tt = TensorTrain::<f64>::zeros(&[2, 3]);
374 /// assert_eq!(zero_tt.log_norm(), f64::NEG_INFINITY);
375 /// ```
376 fn log_norm(&self) -> f64 {
377 if self.is_empty() {
378 return f64::NEG_INFINITY;
379 }
380
381 // Contract tt with its conjugate, normalizing at each step
382 // to avoid overflow/underflow
383 let first = self.site_tensor(0);
384 let right_dim = first.right_dim();
385
386 // current[ra, ra_conj] = sum_s first[0, s, ra] * conj(first[0, s, ra_conj])
387 let mut current = vec![T::zero(); right_dim * right_dim];
388 for s in 0..first.site_dim() {
389 for ra in 0..right_dim {
390 for ra_c in 0..right_dim {
391 let idx = ra * right_dim + ra_c;
392 current[idx] =
393 current[idx] + *first.get3(0, s, ra) * first.get3(0, s, ra_c).conj();
394 }
395 }
396 }
397
398 // Normalize and accumulate log scale
399 let mut log_scale = 0.0;
400 let scale = current
401 .iter()
402 .map(|x| x.abs_sq())
403 .fold(0.0, f64::max)
404 .sqrt();
405 if scale > 0.0 {
406 log_scale += scale.ln();
407 let inv_scale = T::one() / T::from_f64(scale);
408 for val in &mut current {
409 *val = *val * inv_scale;
410 }
411 } else if self.len() == 1 {
412 // Single site with zero norm
413 return f64::NEG_INFINITY;
414 }
415
416 // Contract through remaining tensors
417 for site in 1..self.len() {
418 let tensor = self.site_tensor(site);
419 let left_dim = tensor.left_dim();
420 let right_dim = tensor.right_dim();
421
422 // new_current[ra, ra_conj] = sum_{la, la_conj, s}
423 // current[la, la_conj] * tensor[la, s, ra] * conj(tensor[la_conj, s, ra_conj])
424 let mut new_current = vec![T::zero(); right_dim * right_dim];
425
426 for la in 0..left_dim {
427 for la_c in 0..left_dim {
428 let c_val = current[la * left_dim + la_c];
429 for s in 0..tensor.site_dim() {
430 for ra in 0..right_dim {
431 for ra_c in 0..right_dim {
432 let idx = ra * right_dim + ra_c;
433 new_current[idx] = new_current[idx]
434 + c_val
435 * *tensor.get3(la, s, ra)
436 * tensor.get3(la_c, s, ra_c).conj();
437 }
438 }
439 }
440 }
441 }
442
443 // Normalize and accumulate log scale
444 let scale = new_current
445 .iter()
446 .map(|x| x.abs_sq())
447 .fold(0.0, f64::max)
448 .sqrt();
449 if scale > 0.0 {
450 log_scale += scale.ln();
451 let inv_scale = T::one() / T::from_f64(scale);
452 for val in &mut new_current {
453 *val = *val * inv_scale;
454 }
455 }
456
457 current = new_current;
458 }
459
460 // Final result: norm = sqrt(norm2), where norm2 = final_val * cumulative_scale
461 // log(norm) = 0.5 * log(norm2) = 0.5 * (log(final_val) + log_scale)
462 let final_val = current[0].abs_sq().sqrt();
463 if final_val > 0.0 {
464 0.5 * (final_val.ln() + log_scale)
465 } else {
466 f64::NEG_INFINITY
467 }
468 }
469}