1#![allow(dead_code)]
8
9mod addition;
10mod canonicalize;
11pub mod contraction;
12mod decompose;
13mod fit;
14mod localupdate;
15mod operator_impl;
16mod ops;
17pub mod partial_contraction;
18mod restructure;
19mod swap;
20mod tensor_like;
21mod transform;
22mod truncate;
23
24use petgraph::stable_graph::{EdgeIndex, NodeIndex};
25use petgraph::visit::{Dfs, EdgeRef};
26use std::collections::HashMap;
27use std::collections::HashSet;
28use std::hash::Hash;
29
30use anyhow::{Context, Result};
31
32use crate::algorithm::CanonicalForm;
33use tensor4all_core::{AllowedPairs, Canonical, FactorizeOptions, IndexLike, TensorLike};
34
35use crate::named_graph::NamedGraph;
36use crate::site_index_network::SiteIndexNetwork;
37
38pub use decompose::{factorize_tensor_to_treetn, factorize_tensor_to_treetn_with, TreeTopology};
40
41pub use localupdate::{
43 apply_local_update_sweep, get_boundary_edges, BoundaryEdge, LocalUpdateStep,
44 LocalUpdateSweepPlan, LocalUpdater, TruncateUpdater,
45};
46
47pub use partial_contraction::{partial_contract, PartialContractionSpec};
49
50pub use swap::{ScheduledSwapStep, SwapOptions, SwapSchedule};
52
53pub struct TreeTN<T = tensor4all_core::TensorDynLen, V = NodeIndex>
105where
106 T: TensorLike,
107 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
108{
109 pub(crate) graph: NamedGraph<V, T, T::Index>,
112 pub(crate) canonical_region: HashSet<V>,
117 pub(crate) canonical_form: Option<CanonicalForm>,
121 pub(crate) site_index_network: SiteIndexNetwork<V, T::Index>,
124 pub(crate) link_index_network: crate::link_index_network::LinkIndexNetwork<T::Index>,
127 pub(crate) ortho_towards: HashMap<T::Index, V>,
134}
135
136#[derive(Debug)]
140pub(crate) struct SweepContext {
141 pub(crate) edges: Vec<(NodeIndex, NodeIndex)>,
144}
145
146impl<T, V> TreeTN<T, V>
151where
152 T: TensorLike,
153 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
154{
155 pub fn new() -> Self {
159 Self {
160 graph: NamedGraph::new(),
161 canonical_region: HashSet::new(),
162 canonical_form: None,
163 site_index_network: SiteIndexNetwork::new(),
164 link_index_network: crate::link_index_network::LinkIndexNetwork::new(),
165 ortho_towards: HashMap::new(),
166 }
167 }
168
169 pub fn from_tensors(tensors: Vec<T>, node_names: Vec<V>) -> Result<Self>
218 where
219 <T::Index as IndexLike>::Id:
220 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
221 V: Ord,
222 {
223 let treetn = Self::from_tensors_unchecked(tensors, node_names)?;
224
225 treetn.verify_internal_consistency().context(
227 "TreeTN::from_tensors: constructed TreeTN failed internal consistency check",
228 )?;
229
230 Ok(treetn)
231 }
232
233 fn from_tensors_unchecked(tensors: Vec<T>, node_names: Vec<V>) -> Result<Self>
236 where
237 <T::Index as IndexLike>::Id:
238 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
239 {
240 if tensors.len() != node_names.len() {
242 return Err(anyhow::anyhow!(
243 "Length mismatch: {} tensors but {} node names",
244 tensors.len(),
245 node_names.len()
246 ))
247 .context("TreeTN::from_tensors: tensors and node_names must have the same length");
248 }
249
250 let mut treetn = Self::new();
252
253 let mut node_indices = Vec::with_capacity(tensors.len());
255 for (tensor, node_name) in tensors.into_iter().zip(node_names) {
256 let node_idx = treetn.add_tensor_internal(node_name, tensor)?;
257 node_indices.push(node_idx);
258 }
259
260 #[allow(clippy::type_complexity)]
263 let mut index_map: HashMap<
264 <T::Index as IndexLike>::Id,
265 Vec<(NodeIndex, T::Index)>,
266 > = HashMap::new();
267
268 for node_idx in &node_indices {
269 let tensor = treetn
270 .tensor(*node_idx)
271 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_idx))?;
272
273 for index in tensor.external_indices() {
274 index_map
275 .entry(index.id().clone())
276 .or_insert_with(Vec::new)
277 .push((*node_idx, index.clone()));
278 }
279 }
280
281 for (index_id, nodes_with_index) in index_map {
284 match nodes_with_index.len() {
285 0 => unreachable!(),
286 1 => {
287 continue;
289 }
290 2 => {
291 let (node_a, index_a) = &nodes_with_index[0];
293 let (node_b, index_b) = &nodes_with_index[1];
294
295 treetn
296 .connect_internal(*node_a, index_a, *node_b, index_b)
297 .with_context(|| {
298 format!(
299 "Failed to connect nodes {:?} and {:?} via index ID {:?}",
300 node_a, node_b, index_id
301 )
302 })?;
303 }
304 n => {
305 return Err(anyhow::anyhow!(
307 "Index ID {:?} appears in {} tensors, but TreeTN requires exactly 2 (tree structure)",
308 index_id, n
309 ))
310 .context("TreeTN::from_tensors: each bond index must connect exactly 2 nodes");
311 }
312 }
313 }
314
315 Ok(treetn)
316 }
317
318 pub fn add_tensor(&mut self, node_name: V, tensor: T) -> Result<NodeIndex> {
325 self.add_tensor_internal(node_name, tensor)
326 }
327
328 pub fn add_tensor_auto_name(&mut self, tensor: T) -> NodeIndex
334 where
335 V: From<NodeIndex> + Into<NodeIndex>,
336 {
337 let temp_idx = self.graph.graph_mut().add_node(tensor.clone());
339 let node_name = V::from(temp_idx);
340
341 self.graph.graph_mut().remove_node(temp_idx);
343
344 self.add_tensor_internal(node_name, tensor)
346 .expect("add_tensor_internal failed for auto-named tensor")
347 }
348
349 pub fn connect(
362 &mut self,
363 node_a: NodeIndex,
364 index_a: &T::Index,
365 node_b: NodeIndex,
366 index_b: &T::Index,
367 ) -> Result<EdgeIndex> {
368 self.connect_internal(node_a, index_a, node_b, index_b)
369 }
370}
371
372impl<T, V> TreeTN<T, V>
377where
378 T: TensorLike,
379 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
380{
381 pub(crate) fn add_tensor_internal(&mut self, node_name: V, tensor: T) -> Result<NodeIndex> {
387 let physical_indices: HashSet<T::Index> = tensor.external_indices().into_iter().collect();
389
390 let node_idx = self
392 .graph
393 .add_node(node_name.clone(), tensor)
394 .map_err(|e| anyhow::anyhow!(e))?;
395
396 self.site_index_network
398 .add_node(node_name, physical_indices)
399 .map_err(|e| anyhow::anyhow!("Failed to add node to site_index_network: {}", e))?;
400
401 Ok(node_idx)
402 }
403
404 pub(crate) fn connect_internal(
408 &mut self,
409 node_a: NodeIndex,
410 index_a: &T::Index,
411 node_b: NodeIndex,
412 index_b: &T::Index,
413 ) -> Result<EdgeIndex> {
414 if index_a.id() != index_b.id() {
416 return Err(anyhow::anyhow!(
417 "Index IDs must match in Einsum mode: {:?} != {:?}",
418 index_a.id(),
419 index_b.id()
420 ))
421 .context("Failed to connect tensors");
422 }
423
424 if !self.graph.contains_node(node_a) || !self.graph.contains_node(node_b) {
426 return Err(anyhow::anyhow!("One or both nodes do not exist"))
427 .context("Failed to connect tensors");
428 }
429
430 let tensor_a = self
432 .tensor(node_a)
433 .ok_or_else(|| anyhow::anyhow!("Tensor for node_a not found"))?;
434 let tensor_b = self
435 .tensor(node_b)
436 .ok_or_else(|| anyhow::anyhow!("Tensor for node_b not found"))?;
437
438 let has_index_a = tensor_a.external_indices().iter().any(|idx| idx == index_a);
440 let has_index_b = tensor_b.external_indices().iter().any(|idx| idx == index_b);
441
442 if !has_index_a {
443 return Err(anyhow::anyhow!("Index not found in tensor_a"))
444 .context("Failed to connect: index_a must exist in tensor_a");
445 }
446 if !has_index_b {
447 return Err(anyhow::anyhow!("Index not found in tensor_b"))
448 .context("Failed to connect: index_b must exist in tensor_b");
449 }
450
451 let bond_index = tensor_a
453 .external_indices()
454 .iter()
455 .find(|idx| idx.same_id(index_a))
456 .unwrap()
457 .clone();
458
459 let node_name_a = self
461 .graph
462 .node_name(node_a)
463 .ok_or_else(|| anyhow::anyhow!("Node name for node_a not found"))?
464 .clone();
465 let node_name_b = self
466 .graph
467 .node_name(node_b)
468 .ok_or_else(|| anyhow::anyhow!("Node name for node_b not found"))?
469 .clone();
470
471 let edge_idx = self
473 .graph
474 .graph_mut()
475 .add_edge(node_a, node_b, bond_index.clone());
476
477 self.site_index_network
479 .add_edge(&node_name_a, &node_name_b)
480 .map_err(|e| anyhow::anyhow!("Failed to add edge to site_index_network: {}", e))?;
481
482 let _ = self
485 .site_index_network
486 .remove_site_index(&node_name_a, &bond_index);
487 let _ = self
488 .site_index_network
489 .remove_site_index(&node_name_b, &bond_index);
490
491 self.link_index_network.insert(edge_idx, &bond_index);
493
494 Ok(edge_idx)
495 }
496
497 pub(crate) fn prepare_sweep_to_center(
511 &mut self,
512 canonical_region: impl IntoIterator<Item = V>,
513 context_name: &str,
514 ) -> Result<Option<SweepContext>> {
515 self.validate_tree()
517 .with_context(|| format!("{}: graph must be a tree", context_name))?;
518
519 let canonical_region_v: Vec<V> = canonical_region.into_iter().collect();
521 self.set_canonical_region(canonical_region_v)
522 .with_context(|| format!("{}: failed to set canonical_region", context_name))?;
523
524 if self.canonical_region.is_empty() {
525 return Ok(None); }
527
528 let center_indices: HashSet<NodeIndex> = self
530 .canonical_region
531 .iter()
532 .filter_map(|name| self.graph.node_index(name))
533 .collect();
534
535 if !self.site_index_network.is_connected_subset(¢er_indices) {
537 return Err(anyhow::anyhow!(
538 "canonical_region is not connected: {} centers but not all reachable",
539 self.canonical_region.len()
540 ))
541 .with_context(|| {
542 format!(
543 "{}: canonical_region must form a connected subtree",
544 context_name
545 )
546 });
547 }
548
549 let canonicalize_edges = self
551 .site_index_network
552 .edges_to_canonicalize_to_region(¢er_indices);
553 let edges: Vec<(NodeIndex, NodeIndex)> = canonicalize_edges.into_iter().collect();
554
555 Ok(Some(SweepContext { edges }))
556 }
557
558 pub(crate) fn sweep_edge(
572 &mut self,
573 src: NodeIndex,
574 dst: NodeIndex,
575 factorize_options: &FactorizeOptions,
576 context_name: &str,
577 ) -> Result<()> {
578 let edge = {
580 let g = self.graph.graph();
581 g.edges_connecting(src, dst)
582 .next()
583 .ok_or_else(|| {
584 anyhow::anyhow!("No edge found between node {:?} and {:?}", src, dst)
585 })
586 .with_context(|| format!("{}: edge not found", context_name))?
587 .id()
588 };
589
590 let bond_on_src = self
592 .bond_index(edge)
593 .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge"))
594 .with_context(|| format!("{}: failed to get bond index on src", context_name))?
595 .clone();
596
597 let tensor_src = self
599 .tensor(src)
600 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", src))
601 .with_context(|| format!("{}: tensor not found", context_name))?;
602
603 let left_inds: Vec<T::Index> = tensor_src
605 .external_indices()
606 .iter()
607 .filter(|idx| idx.id() != bond_on_src.id())
608 .cloned()
609 .collect();
610
611 let tensor_external_indices = tensor_src.external_indices();
612 if left_inds.is_empty() {
613 let tensor_dst = self
614 .tensor(dst)
615 .ok_or_else(|| anyhow::anyhow!("Tensor not found for dst node {:?}", dst))
616 .with_context(|| format!("{}: dst tensor not found", context_name))?;
617
618 let src_norm = tensor_src.norm();
619 let updated_src_tensor = if src_norm > 0.0 {
620 tensor_src
621 .scale(tensor4all_core::AnyScalar::new_real(1.0 / src_norm))
622 .with_context(|| format!("{}: failed to normalize src tensor", context_name))?
623 } else {
624 tensor_src.clone()
625 };
626 let updated_dst_tensor = if src_norm > 0.0 {
627 tensor_dst
628 .scale(tensor4all_core::AnyScalar::new_real(src_norm))
629 .with_context(|| format!("{}: failed to scale dst tensor", context_name))?
630 } else {
631 tensor_dst.clone()
632 };
633
634 self.replace_tensor(src, updated_src_tensor)
635 .with_context(|| {
636 format!("{}: failed to replace tensor at src node", context_name)
637 })?;
638 self.replace_tensor(dst, updated_dst_tensor)
639 .with_context(|| {
640 format!("{}: failed to replace tensor at dst node", context_name)
641 })?;
642
643 let dst_name = self
644 .graph
645 .node_name(dst)
646 .ok_or_else(|| anyhow::anyhow!("Dst node name not found"))?
647 .clone();
648 self.set_edge_ortho_towards(edge, Some(dst_name))
649 .with_context(|| format!("{}: failed to set ortho_towards", context_name))?;
650
651 return Ok(());
652 }
653
654 if left_inds.len() == tensor_external_indices.len() {
655 return Err(anyhow::anyhow!(
656 "Cannot process node {:?}: need at least one left index and one right index",
657 src
658 ))
659 .with_context(|| format!("{}: invalid tensor rank for factorization", context_name));
660 }
661
662 let factorize_result = tensor_src
664 .factorize(&left_inds, factorize_options)
665 .map_err(|e| anyhow::anyhow!("Factorization failed: {}", e))
666 .with_context(|| format!("{}: factorization failed", context_name))?;
667
668 let left_tensor = factorize_result.left;
669 let right_tensor = factorize_result.right;
670
671 let tensor_dst = self
673 .tensor(dst)
674 .ok_or_else(|| anyhow::anyhow!("Tensor not found for dst node {:?}", dst))
675 .with_context(|| format!("{}: dst tensor not found", context_name))?;
676
677 let updated_dst_tensor = T::contract(&[tensor_dst, &right_tensor], AllowedPairs::All)
678 .with_context(|| {
679 format!(
680 "{}: failed to absorb right factor into dst tensor",
681 context_name
682 )
683 })?;
684
685 let new_bond_index = factorize_result.bond_index;
687 self.replace_edge_bond(edge, new_bond_index.clone())
688 .with_context(|| format!("{}: failed to update edge bond index", context_name))?;
689
690 self.replace_tensor(src, left_tensor)
692 .with_context(|| format!("{}: failed to replace tensor at src node", context_name))?;
693 self.replace_tensor(dst, updated_dst_tensor)
694 .with_context(|| format!("{}: failed to replace tensor at dst node", context_name))?;
695
696 let dst_name = self
698 .graph
699 .node_name(dst)
700 .ok_or_else(|| anyhow::anyhow!("Dst node name not found"))?
701 .clone();
702 self.set_edge_ortho_towards(edge, Some(dst_name))
703 .with_context(|| format!("{}: failed to set ortho_towards", context_name))?;
704
705 Ok(())
706 }
707
708 pub fn tensor(&self, node: NodeIndex) -> Option<&T> {
714 self.graph.graph().node_weight(node)
715 }
716
717 pub fn tensor_mut(&mut self, node: NodeIndex) -> Option<&mut T> {
719 self.graph.graph_mut().node_weight_mut(node)
720 }
721
722 pub fn replace_tensor(&mut self, node: NodeIndex, new_tensor: T) -> Result<Option<T>> {
729 if !self.graph.contains_node(node) {
731 return Ok(None);
732 }
733
734 let edges = self.edges_for_node(node);
736 let connection_indices: Vec<T::Index> = edges
737 .iter()
738 .filter_map(|(edge_idx, _neighbor)| self.bond_index(*edge_idx).cloned())
739 .collect();
740
741 let new_tensor_indices = new_tensor.external_indices();
743 let common = common_inds(&connection_indices, &new_tensor_indices);
744 if common.len() != connection_indices.len() {
745 return Err(anyhow::anyhow!(
746 "New tensor is missing {} connection index(es): found {} out of {} required indices",
747 connection_indices.len() - common.len(),
748 common.len(),
749 connection_indices.len()
750 ))
751 .context("replace_tensor: new tensor must contain all indices used in connections");
752 }
753
754 let node_name = self
756 .graph
757 .node_name(node)
758 .ok_or_else(|| anyhow::anyhow!("Node name not found"))?
759 .clone();
760
761 let connection_indices_set: HashSet<T::Index> =
763 connection_indices.iter().cloned().collect();
764 let new_physical_indices: HashSet<T::Index> = new_tensor_indices
765 .iter()
766 .filter(|idx| !connection_indices_set.contains(idx))
767 .cloned()
768 .collect();
769
770 let old_tensor = self
772 .graph
773 .graph_mut()
774 .node_weight_mut(node)
775 .map(|old| std::mem::replace(old, new_tensor));
776
777 self.site_index_network
780 .set_site_space(&node_name, new_physical_indices)
781 .map_err(|e| anyhow::anyhow!("Failed to update site_index_network: {}", e))?;
782
783 Ok(old_tensor)
784 }
785
786 pub fn bond_index(&self, edge: EdgeIndex) -> Option<&T::Index> {
788 self.graph.graph().edge_weight(edge)
789 }
790
791 pub fn bond_index_mut(&mut self, edge: EdgeIndex) -> Option<&mut T::Index> {
793 self.graph.graph_mut().edge_weight_mut(edge)
794 }
795
796 pub fn edges_for_node(&self, node: NodeIndex) -> Vec<(EdgeIndex, NodeIndex)> {
798 self.graph
799 .graph()
800 .edges(node)
801 .map(|edge| {
802 let target = edge.target();
803 (edge.id(), target)
804 })
805 .collect()
806 }
807
808 pub fn replace_edge_bond(&mut self, edge: EdgeIndex, new_bond_index: T::Index) -> Result<()> {
813 let (source, target) = self
815 .graph
816 .graph()
817 .edge_endpoints(edge)
818 .ok_or_else(|| anyhow::anyhow!("Edge does not exist"))?;
819
820 let old_bond_index = self
822 .bond_index(edge)
823 .ok_or_else(|| anyhow::anyhow!("Bond index not found"))?
824 .clone();
825
826 let node_name_a = self
828 .graph
829 .node_name(source)
830 .ok_or_else(|| anyhow::anyhow!("Node name for source not found"))?
831 .clone();
832 let node_name_b = self
833 .graph
834 .node_name(target)
835 .ok_or_else(|| anyhow::anyhow!("Node name for target not found"))?
836 .clone();
837
838 *self
840 .bond_index_mut(edge)
841 .ok_or_else(|| anyhow::anyhow!("Bond index not found"))? = new_bond_index.clone();
842
843 self.link_index_network
845 .replace_index(&old_bond_index, &new_bond_index, edge)
846 .map_err(|e| anyhow::anyhow!("{}", e))?;
847
848 if let Some(dir) = self.ortho_towards.remove(&old_bond_index) {
850 self.ortho_towards.insert(new_bond_index.clone(), dir);
851 }
852
853 if let Some(site_space_a) = self.site_index_network.site_space_mut(&node_name_a) {
857 site_space_a.insert(old_bond_index.clone());
858 site_space_a.remove(&new_bond_index);
859 }
860 if let Some(site_space_b) = self.site_index_network.site_space_mut(&node_name_b) {
861 site_space_b.insert(old_bond_index);
862 site_space_b.remove(&new_bond_index);
863 }
864
865 Ok(())
866 }
867
868 pub fn sim_linkinds(&self) -> Result<Self>
882 where
883 T::Index: IndexLike,
884 {
885 let mut result = self.clone();
886 result.sim_linkinds_mut()?;
887 Ok(result)
888 }
889
890 pub fn sim_linkinds_mut(&mut self) -> Result<()>
894 where
895 T::Index: IndexLike,
896 {
897 let edges: Vec<EdgeIndex> = self.graph.graph().edge_indices().collect();
899 for edge in edges {
900 let old_bond = self
901 .bond_index(edge)
902 .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge {:?}", edge))?
903 .clone();
904 let new_bond = old_bond.sim();
905
906 *self
908 .bond_index_mut(edge)
909 .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge {:?}", edge))? =
910 new_bond.clone();
911
912 let (node_a, node_b) = self
914 .graph
915 .graph()
916 .edge_endpoints(edge)
917 .ok_or_else(|| anyhow::anyhow!("Edge {:?} not found", edge))?;
918 for node in [node_a, node_b] {
919 let tensor = self
920 .tensor(node)
921 .ok_or_else(|| anyhow::anyhow!("Tensor not found"))?;
922 let old_in_tensor = tensor
923 .external_indices()
924 .iter()
925 .find(|idx| idx.id() == old_bond.id())
926 .ok_or_else(|| anyhow::anyhow!("Bond index not found in endpoint tensor"))?
927 .clone();
928 let new_tensor = tensor.replaceind(&old_in_tensor, &new_bond)?;
929 self.replace_tensor(node, new_tensor)?;
930 }
931
932 if let Some((key, dir)) = self
934 .ortho_towards
935 .iter()
936 .find(|(k, _)| k.id() == old_bond.id())
937 .map(|(k, v)| (k.clone(), v.clone()))
938 {
939 self.ortho_towards.remove(&key);
940 self.ortho_towards.insert(new_bond.clone(), dir);
941 }
942
943 self.link_index_network
945 .replace_index(&old_bond, &new_bond, edge)
946 .map_err(|e| anyhow::anyhow!("{}", e))?;
947 }
948 Ok(())
949 }
950
951 pub fn set_ortho_towards(&mut self, index: &T::Index, dir: Option<V>) {
959 match dir {
960 Some(node_name) => {
961 self.ortho_towards.insert(index.clone(), node_name);
962 }
963 None => {
964 self.ortho_towards.remove(index);
965 }
966 }
967 }
968
969 pub fn ortho_towards_for_index(&self, index: &T::Index) -> Option<&V> {
973 self.ortho_towards.get(index)
974 }
975
976 pub fn set_edge_ortho_towards(
983 &mut self,
984 edge: petgraph::stable_graph::EdgeIndex,
985 dir: Option<V>,
986 ) -> Result<()> {
987 let bond = self
989 .bond_index(edge)
990 .ok_or_else(|| anyhow::anyhow!("Edge does not exist"))?
991 .clone();
992
993 if let Some(ref node_name) = dir {
995 let (source, target) = self
996 .graph
997 .graph()
998 .edge_endpoints(edge)
999 .ok_or_else(|| anyhow::anyhow!("Edge does not exist"))?;
1000
1001 let source_name = self.graph.node_name(source);
1002 let target_name = self.graph.node_name(target);
1003
1004 if source_name != Some(node_name) && target_name != Some(node_name) {
1005 return Err(anyhow::anyhow!(
1006 "ortho_towards node {:?} must be one of the edge endpoints",
1007 node_name
1008 ))
1009 .context("set_edge_ortho_towards: invalid node");
1010 }
1011 }
1012
1013 self.set_ortho_towards(&bond, dir);
1014 Ok(())
1015 }
1016
1017 pub fn ortho_towards_node(&self, edge: petgraph::stable_graph::EdgeIndex) -> Option<&V> {
1021 self.bond_index(edge)
1022 .and_then(|bond| self.ortho_towards.get(bond))
1023 }
1024
1025 pub fn ortho_towards_node_index(
1029 &self,
1030 edge: petgraph::stable_graph::EdgeIndex,
1031 ) -> Option<NodeIndex> {
1032 self.ortho_towards_node(edge)
1033 .and_then(|name| self.graph.node_index(name))
1034 }
1035
1036 pub fn validate_tree(&self) -> Result<()> {
1042 let g = self.graph.graph();
1043 if g.node_count() == 0 {
1044 return Ok(()); }
1046
1047 let mut visited = std::collections::HashSet::new();
1049 let start_node = g
1050 .node_indices()
1051 .next()
1052 .ok_or_else(|| anyhow::anyhow!("Graph has no nodes"))?;
1053
1054 let mut dfs = Dfs::new(g, start_node);
1056 while let Some(node) = dfs.next(g) {
1057 visited.insert(node);
1058 }
1059
1060 if visited.len() != g.node_count() {
1061 return Err(anyhow::anyhow!(
1062 "Graph is not connected: {} nodes reachable out of {}",
1063 visited.len(),
1064 g.node_count()
1065 ))
1066 .context("validate_tree: graph must be connected");
1067 }
1068
1069 let node_count = g.node_count();
1071 let edge_count = g.edge_count();
1072
1073 if edge_count != node_count - 1 {
1074 return Err(anyhow::anyhow!(
1075 "Graph does not satisfy tree condition: {} edges != {} nodes - 1",
1076 edge_count,
1077 node_count
1078 ))
1079 .context("validate_tree: tree must have edges = nodes - 1");
1080 }
1081
1082 Ok(())
1083 }
1084
1085 pub fn node_count(&self) -> usize {
1087 self.graph.graph().node_count()
1088 }
1089
1090 pub fn edge_count(&self) -> usize {
1092 self.graph.graph().edge_count()
1093 }
1094
1095 pub fn node_index(&self, node_name: &V) -> Option<NodeIndex> {
1097 self.graph.node_index(node_name)
1098 }
1099
1100 pub fn rename_node(&mut self, old_name: &V, new_name: V) -> Result<()> {
1103 if old_name == &new_name {
1104 return Ok(());
1105 }
1106
1107 self.graph
1108 .rename_node(old_name, new_name.clone())
1109 .map_err(|e| anyhow::anyhow!(e))
1110 .context("rename_node: failed to rename graph node")?;
1111 self.site_index_network
1112 .rename_node(old_name, new_name.clone())
1113 .map_err(|e| anyhow::anyhow!(e))
1114 .context("rename_node: failed to rename site-index node")?;
1115
1116 if self.canonical_region.remove(old_name) {
1117 self.canonical_region.insert(new_name.clone());
1118 }
1119
1120 for target in self.ortho_towards.values_mut() {
1121 if target == old_name {
1122 *target = new_name.clone();
1123 }
1124 }
1125
1126 Ok(())
1127 }
1128
1129 pub fn edge_between(&self, node_a: &V, node_b: &V) -> Option<EdgeIndex> {
1133 let idx_a = self.graph.node_index(node_a)?;
1134 let idx_b = self.graph.node_index(node_b)?;
1135 self.graph
1136 .graph()
1137 .find_edge(idx_a, idx_b)
1138 .or_else(|| self.graph.graph().find_edge(idx_b, idx_a))
1139 }
1140
1141 pub fn node_indices(&self) -> Vec<NodeIndex> {
1143 self.graph.graph().node_indices().collect()
1144 }
1145
1146 pub fn node_names(&self) -> Vec<V> {
1148 self.graph
1149 .graph()
1150 .node_indices()
1151 .filter_map(|idx| self.graph.node_name(idx).cloned())
1152 .collect()
1153 }
1154
1155 pub fn edges_to_canonicalize_by_names(&self, target: &V) -> Option<Vec<(V, V)>> {
1170 self.site_index_network
1171 .edges_to_canonicalize_by_names(target)
1172 }
1173
1174 pub fn canonical_region(&self) -> &HashSet<V> {
1178 &self.canonical_region
1179 }
1180
1181 pub fn is_canonicalized(&self) -> bool {
1185 !self.canonical_region.is_empty()
1186 }
1187
1188 pub fn set_canonical_region(&mut self, region: impl IntoIterator<Item = V>) -> Result<()> {
1192 let region: HashSet<V> = region.into_iter().collect();
1193
1194 for node_name in ®ion {
1196 if !self.graph.has_node(node_name) {
1197 return Err(anyhow::anyhow!(
1198 "Node {:?} does not exist in the graph",
1199 node_name
1200 ))
1201 .context("set_canonical_region: all nodes must be valid");
1202 }
1203 }
1204
1205 self.canonical_region = region;
1206 Ok(())
1207 }
1208
1209 pub fn clear_canonical_region(&mut self) {
1213 self.canonical_region.clear();
1214 self.canonical_form = None;
1215 }
1216
1217 pub fn canonical_form(&self) -> Option<CanonicalForm> {
1221 self.canonical_form
1222 }
1223
1224 pub fn add_to_canonical_region(&mut self, node_name: V) -> Result<()> {
1228 if !self.graph.has_node(&node_name) {
1229 return Err(anyhow::anyhow!(
1230 "Node {:?} does not exist in the graph",
1231 node_name
1232 ))
1233 .context("add_to_canonical_region: node must be valid");
1234 }
1235 self.canonical_region.insert(node_name);
1236 Ok(())
1237 }
1238
1239 pub fn remove_from_canonical_region(&mut self, node_name: &V) -> bool {
1243 self.canonical_region.remove(node_name)
1244 }
1245
1246 pub fn site_index_network(&self) -> &SiteIndexNetwork<V, T::Index> {
1250 &self.site_index_network
1251 }
1252
1253 pub fn site_index_network_mut(&mut self) -> &mut SiteIndexNetwork<V, T::Index> {
1255 &mut self.site_index_network
1256 }
1257
1258 pub fn site_space(&self, node_name: &V) -> Option<&std::collections::HashSet<T::Index>> {
1260 self.site_index_network.site_space(node_name)
1261 }
1262
1263 pub fn site_space_mut(
1265 &mut self,
1266 node_name: &V,
1267 ) -> Option<&mut std::collections::HashSet<T::Index>> {
1268 self.site_index_network.site_space_mut(node_name)
1269 }
1270
1271 pub fn share_equivalent_site_index_network(&self, other: &Self) -> bool
1285 where
1286 <T::Index as IndexLike>::Id: Ord,
1287 {
1288 self.site_index_network
1289 .share_equivalent_site_index_network(&other.site_index_network)
1290 }
1291
1292 pub fn same_topology(&self, other: &Self) -> bool {
1300 self.site_index_network
1301 .topology()
1302 .same_topology(other.site_index_network.topology())
1303 }
1304
1305 pub fn same_appearance(&self, other: &Self) -> bool
1327 where
1328 <T::Index as IndexLike>::Id: Ord,
1329 V: Ord,
1330 {
1331 if !self.share_equivalent_site_index_network(other) {
1333 return false;
1334 }
1335
1336 let mut self_bond_ortho_count = 0;
1339 let mut other_bond_ortho_count = 0;
1340
1341 for node_name in self.node_names() {
1343 let self_neighbors: Vec<V> = self.site_index_network.neighbors(&node_name).collect();
1344
1345 for neighbor_name in self_neighbors {
1346 if node_name >= neighbor_name {
1348 continue;
1349 }
1350
1351 let self_edge = match self.edge_between(&node_name, &neighbor_name) {
1353 Some(e) => e,
1354 None => continue,
1355 };
1356 let self_bond = match self.bond_index(self_edge) {
1357 Some(b) => b,
1358 None => continue,
1359 };
1360
1361 let other_edge = match other.edge_between(&node_name, &neighbor_name) {
1363 Some(e) => e,
1364 None => return false, };
1366 let other_bond = match other.bond_index(other_edge) {
1367 Some(b) => b,
1368 None => return false,
1369 };
1370
1371 let self_ortho = self.ortho_towards.get(self_bond);
1373 let other_ortho = other.ortho_towards.get(other_bond);
1374
1375 match (self_ortho, other_ortho) {
1376 (None, None) => {} (Some(self_dir), Some(other_dir)) => {
1378 if self_dir != other_dir {
1380 return false;
1381 }
1382 self_bond_ortho_count += 1;
1383 other_bond_ortho_count += 1;
1384 }
1385 _ => return false, }
1387 }
1388 }
1389
1390 let self_total_bond_entries: usize = self
1394 .graph
1395 .graph()
1396 .edge_indices()
1397 .filter_map(|e| self.bond_index(e))
1398 .filter(|b| self.ortho_towards.contains_key(b))
1399 .count();
1400 let other_total_bond_entries: usize = other
1401 .graph
1402 .graph()
1403 .edge_indices()
1404 .filter_map(|e| other.bond_index(e))
1405 .filter(|b| other.ortho_towards.contains_key(b))
1406 .count();
1407
1408 if self_bond_ortho_count != self_total_bond_entries
1409 || other_bond_ortho_count != other_total_bond_entries
1410 {
1411 return false;
1412 }
1413
1414 true
1415 }
1416
1417 pub(crate) fn swap_on_edge(
1423 &mut self,
1424 node_a_idx: NodeIndex,
1425 node_b_idx: NodeIndex,
1426 a_side_sites: &HashSet<<T::Index as IndexLike>::Id>,
1427 b_side_sites: &HashSet<<T::Index as IndexLike>::Id>,
1428 factorize_options: &FactorizeOptions,
1429 ) -> Result<()>
1430 where
1431 <T::Index as IndexLike>::Id: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
1432 {
1433 let node_b_name = self
1434 .graph
1435 .node_name(node_b_idx)
1436 .ok_or_else(|| anyhow::anyhow!("swap_on_edge: node_b not found"))?
1437 .clone();
1438
1439 let edge = {
1440 let g = self.graph.graph();
1441 g.edges_connecting(node_a_idx, node_b_idx)
1442 .next()
1443 .ok_or_else(|| anyhow::anyhow!("swap_on_edge: no edge between nodes"))?
1444 .id()
1445 };
1446 let bond_ab = self
1447 .bond_index(edge)
1448 .ok_or_else(|| anyhow::anyhow!("swap_on_edge: bond not found"))?
1449 .clone();
1450
1451 let other_bond_ids_a: HashSet<<T::Index as IndexLike>::Id> = self
1453 .edges_for_node(node_a_idx)
1454 .iter()
1455 .filter_map(|(e, _)| self.bond_index(*e).cloned())
1456 .filter(|b| b.id() != bond_ab.id())
1457 .map(|b| b.id().to_owned())
1458 .collect();
1459 let other_bond_ids_b: HashSet<<T::Index as IndexLike>::Id> = self
1460 .edges_for_node(node_b_idx)
1461 .iter()
1462 .filter_map(|(e, _)| self.bond_index(*e).cloned())
1463 .filter(|b| b.id() != bond_ab.id())
1464 .map(|b| b.id().to_owned())
1465 .collect();
1466
1467 let tensor_a = self
1468 .tensor(node_a_idx)
1469 .ok_or_else(|| anyhow::anyhow!("swap_on_edge: tensor_a not found"))?
1470 .clone();
1471 let tensor_b = self
1472 .tensor(node_b_idx)
1473 .ok_or_else(|| anyhow::anyhow!("swap_on_edge: tensor_b not found"))?
1474 .clone();
1475
1476 let site_ids_a: HashSet<<T::Index as IndexLike>::Id> = tensor_a
1478 .external_indices()
1479 .iter()
1480 .filter(|i| i.id() != bond_ab.id() && !other_bond_ids_a.contains(i.id()))
1481 .map(|i| i.id().to_owned())
1482 .collect();
1483 let site_ids_b: HashSet<<T::Index as IndexLike>::Id> = tensor_b
1484 .external_indices()
1485 .iter()
1486 .filter(|i| i.id() != bond_ab.id() && !other_bond_ids_b.contains(i.id()))
1487 .map(|i| i.id().to_owned())
1488 .collect();
1489 let all_site_ids: HashSet<_> = site_ids_a.union(&site_ids_b).cloned().collect();
1490 let assigned_site_ids: HashSet<_> = a_side_sites.union(b_side_sites).cloned().collect();
1491
1492 if !a_side_sites.is_disjoint(b_side_sites) {
1493 return Err(anyhow::anyhow!(
1494 "swap_on_edge: a_side_sites and b_side_sites overlap"
1495 ));
1496 }
1497 if assigned_site_ids != all_site_ids {
1498 return Err(anyhow::anyhow!(
1499 "swap_on_edge: scheduled site partition does not match current edge sites"
1500 ));
1501 }
1502
1503 let tensor_ab = T::contract(&[&tensor_a, &tensor_b], AllowedPairs::All)
1504 .context("swap_on_edge: contract")?;
1505
1506 let ab_indices = tensor_ab.external_indices();
1507 let left_inds: Vec<T::Index> = ab_indices
1508 .iter()
1509 .filter(|i| other_bond_ids_a.contains(i.id()) || a_side_sites.contains(i.id()))
1510 .cloned()
1511 .collect();
1512
1513 let result =
1514 swap::factorize_or_trivial(&tensor_ab, &left_inds, &ab_indices, factorize_options)
1515 .context("swap_on_edge: factorize")?;
1516
1517 self.replace_edge_bond(edge, result.bond_index)
1518 .context("swap_on_edge: replace_edge_bond")?;
1519 self.replace_tensor(node_a_idx, result.left)
1520 .context("swap_on_edge: replace tensor_a")?;
1521 self.replace_tensor(node_b_idx, result.right)
1522 .context("swap_on_edge: replace tensor_b")?;
1523 self.set_edge_ortho_towards(edge, Some(node_b_name))
1524 .context("swap_on_edge: set_edge_ortho_towards")?;
1525
1526 Ok(())
1527 }
1528
1529 pub fn swap_site_indices(
1544 &mut self,
1545 target_assignment: &HashMap<<T::Index as IndexLike>::Id, V>,
1546 options: &swap::SwapOptions,
1547 ) -> Result<()>
1548 where
1549 <T::Index as IndexLike>::Id:
1550 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
1551 V: Ord,
1552 {
1553 if target_assignment.is_empty() {
1554 return Ok(());
1555 }
1556
1557 let current = swap::current_site_assignment(self);
1558 let root = self
1559 .node_names()
1560 .into_iter()
1561 .min()
1562 .ok_or_else(|| anyhow::anyhow!("swap_site_indices: empty network"))?;
1563 let schedule = swap::SwapSchedule::build(
1564 self.site_index_network().topology(),
1565 ¤t,
1566 target_assignment,
1567 &root,
1568 )
1569 .context("swap_site_indices: build schedule")?;
1570
1571 if schedule.steps.is_empty() {
1572 return Ok(());
1573 }
1574
1575 self.canonicalize_mut(
1576 std::iter::once(schedule.root.clone()),
1577 crate::options::CanonicalizationOptions::default(),
1578 )
1579 .context("swap_site_indices: canonicalize")?;
1580
1581 let mut swap_factorize_options = FactorizeOptions::svd().with_canonical(Canonical::Left);
1582 if let Some(mr) = options.max_rank {
1583 swap_factorize_options = swap_factorize_options.with_max_rank(mr);
1584 }
1585 if let Some(rtol) = options.rtol {
1586 swap_factorize_options = swap_factorize_options
1587 .with_svd_policy(tensor4all_core::SvdTruncationPolicy::new(rtol));
1588 }
1589 let transport_factorize_options = FactorizeOptions::svd().with_canonical(Canonical::Left);
1590
1591 for step in &schedule.steps {
1592 for edge in step.transport_path.windows(2) {
1593 let src_name = &edge[0];
1594 let dst_name = &edge[1];
1595 let src_idx = self.node_index(src_name).ok_or_else(|| {
1596 anyhow::anyhow!("swap_site_indices: transport node {:?} not found", src_name)
1597 })?;
1598 let dst_idx = self.node_index(dst_name).ok_or_else(|| {
1599 anyhow::anyhow!("swap_site_indices: transport node {:?} not found", dst_name)
1600 })?;
1601 self.sweep_edge(
1602 src_idx,
1603 dst_idx,
1604 &transport_factorize_options,
1605 "swap_transport",
1606 )
1607 .context("swap_site_indices: transport")?;
1608 }
1609
1610 let a_idx = self.node_index(&step.node_a).ok_or_else(|| {
1611 anyhow::anyhow!("swap_site_indices: node {:?} not found", step.node_a)
1612 })?;
1613 let b_idx = self.node_index(&step.node_b).ok_or_else(|| {
1614 anyhow::anyhow!("swap_site_indices: node {:?} not found", step.node_b)
1615 })?;
1616 self.swap_on_edge(
1617 a_idx,
1618 b_idx,
1619 &step.a_side_sites,
1620 &step.b_side_sites,
1621 &swap_factorize_options,
1622 )
1623 .context("swap_site_indices: swap_on_edge")?;
1624 self.set_canonical_region([step.node_b.clone()])
1625 .context("swap_site_indices: set_canonical_region")?;
1626 }
1627
1628 Ok(())
1629 }
1630
1631 pub fn swap_site_indices_by_index(
1674 &mut self,
1675 target_assignment: &HashMap<T::Index, V>,
1676 options: &swap::SwapOptions,
1677 ) -> Result<()>
1678 where
1679 <T::Index as IndexLike>::Id:
1680 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
1681 T::Index: Hash + Eq,
1682 V: Ord,
1683 {
1684 let id_assignment: HashMap<_, _> = target_assignment
1685 .iter()
1686 .map(|(idx, node)| (idx.id().clone(), node.clone()))
1687 .collect();
1688
1689 self.swap_site_indices(&id_assignment, options)
1690 }
1691
1692 pub fn verify_internal_consistency(&self) -> Result<()>
1714 where
1715 <T::Index as IndexLike>::Id:
1716 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
1717 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
1718 {
1719 let num_nodes = self.graph.graph().node_count();
1722 if num_nodes > 1 {
1723 if let Some(start_node) = self.graph.graph().node_indices().next() {
1725 let mut dfs = Dfs::new(self.graph.graph(), start_node);
1726 let mut visited_count = 0;
1727 while dfs.next(self.graph.graph()).is_some() {
1728 visited_count += 1;
1729 }
1730 if visited_count != num_nodes {
1731 return Err(anyhow::anyhow!(
1732 "TreeTN is disconnected: DFS visited {} of {} nodes. All tensors must be connected.",
1733 visited_count,
1734 num_nodes
1735 ))
1736 .context("verify_internal_consistency: graph must be connected");
1737 }
1738 }
1739 }
1740
1741 let mut index_id_to_nodes: HashMap<<T::Index as IndexLike>::Id, Vec<NodeIndex>> =
1744 HashMap::new();
1745 for node_idx in self.graph.graph().node_indices() {
1746 if let Some(tensor) = self.tensor(node_idx) {
1747 for index in tensor.external_indices() {
1748 index_id_to_nodes
1749 .entry(index.id().clone())
1750 .or_default()
1751 .push(node_idx);
1752 }
1753 }
1754 }
1755
1756 for (index_id, nodes) in &index_id_to_nodes {
1758 if nodes.len() > 2 {
1759 return Err(anyhow::anyhow!(
1761 "Index ID {:?} is shared by {} nodes, but tree structure allows at most 2",
1762 index_id,
1763 nodes.len()
1764 ))
1765 .context("verify_internal_consistency: index ID shared by too many nodes");
1766 }
1767 if nodes.len() == 2 {
1768 let node_a = nodes[0];
1770 let node_b = nodes[1];
1771 if self.graph.graph().find_edge(node_a, node_b).is_none()
1772 && self.graph.graph().find_edge(node_b, node_a).is_none()
1773 {
1774 let name_a = self.graph.node_name(node_a);
1775 let name_b = self.graph.node_name(node_b);
1776 return Err(anyhow::anyhow!(
1777 "Non-adjacent nodes {:?} and {:?} share index ID {:?}. \
1778 Only adjacent (edge-connected) nodes may share index IDs.",
1779 name_a,
1780 name_b,
1781 index_id
1782 ))
1783 .context("verify_internal_consistency: non-adjacent nodes share index ID");
1784 }
1785 }
1786 }
1787
1788 let node_names: Vec<V> = self.node_names();
1790 let tensors: Vec<T> = node_names
1791 .iter()
1792 .filter_map(|name| {
1793 let idx = self.graph.node_index(name)?;
1794 self.tensor(idx).cloned()
1795 })
1796 .collect();
1797
1798 if tensors.len() != node_names.len() {
1799 return Err(anyhow::anyhow!(
1800 "Internal inconsistency: {} node names but {} tensors found",
1801 node_names.len(),
1802 tensors.len()
1803 ));
1804 }
1805
1806 let reconstructed = TreeTN::<T, V>::from_tensors_unchecked(tensors, node_names)
1809 .context("verify_internal_consistency: failed to reconstruct TreeTN")?;
1810
1811 if !self.same_topology(&reconstructed) {
1813 return Err(anyhow::anyhow!(
1814 "Internal inconsistency: topology does not match after reconstruction"
1815 ))
1816 .context("verify_internal_consistency: topology mismatch");
1817 }
1818
1819 if !self
1821 .site_index_network
1822 .share_equivalent_site_index_network(&reconstructed.site_index_network)
1823 {
1824 return Err(anyhow::anyhow!(
1825 "Internal inconsistency: site index network does not match after reconstruction"
1826 ))
1827 .context("verify_internal_consistency: site space mismatch");
1828 }
1829
1830 for node_name in self.node_names() {
1832 let idx_self = self
1833 .graph
1834 .node_index(&node_name)
1835 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in original", node_name))?;
1836 let idx_reconstructed =
1837 reconstructed.graph.node_index(&node_name).ok_or_else(|| {
1838 anyhow::anyhow!("Node {:?} not found in reconstructed", node_name)
1839 })?;
1840
1841 let tensor_self = self.tensor(idx_self).ok_or_else(|| {
1842 anyhow::anyhow!("Tensor not found for node {:?} in original", node_name)
1843 })?;
1844 let tensor_reconstructed =
1845 reconstructed.tensor(idx_reconstructed).ok_or_else(|| {
1846 anyhow::anyhow!("Tensor not found for node {:?} in reconstructed", node_name)
1847 })?;
1848
1849 let indices_self: HashSet<_> = tensor_self.external_indices().into_iter().collect();
1851 let indices_reconstructed: HashSet<_> = tensor_reconstructed
1852 .external_indices()
1853 .into_iter()
1854 .collect();
1855 if indices_self != indices_reconstructed {
1856 return Err(anyhow::anyhow!(
1857 "Internal inconsistency: tensor indices differ at node {:?}",
1858 node_name
1859 ))
1860 .context("verify_internal_consistency: tensor index mismatch");
1861 }
1862
1863 if tensor_self.num_external_indices() != tensor_reconstructed.num_external_indices() {
1865 return Err(anyhow::anyhow!(
1866 "Internal inconsistency: tensor dimensions differ at node {:?}: {} vs {}",
1867 node_name,
1868 tensor_self.num_external_indices(),
1869 tensor_reconstructed.num_external_indices()
1870 ))
1871 .context("verify_internal_consistency: tensor dimension mismatch");
1872 }
1873 }
1874
1875 Ok(())
1876 }
1877}
1878
1879pub(crate) fn common_inds<I: IndexLike>(inds_a: &[I], inds_b: &[I]) -> Vec<I> {
1885 let set_b: HashSet<_> = inds_b.iter().map(|idx| idx.id()).collect();
1886 inds_a
1887 .iter()
1888 .filter(|idx| set_b.contains(idx.id()))
1889 .cloned()
1890 .collect()
1891}