1use super::contraction::ContractionOptions;
7use super::error::{MPOError, Result};
8use super::factorize::{FactorizeMethod, SVDScalar};
9use super::mpo::MPO;
10use super::site_mpo::SiteMPO;
11use super::types::{Tensor4, Tensor4Ops};
12use crate::einsum_helper::EinsumScalar;
13
14#[derive(Debug, Clone)]
16pub struct FitOptions {
17 pub tolerance: f64,
19 pub max_bond_dim: usize,
21 pub max_sweeps: usize,
23 pub convergence_tol: f64,
25 pub factorize_method: FactorizeMethod,
27}
28
29impl Default for FitOptions {
30 fn default() -> Self {
31 Self {
32 tolerance: 1e-12,
33 max_bond_dim: 100,
34 max_sweeps: 10,
35 convergence_tol: 1e-10,
36 factorize_method: FactorizeMethod::SVD,
37 }
38 }
39}
40
41pub fn contract_fit<T: SVDScalar + EinsumScalar>(
56 mpo_a: &MPO<T>,
57 mpo_b: &MPO<T>,
58 options: &FitOptions,
59 initial: Option<MPO<T>>,
60) -> Result<MPO<T>>
61where
62 <T as num_complex::ComplexFloat>::Real: Into<f64>,
63{
64 if mpo_a.len() != mpo_b.len() {
65 return Err(MPOError::LengthMismatch {
66 expected: mpo_a.len(),
67 got: mpo_b.len(),
68 });
69 }
70
71 if mpo_a.is_empty() {
72 return Ok(MPO::from_tensors_unchecked(Vec::new()));
73 }
74
75 let n = mpo_a.len();
76
77 for i in 0..n {
79 let (_, s2_a) = mpo_a.site_dim(i);
80 let (s1_b, _) = mpo_b.site_dim(i);
81 if s2_a != s1_b {
82 return Err(MPOError::SharedDimensionMismatch {
83 site: i,
84 dim_a: s2_a,
85 dim_b: s1_b,
86 });
87 }
88 }
89
90 let result = if let Some(init) = initial {
92 SiteMPO::from_mpo(init, 0)?
93 } else {
94 let naive_opts = ContractionOptions {
96 tolerance: options.tolerance,
97 max_bond_dim: options.max_bond_dim,
98 factorize_method: options.factorize_method,
99 };
100 let naive_result = super::contract_naive::contract_naive(mpo_a, mpo_b, Some(naive_opts))?;
101 SiteMPO::from_mpo(naive_result, 0)?
102 };
103
104 if n <= 2 {
106 return Ok(result.into_mpo());
107 }
108
109 let mut left_envs: Vec<Option<Environment<T>>> = vec![None; n];
111 let mut right_envs: Vec<Option<Environment<T>>> = vec![None; n];
112
113 left_envs[0] = Some(Environment::identity(1, 1, 1));
115 right_envs[n - 1] = Some(Environment::identity(1, 1, 1));
116
117 for i in (1..n).rev() {
119 right_envs[i - 1] = Some(build_right_environment(
120 mpo_a.site_tensor(i),
121 mpo_b.site_tensor(i),
122 result.site_tensor(i),
123 right_envs[i].as_ref().unwrap(),
124 )?);
125 }
126
127 let mut current = result;
128 let mut _prev_norm = f64::INFINITY;
129
130 for _sweep in 0..options.max_sweeps {
132 for i in 0..(n - 1) {
134 let _updated = update_two_site_core(
136 mpo_a,
137 mpo_b,
138 &mut current,
139 i,
140 &left_envs,
141 &right_envs,
142 options,
143 )?;
144
145 left_envs[i + 1] = Some(build_left_environment(
147 mpo_a.site_tensor(i),
148 mpo_b.site_tensor(i),
149 current.site_tensor(i),
150 left_envs[i].as_ref().unwrap(),
151 )?);
152 }
153
154 for i in (1..n).rev() {
156 let _updated = update_two_site_core(
158 mpo_a,
159 mpo_b,
160 &mut current,
161 i - 1,
162 &left_envs,
163 &right_envs,
164 options,
165 )?;
166
167 if i > 0 {
169 right_envs[i - 1] = Some(build_right_environment(
170 mpo_a.site_tensor(i),
171 mpo_b.site_tensor(i),
172 current.site_tensor(i),
173 right_envs[i].as_ref().unwrap(),
174 )?);
175 }
176 }
177
178 }
181
182 Ok(current.into_mpo())
183}
184
185#[derive(Debug, Clone)]
187struct Environment<T> {
188 data: Vec<T>,
190 dim_result: usize,
191 dim_a: usize,
192 dim_b: usize,
193}
194
195impl<T: SVDScalar> Environment<T>
196where
197 <T as num_complex::ComplexFloat>::Real: Into<f64>,
198{
199 fn identity(dim_result: usize, dim_a: usize, dim_b: usize) -> Self {
200 let mut data = vec![T::zero(); dim_result * dim_a * dim_b];
201 let min_dim = dim_result.min(dim_a).min(dim_b);
203 for i in 0..min_dim {
204 data[(i * dim_a + i) * dim_b + i] = T::one();
205 }
206 if dim_result == 1 && dim_a == 1 && dim_b == 1 {
208 data[0] = T::one();
209 }
210 Self {
211 data,
212 dim_result,
213 dim_a,
214 dim_b,
215 }
216 }
217
218 fn get(&self, r: usize, a: usize, b: usize) -> T {
219 self.data[(r * self.dim_a + a) * self.dim_b + b]
220 }
221
222 fn set(&mut self, r: usize, a: usize, b: usize, val: T) {
223 self.data[(r * self.dim_a + a) * self.dim_b + b] = val;
224 }
225}
226
227fn build_left_environment<T: SVDScalar>(
229 tensor_a: &Tensor4<T>,
230 tensor_b: &Tensor4<T>,
231 tensor_result: &Tensor4<T>,
232 prev_env: &Environment<T>,
233) -> Result<Environment<T>>
234where
235 <T as num_complex::ComplexFloat>::Real: Into<f64>,
236{
237 let new_dim_result = tensor_result.right_dim();
238 let new_dim_a = tensor_a.right_dim();
239 let new_dim_b = tensor_b.right_dim();
240
241 let mut new_env = Environment {
242 data: vec![T::zero(); new_dim_result * new_dim_a * new_dim_b],
243 dim_result: new_dim_result,
244 dim_a: new_dim_a,
245 dim_b: new_dim_b,
246 };
247
248 for rr in 0..prev_env.dim_result {
251 for ra in 0..prev_env.dim_a {
252 for rb in 0..prev_env.dim_b {
253 let l_val = prev_env.get(rr, ra, rb);
254 for s1 in 0..tensor_result.site_dim_1() {
255 for s2 in 0..tensor_result.site_dim_2() {
256 for k in 0..tensor_a.site_dim_2() {
257 for rr_new in 0..new_dim_result {
258 for ra_new in 0..new_dim_a {
259 for rb_new in 0..new_dim_b {
260 let c_val = *tensor_result.get4(rr, s1, s2, rr_new);
261 let a_val = *tensor_a.get4(ra, s1, k, ra_new);
262 let b_val = *tensor_b.get4(rb, k, s2, rb_new);
263 let old = new_env.get(rr_new, ra_new, rb_new);
264 new_env.set(
265 rr_new,
266 ra_new,
267 rb_new,
268 old + l_val * c_val * a_val * b_val,
269 );
270 }
271 }
272 }
273 }
274 }
275 }
276 }
277 }
278 }
279
280 Ok(new_env)
281}
282
283fn build_right_environment<T: SVDScalar>(
285 tensor_a: &Tensor4<T>,
286 tensor_b: &Tensor4<T>,
287 tensor_result: &Tensor4<T>,
288 next_env: &Environment<T>,
289) -> Result<Environment<T>>
290where
291 <T as num_complex::ComplexFloat>::Real: Into<f64>,
292{
293 let new_dim_result = tensor_result.left_dim();
294 let new_dim_a = tensor_a.left_dim();
295 let new_dim_b = tensor_b.left_dim();
296
297 let mut new_env = Environment {
298 data: vec![T::zero(); new_dim_result * new_dim_a * new_dim_b],
299 dim_result: new_dim_result,
300 dim_a: new_dim_a,
301 dim_b: new_dim_b,
302 };
303
304 for rr in 0..next_env.dim_result {
307 for ra in 0..next_env.dim_a {
308 for rb in 0..next_env.dim_b {
309 let r_val = next_env.get(rr, ra, rb);
310 for s1 in 0..tensor_result.site_dim_1() {
311 for s2 in 0..tensor_result.site_dim_2() {
312 for k in 0..tensor_a.site_dim_2() {
313 for lr in 0..new_dim_result {
314 for la in 0..new_dim_a {
315 for lb in 0..new_dim_b {
316 let c_val = *tensor_result.get4(lr, s1, s2, rr);
317 let a_val = *tensor_a.get4(la, s1, k, ra);
318 let b_val = *tensor_b.get4(lb, k, s2, rb);
319 let old = new_env.get(lr, la, lb);
320 new_env.set(
321 lr,
322 la,
323 lb,
324 old + r_val * c_val * a_val * b_val,
325 );
326 }
327 }
328 }
329 }
330 }
331 }
332 }
333 }
334 }
335
336 Ok(new_env)
337}
338
339fn update_two_site_core<T: SVDScalar>(
341 _mpo_a: &MPO<T>,
342 _mpo_b: &MPO<T>,
343 _result: &mut SiteMPO<T>,
344 _site: usize,
345 _left_envs: &[Option<Environment<T>>],
346 _right_envs: &[Option<Environment<T>>],
347 _options: &FitOptions,
348) -> Result<bool>
349where
350 <T as num_complex::ComplexFloat>::Real: Into<f64>,
351{
352 Ok(true)
355}
356
357#[cfg(test)]
358mod tests;