tensor4all_treetn/treetn/localupdate.rs
1//! Local update operations for TreeTN.
2//!
3//! This module provides APIs for:
4//! - Extracting a sub-tree from a TreeTN (creating a new TreeTN object)
5//! - Replacing a sub-tree with another TreeTN of the same appearance
6//! - Generating sweep plans for local update algorithms (truncation, fitting)
7//!
8//! These operations are fundamental for local update algorithms in tensor networks.
9
10use std::collections::HashSet;
11use std::fmt::Debug;
12use std::hash::Hash;
13
14use anyhow::{Context, Result};
15
16use tensor4all_core::{AllowedPairs, IndexLike, TensorLike};
17
18use super::TreeTN;
19use crate::node_name_network::NodeNameNetwork;
20
21// ============================================================================
22// Local Update Sweep Plan
23// ============================================================================
24
25/// A single step in a local update sweep.
26///
27/// Each step specifies:
28/// - Which nodes to extract for local update
29/// - Where the canonical center should be after the update
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct LocalUpdateStep<V> {
32 /// Nodes to extract for local update.
33 /// For nsite=1: single node
34 /// For nsite=2: two adjacent nodes (an edge)
35 pub nodes: Vec<V>,
36
37 /// The canonical center after this update step.
38 /// For nsite=1: same as `nodes[0]`.
39 /// For nsite=2: the node in the direction of the next step.
40 pub new_center: V,
41}
42
43/// A complete sweep plan for local updates.
44///
45/// Generated from an Euler tour, this plan specifies the sequence of
46/// local update operations for algorithms like truncation and fitting.
47///
48/// # Sweep Direction
49/// The sweep follows an Euler tour starting from the root, visiting each
50/// edge twice (forward and backward). This ensures all bonds are updated
51/// in both directions, which is essential for algorithms like DMRG/TEBD.
52///
53/// # nsite Parameter
54/// - `nsite=1`: Single-site updates. Each step extracts one node.
55/// - `nsite=2`: Two-site updates. Each step extracts two adjacent nodes (an edge).
56///
57/// Two-site updates are more expensive but can change bond dimensions and
58/// are necessary for algorithms like TDVP-2 or two-site DMRG.
59#[derive(Debug, Clone)]
60pub struct LocalUpdateSweepPlan<V> {
61 /// The sequence of update steps.
62 pub steps: Vec<LocalUpdateStep<V>>,
63
64 /// Number of sites per update (1 or 2).
65 pub nsite: usize,
66}
67
68impl<V> LocalUpdateSweepPlan<V>
69where
70 V: Clone + Hash + Eq + Send + Sync + Debug,
71{
72 /// Generate a sweep plan from a TreeTN's topology.
73 ///
74 /// Convenience method that extracts the NodeNameNetwork topology from a TreeTN.
75 pub fn from_treetn<T>(treetn: &TreeTN<T, V>, root: &V, nsite: usize) -> Option<Self>
76 where
77 T: TensorLike,
78 {
79 Self::new(treetn.site_index_network().topology(), root, nsite)
80 }
81
82 /// Generate a sweep plan from a NodeNameNetwork.
83 ///
84 /// Uses Euler tour traversal to visit all edges in both directions.
85 ///
86 /// # Arguments
87 /// * `network` - The network topology
88 /// * `root` - The starting node for the sweep
89 /// * `nsite` - Number of sites per update (1 or 2)
90 ///
91 /// # Returns
92 /// A sweep plan, or `None` if the root doesn't exist or nsite is invalid.
93 ///
94 /// # Example
95 /// For nsite=1 on chain A-B-C with root B:
96 /// - Euler tour vertices: [B, A, B, C, B]
97 /// - Steps: [(B, B), (A, A), (B, B), (C, C)] (each vertex except last)
98 ///
99 /// For nsite=2 on chain A-B-C with root B:
100 /// - Euler tour edges: [(B,A), (A,B), (B,C), (C,B)]
101 /// - Steps: [({B,A}, A), ({A,B}, B), ({B,C}, C), ({C,B}, B)]
102 pub fn new(network: &NodeNameNetwork<V>, root: &V, nsite: usize) -> Option<Self> {
103 if nsite != 1 && nsite != 2 {
104 return None;
105 }
106
107 let root_idx = network.node_index(root)?;
108
109 match nsite {
110 1 => {
111 // nsite=1: Use vertex sequence from Euler tour
112 let vertices = network.euler_tour_vertices_by_index(root_idx);
113 if vertices.is_empty() {
114 return Some(Self::empty(nsite));
115 }
116
117 // Each vertex (except the last return to root) is a step
118 // The new_center is the current vertex itself
119 let steps: Vec<_> = vertices
120 .iter()
121 .take(vertices.len().saturating_sub(1))
122 .filter_map(|&v| {
123 let name = network.node_name(v)?.clone();
124 Some(LocalUpdateStep {
125 nodes: vec![name.clone()],
126 new_center: name,
127 })
128 })
129 .collect();
130
131 Some(Self { steps, nsite })
132 }
133 2 => {
134 // nsite=2: Use edge sequence from Euler tour
135 let edges = network.euler_tour_edges_by_index(root_idx);
136 if edges.is_empty() {
137 // Single node: no edges to update
138 return Some(Self::empty(nsite));
139 }
140
141 // Each edge (u, v) becomes a step with nodes [u, v]
142 // The new_center is v (the direction we're moving)
143 let steps: Vec<_> = edges
144 .iter()
145 .filter_map(|&(u, v)| {
146 let u_name = network.node_name(u)?.clone();
147 let v_name = network.node_name(v)?.clone();
148 Some(LocalUpdateStep {
149 nodes: vec![u_name, v_name.clone()],
150 new_center: v_name,
151 })
152 })
153 .collect();
154
155 Some(Self { steps, nsite })
156 }
157 _ => None,
158 }
159 }
160
161 /// Create an empty sweep plan.
162 pub fn empty(nsite: usize) -> Self {
163 Self {
164 steps: Vec::new(),
165 nsite,
166 }
167 }
168
169 /// Check if the plan is empty.
170 pub fn is_empty(&self) -> bool {
171 self.steps.is_empty()
172 }
173
174 /// Number of update steps.
175 pub fn len(&self) -> usize {
176 self.steps.len()
177 }
178
179 /// Iterate over the steps.
180 pub fn iter(&self) -> impl Iterator<Item = &LocalUpdateStep<V>> {
181 self.steps.iter()
182 }
183}
184
185// ============================================================================
186// Boundary edge/bond utilities
187// ============================================================================
188
189/// Boundary edge information: (node_in_region, neighbor_outside, bond_index).
190///
191/// Represents an edge connecting a node inside the region to a neighbor outside the region.
192#[derive(Debug, Clone)]
193pub struct BoundaryEdge<T, V>
194where
195 T: TensorLike,
196 V: Clone + Hash + Eq,
197{
198 /// Node inside the region
199 pub node_in_region: V,
200 /// Neighbor outside the region
201 pub neighbor_outside: V,
202 /// Bond index connecting node_in_region to neighbor_outside
203 pub bond_index: T::Index,
204}
205
206/// Get all boundary edges for a given region in a TreeTN.
207///
208/// Returns edges connecting nodes inside the region to neighbors outside the region.
209/// This is useful for maintaining stable bond IDs across updates (e.g., for environment cache consistency).
210///
211/// # Arguments
212/// * `treetn` - The TreeTN to analyze
213/// * `region` - Nodes that are inside the region
214///
215/// # Returns
216/// Vector of boundary edges, each containing the node in region, neighbor outside, and bond index.
217pub fn get_boundary_edges<T, V>(
218 treetn: &TreeTN<T, V>,
219 region: &[V],
220) -> Result<Vec<BoundaryEdge<T, V>>>
221where
222 T: TensorLike,
223 T::Index: IndexLike,
224 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
225{
226 let mut boundary_edges = Vec::new();
227 let region_set: HashSet<&V> = region.iter().collect();
228
229 for node in region {
230 for neighbor in treetn.site_index_network().neighbors(node) {
231 if !region_set.contains(&neighbor) {
232 // This is a boundary edge: node is in region, neighbor is outside
233 if let Some(edge) = treetn.edge_between(node, &neighbor) {
234 if let Some(bond) = treetn.bond_index(edge) {
235 boundary_edges.push(BoundaryEdge {
236 node_in_region: node.clone(),
237 neighbor_outside: neighbor.clone(),
238 bond_index: bond.clone(),
239 });
240 }
241 }
242 }
243 }
244 }
245
246 Ok(boundary_edges)
247}
248
249// ============================================================================
250// LocalUpdater trait
251// ============================================================================
252
253/// Trait for local update operations during a sweep.
254///
255/// Implementors of this trait provide the actual update logic that transforms
256/// a local subtree into an updated version. This allows different algorithms
257/// (truncation, fitting, DMRG, TDVP) to share the same sweep infrastructure.
258///
259/// # Type Parameters
260/// - `T`: Tensor type implementing TensorLike
261/// - `V`: Node name type
262///
263/// # Workflow
264/// During `apply_local_update_sweep`:
265/// 1. For each step in the sweep plan:
266/// a. Extract the subtree containing `step.nodes`
267/// b. Call `update()` with the extracted subtree and step info
268/// c. Replace the subtree in the original TreeTN with the updated one
269/// d. Update the canonical center to `step.new_center`
270pub trait LocalUpdater<T, V>
271where
272 T: TensorLike,
273 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
274{
275 /// Optional hook called before performing an update step.
276 ///
277 /// This is called with the full TreeTN state *before* the update is applied.
278 /// Implementors can use it to validate assumptions or prefetch/update caches.
279 fn before_step(
280 &mut self,
281 _step: &LocalUpdateStep<V>,
282 _full_treetn_before: &TreeTN<T, V>,
283 ) -> Result<()> {
284 Ok(())
285 }
286
287 /// Update a local subtree.
288 ///
289 /// # Arguments
290 /// * `subtree` - The extracted subtree to update
291 /// * `step` - The current step information (nodes and new_center)
292 /// * `full_treetn` - Reference to the full (global) TreeTN. This provides global context
293 /// (e.g., topology, neighbor relations, and index/bond metadata) that some update
294 /// algorithms may need. It may be unused by simple updaters.
295 ///
296 /// # Returns
297 /// The updated subtree, which must have the same "appearance" as the input
298 /// (same nodes, same external indices, same ortho_towards structure).
299 ///
300 /// # Errors
301 /// Returns an error if the update fails (e.g., SVD doesn't converge).
302 fn update(
303 &mut self,
304 subtree: TreeTN<T, V>,
305 step: &LocalUpdateStep<V>,
306 full_treetn: &TreeTN<T, V>,
307 ) -> Result<TreeTN<T, V>>;
308
309 /// Optional hook called after an update step has been applied to the full TreeTN.
310 ///
311 /// This is called after:
312 /// - The updated subtree has been inserted back into the full TreeTN
313 /// - The canonical center has been moved to `step.new_center`
314 ///
315 /// Implementors can use this to update caches that must see the post-update state.
316 fn after_step(
317 &mut self,
318 _step: &LocalUpdateStep<V>,
319 _full_treetn_after: &TreeTN<T, V>,
320 ) -> Result<()> {
321 Ok(())
322 }
323}
324
325/// Apply a local update sweep to a TreeTN.
326///
327/// This function orchestrates the sweep by:
328/// 1. Iterating through the sweep plan
329/// 2. For each step:
330/// a. Validate that the canonical center is a single node within the extracted subtree
331/// b. Extract the local subtree
332/// c. Call the updater to transform it
333/// d. Replace the subtree back into the TreeTN
334///
335/// # Arguments
336/// * `treetn` - The TreeTN to update (modified in place)
337/// * `plan` - The sweep plan specifying the update order
338/// * `updater` - The local updater implementation
339///
340/// # Preconditions
341/// - The TreeTN must be canonicalized with a single-node canonical center
342/// - The canonical center must be within the first step's nodes
343///
344/// # Returns
345/// `Ok(())` if the sweep completes successfully.
346///
347/// # Errors
348/// Returns an error if:
349/// - TreeTN is not canonicalized (canonical_region is empty)
350/// - canonical_region is not a single node
351/// - canonical_region is not within the extracted subtree
352/// - Subtree extraction fails
353/// - The updater returns an error
354/// - Subtree replacement fails
355pub fn apply_local_update_sweep<T, V, U>(
356 treetn: &mut TreeTN<T, V>,
357 plan: &LocalUpdateSweepPlan<V>,
358 updater: &mut U,
359) -> Result<()>
360where
361 T: TensorLike,
362 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
363 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
364 U: LocalUpdater<T, V>,
365{
366 for step in plan.iter() {
367 // Validate: canonical_region must be a single node within the step's nodes
368 let canonical_region = treetn.canonical_region();
369 if canonical_region.is_empty() {
370 return Err(anyhow::anyhow!(
371 "TreeTN is not canonicalized: canonical_region is empty"
372 ))
373 .context("apply_local_update_sweep: TreeTN must be canonicalized before sweep");
374 }
375 if canonical_region.len() != 1 {
376 return Err(anyhow::anyhow!(
377 "canonical_region must be a single node, got {} nodes",
378 canonical_region.len()
379 ))
380 .context("apply_local_update_sweep: canonical_region must be a single node");
381 }
382 let center_node = canonical_region.iter().next().unwrap();
383 let step_nodes_set: HashSet<V> = step.nodes.iter().cloned().collect();
384 if !step_nodes_set.contains(center_node) {
385 return Err(anyhow::anyhow!(
386 "canonical_region {:?} is not within the extracted subtree {:?}",
387 center_node,
388 step.nodes
389 ))
390 .context(
391 "apply_local_update_sweep: canonical_region must be within extracted subtree",
392 );
393 }
394
395 updater
396 .before_step(step, treetn)
397 .context("apply_local_update_sweep: LocalUpdater::before_step failed")?;
398
399 // Extract subtree for the nodes in this step
400 let subtree = treetn.extract_subtree(&step.nodes)?;
401
402 // Apply the update
403 let updated_subtree = updater.update(subtree, step, treetn)?;
404
405 // Replace the subtree back
406 treetn.replace_subtree(&step.nodes, &updated_subtree)?;
407
408 // Update canonical center
409 treetn.set_canonical_region([step.new_center.clone()])?;
410
411 updater
412 .after_step(step, treetn)
413 .context("apply_local_update_sweep: LocalUpdater::after_step failed")?;
414 }
415
416 Ok(())
417}
418
419// ============================================================================
420// TruncateUpdater - LocalUpdater implementation for truncation
421// ============================================================================
422
423use tensor4all_core::{Canonical, FactorizeOptions, SvdTruncationPolicy};
424
425/// Truncation updater for nsite=2 sweeps.
426///
427/// This updater performs SVD-based truncation on two-site subtrees,
428/// compressing bond dimensions while preserving the tensor train structure.
429///
430/// # Algorithm
431/// For each step with nodes [A, B] where B is the new center:
432/// 1. Contract tensors A and B into a single tensor AB
433/// 2. Factorize AB using SVD with truncation (left indices = A's external + bond to A's other neighbors)
434/// 3. The left tensor becomes the new A, the right tensor becomes the new B
435/// 4. B is the orthogonality center (isometry pointing towards B)
436///
437/// # Usage
438/// ```
439/// use tensor4all_core::{DynIndex, TensorDynLen};
440/// use tensor4all_treetn::{apply_local_update_sweep, LocalUpdateSweepPlan, TreeTN, TruncateUpdater};
441///
442/// # fn main() -> anyhow::Result<()> {
443/// let s0 = DynIndex::new_dyn(2);
444/// let bond = DynIndex::new_dyn(1);
445/// let s1 = DynIndex::new_dyn(2);
446/// let t0 = TensorDynLen::from_dense(vec![s0, bond.clone()], vec![1.0, 0.0])?;
447/// let t1 = TensorDynLen::from_dense(vec![bond, s1], vec![1.0, 0.0])?;
448/// let mut treetn = TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0, t1], vec![0, 1])?;
449/// treetn.canonicalize_mut(std::iter::once(0usize), Default::default())?;
450///
451/// let plan = LocalUpdateSweepPlan::from_treetn(&treetn, &0usize, 2).unwrap();
452/// let mut updater = TruncateUpdater::new(
453/// Some(4),
454/// Some(tensor4all_core::SvdTruncationPolicy::new(1e-10)),
455/// );
456/// apply_local_update_sweep(&mut treetn, &plan, &mut updater)?;
457///
458/// assert_eq!(treetn.node_count(), 2);
459/// # Ok(())
460/// # }
461/// ```
462#[derive(Debug, Clone)]
463pub struct TruncateUpdater {
464 /// Maximum bond dimension after truncation.
465 pub max_rank: Option<usize>,
466 /// Explicit SVD truncation policy.
467 pub svd_policy: Option<SvdTruncationPolicy>,
468}
469
470impl TruncateUpdater {
471 /// Create a new truncation updater.
472 ///
473 /// # Arguments
474 /// * `max_rank` - Maximum bond dimension (None for no limit)
475 /// * `svd_policy` - SVD truncation policy override (None uses the global default)
476 pub fn new(max_rank: Option<usize>, svd_policy: Option<SvdTruncationPolicy>) -> Self {
477 Self {
478 max_rank,
479 svd_policy,
480 }
481 }
482}
483
484impl<T, V> LocalUpdater<T, V> for TruncateUpdater
485where
486 T: TensorLike,
487 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
488 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
489{
490 fn update(
491 &mut self,
492 mut subtree: TreeTN<T, V>,
493 step: &LocalUpdateStep<V>,
494 _full_treetn: &TreeTN<T, V>,
495 ) -> Result<TreeTN<T, V>> {
496 // TruncateUpdater is designed for nsite=2
497 if step.nodes.len() != 2 {
498 return Err(anyhow::anyhow!(
499 "TruncateUpdater requires exactly 2 nodes, got {}",
500 step.nodes.len()
501 ));
502 }
503
504 let node_a = &step.nodes[0];
505 let node_b = &step.nodes[1];
506
507 // Get node indices
508 let idx_a = subtree
509 .node_index(node_a)
510 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", node_a))?;
511 let idx_b = subtree
512 .node_index(node_b)
513 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", node_b))?;
514
515 // Get the bond between A and B
516 let edge_ab = subtree
517 .edge_between(node_a, node_b)
518 .ok_or_else(|| anyhow::anyhow!("No edge between {:?} and {:?}", node_a, node_b))?;
519 let bond_ab = subtree
520 .bond_index(edge_ab)
521 .ok_or_else(|| anyhow::anyhow!("Bond index not found"))?
522 .clone();
523
524 // Contract A and B
525 let tensor_a = subtree.tensor(idx_a).unwrap();
526 let tensor_b = subtree.tensor(idx_b).unwrap();
527 let tensor_ab = T::contract(&[tensor_a, tensor_b], AllowedPairs::All)
528 .context("Failed to contract A and B")?;
529
530 // Determine left indices (indices that will remain on A after factorization)
531 // These are: all indices of A except the bond to B
532 let left_inds: Vec<_> = tensor_a
533 .external_indices()
534 .iter()
535 .filter(|idx| idx.id() != bond_ab.id())
536 .cloned()
537 .collect();
538
539 // Set up factorization options
540 let mut options = FactorizeOptions::svd().with_canonical(Canonical::Left); // Left canonical: A is isometry, B has the norm
541
542 if let Some(max_rank) = self.max_rank {
543 options = options.with_max_rank(max_rank);
544 }
545 if let Some(policy) = self.svd_policy {
546 options = options.with_svd_policy(policy);
547 }
548
549 // Factorize
550 let factorize_result = tensor_ab
551 .factorize(&left_inds, &options)
552 .map_err(|e| anyhow::anyhow!("Factorization failed: {}", e))?;
553
554 let new_tensor_a = factorize_result.left;
555 let new_tensor_b = factorize_result.right;
556 let new_bond = factorize_result.bond_index;
557
558 // Update the subtree - first update the edge bond, then the tensors
559 // The factorize result creates a new bond index, so we update the edge to use it
560 subtree.replace_edge_bond(edge_ab, new_bond.clone())?;
561 subtree.replace_tensor(idx_a, new_tensor_a)?;
562 subtree.replace_tensor(idx_b, new_tensor_b)?;
563
564 // Set ortho_towards: bond points towards new_center (B)
565 subtree.set_ortho_towards(&new_bond, Some(step.new_center.clone()));
566
567 // Set canonical center to the new center
568 subtree.set_canonical_region([step.new_center.clone()])?;
569
570 Ok(subtree)
571 }
572}
573
574// ============================================================================
575// Sub-tree extraction
576// ============================================================================
577
578impl<T, V> TreeTN<T, V>
579where
580 T: TensorLike,
581 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
582{
583 /// Extract a sub-tree from this TreeTN.
584 ///
585 /// Creates a new TreeTN containing only the specified nodes and their
586 /// connecting edges. Tensors are cloned into the new TreeTN.
587 ///
588 /// # Arguments
589 /// * `node_names` - The names of nodes to include in the sub-tree
590 ///
591 /// # Returns
592 /// A new TreeTN containing the specified sub-tree, or an error if:
593 /// - Any specified node doesn't exist
594 /// - The specified nodes don't form a connected subtree
595 ///
596 /// # Notes
597 /// - Bond indices between included nodes are preserved
598 /// - Bond indices to excluded nodes become external (site) indices in the sub-tree
599 /// - ortho_towards directions are copied for edges within the sub-tree
600 /// - canonical_region is intersected with the extracted nodes
601 pub fn extract_subtree(&self, node_names: &[V]) -> Result<Self>
602 where
603 <T::Index as IndexLike>::Id:
604 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
605 V: Ord,
606 {
607 if node_names.is_empty() {
608 return Err(anyhow::anyhow!("Cannot extract empty subtree"));
609 }
610
611 // Validate all nodes exist
612 for name in node_names {
613 if self.graph.node_index(name).is_none() {
614 return Err(anyhow::anyhow!("Node {:?} does not exist", name))
615 .context("extract_subtree: invalid node name");
616 }
617 }
618
619 // Check connectivity: the specified nodes must form a connected subtree
620 let node_indices: HashSet<_> = node_names
621 .iter()
622 .filter_map(|n| self.graph.node_index(n))
623 .collect();
624
625 if !self.site_index_network.is_connected_subset(&node_indices) {
626 return Err(anyhow::anyhow!(
627 "Specified nodes do not form a connected subtree"
628 ))
629 .context("extract_subtree: nodes must be connected");
630 }
631
632 let node_name_set: HashSet<V> = node_names.iter().cloned().collect();
633
634 // Create new TreeTN with extracted tensors
635 let mut subtree = TreeTN::<T, V>::new();
636
637 // Step 1: Add all nodes with their tensors
638 for name in node_names {
639 let node_idx = self.graph.node_index(name).unwrap();
640 let tensor = self
641 .tensor(node_idx)
642 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", name))?
643 .clone();
644
645 subtree
646 .add_tensor(name.clone(), tensor)
647 .context("extract_subtree: failed to add tensor")?;
648 }
649
650 // Step 2: Add edges between nodes in the subtree
651 // Track which edges we've already added to avoid duplicates
652 let mut added_edges: HashSet<(V, V)> = HashSet::new();
653
654 for name in node_names {
655 let neighbors: Vec<V> = self.site_index_network.neighbors(name).collect();
656
657 for neighbor in neighbors {
658 // Only add edge if neighbor is also in the subtree
659 if !node_name_set.contains(&neighbor) {
660 continue;
661 }
662
663 // Avoid adding the same edge twice (undirected)
664 let edge_key = if *name < neighbor {
665 (name.clone(), neighbor.clone())
666 } else {
667 (neighbor.clone(), name.clone())
668 };
669
670 if added_edges.contains(&edge_key) {
671 continue;
672 }
673 added_edges.insert(edge_key);
674
675 // Get bond index from original TreeTN
676 let orig_edge = self.edge_between(name, &neighbor).ok_or_else(|| {
677 anyhow::anyhow!("Edge not found between {:?} and {:?}", name, neighbor)
678 })?;
679
680 let bond_index = self
681 .bond_index(orig_edge)
682 .ok_or_else(|| anyhow::anyhow!("Bond index not found"))?
683 .clone();
684
685 // Get node indices in new subtree
686 let subtree_node_a = subtree
687 .graph
688 .node_index(name)
689 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", name))?;
690 let subtree_node_b = subtree
691 .graph
692 .node_index(&neighbor)
693 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in subtree", neighbor))?;
694
695 // Connect in subtree
696 subtree
697 .connect(subtree_node_a, &bond_index, subtree_node_b, &bond_index)
698 .context("extract_subtree: failed to connect nodes")?;
699
700 // Copy ortho_towards if it exists (keyed by full bond index)
701 if let Some(ortho_dir) = self.ortho_towards.get(&bond_index) {
702 // Only copy if the direction node is in the subtree
703 if node_name_set.contains(ortho_dir) {
704 subtree
705 .ortho_towards
706 .insert(bond_index.clone(), ortho_dir.clone());
707 }
708 }
709 }
710 }
711
712 // Step 3: Set canonical_region to intersection with extracted nodes
713 let new_center: HashSet<V> = self
714 .canonical_region
715 .intersection(&node_name_set)
716 .cloned()
717 .collect();
718 subtree.canonical_region = new_center;
719
720 // Copy canonical_form if any center nodes were included
721 if !subtree.canonical_region.is_empty() {
722 subtree.canonical_form = self.canonical_form;
723 }
724
725 Ok(subtree)
726 }
727}
728
729// ============================================================================
730// Sub-tree replacement
731// ============================================================================
732
733impl<T, V> TreeTN<T, V>
734where
735 T: TensorLike,
736 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
737{
738 /// Replace a sub-tree with another TreeTN of the same topology.
739 ///
740 /// This method replaces the tensors and ortho_towards directions for a subset
741 /// of nodes with those from another TreeTN. The replacement TreeTN must have
742 /// the same topology (nodes and edges) as the sub-tree being replaced.
743 ///
744 /// # Arguments
745 /// * `node_names` - The names of nodes to replace
746 /// * `replacement` - The TreeTN to use as replacement
747 ///
748 /// # Returns
749 /// `Ok(())` if the replacement succeeds, or an error if:
750 /// - Any specified node doesn't exist
751 /// - The replacement doesn't have the same topology as the extracted sub-tree
752 /// - Tensor replacement fails
753 ///
754 /// # Notes
755 /// - The replacement TreeTN must have the same nodes, edges, and site indices
756 /// - Bond dimensions may differ (this is the typical use case for truncation)
757 /// - ortho_towards may differ (will be copied from replacement)
758 /// - The original TreeTN is modified in-place
759 pub fn replace_subtree(&mut self, node_names: &[V], replacement: &Self) -> Result<()>
760 where
761 <T::Index as IndexLike>::Id:
762 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
763 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
764 {
765 if node_names.is_empty() {
766 return Ok(()); // Nothing to replace
767 }
768
769 // Extract current subtree for comparison
770 let current_subtree = self.extract_subtree(node_names)?;
771
772 // Verify that replacement has the same topology (nodes and edges)
773 // Note: site index network may differ due to bond dimension changes in truncation
774 if !current_subtree.same_topology(replacement) {
775 return Err(anyhow::anyhow!(
776 "Replacement TreeTN does not have the same topology as the current subtree"
777 ))
778 .context("replace_subtree: topology mismatch");
779 }
780
781 let node_name_set: HashSet<V> = node_names.iter().cloned().collect();
782 let mut processed_edges: HashSet<(V, V)> = HashSet::new();
783
784 // Step 1: Update edge bond indices FIRST (before replacing tensors)
785 // This is crucial because replace_tensor validates that tensors contain connection indices
786 for name in node_names {
787 let neighbors: Vec<V> = self.site_index_network.neighbors(name).collect();
788
789 for neighbor in neighbors {
790 // Only process edges within the subtree
791 if !node_name_set.contains(&neighbor) {
792 continue;
793 }
794
795 let edge_key = if *name < neighbor {
796 (name.clone(), neighbor.clone())
797 } else {
798 (neighbor.clone(), name.clone())
799 };
800
801 if processed_edges.contains(&edge_key) {
802 continue;
803 }
804 processed_edges.insert(edge_key.clone());
805
806 // Get edges in both self and replacement
807 let self_edge = self
808 .edge_between(name, &neighbor)
809 .ok_or_else(|| anyhow::anyhow!("Edge not found in self"))?;
810 let replacement_edge = replacement
811 .edge_between(name, &neighbor)
812 .ok_or_else(|| anyhow::anyhow!("Edge not found in replacement"))?;
813
814 // Get new bond index from replacement
815 let new_bond = replacement
816 .bond_index(replacement_edge)
817 .ok_or_else(|| anyhow::anyhow!("Bond index not found in replacement"))?
818 .clone();
819
820 // Update bond index in self
821 self.replace_edge_bond(self_edge, new_bond.clone())
822 .with_context(|| {
823 format!(
824 "replace_subtree: failed to update bond between {:?} and {:?}",
825 name, neighbor
826 )
827 })?;
828
829 // Copy ortho_towards from replacement (using the new bond)
830 match replacement.ortho_towards.get(&new_bond) {
831 Some(dir) => {
832 self.ortho_towards.insert(new_bond, dir.clone());
833 }
834 None => {
835 self.ortho_towards.remove(&new_bond);
836 }
837 }
838 }
839 }
840
841 // Step 2: Replace tensors (now bond indices match)
842 for name in node_names {
843 let self_node_idx = self
844 .graph
845 .node_index(name)
846 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", name))?;
847 let replacement_node_idx = replacement
848 .graph
849 .node_index(name)
850 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in replacement", name))?;
851
852 let new_tensor = replacement
853 .tensor(replacement_node_idx)
854 .ok_or_else(|| {
855 anyhow::anyhow!("Tensor not found for node {:?} in replacement", name)
856 })?
857 .clone();
858
859 self.replace_tensor(self_node_idx, new_tensor)
860 .with_context(|| {
861 format!(
862 "replace_subtree: failed to replace tensor at node {:?}",
863 name
864 )
865 })?;
866 }
867
868 // Update canonical_region: remove old nodes, add from replacement
869 for name in node_names {
870 self.canonical_region.remove(name);
871 }
872 for name in &replacement.canonical_region {
873 if node_name_set.contains(name) {
874 self.canonical_region.insert(name.clone());
875 }
876 }
877
878 // Update canonical_form if replacement has one
879 if replacement.canonical_form.is_some() {
880 self.canonical_form = replacement.canonical_form;
881 }
882
883 Ok(())
884 }
885}
886
887#[cfg(test)]
888mod tests;