Skip to main content

tensor4all_simplett/mpo/
contraction.rs

1//! Contraction struct with caching for efficient MPO evaluation
2//!
3//! The Contraction struct wraps two MPOs and provides efficient
4//! evaluation with memoization of left and right environments.
5
6use std::collections::HashMap;
7
8use super::error::{MPOError, Result};
9use super::factorize::SVDScalar;
10use super::matrix2_zeros;
11use super::mpo::MPO;
12use super::types::Tensor4Ops;
13use super::Matrix2;
14
15/// Options for contraction operations
16#[derive(Debug, Clone)]
17pub struct ContractionOptions {
18    /// Tolerance for truncation
19    pub tolerance: f64,
20    /// Maximum bond dimension after contraction
21    pub max_bond_dim: usize,
22    /// Factorization method for compression
23    pub factorize_method: super::factorize::FactorizeMethod,
24}
25
26impl Default for ContractionOptions {
27    fn default() -> Self {
28        Self {
29            tolerance: 1e-12,
30            max_bond_dim: usize::MAX,
31            factorize_method: super::factorize::FactorizeMethod::SVD,
32        }
33    }
34}
35
36/// Contraction of two MPOs with caching
37///
38/// This struct efficiently computes the product of two MPOs by
39/// caching left and right environments for reuse.
40pub struct Contraction<T: SVDScalar>
41where
42    <T as num_complex::ComplexFloat>::Real: Into<f64>,
43{
44    /// First MPO
45    mpo_a: MPO<T>,
46    /// Second MPO
47    mpo_b: MPO<T>,
48    /// Cache for left environments, keyed by index sets
49    left_cache: HashMap<Vec<(usize, usize)>, Matrix2<T>>,
50    /// Cache for right environments, keyed by index sets
51    right_cache: HashMap<Vec<(usize, usize)>, Matrix2<T>>,
52    /// Optional transformation function applied to result
53    transform_fn: Option<Box<dyn Fn(T) -> T + Send + Sync>>,
54    /// Site dimensions for both MPOs [(s1_a, s2_a, s1_b, s2_b), ...]
55    site_dims: Vec<(usize, usize, usize, usize)>,
56}
57
58impl<T: SVDScalar> Contraction<T>
59where
60    <T as num_complex::ComplexFloat>::Real: Into<f64>,
61{
62    /// Create a new Contraction from two MPOs
63    pub fn new(mpo_a: MPO<T>, mpo_b: MPO<T>) -> Result<Self> {
64        if mpo_a.len() != mpo_b.len() {
65            return Err(MPOError::LengthMismatch {
66                expected: mpo_a.len(),
67                got: mpo_b.len(),
68            });
69        }
70
71        // Validate shared dimensions
72        for i in 0..mpo_a.len() {
73            let (_s1_a, s2_a) = mpo_a.site_dim(i);
74            let (s1_b, _s2_b) = mpo_b.site_dim(i);
75            if s2_a != s1_b {
76                return Err(MPOError::SharedDimensionMismatch {
77                    site: i,
78                    dim_a: s2_a,
79                    dim_b: s1_b,
80                });
81            }
82        }
83
84        let site_dims: Vec<_> = (0..mpo_a.len())
85            .map(|i| {
86                let (s1_a, s2_a) = mpo_a.site_dim(i);
87                let (s1_b, s2_b) = mpo_b.site_dim(i);
88                (s1_a, s2_a, s1_b, s2_b)
89            })
90            .collect();
91
92        Ok(Self {
93            mpo_a,
94            mpo_b,
95            left_cache: HashMap::new(),
96            right_cache: HashMap::new(),
97            transform_fn: None,
98            site_dims,
99        })
100    }
101
102    /// Create a new Contraction with a transformation function
103    pub fn with_transform<F>(mpo_a: MPO<T>, mpo_b: MPO<T>, f: F) -> Result<Self>
104    where
105        F: Fn(T) -> T + Send + Sync + 'static,
106    {
107        let mut contraction = Self::new(mpo_a, mpo_b)?;
108        contraction.transform_fn = Some(Box::new(f));
109        Ok(contraction)
110    }
111
112    /// Get the number of sites
113    pub fn len(&self) -> usize {
114        self.mpo_a.len()
115    }
116
117    /// Check if empty
118    pub fn is_empty(&self) -> bool {
119        self.mpo_a.is_empty()
120    }
121
122    /// Get site dimensions for the contracted result
123    ///
124    /// Returns (s1_result, s2_result) at each site where:
125    /// - s1_result = s1_a (first physical index of A)
126    /// - s2_result = s2_b (second physical index of B)
127    pub fn result_site_dims(&self) -> Vec<(usize, usize)> {
128        self.site_dims
129            .iter()
130            .map(|&(s1_a, _, _, s2_b)| (s1_a, s2_b))
131            .collect()
132    }
133
134    /// Clear all cached environments
135    pub fn clear_cache(&mut self) {
136        self.left_cache.clear();
137        self.right_cache.clear();
138    }
139
140    /// Evaluate the contraction at a specific set of indices
141    ///
142    /// indices should be [(i1, j1), (i2, j2), ...] where:
143    /// - i_k is the index for s1 of MPO A at site k
144    /// - j_k is the index for s2 of MPO B at site k
145    pub fn evaluate(&mut self, indices: &[(usize, usize)]) -> Result<T> {
146        if indices.len() != self.len() {
147            return Err(MPOError::InvalidOperation {
148                message: format!("Expected {} index pairs, got {}", self.len(), indices.len()),
149            });
150        }
151
152        if self.is_empty() {
153            return Err(MPOError::Empty);
154        }
155
156        // Contract from left to right
157        let first_a = self.mpo_a.site_tensor(0);
158        let first_b = self.mpo_b.site_tensor(0);
159        let (i0, j0) = indices[0];
160
161        // Sum over shared index k
162        let mut current: Matrix2<T> = matrix2_zeros(first_a.right_dim(), first_b.right_dim());
163        for k in 0..first_a.site_dim_2() {
164            for ra in 0..first_a.right_dim() {
165                for rb in 0..first_b.right_dim() {
166                    current[[ra, rb]] = current[[ra, rb]]
167                        + *first_a.get4(0, i0, k, ra) * *first_b.get4(0, k, j0, rb);
168                }
169            }
170        }
171
172        // Contract through remaining sites
173        #[allow(clippy::needless_range_loop)]
174        for site in 1..self.len() {
175            let a = self.mpo_a.site_tensor(site);
176            let b = self.mpo_b.site_tensor(site);
177            let (i_k, j_k) = indices[site];
178
179            let mut new_current: Matrix2<T> = matrix2_zeros(a.right_dim(), b.right_dim());
180
181            for la in 0..a.left_dim() {
182                for lb in 0..b.left_dim() {
183                    let c_val = current[[la, lb]];
184                    for k in 0..a.site_dim_2() {
185                        for ra in 0..a.right_dim() {
186                            for rb in 0..b.right_dim() {
187                                new_current[[ra, rb]] = new_current[[ra, rb]]
188                                    + c_val * *a.get4(la, i_k, k, ra) * *b.get4(lb, k, j_k, rb);
189                            }
190                        }
191                    }
192                }
193            }
194
195            current = new_current;
196        }
197
198        let result = current[[0, 0]];
199
200        // Apply transformation if present
201        let result = if let Some(ref f) = self.transform_fn {
202            f(result)
203        } else {
204            result
205        };
206
207        Ok(result)
208    }
209
210    /// Evaluate the left environment up to site n (exclusive)
211    ///
212    /// Returns L\[n\] = product of sites 0..n
213    pub fn evaluate_left(&mut self, n: usize, indices: &[(usize, usize)]) -> Result<Matrix2<T>> {
214        if n > self.len() {
215            return Err(MPOError::InvalidOperation {
216                message: format!("Site {} is out of range [0, {}]", n, self.len()),
217            });
218        }
219
220        if n == 0 {
221            let mut env: Matrix2<T> = matrix2_zeros(1, 1);
222            env[[0, 0]] = T::one();
223            return Ok(env);
224        }
225
226        // Check cache
227        let key: Vec<(usize, usize)> = indices[..n].to_vec();
228        if let Some(cached) = self.left_cache.get(&key) {
229            return Ok(cached.clone());
230        }
231
232        // Compute recursively
233        let prev_env = self.evaluate_left(n - 1, indices)?;
234        let a = self.mpo_a.site_tensor(n - 1);
235        let b = self.mpo_b.site_tensor(n - 1);
236        let (i_k, j_k) = indices[n - 1];
237
238        let mut new_env: Matrix2<T> = matrix2_zeros(a.right_dim(), b.right_dim());
239
240        for la in 0..a.left_dim() {
241            for lb in 0..b.left_dim() {
242                let l_val = prev_env[[la, lb]];
243                for k in 0..a.site_dim_2() {
244                    for ra in 0..a.right_dim() {
245                        for rb in 0..b.right_dim() {
246                            new_env[[ra, rb]] = new_env[[ra, rb]]
247                                + l_val * *a.get4(la, i_k, k, ra) * *b.get4(lb, k, j_k, rb);
248                        }
249                    }
250                }
251            }
252        }
253
254        // Cache the result
255        self.left_cache.insert(key, new_env.clone());
256
257        Ok(new_env)
258    }
259
260    /// Evaluate the right environment from site n (exclusive) to the end
261    ///
262    /// Returns R\[n\] = product of sites n..L
263    pub fn evaluate_right(&mut self, n: usize, indices: &[(usize, usize)]) -> Result<Matrix2<T>> {
264        let len = self.len();
265        if n > len {
266            return Err(MPOError::InvalidOperation {
267                message: format!("Site {} is out of range [0, {}]", n, len),
268            });
269        }
270
271        if n == len {
272            let mut env: Matrix2<T> = matrix2_zeros(1, 1);
273            env[[0, 0]] = T::one();
274            return Ok(env);
275        }
276
277        // Check cache
278        let key: Vec<(usize, usize)> = indices[n..].to_vec();
279        if let Some(cached) = self.right_cache.get(&key) {
280            return Ok(cached.clone());
281        }
282
283        // Compute recursively
284        let prev_env = self.evaluate_right(n + 1, indices)?;
285        let a = self.mpo_a.site_tensor(n);
286        let b = self.mpo_b.site_tensor(n);
287        let (i_k, j_k) = indices[n];
288
289        let mut new_env: Matrix2<T> = matrix2_zeros(a.left_dim(), b.left_dim());
290
291        for ra in 0..a.right_dim() {
292            for rb in 0..b.right_dim() {
293                let r_val = prev_env[[ra, rb]];
294                for k in 0..a.site_dim_2() {
295                    for la in 0..a.left_dim() {
296                        for lb in 0..b.left_dim() {
297                            new_env[[la, lb]] = new_env[[la, lb]]
298                                + r_val * *a.get4(la, i_k, k, ra) * *b.get4(lb, k, j_k, rb);
299                        }
300                    }
301                }
302            }
303        }
304
305        // Cache the result
306        self.right_cache.insert(key, new_env.clone());
307
308        Ok(new_env)
309    }
310}
311
312#[cfg(test)]
313mod tests;