1use std::collections::HashMap;
7
8use super::error::{MPOError, Result};
9use super::factorize::SVDScalar;
10use super::matrix2_zeros;
11use super::mpo::MPO;
12use super::types::Tensor4Ops;
13use super::Matrix2;
14
15#[derive(Debug, Clone)]
17pub struct ContractionOptions {
18 pub tolerance: f64,
20 pub max_bond_dim: usize,
22 pub factorize_method: super::factorize::FactorizeMethod,
24}
25
26impl Default for ContractionOptions {
27 fn default() -> Self {
28 Self {
29 tolerance: 1e-12,
30 max_bond_dim: usize::MAX,
31 factorize_method: super::factorize::FactorizeMethod::SVD,
32 }
33 }
34}
35
36pub struct Contraction<T: SVDScalar>
41where
42 <T as num_complex::ComplexFloat>::Real: Into<f64>,
43{
44 mpo_a: MPO<T>,
46 mpo_b: MPO<T>,
48 left_cache: HashMap<Vec<(usize, usize)>, Matrix2<T>>,
50 right_cache: HashMap<Vec<(usize, usize)>, Matrix2<T>>,
52 transform_fn: Option<Box<dyn Fn(T) -> T + Send + Sync>>,
54 site_dims: Vec<(usize, usize, usize, usize)>,
56}
57
58impl<T: SVDScalar> Contraction<T>
59where
60 <T as num_complex::ComplexFloat>::Real: Into<f64>,
61{
62 pub fn new(mpo_a: MPO<T>, mpo_b: MPO<T>) -> Result<Self> {
64 if mpo_a.len() != mpo_b.len() {
65 return Err(MPOError::LengthMismatch {
66 expected: mpo_a.len(),
67 got: mpo_b.len(),
68 });
69 }
70
71 for i in 0..mpo_a.len() {
73 let (_s1_a, s2_a) = mpo_a.site_dim(i);
74 let (s1_b, _s2_b) = mpo_b.site_dim(i);
75 if s2_a != s1_b {
76 return Err(MPOError::SharedDimensionMismatch {
77 site: i,
78 dim_a: s2_a,
79 dim_b: s1_b,
80 });
81 }
82 }
83
84 let site_dims: Vec<_> = (0..mpo_a.len())
85 .map(|i| {
86 let (s1_a, s2_a) = mpo_a.site_dim(i);
87 let (s1_b, s2_b) = mpo_b.site_dim(i);
88 (s1_a, s2_a, s1_b, s2_b)
89 })
90 .collect();
91
92 Ok(Self {
93 mpo_a,
94 mpo_b,
95 left_cache: HashMap::new(),
96 right_cache: HashMap::new(),
97 transform_fn: None,
98 site_dims,
99 })
100 }
101
102 pub fn with_transform<F>(mpo_a: MPO<T>, mpo_b: MPO<T>, f: F) -> Result<Self>
104 where
105 F: Fn(T) -> T + Send + Sync + 'static,
106 {
107 let mut contraction = Self::new(mpo_a, mpo_b)?;
108 contraction.transform_fn = Some(Box::new(f));
109 Ok(contraction)
110 }
111
112 pub fn len(&self) -> usize {
114 self.mpo_a.len()
115 }
116
117 pub fn is_empty(&self) -> bool {
119 self.mpo_a.is_empty()
120 }
121
122 pub fn result_site_dims(&self) -> Vec<(usize, usize)> {
128 self.site_dims
129 .iter()
130 .map(|&(s1_a, _, _, s2_b)| (s1_a, s2_b))
131 .collect()
132 }
133
134 pub fn clear_cache(&mut self) {
136 self.left_cache.clear();
137 self.right_cache.clear();
138 }
139
140 pub fn evaluate(&mut self, indices: &[(usize, usize)]) -> Result<T> {
146 if indices.len() != self.len() {
147 return Err(MPOError::InvalidOperation {
148 message: format!("Expected {} index pairs, got {}", self.len(), indices.len()),
149 });
150 }
151
152 if self.is_empty() {
153 return Err(MPOError::Empty);
154 }
155
156 let first_a = self.mpo_a.site_tensor(0);
158 let first_b = self.mpo_b.site_tensor(0);
159 let (i0, j0) = indices[0];
160
161 let mut current: Matrix2<T> = matrix2_zeros(first_a.right_dim(), first_b.right_dim());
163 for k in 0..first_a.site_dim_2() {
164 for ra in 0..first_a.right_dim() {
165 for rb in 0..first_b.right_dim() {
166 current[[ra, rb]] = current[[ra, rb]]
167 + *first_a.get4(0, i0, k, ra) * *first_b.get4(0, k, j0, rb);
168 }
169 }
170 }
171
172 #[allow(clippy::needless_range_loop)]
174 for site in 1..self.len() {
175 let a = self.mpo_a.site_tensor(site);
176 let b = self.mpo_b.site_tensor(site);
177 let (i_k, j_k) = indices[site];
178
179 let mut new_current: Matrix2<T> = matrix2_zeros(a.right_dim(), b.right_dim());
180
181 for la in 0..a.left_dim() {
182 for lb in 0..b.left_dim() {
183 let c_val = current[[la, lb]];
184 for k in 0..a.site_dim_2() {
185 for ra in 0..a.right_dim() {
186 for rb in 0..b.right_dim() {
187 new_current[[ra, rb]] = new_current[[ra, rb]]
188 + c_val * *a.get4(la, i_k, k, ra) * *b.get4(lb, k, j_k, rb);
189 }
190 }
191 }
192 }
193 }
194
195 current = new_current;
196 }
197
198 let result = current[[0, 0]];
199
200 let result = if let Some(ref f) = self.transform_fn {
202 f(result)
203 } else {
204 result
205 };
206
207 Ok(result)
208 }
209
210 pub fn evaluate_left(&mut self, n: usize, indices: &[(usize, usize)]) -> Result<Matrix2<T>> {
214 if n > self.len() {
215 return Err(MPOError::InvalidOperation {
216 message: format!("Site {} is out of range [0, {}]", n, self.len()),
217 });
218 }
219
220 if n == 0 {
221 let mut env: Matrix2<T> = matrix2_zeros(1, 1);
222 env[[0, 0]] = T::one();
223 return Ok(env);
224 }
225
226 let key: Vec<(usize, usize)> = indices[..n].to_vec();
228 if let Some(cached) = self.left_cache.get(&key) {
229 return Ok(cached.clone());
230 }
231
232 let prev_env = self.evaluate_left(n - 1, indices)?;
234 let a = self.mpo_a.site_tensor(n - 1);
235 let b = self.mpo_b.site_tensor(n - 1);
236 let (i_k, j_k) = indices[n - 1];
237
238 let mut new_env: Matrix2<T> = matrix2_zeros(a.right_dim(), b.right_dim());
239
240 for la in 0..a.left_dim() {
241 for lb in 0..b.left_dim() {
242 let l_val = prev_env[[la, lb]];
243 for k in 0..a.site_dim_2() {
244 for ra in 0..a.right_dim() {
245 for rb in 0..b.right_dim() {
246 new_env[[ra, rb]] = new_env[[ra, rb]]
247 + l_val * *a.get4(la, i_k, k, ra) * *b.get4(lb, k, j_k, rb);
248 }
249 }
250 }
251 }
252 }
253
254 self.left_cache.insert(key, new_env.clone());
256
257 Ok(new_env)
258 }
259
260 pub fn evaluate_right(&mut self, n: usize, indices: &[(usize, usize)]) -> Result<Matrix2<T>> {
264 let len = self.len();
265 if n > len {
266 return Err(MPOError::InvalidOperation {
267 message: format!("Site {} is out of range [0, {}]", n, len),
268 });
269 }
270
271 if n == len {
272 let mut env: Matrix2<T> = matrix2_zeros(1, 1);
273 env[[0, 0]] = T::one();
274 return Ok(env);
275 }
276
277 let key: Vec<(usize, usize)> = indices[n..].to_vec();
279 if let Some(cached) = self.right_cache.get(&key) {
280 return Ok(cached.clone());
281 }
282
283 let prev_env = self.evaluate_right(n + 1, indices)?;
285 let a = self.mpo_a.site_tensor(n);
286 let b = self.mpo_b.site_tensor(n);
287 let (i_k, j_k) = indices[n];
288
289 let mut new_env: Matrix2<T> = matrix2_zeros(a.left_dim(), b.left_dim());
290
291 for ra in 0..a.right_dim() {
292 for rb in 0..b.right_dim() {
293 let r_val = prev_env[[ra, rb]];
294 for k in 0..a.site_dim_2() {
295 for la in 0..a.left_dim() {
296 for lb in 0..b.left_dim() {
297 new_env[[la, lb]] = new_env[[la, lb]]
298 + r_val * *a.get4(la, i_k, k, ra) * *b.get4(lb, k, j_k, rb);
299 }
300 }
301 }
302 }
303 }
304
305 self.right_cache.insert(key, new_env.clone());
307
308 Ok(new_env)
309 }
310}
311
312#[cfg(test)]
313mod tests;