tensor4all_treetn/treetn/truncate.rs
1//! Truncation methods for TreeTN.
2//!
3//! This module provides methods for truncating tree tensor networks.
4//!
5//! The truncation algorithm uses a two-site sweep (nsite=2) based on Euler tour traversal:
6//! 1. First, canonicalize the network towards the specified center
7//! 2. Generate a sweep plan that visits each edge twice (forward and backward)
8//! 3. For each step: extract two adjacent nodes, perform SVD-based truncation, replace
9//! 4. The canonical center moves along the sweep path
10//!
11//! This ensures all bonds are optimally truncated in both directions.
12
13use std::collections::HashSet;
14use std::hash::Hash;
15
16use anyhow::{Context, Result};
17
18use tensor4all_core::SvdTruncationPolicy;
19use tensor4all_core::{IndexLike, TensorLike};
20
21use super::localupdate::{apply_local_update_sweep, LocalUpdateSweepPlan, TruncateUpdater};
22use super::TreeTN;
23use crate::options::{CanonicalizationOptions, TruncationOptions};
24use crate::CanonicalForm;
25
26impl<T, V> TreeTN<T, V>
27where
28 T: TensorLike,
29 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
30{
31 /// Truncate the network towards the specified center using options.
32 ///
33 /// This is the recommended unified API for truncation. It accepts:
34 /// - Center nodes specified by their node names (V)
35 /// - [`TruncationOptions`] to control the SVD policy and max_rank
36 ///
37 /// # Algorithm
38 /// 1. Canonicalize the network towards the center (required for truncation)
39 /// 2. Generate a two-site sweep plan using Euler tour traversal
40 /// 3. Apply SVD-based truncation at each step, visiting each edge twice
41 ///
42 /// # Examples
43 ///
44 /// ```
45 /// use tensor4all_treetn::{TreeTN, TruncationOptions};
46 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
47 ///
48 /// // Build a 2-node chain
49 /// let s0 = DynIndex::new_dyn(2);
50 /// let bond = DynIndex::new_dyn(3);
51 /// let s1 = DynIndex::new_dyn(2);
52 ///
53 /// let t0 = TensorDynLen::from_dense(
54 /// vec![s0.clone(), bond.clone()],
55 /// vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0],
56 /// ).unwrap();
57 /// let t1 = TensorDynLen::from_dense(
58 /// vec![bond.clone(), s1.clone()],
59 /// vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0],
60 /// ).unwrap();
61 ///
62 /// let tn = TreeTN::<_, String>::from_tensors(
63 /// vec![t0, t1],
64 /// vec!["A".to_string(), "B".to_string()],
65 /// ).unwrap();
66 ///
67 /// // Truncate with max rank 2 towards node "A"
68 /// let tn = tn.truncate(
69 /// ["A".to_string()],
70 /// TruncationOptions::default().with_max_rank(2),
71 /// ).unwrap();
72 ///
73 /// // Bond dimension is now at most 2
74 /// assert_eq!(tn.node_count(), 2);
75 /// ```
76 pub fn truncate(
77 mut self,
78 canonical_region: impl IntoIterator<Item = V>,
79 options: TruncationOptions,
80 ) -> Result<Self>
81 where
82 V: Ord,
83 <T::Index as IndexLike>::Id: Ord,
84 {
85 self.truncate_impl(
86 canonical_region,
87 options.svd_policy(),
88 options.truncation.max_rank,
89 "truncate",
90 )?;
91 Ok(self)
92 }
93
94 /// Truncate the network in-place towards the specified center using options.
95 ///
96 /// This is the `&mut self` version of [`Self::truncate`].
97 pub fn truncate_mut(
98 &mut self,
99 canonical_region: impl IntoIterator<Item = V>,
100 options: TruncationOptions,
101 ) -> Result<()>
102 where
103 V: Ord,
104 <T::Index as IndexLike>::Id: Ord,
105 {
106 self.truncate_impl(
107 canonical_region,
108 options.svd_policy(),
109 options.truncation.max_rank,
110 "truncate_mut",
111 )
112 }
113
114 /// Internal implementation for truncation.
115 ///
116 /// Uses LocalUpdateSweepPlan with TruncateUpdater for full two-site sweeps.
117 pub(crate) fn truncate_impl(
118 &mut self,
119 canonical_region: impl IntoIterator<Item = V>,
120 svd_policy: Option<SvdTruncationPolicy>,
121 max_rank: Option<usize>,
122 context_name: &str,
123 ) -> Result<()>
124 where
125 V: Ord,
126 <T::Index as IndexLike>::Id: Ord,
127 {
128 // Collect center nodes
129 let center_nodes: HashSet<V> = canonical_region.into_iter().collect();
130
131 if center_nodes.is_empty() {
132 return Ok(()); // Nothing to do
133 }
134
135 // Currently only single-node center is supported for truncation
136 if center_nodes.len() != 1 {
137 return Err(anyhow::anyhow!(
138 "truncate currently requires a single-node center, got {} nodes",
139 center_nodes.len()
140 ))
141 .context(format!("{}: multi-node center not supported", context_name));
142 }
143
144 let center_node = center_nodes.iter().next().unwrap().clone();
145
146 // Step 1: Canonicalize towards the center (required before truncation sweep)
147 let canonicalize_options =
148 CanonicalizationOptions::default().with_form(CanonicalForm::Unitary);
149 self.canonicalize_impl(
150 [center_node.clone()],
151 canonicalize_options.form,
152 &format!("{}: pre-canonicalize", context_name),
153 )?;
154
155 // Step 2: Generate sweep plan (nsite=2 for two-site truncation)
156 let plan = LocalUpdateSweepPlan::from_treetn(self, ¢er_node, 2)
157 .ok_or_else(|| {
158 anyhow::anyhow!("Failed to create sweep plan from center {:?}", center_node)
159 })
160 .context(format!("{}: sweep plan creation failed", context_name))?;
161
162 // If no steps (single node network), nothing more to do
163 if plan.is_empty() {
164 return Ok(());
165 }
166
167 // Step 3: Apply truncation sweep
168 let mut updater = TruncateUpdater::new(max_rank, svd_policy);
169 apply_local_update_sweep(self, &plan, &mut updater)
170 .context(format!("{}: truncation sweep failed", context_name))?;
171
172 // The canonical form is maintained by the sweep
173 self.canonical_form = Some(CanonicalForm::Unitary);
174
175 Ok(())
176 }
177}