Skip to main content

tensor4all_treetn/linsolve/square/
updater.rs

1//! SquareLinsolveUpdater: Local update implementation for square linsolve.
2//!
3//! Uses GMRES (via tensor4all_core::krylov) to solve the local linear problem at each sweep step.
4//! This is the V_in = V_out specialized version.
5
6use std::collections::HashMap;
7use std::hash::Hash;
8use std::sync::Arc;
9use std::sync::RwLock;
10
11use anyhow::{Context, Result};
12
13use tensor4all_core::any_scalar::AnyScalar;
14use tensor4all_core::krylov::{gmres, GmresOptions};
15use tensor4all_core::{AllowedPairs, FactorizeOptions, IndexLike, TensorLike};
16
17use super::local_linop::LocalLinOp;
18use super::projected_state::ProjectedState;
19use crate::linsolve::common::{LinsolveOptions, ProjectedOperator};
20use crate::operator::IndexMapping;
21use crate::{
22    factorize_tensor_to_treetn_with, get_boundary_edges, LocalUpdateStep, LocalUpdater, TreeTN,
23    TreeTopology,
24};
25
26/// Report from SquareLinsolveUpdater::verify().
27#[derive(Debug, Clone)]
28pub struct LinsolveVerifyReport<V> {
29    /// Whether the configuration is valid
30    pub is_valid: bool,
31    /// Errors that would prevent linsolve from working
32    pub errors: Vec<String>,
33    /// Warnings that might indicate issues
34    pub warnings: Vec<String>,
35    /// Per-node details
36    pub node_details: Vec<NodeVerifyDetail<V>>,
37}
38
39impl<V> Default for LinsolveVerifyReport<V> {
40    fn default() -> Self {
41        Self {
42            is_valid: false,
43            errors: Vec::new(),
44            warnings: Vec::new(),
45            node_details: Vec::new(),
46        }
47    }
48}
49
50/// Per-node verification details.
51#[derive(Debug, Clone)]
52pub struct NodeVerifyDetail<V> {
53    /// Node name
54    pub node: V,
55    /// State's site space index IDs
56    pub state_site_indices: Vec<String>,
57    /// Operator's site space index IDs
58    pub op_site_indices: Vec<String>,
59    /// State tensor's all index IDs with dimensions
60    pub state_tensor_indices: Vec<String>,
61    /// Operator tensor's all index IDs with dimensions
62    pub op_tensor_indices: Vec<String>,
63    /// Number of common indices between state and operator
64    pub common_index_count: usize,
65}
66
67impl<V: std::fmt::Debug> std::fmt::Display for LinsolveVerifyReport<V> {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        writeln!(f, "LinsolveVerifyReport:")?;
70        writeln!(f, "  Valid: {}", self.is_valid)?;
71
72        if !self.errors.is_empty() {
73            writeln!(f, "  Errors:")?;
74            for err in &self.errors {
75                writeln!(f, "    - {}", err)?;
76            }
77        }
78
79        if !self.warnings.is_empty() {
80            writeln!(f, "  Warnings:")?;
81            for warn in &self.warnings {
82                writeln!(f, "    - {}", warn)?;
83            }
84        }
85
86        if !self.node_details.is_empty() {
87            writeln!(f, "  Node Details:")?;
88            for detail in &self.node_details {
89                writeln!(f, "    {:?}:", detail.node)?;
90                writeln!(
91                    f,
92                    "      State site indices: {:?}",
93                    detail.state_site_indices
94                )?;
95                writeln!(f, "      Op site indices: {:?}", detail.op_site_indices)?;
96                writeln!(
97                    f,
98                    "      State tensor indices: {:?}",
99                    detail.state_tensor_indices
100                )?;
101                writeln!(f, "      Op tensor indices: {:?}", detail.op_tensor_indices)?;
102                writeln!(f, "      Common index count: {}", detail.common_index_count)?;
103            }
104        }
105
106        Ok(())
107    }
108}
109
110/// SquareLinsolveUpdater: Implements LocalUpdater for the square linsolve algorithm.
111///
112/// At each sweep step:
113/// 1. Compute local operator (from ProjectedOperator environments)
114/// 2. Compute local RHS (from ProjectedState environments)
115/// 3. Solve local linear system using GMRES
116/// 4. Factorize the result and update the state
117///
118/// This is the V_in = V_out specialized version. The current solution x is used
119/// with a separate reference state (with different bond indices) for stable
120/// environment computation.
121pub struct SquareLinsolveUpdater<T, V>
122where
123    T: TensorLike,
124    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
125{
126    /// Projected operator (3-chain), wrapped in `Arc<RwLock>` for GMRES.
127    pub projected_operator: Arc<RwLock<ProjectedOperator<T, V>>>,
128    /// Projected state for RHS (2-chain)
129    pub projected_state: ProjectedState<T, V>,
130    /// Solver options
131    pub options: LinsolveOptions,
132    /// Reference state (separate from ket to avoid unintended contractions).
133    /// Link indices are different from ket_state to prevent bra↔ket link contractions.
134    /// Boundary bonds (region ↔ outside) maintain stable IDs for cache consistency.
135    reference_state: TreeTN<T, V>,
136    /// Mapping from boundary edge (node_in_region, neighbor_outside) to reference-side bond index.
137    /// This ensures boundary bonds keep stable IDs across updates for environment cache reuse.
138    boundary_bond_map: HashMap<(V, V), T::Index>,
139    /// Run the bra/ket convention precheck only once.
140    did_ref_bra_ket_precheck: bool,
141    /// Run MPO structure validation only once.
142    did_mpo_validation: bool,
143}
144
145impl<T, V> SquareLinsolveUpdater<T, V>
146where
147    T: TensorLike + 'static,
148    T::Index: IndexLike,
149    <T::Index as IndexLike>::Id:
150        Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
151    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug + 'static,
152{
153    /// Create a new SquareLinsolveUpdater.
154    ///
155    /// The reference_state will be initialized lazily on the first `before_step` call.
156    pub fn new(operator: TreeTN<T, V>, rhs: TreeTN<T, V>, options: LinsolveOptions) -> Self {
157        Self {
158            projected_operator: Arc::new(RwLock::new(ProjectedOperator::new(operator))),
159            projected_state: ProjectedState::new(rhs),
160            options,
161            reference_state: TreeTN::new(),
162            boundary_bond_map: HashMap::new(),
163            did_ref_bra_ket_precheck: false,
164            did_mpo_validation: false,
165        }
166    }
167
168    /// Create a new SquareLinsolveUpdater with index mappings for correct index handling.
169    ///
170    /// Use this when the MPO uses internal indices (s_in_tmp, s_out_tmp) that differ
171    /// from the state's site indices. The mappings define how to translate between them.
172    ///
173    /// The reference_state will be initialized lazily on the first `before_step` call.
174    ///
175    /// # Arguments
176    /// * `operator` - The MPO with internal index IDs
177    /// * `input_mapping` - Mapping from true input indices to MPO's internal indices
178    /// * `output_mapping` - Mapping from true output indices to MPO's internal indices
179    /// * `rhs` - The RHS b
180    /// * `options` - Solver options
181    pub fn with_index_mappings(
182        operator: TreeTN<T, V>,
183        input_mapping: HashMap<V, IndexMapping<T::Index>>,
184        output_mapping: HashMap<V, IndexMapping<T::Index>>,
185        rhs: TreeTN<T, V>,
186        options: LinsolveOptions,
187    ) -> Self {
188        let projected_operator =
189            ProjectedOperator::with_index_mappings(operator, input_mapping, output_mapping);
190        Self {
191            projected_operator: Arc::new(RwLock::new(projected_operator)),
192            projected_state: ProjectedState::new(rhs),
193            options,
194            reference_state: TreeTN::new(),
195            boundary_bond_map: HashMap::new(),
196            did_ref_bra_ket_precheck: false,
197            did_mpo_validation: false,
198        }
199    }
200
201    /// Initialize reference_state from ket_state if not already initialized.
202    ///
203    /// Creates reference_state by relabeling link indices (using sim() for internal bonds,
204    /// preserving boundary bonds for cache consistency).
205    ///
206    /// This is called lazily on the first `before_step` to ensure we have the initial ket_state.
207    fn ensure_reference_state_initialized(&mut self, ket_state: &TreeTN<T, V>) -> Result<()> {
208        // Check if reference_state is already initialized (has nodes)
209        if !self.reference_state.node_names().is_empty() {
210            return Ok(());
211        }
212
213        // Initialize reference_state by cloning ket_state and relabeling link indices
214        // For boundary bonds, we'll preserve the mapping for later reuse
215        let mut reference_state = ket_state.clone();
216
217        // Get all edges to determine which are boundary bonds
218        // Since we don't have a region yet, we'll relabel all links initially
219        // Boundary bonds will be stabilized in after_step when we know the region
220        reference_state.sim_linkinds_mut()?;
221
222        // Initialize boundary_bond_map as empty (will be populated per-region in after_step)
223        self.boundary_bond_map.clear();
224
225        self.reference_state = reference_state;
226        Ok(())
227    }
228
229    /// Verify internal data consistency between operator, RHS, and state.
230    ///
231    /// This function checks that:
232    /// 1. The operator's site space structure is compatible with the state
233    /// 2. The operator's input indices can match the state's site indices
234    /// 3. Environment computation requirements are satisfiable
235    ///
236    /// Returns a detailed report of any inconsistencies found.
237    pub fn verify(&self, state: &TreeTN<T, V>) -> Result<LinsolveVerifyReport<V>> {
238        let mut report = LinsolveVerifyReport::default();
239
240        let proj_op = self
241            .projected_operator
242            .read()
243            .map_err(|e| {
244                anyhow::anyhow!("Failed to acquire read lock on projected_operator: {}", e)
245            })
246            .context("verify: lock poisoned")?;
247        let operator = &proj_op.operator;
248        let rhs = &self.projected_state.rhs;
249
250        // Check node consistency
251        let state_nodes: std::collections::BTreeSet<_> = state
252            .site_index_network()
253            .node_names()
254            .into_iter()
255            .collect();
256        let op_nodes: std::collections::BTreeSet<_> = operator
257            .site_index_network()
258            .node_names()
259            .into_iter()
260            .collect();
261        let rhs_nodes: std::collections::BTreeSet<_> =
262            rhs.site_index_network().node_names().into_iter().collect();
263
264        if state_nodes != op_nodes {
265            report.errors.push(format!(
266                "State and operator have different node sets. State: {:?}, Operator: {:?}",
267                state_nodes, op_nodes
268            ));
269        }
270
271        if state_nodes != rhs_nodes {
272            report.errors.push(format!(
273                "State and RHS have different node sets. State: {:?}, RHS: {:?}",
274                state_nodes, rhs_nodes
275            ));
276        }
277
278        // Check site index compatibility per node
279        for node in &state_nodes {
280            let state_site = state.site_space(node);
281            let op_site = operator.site_space(node);
282
283            // Get state tensor indices
284            if let Some(state_idx) = state.node_index(node) {
285                if let Some(state_tensor) = state.tensor(state_idx) {
286                    let state_indices_vec = state_tensor.external_indices();
287                    let state_indices: Vec<_> = state_indices_vec
288                        .iter()
289                        .map(|idx| (idx.id().clone(), idx.dim()))
290                        .collect();
291
292                    // Get operator tensor indices
293                    if let Some(op_idx) = operator.node_index(node) {
294                        if let Some(op_tensor) = operator.tensor(op_idx) {
295                            let op_indices_vec = op_tensor.external_indices();
296                            let op_indices: Vec<_> = op_indices_vec
297                                .iter()
298                                .map(|idx| (idx.id().clone(), idx.dim()))
299                                .collect();
300
301                            // Check for common indices (should have at least bond indices)
302                            let common_count = state_indices
303                                .iter()
304                                .filter(|(id, _)| op_indices.iter().any(|(oid, _)| oid == id))
305                                .count();
306
307                            report.node_details.push(NodeVerifyDetail {
308                                node: (*node).clone(),
309                                state_site_indices: state_site
310                                    .map(|s| s.iter().map(|i| format!("{:?}", i.id())).collect())
311                                    .unwrap_or_default(),
312                                op_site_indices: op_site
313                                    .map(|s| s.iter().map(|i| format!("{:?}", i.id())).collect())
314                                    .unwrap_or_default(),
315                                state_tensor_indices: state_indices
316                                    .iter()
317                                    .map(|(id, dim)| format!("{:?}(dim={})", id, dim))
318                                    .collect(),
319                                op_tensor_indices: op_indices
320                                    .iter()
321                                    .map(|(id, dim)| format!("{:?}(dim={})", id, dim))
322                                    .collect(),
323                                common_index_count: common_count,
324                            });
325
326                            // Warn if no site indices in common (expected for MPO)
327                            // In proper MPO structure, operator should have input indices
328                            // that match state's site indices
329                            if common_count == 0 {
330                                report.warnings.push(format!(
331                                    "Node {:?}: No common indices between state and operator tensors. \
332                                     State has {:?}, operator has {:?}",
333                                    node, state_indices, op_indices
334                                ));
335                            }
336                        }
337                    }
338                }
339            }
340        }
341
342        // Final verdict
343        report.is_valid = report.errors.is_empty();
344
345        Ok(report)
346    }
347
348    /// Contract all tensors in the region into a single local tensor.
349    fn contract_region(&self, subtree: &TreeTN<T, V>, region: &[V]) -> Result<T> {
350        if region.is_empty() {
351            return Err(anyhow::anyhow!("Region cannot be empty"));
352        }
353
354        // Collect all tensors in the region
355        let tensors: Vec<T> = region
356            .iter()
357            .map(|node| {
358                let idx = subtree
359                    .node_index(node)
360                    .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", node))?;
361                let tensor = subtree
362                    .tensor(idx)
363                    .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node))?;
364                Ok(tensor.clone())
365            })
366            .collect::<Result<_>>()?;
367
368        // Use TensorLike::contract for contraction
369        let tensor_refs: Vec<&T> = tensors.iter().collect();
370        T::contract(&tensor_refs, AllowedPairs::All)
371    }
372
373    /// Build TreeTopology for the subtree region from the solved tensor.
374    ///
375    /// Maps each node to the index IDs belonging to it in the solved tensor.
376    fn build_subtree_topology(
377        &self,
378        solved_tensor: &T,
379        region: &[V],
380        full_treetn: &TreeTN<T, V>,
381    ) -> Result<TreeTopology<V, <T::Index as IndexLike>::Id>> {
382        use std::collections::HashMap;
383
384        let mut nodes: HashMap<V, Vec<<T::Index as IndexLike>::Id>> = HashMap::new();
385        let mut edges: Vec<(V, V)> = Vec::new();
386
387        let solved_indices = solved_tensor.external_indices();
388
389        // For each node in the region, find which index IDs belong to it
390        for node in region {
391            let mut ids = Vec::new();
392
393            // Get site indices for this node
394            if let Some(site_indices) = full_treetn.site_space(node) {
395                for site_idx in site_indices {
396                    // Verify the index exists in solved_tensor and collect its ID
397                    if solved_indices.iter().any(|idx| idx.id() == site_idx.id()) {
398                        ids.push(site_idx.id().clone());
399                    }
400                }
401            }
402
403            // Get bond indices to neighbors outside the region
404            for neighbor in full_treetn.site_index_network().neighbors(node) {
405                if !region.contains(&neighbor) {
406                    // This is an external neighbor - the bond belongs to this node
407                    if let Some(edge) = full_treetn.edge_between(node, &neighbor) {
408                        if let Some(bond) = full_treetn.bond_index(edge) {
409                            if solved_indices.iter().any(|idx| idx.id() == bond.id()) {
410                                ids.push(bond.id().clone());
411                            }
412                        }
413                    }
414                }
415            }
416
417            nodes.insert(node.clone(), ids);
418        }
419
420        // Build edges between nodes in the region
421        for (i, node_a) in region.iter().enumerate() {
422            for node_b in region.iter().skip(i + 1) {
423                if full_treetn.edge_between(node_a, node_b).is_some() {
424                    edges.push((node_a.clone(), node_b.clone()));
425                }
426            }
427        }
428
429        Ok(TreeTopology::new(nodes, edges))
430    }
431
432    /// Copy decomposed tensors back to subtree, preserving original bond IDs.
433    fn copy_decomposed_to_subtree(
434        &self,
435        subtree: &mut TreeTN<T, V>,
436        decomposed: &TreeTN<T, V>,
437        region: &[V],
438        full_treetn: &TreeTN<T, V>,
439    ) -> Result<()> {
440        use std::collections::HashMap;
441
442        // Phase 1: Build a mapping from decomposed bond IDs to new bond indices
443        // For internal bonds, we create a single new bond index that will be used
444        // for both nodes sharing that edge
445        let mut bond_mapping: HashMap<<T::Index as IndexLike>::Id, T::Index> = HashMap::new();
446
447        for (i, node_a) in region.iter().enumerate() {
448            for node_b in region.iter().skip(i + 1) {
449                // Check if there's an edge between these nodes
450                if let Some(decomp_edge) = decomposed.edge_between(node_a, node_b) {
451                    if let Some(decomp_bond) = decomposed.bond_index(decomp_edge) {
452                        // Create a new bond index matching decomposed bond dimension.
453                        // Use sim() once for this edge to avoid ID collisions.
454                        if let Some(orig_edge) = subtree.edge_between(node_a, node_b) {
455                            let new_bond = decomp_bond.sim();
456                            bond_mapping.insert(decomp_bond.id().clone(), new_bond.clone());
457
458                            // Update the edge bond in subtree
459                            subtree.replace_edge_bond(orig_edge, new_bond)?;
460                        }
461                    }
462                }
463            }
464        }
465
466        // Phase 2: For each node in the region, update its tensor using the bond mapping
467        for node in region {
468            let decomp_idx = decomposed
469                .node_index(node)
470                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in decomposed TreeTN", node))?;
471            let mut new_tensor = decomposed
472                .tensor(decomp_idx)
473                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node))?
474                .clone();
475
476            // Replace bond indices using the pre-computed mapping
477            for neighbor in full_treetn.site_index_network().neighbors(node) {
478                if region.contains(&neighbor) {
479                    // Internal bond - use the mapped bond
480                    if let Some(decomp_edge) = decomposed.edge_between(node, &neighbor) {
481                        if let Some(decomp_bond) = decomposed.bond_index(decomp_edge) {
482                            if let Some(new_bond) = bond_mapping.get(decomp_bond.id()) {
483                                new_tensor = new_tensor.replaceind(decomp_bond, new_bond)?;
484                            }
485                        }
486                    }
487                }
488            }
489
490            // Update the tensor in subtree
491            let subtree_idx = subtree
492                .node_index(node)
493                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", node))?;
494            subtree.replace_tensor(subtree_idx, new_tensor)?;
495        }
496
497        Ok(())
498    }
499
500    /// Solve the local linear problem using GMRES.
501    ///
502    /// Solves: (a₀ + a₁ * H_local) |x_local⟩ = |b_local⟩
503    fn solve_local(&mut self, region: &[V], init: &T, state: &TreeTN<T, V>) -> Result<T> {
504        // Use state's SiteIndexNetwork directly (implements NetworkTopology)
505        let topology = state.site_index_network();
506
507        // Get local RHS: <b|_local
508        let rhs_local_raw = self
509            .projected_state
510            .local_constant_term(region, state, topology)?;
511
512        // Align RHS indices with init indices.
513        // The RHS may have indices from the `rhs` TreeTN, while init has indices from
514        // the current `state`. For GMRES operations (like b - A*x), they must match.
515        //
516        // For MPO cases, external indices should match (validated earlier), but the order
517        // may differ. We use ID-based matching to align them.
518        let init_indices = init.external_indices();
519        let rhs_indices = rhs_local_raw.external_indices();
520
521        let rhs_local = if self.index_sets_match(&init_indices, &rhs_indices) {
522            // Same set of indices (by ID) and same count - check if order matches
523            let indices_match = init_indices
524                .iter()
525                .zip(rhs_indices.iter())
526                .all(|(ii, ri)| ii.id() == ri.id() && ii.dim() == ri.dim());
527            if indices_match {
528                rhs_local_raw
529            } else {
530                // Permute RHS indices to init index order (matched by ID)
531                // Note: permuteinds requires same length, which we've already checked
532                rhs_local_raw.permuteinds(&init_indices)?
533            }
534        } else {
535            return Err(anyhow::anyhow!(
536                "{}",
537                self.index_structure_mismatch_message(
538                    &init_indices,
539                    &rhs_indices,
540                    "Index structure mismatch between init and RHS (local tensors)",
541                    "This suggests:\n  - ProjectedState environment construction may have contracted/left open unexpected indices\n  - External indices may not be properly aligned between x and b\n  - AllowedPairs::All may have over-contracted external indices in the environment\n\nSee `plan/linsolve-mpo.md` for analysis of external index handling.",
542                )
543            ));
544        };
545
546        // Convert coefficients to AnyScalar
547        let a0 = AnyScalar::new_real(self.options.a0);
548        let a1 = AnyScalar::new_real(self.options.a1);
549
550        // Create local linear operator with separate reference_state
551        // This prevents unintended bra↔ket link contractions in environment computation
552        let linop = LocalLinOp::new(
553            Arc::clone(&self.projected_operator),
554            region.to_vec(),
555            state.clone(),
556            self.reference_state.clone(),
557            a0,
558            a1,
559        );
560
561        // Create closure for GMRES that applies the linear operator
562        let apply_a = |x: &T| linop.apply(x);
563
564        // Set up GMRES options
565        let gmres_options = GmresOptions {
566            max_iter: self.options.krylov_dim,
567            rtol: self.options.krylov_tol,
568            max_restarts: (self.options.krylov_maxiter / self.options.krylov_dim).max(1),
569            verbose: false,
570            check_true_residual: false,
571        };
572
573        // Solve using GMRES (works directly with TensorDynLen)
574        let result = gmres(apply_a, &rhs_local, init, &gmres_options)?;
575
576        Ok(result.solution)
577    }
578
579    /// Synchronize reference_state region with ket_state, preserving boundary bond IDs.
580    ///
581    /// This ensures reference_state stays in sync with ket_state updates while maintaining
582    /// stable boundary bond IDs for environment cache reuse.
583    fn sync_reference_state_region(
584        &mut self,
585        step: &LocalUpdateStep<V>,
586        ket_state: &TreeTN<T, V>,
587    ) -> Result<()> {
588        // Extract updated region from ket_state
589        let ket_region = ket_state.extract_subtree(&step.nodes)?;
590
591        // Build mapping from ket bond IDs to reference bond indices for *all* bonds incident to the region.
592        //
593        // Important: reference_state bond IDs must remain stable across steps, even when an edge alternates
594        // between being a boundary edge and an internal edge in different steps (as happens in sweeps).
595        // Therefore, we always reuse the current reference_state bond indices, and never create fresh IDs here.
596        let mut ket_to_ref_bond_map: HashMap<<T::Index as IndexLike>::Id, T::Index> =
597            HashMap::new();
598
599        // Populate mapping for all edges incident to region nodes.
600        // - Boundary edges (region ↔ outside): use reference_state's existing bond IDs for cache stability
601        // - Internal edges (within region): use sim() to create new IDs distinct from ket
602        // This ensures reference_state bond IDs are always different from ket_state bond IDs.
603        let region_nodes: std::collections::HashSet<_> = step.nodes.iter().collect();
604        for node in &step.nodes {
605            for neighbor in ket_state.site_index_network().neighbors(node) {
606                let ket_edge = match ket_state.edge_between(node, &neighbor) {
607                    Some(e) => e,
608                    None => continue,
609                };
610                let ket_bond = match ket_state.bond_index(ket_edge) {
611                    Some(b) => b,
612                    None => continue,
613                };
614
615                let ref_bond = if region_nodes.contains(&neighbor) {
616                    // Internal edge: create new ID using sim()
617                    ket_bond.sim()
618                } else {
619                    // Boundary edge: use reference_state's existing bond ID
620                    let ref_edge = match self.reference_state.edge_between(node, &neighbor) {
621                        Some(e) => e,
622                        None => continue,
623                    };
624                    match self.reference_state.bond_index(ref_edge) {
625                        Some(b) => b.clone(),
626                        None => continue,
627                    }
628                };
629                ket_to_ref_bond_map.insert(ket_bond.id().clone(), ref_bond);
630            }
631        }
632
633        // Keep a small explicit cache for boundary edges (region ↔ outside) for inspection/debugging.
634        // This is not used to drive the mapping logic.
635        for boundary_edge in get_boundary_edges(ket_state, &step.nodes)? {
636            if let Some(edge) = self.reference_state.edge_between(
637                &boundary_edge.node_in_region,
638                &boundary_edge.neighbor_outside,
639            ) {
640                if let Some(ref_bond) = self.reference_state.bond_index(edge) {
641                    self.boundary_bond_map.insert(
642                        (
643                            boundary_edge.node_in_region.clone(),
644                            boundary_edge.neighbor_outside.clone(),
645                        ),
646                        ref_bond.clone(),
647                    );
648                }
649            }
650        }
651
652        // Create new ref_region by copying ket_region
653        let mut ref_region = ket_region.clone();
654
655        // First, update edge bonds in ref_region to reference-side IDs
656        // This must be done before replacing tensors to ensure consistency
657        let mut edges_to_update: Vec<(V, V, T::Index)> = Vec::new();
658        for node in &step.nodes {
659            let neighbors: Vec<V> = ref_region.site_index_network().neighbors(node).collect();
660            for neighbor in neighbors {
661                if let Some(edge) = ref_region.edge_between(node, &neighbor) {
662                    if let Some(bond) = ref_region.bond_index(edge) {
663                        if let Some(new_bond) = ket_to_ref_bond_map.get(bond.id()) {
664                            edges_to_update.push((node.clone(), neighbor, new_bond.clone()));
665                        }
666                    }
667                }
668            }
669        }
670        // Update edges (ref_region is no longer borrowed)
671        for (node, neighbor, new_bond) in edges_to_update {
672            if let Some(edge) = ref_region.edge_between(&node, &neighbor) {
673                ref_region.replace_edge_bond(edge, new_bond)?;
674            }
675        }
676
677        // Now replace all bond indices (including boundary bonds) in ref_region tensors with reference-side IDs
678        for node in &step.nodes {
679            if let Some(node_idx) = ref_region.node_index(node) {
680                if let Some(tensor) = ref_region.tensor(node_idx) {
681                    let mut new_tensor = tensor.clone();
682                    let tensor_indices = tensor.external_indices();
683
684                    for ket_idx in &tensor_indices {
685                        // Replace if this index is one of the region's bond indices (internal or boundary)
686                        if let Some(ref_bond) = ket_to_ref_bond_map.get(ket_idx.id()) {
687                            new_tensor = new_tensor.replaceind(ket_idx, ref_bond)?;
688                        }
689                        // Site indices are kept as-is (same IDs in reference and ket)
690                    }
691
692                    ref_region.replace_tensor(node_idx, new_tensor)?;
693                }
694            }
695        }
696
697        // Replace the region back into reference_state
698        self.reference_state
699            .replace_subtree(&step.nodes, &ref_region)?;
700
701        Ok(())
702    }
703}
704
705impl<T, V> LocalUpdater<T, V> for SquareLinsolveUpdater<T, V>
706where
707    T: TensorLike + 'static,
708    T::Index: IndexLike,
709    <T::Index as IndexLike>::Id:
710        Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
711    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug + 'static,
712{
713    fn before_step(
714        &mut self,
715        step: &LocalUpdateStep<V>,
716        full_treetn_before: &TreeTN<T, V>,
717    ) -> Result<()> {
718        // Initialize reference_state lazily on first call
719        self.ensure_reference_state_initialized(full_treetn_before)?;
720
721        // (1) Precheck: ensure local RHS indices align with local init indices
722        // (bra/ket convention sanity for `<ref|H|x>` vs `<ref|b>`).
723        if !self.did_ref_bra_ket_precheck {
724            self.precheck_ref_bra_ket_convention(step, full_treetn_before)?;
725            self.did_ref_bra_ket_precheck = true;
726        }
727
728        // (3) MPO structure validation (fail fast) – run once.
729        if !self.did_mpo_validation {
730            self.validate_mpo_external_indices(full_treetn_before)?;
731            self.did_mpo_validation = true;
732        }
733        Ok(())
734    }
735
736    fn update(
737        &mut self,
738        mut subtree: TreeTN<T, V>,
739        step: &LocalUpdateStep<V>,
740        full_treetn: &TreeTN<T, V>,
741    ) -> Result<TreeTN<T, V>> {
742        // Contract tensors in the region into a single local tensor
743        let init_local = self.contract_region(&subtree, &step.nodes)?;
744        // Solve local linear problem using GMRES
745        let solved_local = self.solve_local(&step.nodes, &init_local, full_treetn)?;
746
747        // Build TreeTopology for the subtree region
748        let topology = self.build_subtree_topology(&solved_local, &step.nodes, full_treetn)?;
749
750        // Decompose solved tensor back into TreeTN using factorize_tensor_to_treetn
751        let mut factorize_options = FactorizeOptions::svd();
752        if let Some(max_rank) = self.options.truncation.max_rank() {
753            factorize_options = factorize_options.with_max_rank(max_rank);
754        }
755        if let Some(policy) = self.options.truncation.svd_policy() {
756            factorize_options = factorize_options.with_svd_policy(policy);
757        }
758        // Force decomposition root to be consistent with the sweep plan's new center.
759        // This keeps the norm-carrying tensor on the declared canonical center.
760        let decomposed = factorize_tensor_to_treetn_with(
761            &solved_local,
762            &topology,
763            factorize_options,
764            &step.new_center,
765        )?;
766
767        // Copy decomposed tensors back to subtree, preserving original bond IDs
768        self.copy_decomposed_to_subtree(&mut subtree, &decomposed, &step.nodes, full_treetn)?;
769
770        // Force-move the canonical region metadata AND the bond ortho directions to the new center.
771        // (apply_local_update_sweep also updates canonical_region on the full TreeTN, but we keep
772        // the subtree self-consistent here.)
773        subtree.set_canonical_region([step.new_center.clone()])?;
774        if let Some(edges) = subtree.edges_to_canonicalize_by_names(&step.new_center) {
775            for (from, to) in edges {
776                if let Some(edge) = subtree.edge_between(&from, &to) {
777                    subtree.set_edge_ortho_towards(edge, Some(to))?;
778                }
779            }
780        }
781
782        Ok(subtree)
783    }
784
785    fn after_step(
786        &mut self,
787        step: &LocalUpdateStep<V>,
788        full_treetn_after: &TreeTN<T, V>,
789    ) -> Result<()> {
790        // Use state's SiteIndexNetwork directly (implements NetworkTopology)
791        let topology = full_treetn_after.site_index_network();
792
793        // Synchronize reference_state with ket_state (full_treetn_after) for the updated region
794        // while preserving boundary bond IDs for cache consistency
795        self.sync_reference_state_region(step, full_treetn_after)?;
796
797        // Invalidate all caches affected by the updated region
798        {
799            let mut proj_op = self
800                .projected_operator
801                .write()
802                .map_err(|e| anyhow::anyhow!("Failed to acquire write lock: {}", e))
803                .context("after_step: lock poisoned")?;
804            proj_op.invalidate(&step.nodes, topology);
805        }
806        self.projected_state.invalidate(&step.nodes, topology);
807
808        Ok(())
809    }
810}
811
812impl<T, V> SquareLinsolveUpdater<T, V>
813where
814    T: TensorLike + 'static,
815    T::Index: IndexLike,
816    <T::Index as IndexLike>::Id:
817        Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
818    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug + 'static,
819{
820    fn index_sets_match(&self, init_indices: &[T::Index], rhs_indices: &[T::Index]) -> bool {
821        if init_indices.len() != rhs_indices.len() {
822            return false;
823        }
824        let init_ids: std::collections::HashSet<_> = init_indices.iter().map(|i| i.id()).collect();
825        let rhs_ids: std::collections::HashSet<_> = rhs_indices.iter().map(|i| i.id()).collect();
826        init_ids == rhs_ids
827    }
828
829    fn index_structure_mismatch_message(
830        &self,
831        init_indices: &[T::Index],
832        rhs_indices: &[T::Index],
833        header: &str,
834        footer: &str,
835    ) -> String {
836        let init_ids: std::collections::HashSet<_> = init_indices.iter().map(|i| i.id()).collect();
837        let rhs_ids: std::collections::HashSet<_> = rhs_indices.iter().map(|i| i.id()).collect();
838        let extra_in_rhs: Vec<_> = rhs_ids
839            .difference(&init_ids)
840            .map(|id| {
841                rhs_indices
842                    .iter()
843                    .find(|i| i.id() == *id)
844                    .map(|i| format!("{:?}:{}", id, i.dim()))
845                    .unwrap_or_else(|| format!("{:?}:?", id))
846            })
847            .collect();
848        let missing_in_rhs: Vec<_> = init_ids
849            .difference(&rhs_ids)
850            .map(|id| {
851                init_indices
852                    .iter()
853                    .find(|i| i.id() == *id)
854                    .map(|i| format!("{:?}:{}", id, i.dim()))
855                    .unwrap_or_else(|| format!("{:?}:?", id))
856            })
857            .collect();
858
859        format!(
860            "{header}:\n  init has {} indices: {:?}\n  rhs has {} indices: {:?}\n  extra in rhs (not in init): {:?}\n  missing in rhs (in init but not in rhs): {:?}\n\n{footer}",
861            init_indices.len(),
862            init_indices
863                .iter()
864                .map(|i| format!("{:?}:{}", i.id(), i.dim()))
865                .collect::<Vec<_>>(),
866            rhs_indices.len(),
867            rhs_indices
868                .iter()
869                .map(|i| format!("{:?}:{}", i.id(), i.dim()))
870                .collect::<Vec<_>>(),
871            extra_in_rhs,
872            missing_in_rhs,
873        )
874    }
875
876    fn precheck_ref_bra_ket_convention(
877        &mut self,
878        step: &LocalUpdateStep<V>,
879        full_treetn_before: &TreeTN<T, V>,
880    ) -> Result<()> {
881        let subtree = full_treetn_before.extract_subtree(&step.nodes)?;
882        let init_local = self.contract_region(&subtree, &step.nodes)?;
883
884        let topology = full_treetn_before.site_index_network();
885        let rhs_local_raw =
886            self.projected_state
887                .local_constant_term(&step.nodes, full_treetn_before, topology)?;
888
889        let init_indices = init_local.external_indices();
890        let rhs_indices = rhs_local_raw.external_indices();
891
892        if !self.index_sets_match(&init_indices, &rhs_indices) {
893            return Err(anyhow::anyhow!(
894                "{}",
895                self.index_structure_mismatch_message(
896                    &init_indices,
897                    &rhs_indices,
898                    "linsolve precheck failed (local index structure mismatch)",
899                    "This suggests `<ref|H|x>` vs `<ref|b>` conventions (or external-index contraction rules) are inconsistent for the current region. See `plan/linsolve-mpo.md` for analysis.",
900                )
901            ));
902        }
903
904        Ok(())
905    }
906
907    #[allow(clippy::type_complexity)]
908    fn validate_mpo_external_indices(&mut self, state: &TreeTN<T, V>) -> Result<()> {
909        // Only validate when operator mappings exist (MPO-with-mappings path).
910        let (input_mapping, output_mapping): (
911            HashMap<V, IndexMapping<T::Index>>,
912            HashMap<V, IndexMapping<T::Index>>,
913        ) = {
914            let proj_op = self.projected_operator.read().map_err(|e| {
915                anyhow::anyhow!("validate_mpo_external_indices: lock poisoned: {e}")
916            })?;
917            let Some(input) = proj_op.input_mapping.as_ref() else {
918                return Ok(());
919            };
920            let Some(output) = proj_op.output_mapping.as_ref() else {
921                return Ok(());
922            };
923            (input.clone(), output.clone())
924        };
925
926        for node in state.node_names() {
927            let Some(x_sites) = state.site_space(&node) else {
928                continue;
929            };
930            let Some(b_sites) = self.projected_state.rhs.site_space(&node) else {
931                continue;
932            };
933
934            // Only apply strict MPO validation when both look like MPO (2 site indices).
935            if x_sites.len() != 2 || b_sites.len() != 2 {
936                continue;
937            }
938
939            let x_contracted = input_mapping
940                .get(&node)
941                .ok_or_else(|| {
942                    anyhow::anyhow!("MPO validation: missing input_mapping for node {:?}", node)
943                })?
944                .true_index
945                .clone();
946            let b_contracted = output_mapping
947                .get(&node)
948                .ok_or_else(|| {
949                    anyhow::anyhow!("MPO validation: missing output_mapping for node {:?}", node)
950                })?
951                .true_index
952                .clone();
953
954            let x_external: Vec<_> = x_sites
955                .iter()
956                .filter(|idx| !idx.same_id(&x_contracted))
957                .cloned()
958                .collect();
959            let b_external: Vec<_> = b_sites
960                .iter()
961                .filter(|idx| !idx.same_id(&b_contracted))
962                .cloned()
963                .collect();
964
965            if x_external.len() != 1 || b_external.len() != 1 {
966                return Err(anyhow::anyhow!(
967                    "MPO validation: expected exactly 1 external site index after removing contracted index. node={:?}, x_site_len={}, b_site_len={}, x_external={:?}, b_external={:?}",
968                    node,
969                    x_sites.len(),
970                    b_sites.len(),
971                    x_external.iter().map(|i| format!("{:?}:{}", i.id(), i.dim())).collect::<Vec<_>>(),
972                    b_external.iter().map(|i| format!("{:?}:{}", i.id(), i.dim())).collect::<Vec<_>>(),
973                ));
974            }
975
976            let x_ext = &x_external[0];
977            let b_ext = &b_external[0];
978            if !x_ext.same_id(b_ext) || x_ext.dim() != b_ext.dim() {
979                return Err(anyhow::anyhow!(
980                    "MPO validation: external index mismatch at node {:?}: x has {:?}:{}, b has {:?}:{}",
981                    node,
982                    x_ext.id(),
983                    x_ext.dim(),
984                    b_ext.id(),
985                    b_ext.dim(),
986                ));
987            }
988        }
989
990        Ok(())
991    }
992}