1use std::collections::HashMap;
7use std::hash::Hash;
8use std::sync::Arc;
9use std::sync::RwLock;
10
11use anyhow::{Context, Result};
12
13use tensor4all_core::any_scalar::AnyScalar;
14use tensor4all_core::krylov::{gmres, GmresOptions};
15use tensor4all_core::{AllowedPairs, FactorizeOptions, IndexLike, TensorLike};
16
17use super::local_linop::LocalLinOp;
18use super::projected_state::ProjectedState;
19use crate::linsolve::common::{LinsolveOptions, ProjectedOperator};
20use crate::operator::IndexMapping;
21use crate::{
22 factorize_tensor_to_treetn_with, get_boundary_edges, LocalUpdateStep, LocalUpdater, TreeTN,
23 TreeTopology,
24};
25
26#[derive(Debug, Clone)]
28pub struct LinsolveVerifyReport<V> {
29 pub is_valid: bool,
31 pub errors: Vec<String>,
33 pub warnings: Vec<String>,
35 pub node_details: Vec<NodeVerifyDetail<V>>,
37}
38
39impl<V> Default for LinsolveVerifyReport<V> {
40 fn default() -> Self {
41 Self {
42 is_valid: false,
43 errors: Vec::new(),
44 warnings: Vec::new(),
45 node_details: Vec::new(),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct NodeVerifyDetail<V> {
53 pub node: V,
55 pub state_site_indices: Vec<String>,
57 pub op_site_indices: Vec<String>,
59 pub state_tensor_indices: Vec<String>,
61 pub op_tensor_indices: Vec<String>,
63 pub common_index_count: usize,
65}
66
67impl<V: std::fmt::Debug> std::fmt::Display for LinsolveVerifyReport<V> {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 writeln!(f, "LinsolveVerifyReport:")?;
70 writeln!(f, " Valid: {}", self.is_valid)?;
71
72 if !self.errors.is_empty() {
73 writeln!(f, " Errors:")?;
74 for err in &self.errors {
75 writeln!(f, " - {}", err)?;
76 }
77 }
78
79 if !self.warnings.is_empty() {
80 writeln!(f, " Warnings:")?;
81 for warn in &self.warnings {
82 writeln!(f, " - {}", warn)?;
83 }
84 }
85
86 if !self.node_details.is_empty() {
87 writeln!(f, " Node Details:")?;
88 for detail in &self.node_details {
89 writeln!(f, " {:?}:", detail.node)?;
90 writeln!(
91 f,
92 " State site indices: {:?}",
93 detail.state_site_indices
94 )?;
95 writeln!(f, " Op site indices: {:?}", detail.op_site_indices)?;
96 writeln!(
97 f,
98 " State tensor indices: {:?}",
99 detail.state_tensor_indices
100 )?;
101 writeln!(f, " Op tensor indices: {:?}", detail.op_tensor_indices)?;
102 writeln!(f, " Common index count: {}", detail.common_index_count)?;
103 }
104 }
105
106 Ok(())
107 }
108}
109
110pub struct SquareLinsolveUpdater<T, V>
122where
123 T: TensorLike,
124 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
125{
126 pub projected_operator: Arc<RwLock<ProjectedOperator<T, V>>>,
128 pub projected_state: ProjectedState<T, V>,
130 pub options: LinsolveOptions,
132 reference_state: TreeTN<T, V>,
136 boundary_bond_map: HashMap<(V, V), T::Index>,
139 did_ref_bra_ket_precheck: bool,
141 did_mpo_validation: bool,
143}
144
145impl<T, V> SquareLinsolveUpdater<T, V>
146where
147 T: TensorLike + 'static,
148 T::Index: IndexLike,
149 <T::Index as IndexLike>::Id:
150 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
151 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug + 'static,
152{
153 pub fn new(operator: TreeTN<T, V>, rhs: TreeTN<T, V>, options: LinsolveOptions) -> Self {
157 Self {
158 projected_operator: Arc::new(RwLock::new(ProjectedOperator::new(operator))),
159 projected_state: ProjectedState::new(rhs),
160 options,
161 reference_state: TreeTN::new(),
162 boundary_bond_map: HashMap::new(),
163 did_ref_bra_ket_precheck: false,
164 did_mpo_validation: false,
165 }
166 }
167
168 pub fn with_index_mappings(
182 operator: TreeTN<T, V>,
183 input_mapping: HashMap<V, IndexMapping<T::Index>>,
184 output_mapping: HashMap<V, IndexMapping<T::Index>>,
185 rhs: TreeTN<T, V>,
186 options: LinsolveOptions,
187 ) -> Self {
188 let projected_operator =
189 ProjectedOperator::with_index_mappings(operator, input_mapping, output_mapping);
190 Self {
191 projected_operator: Arc::new(RwLock::new(projected_operator)),
192 projected_state: ProjectedState::new(rhs),
193 options,
194 reference_state: TreeTN::new(),
195 boundary_bond_map: HashMap::new(),
196 did_ref_bra_ket_precheck: false,
197 did_mpo_validation: false,
198 }
199 }
200
201 fn ensure_reference_state_initialized(&mut self, ket_state: &TreeTN<T, V>) -> Result<()> {
208 if !self.reference_state.node_names().is_empty() {
210 return Ok(());
211 }
212
213 let mut reference_state = ket_state.clone();
216
217 reference_state.sim_linkinds_mut()?;
221
222 self.boundary_bond_map.clear();
224
225 self.reference_state = reference_state;
226 Ok(())
227 }
228
229 pub fn verify(&self, state: &TreeTN<T, V>) -> Result<LinsolveVerifyReport<V>> {
238 let mut report = LinsolveVerifyReport::default();
239
240 let proj_op = self
241 .projected_operator
242 .read()
243 .map_err(|e| {
244 anyhow::anyhow!("Failed to acquire read lock on projected_operator: {}", e)
245 })
246 .context("verify: lock poisoned")?;
247 let operator = &proj_op.operator;
248 let rhs = &self.projected_state.rhs;
249
250 let state_nodes: std::collections::BTreeSet<_> = state
252 .site_index_network()
253 .node_names()
254 .into_iter()
255 .collect();
256 let op_nodes: std::collections::BTreeSet<_> = operator
257 .site_index_network()
258 .node_names()
259 .into_iter()
260 .collect();
261 let rhs_nodes: std::collections::BTreeSet<_> =
262 rhs.site_index_network().node_names().into_iter().collect();
263
264 if state_nodes != op_nodes {
265 report.errors.push(format!(
266 "State and operator have different node sets. State: {:?}, Operator: {:?}",
267 state_nodes, op_nodes
268 ));
269 }
270
271 if state_nodes != rhs_nodes {
272 report.errors.push(format!(
273 "State and RHS have different node sets. State: {:?}, RHS: {:?}",
274 state_nodes, rhs_nodes
275 ));
276 }
277
278 for node in &state_nodes {
280 let state_site = state.site_space(node);
281 let op_site = operator.site_space(node);
282
283 if let Some(state_idx) = state.node_index(node) {
285 if let Some(state_tensor) = state.tensor(state_idx) {
286 let state_indices_vec = state_tensor.external_indices();
287 let state_indices: Vec<_> = state_indices_vec
288 .iter()
289 .map(|idx| (idx.id().clone(), idx.dim()))
290 .collect();
291
292 if let Some(op_idx) = operator.node_index(node) {
294 if let Some(op_tensor) = operator.tensor(op_idx) {
295 let op_indices_vec = op_tensor.external_indices();
296 let op_indices: Vec<_> = op_indices_vec
297 .iter()
298 .map(|idx| (idx.id().clone(), idx.dim()))
299 .collect();
300
301 let common_count = state_indices
303 .iter()
304 .filter(|(id, _)| op_indices.iter().any(|(oid, _)| oid == id))
305 .count();
306
307 report.node_details.push(NodeVerifyDetail {
308 node: (*node).clone(),
309 state_site_indices: state_site
310 .map(|s| s.iter().map(|i| format!("{:?}", i.id())).collect())
311 .unwrap_or_default(),
312 op_site_indices: op_site
313 .map(|s| s.iter().map(|i| format!("{:?}", i.id())).collect())
314 .unwrap_or_default(),
315 state_tensor_indices: state_indices
316 .iter()
317 .map(|(id, dim)| format!("{:?}(dim={})", id, dim))
318 .collect(),
319 op_tensor_indices: op_indices
320 .iter()
321 .map(|(id, dim)| format!("{:?}(dim={})", id, dim))
322 .collect(),
323 common_index_count: common_count,
324 });
325
326 if common_count == 0 {
330 report.warnings.push(format!(
331 "Node {:?}: No common indices between state and operator tensors. \
332 State has {:?}, operator has {:?}",
333 node, state_indices, op_indices
334 ));
335 }
336 }
337 }
338 }
339 }
340 }
341
342 report.is_valid = report.errors.is_empty();
344
345 Ok(report)
346 }
347
348 fn contract_region(&self, subtree: &TreeTN<T, V>, region: &[V]) -> Result<T> {
350 if region.is_empty() {
351 return Err(anyhow::anyhow!("Region cannot be empty"));
352 }
353
354 let tensors: Vec<T> = region
356 .iter()
357 .map(|node| {
358 let idx = subtree
359 .node_index(node)
360 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", node))?;
361 let tensor = subtree
362 .tensor(idx)
363 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node))?;
364 Ok(tensor.clone())
365 })
366 .collect::<Result<_>>()?;
367
368 let tensor_refs: Vec<&T> = tensors.iter().collect();
370 T::contract(&tensor_refs, AllowedPairs::All)
371 }
372
373 fn build_subtree_topology(
377 &self,
378 solved_tensor: &T,
379 region: &[V],
380 full_treetn: &TreeTN<T, V>,
381 ) -> Result<TreeTopology<V, <T::Index as IndexLike>::Id>> {
382 use std::collections::HashMap;
383
384 let mut nodes: HashMap<V, Vec<<T::Index as IndexLike>::Id>> = HashMap::new();
385 let mut edges: Vec<(V, V)> = Vec::new();
386
387 let solved_indices = solved_tensor.external_indices();
388
389 for node in region {
391 let mut ids = Vec::new();
392
393 if let Some(site_indices) = full_treetn.site_space(node) {
395 for site_idx in site_indices {
396 if solved_indices.iter().any(|idx| idx.id() == site_idx.id()) {
398 ids.push(site_idx.id().clone());
399 }
400 }
401 }
402
403 for neighbor in full_treetn.site_index_network().neighbors(node) {
405 if !region.contains(&neighbor) {
406 if let Some(edge) = full_treetn.edge_between(node, &neighbor) {
408 if let Some(bond) = full_treetn.bond_index(edge) {
409 if solved_indices.iter().any(|idx| idx.id() == bond.id()) {
410 ids.push(bond.id().clone());
411 }
412 }
413 }
414 }
415 }
416
417 nodes.insert(node.clone(), ids);
418 }
419
420 for (i, node_a) in region.iter().enumerate() {
422 for node_b in region.iter().skip(i + 1) {
423 if full_treetn.edge_between(node_a, node_b).is_some() {
424 edges.push((node_a.clone(), node_b.clone()));
425 }
426 }
427 }
428
429 Ok(TreeTopology::new(nodes, edges))
430 }
431
432 fn copy_decomposed_to_subtree(
434 &self,
435 subtree: &mut TreeTN<T, V>,
436 decomposed: &TreeTN<T, V>,
437 region: &[V],
438 full_treetn: &TreeTN<T, V>,
439 ) -> Result<()> {
440 use std::collections::HashMap;
441
442 let mut bond_mapping: HashMap<<T::Index as IndexLike>::Id, T::Index> = HashMap::new();
446
447 for (i, node_a) in region.iter().enumerate() {
448 for node_b in region.iter().skip(i + 1) {
449 if let Some(decomp_edge) = decomposed.edge_between(node_a, node_b) {
451 if let Some(decomp_bond) = decomposed.bond_index(decomp_edge) {
452 if let Some(orig_edge) = subtree.edge_between(node_a, node_b) {
455 let new_bond = decomp_bond.sim();
456 bond_mapping.insert(decomp_bond.id().clone(), new_bond.clone());
457
458 subtree.replace_edge_bond(orig_edge, new_bond)?;
460 }
461 }
462 }
463 }
464 }
465
466 for node in region {
468 let decomp_idx = decomposed
469 .node_index(node)
470 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in decomposed TreeTN", node))?;
471 let mut new_tensor = decomposed
472 .tensor(decomp_idx)
473 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node))?
474 .clone();
475
476 for neighbor in full_treetn.site_index_network().neighbors(node) {
478 if region.contains(&neighbor) {
479 if let Some(decomp_edge) = decomposed.edge_between(node, &neighbor) {
481 if let Some(decomp_bond) = decomposed.bond_index(decomp_edge) {
482 if let Some(new_bond) = bond_mapping.get(decomp_bond.id()) {
483 new_tensor = new_tensor.replaceind(decomp_bond, new_bond)?;
484 }
485 }
486 }
487 }
488 }
489
490 let subtree_idx = subtree
492 .node_index(node)
493 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", node))?;
494 subtree.replace_tensor(subtree_idx, new_tensor)?;
495 }
496
497 Ok(())
498 }
499
500 fn solve_local(&mut self, region: &[V], init: &T, state: &TreeTN<T, V>) -> Result<T> {
504 let topology = state.site_index_network();
506
507 let rhs_local_raw = self
509 .projected_state
510 .local_constant_term(region, state, topology)?;
511
512 let init_indices = init.external_indices();
519 let rhs_indices = rhs_local_raw.external_indices();
520
521 let rhs_local = if self.index_sets_match(&init_indices, &rhs_indices) {
522 let indices_match = init_indices
524 .iter()
525 .zip(rhs_indices.iter())
526 .all(|(ii, ri)| ii.id() == ri.id() && ii.dim() == ri.dim());
527 if indices_match {
528 rhs_local_raw
529 } else {
530 rhs_local_raw.permuteinds(&init_indices)?
533 }
534 } else {
535 return Err(anyhow::anyhow!(
536 "{}",
537 self.index_structure_mismatch_message(
538 &init_indices,
539 &rhs_indices,
540 "Index structure mismatch between init and RHS (local tensors)",
541 "This suggests:\n - ProjectedState environment construction may have contracted/left open unexpected indices\n - External indices may not be properly aligned between x and b\n - AllowedPairs::All may have over-contracted external indices in the environment\n\nSee `plan/linsolve-mpo.md` for analysis of external index handling.",
542 )
543 ));
544 };
545
546 let a0 = AnyScalar::new_real(self.options.a0);
548 let a1 = AnyScalar::new_real(self.options.a1);
549
550 let linop = LocalLinOp::new(
553 Arc::clone(&self.projected_operator),
554 region.to_vec(),
555 state.clone(),
556 self.reference_state.clone(),
557 a0,
558 a1,
559 );
560
561 let apply_a = |x: &T| linop.apply(x);
563
564 let gmres_options = GmresOptions {
566 max_iter: self.options.krylov_dim,
567 rtol: self.options.krylov_tol,
568 max_restarts: (self.options.krylov_maxiter / self.options.krylov_dim).max(1),
569 verbose: false,
570 check_true_residual: false,
571 };
572
573 let result = gmres(apply_a, &rhs_local, init, &gmres_options)?;
575
576 Ok(result.solution)
577 }
578
579 fn sync_reference_state_region(
584 &mut self,
585 step: &LocalUpdateStep<V>,
586 ket_state: &TreeTN<T, V>,
587 ) -> Result<()> {
588 let ket_region = ket_state.extract_subtree(&step.nodes)?;
590
591 let mut ket_to_ref_bond_map: HashMap<<T::Index as IndexLike>::Id, T::Index> =
597 HashMap::new();
598
599 let region_nodes: std::collections::HashSet<_> = step.nodes.iter().collect();
604 for node in &step.nodes {
605 for neighbor in ket_state.site_index_network().neighbors(node) {
606 let ket_edge = match ket_state.edge_between(node, &neighbor) {
607 Some(e) => e,
608 None => continue,
609 };
610 let ket_bond = match ket_state.bond_index(ket_edge) {
611 Some(b) => b,
612 None => continue,
613 };
614
615 let ref_bond = if region_nodes.contains(&neighbor) {
616 ket_bond.sim()
618 } else {
619 let ref_edge = match self.reference_state.edge_between(node, &neighbor) {
621 Some(e) => e,
622 None => continue,
623 };
624 match self.reference_state.bond_index(ref_edge) {
625 Some(b) => b.clone(),
626 None => continue,
627 }
628 };
629 ket_to_ref_bond_map.insert(ket_bond.id().clone(), ref_bond);
630 }
631 }
632
633 for boundary_edge in get_boundary_edges(ket_state, &step.nodes)? {
636 if let Some(edge) = self.reference_state.edge_between(
637 &boundary_edge.node_in_region,
638 &boundary_edge.neighbor_outside,
639 ) {
640 if let Some(ref_bond) = self.reference_state.bond_index(edge) {
641 self.boundary_bond_map.insert(
642 (
643 boundary_edge.node_in_region.clone(),
644 boundary_edge.neighbor_outside.clone(),
645 ),
646 ref_bond.clone(),
647 );
648 }
649 }
650 }
651
652 let mut ref_region = ket_region.clone();
654
655 let mut edges_to_update: Vec<(V, V, T::Index)> = Vec::new();
658 for node in &step.nodes {
659 let neighbors: Vec<V> = ref_region.site_index_network().neighbors(node).collect();
660 for neighbor in neighbors {
661 if let Some(edge) = ref_region.edge_between(node, &neighbor) {
662 if let Some(bond) = ref_region.bond_index(edge) {
663 if let Some(new_bond) = ket_to_ref_bond_map.get(bond.id()) {
664 edges_to_update.push((node.clone(), neighbor, new_bond.clone()));
665 }
666 }
667 }
668 }
669 }
670 for (node, neighbor, new_bond) in edges_to_update {
672 if let Some(edge) = ref_region.edge_between(&node, &neighbor) {
673 ref_region.replace_edge_bond(edge, new_bond)?;
674 }
675 }
676
677 for node in &step.nodes {
679 if let Some(node_idx) = ref_region.node_index(node) {
680 if let Some(tensor) = ref_region.tensor(node_idx) {
681 let mut new_tensor = tensor.clone();
682 let tensor_indices = tensor.external_indices();
683
684 for ket_idx in &tensor_indices {
685 if let Some(ref_bond) = ket_to_ref_bond_map.get(ket_idx.id()) {
687 new_tensor = new_tensor.replaceind(ket_idx, ref_bond)?;
688 }
689 }
691
692 ref_region.replace_tensor(node_idx, new_tensor)?;
693 }
694 }
695 }
696
697 self.reference_state
699 .replace_subtree(&step.nodes, &ref_region)?;
700
701 Ok(())
702 }
703}
704
705impl<T, V> LocalUpdater<T, V> for SquareLinsolveUpdater<T, V>
706where
707 T: TensorLike + 'static,
708 T::Index: IndexLike,
709 <T::Index as IndexLike>::Id:
710 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
711 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug + 'static,
712{
713 fn before_step(
714 &mut self,
715 step: &LocalUpdateStep<V>,
716 full_treetn_before: &TreeTN<T, V>,
717 ) -> Result<()> {
718 self.ensure_reference_state_initialized(full_treetn_before)?;
720
721 if !self.did_ref_bra_ket_precheck {
724 self.precheck_ref_bra_ket_convention(step, full_treetn_before)?;
725 self.did_ref_bra_ket_precheck = true;
726 }
727
728 if !self.did_mpo_validation {
730 self.validate_mpo_external_indices(full_treetn_before)?;
731 self.did_mpo_validation = true;
732 }
733 Ok(())
734 }
735
736 fn update(
737 &mut self,
738 mut subtree: TreeTN<T, V>,
739 step: &LocalUpdateStep<V>,
740 full_treetn: &TreeTN<T, V>,
741 ) -> Result<TreeTN<T, V>> {
742 let init_local = self.contract_region(&subtree, &step.nodes)?;
744 let solved_local = self.solve_local(&step.nodes, &init_local, full_treetn)?;
746
747 let topology = self.build_subtree_topology(&solved_local, &step.nodes, full_treetn)?;
749
750 let mut factorize_options = FactorizeOptions::svd();
752 if let Some(max_rank) = self.options.truncation.max_rank() {
753 factorize_options = factorize_options.with_max_rank(max_rank);
754 }
755 if let Some(policy) = self.options.truncation.svd_policy() {
756 factorize_options = factorize_options.with_svd_policy(policy);
757 }
758 let decomposed = factorize_tensor_to_treetn_with(
761 &solved_local,
762 &topology,
763 factorize_options,
764 &step.new_center,
765 )?;
766
767 self.copy_decomposed_to_subtree(&mut subtree, &decomposed, &step.nodes, full_treetn)?;
769
770 subtree.set_canonical_region([step.new_center.clone()])?;
774 if let Some(edges) = subtree.edges_to_canonicalize_by_names(&step.new_center) {
775 for (from, to) in edges {
776 if let Some(edge) = subtree.edge_between(&from, &to) {
777 subtree.set_edge_ortho_towards(edge, Some(to))?;
778 }
779 }
780 }
781
782 Ok(subtree)
783 }
784
785 fn after_step(
786 &mut self,
787 step: &LocalUpdateStep<V>,
788 full_treetn_after: &TreeTN<T, V>,
789 ) -> Result<()> {
790 let topology = full_treetn_after.site_index_network();
792
793 self.sync_reference_state_region(step, full_treetn_after)?;
796
797 {
799 let mut proj_op = self
800 .projected_operator
801 .write()
802 .map_err(|e| anyhow::anyhow!("Failed to acquire write lock: {}", e))
803 .context("after_step: lock poisoned")?;
804 proj_op.invalidate(&step.nodes, topology);
805 }
806 self.projected_state.invalidate(&step.nodes, topology);
807
808 Ok(())
809 }
810}
811
812impl<T, V> SquareLinsolveUpdater<T, V>
813where
814 T: TensorLike + 'static,
815 T::Index: IndexLike,
816 <T::Index as IndexLike>::Id:
817 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync + 'static,
818 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug + 'static,
819{
820 fn index_sets_match(&self, init_indices: &[T::Index], rhs_indices: &[T::Index]) -> bool {
821 if init_indices.len() != rhs_indices.len() {
822 return false;
823 }
824 let init_ids: std::collections::HashSet<_> = init_indices.iter().map(|i| i.id()).collect();
825 let rhs_ids: std::collections::HashSet<_> = rhs_indices.iter().map(|i| i.id()).collect();
826 init_ids == rhs_ids
827 }
828
829 fn index_structure_mismatch_message(
830 &self,
831 init_indices: &[T::Index],
832 rhs_indices: &[T::Index],
833 header: &str,
834 footer: &str,
835 ) -> String {
836 let init_ids: std::collections::HashSet<_> = init_indices.iter().map(|i| i.id()).collect();
837 let rhs_ids: std::collections::HashSet<_> = rhs_indices.iter().map(|i| i.id()).collect();
838 let extra_in_rhs: Vec<_> = rhs_ids
839 .difference(&init_ids)
840 .map(|id| {
841 rhs_indices
842 .iter()
843 .find(|i| i.id() == *id)
844 .map(|i| format!("{:?}:{}", id, i.dim()))
845 .unwrap_or_else(|| format!("{:?}:?", id))
846 })
847 .collect();
848 let missing_in_rhs: Vec<_> = init_ids
849 .difference(&rhs_ids)
850 .map(|id| {
851 init_indices
852 .iter()
853 .find(|i| i.id() == *id)
854 .map(|i| format!("{:?}:{}", id, i.dim()))
855 .unwrap_or_else(|| format!("{:?}:?", id))
856 })
857 .collect();
858
859 format!(
860 "{header}:\n init has {} indices: {:?}\n rhs has {} indices: {:?}\n extra in rhs (not in init): {:?}\n missing in rhs (in init but not in rhs): {:?}\n\n{footer}",
861 init_indices.len(),
862 init_indices
863 .iter()
864 .map(|i| format!("{:?}:{}", i.id(), i.dim()))
865 .collect::<Vec<_>>(),
866 rhs_indices.len(),
867 rhs_indices
868 .iter()
869 .map(|i| format!("{:?}:{}", i.id(), i.dim()))
870 .collect::<Vec<_>>(),
871 extra_in_rhs,
872 missing_in_rhs,
873 )
874 }
875
876 fn precheck_ref_bra_ket_convention(
877 &mut self,
878 step: &LocalUpdateStep<V>,
879 full_treetn_before: &TreeTN<T, V>,
880 ) -> Result<()> {
881 let subtree = full_treetn_before.extract_subtree(&step.nodes)?;
882 let init_local = self.contract_region(&subtree, &step.nodes)?;
883
884 let topology = full_treetn_before.site_index_network();
885 let rhs_local_raw =
886 self.projected_state
887 .local_constant_term(&step.nodes, full_treetn_before, topology)?;
888
889 let init_indices = init_local.external_indices();
890 let rhs_indices = rhs_local_raw.external_indices();
891
892 if !self.index_sets_match(&init_indices, &rhs_indices) {
893 return Err(anyhow::anyhow!(
894 "{}",
895 self.index_structure_mismatch_message(
896 &init_indices,
897 &rhs_indices,
898 "linsolve precheck failed (local index structure mismatch)",
899 "This suggests `<ref|H|x>` vs `<ref|b>` conventions (or external-index contraction rules) are inconsistent for the current region. See `plan/linsolve-mpo.md` for analysis.",
900 )
901 ));
902 }
903
904 Ok(())
905 }
906
907 #[allow(clippy::type_complexity)]
908 fn validate_mpo_external_indices(&mut self, state: &TreeTN<T, V>) -> Result<()> {
909 let (input_mapping, output_mapping): (
911 HashMap<V, IndexMapping<T::Index>>,
912 HashMap<V, IndexMapping<T::Index>>,
913 ) = {
914 let proj_op = self.projected_operator.read().map_err(|e| {
915 anyhow::anyhow!("validate_mpo_external_indices: lock poisoned: {e}")
916 })?;
917 let Some(input) = proj_op.input_mapping.as_ref() else {
918 return Ok(());
919 };
920 let Some(output) = proj_op.output_mapping.as_ref() else {
921 return Ok(());
922 };
923 (input.clone(), output.clone())
924 };
925
926 for node in state.node_names() {
927 let Some(x_sites) = state.site_space(&node) else {
928 continue;
929 };
930 let Some(b_sites) = self.projected_state.rhs.site_space(&node) else {
931 continue;
932 };
933
934 if x_sites.len() != 2 || b_sites.len() != 2 {
936 continue;
937 }
938
939 let x_contracted = input_mapping
940 .get(&node)
941 .ok_or_else(|| {
942 anyhow::anyhow!("MPO validation: missing input_mapping for node {:?}", node)
943 })?
944 .true_index
945 .clone();
946 let b_contracted = output_mapping
947 .get(&node)
948 .ok_or_else(|| {
949 anyhow::anyhow!("MPO validation: missing output_mapping for node {:?}", node)
950 })?
951 .true_index
952 .clone();
953
954 let x_external: Vec<_> = x_sites
955 .iter()
956 .filter(|idx| !idx.same_id(&x_contracted))
957 .cloned()
958 .collect();
959 let b_external: Vec<_> = b_sites
960 .iter()
961 .filter(|idx| !idx.same_id(&b_contracted))
962 .cloned()
963 .collect();
964
965 if x_external.len() != 1 || b_external.len() != 1 {
966 return Err(anyhow::anyhow!(
967 "MPO validation: expected exactly 1 external site index after removing contracted index. node={:?}, x_site_len={}, b_site_len={}, x_external={:?}, b_external={:?}",
968 node,
969 x_sites.len(),
970 b_sites.len(),
971 x_external.iter().map(|i| format!("{:?}:{}", i.id(), i.dim())).collect::<Vec<_>>(),
972 b_external.iter().map(|i| format!("{:?}:{}", i.id(), i.dim())).collect::<Vec<_>>(),
973 ));
974 }
975
976 let x_ext = &x_external[0];
977 let b_ext = &b_external[0];
978 if !x_ext.same_id(b_ext) || x_ext.dim() != b_ext.dim() {
979 return Err(anyhow::anyhow!(
980 "MPO validation: external index mismatch at node {:?}: x has {:?}:{}, b has {:?}:{}",
981 node,
982 x_ext.id(),
983 x_ext.dim(),
984 b_ext.id(),
985 b_ext.dim(),
986 ));
987 }
988 }
989
990 Ok(())
991 }
992}