Skip to main content

tensor4all_treetn/treetn/
localupdate.rs

1//! Local update operations for TreeTN.
2//!
3//! This module provides APIs for:
4//! - Extracting a sub-tree from a TreeTN (creating a new TreeTN object)
5//! - Replacing a sub-tree with another TreeTN of the same appearance
6//! - Generating sweep plans for local update algorithms (truncation, fitting)
7//!
8//! These operations are fundamental for local update algorithms in tensor networks.
9
10use std::collections::HashSet;
11use std::fmt::Debug;
12use std::hash::Hash;
13
14use anyhow::{Context, Result};
15
16use tensor4all_core::{AllowedPairs, IndexLike, TensorLike};
17
18use super::TreeTN;
19use crate::node_name_network::NodeNameNetwork;
20
21// ============================================================================
22// Local Update Sweep Plan
23// ============================================================================
24
25/// A single step in a local update sweep.
26///
27/// Each step specifies:
28/// - Which nodes to extract for local update
29/// - Where the canonical center should be after the update
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct LocalUpdateStep<V> {
32    /// Nodes to extract for local update.
33    /// For nsite=1: single node
34    /// For nsite=2: two adjacent nodes (an edge)
35    pub nodes: Vec<V>,
36
37    /// The canonical center after this update step.
38    /// For nsite=1: same as `nodes[0]`.
39    /// For nsite=2: the node in the direction of the next step.
40    pub new_center: V,
41}
42
43/// A complete sweep plan for local updates.
44///
45/// Generated from an Euler tour, this plan specifies the sequence of
46/// local update operations for algorithms like truncation and fitting.
47///
48/// # Sweep Direction
49/// The sweep follows an Euler tour starting from the root, visiting each
50/// edge twice (forward and backward). This ensures all bonds are updated
51/// in both directions, which is essential for algorithms like DMRG/TEBD.
52///
53/// # nsite Parameter
54/// - `nsite=1`: Single-site updates. Each step extracts one node.
55/// - `nsite=2`: Two-site updates. Each step extracts two adjacent nodes (an edge).
56///
57/// Two-site updates are more expensive but can change bond dimensions and
58/// are necessary for algorithms like TDVP-2 or two-site DMRG.
59#[derive(Debug, Clone)]
60pub struct LocalUpdateSweepPlan<V> {
61    /// The sequence of update steps.
62    pub steps: Vec<LocalUpdateStep<V>>,
63
64    /// Number of sites per update (1 or 2).
65    pub nsite: usize,
66}
67
68impl<V> LocalUpdateSweepPlan<V>
69where
70    V: Clone + Hash + Eq + Send + Sync + Debug,
71{
72    /// Generate a sweep plan from a TreeTN's topology.
73    ///
74    /// Convenience method that extracts the NodeNameNetwork topology from a TreeTN.
75    pub fn from_treetn<T>(treetn: &TreeTN<T, V>, root: &V, nsite: usize) -> Option<Self>
76    where
77        T: TensorLike,
78    {
79        Self::new(treetn.site_index_network().topology(), root, nsite)
80    }
81
82    /// Generate a sweep plan from a NodeNameNetwork.
83    ///
84    /// Uses Euler tour traversal to visit all edges in both directions.
85    ///
86    /// # Arguments
87    /// * `network` - The network topology
88    /// * `root` - The starting node for the sweep
89    /// * `nsite` - Number of sites per update (1 or 2)
90    ///
91    /// # Returns
92    /// A sweep plan, or `None` if the root doesn't exist or nsite is invalid.
93    ///
94    /// # Example
95    /// For nsite=1 on chain A-B-C with root B:
96    /// - Euler tour vertices: [B, A, B, C, B]
97    /// - Steps: [(B, B), (A, A), (B, B), (C, C)] (each vertex except last)
98    ///
99    /// For nsite=2 on chain A-B-C with root B:
100    /// - Euler tour edges: [(B,A), (A,B), (B,C), (C,B)]
101    /// - Steps: [({B,A}, A), ({A,B}, B), ({B,C}, C), ({C,B}, B)]
102    pub fn new(network: &NodeNameNetwork<V>, root: &V, nsite: usize) -> Option<Self> {
103        if nsite != 1 && nsite != 2 {
104            return None;
105        }
106
107        let root_idx = network.node_index(root)?;
108
109        match nsite {
110            1 => {
111                // nsite=1: Use vertex sequence from Euler tour
112                let vertices = network.euler_tour_vertices_by_index(root_idx);
113                if vertices.is_empty() {
114                    return Some(Self::empty(nsite));
115                }
116
117                // Each vertex (except the last return to root) is a step
118                // The new_center is the current vertex itself
119                let steps: Vec<_> = vertices
120                    .iter()
121                    .take(vertices.len().saturating_sub(1))
122                    .filter_map(|&v| {
123                        let name = network.node_name(v)?.clone();
124                        Some(LocalUpdateStep {
125                            nodes: vec![name.clone()],
126                            new_center: name,
127                        })
128                    })
129                    .collect();
130
131                Some(Self { steps, nsite })
132            }
133            2 => {
134                // nsite=2: Use edge sequence from Euler tour
135                let edges = network.euler_tour_edges_by_index(root_idx);
136                if edges.is_empty() {
137                    // Single node: no edges to update
138                    return Some(Self::empty(nsite));
139                }
140
141                // Each edge (u, v) becomes a step with nodes [u, v]
142                // The new_center is v (the direction we're moving)
143                let steps: Vec<_> = edges
144                    .iter()
145                    .filter_map(|&(u, v)| {
146                        let u_name = network.node_name(u)?.clone();
147                        let v_name = network.node_name(v)?.clone();
148                        Some(LocalUpdateStep {
149                            nodes: vec![u_name, v_name.clone()],
150                            new_center: v_name,
151                        })
152                    })
153                    .collect();
154
155                Some(Self { steps, nsite })
156            }
157            _ => None,
158        }
159    }
160
161    /// Create an empty sweep plan.
162    pub fn empty(nsite: usize) -> Self {
163        Self {
164            steps: Vec::new(),
165            nsite,
166        }
167    }
168
169    /// Check if the plan is empty.
170    pub fn is_empty(&self) -> bool {
171        self.steps.is_empty()
172    }
173
174    /// Number of update steps.
175    pub fn len(&self) -> usize {
176        self.steps.len()
177    }
178
179    /// Iterate over the steps.
180    pub fn iter(&self) -> impl Iterator<Item = &LocalUpdateStep<V>> {
181        self.steps.iter()
182    }
183}
184
185// ============================================================================
186// Boundary edge/bond utilities
187// ============================================================================
188
189/// Boundary edge information: (node_in_region, neighbor_outside, bond_index).
190///
191/// Represents an edge connecting a node inside the region to a neighbor outside the region.
192#[derive(Debug, Clone)]
193pub struct BoundaryEdge<T, V>
194where
195    T: TensorLike,
196    V: Clone + Hash + Eq,
197{
198    /// Node inside the region
199    pub node_in_region: V,
200    /// Neighbor outside the region
201    pub neighbor_outside: V,
202    /// Bond index connecting node_in_region to neighbor_outside
203    pub bond_index: T::Index,
204}
205
206/// Get all boundary edges for a given region in a TreeTN.
207///
208/// Returns edges connecting nodes inside the region to neighbors outside the region.
209/// This is useful for maintaining stable bond IDs across updates (e.g., for environment cache consistency).
210///
211/// # Arguments
212/// * `treetn` - The TreeTN to analyze
213/// * `region` - Nodes that are inside the region
214///
215/// # Returns
216/// Vector of boundary edges, each containing the node in region, neighbor outside, and bond index.
217pub fn get_boundary_edges<T, V>(
218    treetn: &TreeTN<T, V>,
219    region: &[V],
220) -> Result<Vec<BoundaryEdge<T, V>>>
221where
222    T: TensorLike,
223    T::Index: IndexLike,
224    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
225{
226    let mut boundary_edges = Vec::new();
227    let region_set: HashSet<&V> = region.iter().collect();
228
229    for node in region {
230        for neighbor in treetn.site_index_network().neighbors(node) {
231            if !region_set.contains(&neighbor) {
232                // This is a boundary edge: node is in region, neighbor is outside
233                if let Some(edge) = treetn.edge_between(node, &neighbor) {
234                    if let Some(bond) = treetn.bond_index(edge) {
235                        boundary_edges.push(BoundaryEdge {
236                            node_in_region: node.clone(),
237                            neighbor_outside: neighbor.clone(),
238                            bond_index: bond.clone(),
239                        });
240                    }
241                }
242            }
243        }
244    }
245
246    Ok(boundary_edges)
247}
248
249// ============================================================================
250// LocalUpdater trait
251// ============================================================================
252
253/// Trait for local update operations during a sweep.
254///
255/// Implementors of this trait provide the actual update logic that transforms
256/// a local subtree into an updated version. This allows different algorithms
257/// (truncation, fitting, DMRG, TDVP) to share the same sweep infrastructure.
258///
259/// # Type Parameters
260/// - `T`: Tensor type implementing TensorLike
261/// - `V`: Node name type
262///
263/// # Workflow
264/// During `apply_local_update_sweep`:
265/// 1. For each step in the sweep plan:
266///    a. Extract the subtree containing `step.nodes`
267///    b. Call `update()` with the extracted subtree and step info
268///    c. Replace the subtree in the original TreeTN with the updated one
269///    d. Update the canonical center to `step.new_center`
270pub trait LocalUpdater<T, V>
271where
272    T: TensorLike,
273    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
274{
275    /// Optional hook called before performing an update step.
276    ///
277    /// This is called with the full TreeTN state *before* the update is applied.
278    /// Implementors can use it to validate assumptions or prefetch/update caches.
279    fn before_step(
280        &mut self,
281        _step: &LocalUpdateStep<V>,
282        _full_treetn_before: &TreeTN<T, V>,
283    ) -> Result<()> {
284        Ok(())
285    }
286
287    /// Update a local subtree.
288    ///
289    /// # Arguments
290    /// * `subtree` - The extracted subtree to update
291    /// * `step` - The current step information (nodes and new_center)
292    /// * `full_treetn` - Reference to the full (global) TreeTN. This provides global context
293    ///   (e.g., topology, neighbor relations, and index/bond metadata) that some update
294    ///   algorithms may need. It may be unused by simple updaters.
295    ///
296    /// # Returns
297    /// The updated subtree, which must have the same "appearance" as the input
298    /// (same nodes, same external indices, same ortho_towards structure).
299    ///
300    /// # Errors
301    /// Returns an error if the update fails (e.g., SVD doesn't converge).
302    fn update(
303        &mut self,
304        subtree: TreeTN<T, V>,
305        step: &LocalUpdateStep<V>,
306        full_treetn: &TreeTN<T, V>,
307    ) -> Result<TreeTN<T, V>>;
308
309    /// Optional hook called after an update step has been applied to the full TreeTN.
310    ///
311    /// This is called after:
312    /// - The updated subtree has been inserted back into the full TreeTN
313    /// - The canonical center has been moved to `step.new_center`
314    ///
315    /// Implementors can use this to update caches that must see the post-update state.
316    fn after_step(
317        &mut self,
318        _step: &LocalUpdateStep<V>,
319        _full_treetn_after: &TreeTN<T, V>,
320    ) -> Result<()> {
321        Ok(())
322    }
323}
324
325/// Apply a local update sweep to a TreeTN.
326///
327/// This function orchestrates the sweep by:
328/// 1. Iterating through the sweep plan
329/// 2. For each step:
330///    a. Validate that the canonical center is a single node within the extracted subtree
331///    b. Extract the local subtree
332///    c. Call the updater to transform it
333///    d. Replace the subtree back into the TreeTN
334///
335/// # Arguments
336/// * `treetn` - The TreeTN to update (modified in place)
337/// * `plan` - The sweep plan specifying the update order
338/// * `updater` - The local updater implementation
339///
340/// # Preconditions
341/// - The TreeTN must be canonicalized with a single-node canonical center
342/// - The canonical center must be within the first step's nodes
343///
344/// # Returns
345/// `Ok(())` if the sweep completes successfully.
346///
347/// # Errors
348/// Returns an error if:
349/// - TreeTN is not canonicalized (canonical_region is empty)
350/// - canonical_region is not a single node
351/// - canonical_region is not within the extracted subtree
352/// - Subtree extraction fails
353/// - The updater returns an error
354/// - Subtree replacement fails
355pub fn apply_local_update_sweep<T, V, U>(
356    treetn: &mut TreeTN<T, V>,
357    plan: &LocalUpdateSweepPlan<V>,
358    updater: &mut U,
359) -> Result<()>
360where
361    T: TensorLike,
362    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
363    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
364    U: LocalUpdater<T, V>,
365{
366    for step in plan.iter() {
367        // Validate: canonical_region must be a single node within the step's nodes
368        let canonical_region = treetn.canonical_region();
369        if canonical_region.is_empty() {
370            return Err(anyhow::anyhow!(
371                "TreeTN is not canonicalized: canonical_region is empty"
372            ))
373            .context("apply_local_update_sweep: TreeTN must be canonicalized before sweep");
374        }
375        if canonical_region.len() != 1 {
376            return Err(anyhow::anyhow!(
377                "canonical_region must be a single node, got {} nodes",
378                canonical_region.len()
379            ))
380            .context("apply_local_update_sweep: canonical_region must be a single node");
381        }
382        let center_node = canonical_region.iter().next().unwrap();
383        let step_nodes_set: HashSet<V> = step.nodes.iter().cloned().collect();
384        if !step_nodes_set.contains(center_node) {
385            return Err(anyhow::anyhow!(
386                "canonical_region {:?} is not within the extracted subtree {:?}",
387                center_node,
388                step.nodes
389            ))
390            .context(
391                "apply_local_update_sweep: canonical_region must be within extracted subtree",
392            );
393        }
394
395        updater
396            .before_step(step, treetn)
397            .context("apply_local_update_sweep: LocalUpdater::before_step failed")?;
398
399        // Extract subtree for the nodes in this step
400        let subtree = treetn.extract_subtree(&step.nodes)?;
401
402        // Apply the update
403        let updated_subtree = updater.update(subtree, step, treetn)?;
404
405        // Replace the subtree back
406        treetn.replace_subtree(&step.nodes, &updated_subtree)?;
407
408        // Update canonical center
409        treetn.set_canonical_region([step.new_center.clone()])?;
410
411        updater
412            .after_step(step, treetn)
413            .context("apply_local_update_sweep: LocalUpdater::after_step failed")?;
414    }
415
416    Ok(())
417}
418
419// ============================================================================
420// TruncateUpdater - LocalUpdater implementation for truncation
421// ============================================================================
422
423use tensor4all_core::{Canonical, FactorizeOptions, SvdTruncationPolicy};
424
425/// Truncation updater for nsite=2 sweeps.
426///
427/// This updater performs SVD-based truncation on two-site subtrees,
428/// compressing bond dimensions while preserving the tensor train structure.
429///
430/// # Algorithm
431/// For each step with nodes [A, B] where B is the new center:
432/// 1. Contract tensors A and B into a single tensor AB
433/// 2. Factorize AB using SVD with truncation (left indices = A's external + bond to A's other neighbors)
434/// 3. The left tensor becomes the new A, the right tensor becomes the new B
435/// 4. B is the orthogonality center (isometry pointing towards B)
436///
437/// # Usage
438/// ```
439/// use tensor4all_core::{DynIndex, TensorDynLen};
440/// use tensor4all_treetn::{apply_local_update_sweep, LocalUpdateSweepPlan, TreeTN, TruncateUpdater};
441///
442/// # fn main() -> anyhow::Result<()> {
443/// let s0 = DynIndex::new_dyn(2);
444/// let bond = DynIndex::new_dyn(1);
445/// let s1 = DynIndex::new_dyn(2);
446/// let t0 = TensorDynLen::from_dense(vec![s0, bond.clone()], vec![1.0, 0.0])?;
447/// let t1 = TensorDynLen::from_dense(vec![bond, s1], vec![1.0, 0.0])?;
448/// let mut treetn = TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0, t1], vec![0, 1])?;
449/// treetn.canonicalize_mut(std::iter::once(0usize), Default::default())?;
450///
451/// let plan = LocalUpdateSweepPlan::from_treetn(&treetn, &0usize, 2).unwrap();
452/// let mut updater = TruncateUpdater::new(
453///     Some(4),
454///     Some(tensor4all_core::SvdTruncationPolicy::new(1e-10)),
455/// );
456/// apply_local_update_sweep(&mut treetn, &plan, &mut updater)?;
457///
458/// assert_eq!(treetn.node_count(), 2);
459/// # Ok(())
460/// # }
461/// ```
462#[derive(Debug, Clone)]
463pub struct TruncateUpdater {
464    /// Maximum bond dimension after truncation.
465    pub max_rank: Option<usize>,
466    /// Explicit SVD truncation policy.
467    pub svd_policy: Option<SvdTruncationPolicy>,
468}
469
470impl TruncateUpdater {
471    /// Create a new truncation updater.
472    ///
473    /// # Arguments
474    /// * `max_rank` - Maximum bond dimension (None for no limit)
475    /// * `svd_policy` - SVD truncation policy override (None uses the global default)
476    pub fn new(max_rank: Option<usize>, svd_policy: Option<SvdTruncationPolicy>) -> Self {
477        Self {
478            max_rank,
479            svd_policy,
480        }
481    }
482}
483
484impl<T, V> LocalUpdater<T, V> for TruncateUpdater
485where
486    T: TensorLike,
487    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
488    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
489{
490    fn update(
491        &mut self,
492        mut subtree: TreeTN<T, V>,
493        step: &LocalUpdateStep<V>,
494        _full_treetn: &TreeTN<T, V>,
495    ) -> Result<TreeTN<T, V>> {
496        // TruncateUpdater is designed for nsite=2
497        if step.nodes.len() != 2 {
498            return Err(anyhow::anyhow!(
499                "TruncateUpdater requires exactly 2 nodes, got {}",
500                step.nodes.len()
501            ));
502        }
503
504        let node_a = &step.nodes[0];
505        let node_b = &step.nodes[1];
506
507        // Get node indices
508        let idx_a = subtree
509            .node_index(node_a)
510            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", node_a))?;
511        let idx_b = subtree
512            .node_index(node_b)
513            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", node_b))?;
514
515        // Get the bond between A and B
516        let edge_ab = subtree
517            .edge_between(node_a, node_b)
518            .ok_or_else(|| anyhow::anyhow!("No edge between {:?} and {:?}", node_a, node_b))?;
519        let bond_ab = subtree
520            .bond_index(edge_ab)
521            .ok_or_else(|| anyhow::anyhow!("Bond index not found"))?
522            .clone();
523
524        // Contract A and B
525        let tensor_a = subtree.tensor(idx_a).unwrap();
526        let tensor_b = subtree.tensor(idx_b).unwrap();
527        let tensor_ab = T::contract(&[tensor_a, tensor_b], AllowedPairs::All)
528            .context("Failed to contract A and B")?;
529
530        // Determine left indices (indices that will remain on A after factorization)
531        // These are: all indices of A except the bond to B
532        let left_inds: Vec<_> = tensor_a
533            .external_indices()
534            .iter()
535            .filter(|idx| idx.id() != bond_ab.id())
536            .cloned()
537            .collect();
538
539        // Set up factorization options
540        let mut options = FactorizeOptions::svd().with_canonical(Canonical::Left); // Left canonical: A is isometry, B has the norm
541
542        if let Some(max_rank) = self.max_rank {
543            options = options.with_max_rank(max_rank);
544        }
545        if let Some(policy) = self.svd_policy {
546            options = options.with_svd_policy(policy);
547        }
548
549        // Factorize
550        let factorize_result = tensor_ab
551            .factorize(&left_inds, &options)
552            .map_err(|e| anyhow::anyhow!("Factorization failed: {}", e))?;
553
554        let new_tensor_a = factorize_result.left;
555        let new_tensor_b = factorize_result.right;
556        let new_bond = factorize_result.bond_index;
557
558        // Update the subtree - first update the edge bond, then the tensors
559        // The factorize result creates a new bond index, so we update the edge to use it
560        subtree.replace_edge_bond(edge_ab, new_bond.clone())?;
561        subtree.replace_tensor(idx_a, new_tensor_a)?;
562        subtree.replace_tensor(idx_b, new_tensor_b)?;
563
564        // Set ortho_towards: bond points towards new_center (B)
565        subtree.set_ortho_towards(&new_bond, Some(step.new_center.clone()));
566
567        // Set canonical center to the new center
568        subtree.set_canonical_region([step.new_center.clone()])?;
569
570        Ok(subtree)
571    }
572}
573
574// ============================================================================
575// Sub-tree extraction
576// ============================================================================
577
578impl<T, V> TreeTN<T, V>
579where
580    T: TensorLike,
581    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
582{
583    /// Extract a sub-tree from this TreeTN.
584    ///
585    /// Creates a new TreeTN containing only the specified nodes and their
586    /// connecting edges. Tensors are cloned into the new TreeTN.
587    ///
588    /// # Arguments
589    /// * `node_names` - The names of nodes to include in the sub-tree
590    ///
591    /// # Returns
592    /// A new TreeTN containing the specified sub-tree, or an error if:
593    /// - Any specified node doesn't exist
594    /// - The specified nodes don't form a connected subtree
595    ///
596    /// # Notes
597    /// - Bond indices between included nodes are preserved
598    /// - Bond indices to excluded nodes become external (site) indices in the sub-tree
599    /// - ortho_towards directions are copied for edges within the sub-tree
600    /// - canonical_region is intersected with the extracted nodes
601    pub fn extract_subtree(&self, node_names: &[V]) -> Result<Self>
602    where
603        <T::Index as IndexLike>::Id:
604            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
605        V: Ord,
606    {
607        if node_names.is_empty() {
608            return Err(anyhow::anyhow!("Cannot extract empty subtree"));
609        }
610
611        // Validate all nodes exist
612        for name in node_names {
613            if self.graph.node_index(name).is_none() {
614                return Err(anyhow::anyhow!("Node {:?} does not exist", name))
615                    .context("extract_subtree: invalid node name");
616            }
617        }
618
619        // Check connectivity: the specified nodes must form a connected subtree
620        let node_indices: HashSet<_> = node_names
621            .iter()
622            .filter_map(|n| self.graph.node_index(n))
623            .collect();
624
625        if !self.site_index_network.is_connected_subset(&node_indices) {
626            return Err(anyhow::anyhow!(
627                "Specified nodes do not form a connected subtree"
628            ))
629            .context("extract_subtree: nodes must be connected");
630        }
631
632        let node_name_set: HashSet<V> = node_names.iter().cloned().collect();
633
634        // Create new TreeTN with extracted tensors
635        let mut subtree = TreeTN::<T, V>::new();
636
637        // Step 1: Add all nodes with their tensors
638        for name in node_names {
639            let node_idx = self.graph.node_index(name).unwrap();
640            let tensor = self
641                .tensor(node_idx)
642                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", name))?
643                .clone();
644
645            subtree
646                .add_tensor(name.clone(), tensor)
647                .context("extract_subtree: failed to add tensor")?;
648        }
649
650        // Step 2: Add edges between nodes in the subtree
651        // Track which edges we've already added to avoid duplicates
652        let mut added_edges: HashSet<(V, V)> = HashSet::new();
653
654        for name in node_names {
655            let neighbors: Vec<V> = self.site_index_network.neighbors(name).collect();
656
657            for neighbor in neighbors {
658                // Only add edge if neighbor is also in the subtree
659                if !node_name_set.contains(&neighbor) {
660                    continue;
661                }
662
663                // Avoid adding the same edge twice (undirected)
664                let edge_key = if *name < neighbor {
665                    (name.clone(), neighbor.clone())
666                } else {
667                    (neighbor.clone(), name.clone())
668                };
669
670                if added_edges.contains(&edge_key) {
671                    continue;
672                }
673                added_edges.insert(edge_key);
674
675                // Get bond index from original TreeTN
676                let orig_edge = self.edge_between(name, &neighbor).ok_or_else(|| {
677                    anyhow::anyhow!("Edge not found between {:?} and {:?}", name, neighbor)
678                })?;
679
680                let bond_index = self
681                    .bond_index(orig_edge)
682                    .ok_or_else(|| anyhow::anyhow!("Bond index not found"))?
683                    .clone();
684
685                // Get node indices in new subtree
686                let subtree_node_a = subtree
687                    .graph
688                    .node_index(name)
689                    .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", name))?;
690                let subtree_node_b = subtree
691                    .graph
692                    .node_index(&neighbor)
693                    .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", neighbor))?;
694
695                // Connect in subtree
696                subtree
697                    .connect(subtree_node_a, &bond_index, subtree_node_b, &bond_index)
698                    .context("extract_subtree: failed to connect nodes")?;
699
700                // Copy ortho_towards if it exists (keyed by full bond index)
701                if let Some(ortho_dir) = self.ortho_towards.get(&bond_index) {
702                    // Only copy if the direction node is in the subtree
703                    if node_name_set.contains(ortho_dir) {
704                        subtree
705                            .ortho_towards
706                            .insert(bond_index.clone(), ortho_dir.clone());
707                    }
708                }
709            }
710        }
711
712        // Step 3: Set canonical_region to intersection with extracted nodes
713        let new_center: HashSet<V> = self
714            .canonical_region
715            .intersection(&node_name_set)
716            .cloned()
717            .collect();
718        subtree.canonical_region = new_center;
719
720        // Copy canonical_form if any center nodes were included
721        if !subtree.canonical_region.is_empty() {
722            subtree.canonical_form = self.canonical_form;
723        }
724
725        Ok(subtree)
726    }
727}
728
729// ============================================================================
730// Sub-tree replacement
731// ============================================================================
732
733impl<T, V> TreeTN<T, V>
734where
735    T: TensorLike,
736    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
737{
738    /// Replace a sub-tree with another TreeTN of the same topology.
739    ///
740    /// This method replaces the tensors and ortho_towards directions for a subset
741    /// of nodes with those from another TreeTN. The replacement TreeTN must have
742    /// the same topology (nodes and edges) as the sub-tree being replaced.
743    ///
744    /// # Arguments
745    /// * `node_names` - The names of nodes to replace
746    /// * `replacement` - The TreeTN to use as replacement
747    ///
748    /// # Returns
749    /// `Ok(())` if the replacement succeeds, or an error if:
750    /// - Any specified node doesn't exist
751    /// - The replacement doesn't have the same topology as the extracted sub-tree
752    /// - Tensor replacement fails
753    ///
754    /// # Notes
755    /// - The replacement TreeTN must have the same nodes, edges, and site indices
756    /// - Bond dimensions may differ (this is the typical use case for truncation)
757    /// - ortho_towards may differ (will be copied from replacement)
758    /// - The original TreeTN is modified in-place
759    pub fn replace_subtree(&mut self, node_names: &[V], replacement: &Self) -> Result<()>
760    where
761        <T::Index as IndexLike>::Id:
762            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
763        V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
764    {
765        if node_names.is_empty() {
766            return Ok(()); // Nothing to replace
767        }
768
769        // Extract current subtree for comparison
770        let current_subtree = self.extract_subtree(node_names)?;
771
772        // Verify that replacement has the same topology (nodes and edges)
773        // Note: site index network may differ due to bond dimension changes in truncation
774        if !current_subtree.same_topology(replacement) {
775            return Err(anyhow::anyhow!(
776                "Replacement TreeTN does not have the same topology as the current subtree"
777            ))
778            .context("replace_subtree: topology mismatch");
779        }
780
781        let node_name_set: HashSet<V> = node_names.iter().cloned().collect();
782        let mut processed_edges: HashSet<(V, V)> = HashSet::new();
783
784        // Step 1: Update edge bond indices FIRST (before replacing tensors)
785        // This is crucial because replace_tensor validates that tensors contain connection indices
786        for name in node_names {
787            let neighbors: Vec<V> = self.site_index_network.neighbors(name).collect();
788
789            for neighbor in neighbors {
790                // Only process edges within the subtree
791                if !node_name_set.contains(&neighbor) {
792                    continue;
793                }
794
795                let edge_key = if *name < neighbor {
796                    (name.clone(), neighbor.clone())
797                } else {
798                    (neighbor.clone(), name.clone())
799                };
800
801                if processed_edges.contains(&edge_key) {
802                    continue;
803                }
804                processed_edges.insert(edge_key.clone());
805
806                // Get edges in both self and replacement
807                let self_edge = self
808                    .edge_between(name, &neighbor)
809                    .ok_or_else(|| anyhow::anyhow!("Edge not found in self"))?;
810                let replacement_edge = replacement
811                    .edge_between(name, &neighbor)
812                    .ok_or_else(|| anyhow::anyhow!("Edge not found in replacement"))?;
813
814                // Get new bond index from replacement
815                let new_bond = replacement
816                    .bond_index(replacement_edge)
817                    .ok_or_else(|| anyhow::anyhow!("Bond index not found in replacement"))?
818                    .clone();
819
820                // Update bond index in self
821                self.replace_edge_bond(self_edge, new_bond.clone())
822                    .with_context(|| {
823                        format!(
824                            "replace_subtree: failed to update bond between {:?} and {:?}",
825                            name, neighbor
826                        )
827                    })?;
828
829                // Copy ortho_towards from replacement (using the new bond)
830                match replacement.ortho_towards.get(&new_bond) {
831                    Some(dir) => {
832                        self.ortho_towards.insert(new_bond, dir.clone());
833                    }
834                    None => {
835                        self.ortho_towards.remove(&new_bond);
836                    }
837                }
838            }
839        }
840
841        // Step 2: Replace tensors (now bond indices match)
842        for name in node_names {
843            let self_node_idx = self
844                .graph
845                .node_index(name)
846                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", name))?;
847            let replacement_node_idx = replacement
848                .graph
849                .node_index(name)
850                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in replacement", name))?;
851
852            let new_tensor = replacement
853                .tensor(replacement_node_idx)
854                .ok_or_else(|| {
855                    anyhow::anyhow!("Tensor not found for node {:?} in replacement", name)
856                })?
857                .clone();
858
859            self.replace_tensor(self_node_idx, new_tensor)
860                .with_context(|| {
861                    format!(
862                        "replace_subtree: failed to replace tensor at node {:?}",
863                        name
864                    )
865                })?;
866        }
867
868        // Update canonical_region: remove old nodes, add from replacement
869        for name in node_names {
870            self.canonical_region.remove(name);
871        }
872        for name in &replacement.canonical_region {
873            if node_name_set.contains(name) {
874                self.canonical_region.insert(name.clone());
875            }
876        }
877
878        // Update canonical_form if replacement has one
879        if replacement.canonical_form.is_some() {
880            self.canonical_form = replacement.canonical_form;
881        }
882
883        Ok(())
884    }
885}
886
887#[cfg(test)]
888mod tests;