tensor4all_simplett/contraction.rs
1//! Contraction operations for tensor trains
2//!
3//! This module provides various ways to combine tensor trains:
4//! - `dot`: Inner product (returns scalar)
5
6use crate::compression::CompressionMethod;
7use crate::einsum_helper::{
8 einsum_tensors, tensor_to_row_major_vec, typed_tensor_from_row_major_slice, EinsumScalar,
9};
10use crate::error::{Result, TensorTrainError};
11use crate::tensortrain::TensorTrain;
12use crate::traits::{AbstractTensorTrain, TTScalar};
13use crate::types::Tensor3Ops;
14use tensor4all_tcicore::matrix::Matrix;
15use tensor4all_tcicore::Scalar;
16
17/// Options for MPO-MPO contraction with on-the-fly compression.
18///
19/// # Examples
20///
21/// ```
22/// use tensor4all_simplett::ContractionOptions;
23///
24/// let opts = ContractionOptions::default();
25/// assert!((opts.tolerance - 1e-12).abs() < 1e-15);
26/// assert_eq!(opts.max_bond_dim, usize::MAX);
27/// ```
28#[derive(Debug, Clone)]
29pub struct ContractionOptions {
30 /// Relative truncation tolerance during contraction.
31 pub tolerance: f64,
32 /// Hard upper bound on bond dimension.
33 pub max_bond_dim: usize,
34 /// Decomposition method for intermediate compression.
35 pub method: CompressionMethod,
36}
37
38impl Default for ContractionOptions {
39 fn default() -> Self {
40 Self {
41 tolerance: 1e-12,
42 max_bond_dim: usize::MAX,
43 method: CompressionMethod::LU,
44 }
45 }
46}
47
48impl<T: TTScalar + Scalar + Default + EinsumScalar> TensorTrain<T> {
49 /// Inner product (dot product) of two tensor trains.
50 ///
51 /// Computes `sum_i self[i] * other[i]` by contracting the site tensors
52 /// from left to right. Both tensor trains must have the same length and
53 /// matching site dimensions.
54 ///
55 /// # Errors
56 ///
57 /// Returns an error if lengths or site dimensions do not match.
58 ///
59 /// # Examples
60 ///
61 /// ```
62 /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
63 ///
64 /// let a = TensorTrain::<f64>::constant(&[2, 3], 1.0);
65 /// let b = TensorTrain::<f64>::constant(&[2, 3], 2.0);
66 ///
67 /// // dot = sum_ij a[i,j]*b[i,j] = 1*2 * 2*3 = 12
68 /// let d = a.dot(&b).unwrap();
69 /// assert!((d - 12.0).abs() < 1e-10);
70 /// ```
71 pub fn dot(&self, other: &Self) -> Result<T> {
72 if self.len() != other.len() {
73 return Err(TensorTrainError::InvalidOperation {
74 message: format!(
75 "Cannot compute dot product of tensor trains with different lengths: {} vs {}",
76 self.len(),
77 other.len()
78 ),
79 });
80 }
81
82 if self.is_empty() {
83 return Ok(T::zero());
84 }
85
86 let n = self.len();
87
88 // Start with contraction of first site
89 // result[ra, rb] = sum_s a[0, s, ra] * b[0, s, rb]
90 let a0 = self.site_tensor(0);
91 let b0 = other.site_tensor(0);
92
93 if a0.site_dim() != b0.site_dim() {
94 return Err(TensorTrainError::InvalidOperation {
95 message: format!(
96 "Site dimensions mismatch at site 0: {} vs {}",
97 a0.site_dim(),
98 b0.site_dim()
99 ),
100 });
101 }
102
103 let mut result = Matrix::from_raw_vec(
104 a0.right_dim(),
105 b0.right_dim(),
106 tensor_to_row_major_vec(&einsum_tensors(
107 "asr,ast->rt",
108 &[a0.as_inner(), b0.as_inner()],
109 )),
110 );
111
112 // Contract through remaining sites
113 for i in 1..n {
114 let a = self.site_tensor(i);
115 let b = other.site_tensor(i);
116
117 if a.site_dim() != b.site_dim() {
118 return Err(TensorTrainError::InvalidOperation {
119 message: format!(
120 "Site dimensions mismatch at site {}: {} vs {}",
121 i,
122 a.site_dim(),
123 b.site_dim()
124 ),
125 });
126 }
127
128 let result_tf = typed_tensor_from_row_major_slice(
129 result.as_slice(),
130 &[result.nrows(), result.ncols()],
131 );
132
133 result = Matrix::from_raw_vec(
134 a.right_dim(),
135 b.right_dim(),
136 tensor_to_row_major_vec(&einsum_tensors(
137 "ij,isk,jsl->kl",
138 &[&result_tf, a.as_inner(), b.as_inner()],
139 )),
140 );
141 }
142
143 // Final result should be 1x1
144 Ok(result[[0, 0]])
145 }
146}
147
148/// Free-function wrapper for [`TensorTrain::dot`].
149///
150/// # Examples
151///
152/// ```
153/// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain, contraction::dot};
154///
155/// let a = TensorTrain::<f64>::constant(&[2, 3], 3.0);
156/// let b = TensorTrain::<f64>::constant(&[2, 3], 4.0);
157/// let d = dot(&a, &b).unwrap();
158/// // 3*4 * 2*3 = 72
159/// assert!((d - 72.0).abs() < 1e-10);
160/// ```
161pub fn dot<T: TTScalar + Scalar + Default + EinsumScalar>(
162 a: &TensorTrain<T>,
163 b: &TensorTrain<T>,
164) -> Result<T> {
165 a.dot(b)
166}
167
168#[cfg(test)]
169mod tests;