Skip to main content

tensor4all_treetn/linsolve/square/
projected_state.rs

1//! ProjectedState: 2-chain environment for RHS computation (square case).
2//!
3//! Computes the local RHS term consistent with the square linsolve setup:
4//! - ProjectedOperator builds `<ref|H|x>`
5//! - ProjectedState builds `<ref|b>`
6//!
7//! This returns a local tensor with open indices aligned to the current solution's
8//! local tensor (up to permutation), so GMRES can operate in the same vector space.
9//!
10//! This is the V_in = V_out specialized version.
11
12use std::hash::Hash;
13
14use anyhow::Result;
15
16use tensor4all_core::{AllowedPairs, IndexLike, TensorLike};
17
18use crate::linsolve::common::{EnvironmentCache, NetworkTopology};
19use crate::treetn::TreeTN;
20
21/// ProjectedState: Manages 2-chain environments for RHS computation.
22///
23/// This computes `<b|x_local>` for each local region during the sweep.
24///
25/// For Tree Tensor Networks, the environment is computed by contracting
26/// all tensors outside the "open region" into environment tensors.
27/// The open region consists of nodes being updated in the current sweep step.
28///
29/// # Structure
30///
31/// For each edge (from, to) pointing towards the open region, we cache:
32/// ```text
33/// env[(from, to)] = contraction of:
34///   - bra tensor at `from` (conjugated RHS)
35///   - ket tensor at `from` (current solution)
36///   - all child environments (edges pointing away from `to`)
37/// ```
38///
39/// This forms a "2-chain" overlap: `<b|x>` contracted over
40/// all nodes except the open region.
41pub struct ProjectedState<T, V>
42where
43    T: TensorLike,
44    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
45{
46    /// The RHS state |b⟩
47    pub rhs: TreeTN<T, V>,
48    /// Environment cache
49    pub envs: EnvironmentCache<T, V>,
50}
51
52impl<T, V> ProjectedState<T, V>
53where
54    T: TensorLike,
55    <T::Index as IndexLike>::Id:
56        Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
57    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
58{
59    /// Create a new ProjectedState.
60    pub fn new(rhs: TreeTN<T, V>) -> Self {
61        Self {
62            rhs,
63            envs: EnvironmentCache::new(),
64        }
65    }
66
67    /// Compute the local constant term (local RHS) for the given region.
68    ///
69    /// This returns the local RHS tensors contracted with environments.
70    ///
71    /// For the square case, the `reference_state` is used as the bra (conjugated),
72    /// and `rhs` is used as the ket, i.e. environments are constructed for `<ref|b>`.
73    ///
74    /// # Arguments
75    /// * `region` - The nodes in the local update region
76    /// * `reference_state` - The current solution state (used as reference for environments)
77    /// * `topology` - The network topology
78    pub fn local_constant_term<NT: NetworkTopology<V>>(
79        &mut self,
80        region: &[V],
81        reference_state: &TreeTN<T, V>,
82        topology: &NT,
83    ) -> Result<T> {
84        // Ensure environments are computed
85        self.ensure_environments(region, reference_state, topology)?;
86
87        // Collect all tensors to contract: local RHS tensors + environments
88        let mut all_tensors: Vec<T> = Vec::new();
89
90        // Collect local RHS tensors (ket side; do NOT conjugate)
91        for node in region {
92            let node_idx = self
93                .rhs
94                .node_index(node)
95                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in RHS", node))?;
96            let tensor = self
97                .rhs
98                .tensor(node_idx)
99                .ok_or_else(|| anyhow::anyhow!("Tensor not found in RHS"))?
100                .clone();
101            all_tensors.push(tensor);
102        }
103
104        // Collect environments from neighbors outside the region
105        for node in region {
106            for neighbor in topology.neighbors(node) {
107                if !region.contains(&neighbor) {
108                    if let Some(env) = self.envs.get(&neighbor, node) {
109                        all_tensors.push(env.clone());
110                    }
111                }
112            }
113        }
114
115        // Use T::contract for optimal contraction ordering
116        let tensor_refs: Vec<&T> = all_tensors.iter().collect();
117        T::contract(&tensor_refs, AllowedPairs::All)
118    }
119
120    /// Ensure environments are computed for neighbors of the region.
121    fn ensure_environments<NT: NetworkTopology<V>>(
122        &mut self,
123        region: &[V],
124        reference_state: &TreeTN<T, V>,
125        topology: &NT,
126    ) -> Result<()> {
127        for node in region {
128            for neighbor in topology.neighbors(node) {
129                if !region.contains(&neighbor) && !self.envs.contains(&neighbor, node) {
130                    let env =
131                        self.compute_environment(&neighbor, node, reference_state, topology)?;
132                    self.envs.insert(neighbor.clone(), node.clone(), env);
133                }
134            }
135        }
136        Ok(())
137    }
138
139    /// Recursively compute environment for edge (from, to).
140    ///
141    /// Computes `<ref|b>` partial contraction at node `from`.
142    fn compute_environment<NT: NetworkTopology<V>>(
143        &mut self,
144        from: &V,
145        to: &V,
146        reference_state: &TreeTN<T, V>,
147        topology: &NT,
148    ) -> Result<T> {
149        // First, ensure child environments are computed
150        let child_neighbors: Vec<V> = topology.neighbors(from).filter(|n| n != to).collect();
151
152        for child in &child_neighbors {
153            if !self.envs.contains(child, from) {
154                let child_env = self.compute_environment(child, from, reference_state, topology)?;
155                self.envs.insert(child.clone(), from.clone(), child_env);
156            }
157        }
158
159        // Collect child environments
160        let child_envs: Vec<T> = child_neighbors
161            .iter()
162            .filter_map(|child| self.envs.get(child, from).cloned())
163            .collect();
164
165        // Contract bra (reference_state) with ket (RHS) at this node
166        let node_idx_ref = reference_state
167            .node_index(from)
168            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in reference_state", from))?;
169        let node_idx_b = self
170            .rhs
171            .node_index(from)
172            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in RHS", from))?;
173
174        let tensor_ref = reference_state
175            .tensor(node_idx_ref)
176            .ok_or_else(|| anyhow::anyhow!("Tensor not found in reference_state"))?;
177        let tensor_b = self
178            .rhs
179            .tensor(node_idx_b)
180            .ok_or_else(|| anyhow::anyhow!("Tensor not found in RHS"))?;
181
182        let bra_conj = tensor_ref.conj();
183
184        // Contract bra and ket - T::contract auto-detects contractable pairs
185        let bra_ket = T::contract(&[&bra_conj, tensor_b], AllowedPairs::All)?;
186
187        // Contract bra*ket with child environments using T::contract
188        if child_envs.is_empty() {
189            Ok(bra_ket)
190        } else {
191            let mut all_tensors: Vec<&T> = vec![&bra_ket];
192            all_tensors.extend(child_envs.iter());
193            T::contract(&all_tensors, AllowedPairs::All)
194        }
195    }
196
197    /// Invalidate caches affected by updates to the given region.
198    pub fn invalidate<NT: NetworkTopology<V>>(&mut self, region: &[V], topology: &NT) {
199        self.envs.invalidate(region, topology);
200    }
201}