tensor4all_treetn/linsolve/common/
projected_operator.rs1use std::collections::HashMap;
12use std::hash::Hash;
13
14use anyhow::Result;
15
16use tensor4all_core::{AllowedPairs, IndexLike, TensorLike};
17
18use super::environment::{EnvironmentCache, NetworkTopology};
19use crate::operator::IndexMapping;
20use crate::treetn::TreeTN;
21
22pub struct ProjectedOperator<T, V>
44where
45 T: TensorLike,
46 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
47{
48 pub operator: TreeTN<T, V>,
50 pub envs: EnvironmentCache<T, V>,
52 pub input_mapping: Option<HashMap<V, IndexMapping<T::Index>>>,
55 pub output_mapping: Option<HashMap<V, IndexMapping<T::Index>>>,
57}
58
59impl<T, V> ProjectedOperator<T, V>
60where
61 T: TensorLike,
62 <T::Index as IndexLike>::Id:
63 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
64 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
65{
66 pub fn new(operator: TreeTN<T, V>) -> Self {
68 Self {
69 operator,
70 envs: EnvironmentCache::new(),
71 input_mapping: None,
72 output_mapping: None,
73 }
74 }
75
76 pub fn with_index_mappings(
82 operator: TreeTN<T, V>,
83 input_mapping: HashMap<V, IndexMapping<T::Index>>,
84 output_mapping: HashMap<V, IndexMapping<T::Index>>,
85 ) -> Self {
86 Self {
87 operator,
88 envs: EnvironmentCache::new(),
89 input_mapping: Some(input_mapping),
90 output_mapping: Some(output_mapping),
91 }
92 }
93
94 pub fn apply<NT: NetworkTopology<V>>(
115 &mut self,
116 v: &T,
117 region: &[V],
118 ket_state: &TreeTN<T, V>,
119 bra_state: &TreeTN<T, V>,
120 topology: &NT,
121 ) -> Result<T> {
122 self.ensure_environments(region, ket_state, bra_state, topology)?;
124
125 let mut all_tensors: Vec<T> = Vec::new();
126 let mut temp_out_to_true: Vec<(T::Index, T::Index)> = Vec::new();
127
128 if let (Some(ref input_mapping), Some(ref output_mapping)) =
129 (&self.input_mapping, &self.output_mapping)
130 {
131 let mut per_node: Vec<(T::Index, T::Index, T::Index)> = Vec::new();
135 for node in region {
136 let im = input_mapping
137 .get(node)
138 .ok_or_else(|| anyhow::anyhow!("Missing input_mapping for node {:?}", node))?;
139 let om = output_mapping
140 .get(node)
141 .ok_or_else(|| anyhow::anyhow!("Missing output_mapping for node {:?}", node))?;
142 let temp_in = im.internal_index.sim();
143 let temp_out = om.internal_index.sim();
144 per_node.push((temp_in, temp_out, om.true_index.clone()));
145 }
146
147 let mut transformed_v = v.clone();
148 for (node, (temp_in, _temp_out, _)) in region.iter().zip(per_node.iter()) {
149 let im = input_mapping.get(node).unwrap();
150 transformed_v = transformed_v.replaceind(&im.true_index, temp_in)?;
151 }
152 all_tensors.push(transformed_v);
153
154 for (node, (temp_in, temp_out, true_idx)) in region.iter().zip(per_node.iter()) {
155 let im = input_mapping.get(node).unwrap();
156 let om = output_mapping.get(node).unwrap();
157 let node_idx = self
158 .operator
159 .node_index(node)
160 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", node))?;
161 let mut t = self
162 .operator
163 .tensor(node_idx)
164 .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))?
165 .clone();
166 t = t.replaceind(&im.internal_index, temp_in)?;
167 t = t.replaceind(&om.internal_index, temp_out)?;
168 all_tensors.push(t);
169 temp_out_to_true.push((temp_out.clone(), true_idx.clone()));
170 }
171 } else {
172 all_tensors.push(v.clone());
174 for node in region {
175 let node_idx = self
176 .operator
177 .node_index(node)
178 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", node))?;
179 let tensor = self
180 .operator
181 .tensor(node_idx)
182 .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))?
183 .clone();
184 all_tensors.push(tensor);
185 }
186 }
187
188 for node in region {
190 for neighbor in topology.neighbors(node) {
191 if region.contains(&neighbor) {
192 continue;
193 }
194 if let Some(env) = self.envs.get(&neighbor, node) {
195 all_tensors.push(env.clone());
196 }
197 }
198 }
199
200 let tensor_refs: Vec<&T> = all_tensors.iter().collect();
201 let mut contracted = T::contract(&tensor_refs, AllowedPairs::All)?;
202
203 for (temp_out, true_idx) in &temp_out_to_true {
205 contracted = contracted.replaceind(temp_out, true_idx)?;
206 }
207
208 for node in region {
210 for neighbor in topology.neighbors(node) {
211 if region.contains(&neighbor) {
212 continue;
213 }
214 let ket_edge = match ket_state.edge_between(node, &neighbor) {
215 Some(e) => e,
216 None => continue,
217 };
218 let bra_edge = match bra_state.edge_between(node, &neighbor) {
219 Some(e) => e,
220 None => continue,
221 };
222 let ket_bond = match ket_state.bond_index(ket_edge) {
223 Some(b) => b.clone(),
224 None => continue,
225 };
226 let bra_bond = match bra_state.bond_index(bra_edge) {
227 Some(b) => b.clone(),
228 None => continue,
229 };
230 if contracted
231 .external_indices()
232 .iter()
233 .any(|i| i.id() == bra_bond.id())
234 {
235 contracted = contracted.replaceind(&bra_bond, &ket_bond)?;
236 }
237 }
238 }
239
240 let v_inds = v.external_indices();
242 let res_inds = contracted.external_indices();
243 let v_ids: std::collections::HashSet<_> = v_inds.iter().map(|i| i.id()).collect();
244 let res_ids: std::collections::HashSet<_> = res_inds.iter().map(|i| i.id()).collect();
245 if v_ids == res_ids && v_inds.len() == res_inds.len() {
246 contracted = contracted.permuteinds(&v_inds)?;
247 }
248
249 Ok(contracted)
250 }
251
252 fn ensure_environments<NT: NetworkTopology<V>>(
254 &mut self,
255 region: &[V],
256 ket_state: &TreeTN<T, V>,
257 bra_state: &TreeTN<T, V>,
258 topology: &NT,
259 ) -> Result<()> {
260 for node in region {
261 for neighbor in topology.neighbors(node) {
262 if !region.contains(&neighbor) && !self.envs.contains(&neighbor, node) {
263 let env =
264 self.compute_environment(&neighbor, node, ket_state, bra_state, topology)?;
265 self.envs.insert(neighbor.clone(), node.clone(), env);
266 }
267 }
268 }
269 Ok(())
270 }
271
272 fn compute_environment<NT: NetworkTopology<V>>(
281 &mut self,
282 from: &V,
283 to: &V,
284 ket_state: &TreeTN<T, V>,
285 bra_state: &TreeTN<T, V>,
286 topology: &NT,
287 ) -> Result<T> {
288 let child_neighbors: Vec<V> = topology.neighbors(from).filter(|n| n != to).collect();
290
291 for child in &child_neighbors {
292 if !self.envs.contains(child, from) {
293 let child_env =
294 self.compute_environment(child, from, ket_state, bra_state, topology)?;
295 self.envs.insert(child.clone(), from.clone(), child_env);
296 }
297 }
298
299 let child_envs: Vec<T> = child_neighbors
301 .iter()
302 .filter_map(|child| self.envs.get(child, from).cloned())
303 .collect();
304
305 let node_idx_bra = bra_state
307 .node_index(from)
308 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in bra_state", from))?;
309 let node_idx_op = self
310 .operator
311 .node_index(from)
312 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", from))?;
313 let node_idx_ket = ket_state
314 .node_index(from)
315 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in ket_state", from))?;
316
317 let tensor_bra = bra_state
318 .tensor(node_idx_bra)
319 .ok_or_else(|| anyhow::anyhow!("Tensor not found in bra_state"))?;
320 let tensor_op = self
321 .operator
322 .tensor(node_idx_op)
323 .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))?;
324 let tensor_ket = ket_state
325 .tensor(node_idx_ket)
326 .ok_or_else(|| anyhow::anyhow!("Tensor not found in ket_state"))?;
327
328 let bra_conj = tensor_bra.conj();
337
338 let transformed_ket = if let Some(ref input_mapping) = self.input_mapping {
340 if let Some(mapping) = input_mapping.get(from) {
341 tensor_ket.replaceind(&mapping.true_index, &mapping.internal_index)?
342 } else {
343 tensor_ket.clone()
344 }
345 } else {
346 tensor_ket.clone()
347 };
348
349 let transformed_bra_conj = if let Some(ref output_mapping) = self.output_mapping {
351 if let Some(mapping) = output_mapping.get(from) {
352 bra_conj.replaceind(&mapping.true_index, &mapping.internal_index)?
353 } else {
354 bra_conj.clone()
355 }
356 } else {
357 bra_conj.clone()
358 };
359
360 let mut tensor_refs: Vec<&T> = vec![&transformed_ket, tensor_op, &transformed_bra_conj];
363 tensor_refs.extend(child_envs.iter());
364 T::contract(&tensor_refs, AllowedPairs::All)
365 }
366
367 pub fn local_dimension(&self, region: &[V]) -> usize {
369 let mut dim = 1;
370 for node in region {
371 if let Some(site_space) = self.operator.site_space(node) {
372 for idx in site_space {
373 dim *= idx.dim();
374 }
375 }
376 }
377 dim
378 }
379
380 pub fn invalidate<NT: NetworkTopology<V>>(&mut self, region: &[V], topology: &NT) {
382 self.envs.invalidate(region, topology);
383 }
384}