tensor4all_treetn/options.rs
1//! Options and traits for TreeTN operations.
2//!
3//! Provides:
4//! - [`CanonicalizationOptions`]: Options for canonicalization
5//! - [`TruncationOptions`]: Options for truncation
6//! - [`SplitOptions`]: Options for split operations
7//! - [`RestructureOptions`]: Options for multi-phase restructure operations
8
9use crate::algorithm::CanonicalForm;
10use crate::treetn::SwapOptions;
11use tensor4all_core::SvdTruncationPolicy;
12
13#[derive(Debug, Clone, Copy, Default, PartialEq)]
14pub(crate) struct FactorizationToleranceOptions {
15 pub max_rank: Option<usize>,
16 pub svd_policy: Option<SvdTruncationPolicy>,
17 pub qr_rtol: Option<f64>,
18}
19
20/// Options for canonicalization operations.
21///
22/// # Builder Pattern
23///
24/// ```
25/// use tensor4all_treetn::{CanonicalForm, CanonicalizationOptions};
26///
27/// let options = CanonicalizationOptions::default()
28/// .with_form(CanonicalForm::LU)
29/// .force();
30///
31/// assert!(matches!(options.form, CanonicalForm::LU));
32/// assert!(options.force);
33/// ```
34#[derive(Debug, Clone, Copy)]
35pub struct CanonicalizationOptions {
36 /// Canonical form to use (QR, LU, or CI)
37 pub form: CanonicalForm,
38 /// If true, always performs full canonicalization.
39 /// If false, checks current state and may skip or optimize.
40 pub force: bool,
41}
42
43impl Default for CanonicalizationOptions {
44 fn default() -> Self {
45 Self {
46 form: CanonicalForm::Unitary,
47 force: false,
48 }
49 }
50}
51
52impl CanonicalizationOptions {
53 /// Create options with default settings.
54 pub fn new() -> Self {
55 Self::default()
56 }
57
58 /// Create options that force full canonicalization.
59 pub fn forced() -> Self {
60 Self {
61 form: CanonicalForm::Unitary,
62 force: true,
63 }
64 }
65
66 /// Set the canonical form.
67 pub fn with_form(mut self, form: CanonicalForm) -> Self {
68 self.form = form;
69 self
70 }
71
72 /// Set force mode (always perform full canonicalization).
73 pub fn force(mut self) -> Self {
74 self.force = true;
75 self
76 }
77
78 /// Disable force mode (check current state before canonicalizing).
79 pub fn smart(mut self) -> Self {
80 self.force = false;
81 self
82 }
83}
84
85/// Options for truncation operations.
86///
87/// # Builder Pattern
88///
89/// ```
90/// use tensor4all_core::SvdTruncationPolicy;
91/// use tensor4all_treetn::TruncationOptions;
92///
93/// let options = TruncationOptions::default()
94/// .with_max_rank(50)
95/// .with_svd_policy(SvdTruncationPolicy::new(1e-10));
96///
97/// assert_eq!(options.max_rank(), Some(50));
98/// assert_eq!(options.svd_policy(), Some(SvdTruncationPolicy::new(1e-10)));
99/// ```
100#[derive(Debug, Clone, Copy)]
101pub struct TruncationOptions {
102 #[allow(dead_code)]
103 pub(crate) form: CanonicalForm,
104 pub(crate) truncation: FactorizationToleranceOptions,
105}
106
107impl Default for TruncationOptions {
108 fn default() -> Self {
109 Self {
110 form: CanonicalForm::Unitary,
111 truncation: FactorizationToleranceOptions::default(),
112 }
113 }
114}
115
116impl TruncationOptions {
117 /// Create options with default settings (no truncation limits).
118 pub fn new() -> Self {
119 Self::default()
120 }
121
122 /// Create options with a maximum rank.
123 pub fn with_max_rank(mut self, rank: usize) -> Self {
124 self.truncation.max_rank = Some(rank);
125 self
126 }
127
128 /// Set the SVD truncation policy used during the truncation sweep.
129 pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
130 self.truncation.svd_policy = Some(policy);
131 self
132 }
133
134 /// Get the SVD truncation policy.
135 pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
136 self.truncation.svd_policy
137 }
138
139 /// Get max_rank.
140 pub fn max_rank(&self) -> Option<usize> {
141 self.truncation.max_rank
142 }
143}
144
145/// Options for split operations.
146///
147/// # Builder Pattern
148///
149/// ```
150/// use tensor4all_core::SvdTruncationPolicy;
151/// use tensor4all_treetn::{CanonicalForm, SplitOptions};
152///
153/// let options = SplitOptions::default()
154/// .with_max_rank(50)
155/// .with_svd_policy(SvdTruncationPolicy::new(1e-10))
156/// .with_qr_rtol(1e-12)
157/// .with_final_sweep(true);
158///
159/// assert!(matches!(options.form, CanonicalForm::Unitary));
160/// assert_eq!(options.max_rank(), Some(50));
161/// assert_eq!(options.svd_policy(), Some(SvdTruncationPolicy::new(1e-10)));
162/// assert_eq!(options.qr_rtol(), Some(1e-12));
163/// assert!(options.final_sweep);
164/// ```
165#[derive(Debug, Clone, Copy)]
166pub struct SplitOptions {
167 /// Canonical form / algorithm to use (SVD, QR, etc.)
168 pub form: CanonicalForm,
169 /// Algorithm-aware factorization tolerances.
170 pub(crate) truncation: FactorizationToleranceOptions,
171 /// Whether to perform a final sweep for global bond dimension optimization
172 pub final_sweep: bool,
173}
174
175impl Default for SplitOptions {
176 fn default() -> Self {
177 Self {
178 form: CanonicalForm::Unitary,
179 truncation: FactorizationToleranceOptions::default(),
180 final_sweep: false,
181 }
182 }
183}
184
185impl SplitOptions {
186 /// Create options with default settings.
187 pub fn new() -> Self {
188 Self::default()
189 }
190
191 /// Create options with a maximum rank.
192 pub fn with_max_rank(mut self, rank: usize) -> Self {
193 self.truncation.max_rank = Some(rank);
194 self
195 }
196
197 /// Set the SVD truncation policy used when `form` is unitary/SVD-based.
198 pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
199 self.truncation.svd_policy = Some(policy);
200 self
201 }
202
203 /// Set the QR-specific relative tolerance used when `form` is QR-based.
204 pub fn with_qr_rtol(mut self, rtol: f64) -> Self {
205 self.truncation.qr_rtol = Some(rtol);
206 self
207 }
208
209 /// Set the canonical form / algorithm.
210 pub fn with_form(mut self, form: CanonicalForm) -> Self {
211 self.form = form;
212 self
213 }
214
215 /// Enable or disable final sweep for global optimization.
216 pub fn with_final_sweep(mut self, final_sweep: bool) -> Self {
217 self.final_sweep = final_sweep;
218 self
219 }
220
221 /// Get the SVD truncation policy.
222 pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
223 self.truncation.svd_policy
224 }
225
226 /// Get the QR-specific relative tolerance.
227 pub fn qr_rtol(&self) -> Option<f64> {
228 self.truncation.qr_rtol
229 }
230
231 /// Get max_rank.
232 pub fn max_rank(&self) -> Option<usize> {
233 self.truncation.max_rank
234 }
235}
236
237/// Options for `TreeTN::restructure_to` multi-phase restructures.
238///
239/// `RestructureOptions` combines the three phases in the approved B2a design:
240/// a split/refinement phase, a site-transport phase, and an optional final
241/// truncation sweep after the target structure has been assembled.
242///
243/// Related types:
244/// - [`SplitOptions`] controls exact splitting plus any optional final sweep
245/// inside the split primitive.
246/// - [`SwapOptions`] controls bond truncation during site-index transport.
247/// - [`TruncationOptions`] can be applied once at the end of the full
248/// restructure to clean up bond dimensions on the final topology.
249///
250/// When in doubt, start with `RestructureOptions::default()`: exact splitting,
251/// exact transport, and no extra final truncation sweep.
252///
253/// # Examples
254///
255/// ```
256/// use tensor4all_treetn::{
257/// RestructureOptions, SplitOptions, SwapOptions, TruncationOptions,
258/// };
259/// use tensor4all_core::SvdTruncationPolicy;
260///
261/// let options = RestructureOptions::new()
262/// .with_split(SplitOptions::new().with_max_rank(32))
263/// .with_swap(SwapOptions {
264/// max_rank: Some(16),
265/// rtol: Some(1e-10),
266/// })
267/// .with_final_truncation(
268/// TruncationOptions::new().with_svd_policy(SvdTruncationPolicy::new(1e-12)),
269/// );
270///
271/// assert_eq!(options.split.max_rank(), Some(32));
272/// assert!(!options.split.final_sweep);
273/// assert_eq!(options.swap.max_rank, Some(16));
274/// assert_eq!(options.swap.rtol, Some(1e-10));
275/// assert_eq!(
276/// options
277/// .final_truncation
278/// .as_ref()
279/// .and_then(TruncationOptions::svd_policy),
280/// Some(SvdTruncationPolicy::new(1e-12))
281/// );
282/// ```
283#[derive(Debug, Clone, Default)]
284pub struct RestructureOptions {
285 /// Options for the split/refinement phase.
286 ///
287 /// These settings matter when a current node must be factored into multiple
288 /// fragments before any fragment movement can happen. Higher `max_rank`,
289 /// stricter `svd_policy`, and smaller `qr_rtol` preserve more fidelity but
290 /// can increase intermediate bond dimensions. `final_sweep` should usually
291 /// remain `false` here unless a split-only workflow is being optimized in
292 /// isolation.
293 pub split: SplitOptions,
294 /// Options for the site-transport / swap phase.
295 ///
296 /// This phase only moves already planned fragments across existing edges.
297 /// Leaving both fields unset keeps swaps exact. Setting `max_rank` or
298 /// `rtol` can control intermediate rank growth, but may introduce
299 /// approximation earlier than the optional final truncation sweep.
300 pub swap: SwapOptions,
301 /// Optional final truncation sweep on the fully restructured network.
302 ///
303 /// Use this when the split and swap phases should remain as exact as
304 /// possible, and compression should happen only after the target topology
305 /// and grouping have been reached. `None` disables this cleanup pass.
306 pub final_truncation: Option<TruncationOptions>,
307}
308
309impl RestructureOptions {
310 /// Create options with exact split/swap phases and no final cleanup sweep.
311 ///
312 /// # Returns
313 /// A `RestructureOptions` value equivalent to [`Default::default`].
314 ///
315 /// # Examples
316 ///
317 /// ```
318 /// use tensor4all_treetn::RestructureOptions;
319 ///
320 /// let options = RestructureOptions::new();
321 ///
322 /// assert!(!options.split.final_sweep);
323 /// assert_eq!(options.swap.max_rank, None);
324 /// assert_eq!(options.swap.rtol, None);
325 /// assert!(options.final_truncation.is_none());
326 /// ```
327 pub fn new() -> Self {
328 Self::default()
329 }
330
331 /// Replace the split-phase options.
332 ///
333 /// # Arguments
334 /// * `split` - Split/refinement settings used before fragment transport.
335 ///
336 /// # Returns
337 /// Updated restructure options using the provided split settings.
338 ///
339 /// # Examples
340 ///
341 /// ```
342 /// use tensor4all_treetn::{RestructureOptions, SplitOptions};
343 ///
344 /// let options = RestructureOptions::new()
345 /// .with_split(SplitOptions::new().with_max_rank(24).with_final_sweep(true));
346 ///
347 /// assert_eq!(options.split.max_rank(), Some(24));
348 /// assert!(options.split.final_sweep);
349 /// ```
350 pub fn with_split(mut self, split: SplitOptions) -> Self {
351 self.split = split;
352 self
353 }
354
355 /// Replace the swap/transport options.
356 ///
357 /// # Arguments
358 /// * `swap` - Truncation settings applied during fragment movement.
359 ///
360 /// # Returns
361 /// Updated restructure options using the provided swap settings.
362 ///
363 /// # Examples
364 ///
365 /// ```
366 /// use tensor4all_treetn::{RestructureOptions, SwapOptions};
367 ///
368 /// let options = RestructureOptions::new().with_swap(SwapOptions {
369 /// max_rank: Some(12),
370 /// rtol: Some(1e-8),
371 /// });
372 ///
373 /// assert_eq!(options.swap.max_rank, Some(12));
374 /// assert_eq!(options.swap.rtol, Some(1e-8));
375 /// ```
376 pub fn with_swap(mut self, swap: SwapOptions) -> Self {
377 self.swap = swap;
378 self
379 }
380
381 /// Set the optional final truncation sweep.
382 ///
383 /// # Arguments
384 /// * `final_truncation` - Final cleanup sweep to run on the target
385 /// topology.
386 ///
387 /// # Returns
388 /// Updated restructure options using the provided final sweep settings.
389 ///
390 /// # Examples
391 ///
392 /// ```
393 /// use tensor4all_treetn::{RestructureOptions, TruncationOptions};
394 /// use tensor4all_core::SvdTruncationPolicy;
395 ///
396 /// let options = RestructureOptions::new()
397 /// .with_final_truncation(
398 /// TruncationOptions::new()
399 /// .with_max_rank(10)
400 /// .with_svd_policy(SvdTruncationPolicy::new(1e-10)),
401 /// );
402 ///
403 /// assert_eq!(
404 /// options
405 /// .final_truncation
406 /// .as_ref()
407 /// .and_then(TruncationOptions::max_rank),
408 /// Some(10)
409 /// );
410 /// ```
411 pub fn with_final_truncation(mut self, final_truncation: TruncationOptions) -> Self {
412 self.final_truncation = Some(final_truncation);
413 self
414 }
415}
416
417#[cfg(test)]
418mod tests;