tensor4all_simplett/mpo/
contract_naive.rs1use super::contraction::ContractionOptions;
7use super::environment::contract_site_tensors;
8use super::error::{MPOError, Result};
9use super::factorize::{factorize, FactorizeOptions, SVDScalar};
10use super::mpo::MPO;
11use super::types::{tensor4_zeros, Tensor4, Tensor4Ops};
12use super::{matrix2_zeros, Matrix2};
13use crate::einsum_helper::EinsumScalar;
14
15pub fn contract_naive<T: SVDScalar + EinsumScalar>(
35 mpo_a: &MPO<T>,
36 mpo_b: &MPO<T>,
37 options: Option<ContractionOptions>,
38) -> Result<MPO<T>>
39where
40 <T as num_complex::ComplexFloat>::Real: Into<f64>,
41{
42 if mpo_a.len() != mpo_b.len() {
43 return Err(MPOError::LengthMismatch {
44 expected: mpo_a.len(),
45 got: mpo_b.len(),
46 });
47 }
48
49 if mpo_a.is_empty() {
50 return Ok(MPO::from_tensors_unchecked(Vec::new()));
51 }
52
53 let n = mpo_a.len();
54
55 for i in 0..n {
57 let (_, s2_a) = mpo_a.site_dim(i);
58 let (s1_b, _) = mpo_b.site_dim(i);
59 if s2_a != s1_b {
60 return Err(MPOError::SharedDimensionMismatch {
61 site: i,
62 dim_a: s2_a,
63 dim_b: s1_b,
64 });
65 }
66 }
67
68 let mut tensors: Vec<Tensor4<T>> = Vec::with_capacity(n);
70
71 for i in 0..n {
72 let a = mpo_a.site_tensor(i);
73 let b = mpo_b.site_tensor(i);
74
75 let contracted = contract_site_tensors(a, b)?;
79 tensors.push(contracted);
80 }
81
82 let mut result = MPO::from_tensors_unchecked(tensors);
83
84 if let Some(opts) = options {
86 compress_mpo(&mut result, &opts)?;
87 }
88
89 Ok(result)
90}
91
92fn compress_mpo<T: SVDScalar>(mpo: &mut MPO<T>, options: &ContractionOptions) -> Result<()>
94where
95 <T as num_complex::ComplexFloat>::Real: Into<f64>,
96{
97 if mpo.len() <= 1 {
98 return Ok(());
99 }
100
101 let factorize_opts = FactorizeOptions {
102 method: options.factorize_method,
103 tolerance: options.tolerance,
104 max_rank: options.max_bond_dim,
105 left_orthogonal: true,
106 ..Default::default()
107 };
108
109 for i in 0..(mpo.len() - 1) {
111 let tensor = mpo.site_tensor(i);
112 let left_dim = tensor.left_dim();
113 let s1 = tensor.site_dim_1();
114 let s2 = tensor.site_dim_2();
115 let right_dim = tensor.right_dim();
116
117 let rows = left_dim * s1 * s2;
119 let cols = right_dim;
120 let mut mat: Matrix2<T> = matrix2_zeros(rows, cols);
121 for l in 0..left_dim {
122 for i1 in 0..s1 {
123 for i2 in 0..s2 {
124 let row = (l * s1 + i1) * s2 + i2;
125 for col in 0..cols {
126 mat[[row, col]] = *tensor.get4(l, i1, i2, col);
127 }
128 }
129 }
130 }
131
132 let fact_result = factorize(&mat, &factorize_opts)?;
134 let new_rank = fact_result.rank;
135
136 let mut new_tensor = tensor4_zeros(left_dim, s1, s2, new_rank);
138 let left_rows = fact_result.left.dim(0);
139 let left_cols = fact_result.left.dim(1);
140 for l in 0..left_dim {
141 for i1 in 0..s1 {
142 for i2 in 0..s2 {
143 let row = (l * s1 + i1) * s2 + i2;
144 for r in 0..new_rank {
145 if row < left_rows && r < left_cols {
146 new_tensor.set4(l, i1, i2, r, fact_result.left[[row, r]]);
147 }
148 }
149 }
150 }
151 }
152
153 let next_tensor = mpo.site_tensor(i + 1);
155 let next_left = next_tensor.left_dim();
156 let next_s1 = next_tensor.site_dim_1();
157 let next_s2 = next_tensor.site_dim_2();
158 let next_right = next_tensor.right_dim();
159
160 let right_rows = fact_result.right.dim(0);
164 let right_cols = fact_result.right.dim(1);
165 let mut new_next = tensor4_zeros(new_rank, next_s1, next_s2, next_right);
166 for l in 0..new_rank {
167 for i1 in 0..next_s1 {
168 for i2 in 0..next_s2 {
169 for r in 0..next_right {
170 let mut sum = T::zero();
171 for k in 0..next_left.min(right_cols) {
172 if l < right_rows {
173 sum = sum
174 + fact_result.right[[l, k]] * *next_tensor.get4(k, i1, i2, r);
175 }
176 }
177 new_next.set4(l, i1, i2, r, sum);
178 }
179 }
180 }
181 }
182
183 *mpo.site_tensor_mut(i) = new_tensor;
185 *mpo.site_tensor_mut(i + 1) = new_next;
186 }
187
188 Ok(())
189}
190
191#[cfg(test)]
192mod tests;