Skip to main content

tensor4all_simplett/mpo/
environment.rs

1//! Environment computation for MPO contractions
2//!
3//! This module provides functions for computing left and right environments
4//! used in variational algorithms and efficient MPO evaluation.
5
6use crate::einsum_helper::{einsum_tensors, typed_tensor_reshape, EinsumScalar};
7
8use super::error::{MPOError, Result};
9use super::factorize::SVDScalar;
10use super::mpo::MPO;
11use super::types::{Tensor4, Tensor4Ops};
12use super::{matrix2_zeros, Matrix2};
13/// Contract two 4D site tensors over their shared physical index
14///
15/// Given two 4D tensors:
16/// - A: (left_a, s1_a, s2_a, right_a) where s2_a is the shared index
17/// - B: (left_b, s1_b, s2_b, right_b) where s1_b is the shared index
18///
19/// Contracts over s2_a = s1_b to produce:
20/// - C: (left_a * left_b, s1_a, s2_b, right_a * right_b)
21///
22/// This is the Rust equivalent of `_contractsitetensors` from Julia.
23pub fn contract_site_tensors<T: SVDScalar + EinsumScalar>(
24    a: &Tensor4<T>,
25    b: &Tensor4<T>,
26) -> Result<Tensor4<T>>
27where
28    <T as num_complex::ComplexFloat>::Real: Into<f64>,
29{
30    // Check that shared dimensions match
31    if a.site_dim_2() != b.site_dim_1() {
32        return Err(MPOError::SharedDimensionMismatch {
33            site: 0,
34            dim_a: a.site_dim_2(),
35            dim_b: b.site_dim_1(),
36        });
37    }
38
39    let left_a = a.left_dim();
40    let s1_a = a.site_dim_1();
41    let right_a = a.right_dim();
42
43    let left_b = b.left_dim();
44    let s2_b = b.site_dim_2();
45    let right_b = b.right_dim();
46
47    // Result dimensions
48    let new_left = left_a * left_b;
49    let new_s1 = s1_a;
50    let new_s2 = s2_b;
51    let new_right = right_a * right_b;
52
53    // Arrange the open bond indices as (left_b, left_a, ..., right_b, right_a)
54    // so the column-major reshape preserves the existing la * left_b + lb and
55    // ra * right_b + rb indexing.
56    //
57    // TODO: Remove this materialization once tenferro supports reshaping
58    // layout-compatible strided views directly.
59    // Tracking issue: https://github.com/tensor4all/tenferro-rs/issues/575
60    let contracted = einsum_tensors("askr,bktq->bastqr", &[a.as_inner(), b.as_inner()]);
61    let reshaped = typed_tensor_reshape(&contracted, &[new_left, new_s1, new_s2, new_right]);
62
63    Ok(Tensor4::from_tenferro(reshaped))
64}
65
66/// Compute the left environment at site i for MPO contraction
67///
68/// The left environment L\[i\] represents the contraction of all sites 0..i
69/// for the product of two MPOs A and B.
70///
71/// L\[i\] has shape (left_a_i, left_b_i) representing the accumulated
72/// contraction from the left.
73pub fn left_environment<T: SVDScalar + EinsumScalar>(
74    mpo_a: &MPO<T>,
75    mpo_b: &MPO<T>,
76    site: usize,
77    cache: &mut Vec<Option<Matrix2<T>>>,
78) -> Result<Matrix2<T>>
79where
80    <T as num_complex::ComplexFloat>::Real: Into<f64>,
81{
82    if mpo_a.len() != mpo_b.len() {
83        return Err(MPOError::LengthMismatch {
84            expected: mpo_a.len(),
85            got: mpo_b.len(),
86        });
87    }
88
89    // Base case: left of site 0 is just [[1]]
90    if site == 0 {
91        let mut env: Matrix2<T> = matrix2_zeros(1, 1);
92        env[[0, 0]] = T::one();
93        return Ok(env);
94    }
95
96    // Check cache
97    if site <= cache.len() && cache[site - 1].is_some() {
98        return Ok(cache[site - 1].as_ref().unwrap().clone());
99    }
100
101    // Recursively compute from the left
102    let prev_env = left_environment(mpo_a, mpo_b, site - 1, cache)?;
103    let a = mpo_a.site_tensor(site - 1);
104    let b = mpo_b.site_tensor(site - 1);
105
106    // Contract: L[i-1] with A[i-1] and B[i-1]
107    // L[i-1]: (left_a, left_b)
108    // A[i-1]: (left_a, s1_a, s2_a, right_a)
109    // B[i-1]: (left_b, s1_b, s2_b, right_b)
110    //
111    // Sum over left_a, left_b, s1_a (=s1_b), s2_a (=s2_b) to get:
112    // L[i]: (right_a, right_b)
113
114    // Check shared dimensions
115    if a.site_dim_1() != b.site_dim_1() || a.site_dim_2() != b.site_dim_2() {
116        return Err(MPOError::SharedDimensionMismatch {
117            site: site - 1,
118            dim_a: a.site_dim_1(),
119            dim_b: b.site_dim_1(),
120        });
121    }
122
123    let new_env = Matrix2::from_tenferro(einsum_tensors(
124        "ab,asdr,bsdt->rt",
125        &[prev_env.as_inner(), a.as_inner(), b.as_inner()],
126    ));
127
128    // Update cache
129    while cache.len() < site {
130        cache.push(None);
131    }
132    cache[site - 1] = Some(new_env.clone());
133
134    Ok(new_env)
135}
136
137/// Compute the right environment at site i for MPO contraction
138///
139/// The right environment R\[i\] represents the contraction of all sites i+1..L
140/// for the product of two MPOs A and B.
141///
142/// R\[i\] has shape (right_a_i, right_b_i) representing the accumulated
143/// contraction from the right.
144pub fn right_environment<T: SVDScalar + EinsumScalar>(
145    mpo_a: &MPO<T>,
146    mpo_b: &MPO<T>,
147    site: usize,
148    cache: &mut Vec<Option<Matrix2<T>>>,
149) -> Result<Matrix2<T>>
150where
151    <T as num_complex::ComplexFloat>::Real: Into<f64>,
152{
153    if mpo_a.len() != mpo_b.len() {
154        return Err(MPOError::LengthMismatch {
155            expected: mpo_a.len(),
156            got: mpo_b.len(),
157        });
158    }
159
160    let n = mpo_a.len();
161
162    // Base case: right of last site is just [[1]]
163    if site == n - 1 {
164        let mut env: Matrix2<T> = matrix2_zeros(1, 1);
165        env[[0, 0]] = T::one();
166        return Ok(env);
167    }
168
169    // Check cache
170    let cache_idx = n - site - 2;
171    if cache_idx < cache.len() && cache[cache_idx].is_some() {
172        return Ok(cache[cache_idx].as_ref().unwrap().clone());
173    }
174
175    // Recursively compute from the right
176    let prev_env = right_environment(mpo_a, mpo_b, site + 1, cache)?;
177    let a = mpo_a.site_tensor(site + 1);
178    let b = mpo_b.site_tensor(site + 1);
179
180    // Contract: R[i+1] with A[i+1] and B[i+1]
181    // R[i+1]: (right_a, right_b)
182    // A[i+1]: (left_a, s1_a, s2_a, right_a)
183    // B[i+1]: (left_b, s1_b, s2_b, right_b)
184    //
185    // Sum over right_a, right_b, s1_a (=s1_b), s2_a (=s2_b) to get:
186    // R[i]: (left_a, left_b)
187
188    // Check shared dimensions
189    if a.site_dim_1() != b.site_dim_1() || a.site_dim_2() != b.site_dim_2() {
190        return Err(MPOError::SharedDimensionMismatch {
191            site: site + 1,
192            dim_a: a.site_dim_1(),
193            dim_b: b.site_dim_1(),
194        });
195    }
196
197    let new_env = Matrix2::from_tenferro(einsum_tensors(
198        "rt,asdr,bsdt->ab",
199        &[prev_env.as_inner(), a.as_inner(), b.as_inner()],
200    ));
201
202    // Update cache
203    while cache.len() <= cache_idx {
204        cache.push(None);
205    }
206    cache[cache_idx] = Some(new_env.clone());
207
208    Ok(new_env)
209}
210
211#[cfg(test)]
212mod tests;