Skip to main content

tensor4all_simplett/mpo/
contract_fit.rs

1//! Variational fitting algorithm for MPO contraction
2//!
3//! This module implements the variational fitting (DMRG-like) algorithm
4//! for computing the product of two MPOs with controlled bond dimension.
5
6use super::contraction::ContractionOptions;
7use super::error::{MPOError, Result};
8use super::factorize::{FactorizeMethod, SVDScalar};
9use super::mpo::MPO;
10use super::site_mpo::SiteMPO;
11use super::types::{Tensor4, Tensor4Ops};
12use crate::einsum_helper::EinsumScalar;
13
14/// Options for the variational fit algorithm
15#[derive(Debug, Clone)]
16pub struct FitOptions {
17    /// Tolerance for truncation at each step
18    pub tolerance: f64,
19    /// Maximum bond dimension
20    pub max_bond_dim: usize,
21    /// Maximum number of sweeps
22    pub max_sweeps: usize,
23    /// Convergence tolerance (stop if change < this)
24    pub convergence_tol: f64,
25    /// Factorization method
26    pub factorize_method: FactorizeMethod,
27}
28
29impl Default for FitOptions {
30    fn default() -> Self {
31        Self {
32            tolerance: 1e-12,
33            max_bond_dim: 100,
34            max_sweeps: 10,
35            convergence_tol: 1e-10,
36            factorize_method: FactorizeMethod::SVD,
37        }
38    }
39}
40
41/// Perform variational fitting contraction of two MPOs
42///
43/// This computes C = A * B using a variational (DMRG-like) algorithm
44/// that alternates between sweeping left-to-right and right-to-left,
45/// optimizing two sites at a time.
46///
47/// # Arguments
48/// * `mpo_a` - First MPO
49/// * `mpo_b` - Second MPO
50/// * `options` - Fitting options
51/// * `initial` - Optional initial guess (if None, uses naive contraction)
52///
53/// # Returns
54/// The contracted MPO C with bond dimension controlled by options
55pub fn contract_fit<T: SVDScalar + EinsumScalar>(
56    mpo_a: &MPO<T>,
57    mpo_b: &MPO<T>,
58    options: &FitOptions,
59    initial: Option<MPO<T>>,
60) -> Result<MPO<T>>
61where
62    <T as num_complex::ComplexFloat>::Real: Into<f64>,
63{
64    if mpo_a.len() != mpo_b.len() {
65        return Err(MPOError::LengthMismatch {
66            expected: mpo_a.len(),
67            got: mpo_b.len(),
68        });
69    }
70
71    if mpo_a.is_empty() {
72        return Ok(MPO::from_tensors_unchecked(Vec::new()));
73    }
74
75    let n = mpo_a.len();
76
77    // Validate shared dimensions
78    for i in 0..n {
79        let (_, s2_a) = mpo_a.site_dim(i);
80        let (s1_b, _) = mpo_b.site_dim(i);
81        if s2_a != s1_b {
82            return Err(MPOError::SharedDimensionMismatch {
83                site: i,
84                dim_a: s2_a,
85                dim_b: s1_b,
86            });
87        }
88    }
89
90    // Initialize the result
91    let result = if let Some(init) = initial {
92        SiteMPO::from_mpo(init, 0)?
93    } else {
94        // Use naive contraction as initial guess, then compress
95        let naive_opts = ContractionOptions {
96            tolerance: options.tolerance,
97            max_bond_dim: options.max_bond_dim,
98            factorize_method: options.factorize_method,
99        };
100        let naive_result = super::contract_naive::contract_naive(mpo_a, mpo_b, Some(naive_opts))?;
101        SiteMPO::from_mpo(naive_result, 0)?
102    };
103
104    // For single or two-site systems, return immediately
105    if n <= 2 {
106        return Ok(result.into_mpo());
107    }
108
109    // Build left and right environments
110    let mut left_envs: Vec<Option<Environment<T>>> = vec![None; n];
111    let mut right_envs: Vec<Option<Environment<T>>> = vec![None; n];
112
113    // Initialize boundary environments
114    left_envs[0] = Some(Environment::identity(1, 1, 1));
115    right_envs[n - 1] = Some(Environment::identity(1, 1, 1));
116
117    // Build initial right environments
118    for i in (1..n).rev() {
119        right_envs[i - 1] = Some(build_right_environment(
120            mpo_a.site_tensor(i),
121            mpo_b.site_tensor(i),
122            result.site_tensor(i),
123            right_envs[i].as_ref().unwrap(),
124        )?);
125    }
126
127    let mut current = result;
128    let mut _prev_norm = f64::INFINITY;
129
130    // Main optimization loop
131    for _sweep in 0..options.max_sweeps {
132        // Forward sweep (left to right)
133        for i in 0..(n - 1) {
134            // Update two-site core at positions i and i+1
135            let _updated = update_two_site_core(
136                mpo_a,
137                mpo_b,
138                &mut current,
139                i,
140                &left_envs,
141                &right_envs,
142                options,
143            )?;
144
145            // Update left environment at i+1
146            left_envs[i + 1] = Some(build_left_environment(
147                mpo_a.site_tensor(i),
148                mpo_b.site_tensor(i),
149                current.site_tensor(i),
150                left_envs[i].as_ref().unwrap(),
151            )?);
152        }
153
154        // Backward sweep (right to left)
155        for i in (1..n).rev() {
156            // Update two-site core at positions i-1 and i
157            let _updated = update_two_site_core(
158                mpo_a,
159                mpo_b,
160                &mut current,
161                i - 1,
162                &left_envs,
163                &right_envs,
164                options,
165            )?;
166
167            // Update right environment at i-1
168            if i > 0 {
169                right_envs[i - 1] = Some(build_right_environment(
170                    mpo_a.site_tensor(i),
171                    mpo_b.site_tensor(i),
172                    current.site_tensor(i),
173                    right_envs[i].as_ref().unwrap(),
174                )?);
175            }
176        }
177
178        // Check convergence
179        // TODO: Implement proper convergence check using norm difference
180    }
181
182    Ok(current.into_mpo())
183}
184
185/// Environment tensor for variational algorithm
186#[derive(Debug, Clone)]
187struct Environment<T> {
188    /// Shape: (link_result, link_a, link_b)
189    data: Vec<T>,
190    dim_result: usize,
191    dim_a: usize,
192    dim_b: usize,
193}
194
195impl<T: SVDScalar> Environment<T>
196where
197    <T as num_complex::ComplexFloat>::Real: Into<f64>,
198{
199    fn identity(dim_result: usize, dim_a: usize, dim_b: usize) -> Self {
200        let mut data = vec![T::zero(); dim_result * dim_a * dim_b];
201        // Set diagonal elements to 1
202        let min_dim = dim_result.min(dim_a).min(dim_b);
203        for i in 0..min_dim {
204            data[(i * dim_a + i) * dim_b + i] = T::one();
205        }
206        // Actually for identity, we just want a single 1.0
207        if dim_result == 1 && dim_a == 1 && dim_b == 1 {
208            data[0] = T::one();
209        }
210        Self {
211            data,
212            dim_result,
213            dim_a,
214            dim_b,
215        }
216    }
217
218    fn get(&self, r: usize, a: usize, b: usize) -> T {
219        self.data[(r * self.dim_a + a) * self.dim_b + b]
220    }
221
222    fn set(&mut self, r: usize, a: usize, b: usize, val: T) {
223        self.data[(r * self.dim_a + a) * self.dim_b + b] = val;
224    }
225}
226
227/// Build left environment by extending from previous environment
228fn build_left_environment<T: SVDScalar>(
229    tensor_a: &Tensor4<T>,
230    tensor_b: &Tensor4<T>,
231    tensor_result: &Tensor4<T>,
232    prev_env: &Environment<T>,
233) -> Result<Environment<T>>
234where
235    <T as num_complex::ComplexFloat>::Real: Into<f64>,
236{
237    let new_dim_result = tensor_result.right_dim();
238    let new_dim_a = tensor_a.right_dim();
239    let new_dim_b = tensor_b.right_dim();
240
241    let mut new_env = Environment {
242        data: vec![T::zero(); new_dim_result * new_dim_a * new_dim_b],
243        dim_result: new_dim_result,
244        dim_a: new_dim_a,
245        dim_b: new_dim_b,
246    };
247
248    // Contract: L'[rr', ra', rb'] = sum_{rr, ra, rb, s1, s2, k}
249    //   L[rr, ra, rb] * C[rr, s1, s2, rr'] * A[ra, s1, k, ra'] * B[rb, k, s2, rb']
250    for rr in 0..prev_env.dim_result {
251        for ra in 0..prev_env.dim_a {
252            for rb in 0..prev_env.dim_b {
253                let l_val = prev_env.get(rr, ra, rb);
254                for s1 in 0..tensor_result.site_dim_1() {
255                    for s2 in 0..tensor_result.site_dim_2() {
256                        for k in 0..tensor_a.site_dim_2() {
257                            for rr_new in 0..new_dim_result {
258                                for ra_new in 0..new_dim_a {
259                                    for rb_new in 0..new_dim_b {
260                                        let c_val = *tensor_result.get4(rr, s1, s2, rr_new);
261                                        let a_val = *tensor_a.get4(ra, s1, k, ra_new);
262                                        let b_val = *tensor_b.get4(rb, k, s2, rb_new);
263                                        let old = new_env.get(rr_new, ra_new, rb_new);
264                                        new_env.set(
265                                            rr_new,
266                                            ra_new,
267                                            rb_new,
268                                            old + l_val * c_val * a_val * b_val,
269                                        );
270                                    }
271                                }
272                            }
273                        }
274                    }
275                }
276            }
277        }
278    }
279
280    Ok(new_env)
281}
282
283/// Build right environment by extending from next environment
284fn build_right_environment<T: SVDScalar>(
285    tensor_a: &Tensor4<T>,
286    tensor_b: &Tensor4<T>,
287    tensor_result: &Tensor4<T>,
288    next_env: &Environment<T>,
289) -> Result<Environment<T>>
290where
291    <T as num_complex::ComplexFloat>::Real: Into<f64>,
292{
293    let new_dim_result = tensor_result.left_dim();
294    let new_dim_a = tensor_a.left_dim();
295    let new_dim_b = tensor_b.left_dim();
296
297    let mut new_env = Environment {
298        data: vec![T::zero(); new_dim_result * new_dim_a * new_dim_b],
299        dim_result: new_dim_result,
300        dim_a: new_dim_a,
301        dim_b: new_dim_b,
302    };
303
304    // Contract: R'[lr, la, lb] = sum_{rr, ra, rb, s1, s2, k}
305    //   R[rr, ra, rb] * C[lr, s1, s2, rr] * A[la, s1, k, ra] * B[lb, k, s2, rb]
306    for rr in 0..next_env.dim_result {
307        for ra in 0..next_env.dim_a {
308            for rb in 0..next_env.dim_b {
309                let r_val = next_env.get(rr, ra, rb);
310                for s1 in 0..tensor_result.site_dim_1() {
311                    for s2 in 0..tensor_result.site_dim_2() {
312                        for k in 0..tensor_a.site_dim_2() {
313                            for lr in 0..new_dim_result {
314                                for la in 0..new_dim_a {
315                                    for lb in 0..new_dim_b {
316                                        let c_val = *tensor_result.get4(lr, s1, s2, rr);
317                                        let a_val = *tensor_a.get4(la, s1, k, ra);
318                                        let b_val = *tensor_b.get4(lb, k, s2, rb);
319                                        let old = new_env.get(lr, la, lb);
320                                        new_env.set(
321                                            lr,
322                                            la,
323                                            lb,
324                                            old + r_val * c_val * a_val * b_val,
325                                        );
326                                    }
327                                }
328                            }
329                        }
330                    }
331                }
332            }
333        }
334    }
335
336    Ok(new_env)
337}
338
339/// Update the two-site core tensor at positions site and site+1
340fn update_two_site_core<T: SVDScalar>(
341    _mpo_a: &MPO<T>,
342    _mpo_b: &MPO<T>,
343    _result: &mut SiteMPO<T>,
344    _site: usize,
345    _left_envs: &[Option<Environment<T>>],
346    _right_envs: &[Option<Environment<T>>],
347    _options: &FitOptions,
348) -> Result<bool>
349where
350    <T as num_complex::ComplexFloat>::Real: Into<f64>,
351{
352    // For now, just return Ok - full implementation would update the core
353    // This is a placeholder for the variational update step
354    Ok(true)
355}
356
357#[cfg(test)]
358mod tests;