Skip to main content

tensor4all_treetn/treetn/
mod.rs

1//! Tree Tensor Network implementation.
2//!
3//! This module provides the [`TreeTN`] type, a tree-structured tensor network
4//! for efficient tensor operations with canonicalization and truncation support.
5
6// Some utility functions are WIP and not yet connected
7#![allow(dead_code)]
8
9mod addition;
10mod canonicalize;
11pub mod contraction;
12mod decompose;
13mod fit;
14mod localupdate;
15mod operator_impl;
16mod ops;
17pub mod partial_contraction;
18mod restructure;
19mod swap;
20mod tensor_like;
21mod transform;
22mod truncate;
23
24use petgraph::stable_graph::{EdgeIndex, NodeIndex};
25use petgraph::visit::{Dfs, EdgeRef};
26use std::collections::HashMap;
27use std::collections::HashSet;
28use std::hash::Hash;
29
30use anyhow::{Context, Result};
31
32use crate::algorithm::CanonicalForm;
33use tensor4all_core::{AllowedPairs, Canonical, FactorizeOptions, IndexLike, TensorLike};
34
35use crate::named_graph::NamedGraph;
36use crate::site_index_network::SiteIndexNetwork;
37
38// Re-export the decomposition functions and types
39pub use decompose::{factorize_tensor_to_treetn, factorize_tensor_to_treetn_with, TreeTopology};
40
41// Re-export local update types
42pub use localupdate::{
43    apply_local_update_sweep, get_boundary_edges, BoundaryEdge, LocalUpdateStep,
44    LocalUpdateSweepPlan, LocalUpdater, TruncateUpdater,
45};
46
47// Re-export partial contraction types
48pub use partial_contraction::{partial_contract, PartialContractionSpec};
49
50// Re-export swap types
51pub use swap::{ScheduledSwapStep, SwapOptions, SwapSchedule};
52
53/// Tree Tensor Network structure (inspired by ITensorNetworks.jl's TreeTensorNetwork).
54///
55/// Maintains a graph of tensors connected by bonds (edges).
56/// Each node stores a tensor, and edges store `Connection` objects
57/// that hold the bond index.
58///
59/// The structure uses SiteIndexNetwork to manage:
60/// - **Topology**: Graph structure (which nodes connect to which)
61/// - **Site Space**: Physical indices organized by node
62///
63/// # Type Parameters
64/// - `T`: Tensor type implementing `TensorLike` (default: `TensorDynLen`)
65/// - `V`: Node name type for named nodes (default: NodeIndex for backward compatibility)
66///
67/// # Construction
68///
69/// - `TreeTN::new()`: Create an empty network, then use `add_tensor()` and `connect()` to build.
70/// - `TreeTN::from_tensors(tensors, node_names)`: Create from tensors with auto-connection by matching index IDs.
71///
72/// # Examples
73///
74/// Build a 2-node chain manually and verify node count:
75///
76/// ```
77/// use tensor4all_treetn::TreeTN;
78/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
79///
80/// // Create site and bond indices
81/// let s0 = DynIndex::new_dyn(2);
82/// let bond = DynIndex::new_dyn(3);
83/// let s1 = DynIndex::new_dyn(2);
84///
85/// // Build tensors
86/// let t0 = TensorDynLen::from_dense(
87///     vec![s0.clone(), bond.clone()],
88///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
89/// ).unwrap();
90/// let t1 = TensorDynLen::from_dense(
91///     vec![bond.clone(), s1.clone()],
92///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
93/// ).unwrap();
94///
95/// // Use from_tensors (auto-connects nodes sharing the same index ID)
96/// let tn = TreeTN::<_, String>::from_tensors(
97///     vec![t0, t1],
98///     vec!["A".to_string(), "B".to_string()],
99/// ).unwrap();
100///
101/// assert_eq!(tn.node_count(), 2);
102/// assert_eq!(tn.edge_count(), 1);
103/// ```
104pub struct TreeTN<T = tensor4all_core::TensorDynLen, V = NodeIndex>
105where
106    T: TensorLike,
107    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
108{
109    /// Named graph wrapper: provides mapping between node names (V) and NodeIndex
110    /// Edges store the bond Index directly.
111    pub(crate) graph: NamedGraph<V, T, T::Index>,
112    /// Orthogonalization region (canonical_region).
113    /// When empty, the network is not canonicalized.
114    /// When non-empty, contains the node names (V) of the orthogonalization region.
115    /// The region must form a connected subtree in the network.
116    pub(crate) canonical_region: HashSet<V>,
117    /// Canonical form used for the current canonicalization.
118    /// `None` if not canonicalized (canonical_region is empty).
119    /// `Some(form)` if canonicalized with the specified form.
120    pub(crate) canonical_form: Option<CanonicalForm>,
121    /// Site index network: manages topology and site space (physical indices).
122    /// This structure enables topology and site space comparison independent of tensor data.
123    pub(crate) site_index_network: SiteIndexNetwork<V, T::Index>,
124    /// Link index network: manages bond/link indices with reverse lookup.
125    /// Provides O(1) lookup from index ID to edge.
126    pub(crate) link_index_network: crate::link_index_network::LinkIndexNetwork<T::Index>,
127    /// Orthogonalization direction for each index (bond or site).
128    /// Maps index to the node name (V) that the orthogonalization points towards.
129    /// - For bond indices: points towards the canonical center direction
130    /// - For site indices: points to the node that owns the index (always towards canonical center)
131    ///
132    /// Note: Uses the full index as the key (via `IndexLike: Eq + Hash`).
133    pub(crate) ortho_towards: HashMap<T::Index, V>,
134}
135
136/// Internal context for sweep-to-center operations.
137///
138/// Contains precomputed information needed for both canonicalization and truncation.
139#[derive(Debug)]
140pub(crate) struct SweepContext {
141    /// Edges to process, ordered from leaves towards center.
142    /// Each edge is (src, dst) where src is the node to factorize and dst is its parent.
143    pub(crate) edges: Vec<(NodeIndex, NodeIndex)>,
144}
145
146// ============================================================================
147// Construction methods
148// ============================================================================
149
150impl<T, V> TreeTN<T, V>
151where
152    T: TensorLike,
153    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
154{
155    /// Create a new empty TreeTN.
156    ///
157    /// Use `add_tensor()` to add tensors and `connect()` to establish bonds manually.
158    pub fn new() -> Self {
159        Self {
160            graph: NamedGraph::new(),
161            canonical_region: HashSet::new(),
162            canonical_form: None,
163            site_index_network: SiteIndexNetwork::new(),
164            link_index_network: crate::link_index_network::LinkIndexNetwork::new(),
165            ortho_towards: HashMap::new(),
166        }
167    }
168
169    /// Create a TreeTN from a list of tensors and node names using einsum rule.
170    ///
171    /// This function connects tensors that share common indices (by ID).
172    /// The algorithm is O(n) where n is the number of tensors:
173    /// 1. Add all tensors as nodes
174    /// 2. Build a map from index ID to (node, index) pairs in a single pass
175    /// 3. Connect nodes that share the same index ID
176    ///
177    /// # Arguments
178    /// * `tensors` - Vector of tensors to add to the network
179    /// * `node_names` - Vector of node names corresponding to each tensor
180    ///
181    /// # Returns
182    /// A new TreeTN with tensors connected by common indices, or an error if:
183    /// - The lengths of `tensors` and `node_names` don't match
184    /// - An index ID appears in more than 2 tensors (TreeTN is a tree, so each bond connects exactly 2 nodes)
185    /// - Connection fails (e.g., dimension mismatch)
186    ///
187    /// # Errors
188    /// Returns an error if validation fails or connection fails.
189    ///
190    /// # Examples
191    ///
192    /// ```
193    /// use tensor4all_treetn::TreeTN;
194    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
195    ///
196    /// let s0 = DynIndex::new_dyn(2);
197    /// let bond = DynIndex::new_dyn(3);
198    /// let s1 = DynIndex::new_dyn(2);
199    ///
200    /// let t0 = TensorDynLen::from_dense(
201    ///     vec![s0.clone(), bond.clone()],
202    ///     vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0],
203    /// ).unwrap();
204    /// let t1 = TensorDynLen::from_dense(
205    ///     vec![bond.clone(), s1.clone()],
206    ///     vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0],
207    /// ).unwrap();
208    ///
209    /// let tn = TreeTN::<_, String>::from_tensors(
210    ///     vec![t0, t1],
211    ///     vec!["A".to_string(), "B".to_string()],
212    /// ).unwrap();
213    ///
214    /// assert_eq!(tn.node_count(), 2);
215    /// assert_eq!(tn.edge_count(), 1);
216    /// ```
217    pub fn from_tensors(tensors: Vec<T>, node_names: Vec<V>) -> Result<Self>
218    where
219        <T::Index as IndexLike>::Id:
220            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
221        V: Ord,
222    {
223        let treetn = Self::from_tensors_unchecked(tensors, node_names)?;
224
225        // Verify structural constraints after construction
226        treetn.verify_internal_consistency().context(
227            "TreeTN::from_tensors: constructed TreeTN failed internal consistency check",
228        )?;
229
230        Ok(treetn)
231    }
232
233    /// Internal version of `from_tensors` that skips verification.
234    /// Used by `verify_internal_consistency` to avoid infinite recursion.
235    fn from_tensors_unchecked(tensors: Vec<T>, node_names: Vec<V>) -> Result<Self>
236    where
237        <T::Index as IndexLike>::Id:
238            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
239    {
240        // Validate input lengths
241        if tensors.len() != node_names.len() {
242            return Err(anyhow::anyhow!(
243                "Length mismatch: {} tensors but {} node names",
244                tensors.len(),
245                node_names.len()
246            ))
247            .context("TreeTN::from_tensors: tensors and node_names must have the same length");
248        }
249
250        // Create empty TreeTN
251        let mut treetn = Self::new();
252
253        // Step 1: Add all tensors as nodes and collect NodeIndex mappings
254        let mut node_indices = Vec::with_capacity(tensors.len());
255        for (tensor, node_name) in tensors.into_iter().zip(node_names) {
256            let node_idx = treetn.add_tensor_internal(node_name, tensor)?;
257            node_indices.push(node_idx);
258        }
259
260        // Step 2: Build a map from index ID to (node_index, index) pairs in O(n) time
261        // Key: index ID, Value: vector of (NodeIndex, Index) pairs
262        #[allow(clippy::type_complexity)]
263        let mut index_map: HashMap<
264            <T::Index as IndexLike>::Id,
265            Vec<(NodeIndex, T::Index)>,
266        > = HashMap::new();
267
268        for node_idx in &node_indices {
269            let tensor = treetn
270                .tensor(*node_idx)
271                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_idx))?;
272
273            for index in tensor.external_indices() {
274                index_map
275                    .entry(index.id().clone())
276                    .or_insert_with(Vec::new)
277                    .push((*node_idx, index.clone()));
278            }
279        }
280
281        // Step 3: Connect nodes that share the same index ID
282        // For TreeTN (tree structure), each index ID should appear in exactly 2 tensors
283        for (index_id, nodes_with_index) in index_map {
284            match nodes_with_index.len() {
285                0 => unreachable!(),
286                1 => {
287                    // Index appears in only one tensor - this is a physical index, no connection needed
288                    continue;
289                }
290                2 => {
291                    // Index appears in exactly 2 tensors - connect them
292                    let (node_a, index_a) = &nodes_with_index[0];
293                    let (node_b, index_b) = &nodes_with_index[1];
294
295                    treetn
296                        .connect_internal(*node_a, index_a, *node_b, index_b)
297                        .with_context(|| {
298                            format!(
299                                "Failed to connect nodes {:?} and {:?} via index ID {:?}",
300                                node_a, node_b, index_id
301                            )
302                        })?;
303                }
304                n => {
305                    // Index appears in more than 2 tensors - this violates tree structure
306                    return Err(anyhow::anyhow!(
307                        "Index ID {:?} appears in {} tensors, but TreeTN requires exactly 2 (tree structure)",
308                        index_id, n
309                    ))
310                    .context("TreeTN::from_tensors: each bond index must connect exactly 2 nodes");
311                }
312            }
313        }
314
315        Ok(treetn)
316    }
317
318    /// Add a tensor to the network with a node name.
319    ///
320    /// Returns the NodeIndex for the newly added tensor.
321    ///
322    /// Also updates the site_index_network with the physical indices (all indices initially,
323    /// as no connections exist yet).
324    pub fn add_tensor(&mut self, node_name: V, tensor: T) -> Result<NodeIndex> {
325        self.add_tensor_internal(node_name, tensor)
326    }
327
328    /// Add a tensor to the network using NodeIndex as the node name.
329    ///
330    /// This method only works when `V = NodeIndex`.
331    ///
332    /// Returns the NodeIndex for the newly added tensor.
333    pub fn add_tensor_auto_name(&mut self, tensor: T) -> NodeIndex
334    where
335        V: From<NodeIndex> + Into<NodeIndex>,
336    {
337        // We need to add with a temporary name first, then get the actual NodeIndex
338        let temp_idx = self.graph.graph_mut().add_node(tensor.clone());
339        let node_name = V::from(temp_idx);
340
341        // Remove the temporary node and add properly with name
342        self.graph.graph_mut().remove_node(temp_idx);
343
344        // Re-add with the correct name
345        self.add_tensor_internal(node_name, tensor)
346            .expect("add_tensor_internal failed for auto-named tensor")
347    }
348
349    /// Connect two tensors via a specified pair of indices.
350    ///
351    /// The indices must have the same ID (Einsum mode).
352    ///
353    /// # Arguments
354    /// * `node_a` - First node
355    /// * `index_a` - Index on first node to use for connection
356    /// * `node_b` - Second node
357    /// * `index_b` - Index on second node to use for connection
358    ///
359    /// # Returns
360    /// The EdgeIndex of the new connection, or an error if validation fails.
361    pub fn connect(
362        &mut self,
363        node_a: NodeIndex,
364        index_a: &T::Index,
365        node_b: NodeIndex,
366        index_b: &T::Index,
367    ) -> Result<EdgeIndex> {
368        self.connect_internal(node_a, index_a, node_b, index_b)
369    }
370}
371
372// ============================================================================
373// Common implementation
374// ============================================================================
375
376impl<T, V> TreeTN<T, V>
377where
378    T: TensorLike,
379    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
380{
381    // ------------------------------------------------------------------------
382    // Internal methods (used by mode-specific methods)
383    // ------------------------------------------------------------------------
384
385    /// Internal method to add a tensor with a node name.
386    pub(crate) fn add_tensor_internal(&mut self, node_name: V, tensor: T) -> Result<NodeIndex> {
387        // Extract physical indices: initially all indices are physical (no connections yet)
388        let physical_indices: HashSet<T::Index> = tensor.external_indices().into_iter().collect();
389
390        // Add to graph
391        let node_idx = self
392            .graph
393            .add_node(node_name.clone(), tensor)
394            .map_err(|e| anyhow::anyhow!(e))?;
395
396        // Add to site_index_network
397        self.site_index_network
398            .add_node(node_name, physical_indices)
399            .map_err(|e| anyhow::anyhow!("Failed to add node to site_index_network: {}", e))?;
400
401        Ok(node_idx)
402    }
403
404    /// Internal method to connect two tensors.
405    ///
406    /// In Einsum mode, `index_a` and `index_b` must have the same ID.
407    pub(crate) fn connect_internal(
408        &mut self,
409        node_a: NodeIndex,
410        index_a: &T::Index,
411        node_b: NodeIndex,
412        index_b: &T::Index,
413    ) -> Result<EdgeIndex> {
414        // Validate that indices have the same ID (Einsum mode requirement)
415        if index_a.id() != index_b.id() {
416            return Err(anyhow::anyhow!(
417                "Index IDs must match in Einsum mode: {:?} != {:?}",
418                index_a.id(),
419                index_b.id()
420            ))
421            .context("Failed to connect tensors");
422        }
423
424        // Validate that nodes exist
425        if !self.graph.contains_node(node_a) || !self.graph.contains_node(node_b) {
426            return Err(anyhow::anyhow!("One or both nodes do not exist"))
427                .context("Failed to connect tensors");
428        }
429
430        // Validate that indices exist in respective tensors
431        let tensor_a = self
432            .tensor(node_a)
433            .ok_or_else(|| anyhow::anyhow!("Tensor for node_a not found"))?;
434        let tensor_b = self
435            .tensor(node_b)
436            .ok_or_else(|| anyhow::anyhow!("Tensor for node_b not found"))?;
437
438        // Check that indices exist in tensors
439        let has_index_a = tensor_a.external_indices().iter().any(|idx| idx == index_a);
440        let has_index_b = tensor_b.external_indices().iter().any(|idx| idx == index_b);
441
442        if !has_index_a {
443            return Err(anyhow::anyhow!("Index not found in tensor_a"))
444                .context("Failed to connect: index_a must exist in tensor_a");
445        }
446        if !has_index_b {
447            return Err(anyhow::anyhow!("Index not found in tensor_b"))
448                .context("Failed to connect: index_b must exist in tensor_b");
449        }
450
451        // Clone the bond index (same ID, use index_a)
452        let bond_index = tensor_a
453            .external_indices()
454            .iter()
455            .find(|idx| idx.same_id(index_a))
456            .unwrap()
457            .clone();
458
459        // Get node names for site_index_network (before mutable borrow)
460        let node_name_a = self
461            .graph
462            .node_name(node_a)
463            .ok_or_else(|| anyhow::anyhow!("Node name for node_a not found"))?
464            .clone();
465        let node_name_b = self
466            .graph
467            .node_name(node_b)
468            .ok_or_else(|| anyhow::anyhow!("Node name for node_b not found"))?
469            .clone();
470
471        // Add edge to graph with the bond index directly
472        let edge_idx = self
473            .graph
474            .graph_mut()
475            .add_edge(node_a, node_b, bond_index.clone());
476
477        // Add edge to site_index_network
478        self.site_index_network
479            .add_edge(&node_name_a, &node_name_b)
480            .map_err(|e| anyhow::anyhow!("Failed to add edge to site_index_network: {}", e))?;
481
482        // Update physical indices: remove bond index from physical indices
483        // Use remove_site_index to also update the index_to_node reverse lookup
484        let _ = self
485            .site_index_network
486            .remove_site_index(&node_name_a, &bond_index);
487        let _ = self
488            .site_index_network
489            .remove_site_index(&node_name_b, &bond_index);
490
491        // Register bond index in link_index_network for reverse lookup
492        self.link_index_network.insert(edge_idx, &bond_index);
493
494        Ok(edge_idx)
495    }
496
497    /// Prepare context for sweep-to-center operations.
498    ///
499    /// This method:
500    /// 1. Validates tree structure
501    /// 2. Sets canonical_region and validates connectivity
502    /// 3. Computes edges from leaves towards center using edges_to_canonicalize_to_region
503    ///
504    /// # Arguments
505    /// * `canonical_region` - The node names that will serve as centers
506    /// * `context_name` - Name for error context (e.g., "canonicalize_with")
507    ///
508    /// # Returns
509    /// A SweepContext if successful, or an error if validation fails.
510    pub(crate) fn prepare_sweep_to_center(
511        &mut self,
512        canonical_region: impl IntoIterator<Item = V>,
513        context_name: &str,
514    ) -> Result<Option<SweepContext>> {
515        // 1. Validate tree structure
516        self.validate_tree()
517            .with_context(|| format!("{}: graph must be a tree", context_name))?;
518
519        // 2. Set canonical_region
520        let canonical_region_v: Vec<V> = canonical_region.into_iter().collect();
521        self.set_canonical_region(canonical_region_v)
522            .with_context(|| format!("{}: failed to set canonical_region", context_name))?;
523
524        if self.canonical_region.is_empty() {
525            return Ok(None); // Nothing to do if no centers
526        }
527
528        // 3. Convert canonical_region names to NodeIndex set
529        let center_indices: HashSet<NodeIndex> = self
530            .canonical_region
531            .iter()
532            .filter_map(|name| self.graph.node_index(name))
533            .collect();
534
535        // 4. Validate canonical_region connectivity
536        if !self.site_index_network.is_connected_subset(&center_indices) {
537            return Err(anyhow::anyhow!(
538                "canonical_region is not connected: {} centers but not all reachable",
539                self.canonical_region.len()
540            ))
541            .with_context(|| {
542                format!(
543                    "{}: canonical_region must form a connected subtree",
544                    context_name
545                )
546            });
547        }
548
549        // 5. Get ordered edges from leaves towards center
550        let canonicalize_edges = self
551            .site_index_network
552            .edges_to_canonicalize_to_region(&center_indices);
553        let edges: Vec<(NodeIndex, NodeIndex)> = canonicalize_edges.into_iter().collect();
554
555        Ok(Some(SweepContext { edges }))
556    }
557
558    /// Process one edge during a sweep operation.
559    ///
560    /// Factorizes the tensor at `src` node, absorbs the right factor into `dst` (parent),
561    /// and updates the edge bond and ortho_towards.
562    ///
563    /// # Arguments
564    /// * `src` - The source node to factorize (further from center)
565    /// * `dst` - The destination/parent node (closer to center)
566    /// * `factorize_options` - Options for factorization (algorithm, rtol, max_rank)
567    /// * `context_name` - Name for error context
568    ///
569    /// # Returns
570    /// `Ok(())` if successful, or an error if any step fails.
571    pub(crate) fn sweep_edge(
572        &mut self,
573        src: NodeIndex,
574        dst: NodeIndex,
575        factorize_options: &FactorizeOptions,
576        context_name: &str,
577    ) -> Result<()> {
578        // Find edge between src and dst
579        let edge = {
580            let g = self.graph.graph();
581            g.edges_connecting(src, dst)
582                .next()
583                .ok_or_else(|| {
584                    anyhow::anyhow!("No edge found between node {:?} and {:?}", src, dst)
585                })
586                .with_context(|| format!("{}: edge not found", context_name))?
587                .id()
588        };
589
590        // Get bond index on src-side (the index we will factorize over)
591        let bond_on_src = self
592            .bond_index(edge)
593            .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge"))
594            .with_context(|| format!("{}: failed to get bond index on src", context_name))?
595            .clone();
596
597        // Get tensor at src node
598        let tensor_src = self
599            .tensor(src)
600            .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", src))
601            .with_context(|| format!("{}: tensor not found", context_name))?;
602
603        // Build left_inds = all indices except dst bond
604        let left_inds: Vec<T::Index> = tensor_src
605            .external_indices()
606            .iter()
607            .filter(|idx| idx.id() != bond_on_src.id())
608            .cloned()
609            .collect();
610
611        let tensor_external_indices = tensor_src.external_indices();
612        if left_inds.is_empty() {
613            let tensor_dst = self
614                .tensor(dst)
615                .ok_or_else(|| anyhow::anyhow!("Tensor not found for dst node {:?}", dst))
616                .with_context(|| format!("{}: dst tensor not found", context_name))?;
617
618            let src_norm = tensor_src.norm();
619            let updated_src_tensor = if src_norm > 0.0 {
620                tensor_src
621                    .scale(tensor4all_core::AnyScalar::new_real(1.0 / src_norm))
622                    .with_context(|| format!("{}: failed to normalize src tensor", context_name))?
623            } else {
624                tensor_src.clone()
625            };
626            let updated_dst_tensor = if src_norm > 0.0 {
627                tensor_dst
628                    .scale(tensor4all_core::AnyScalar::new_real(src_norm))
629                    .with_context(|| format!("{}: failed to scale dst tensor", context_name))?
630            } else {
631                tensor_dst.clone()
632            };
633
634            self.replace_tensor(src, updated_src_tensor)
635                .with_context(|| {
636                    format!("{}: failed to replace tensor at src node", context_name)
637                })?;
638            self.replace_tensor(dst, updated_dst_tensor)
639                .with_context(|| {
640                    format!("{}: failed to replace tensor at dst node", context_name)
641                })?;
642
643            let dst_name = self
644                .graph
645                .node_name(dst)
646                .ok_or_else(|| anyhow::anyhow!("Dst node name not found"))?
647                .clone();
648            self.set_edge_ortho_towards(edge, Some(dst_name))
649                .with_context(|| format!("{}: failed to set ortho_towards", context_name))?;
650
651            return Ok(());
652        }
653
654        if left_inds.len() == tensor_external_indices.len() {
655            return Err(anyhow::anyhow!(
656                "Cannot process node {:?}: need at least one left index and one right index",
657                src
658            ))
659            .with_context(|| format!("{}: invalid tensor rank for factorization", context_name));
660        }
661
662        // Perform factorization
663        let factorize_result = tensor_src
664            .factorize(&left_inds, factorize_options)
665            .map_err(|e| anyhow::anyhow!("Factorization failed: {}", e))
666            .with_context(|| format!("{}: factorization failed", context_name))?;
667
668        let left_tensor = factorize_result.left;
669        let right_tensor = factorize_result.right;
670
671        // Absorb right_tensor into dst
672        let tensor_dst = self
673            .tensor(dst)
674            .ok_or_else(|| anyhow::anyhow!("Tensor not found for dst node {:?}", dst))
675            .with_context(|| format!("{}: dst tensor not found", context_name))?;
676
677        let updated_dst_tensor = T::contract(&[tensor_dst, &right_tensor], AllowedPairs::All)
678            .with_context(|| {
679                format!(
680                    "{}: failed to absorb right factor into dst tensor",
681                    context_name
682                )
683            })?;
684
685        // Update bond index FIRST, so replace_tensor validation matches
686        let new_bond_index = factorize_result.bond_index;
687        self.replace_edge_bond(edge, new_bond_index.clone())
688            .with_context(|| format!("{}: failed to update edge bond index", context_name))?;
689
690        // Update tensors
691        self.replace_tensor(src, left_tensor)
692            .with_context(|| format!("{}: failed to replace tensor at src node", context_name))?;
693        self.replace_tensor(dst, updated_dst_tensor)
694            .with_context(|| format!("{}: failed to replace tensor at dst node", context_name))?;
695
696        // Set ortho_towards to point towards dst (canonical_region direction)
697        let dst_name = self
698            .graph
699            .node_name(dst)
700            .ok_or_else(|| anyhow::anyhow!("Dst node name not found"))?
701            .clone();
702        self.set_edge_ortho_towards(edge, Some(dst_name))
703            .with_context(|| format!("{}: failed to set ortho_towards", context_name))?;
704
705        Ok(())
706    }
707
708    // ------------------------------------------------------------------------
709    // Public accessors
710    // ------------------------------------------------------------------------
711
712    /// Get a reference to a tensor by NodeIndex.
713    pub fn tensor(&self, node: NodeIndex) -> Option<&T> {
714        self.graph.graph().node_weight(node)
715    }
716
717    /// Get a mutable reference to a tensor by NodeIndex.
718    pub fn tensor_mut(&mut self, node: NodeIndex) -> Option<&mut T> {
719        self.graph.graph_mut().node_weight_mut(node)
720    }
721
722    /// Replace a tensor at the given node with a new tensor.
723    ///
724    /// Validates that the new tensor contains all indices used in connections
725    /// to this node. Returns an error if any connection index is missing.
726    ///
727    /// Returns the old tensor if the node exists and validation passes.
728    pub fn replace_tensor(&mut self, node: NodeIndex, new_tensor: T) -> Result<Option<T>> {
729        // Check if node exists
730        if !self.graph.contains_node(node) {
731            return Ok(None);
732        }
733
734        // Validate that all connection indices exist in the new tensor
735        let edges = self.edges_for_node(node);
736        let connection_indices: Vec<T::Index> = edges
737            .iter()
738            .filter_map(|(edge_idx, _neighbor)| self.bond_index(*edge_idx).cloned())
739            .collect();
740
741        // Check if all connection indices are present in the new tensor
742        let new_tensor_indices = new_tensor.external_indices();
743        let common = common_inds(&connection_indices, &new_tensor_indices);
744        if common.len() != connection_indices.len() {
745            return Err(anyhow::anyhow!(
746                "New tensor is missing {} connection index(es): found {} out of {} required indices",
747                connection_indices.len() - common.len(),
748                common.len(),
749                connection_indices.len()
750            ))
751            .context("replace_tensor: new tensor must contain all indices used in connections");
752        }
753
754        // Get node name for site_index_network update
755        let node_name = self
756            .graph
757            .node_name(node)
758            .ok_or_else(|| anyhow::anyhow!("Node name not found"))?
759            .clone();
760
761        // Calculate new physical indices: all indices minus connection indices
762        let connection_indices_set: HashSet<T::Index> =
763            connection_indices.iter().cloned().collect();
764        let new_physical_indices: HashSet<T::Index> = new_tensor_indices
765            .iter()
766            .filter(|idx| !connection_indices_set.contains(idx))
767            .cloned()
768            .collect();
769
770        // All validations passed, replace the tensor
771        let old_tensor = self
772            .graph
773            .graph_mut()
774            .node_weight_mut(node)
775            .map(|old| std::mem::replace(old, new_tensor));
776
777        // Update site_index_network with new physical indices
778        // This properly updates both the site_space and the index_to_node mapping
779        self.site_index_network
780            .set_site_space(&node_name, new_physical_indices)
781            .map_err(|e| anyhow::anyhow!("Failed to update site_index_network: {}", e))?;
782
783        Ok(old_tensor)
784    }
785
786    /// Get the bond index for a given edge.
787    pub fn bond_index(&self, edge: EdgeIndex) -> Option<&T::Index> {
788        self.graph.graph().edge_weight(edge)
789    }
790
791    /// Get a mutable reference to the bond index for a given edge.
792    pub fn bond_index_mut(&mut self, edge: EdgeIndex) -> Option<&mut T::Index> {
793        self.graph.graph_mut().edge_weight_mut(edge)
794    }
795
796    /// Get all edges connected to a node.
797    pub fn edges_for_node(&self, node: NodeIndex) -> Vec<(EdgeIndex, NodeIndex)> {
798        self.graph
799            .graph()
800            .edges(node)
801            .map(|edge| {
802                let target = edge.target();
803                (edge.id(), target)
804            })
805            .collect()
806    }
807
808    /// Replace the bond index for an edge (e.g., after SVD creates a new bond index).
809    ///
810    /// Also updates site_index_network: the old bond index becomes physical again,
811    /// and the new bond index is removed from physical indices.
812    pub fn replace_edge_bond(&mut self, edge: EdgeIndex, new_bond_index: T::Index) -> Result<()> {
813        // Validate edge exists and get endpoints
814        let (source, target) = self
815            .graph
816            .graph()
817            .edge_endpoints(edge)
818            .ok_or_else(|| anyhow::anyhow!("Edge does not exist"))?;
819
820        // Get old bond index before updating
821        let old_bond_index = self
822            .bond_index(edge)
823            .ok_or_else(|| anyhow::anyhow!("Bond index not found"))?
824            .clone();
825
826        // Get node names for site_index_network update
827        let node_name_a = self
828            .graph
829            .node_name(source)
830            .ok_or_else(|| anyhow::anyhow!("Node name for source not found"))?
831            .clone();
832        let node_name_b = self
833            .graph
834            .node_name(target)
835            .ok_or_else(|| anyhow::anyhow!("Node name for target not found"))?
836            .clone();
837
838        // Update the bond index
839        *self
840            .bond_index_mut(edge)
841            .ok_or_else(|| anyhow::anyhow!("Bond index not found"))? = new_bond_index.clone();
842
843        // Update link_index_network: old id -> new id
844        self.link_index_network
845            .replace_index(&old_bond_index, &new_bond_index, edge)
846            .map_err(|e| anyhow::anyhow!("{}", e))?;
847
848        // Update ortho_towards key if present
849        if let Some(dir) = self.ortho_towards.remove(&old_bond_index) {
850            self.ortho_towards.insert(new_bond_index.clone(), dir);
851        }
852
853        // Update site_index_network:
854        // - Old bond index becomes physical again
855        // - New bond index is removed from physical
856        if let Some(site_space_a) = self.site_index_network.site_space_mut(&node_name_a) {
857            site_space_a.insert(old_bond_index.clone());
858            site_space_a.remove(&new_bond_index);
859        }
860        if let Some(site_space_b) = self.site_index_network.site_space_mut(&node_name_b) {
861            site_space_b.insert(old_bond_index);
862            site_space_b.remove(&new_bond_index);
863        }
864
865        Ok(())
866    }
867
868    // ------------------------------------------------------------------------
869    // ITensorMPS-like index relabeling helpers
870    // ------------------------------------------------------------------------
871
872    /// Return a copy with all link/bond indices replaced by fresh IDs.
873    ///
874    /// This is analogous to ITensorMPS.jl's `sim(linkinds, M)` / `sim!(linkinds, M)`,
875    /// and is mainly useful to avoid accidental index-ID collisions when combining
876    /// multiple networks.
877    ///
878    /// Notes:
879    /// - This keeps dimensions and conjugate states, but changes identities.
880    /// - This updates both endpoint tensors and internal bookkeeping.
881    pub fn sim_linkinds(&self) -> Result<Self>
882    where
883        T::Index: IndexLike,
884    {
885        let mut result = self.clone();
886        result.sim_linkinds_mut()?;
887        Ok(result)
888    }
889
890    /// Replace all link/bond indices with fresh IDs in-place.
891    ///
892    /// See [`Self::sim_linkinds`] for details.
893    pub fn sim_linkinds_mut(&mut self) -> Result<()>
894    where
895        T::Index: IndexLike,
896    {
897        // Snapshot edges first since replacements may touch internal maps.
898        let edges: Vec<EdgeIndex> = self.graph.graph().edge_indices().collect();
899        for edge in edges {
900            let old_bond = self
901                .bond_index(edge)
902                .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge {:?}", edge))?
903                .clone();
904            let new_bond = old_bond.sim();
905
906            // Update edge weight first so endpoint tensors can be validated against the new bond.
907            *self
908                .bond_index_mut(edge)
909                .ok_or_else(|| anyhow::anyhow!("Bond index not found for edge {:?}", edge))? =
910                new_bond.clone();
911
912            // Update endpoint tensors by matching the old bond by ID.
913            let (node_a, node_b) = self
914                .graph
915                .graph()
916                .edge_endpoints(edge)
917                .ok_or_else(|| anyhow::anyhow!("Edge {:?} not found", edge))?;
918            for node in [node_a, node_b] {
919                let tensor = self
920                    .tensor(node)
921                    .ok_or_else(|| anyhow::anyhow!("Tensor not found"))?;
922                let old_in_tensor = tensor
923                    .external_indices()
924                    .iter()
925                    .find(|idx| idx.id() == old_bond.id())
926                    .ok_or_else(|| anyhow::anyhow!("Bond index not found in endpoint tensor"))?
927                    .clone();
928                let new_tensor = tensor.replaceind(&old_in_tensor, &new_bond)?;
929                self.replace_tensor(node, new_tensor)?;
930            }
931
932            // Update ortho_towards key for this bond (if present), matched by ID.
933            if let Some((key, dir)) = self
934                .ortho_towards
935                .iter()
936                .find(|(k, _)| k.id() == old_bond.id())
937                .map(|(k, v)| (k.clone(), v.clone()))
938            {
939                self.ortho_towards.remove(&key);
940                self.ortho_towards.insert(new_bond.clone(), dir);
941            }
942
943            // Update reverse lookup map (id -> edge).
944            self.link_index_network
945                .replace_index(&old_bond, &new_bond, edge)
946                .map_err(|e| anyhow::anyhow!("{}", e))?;
947        }
948        Ok(())
949    }
950
951    /// Set the orthogonalization direction for an index (bond or site).
952    ///
953    /// The direction is specified as a node name (or None to clear).
954    ///
955    /// # Arguments
956    /// * `index` - The index to set ortho direction for
957    /// * `dir` - The node name that the ortho points towards, or None to clear
958    pub fn set_ortho_towards(&mut self, index: &T::Index, dir: Option<V>) {
959        match dir {
960            Some(node_name) => {
961                self.ortho_towards.insert(index.clone(), node_name);
962            }
963            None => {
964                self.ortho_towards.remove(index);
965            }
966        }
967    }
968
969    /// Get the node name that the orthogonalization points towards for an index.
970    ///
971    /// Returns None if ortho_towards is not set for this index.
972    pub fn ortho_towards_for_index(&self, index: &T::Index) -> Option<&V> {
973        self.ortho_towards.get(index)
974    }
975
976    /// Set the orthogonalization direction for an edge (by EdgeIndex).
977    ///
978    /// This is a convenience method that looks up the bond index and calls `set_ortho_towards`.
979    ///
980    /// The direction is specified as a node name (or None to clear).
981    /// The node must be one of the edge's endpoints.
982    pub fn set_edge_ortho_towards(
983        &mut self,
984        edge: petgraph::stable_graph::EdgeIndex,
985        dir: Option<V>,
986    ) -> Result<()> {
987        // Get the bond index for this edge
988        let bond = self
989            .bond_index(edge)
990            .ok_or_else(|| anyhow::anyhow!("Edge does not exist"))?
991            .clone();
992
993        // Validate that the node (if any) is one of the edge endpoints
994        if let Some(ref node_name) = dir {
995            let (source, target) = self
996                .graph
997                .graph()
998                .edge_endpoints(edge)
999                .ok_or_else(|| anyhow::anyhow!("Edge does not exist"))?;
1000
1001            let source_name = self.graph.node_name(source);
1002            let target_name = self.graph.node_name(target);
1003
1004            if source_name != Some(node_name) && target_name != Some(node_name) {
1005                return Err(anyhow::anyhow!(
1006                    "ortho_towards node {:?} must be one of the edge endpoints",
1007                    node_name
1008                ))
1009                .context("set_edge_ortho_towards: invalid node");
1010            }
1011        }
1012
1013        self.set_ortho_towards(&bond, dir);
1014        Ok(())
1015    }
1016
1017    /// Get the node name that the orthogonalization points towards for an edge.
1018    ///
1019    /// Returns None if ortho_towards is not set for this edge's bond index.
1020    pub fn ortho_towards_node(&self, edge: petgraph::stable_graph::EdgeIndex) -> Option<&V> {
1021        self.bond_index(edge)
1022            .and_then(|bond| self.ortho_towards.get(bond))
1023    }
1024
1025    /// Get the NodeIndex that the orthogonalization points towards for an edge.
1026    ///
1027    /// Returns None if ortho_towards is not set for this edge's bond index.
1028    pub fn ortho_towards_node_index(
1029        &self,
1030        edge: petgraph::stable_graph::EdgeIndex,
1031    ) -> Option<NodeIndex> {
1032        self.ortho_towards_node(edge)
1033            .and_then(|name| self.graph.node_index(name))
1034    }
1035
1036    /// Validate that the graph is a tree (or forest).
1037    ///
1038    /// Checks:
1039    /// - The graph is connected (all nodes reachable from the first node)
1040    /// - For each connected component: edges = nodes - 1 (tree condition)
1041    pub fn validate_tree(&self) -> Result<()> {
1042        let g = self.graph.graph();
1043        if g.node_count() == 0 {
1044            return Ok(()); // Empty graph is trivially valid
1045        }
1046
1047        // Check if graph is connected
1048        let mut visited = std::collections::HashSet::new();
1049        let start_node = g
1050            .node_indices()
1051            .next()
1052            .ok_or_else(|| anyhow::anyhow!("Graph has no nodes"))?;
1053
1054        // DFS to count reachable nodes
1055        let mut dfs = Dfs::new(g, start_node);
1056        while let Some(node) = dfs.next(g) {
1057            visited.insert(node);
1058        }
1059
1060        if visited.len() != g.node_count() {
1061            return Err(anyhow::anyhow!(
1062                "Graph is not connected: {} nodes reachable out of {}",
1063                visited.len(),
1064                g.node_count()
1065            ))
1066            .context("validate_tree: graph must be connected");
1067        }
1068
1069        // Check tree condition: edges = nodes - 1
1070        let node_count = g.node_count();
1071        let edge_count = g.edge_count();
1072
1073        if edge_count != node_count - 1 {
1074            return Err(anyhow::anyhow!(
1075                "Graph does not satisfy tree condition: {} edges != {} nodes - 1",
1076                edge_count,
1077                node_count
1078            ))
1079            .context("validate_tree: tree must have edges = nodes - 1");
1080        }
1081
1082        Ok(())
1083    }
1084
1085    /// Get the number of nodes in the network.
1086    pub fn node_count(&self) -> usize {
1087        self.graph.graph().node_count()
1088    }
1089
1090    /// Get the number of edges in the network.
1091    pub fn edge_count(&self) -> usize {
1092        self.graph.graph().edge_count()
1093    }
1094
1095    /// Get the NodeIndex for a node by name.
1096    pub fn node_index(&self, node_name: &V) -> Option<NodeIndex> {
1097        self.graph.node_index(node_name)
1098    }
1099
1100    /// Rename an existing node while preserving topology, site space, and
1101    /// orthogonality metadata.
1102    pub fn rename_node(&mut self, old_name: &V, new_name: V) -> Result<()> {
1103        if old_name == &new_name {
1104            return Ok(());
1105        }
1106
1107        self.graph
1108            .rename_node(old_name, new_name.clone())
1109            .map_err(|e| anyhow::anyhow!(e))
1110            .context("rename_node: failed to rename graph node")?;
1111        self.site_index_network
1112            .rename_node(old_name, new_name.clone())
1113            .map_err(|e| anyhow::anyhow!(e))
1114            .context("rename_node: failed to rename site-index node")?;
1115
1116        if self.canonical_region.remove(old_name) {
1117            self.canonical_region.insert(new_name.clone());
1118        }
1119
1120        for target in self.ortho_towards.values_mut() {
1121            if target == old_name {
1122                *target = new_name.clone();
1123            }
1124        }
1125
1126        Ok(())
1127    }
1128
1129    /// Get the EdgeIndex for the edge between two nodes by name.
1130    ///
1131    /// Returns `None` if either node doesn't exist or there's no edge between them.
1132    pub fn edge_between(&self, node_a: &V, node_b: &V) -> Option<EdgeIndex> {
1133        let idx_a = self.graph.node_index(node_a)?;
1134        let idx_b = self.graph.node_index(node_b)?;
1135        self.graph
1136            .graph()
1137            .find_edge(idx_a, idx_b)
1138            .or_else(|| self.graph.graph().find_edge(idx_b, idx_a))
1139    }
1140
1141    /// Get all node indices in the tree tensor network.
1142    pub fn node_indices(&self) -> Vec<NodeIndex> {
1143        self.graph.graph().node_indices().collect()
1144    }
1145
1146    /// Get all node names in the tree tensor network.
1147    pub fn node_names(&self) -> Vec<V> {
1148        self.graph
1149            .graph()
1150            .node_indices()
1151            .filter_map(|idx| self.graph.node_name(idx).cloned())
1152            .collect()
1153    }
1154
1155    /// Compute edges to canonicalize from leaves to target, returning node names.
1156    ///
1157    /// Returns `(from, to)` pairs in the order they should be processed:
1158    /// - `from` is the node being factorized
1159    /// - `to` is the parent node (towards target)
1160    ///
1161    /// This is useful for contract_zipup and similar algorithms that work with
1162    /// node names rather than NodeIndex.
1163    ///
1164    /// # Arguments
1165    /// * `target` - Target node name for the orthogonality center
1166    ///
1167    /// # Returns
1168    /// `None` if target node doesn't exist, otherwise a vector of `(from, to)` pairs.
1169    pub fn edges_to_canonicalize_by_names(&self, target: &V) -> Option<Vec<(V, V)>> {
1170        self.site_index_network
1171            .edges_to_canonicalize_by_names(target)
1172    }
1173
1174    /// Get a reference to the orthogonalization region (using node names).
1175    ///
1176    /// When empty, the network is not canonicalized.
1177    pub fn canonical_region(&self) -> &HashSet<V> {
1178        &self.canonical_region
1179    }
1180
1181    /// Check if the network is canonicalized.
1182    ///
1183    /// Returns `true` if `canonical_region` is non-empty, `false` otherwise.
1184    pub fn is_canonicalized(&self) -> bool {
1185        !self.canonical_region.is_empty()
1186    }
1187
1188    /// Set the orthogonalization region (using node names).
1189    ///
1190    /// Validates that all specified nodes exist in the graph.
1191    pub fn set_canonical_region(&mut self, region: impl IntoIterator<Item = V>) -> Result<()> {
1192        let region: HashSet<V> = region.into_iter().collect();
1193
1194        // Validate that all nodes exist in the graph
1195        for node_name in &region {
1196            if !self.graph.has_node(node_name) {
1197                return Err(anyhow::anyhow!(
1198                    "Node {:?} does not exist in the graph",
1199                    node_name
1200                ))
1201                .context("set_canonical_region: all nodes must be valid");
1202            }
1203        }
1204
1205        self.canonical_region = region;
1206        Ok(())
1207    }
1208
1209    /// Clear the orthogonalization region (mark network as not canonicalized).
1210    ///
1211    /// Also clears the canonical form.
1212    pub fn clear_canonical_region(&mut self) {
1213        self.canonical_region.clear();
1214        self.canonical_form = None;
1215    }
1216
1217    /// Get the current canonical form.
1218    ///
1219    /// Returns `None` if not canonicalized.
1220    pub fn canonical_form(&self) -> Option<CanonicalForm> {
1221        self.canonical_form
1222    }
1223
1224    /// Add a node to the orthogonalization region.
1225    ///
1226    /// Validates that the node exists in the graph.
1227    pub fn add_to_canonical_region(&mut self, node_name: V) -> Result<()> {
1228        if !self.graph.has_node(&node_name) {
1229            return Err(anyhow::anyhow!(
1230                "Node {:?} does not exist in the graph",
1231                node_name
1232            ))
1233            .context("add_to_canonical_region: node must be valid");
1234        }
1235        self.canonical_region.insert(node_name);
1236        Ok(())
1237    }
1238
1239    /// Remove a node from the orthogonalization region.
1240    ///
1241    /// Returns `true` if the node was in the region, `false` otherwise.
1242    pub fn remove_from_canonical_region(&mut self, node_name: &V) -> bool {
1243        self.canonical_region.remove(node_name)
1244    }
1245
1246    /// Get a reference to the site index network.
1247    ///
1248    /// The site index network contains both topology (graph structure) and site space (physical indices).
1249    pub fn site_index_network(&self) -> &SiteIndexNetwork<V, T::Index> {
1250        &self.site_index_network
1251    }
1252
1253    /// Get a mutable reference to the site index network.
1254    pub fn site_index_network_mut(&mut self) -> &mut SiteIndexNetwork<V, T::Index> {
1255        &mut self.site_index_network
1256    }
1257
1258    /// Get a reference to the site space (physical indices) for a node.
1259    pub fn site_space(&self, node_name: &V) -> Option<&std::collections::HashSet<T::Index>> {
1260        self.site_index_network.site_space(node_name)
1261    }
1262
1263    /// Get a mutable reference to the site space (physical indices) for a node.
1264    pub fn site_space_mut(
1265        &mut self,
1266        node_name: &V,
1267    ) -> Option<&mut std::collections::HashSet<T::Index>> {
1268        self.site_index_network.site_space_mut(node_name)
1269    }
1270
1271    /// Check if two TreeTNs share equivalent site index network structure.
1272    ///
1273    /// Two TreeTNs share equivalent structure if:
1274    /// - Same topology (nodes and edges)
1275    /// - Same site space for each node
1276    ///
1277    /// This is used to verify that two TreeTNs can be added or contracted.
1278    ///
1279    /// # Arguments
1280    /// * `other` - The other TreeTN to check against
1281    ///
1282    /// # Returns
1283    /// `true` if the networks share equivalent site index structure, `false` otherwise.
1284    pub fn share_equivalent_site_index_network(&self, other: &Self) -> bool
1285    where
1286        <T::Index as IndexLike>::Id: Ord,
1287    {
1288        self.site_index_network
1289            .share_equivalent_site_index_network(&other.site_index_network)
1290    }
1291
1292    /// Check if two TreeTNs have the same topology (graph structure).
1293    ///
1294    /// This only checks that both networks have the same nodes and edges,
1295    /// not that they have the same site indices.
1296    ///
1297    /// Useful for operations like `contract_zipup` where we need networks
1298    /// with the same structure but possibly different site indices.
1299    pub fn same_topology(&self, other: &Self) -> bool {
1300        self.site_index_network
1301            .topology()
1302            .same_topology(other.site_index_network.topology())
1303    }
1304
1305    /// Check if two TreeTNs have the same "appearance".
1306    ///
1307    /// Two TreeTNs have the same appearance if:
1308    /// 1. They have the same topology (same nodes and edges)
1309    /// 2. They have the same external indices (physical indices) at each node
1310    ///    (compared as sets, so order within a node doesn't matter)
1311    /// 3. They have the same orthogonalization direction (ortho_towards) on each edge
1312    ///
1313    /// This is a weaker check than `share_equivalent_site_index_network`:
1314    /// - `share_equivalent_site_index_network`: checks topology + site space (indices)
1315    /// - `same_appearance`: checks topology + site space + ortho_towards directions
1316    ///
1317    /// Note: This does NOT compare tensor data, only structural information.
1318    /// Note: Bond index IDs may differ between the two TreeTNs (e.g., after independent
1319    ///       canonicalization), so we compare ortho_towards by edge position, not by index ID.
1320    ///
1321    /// # Arguments
1322    /// * `other` - The other TreeTN to compare against
1323    ///
1324    /// # Returns
1325    /// `true` if both TreeTNs have the same appearance, `false` otherwise.
1326    pub fn same_appearance(&self, other: &Self) -> bool
1327    where
1328        <T::Index as IndexLike>::Id: Ord,
1329        V: Ord,
1330    {
1331        // Step 1: Check topology and site space
1332        if !self.share_equivalent_site_index_network(other) {
1333            return false;
1334        }
1335
1336        // Step 2: Check ortho_towards on each edge by position (node pair)
1337        // Bond index IDs may differ, so we compare by edge location (node_a, node_b)
1338        let mut self_bond_ortho_count = 0;
1339        let mut other_bond_ortho_count = 0;
1340
1341        // Count bond index entries in self
1342        for node_name in self.node_names() {
1343            let self_neighbors: Vec<V> = self.site_index_network.neighbors(&node_name).collect();
1344
1345            for neighbor_name in self_neighbors {
1346                // Only process each edge once (when node_name < neighbor_name)
1347                if node_name >= neighbor_name {
1348                    continue;
1349                }
1350
1351                // Get edge and bond in self
1352                let self_edge = match self.edge_between(&node_name, &neighbor_name) {
1353                    Some(e) => e,
1354                    None => continue,
1355                };
1356                let self_bond = match self.bond_index(self_edge) {
1357                    Some(b) => b,
1358                    None => continue,
1359                };
1360
1361                // Get edge and bond in other
1362                let other_edge = match other.edge_between(&node_name, &neighbor_name) {
1363                    Some(e) => e,
1364                    None => return false, // Edge exists in self but not in other
1365                };
1366                let other_bond = match other.bond_index(other_edge) {
1367                    Some(b) => b,
1368                    None => return false,
1369                };
1370
1371                // Compare ortho_towards for this edge
1372                let self_ortho = self.ortho_towards.get(self_bond);
1373                let other_ortho = other.ortho_towards.get(other_bond);
1374
1375                match (self_ortho, other_ortho) {
1376                    (None, None) => {} // Both have no direction - OK
1377                    (Some(self_dir), Some(other_dir)) => {
1378                        // Both have direction - must be the same
1379                        if self_dir != other_dir {
1380                            return false;
1381                        }
1382                        self_bond_ortho_count += 1;
1383                        other_bond_ortho_count += 1;
1384                    }
1385                    _ => return false, // One has direction, other doesn't
1386                }
1387            }
1388        }
1389
1390        // Verify we compared all bond ortho_towards entries
1391        // (site index ortho_towards are not compared here as they're implied by topology)
1392        // Count actual bond index entries in each ortho_towards map
1393        let self_total_bond_entries: usize = self
1394            .graph
1395            .graph()
1396            .edge_indices()
1397            .filter_map(|e| self.bond_index(e))
1398            .filter(|b| self.ortho_towards.contains_key(b))
1399            .count();
1400        let other_total_bond_entries: usize = other
1401            .graph
1402            .graph()
1403            .edge_indices()
1404            .filter_map(|e| other.bond_index(e))
1405            .filter(|b| other.ortho_towards.contains_key(b))
1406            .count();
1407
1408        if self_bond_ortho_count != self_total_bond_entries
1409            || other_bond_ortho_count != other_total_bond_entries
1410        {
1411            return false;
1412        }
1413
1414        true
1415    }
1416
1417    /// Perform an in-place adjacent swap on the edge (node_a, node_b).
1418    ///
1419    /// Contracts the two tensors, uses the explicitly scheduled site partition,
1420    /// then factorizes back in-place with `Canonical::Left` so the new center
1421    /// lands on `node_b`.
1422    pub(crate) fn swap_on_edge(
1423        &mut self,
1424        node_a_idx: NodeIndex,
1425        node_b_idx: NodeIndex,
1426        a_side_sites: &HashSet<<T::Index as IndexLike>::Id>,
1427        b_side_sites: &HashSet<<T::Index as IndexLike>::Id>,
1428        factorize_options: &FactorizeOptions,
1429    ) -> Result<()>
1430    where
1431        <T::Index as IndexLike>::Id: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
1432    {
1433        let node_b_name = self
1434            .graph
1435            .node_name(node_b_idx)
1436            .ok_or_else(|| anyhow::anyhow!("swap_on_edge: node_b not found"))?
1437            .clone();
1438
1439        let edge = {
1440            let g = self.graph.graph();
1441            g.edges_connecting(node_a_idx, node_b_idx)
1442                .next()
1443                .ok_or_else(|| anyhow::anyhow!("swap_on_edge: no edge between nodes"))?
1444                .id()
1445        };
1446        let bond_ab = self
1447            .bond_index(edge)
1448            .ok_or_else(|| anyhow::anyhow!("swap_on_edge: bond not found"))?
1449            .clone();
1450
1451        // Structural bond ids of A and B (bonds other than bond_ab)
1452        let other_bond_ids_a: HashSet<<T::Index as IndexLike>::Id> = self
1453            .edges_for_node(node_a_idx)
1454            .iter()
1455            .filter_map(|(e, _)| self.bond_index(*e).cloned())
1456            .filter(|b| b.id() != bond_ab.id())
1457            .map(|b| b.id().to_owned())
1458            .collect();
1459        let other_bond_ids_b: HashSet<<T::Index as IndexLike>::Id> = self
1460            .edges_for_node(node_b_idx)
1461            .iter()
1462            .filter_map(|(e, _)| self.bond_index(*e).cloned())
1463            .filter(|b| b.id() != bond_ab.id())
1464            .map(|b| b.id().to_owned())
1465            .collect();
1466
1467        let tensor_a = self
1468            .tensor(node_a_idx)
1469            .ok_or_else(|| anyhow::anyhow!("swap_on_edge: tensor_a not found"))?
1470            .clone();
1471        let tensor_b = self
1472            .tensor(node_b_idx)
1473            .ok_or_else(|| anyhow::anyhow!("swap_on_edge: tensor_b not found"))?
1474            .clone();
1475
1476        // Site ids currently at each node (all non-bond indices)
1477        let site_ids_a: HashSet<<T::Index as IndexLike>::Id> = tensor_a
1478            .external_indices()
1479            .iter()
1480            .filter(|i| i.id() != bond_ab.id() && !other_bond_ids_a.contains(i.id()))
1481            .map(|i| i.id().to_owned())
1482            .collect();
1483        let site_ids_b: HashSet<<T::Index as IndexLike>::Id> = tensor_b
1484            .external_indices()
1485            .iter()
1486            .filter(|i| i.id() != bond_ab.id() && !other_bond_ids_b.contains(i.id()))
1487            .map(|i| i.id().to_owned())
1488            .collect();
1489        let all_site_ids: HashSet<_> = site_ids_a.union(&site_ids_b).cloned().collect();
1490        let assigned_site_ids: HashSet<_> = a_side_sites.union(b_side_sites).cloned().collect();
1491
1492        if !a_side_sites.is_disjoint(b_side_sites) {
1493            return Err(anyhow::anyhow!(
1494                "swap_on_edge: a_side_sites and b_side_sites overlap"
1495            ));
1496        }
1497        if assigned_site_ids != all_site_ids {
1498            return Err(anyhow::anyhow!(
1499                "swap_on_edge: scheduled site partition does not match current edge sites"
1500            ));
1501        }
1502
1503        let tensor_ab = T::contract(&[&tensor_a, &tensor_b], AllowedPairs::All)
1504            .context("swap_on_edge: contract")?;
1505
1506        let ab_indices = tensor_ab.external_indices();
1507        let left_inds: Vec<T::Index> = ab_indices
1508            .iter()
1509            .filter(|i| other_bond_ids_a.contains(i.id()) || a_side_sites.contains(i.id()))
1510            .cloned()
1511            .collect();
1512
1513        let result =
1514            swap::factorize_or_trivial(&tensor_ab, &left_inds, &ab_indices, factorize_options)
1515                .context("swap_on_edge: factorize")?;
1516
1517        self.replace_edge_bond(edge, result.bond_index)
1518            .context("swap_on_edge: replace_edge_bond")?;
1519        self.replace_tensor(node_a_idx, result.left)
1520            .context("swap_on_edge: replace tensor_a")?;
1521        self.replace_tensor(node_b_idx, result.right)
1522            .context("swap_on_edge: replace tensor_b")?;
1523        self.set_edge_ortho_towards(edge, Some(node_b_name))
1524            .context("swap_on_edge: set_edge_ortho_towards")?;
1525
1526        Ok(())
1527    }
1528
1529    /// Reorder site indices so that each index id ends up at the target node.
1530    ///
1531    /// Builds a pre-computed schedule from the topology plus current and target
1532    /// site assignments, canonicalizes the network to the schedule root, then
1533    /// executes the scheduled transport and swap steps.
1534    /// Partial assignment is supported: indices not listed in
1535    /// `target_assignment` stay on their current side of every visited edge.
1536    ///
1537    /// # Arguments
1538    /// * `target_assignment` - Map from site index id to target node name.
1539    /// * `options` - Truncation options for each SVD (default: no truncation, exact).
1540    ///
1541    /// # Errors
1542    /// Returns an error if target nodes are missing, an index id is unknown, or sweep fails.
1543    pub fn swap_site_indices(
1544        &mut self,
1545        target_assignment: &HashMap<<T::Index as IndexLike>::Id, V>,
1546        options: &swap::SwapOptions,
1547    ) -> Result<()>
1548    where
1549        <T::Index as IndexLike>::Id:
1550            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
1551        V: Ord,
1552    {
1553        if target_assignment.is_empty() {
1554            return Ok(());
1555        }
1556
1557        let current = swap::current_site_assignment(self);
1558        let root = self
1559            .node_names()
1560            .into_iter()
1561            .min()
1562            .ok_or_else(|| anyhow::anyhow!("swap_site_indices: empty network"))?;
1563        let schedule = swap::SwapSchedule::build(
1564            self.site_index_network().topology(),
1565            &current,
1566            target_assignment,
1567            &root,
1568        )
1569        .context("swap_site_indices: build schedule")?;
1570
1571        if schedule.steps.is_empty() {
1572            return Ok(());
1573        }
1574
1575        self.canonicalize_mut(
1576            std::iter::once(schedule.root.clone()),
1577            crate::options::CanonicalizationOptions::default(),
1578        )
1579        .context("swap_site_indices: canonicalize")?;
1580
1581        let mut swap_factorize_options = FactorizeOptions::svd().with_canonical(Canonical::Left);
1582        if let Some(mr) = options.max_rank {
1583            swap_factorize_options = swap_factorize_options.with_max_rank(mr);
1584        }
1585        if let Some(rtol) = options.rtol {
1586            swap_factorize_options = swap_factorize_options
1587                .with_svd_policy(tensor4all_core::SvdTruncationPolicy::new(rtol));
1588        }
1589        let transport_factorize_options = FactorizeOptions::svd().with_canonical(Canonical::Left);
1590
1591        for step in &schedule.steps {
1592            for edge in step.transport_path.windows(2) {
1593                let src_name = &edge[0];
1594                let dst_name = &edge[1];
1595                let src_idx = self.node_index(src_name).ok_or_else(|| {
1596                    anyhow::anyhow!("swap_site_indices: transport node {:?} not found", src_name)
1597                })?;
1598                let dst_idx = self.node_index(dst_name).ok_or_else(|| {
1599                    anyhow::anyhow!("swap_site_indices: transport node {:?} not found", dst_name)
1600                })?;
1601                self.sweep_edge(
1602                    src_idx,
1603                    dst_idx,
1604                    &transport_factorize_options,
1605                    "swap_transport",
1606                )
1607                .context("swap_site_indices: transport")?;
1608            }
1609
1610            let a_idx = self.node_index(&step.node_a).ok_or_else(|| {
1611                anyhow::anyhow!("swap_site_indices: node {:?} not found", step.node_a)
1612            })?;
1613            let b_idx = self.node_index(&step.node_b).ok_or_else(|| {
1614                anyhow::anyhow!("swap_site_indices: node {:?} not found", step.node_b)
1615            })?;
1616            self.swap_on_edge(
1617                a_idx,
1618                b_idx,
1619                &step.a_side_sites,
1620                &step.b_side_sites,
1621                &swap_factorize_options,
1622            )
1623            .context("swap_site_indices: swap_on_edge")?;
1624            self.set_canonical_region([step.node_b.clone()])
1625                .context("swap_site_indices: set_canonical_region")?;
1626        }
1627
1628        Ok(())
1629    }
1630
1631    /// Reorder site indices so that each index ends up at the target node.
1632    ///
1633    /// Index-based version of [`swap_site_indices`](Self::swap_site_indices).
1634    /// Accepts `T::Index` keys instead of `T::Index::Id`.
1635    ///
1636    /// # Examples
1637    ///
1638    /// ```
1639    /// use std::collections::HashMap;
1640    ///
1641    /// use tensor4all_core::{DynIndex, IndexLike, TensorDynLen};
1642    /// use tensor4all_treetn::{SwapOptions, TreeTN};
1643    ///
1644    /// # fn main() -> anyhow::Result<()> {
1645    /// let node_name_a = "A".to_string();
1646    /// let node_name_b = "B".to_string();
1647    /// let idx_a = DynIndex::new_dyn(2);
1648    /// let idx_b = DynIndex::new_dyn(2);
1649    /// let bond = DynIndex::new_dyn(1);
1650    /// let t0 = TensorDynLen::from_dense(vec![idx_a.clone(), bond.clone()], vec![1.0, 0.0])?;
1651    /// let t1 = TensorDynLen::from_dense(vec![bond, idx_b.clone()], vec![1.0, 0.0])?;
1652    /// let mut treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1653    ///     vec![t0, t1],
1654    ///     vec![node_name_a.clone(), node_name_b.clone()],
1655    /// )?;
1656    ///
1657    /// let mut target = HashMap::new();
1658    /// target.insert(idx_a.clone(), node_name_b.clone());
1659    ///
1660    /// treetn.swap_site_indices_by_index(&target, &SwapOptions::default())?;
1661    ///
1662    /// assert_eq!(
1663    ///     treetn
1664    ///         .site_index_network()
1665    ///         .find_node_by_index_id(idx_a.id())
1666    ///         .map(|name| name.as_str()),
1667    ///     Some(node_name_b.as_str())
1668    /// );
1669    /// assert!(treetn.is_canonicalized());
1670    /// # Ok::<(), anyhow::Error>(())
1671    /// # }
1672    /// ```
1673    pub fn swap_site_indices_by_index(
1674        &mut self,
1675        target_assignment: &HashMap<T::Index, V>,
1676        options: &swap::SwapOptions,
1677    ) -> Result<()>
1678    where
1679        <T::Index as IndexLike>::Id:
1680            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
1681        T::Index: Hash + Eq,
1682        V: Ord,
1683    {
1684        let id_assignment: HashMap<_, _> = target_assignment
1685            .iter()
1686            .map(|(idx, node)| (idx.id().clone(), node.clone()))
1687            .collect();
1688
1689        self.swap_site_indices(&id_assignment, options)
1690    }
1691
1692    /// Verify internal data consistency by checking structural invariants and reconstructing the TreeTN.
1693    ///
1694    /// This function performs two categories of checks:
1695    ///
1696    /// ## Structural invariants (fail-fast checks):
1697    /// 0a. **Connectivity**: All tensors must form a single connected component
1698    /// 0b. **Index sharing**: Only edge-connected (adjacent) nodes may share index IDs.
1699    ///     Non-adjacent nodes sharing an index ID violates tree structure assumptions.
1700    ///
1701    /// ## Reconstruction consistency:
1702    /// After structural checks pass, clones all tensors and node names, reconstructs
1703    /// a new TreeTN using `from_tensors`, and verifies:
1704    /// 1. **Topology**: Same nodes and edges
1705    /// 2. **Site space**: Same physical indices for each node
1706    /// 3. **Tensors**: Same tensor data at each node
1707    ///
1708    /// This is useful for debugging and testing to ensure that the internal state
1709    /// of a TreeTN is consistent after complex operations.
1710    ///
1711    /// # Returns
1712    /// `Ok(())` if the internal data is consistent, or `Err` with details about the inconsistency.
1713    pub fn verify_internal_consistency(&self) -> Result<()>
1714    where
1715        <T::Index as IndexLike>::Id:
1716            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
1717        V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
1718    {
1719        // Step 0a: Verify all tensors are connected (form a single connected component)
1720        // Use DFS to check connectivity since StableGraph doesn't support connected_components
1721        let num_nodes = self.graph.graph().node_count();
1722        if num_nodes > 1 {
1723            // Start DFS from any node
1724            if let Some(start_node) = self.graph.graph().node_indices().next() {
1725                let mut dfs = Dfs::new(self.graph.graph(), start_node);
1726                let mut visited_count = 0;
1727                while dfs.next(self.graph.graph()).is_some() {
1728                    visited_count += 1;
1729                }
1730                if visited_count != num_nodes {
1731                    return Err(anyhow::anyhow!(
1732                        "TreeTN is disconnected: DFS visited {} of {} nodes. All tensors must be connected.",
1733                        visited_count,
1734                        num_nodes
1735                    ))
1736                    .context("verify_internal_consistency: graph must be connected");
1737                }
1738            }
1739        }
1740
1741        // Step 0b: Verify non-adjacent tensors don't share index IDs
1742        // Build a map from index ID to nodes that have that index
1743        let mut index_id_to_nodes: HashMap<<T::Index as IndexLike>::Id, Vec<NodeIndex>> =
1744            HashMap::new();
1745        for node_idx in self.graph.graph().node_indices() {
1746            if let Some(tensor) = self.tensor(node_idx) {
1747                for index in tensor.external_indices() {
1748                    index_id_to_nodes
1749                        .entry(index.id().clone())
1750                        .or_default()
1751                        .push(node_idx);
1752                }
1753            }
1754        }
1755
1756        // Check each index ID - if shared by multiple nodes, they must be adjacent
1757        for (index_id, nodes) in &index_id_to_nodes {
1758            if nodes.len() > 2 {
1759                // More than 2 nodes share the same index ID - always invalid for tree structure
1760                return Err(anyhow::anyhow!(
1761                    "Index ID {:?} is shared by {} nodes, but tree structure allows at most 2",
1762                    index_id,
1763                    nodes.len()
1764                ))
1765                .context("verify_internal_consistency: index ID shared by too many nodes");
1766            }
1767            if nodes.len() == 2 {
1768                // Two nodes share the index - they must be adjacent (connected by an edge)
1769                let node_a = nodes[0];
1770                let node_b = nodes[1];
1771                if self.graph.graph().find_edge(node_a, node_b).is_none()
1772                    && self.graph.graph().find_edge(node_b, node_a).is_none()
1773                {
1774                    let name_a = self.graph.node_name(node_a);
1775                    let name_b = self.graph.node_name(node_b);
1776                    return Err(anyhow::anyhow!(
1777                        "Non-adjacent nodes {:?} and {:?} share index ID {:?}. \
1778                        Only adjacent (edge-connected) nodes may share index IDs.",
1779                        name_a,
1780                        name_b,
1781                        index_id
1782                    ))
1783                    .context("verify_internal_consistency: non-adjacent nodes share index ID");
1784                }
1785            }
1786        }
1787
1788        // Step 1: Clone all tensors and node names
1789        let node_names: Vec<V> = self.node_names();
1790        let tensors: Vec<T> = node_names
1791            .iter()
1792            .filter_map(|name| {
1793                let idx = self.graph.node_index(name)?;
1794                self.tensor(idx).cloned()
1795            })
1796            .collect();
1797
1798        if tensors.len() != node_names.len() {
1799            return Err(anyhow::anyhow!(
1800                "Internal inconsistency: {} node names but {} tensors found",
1801                node_names.len(),
1802                tensors.len()
1803            ));
1804        }
1805
1806        // Step 2: Reconstruct TreeTN from scratch using from_tensors_unchecked
1807        // (use unchecked version to avoid infinite recursion)
1808        let reconstructed = TreeTN::<T, V>::from_tensors_unchecked(tensors, node_names)
1809            .context("verify_internal_consistency: failed to reconstruct TreeTN")?;
1810
1811        // Step 3: Verify topology matches
1812        if !self.same_topology(&reconstructed) {
1813            return Err(anyhow::anyhow!(
1814                "Internal inconsistency: topology does not match after reconstruction"
1815            ))
1816            .context("verify_internal_consistency: topology mismatch");
1817        }
1818
1819        // Step 4: Verify site index network matches
1820        if !self
1821            .site_index_network
1822            .share_equivalent_site_index_network(&reconstructed.site_index_network)
1823        {
1824            return Err(anyhow::anyhow!(
1825                "Internal inconsistency: site index network does not match after reconstruction"
1826            ))
1827            .context("verify_internal_consistency: site space mismatch");
1828        }
1829
1830        // Step 5: Verify tensor data matches at each node
1831        for node_name in self.node_names() {
1832            let idx_self = self
1833                .graph
1834                .node_index(&node_name)
1835                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in original", node_name))?;
1836            let idx_reconstructed =
1837                reconstructed.graph.node_index(&node_name).ok_or_else(|| {
1838                    anyhow::anyhow!("Node {:?} not found in reconstructed", node_name)
1839                })?;
1840
1841            let tensor_self = self.tensor(idx_self).ok_or_else(|| {
1842                anyhow::anyhow!("Tensor not found for node {:?} in original", node_name)
1843            })?;
1844            let tensor_reconstructed =
1845                reconstructed.tensor(idx_reconstructed).ok_or_else(|| {
1846                    anyhow::anyhow!("Tensor not found for node {:?} in reconstructed", node_name)
1847                })?;
1848
1849            // Compare tensor indices (as sets, since order may differ)
1850            let indices_self: HashSet<_> = tensor_self.external_indices().into_iter().collect();
1851            let indices_reconstructed: HashSet<_> = tensor_reconstructed
1852                .external_indices()
1853                .into_iter()
1854                .collect();
1855            if indices_self != indices_reconstructed {
1856                return Err(anyhow::anyhow!(
1857                    "Internal inconsistency: tensor indices differ at node {:?}",
1858                    node_name
1859                ))
1860                .context("verify_internal_consistency: tensor index mismatch");
1861            }
1862
1863            // Compare tensor dimensions
1864            if tensor_self.num_external_indices() != tensor_reconstructed.num_external_indices() {
1865                return Err(anyhow::anyhow!(
1866                    "Internal inconsistency: tensor dimensions differ at node {:?}: {} vs {}",
1867                    node_name,
1868                    tensor_self.num_external_indices(),
1869                    tensor_reconstructed.num_external_indices()
1870                ))
1871                .context("verify_internal_consistency: tensor dimension mismatch");
1872            }
1873        }
1874
1875        Ok(())
1876    }
1877}
1878
1879// ============================================================================
1880// Helper functions
1881// ============================================================================
1882
1883/// Find common indices between two slices of indices.
1884pub(crate) fn common_inds<I: IndexLike>(inds_a: &[I], inds_b: &[I]) -> Vec<I> {
1885    let set_b: HashSet<_> = inds_b.iter().map(|idx| idx.id()).collect();
1886    inds_a
1887        .iter()
1888        .filter(|idx| set_b.contains(idx.id()))
1889        .cloned()
1890        .collect()
1891}