Skip to main content

tensor4all_treetn/treetn/
truncate.rs

1//! Truncation methods for TreeTN.
2//!
3//! This module provides methods for truncating tree tensor networks.
4//!
5//! The truncation algorithm uses a two-site sweep (nsite=2) based on Euler tour traversal:
6//! 1. First, canonicalize the network towards the specified center
7//! 2. Generate a sweep plan that visits each edge twice (forward and backward)
8//! 3. For each step: extract two adjacent nodes, perform SVD-based truncation, replace
9//! 4. The canonical center moves along the sweep path
10//!
11//! This ensures all bonds are optimally truncated in both directions.
12
13use std::collections::HashSet;
14use std::hash::Hash;
15
16use anyhow::{Context, Result};
17
18use tensor4all_core::SvdTruncationPolicy;
19use tensor4all_core::{IndexLike, TensorLike};
20
21use super::localupdate::{apply_local_update_sweep, LocalUpdateSweepPlan, TruncateUpdater};
22use super::TreeTN;
23use crate::options::{CanonicalizationOptions, TruncationOptions};
24use crate::CanonicalForm;
25
26impl<T, V> TreeTN<T, V>
27where
28    T: TensorLike,
29    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
30{
31    /// Truncate the network towards the specified center using options.
32    ///
33    /// This is the recommended unified API for truncation. It accepts:
34    /// - Center nodes specified by their node names (V)
35    /// - [`TruncationOptions`] to control the SVD policy and max_rank
36    ///
37    /// # Algorithm
38    /// 1. Canonicalize the network towards the center (required for truncation)
39    /// 2. Generate a two-site sweep plan using Euler tour traversal
40    /// 3. Apply SVD-based truncation at each step, visiting each edge twice
41    ///
42    /// # Examples
43    ///
44    /// ```
45    /// use tensor4all_treetn::{TreeTN, TruncationOptions};
46    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
47    ///
48    /// // Build a 2-node chain
49    /// let s0 = DynIndex::new_dyn(2);
50    /// let bond = DynIndex::new_dyn(3);
51    /// let s1 = DynIndex::new_dyn(2);
52    ///
53    /// let t0 = TensorDynLen::from_dense(
54    ///     vec![s0.clone(), bond.clone()],
55    ///     vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0],
56    /// ).unwrap();
57    /// let t1 = TensorDynLen::from_dense(
58    ///     vec![bond.clone(), s1.clone()],
59    ///     vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0],
60    /// ).unwrap();
61    ///
62    /// let tn = TreeTN::<_, String>::from_tensors(
63    ///     vec![t0, t1],
64    ///     vec!["A".to_string(), "B".to_string()],
65    /// ).unwrap();
66    ///
67    /// // Truncate with max rank 2 towards node "A"
68    /// let tn = tn.truncate(
69    ///     ["A".to_string()],
70    ///     TruncationOptions::default().with_max_rank(2),
71    /// ).unwrap();
72    ///
73    /// // Bond dimension is now at most 2
74    /// assert_eq!(tn.node_count(), 2);
75    /// ```
76    pub fn truncate(
77        mut self,
78        canonical_region: impl IntoIterator<Item = V>,
79        options: TruncationOptions,
80    ) -> Result<Self>
81    where
82        V: Ord,
83        <T::Index as IndexLike>::Id: Ord,
84    {
85        self.truncate_impl(
86            canonical_region,
87            options.svd_policy(),
88            options.truncation.max_rank,
89            "truncate",
90        )?;
91        Ok(self)
92    }
93
94    /// Truncate the network in-place towards the specified center using options.
95    ///
96    /// This is the `&mut self` version of [`Self::truncate`].
97    pub fn truncate_mut(
98        &mut self,
99        canonical_region: impl IntoIterator<Item = V>,
100        options: TruncationOptions,
101    ) -> Result<()>
102    where
103        V: Ord,
104        <T::Index as IndexLike>::Id: Ord,
105    {
106        self.truncate_impl(
107            canonical_region,
108            options.svd_policy(),
109            options.truncation.max_rank,
110            "truncate_mut",
111        )
112    }
113
114    /// Internal implementation for truncation.
115    ///
116    /// Uses LocalUpdateSweepPlan with TruncateUpdater for full two-site sweeps.
117    pub(crate) fn truncate_impl(
118        &mut self,
119        canonical_region: impl IntoIterator<Item = V>,
120        svd_policy: Option<SvdTruncationPolicy>,
121        max_rank: Option<usize>,
122        context_name: &str,
123    ) -> Result<()>
124    where
125        V: Ord,
126        <T::Index as IndexLike>::Id: Ord,
127    {
128        // Collect center nodes
129        let center_nodes: HashSet<V> = canonical_region.into_iter().collect();
130
131        if center_nodes.is_empty() {
132            return Ok(()); // Nothing to do
133        }
134
135        // Currently only single-node center is supported for truncation
136        if center_nodes.len() != 1 {
137            return Err(anyhow::anyhow!(
138                "truncate currently requires a single-node center, got {} nodes",
139                center_nodes.len()
140            ))
141            .context(format!("{}: multi-node center not supported", context_name));
142        }
143
144        let center_node = center_nodes.iter().next().unwrap().clone();
145
146        // Step 1: Canonicalize towards the center (required before truncation sweep)
147        let canonicalize_options =
148            CanonicalizationOptions::default().with_form(CanonicalForm::Unitary);
149        self.canonicalize_impl(
150            [center_node.clone()],
151            canonicalize_options.form,
152            &format!("{}: pre-canonicalize", context_name),
153        )?;
154
155        // Step 2: Generate sweep plan (nsite=2 for two-site truncation)
156        let plan = LocalUpdateSweepPlan::from_treetn(self, &center_node, 2)
157            .ok_or_else(|| {
158                anyhow::anyhow!("Failed to create sweep plan from center {:?}", center_node)
159            })
160            .context(format!("{}: sweep plan creation failed", context_name))?;
161
162        // If no steps (single node network), nothing more to do
163        if plan.is_empty() {
164            return Ok(());
165        }
166
167        // Step 3: Apply truncation sweep
168        let mut updater = TruncateUpdater::new(max_rank, svd_policy);
169        apply_local_update_sweep(self, &plan, &mut updater)
170            .context(format!("{}: truncation sweep failed", context_name))?;
171
172        // The canonical form is maintained by the sweep
173        self.canonical_form = Some(CanonicalForm::Unitary);
174
175        Ok(())
176    }
177}