Skip to main content

tensor4all_treetn/linsolve/common/
projected_operator.rs

1//! ProjectedOperator: 3-chain environment for operator application.
2//!
3//! Computes `<psi|H|v>` efficiently for Tree Tensor Networks.
4//!
5//! # Index Mapping
6//!
7//! For MPOs where input/output site indices have different IDs from the state's
8//! site indices (required because a tensor cannot have two indices with the same ID),
9//! use `with_index_mappings` to define the correspondence.
10
11use std::collections::HashMap;
12use std::hash::Hash;
13
14use anyhow::Result;
15
16use tensor4all_core::{AllowedPairs, IndexLike, TensorLike};
17
18use super::environment::{EnvironmentCache, NetworkTopology};
19use crate::operator::IndexMapping;
20use crate::treetn::TreeTN;
21
22/// ProjectedOperator: Manages 3-chain environments for operator application.
23///
24/// This computes `<psi|H|v>` for each local region during the sweep.
25///
26/// For Tree Tensor Networks, the environment is computed by contracting
27/// all tensors outside the "open region" into environment tensors.
28/// The open region consists of nodes being updated in the current sweep step.
29///
30/// # Structure
31///
32/// For each edge (from, to) pointing towards the open region, we cache:
33/// ```text
34/// env[(from, to)] = contraction of:
35///   - bra tensor at `from` (conjugated)
36///   - operator tensor at `from`
37///   - ket tensor at `from`
38///   - all child environments (edges pointing away from `to`)
39/// ```
40///
41/// This forms a "3-chain" sandwich: `<bra| H |ket>` contracted over
42/// all nodes except the open region.
43pub struct ProjectedOperator<T, V>
44where
45    T: TensorLike,
46    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
47{
48    /// The operator H
49    pub operator: TreeTN<T, V>,
50    /// Environment cache
51    pub envs: EnvironmentCache<T, V>,
52    /// Input index mapping (true site index -> MPO's internal input index)
53    /// Used when MPO has internal indices different from state's site indices.
54    pub input_mapping: Option<HashMap<V, IndexMapping<T::Index>>>,
55    /// Output index mapping (true site index -> MPO's internal output index)
56    pub output_mapping: Option<HashMap<V, IndexMapping<T::Index>>>,
57}
58
59impl<T, V> ProjectedOperator<T, V>
60where
61    T: TensorLike,
62    <T::Index as IndexLike>::Id:
63        Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
64    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
65{
66    /// Create a new ProjectedOperator.
67    pub fn new(operator: TreeTN<T, V>) -> Self {
68        Self {
69            operator,
70            envs: EnvironmentCache::new(),
71            input_mapping: None,
72            output_mapping: None,
73        }
74    }
75
76    /// Create a new ProjectedOperator with index mappings from a LinearOperator.
77    ///
78    /// The mappings define how state's site indices relate to MPO's internal indices.
79    /// This is required when the MPO uses internal indices (s_in_tmp, s_out_tmp)
80    /// that differ from the state's site indices.
81    pub fn with_index_mappings(
82        operator: TreeTN<T, V>,
83        input_mapping: HashMap<V, IndexMapping<T::Index>>,
84        output_mapping: HashMap<V, IndexMapping<T::Index>>,
85    ) -> Self {
86        Self {
87            operator,
88            envs: EnvironmentCache::new(),
89            input_mapping: Some(input_mapping),
90            output_mapping: Some(output_mapping),
91        }
92    }
93
94    /// Apply the operator to a local tensor: compute `H|v⟩` at the current position.
95    ///
96    /// If index mappings are set (via `with_index_mappings`), this method:
97    /// 1. Transforms input `v`'s site indices using **unique** temp indices (avoids duplicate IDs)
98    /// 2. Contracts with MPO tensors and environment tensors
99    /// 3. Transforms result's temp output indices back to true site indices
100    /// 4. Replaces bra-side boundary bonds with ket-side so output lives in same space as `v`
101    /// 5. Permutes result to `v`'s index order so output structure matches input
102    ///
103    /// # Arguments
104    /// * `v` - The local tensor to apply the operator to
105    /// * `region` - The nodes in the open region
106    /// * `ket_state` - The current state |ket⟩ (used for ket in environment computation)
107    /// * `bra_state` - The reference state ⟨bra| (used for bra in environment computation)
108    ///   For V_in = V_out, this is the same as ket_state.
109    ///   For V_in ≠ V_out, this should be a state in V_out.
110    /// * `topology` - Network topology for traversal
111    ///
112    /// # Returns
113    /// The result of applying H to v: `H|v⟩`, with same index set and order as `v`.
114    pub fn apply<NT: NetworkTopology<V>>(
115        &mut self,
116        v: &T,
117        region: &[V],
118        ket_state: &TreeTN<T, V>,
119        bra_state: &TreeTN<T, V>,
120        topology: &NT,
121    ) -> Result<T> {
122        // Ensure environments are computed
123        self.ensure_environments(region, ket_state, bra_state, topology)?;
124
125        let mut all_tensors: Vec<T> = Vec::new();
126        let mut temp_out_to_true: Vec<(T::Index, T::Index)> = Vec::new();
127
128        if let (Some(ref input_mapping), Some(ref output_mapping)) =
129            (&self.input_mapping, &self.output_mapping)
130        {
131            // MPO-with-mappings path: use unique temp indices to avoid duplicate IDs.
132            // Replace true_index -> temp_in on v (never use internal_index on v).
133            // Use same temp_in/temp_out on op tensors so they contract with v.
134            let mut per_node: Vec<(T::Index, T::Index, T::Index)> = Vec::new();
135            for node in region {
136                let im = input_mapping
137                    .get(node)
138                    .ok_or_else(|| anyhow::anyhow!("Missing input_mapping for node {:?}", node))?;
139                let om = output_mapping
140                    .get(node)
141                    .ok_or_else(|| anyhow::anyhow!("Missing output_mapping for node {:?}", node))?;
142                let temp_in = im.internal_index.sim();
143                let temp_out = om.internal_index.sim();
144                per_node.push((temp_in, temp_out, om.true_index.clone()));
145            }
146
147            let mut transformed_v = v.clone();
148            for (node, (temp_in, _temp_out, _)) in region.iter().zip(per_node.iter()) {
149                let im = input_mapping.get(node).unwrap();
150                transformed_v = transformed_v.replaceind(&im.true_index, temp_in)?;
151            }
152            all_tensors.push(transformed_v);
153
154            for (node, (temp_in, temp_out, true_idx)) in region.iter().zip(per_node.iter()) {
155                let im = input_mapping.get(node).unwrap();
156                let om = output_mapping.get(node).unwrap();
157                let node_idx = self
158                    .operator
159                    .node_index(node)
160                    .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", node))?;
161                let mut t = self
162                    .operator
163                    .tensor(node_idx)
164                    .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))?
165                    .clone();
166                t = t.replaceind(&im.internal_index, temp_in)?;
167                t = t.replaceind(&om.internal_index, temp_out)?;
168                all_tensors.push(t);
169                temp_out_to_true.push((temp_out.clone(), true_idx.clone()));
170            }
171        } else {
172            // No mappings: plain path
173            all_tensors.push(v.clone());
174            for node in region {
175                let node_idx = self
176                    .operator
177                    .node_index(node)
178                    .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", node))?;
179                let tensor = self
180                    .operator
181                    .tensor(node_idx)
182                    .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))?
183                    .clone();
184                all_tensors.push(tensor);
185            }
186        }
187
188        // Collect environments from neighbors outside the region
189        for node in region {
190            for neighbor in topology.neighbors(node) {
191                if region.contains(&neighbor) {
192                    continue;
193                }
194                if let Some(env) = self.envs.get(&neighbor, node) {
195                    all_tensors.push(env.clone());
196                }
197            }
198        }
199
200        let tensor_refs: Vec<&T> = all_tensors.iter().collect();
201        let mut contracted = T::contract(&tensor_refs, AllowedPairs::All)?;
202
203        // Replace temp_out -> true_index
204        for (temp_out, true_idx) in &temp_out_to_true {
205            contracted = contracted.replaceind(temp_out, true_idx)?;
206        }
207
208        // Bra -> ket boundary bonds in result so output lives in same space as v (ket bonds).
209        for node in region {
210            for neighbor in topology.neighbors(node) {
211                if region.contains(&neighbor) {
212                    continue;
213                }
214                let ket_edge = match ket_state.edge_between(node, &neighbor) {
215                    Some(e) => e,
216                    None => continue,
217                };
218                let bra_edge = match bra_state.edge_between(node, &neighbor) {
219                    Some(e) => e,
220                    None => continue,
221                };
222                let ket_bond = match ket_state.bond_index(ket_edge) {
223                    Some(b) => b.clone(),
224                    None => continue,
225                };
226                let bra_bond = match bra_state.bond_index(bra_edge) {
227                    Some(b) => b.clone(),
228                    None => continue,
229                };
230                if contracted
231                    .external_indices()
232                    .iter()
233                    .any(|i| i.id() == bra_bond.id())
234                {
235                    contracted = contracted.replaceind(&bra_bond, &ket_bond)?;
236                }
237            }
238        }
239
240        // Align result to v's index order
241        let v_inds = v.external_indices();
242        let res_inds = contracted.external_indices();
243        let v_ids: std::collections::HashSet<_> = v_inds.iter().map(|i| i.id()).collect();
244        let res_ids: std::collections::HashSet<_> = res_inds.iter().map(|i| i.id()).collect();
245        if v_ids == res_ids && v_inds.len() == res_inds.len() {
246            contracted = contracted.permuteinds(&v_inds)?;
247        }
248
249        Ok(contracted)
250    }
251
252    /// Ensure environments are computed for neighbors of the region.
253    fn ensure_environments<NT: NetworkTopology<V>>(
254        &mut self,
255        region: &[V],
256        ket_state: &TreeTN<T, V>,
257        bra_state: &TreeTN<T, V>,
258        topology: &NT,
259    ) -> Result<()> {
260        for node in region {
261            for neighbor in topology.neighbors(node) {
262                if !region.contains(&neighbor) && !self.envs.contains(&neighbor, node) {
263                    let env =
264                        self.compute_environment(&neighbor, node, ket_state, bra_state, topology)?;
265                    self.envs.insert(neighbor.clone(), node.clone(), env);
266                }
267            }
268        }
269        Ok(())
270    }
271
272    /// Recursively compute environment for edge (from, to).
273    ///
274    /// # Arguments
275    /// * `from` - Source node of the edge
276    /// * `to` - Destination node of the edge
277    /// * `ket_state` - State for ket tensors (input space, V_in)
278    /// * `bra_state` - State for bra tensors (output space, V_out)
279    /// * `topology` - Network topology
280    fn compute_environment<NT: NetworkTopology<V>>(
281        &mut self,
282        from: &V,
283        to: &V,
284        ket_state: &TreeTN<T, V>,
285        bra_state: &TreeTN<T, V>,
286        topology: &NT,
287    ) -> Result<T> {
288        // First, ensure child environments are computed
289        let child_neighbors: Vec<V> = topology.neighbors(from).filter(|n| n != to).collect();
290
291        for child in &child_neighbors {
292            if !self.envs.contains(child, from) {
293                let child_env =
294                    self.compute_environment(child, from, ket_state, bra_state, topology)?;
295                self.envs.insert(child.clone(), from.clone(), child_env);
296            }
297        }
298
299        // Collect child environments
300        let child_envs: Vec<T> = child_neighbors
301            .iter()
302            .filter_map(|child| self.envs.get(child, from).cloned())
303            .collect();
304
305        // Get tensors from bra (V_out), operator, and ket (V_in) at this node
306        let node_idx_bra = bra_state
307            .node_index(from)
308            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in bra_state", from))?;
309        let node_idx_op = self
310            .operator
311            .node_index(from)
312            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", from))?;
313        let node_idx_ket = ket_state
314            .node_index(from)
315            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in ket_state", from))?;
316
317        let tensor_bra = bra_state
318            .tensor(node_idx_bra)
319            .ok_or_else(|| anyhow::anyhow!("Tensor not found in bra_state"))?;
320        let tensor_op = self
321            .operator
322            .tensor(node_idx_op)
323            .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))?;
324        let tensor_ket = ket_state
325            .tensor(node_idx_ket)
326            .ok_or_else(|| anyhow::anyhow!("Tensor not found in ket_state"))?;
327
328        // Environment contraction for 3-chain: <bra| H |ket>
329        //
330        // When using index mappings (from LinearOperator):
331        // - ket's site index (s) needs to be replaced with MPO's input index (s_in_tmp) for contraction
332        // - bra's site index (s) needs to be replaced with MPO's output index (s_out_tmp) for contraction
333        //
334        // Without mappings: indices are assumed to match directly (same ID).
335
336        let bra_conj = tensor_bra.conj();
337
338        // Transform ket tensor for contraction with operator
339        let transformed_ket = if let Some(ref input_mapping) = self.input_mapping {
340            if let Some(mapping) = input_mapping.get(from) {
341                tensor_ket.replaceind(&mapping.true_index, &mapping.internal_index)?
342            } else {
343                tensor_ket.clone()
344            }
345        } else {
346            tensor_ket.clone()
347        };
348
349        // Transform bra_conj tensor for contraction with operator
350        let transformed_bra_conj = if let Some(ref output_mapping) = self.output_mapping {
351            if let Some(mapping) = output_mapping.get(from) {
352                bra_conj.replaceind(&mapping.true_index, &mapping.internal_index)?
353            } else {
354                bra_conj.clone()
355            }
356        } else {
357            bra_conj.clone()
358        };
359
360        // Contract ket, op, bra, and child environments together
361        // Let contract() find the optimal contraction order
362        let mut tensor_refs: Vec<&T> = vec![&transformed_ket, tensor_op, &transformed_bra_conj];
363        tensor_refs.extend(child_envs.iter());
364        T::contract(&tensor_refs, AllowedPairs::All)
365    }
366
367    /// Compute the local dimension (size of the local Hilbert space).
368    pub fn local_dimension(&self, region: &[V]) -> usize {
369        let mut dim = 1;
370        for node in region {
371            if let Some(site_space) = self.operator.site_space(node) {
372                for idx in site_space {
373                    dim *= idx.dim();
374                }
375            }
376        }
377        dim
378    }
379
380    /// Invalidate caches affected by updates to the given region.
381    pub fn invalidate<NT: NetworkTopology<V>>(&mut self, region: &[V], topology: &NT) {
382        self.envs.invalidate(region, topology);
383    }
384}