Skip to main content

tensor4all_simplett/
arithmetic.rs

1//! Arithmetic operations for tensor trains
2
3use crate::error::Result;
4use crate::tensortrain::TensorTrain;
5use crate::traits::{AbstractTensorTrain, TTScalar};
6use crate::types::{tensor3_zeros, Tensor3Ops};
7
8impl<T: TTScalar> TensorTrain<T> {
9    /// Add two tensor trains element-wise: `result[i] = self[i] + other[i]`.
10    ///
11    /// The result has bond dimension equal to the **sum** of the input bond
12    /// dimensions. Call [`compress`](crate::TensorTrain::compress) afterward
13    /// to reduce the bond dimension.
14    ///
15    /// # Errors
16    ///
17    /// Returns an error if the tensor trains have different lengths or
18    /// mismatched site dimensions.
19    ///
20    /// # Examples
21    ///
22    /// ```
23    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
24    ///
25    /// let a = TensorTrain::<f64>::constant(&[2, 3], 1.0);
26    /// let b = TensorTrain::<f64>::constant(&[2, 3], 2.0);
27    /// let c = a.add(&b).unwrap();
28    ///
29    /// // Every entry = 1 + 2 = 3
30    /// assert!((c.evaluate(&[0, 0]).unwrap() - 3.0).abs() < 1e-12);
31    /// // Bond dim = 1 + 1 = 2
32    /// assert_eq!(c.rank(), 2);
33    /// ```
34    pub fn add(&self, other: &Self) -> Result<Self> {
35        use crate::error::TensorTrainError;
36
37        if self.len() != other.len() {
38            return Err(TensorTrainError::InvalidOperation {
39                message: format!(
40                    "Cannot add tensor trains of different lengths: {} vs {}",
41                    self.len(),
42                    other.len()
43                ),
44            });
45        }
46
47        if self.is_empty() {
48            return Ok(other.clone());
49        }
50
51        let n = self.len();
52        let mut tensors = Vec::with_capacity(n);
53
54        for i in 0..n {
55            let a = self.site_tensor(i);
56            let b = other.site_tensor(i);
57
58            if a.site_dim() != b.site_dim() {
59                return Err(TensorTrainError::InvalidOperation {
60                    message: format!(
61                        "Site dimensions mismatch at site {}: {} vs {}",
62                        i,
63                        a.site_dim(),
64                        b.site_dim()
65                    ),
66                });
67            }
68
69            let site_dim = a.site_dim();
70
71            if i == 0 && i == n - 1 {
72                // Single-site TTs must keep both bond dimensions at 1.
73                let mut new_tensor = tensor3_zeros(1, site_dim, 1);
74                for s in 0..site_dim {
75                    new_tensor.set3(0, s, 0, *a.get3(0, s, 0) + *b.get3(0, s, 0));
76                }
77                tensors.push(new_tensor);
78            } else if i == 0 {
79                // First tensor: [A | B] horizontally
80                let new_right_dim = a.right_dim() + b.right_dim();
81                let mut new_tensor = tensor3_zeros(1, site_dim, new_right_dim);
82
83                for s in 0..site_dim {
84                    for r in 0..a.right_dim() {
85                        new_tensor.set3(0, s, r, *a.get3(0, s, r));
86                    }
87                    for r in 0..b.right_dim() {
88                        new_tensor.set3(0, s, a.right_dim() + r, *b.get3(0, s, r));
89                    }
90                }
91                tensors.push(new_tensor);
92            } else if i == n - 1 {
93                // Last tensor: [A; B] vertically
94                let new_left_dim = a.left_dim() + b.left_dim();
95                let mut new_tensor = tensor3_zeros(new_left_dim, site_dim, 1);
96
97                for l in 0..a.left_dim() {
98                    for s in 0..site_dim {
99                        new_tensor.set3(l, s, 0, *a.get3(l, s, 0));
100                    }
101                }
102                for l in 0..b.left_dim() {
103                    for s in 0..site_dim {
104                        new_tensor.set3(a.left_dim() + l, s, 0, *b.get3(l, s, 0));
105                    }
106                }
107                tensors.push(new_tensor);
108            } else {
109                // Middle tensors: block diagonal [A 0; 0 B]
110                let new_left_dim = a.left_dim() + b.left_dim();
111                let new_right_dim = a.right_dim() + b.right_dim();
112                let mut new_tensor = tensor3_zeros(new_left_dim, site_dim, new_right_dim);
113
114                // Copy A block
115                for l in 0..a.left_dim() {
116                    for s in 0..site_dim {
117                        for r in 0..a.right_dim() {
118                            new_tensor.set3(l, s, r, *a.get3(l, s, r));
119                        }
120                    }
121                }
122                // Copy B block
123                for l in 0..b.left_dim() {
124                    for s in 0..site_dim {
125                        for r in 0..b.right_dim() {
126                            new_tensor.set3(
127                                a.left_dim() + l,
128                                s,
129                                a.right_dim() + r,
130                                *b.get3(l, s, r),
131                            );
132                        }
133                    }
134                }
135                tensors.push(new_tensor);
136            }
137        }
138
139        Ok(TensorTrain::from_tensors_unchecked(tensors))
140    }
141
142    /// Subtract element-wise: `result[i] = self[i] - other[i]`.
143    ///
144    /// Equivalent to `self.add(&other.negate())`.
145    ///
146    /// # Examples
147    ///
148    /// ```
149    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
150    ///
151    /// let a = TensorTrain::<f64>::constant(&[2, 3], 5.0);
152    /// let b = TensorTrain::<f64>::constant(&[2, 3], 2.0);
153    /// let c = a.sub(&b).unwrap();
154    /// assert!((c.evaluate(&[0, 0]).unwrap() - 3.0).abs() < 1e-12);
155    /// ```
156    pub fn sub(&self, other: &Self) -> Result<Self> {
157        let neg_other = other.scaled(-T::one());
158        self.add(&neg_other)
159    }
160
161    /// Negate every entry: `result[i] = -self[i]`.
162    ///
163    /// # Examples
164    ///
165    /// ```
166    /// use tensor4all_simplett::{TensorTrain, AbstractTensorTrain};
167    ///
168    /// let tt = TensorTrain::<f64>::constant(&[2, 3], 7.0);
169    /// let neg = tt.negate();
170    /// assert!((neg.evaluate(&[0, 0]).unwrap() + 7.0).abs() < 1e-12);
171    /// ```
172    pub fn negate(&self) -> Self {
173        self.scaled(-T::one())
174    }
175}
176
177impl<T: TTScalar> std::ops::Add for TensorTrain<T> {
178    type Output = Result<Self>;
179
180    fn add(self, other: Self) -> Self::Output {
181        TensorTrain::add(&self, &other)
182    }
183}
184
185impl<T: TTScalar> std::ops::Add for &TensorTrain<T> {
186    type Output = Result<TensorTrain<T>>;
187
188    fn add(self, other: Self) -> Self::Output {
189        TensorTrain::add(self, other)
190    }
191}
192
193impl<T: TTScalar> std::ops::Sub for TensorTrain<T> {
194    type Output = Result<Self>;
195
196    fn sub(self, other: Self) -> Self::Output {
197        TensorTrain::sub(&self, &other)
198    }
199}
200
201impl<T: TTScalar> std::ops::Sub for &TensorTrain<T> {
202    type Output = Result<TensorTrain<T>>;
203
204    fn sub(self, other: Self) -> Self::Output {
205        TensorTrain::sub(self, other)
206    }
207}
208
209impl<T: TTScalar> std::ops::Neg for TensorTrain<T> {
210    type Output = Self;
211
212    fn neg(self) -> Self::Output {
213        self.negate()
214    }
215}
216
217impl<T: TTScalar> std::ops::Neg for &TensorTrain<T> {
218    type Output = TensorTrain<T>;
219
220    fn neg(self) -> Self::Output {
221        self.negate()
222    }
223}
224
225impl<T: TTScalar> std::ops::Mul<T> for TensorTrain<T> {
226    type Output = Self;
227
228    fn mul(self, scalar: T) -> Self::Output {
229        self.scaled(scalar)
230    }
231}
232
233impl<T: TTScalar> std::ops::Mul<T> for &TensorTrain<T> {
234    type Output = TensorTrain<T>;
235
236    fn mul(self, scalar: T) -> Self::Output {
237        self.scaled(scalar)
238    }
239}
240
241#[cfg(test)]
242mod tests;