Skip to main content

tensor4all_treetn/
options.rs

1//! Options and traits for TreeTN operations.
2//!
3//! Provides:
4//! - [`CanonicalizationOptions`]: Options for canonicalization
5//! - [`TruncationOptions`]: Options for truncation
6//! - [`SplitOptions`]: Options for split operations
7//! - [`RestructureOptions`]: Options for multi-phase restructure operations
8
9use crate::algorithm::CanonicalForm;
10use crate::treetn::SwapOptions;
11use tensor4all_core::SvdTruncationPolicy;
12
13#[derive(Debug, Clone, Copy, Default, PartialEq)]
14pub(crate) struct FactorizationToleranceOptions {
15    pub max_rank: Option<usize>,
16    pub svd_policy: Option<SvdTruncationPolicy>,
17    pub qr_rtol: Option<f64>,
18}
19
20/// Options for canonicalization operations.
21///
22/// # Builder Pattern
23///
24/// ```
25/// use tensor4all_treetn::{CanonicalForm, CanonicalizationOptions};
26///
27/// let options = CanonicalizationOptions::default()
28///     .with_form(CanonicalForm::LU)
29///     .force();
30///
31/// assert!(matches!(options.form, CanonicalForm::LU));
32/// assert!(options.force);
33/// ```
34#[derive(Debug, Clone, Copy)]
35pub struct CanonicalizationOptions {
36    /// Canonical form to use (QR, LU, or CI)
37    pub form: CanonicalForm,
38    /// If true, always performs full canonicalization.
39    /// If false, checks current state and may skip or optimize.
40    pub force: bool,
41}
42
43impl Default for CanonicalizationOptions {
44    fn default() -> Self {
45        Self {
46            form: CanonicalForm::Unitary,
47            force: false,
48        }
49    }
50}
51
52impl CanonicalizationOptions {
53    /// Create options with default settings.
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    /// Create options that force full canonicalization.
59    pub fn forced() -> Self {
60        Self {
61            form: CanonicalForm::Unitary,
62            force: true,
63        }
64    }
65
66    /// Set the canonical form.
67    pub fn with_form(mut self, form: CanonicalForm) -> Self {
68        self.form = form;
69        self
70    }
71
72    /// Set force mode (always perform full canonicalization).
73    pub fn force(mut self) -> Self {
74        self.force = true;
75        self
76    }
77
78    /// Disable force mode (check current state before canonicalizing).
79    pub fn smart(mut self) -> Self {
80        self.force = false;
81        self
82    }
83}
84
85/// Options for truncation operations.
86///
87/// # Builder Pattern
88///
89/// ```
90/// use tensor4all_core::SvdTruncationPolicy;
91/// use tensor4all_treetn::TruncationOptions;
92///
93/// let options = TruncationOptions::default()
94///     .with_max_rank(50)
95///     .with_svd_policy(SvdTruncationPolicy::new(1e-10));
96///
97/// assert_eq!(options.max_rank(), Some(50));
98/// assert_eq!(options.svd_policy(), Some(SvdTruncationPolicy::new(1e-10)));
99/// ```
100#[derive(Debug, Clone, Copy)]
101pub struct TruncationOptions {
102    #[allow(dead_code)]
103    pub(crate) form: CanonicalForm,
104    pub(crate) truncation: FactorizationToleranceOptions,
105}
106
107impl Default for TruncationOptions {
108    fn default() -> Self {
109        Self {
110            form: CanonicalForm::Unitary,
111            truncation: FactorizationToleranceOptions::default(),
112        }
113    }
114}
115
116impl TruncationOptions {
117    /// Create options with default settings (no truncation limits).
118    pub fn new() -> Self {
119        Self::default()
120    }
121
122    /// Create options with a maximum rank.
123    pub fn with_max_rank(mut self, rank: usize) -> Self {
124        self.truncation.max_rank = Some(rank);
125        self
126    }
127
128    /// Set the SVD truncation policy used during the truncation sweep.
129    pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
130        self.truncation.svd_policy = Some(policy);
131        self
132    }
133
134    /// Get the SVD truncation policy.
135    pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
136        self.truncation.svd_policy
137    }
138
139    /// Get max_rank.
140    pub fn max_rank(&self) -> Option<usize> {
141        self.truncation.max_rank
142    }
143}
144
145/// Options for split operations.
146///
147/// # Builder Pattern
148///
149/// ```
150/// use tensor4all_core::SvdTruncationPolicy;
151/// use tensor4all_treetn::{CanonicalForm, SplitOptions};
152///
153/// let options = SplitOptions::default()
154///     .with_max_rank(50)
155///     .with_svd_policy(SvdTruncationPolicy::new(1e-10))
156///     .with_qr_rtol(1e-12)
157///     .with_final_sweep(true);
158///
159/// assert!(matches!(options.form, CanonicalForm::Unitary));
160/// assert_eq!(options.max_rank(), Some(50));
161/// assert_eq!(options.svd_policy(), Some(SvdTruncationPolicy::new(1e-10)));
162/// assert_eq!(options.qr_rtol(), Some(1e-12));
163/// assert!(options.final_sweep);
164/// ```
165#[derive(Debug, Clone, Copy)]
166pub struct SplitOptions {
167    /// Canonical form / algorithm to use (SVD, QR, etc.)
168    pub form: CanonicalForm,
169    /// Algorithm-aware factorization tolerances.
170    pub(crate) truncation: FactorizationToleranceOptions,
171    /// Whether to perform a final sweep for global bond dimension optimization
172    pub final_sweep: bool,
173}
174
175impl Default for SplitOptions {
176    fn default() -> Self {
177        Self {
178            form: CanonicalForm::Unitary,
179            truncation: FactorizationToleranceOptions::default(),
180            final_sweep: false,
181        }
182    }
183}
184
185impl SplitOptions {
186    /// Create options with default settings.
187    pub fn new() -> Self {
188        Self::default()
189    }
190
191    /// Create options with a maximum rank.
192    pub fn with_max_rank(mut self, rank: usize) -> Self {
193        self.truncation.max_rank = Some(rank);
194        self
195    }
196
197    /// Set the SVD truncation policy used when `form` is unitary/SVD-based.
198    pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
199        self.truncation.svd_policy = Some(policy);
200        self
201    }
202
203    /// Set the QR-specific relative tolerance used when `form` is QR-based.
204    pub fn with_qr_rtol(mut self, rtol: f64) -> Self {
205        self.truncation.qr_rtol = Some(rtol);
206        self
207    }
208
209    /// Set the canonical form / algorithm.
210    pub fn with_form(mut self, form: CanonicalForm) -> Self {
211        self.form = form;
212        self
213    }
214
215    /// Enable or disable final sweep for global optimization.
216    pub fn with_final_sweep(mut self, final_sweep: bool) -> Self {
217        self.final_sweep = final_sweep;
218        self
219    }
220
221    /// Get the SVD truncation policy.
222    pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
223        self.truncation.svd_policy
224    }
225
226    /// Get the QR-specific relative tolerance.
227    pub fn qr_rtol(&self) -> Option<f64> {
228        self.truncation.qr_rtol
229    }
230
231    /// Get max_rank.
232    pub fn max_rank(&self) -> Option<usize> {
233        self.truncation.max_rank
234    }
235}
236
237/// Options for `TreeTN::restructure_to` multi-phase restructures.
238///
239/// `RestructureOptions` combines the three phases in the approved B2a design:
240/// a split/refinement phase, a site-transport phase, and an optional final
241/// truncation sweep after the target structure has been assembled.
242///
243/// Related types:
244/// - [`SplitOptions`] controls exact splitting plus any optional final sweep
245///   inside the split primitive.
246/// - [`SwapOptions`] controls bond truncation during site-index transport.
247/// - [`TruncationOptions`] can be applied once at the end of the full
248///   restructure to clean up bond dimensions on the final topology.
249///
250/// When in doubt, start with `RestructureOptions::default()`: exact splitting,
251/// exact transport, and no extra final truncation sweep.
252///
253/// # Examples
254///
255/// ```
256/// use tensor4all_treetn::{
257///     RestructureOptions, SplitOptions, SwapOptions, TruncationOptions,
258/// };
259/// use tensor4all_core::SvdTruncationPolicy;
260///
261/// let options = RestructureOptions::new()
262///     .with_split(SplitOptions::new().with_max_rank(32))
263///     .with_swap(SwapOptions {
264///         max_rank: Some(16),
265///         rtol: Some(1e-10),
266///     })
267///     .with_final_truncation(
268///         TruncationOptions::new().with_svd_policy(SvdTruncationPolicy::new(1e-12)),
269///     );
270///
271/// assert_eq!(options.split.max_rank(), Some(32));
272/// assert!(!options.split.final_sweep);
273/// assert_eq!(options.swap.max_rank, Some(16));
274/// assert_eq!(options.swap.rtol, Some(1e-10));
275/// assert_eq!(
276///     options
277///         .final_truncation
278///         .as_ref()
279///         .and_then(TruncationOptions::svd_policy),
280///     Some(SvdTruncationPolicy::new(1e-12))
281/// );
282/// ```
283#[derive(Debug, Clone, Default)]
284pub struct RestructureOptions {
285    /// Options for the split/refinement phase.
286    ///
287    /// These settings matter when a current node must be factored into multiple
288    /// fragments before any fragment movement can happen. Higher `max_rank`,
289    /// stricter `svd_policy`, and smaller `qr_rtol` preserve more fidelity but
290    /// can increase intermediate bond dimensions. `final_sweep` should usually
291    /// remain `false` here unless a split-only workflow is being optimized in
292    /// isolation.
293    pub split: SplitOptions,
294    /// Options for the site-transport / swap phase.
295    ///
296    /// This phase only moves already planned fragments across existing edges.
297    /// Leaving both fields unset keeps swaps exact. Setting `max_rank` or
298    /// `rtol` can control intermediate rank growth, but may introduce
299    /// approximation earlier than the optional final truncation sweep.
300    pub swap: SwapOptions,
301    /// Optional final truncation sweep on the fully restructured network.
302    ///
303    /// Use this when the split and swap phases should remain as exact as
304    /// possible, and compression should happen only after the target topology
305    /// and grouping have been reached. `None` disables this cleanup pass.
306    pub final_truncation: Option<TruncationOptions>,
307}
308
309impl RestructureOptions {
310    /// Create options with exact split/swap phases and no final cleanup sweep.
311    ///
312    /// # Returns
313    /// A `RestructureOptions` value equivalent to [`Default::default`].
314    ///
315    /// # Examples
316    ///
317    /// ```
318    /// use tensor4all_treetn::RestructureOptions;
319    ///
320    /// let options = RestructureOptions::new();
321    ///
322    /// assert!(!options.split.final_sweep);
323    /// assert_eq!(options.swap.max_rank, None);
324    /// assert_eq!(options.swap.rtol, None);
325    /// assert!(options.final_truncation.is_none());
326    /// ```
327    pub fn new() -> Self {
328        Self::default()
329    }
330
331    /// Replace the split-phase options.
332    ///
333    /// # Arguments
334    /// * `split` - Split/refinement settings used before fragment transport.
335    ///
336    /// # Returns
337    /// Updated restructure options using the provided split settings.
338    ///
339    /// # Examples
340    ///
341    /// ```
342    /// use tensor4all_treetn::{RestructureOptions, SplitOptions};
343    ///
344    /// let options = RestructureOptions::new()
345    ///     .with_split(SplitOptions::new().with_max_rank(24).with_final_sweep(true));
346    ///
347    /// assert_eq!(options.split.max_rank(), Some(24));
348    /// assert!(options.split.final_sweep);
349    /// ```
350    pub fn with_split(mut self, split: SplitOptions) -> Self {
351        self.split = split;
352        self
353    }
354
355    /// Replace the swap/transport options.
356    ///
357    /// # Arguments
358    /// * `swap` - Truncation settings applied during fragment movement.
359    ///
360    /// # Returns
361    /// Updated restructure options using the provided swap settings.
362    ///
363    /// # Examples
364    ///
365    /// ```
366    /// use tensor4all_treetn::{RestructureOptions, SwapOptions};
367    ///
368    /// let options = RestructureOptions::new().with_swap(SwapOptions {
369    ///     max_rank: Some(12),
370    ///     rtol: Some(1e-8),
371    /// });
372    ///
373    /// assert_eq!(options.swap.max_rank, Some(12));
374    /// assert_eq!(options.swap.rtol, Some(1e-8));
375    /// ```
376    pub fn with_swap(mut self, swap: SwapOptions) -> Self {
377        self.swap = swap;
378        self
379    }
380
381    /// Set the optional final truncation sweep.
382    ///
383    /// # Arguments
384    /// * `final_truncation` - Final cleanup sweep to run on the target
385    ///   topology.
386    ///
387    /// # Returns
388    /// Updated restructure options using the provided final sweep settings.
389    ///
390    /// # Examples
391    ///
392    /// ```
393    /// use tensor4all_treetn::{RestructureOptions, TruncationOptions};
394    /// use tensor4all_core::SvdTruncationPolicy;
395    ///
396    /// let options = RestructureOptions::new()
397    ///     .with_final_truncation(
398    ///         TruncationOptions::new()
399    ///             .with_max_rank(10)
400    ///             .with_svd_policy(SvdTruncationPolicy::new(1e-10)),
401    ///     );
402    ///
403    /// assert_eq!(
404    ///     options
405    ///         .final_truncation
406    ///         .as_ref()
407    ///         .and_then(TruncationOptions::max_rank),
408    ///     Some(10)
409    /// );
410    /// ```
411    pub fn with_final_truncation(mut self, final_truncation: TruncationOptions) -> Self {
412        self.final_truncation = Some(final_truncation);
413        self
414    }
415}
416
417#[cfg(test)]
418mod tests;