Skip to main content

tensor4all_simplett/mpo/
contract_zipup.rs

1//! Zip-up MPO contraction algorithm
2//!
3//! Contracts two MPOs with on-the-fly compression at each step.
4//! This is more memory-efficient than naive contraction followed by compression.
5
6use super::contraction::ContractionOptions;
7use super::error::{MPOError, Result};
8use super::factorize::{factorize, FactorizeOptions, SVDScalar};
9use super::mpo::MPO;
10use super::types::{tensor4_zeros, Tensor4, Tensor4Ops};
11use super::{matrix2_zeros, Matrix2};
12
13/// Perform zip-up contraction of two MPOs
14///
15/// This computes C = A * B where the contraction is over the shared
16/// physical index (s2 of A contracts with s1 of B), with on-the-fly
17/// compression at each step.
18///
19/// The zip-up algorithm:
20/// 1. Start from the left with a remainder tensor R = \[\[1\]\]
21/// 2. At each site:
22///    a. Contract R with A\[i\] and B\[i\]
23///    b. Reshape to matrix
24///    c. Factorize into left and right factors
25///    d. Store left factor as result tensor
26///    e. Use right factor as new remainder R
27/// 3. At the last site, just store the contracted tensor
28///
29/// # Arguments
30/// * `mpo_a` - First MPO
31/// * `mpo_b` - Second MPO
32/// * `options` - Contraction options (tolerance, max_bond_dim, method)
33///
34/// # Returns
35/// The contracted and compressed MPO C
36pub fn contract_zipup<T: SVDScalar>(
37    mpo_a: &MPO<T>,
38    mpo_b: &MPO<T>,
39    options: &ContractionOptions,
40) -> Result<MPO<T>>
41where
42    <T as num_complex::ComplexFloat>::Real: Into<f64>,
43{
44    if mpo_a.len() != mpo_b.len() {
45        return Err(MPOError::LengthMismatch {
46            expected: mpo_a.len(),
47            got: mpo_b.len(),
48        });
49    }
50
51    if mpo_a.is_empty() {
52        return Ok(MPO::from_tensors_unchecked(Vec::new()));
53    }
54
55    let n = mpo_a.len();
56
57    // Validate shared dimensions
58    for i in 0..n {
59        let (_, s2_a) = mpo_a.site_dim(i);
60        let (s1_b, _) = mpo_b.site_dim(i);
61        if s2_a != s1_b {
62            return Err(MPOError::SharedDimensionMismatch {
63                site: i,
64                dim_a: s2_a,
65                dim_b: s1_b,
66            });
67        }
68    }
69
70    // Remainder tensor: R[new_link, link_a, link_b]
71    // Start with 1x1x1 identity
72    let mut r_left_dim = 1usize;
73    let mut r_a_dim = 1usize;
74    let mut r_b_dim = 1usize;
75    let mut r_data: Vec<T> = vec![T::one()];
76
77    let mut result_tensors: Vec<Tensor4<T>> = Vec::with_capacity(n);
78
79    let factorize_opts = FactorizeOptions {
80        method: options.factorize_method,
81        tolerance: options.tolerance,
82        max_rank: options.max_bond_dim,
83        left_orthogonal: true,
84        ..Default::default()
85    };
86
87    for i in 0..n {
88        let a = mpo_a.site_tensor(i);
89        let b = mpo_b.site_tensor(i);
90
91        let s1_a = a.site_dim_1();
92        let s2_a = a.site_dim_2(); // shared with s1_b
93        let s2_b = b.site_dim_2();
94        let right_a = a.right_dim();
95        let right_b = b.right_dim();
96
97        // Contract R with A and B:
98        // C[new_link, s1_a, s2_b, right_a, right_b]
99        // = sum_{link_a, link_b, k} R[new_link, link_a, link_b]
100        //   * A[link_a, s1_a, k, right_a] * B[link_b, k, s2_b, right_b]
101
102        let c_new_link = r_left_dim;
103        let c_s1 = s1_a;
104        let c_s2 = s2_b;
105        let c_right_a = right_a;
106        let c_right_b = right_b;
107
108        // C as a 5D array, but we'll store it flat
109        // Shape: (c_new_link * c_s1 * c_s2) x (c_right_a * c_right_b)
110        let rows = c_new_link * c_s1 * c_s2;
111        let cols = c_right_a * c_right_b;
112        let mut c_mat: Matrix2<T> = matrix2_zeros(rows, cols);
113
114        for ln in 0..r_left_dim {
115            for la in 0..r_a_dim {
116                for lb in 0..r_b_dim {
117                    let r_idx = (ln * r_a_dim + la) * r_b_dim + lb;
118                    let r_val = r_data[r_idx];
119                    for s1 in 0..c_s1 {
120                        for s2 in 0..c_s2 {
121                            for k in 0..s2_a {
122                                for ra in 0..c_right_a {
123                                    for rb in 0..c_right_b {
124                                        let row = (ln * c_s1 + s1) * c_s2 + s2;
125                                        let col = ra * c_right_b + rb;
126                                        c_mat[[row, col]] = c_mat[[row, col]]
127                                            + r_val
128                                                * *a.get4(la, s1, k, ra)
129                                                * *b.get4(lb, k, s2, rb);
130                                    }
131                                }
132                            }
133                        }
134                    }
135                }
136            }
137        }
138
139        if i == n - 1 {
140            // Last site: just reshape and store
141            // Result should have right_dim = 1
142            let mut result_tensor = tensor4_zeros(c_new_link, c_s1, c_s2, 1);
143            for ln in 0..c_new_link {
144                for s1 in 0..c_s1 {
145                    for s2 in 0..c_s2 {
146                        let row = (ln * c_s1 + s1) * c_s2 + s2;
147                        // Sum over all right indices (should be 1x1)
148                        let val = c_mat[[row, 0]];
149                        result_tensor.set4(ln, s1, s2, 0, val);
150                    }
151                }
152            }
153            result_tensors.push(result_tensor);
154        } else {
155            // Factorize C into left and right parts
156            let fact_result = factorize(&c_mat, &factorize_opts)?;
157            let new_bond_dim = fact_result.rank.max(1);
158            let left_rows = fact_result.left.dim(0);
159            let left_cols = fact_result.left.dim(1);
160            let right_rows = fact_result.right.dim(0);
161            let right_cols = fact_result.right.dim(1);
162
163            // Store left factor as site tensor: (c_new_link, c_s1, c_s2, new_bond_dim)
164            let mut result_tensor = tensor4_zeros(c_new_link, c_s1, c_s2, new_bond_dim);
165            for ln in 0..c_new_link {
166                for s1 in 0..c_s1 {
167                    for s2 in 0..c_s2 {
168                        for r in 0..new_bond_dim {
169                            let row = (ln * c_s1 + s1) * c_s2 + s2;
170                            if row < left_rows && r < left_cols {
171                                result_tensor.set4(ln, s1, s2, r, fact_result.left[[row, r]]);
172                            }
173                        }
174                    }
175                }
176            }
177            result_tensors.push(result_tensor);
178
179            // Update R tensor for next iteration: R[new_bond_dim, right_a, right_b]
180            r_left_dim = new_bond_dim;
181            r_a_dim = c_right_a;
182            r_b_dim = c_right_b;
183            r_data = vec![T::zero(); r_left_dim * r_a_dim * r_b_dim];
184
185            for l in 0..new_bond_dim {
186                for ra in 0..c_right_a {
187                    for rb in 0..c_right_b {
188                        let col = ra * c_right_b + rb;
189                        let r_idx = (l * r_a_dim + ra) * r_b_dim + rb;
190                        if l < right_rows && col < right_cols {
191                            r_data[r_idx] = fact_result.right[[l, col]];
192                        }
193                    }
194                }
195            }
196        }
197    }
198
199    Ok(MPO::from_tensors_unchecked(result_tensors))
200}
201
202#[cfg(test)]
203mod tests;