Skip to main content

tensor4all_simplett/mpo/
contract_naive.rs

1//! Naive MPO contraction algorithm
2//!
3//! Contracts two MPOs by directly multiplying site tensors,
4//! optionally followed by compression.
5
6use super::contraction::ContractionOptions;
7use super::environment::contract_site_tensors;
8use super::error::{MPOError, Result};
9use super::factorize::{factorize, FactorizeOptions, SVDScalar};
10use super::mpo::MPO;
11use super::types::{tensor4_zeros, Tensor4, Tensor4Ops};
12use super::{matrix2_zeros, Matrix2};
13use crate::einsum_helper::EinsumScalar;
14
15/// Perform naive contraction of two MPOs
16///
17/// This computes C = A * B where the contraction is over the shared
18/// physical index (s2 of A contracts with s1 of B).
19///
20/// The naive algorithm:
21/// 1. Contract each pair of site tensors
22/// 2. Optionally compress the result
23///
24/// # Arguments
25/// * `mpo_a` - First MPO
26/// * `mpo_b` - Second MPO
27/// * `options` - Optional compression options
28///
29/// # Returns
30/// The contracted MPO C with dimensions:
31/// - s1: from A
32/// - s2: from B
33/// - bond dimensions: product of input bond dimensions (before compression)
34pub fn contract_naive<T: SVDScalar + EinsumScalar>(
35    mpo_a: &MPO<T>,
36    mpo_b: &MPO<T>,
37    options: Option<ContractionOptions>,
38) -> Result<MPO<T>>
39where
40    <T as num_complex::ComplexFloat>::Real: Into<f64>,
41{
42    if mpo_a.len() != mpo_b.len() {
43        return Err(MPOError::LengthMismatch {
44            expected: mpo_a.len(),
45            got: mpo_b.len(),
46        });
47    }
48
49    if mpo_a.is_empty() {
50        return Ok(MPO::from_tensors_unchecked(Vec::new()));
51    }
52
53    let n = mpo_a.len();
54
55    // Validate shared dimensions
56    for i in 0..n {
57        let (_, s2_a) = mpo_a.site_dim(i);
58        let (s1_b, _) = mpo_b.site_dim(i);
59        if s2_a != s1_b {
60            return Err(MPOError::SharedDimensionMismatch {
61                site: i,
62                dim_a: s2_a,
63                dim_b: s1_b,
64            });
65        }
66    }
67
68    // Contract each pair of site tensors
69    let mut tensors: Vec<Tensor4<T>> = Vec::with_capacity(n);
70
71    for i in 0..n {
72        let a = mpo_a.site_tensor(i);
73        let b = mpo_b.site_tensor(i);
74
75        // Contract over shared index: a.s2 = b.s1
76        // Result has shape:
77        // (left_a * left_b, s1_a, s2_b, right_a * right_b)
78        let contracted = contract_site_tensors(a, b)?;
79        tensors.push(contracted);
80    }
81
82    let mut result = MPO::from_tensors_unchecked(tensors);
83
84    // Apply compression if options are provided
85    if let Some(opts) = options {
86        compress_mpo(&mut result, &opts)?;
87    }
88
89    Ok(result)
90}
91
92/// Compress an MPO using the specified options
93fn compress_mpo<T: SVDScalar>(mpo: &mut MPO<T>, options: &ContractionOptions) -> Result<()>
94where
95    <T as num_complex::ComplexFloat>::Real: Into<f64>,
96{
97    if mpo.len() <= 1 {
98        return Ok(());
99    }
100
101    let factorize_opts = FactorizeOptions {
102        method: options.factorize_method,
103        tolerance: options.tolerance,
104        max_rank: options.max_bond_dim,
105        left_orthogonal: true,
106        ..Default::default()
107    };
108
109    // Sweep left to right, factorizing each bond
110    for i in 0..(mpo.len() - 1) {
111        let tensor = mpo.site_tensor(i);
112        let left_dim = tensor.left_dim();
113        let s1 = tensor.site_dim_1();
114        let s2 = tensor.site_dim_2();
115        let right_dim = tensor.right_dim();
116
117        // Reshape to matrix: (left * s1 * s2, right)
118        let rows = left_dim * s1 * s2;
119        let cols = right_dim;
120        let mut mat: Matrix2<T> = matrix2_zeros(rows, cols);
121        for l in 0..left_dim {
122            for i1 in 0..s1 {
123                for i2 in 0..s2 {
124                    let row = (l * s1 + i1) * s2 + i2;
125                    for col in 0..cols {
126                        mat[[row, col]] = *tensor.get4(l, i1, i2, col);
127                    }
128                }
129            }
130        }
131
132        // Factorize
133        let fact_result = factorize(&mat, &factorize_opts)?;
134        let new_rank = fact_result.rank;
135
136        // Update current tensor with left factor
137        let mut new_tensor = tensor4_zeros(left_dim, s1, s2, new_rank);
138        let left_rows = fact_result.left.dim(0);
139        let left_cols = fact_result.left.dim(1);
140        for l in 0..left_dim {
141            for i1 in 0..s1 {
142                for i2 in 0..s2 {
143                    let row = (l * s1 + i1) * s2 + i2;
144                    for r in 0..new_rank {
145                        if row < left_rows && r < left_cols {
146                            new_tensor.set4(l, i1, i2, r, fact_result.left[[row, r]]);
147                        }
148                    }
149                }
150            }
151        }
152
153        // Get next tensor's dimensions
154        let next_tensor = mpo.site_tensor(i + 1);
155        let next_left = next_tensor.left_dim();
156        let next_s1 = next_tensor.site_dim_1();
157        let next_s2 = next_tensor.site_dim_2();
158        let next_right = next_tensor.right_dim();
159
160        // Multiply right factor into next tensor
161        // R[new_rank, right_dim] @ next[right_dim, s1, s2, next_right]
162        // = new_next[new_rank, s1, s2, next_right]
163        let right_rows = fact_result.right.dim(0);
164        let right_cols = fact_result.right.dim(1);
165        let mut new_next = tensor4_zeros(new_rank, next_s1, next_s2, next_right);
166        for l in 0..new_rank {
167            for i1 in 0..next_s1 {
168                for i2 in 0..next_s2 {
169                    for r in 0..next_right {
170                        let mut sum = T::zero();
171                        for k in 0..next_left.min(right_cols) {
172                            if l < right_rows {
173                                sum = sum
174                                    + fact_result.right[[l, k]] * *next_tensor.get4(k, i1, i2, r);
175                            }
176                        }
177                        new_next.set4(l, i1, i2, r, sum);
178                    }
179                }
180            }
181        }
182
183        // Update tensors in place
184        *mpo.site_tensor_mut(i) = new_tensor;
185        *mpo.site_tensor_mut(i + 1) = new_next;
186    }
187
188    Ok(())
189}
190
191#[cfg(test)]
192mod tests;