tensor4all_treetn/treetn/canonicalize.rs
1//! Canonicalization methods for TreeTN.
2//!
3//! This module provides methods for canonicalizing tree tensor networks.
4
5use std::collections::HashSet;
6use std::hash::Hash;
7
8use anyhow::{Context, Result};
9
10use crate::algorithm::CanonicalForm;
11use tensor4all_core::{Canonical, FactorizeAlg, FactorizeOptions, TensorLike};
12
13use super::TreeTN;
14use crate::options::CanonicalizationOptions;
15
16impl<T, V> TreeTN<T, V>
17where
18 T: TensorLike,
19 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
20{
21 /// Canonicalize the network towards the specified center using options.
22 ///
23 /// This is the recommended unified API for canonicalization. It accepts:
24 /// - Center nodes specified by their node names (V)
25 /// - [`CanonicalizationOptions`] to control the form and force behavior
26 ///
27 /// # Behavior
28 /// - If `options.force` is false (default):
29 /// - Already at target with same form: returns unchanged (no-op)
30 /// - Different form: returns an error (use `options.force()` to override)
31 /// - If `options.force` is true:
32 /// - Always performs full canonicalization
33 ///
34 /// # Examples
35 ///
36 /// ```
37 /// use tensor4all_treetn::{TreeTN, CanonicalizationOptions};
38 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
39 ///
40 /// let s0 = DynIndex::new_dyn(2);
41 /// let bond = DynIndex::new_dyn(3);
42 /// let s1 = DynIndex::new_dyn(2);
43 ///
44 /// let t0 = TensorDynLen::from_dense(
45 /// vec![s0.clone(), bond.clone()],
46 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
47 /// ).unwrap();
48 /// let t1 = TensorDynLen::from_dense(
49 /// vec![bond.clone(), s1.clone()],
50 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
51 /// ).unwrap();
52 ///
53 /// let tn = TreeTN::<_, String>::from_tensors(
54 /// vec![t0, t1],
55 /// vec!["A".to_string(), "B".to_string()],
56 /// ).unwrap();
57 ///
58 /// // Canonicalize towards node "A"
59 /// let tn = tn.canonicalize(["A".to_string()], CanonicalizationOptions::default()).unwrap();
60 /// assert!(tn.is_canonicalized());
61 /// ```
62 pub fn canonicalize(
63 mut self,
64 canonical_region: impl IntoIterator<Item = V>,
65 options: CanonicalizationOptions,
66 ) -> Result<Self> {
67 let center_v: HashSet<V> = canonical_region.into_iter().collect();
68
69 // Smart behavior when not forced
70 if !options.force {
71 // Check if already canonicalized with a different form
72 if let Some(current_form) = self.canonical_form {
73 if current_form != options.form {
74 return Err(anyhow::anyhow!(
75 "Cannot move ortho center: current form is {:?} but {:?} was requested. \
76 Use CanonicalizationOptions::forced() to re-canonicalize with a different form.",
77 current_form,
78 options.form
79 ))
80 .context("canonicalize: form mismatch");
81 }
82 }
83
84 // Check if already at target
85 if self.canonical_region == center_v && self.canonical_form == Some(options.form) {
86 return Ok(self);
87 }
88 }
89
90 // Perform canonicalization
91 self.canonicalize_impl(center_v, options.form, "canonicalize")?;
92 Ok(self)
93 }
94
95 /// Canonicalize the network in-place towards the specified center using options.
96 ///
97 /// This is the `&mut self` version of [`Self::canonicalize`].
98 pub fn canonicalize_mut(
99 &mut self,
100 canonical_region: impl IntoIterator<Item = V>,
101 options: CanonicalizationOptions,
102 ) -> Result<()>
103 where
104 Self: Default,
105 {
106 let taken = std::mem::take(self);
107 match taken.canonicalize(canonical_region, options) {
108 Ok(result) => {
109 *self = result;
110 Ok(())
111 }
112 Err(e) => Err(e),
113 }
114 }
115
116 /// Internal implementation for canonicalization.
117 ///
118 /// This is the core canonicalization logic that public methods delegate to.
119 pub(crate) fn canonicalize_impl(
120 &mut self,
121 canonical_region: impl IntoIterator<Item = V>,
122 form: CanonicalForm,
123 context_name: &str,
124 ) -> Result<()> {
125 // Determine algorithm from form
126 let alg = match form {
127 CanonicalForm::Unitary => FactorizeAlg::QR,
128 CanonicalForm::LU => FactorizeAlg::LU,
129 CanonicalForm::CI => FactorizeAlg::CI,
130 };
131
132 // Prepare sweep context
133 let sweep_ctx = self.prepare_sweep_to_center(canonical_region, context_name)?;
134
135 // If no centers (empty), nothing to do
136 let sweep_ctx = match sweep_ctx {
137 Some(ctx) => ctx,
138 None => return Ok(()),
139 };
140
141 // Set up factorization options (no truncation for canonicalization)
142 let factorize_options = FactorizeOptions {
143 alg,
144 canonical: Canonical::Left,
145 max_rank: None,
146 svd_policy: None,
147 qr_rtol: None,
148 };
149
150 // Process edges in order (leaves towards center)
151 for (src, dst) in &sweep_ctx.edges {
152 self.sweep_edge(*src, *dst, &factorize_options, context_name)?;
153 }
154
155 // Set the canonical form
156 self.canonical_form = Some(form);
157
158 Ok(())
159 }
160}