Skip to main content

tensor4all_treetn/treetn/
canonicalize.rs

1//! Canonicalization methods for TreeTN.
2//!
3//! This module provides methods for canonicalizing tree tensor networks.
4
5use std::collections::HashSet;
6use std::hash::Hash;
7
8use anyhow::{Context, Result};
9
10use crate::algorithm::CanonicalForm;
11use tensor4all_core::{Canonical, FactorizeAlg, FactorizeOptions, TensorLike};
12
13use super::TreeTN;
14use crate::options::CanonicalizationOptions;
15
16impl<T, V> TreeTN<T, V>
17where
18    T: TensorLike,
19    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
20{
21    /// Canonicalize the network towards the specified center using options.
22    ///
23    /// This is the recommended unified API for canonicalization. It accepts:
24    /// - Center nodes specified by their node names (V)
25    /// - [`CanonicalizationOptions`] to control the form and force behavior
26    ///
27    /// # Behavior
28    /// - If `options.force` is false (default):
29    ///   - Already at target with same form: returns unchanged (no-op)
30    ///   - Different form: returns an error (use `options.force()` to override)
31    /// - If `options.force` is true:
32    ///   - Always performs full canonicalization
33    ///
34    /// # Examples
35    ///
36    /// ```
37    /// use tensor4all_treetn::{TreeTN, CanonicalizationOptions};
38    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
39    ///
40    /// let s0 = DynIndex::new_dyn(2);
41    /// let bond = DynIndex::new_dyn(3);
42    /// let s1 = DynIndex::new_dyn(2);
43    ///
44    /// let t0 = TensorDynLen::from_dense(
45    ///     vec![s0.clone(), bond.clone()],
46    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
47    /// ).unwrap();
48    /// let t1 = TensorDynLen::from_dense(
49    ///     vec![bond.clone(), s1.clone()],
50    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
51    /// ).unwrap();
52    ///
53    /// let tn = TreeTN::<_, String>::from_tensors(
54    ///     vec![t0, t1],
55    ///     vec!["A".to_string(), "B".to_string()],
56    /// ).unwrap();
57    ///
58    /// // Canonicalize towards node "A"
59    /// let tn = tn.canonicalize(["A".to_string()], CanonicalizationOptions::default()).unwrap();
60    /// assert!(tn.is_canonicalized());
61    /// ```
62    pub fn canonicalize(
63        mut self,
64        canonical_region: impl IntoIterator<Item = V>,
65        options: CanonicalizationOptions,
66    ) -> Result<Self> {
67        let center_v: HashSet<V> = canonical_region.into_iter().collect();
68
69        // Smart behavior when not forced
70        if !options.force {
71            // Check if already canonicalized with a different form
72            if let Some(current_form) = self.canonical_form {
73                if current_form != options.form {
74                    return Err(anyhow::anyhow!(
75                        "Cannot move ortho center: current form is {:?} but {:?} was requested. \
76                         Use CanonicalizationOptions::forced() to re-canonicalize with a different form.",
77                        current_form,
78                        options.form
79                    ))
80                    .context("canonicalize: form mismatch");
81                }
82            }
83
84            // Check if already at target
85            if self.canonical_region == center_v && self.canonical_form == Some(options.form) {
86                return Ok(self);
87            }
88        }
89
90        // Perform canonicalization
91        self.canonicalize_impl(center_v, options.form, "canonicalize")?;
92        Ok(self)
93    }
94
95    /// Canonicalize the network in-place towards the specified center using options.
96    ///
97    /// This is the `&mut self` version of [`Self::canonicalize`].
98    pub fn canonicalize_mut(
99        &mut self,
100        canonical_region: impl IntoIterator<Item = V>,
101        options: CanonicalizationOptions,
102    ) -> Result<()>
103    where
104        Self: Default,
105    {
106        let taken = std::mem::take(self);
107        match taken.canonicalize(canonical_region, options) {
108            Ok(result) => {
109                *self = result;
110                Ok(())
111            }
112            Err(e) => Err(e),
113        }
114    }
115
116    /// Internal implementation for canonicalization.
117    ///
118    /// This is the core canonicalization logic that public methods delegate to.
119    pub(crate) fn canonicalize_impl(
120        &mut self,
121        canonical_region: impl IntoIterator<Item = V>,
122        form: CanonicalForm,
123        context_name: &str,
124    ) -> Result<()> {
125        // Determine algorithm from form
126        let alg = match form {
127            CanonicalForm::Unitary => FactorizeAlg::QR,
128            CanonicalForm::LU => FactorizeAlg::LU,
129            CanonicalForm::CI => FactorizeAlg::CI,
130        };
131
132        // Prepare sweep context
133        let sweep_ctx = self.prepare_sweep_to_center(canonical_region, context_name)?;
134
135        // If no centers (empty), nothing to do
136        let sweep_ctx = match sweep_ctx {
137            Some(ctx) => ctx,
138            None => return Ok(()),
139        };
140
141        // Set up factorization options (no truncation for canonicalization)
142        let factorize_options = FactorizeOptions {
143            alg,
144            canonical: Canonical::Left,
145            max_rank: None,
146            svd_policy: None,
147            qr_rtol: None,
148        };
149
150        // Process edges in order (leaves towards center)
151        for (src, dst) in &sweep_ctx.edges {
152            self.sweep_edge(*src, *dst, &factorize_options, context_name)?;
153        }
154
155        // Set the canonical form
156        self.canonical_form = Some(form);
157
158        Ok(())
159    }
160}