tensor4all_treetn/linsolve/square/projected_state.rs
1//! ProjectedState: 2-chain environment for RHS computation (square case).
2//!
3//! Computes the local RHS term consistent with the square linsolve setup:
4//! - ProjectedOperator builds `<ref|H|x>`
5//! - ProjectedState builds `<ref|b>`
6//!
7//! This returns a local tensor with open indices aligned to the current solution's
8//! local tensor (up to permutation), so GMRES can operate in the same vector space.
9//!
10//! This is the V_in = V_out specialized version.
11
12use std::hash::Hash;
13
14use anyhow::Result;
15
16use tensor4all_core::{AllowedPairs, IndexLike, TensorLike};
17
18use crate::linsolve::common::{EnvironmentCache, NetworkTopology};
19use crate::treetn::TreeTN;
20
21/// ProjectedState: Manages 2-chain environments for RHS computation.
22///
23/// This computes `<b|x_local>` for each local region during the sweep.
24///
25/// For Tree Tensor Networks, the environment is computed by contracting
26/// all tensors outside the "open region" into environment tensors.
27/// The open region consists of nodes being updated in the current sweep step.
28///
29/// # Structure
30///
31/// For each edge (from, to) pointing towards the open region, we cache:
32/// ```text
33/// env[(from, to)] = contraction of:
34/// - bra tensor at `from` (conjugated RHS)
35/// - ket tensor at `from` (current solution)
36/// - all child environments (edges pointing away from `to`)
37/// ```
38///
39/// This forms a "2-chain" overlap: `<b|x>` contracted over
40/// all nodes except the open region.
41pub struct ProjectedState<T, V>
42where
43 T: TensorLike,
44 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
45{
46 /// The RHS state |b⟩
47 pub rhs: TreeTN<T, V>,
48 /// Environment cache
49 pub envs: EnvironmentCache<T, V>,
50}
51
52impl<T, V> ProjectedState<T, V>
53where
54 T: TensorLike,
55 <T::Index as IndexLike>::Id:
56 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
57 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
58{
59 /// Create a new ProjectedState.
60 pub fn new(rhs: TreeTN<T, V>) -> Self {
61 Self {
62 rhs,
63 envs: EnvironmentCache::new(),
64 }
65 }
66
67 /// Compute the local constant term (local RHS) for the given region.
68 ///
69 /// This returns the local RHS tensors contracted with environments.
70 ///
71 /// For the square case, the `reference_state` is used as the bra (conjugated),
72 /// and `rhs` is used as the ket, i.e. environments are constructed for `<ref|b>`.
73 ///
74 /// # Arguments
75 /// * `region` - The nodes in the local update region
76 /// * `reference_state` - The current solution state (used as reference for environments)
77 /// * `topology` - The network topology
78 pub fn local_constant_term<NT: NetworkTopology<V>>(
79 &mut self,
80 region: &[V],
81 reference_state: &TreeTN<T, V>,
82 topology: &NT,
83 ) -> Result<T> {
84 // Ensure environments are computed
85 self.ensure_environments(region, reference_state, topology)?;
86
87 // Collect all tensors to contract: local RHS tensors + environments
88 let mut all_tensors: Vec<T> = Vec::new();
89
90 // Collect local RHS tensors (ket side; do NOT conjugate)
91 for node in region {
92 let node_idx = self
93 .rhs
94 .node_index(node)
95 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in RHS", node))?;
96 let tensor = self
97 .rhs
98 .tensor(node_idx)
99 .ok_or_else(|| anyhow::anyhow!("Tensor not found in RHS"))?
100 .clone();
101 all_tensors.push(tensor);
102 }
103
104 // Collect environments from neighbors outside the region
105 for node in region {
106 for neighbor in topology.neighbors(node) {
107 if !region.contains(&neighbor) {
108 if let Some(env) = self.envs.get(&neighbor, node) {
109 all_tensors.push(env.clone());
110 }
111 }
112 }
113 }
114
115 // Use T::contract for optimal contraction ordering
116 let tensor_refs: Vec<&T> = all_tensors.iter().collect();
117 T::contract(&tensor_refs, AllowedPairs::All)
118 }
119
120 /// Ensure environments are computed for neighbors of the region.
121 fn ensure_environments<NT: NetworkTopology<V>>(
122 &mut self,
123 region: &[V],
124 reference_state: &TreeTN<T, V>,
125 topology: &NT,
126 ) -> Result<()> {
127 for node in region {
128 for neighbor in topology.neighbors(node) {
129 if !region.contains(&neighbor) && !self.envs.contains(&neighbor, node) {
130 let env =
131 self.compute_environment(&neighbor, node, reference_state, topology)?;
132 self.envs.insert(neighbor.clone(), node.clone(), env);
133 }
134 }
135 }
136 Ok(())
137 }
138
139 /// Recursively compute environment for edge (from, to).
140 ///
141 /// Computes `<ref|b>` partial contraction at node `from`.
142 fn compute_environment<NT: NetworkTopology<V>>(
143 &mut self,
144 from: &V,
145 to: &V,
146 reference_state: &TreeTN<T, V>,
147 topology: &NT,
148 ) -> Result<T> {
149 // First, ensure child environments are computed
150 let child_neighbors: Vec<V> = topology.neighbors(from).filter(|n| n != to).collect();
151
152 for child in &child_neighbors {
153 if !self.envs.contains(child, from) {
154 let child_env = self.compute_environment(child, from, reference_state, topology)?;
155 self.envs.insert(child.clone(), from.clone(), child_env);
156 }
157 }
158
159 // Collect child environments
160 let child_envs: Vec<T> = child_neighbors
161 .iter()
162 .filter_map(|child| self.envs.get(child, from).cloned())
163 .collect();
164
165 // Contract bra (reference_state) with ket (RHS) at this node
166 let node_idx_ref = reference_state
167 .node_index(from)
168 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in reference_state", from))?;
169 let node_idx_b = self
170 .rhs
171 .node_index(from)
172 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in RHS", from))?;
173
174 let tensor_ref = reference_state
175 .tensor(node_idx_ref)
176 .ok_or_else(|| anyhow::anyhow!("Tensor not found in reference_state"))?;
177 let tensor_b = self
178 .rhs
179 .tensor(node_idx_b)
180 .ok_or_else(|| anyhow::anyhow!("Tensor not found in RHS"))?;
181
182 let bra_conj = tensor_ref.conj();
183
184 // Contract bra and ket - T::contract auto-detects contractable pairs
185 let bra_ket = T::contract(&[&bra_conj, tensor_b], AllowedPairs::All)?;
186
187 // Contract bra*ket with child environments using T::contract
188 if child_envs.is_empty() {
189 Ok(bra_ket)
190 } else {
191 let mut all_tensors: Vec<&T> = vec![&bra_ket];
192 all_tensors.extend(child_envs.iter());
193 T::contract(&all_tensors, AllowedPairs::All)
194 }
195 }
196
197 /// Invalidate caches affected by updates to the given region.
198 pub fn invalidate<NT: NetworkTopology<V>>(&mut self, region: &[V], topology: &NT) {
199 self.envs.invalidate(region, topology);
200 }
201}