tensor4all_simplett/mpo/
environment.rs1use crate::einsum_helper::{einsum_tensors, typed_tensor_reshape, EinsumScalar};
7
8use super::error::{MPOError, Result};
9use super::factorize::SVDScalar;
10use super::mpo::MPO;
11use super::types::{Tensor4, Tensor4Ops};
12use super::{matrix2_zeros, Matrix2};
13pub fn contract_site_tensors<T: SVDScalar + EinsumScalar>(
24 a: &Tensor4<T>,
25 b: &Tensor4<T>,
26) -> Result<Tensor4<T>>
27where
28 <T as num_complex::ComplexFloat>::Real: Into<f64>,
29{
30 if a.site_dim_2() != b.site_dim_1() {
32 return Err(MPOError::SharedDimensionMismatch {
33 site: 0,
34 dim_a: a.site_dim_2(),
35 dim_b: b.site_dim_1(),
36 });
37 }
38
39 let left_a = a.left_dim();
40 let s1_a = a.site_dim_1();
41 let right_a = a.right_dim();
42
43 let left_b = b.left_dim();
44 let s2_b = b.site_dim_2();
45 let right_b = b.right_dim();
46
47 let new_left = left_a * left_b;
49 let new_s1 = s1_a;
50 let new_s2 = s2_b;
51 let new_right = right_a * right_b;
52
53 let contracted = einsum_tensors("askr,bktq->bastqr", &[a.as_inner(), b.as_inner()]);
61 let reshaped = typed_tensor_reshape(&contracted, &[new_left, new_s1, new_s2, new_right]);
62
63 Ok(Tensor4::from_tenferro(reshaped))
64}
65
66pub fn left_environment<T: SVDScalar + EinsumScalar>(
74 mpo_a: &MPO<T>,
75 mpo_b: &MPO<T>,
76 site: usize,
77 cache: &mut Vec<Option<Matrix2<T>>>,
78) -> Result<Matrix2<T>>
79where
80 <T as num_complex::ComplexFloat>::Real: Into<f64>,
81{
82 if mpo_a.len() != mpo_b.len() {
83 return Err(MPOError::LengthMismatch {
84 expected: mpo_a.len(),
85 got: mpo_b.len(),
86 });
87 }
88
89 if site == 0 {
91 let mut env: Matrix2<T> = matrix2_zeros(1, 1);
92 env[[0, 0]] = T::one();
93 return Ok(env);
94 }
95
96 if site <= cache.len() && cache[site - 1].is_some() {
98 return Ok(cache[site - 1].as_ref().unwrap().clone());
99 }
100
101 let prev_env = left_environment(mpo_a, mpo_b, site - 1, cache)?;
103 let a = mpo_a.site_tensor(site - 1);
104 let b = mpo_b.site_tensor(site - 1);
105
106 if a.site_dim_1() != b.site_dim_1() || a.site_dim_2() != b.site_dim_2() {
116 return Err(MPOError::SharedDimensionMismatch {
117 site: site - 1,
118 dim_a: a.site_dim_1(),
119 dim_b: b.site_dim_1(),
120 });
121 }
122
123 let new_env = Matrix2::from_tenferro(einsum_tensors(
124 "ab,asdr,bsdt->rt",
125 &[prev_env.as_inner(), a.as_inner(), b.as_inner()],
126 ));
127
128 while cache.len() < site {
130 cache.push(None);
131 }
132 cache[site - 1] = Some(new_env.clone());
133
134 Ok(new_env)
135}
136
137pub fn right_environment<T: SVDScalar + EinsumScalar>(
145 mpo_a: &MPO<T>,
146 mpo_b: &MPO<T>,
147 site: usize,
148 cache: &mut Vec<Option<Matrix2<T>>>,
149) -> Result<Matrix2<T>>
150where
151 <T as num_complex::ComplexFloat>::Real: Into<f64>,
152{
153 if mpo_a.len() != mpo_b.len() {
154 return Err(MPOError::LengthMismatch {
155 expected: mpo_a.len(),
156 got: mpo_b.len(),
157 });
158 }
159
160 let n = mpo_a.len();
161
162 if site == n - 1 {
164 let mut env: Matrix2<T> = matrix2_zeros(1, 1);
165 env[[0, 0]] = T::one();
166 return Ok(env);
167 }
168
169 let cache_idx = n - site - 2;
171 if cache_idx < cache.len() && cache[cache_idx].is_some() {
172 return Ok(cache[cache_idx].as_ref().unwrap().clone());
173 }
174
175 let prev_env = right_environment(mpo_a, mpo_b, site + 1, cache)?;
177 let a = mpo_a.site_tensor(site + 1);
178 let b = mpo_b.site_tensor(site + 1);
179
180 if a.site_dim_1() != b.site_dim_1() || a.site_dim_2() != b.site_dim_2() {
190 return Err(MPOError::SharedDimensionMismatch {
191 site: site + 1,
192 dim_a: a.site_dim_1(),
193 dim_b: b.site_dim_1(),
194 });
195 }
196
197 let new_env = Matrix2::from_tenferro(einsum_tensors(
198 "rt,asdr,bsdt->ab",
199 &[prev_env.as_inner(), a.as_inner(), b.as_inner()],
200 ));
201
202 while cache.len() <= cache_idx {
204 cache.push(None);
205 }
206 cache[cache_idx] = Some(new_env.clone());
207
208 Ok(new_env)
209}
210
211#[cfg(test)]
212mod tests;