Skip to main content

tensor4all_itensorlike/
options.rs

1//! Configuration options for tensor train operations.
2
3use std::ops::Range;
4
5use tensor4all_core::SvdTruncationPolicy;
6
7use crate::error::{Result, TensorTrainError};
8
9// Re-export CanonicalForm from treetn for convenience.
10pub use tensor4all_treetn::algorithm::CanonicalForm;
11
12pub(crate) fn validate_svd_truncation_options(
13    max_rank: Option<usize>,
14    svd_policy: Option<SvdTruncationPolicy>,
15) -> Result<()> {
16    if let Some(policy) = svd_policy {
17        if !policy.threshold.is_finite() || policy.threshold < 0.0 {
18            return Err(TensorTrainError::OperationError {
19                message: format!(
20                    "svd_policy.threshold must be finite and >= 0, got {}",
21                    policy.threshold
22                ),
23            });
24        }
25    }
26
27    if let Some(max_rank) = max_rank {
28        if max_rank == 0 {
29            return Err(TensorTrainError::OperationError {
30                message: "max_rank/maxdim must be >= 1".to_string(),
31            });
32        }
33    }
34
35    Ok(())
36}
37
38/// Options for tensor train truncation.
39///
40/// Truncation is explicitly SVD-based. Canonicalization remains the API for
41/// LU/CI-style forms; truncate itself only accepts SVD truncation controls.
42///
43/// # Examples
44///
45/// ```
46/// use tensor4all_core::SvdTruncationPolicy;
47/// use tensor4all_itensorlike::TruncateOptions;
48///
49/// let opts = TruncateOptions::svd()
50///     .with_svd_policy(SvdTruncationPolicy::new(1e-10))
51///     .with_max_rank(20)
52///     .with_site_range(0..4);
53///
54/// assert_eq!(opts.svd_policy(), Some(SvdTruncationPolicy::new(1e-10)));
55/// assert_eq!(opts.max_rank(), Some(20));
56/// assert_eq!(opts.site_range(), Some(0..4));
57/// ```
58#[derive(Debug, Clone, Default)]
59pub struct TruncateOptions {
60    max_rank: Option<usize>,
61    svd_policy: Option<SvdTruncationPolicy>,
62    site_range: Option<Range<usize>>,
63}
64
65impl TruncateOptions {
66    /// Create options for SVD-based truncation.
67    pub fn svd() -> Self {
68        Self::default()
69    }
70
71    /// Set the explicit SVD truncation policy.
72    pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
73        self.svd_policy = Some(policy);
74        self
75    }
76
77    /// Set the maximum retained bond dimension.
78    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
79        self.max_rank = Some(max_rank);
80        self
81    }
82
83    /// Set the site range for truncation.
84    ///
85    /// The range is 0-indexed with exclusive end.
86    pub fn with_site_range(mut self, range: Range<usize>) -> Self {
87        self.site_range = Some(range);
88        self
89    }
90
91    /// Get the SVD truncation policy override.
92    #[inline]
93    pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
94        self.svd_policy
95    }
96
97    /// Get the maximum retained bond dimension.
98    #[inline]
99    pub fn max_rank(&self) -> Option<usize> {
100        self.max_rank
101    }
102
103    /// Get the site range for truncation.
104    #[inline]
105    pub fn site_range(&self) -> Option<Range<usize>> {
106        self.site_range.clone()
107    }
108}
109
110/// Contraction method for tensor train operations.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
112pub enum ContractMethod {
113    /// Zip-up contraction (faster, one-pass).
114    #[default]
115    Zipup,
116    /// Fit/variational contraction (iterative optimization).
117    Fit,
118    /// Naive contraction: contract to full tensor, then decompose back.
119    /// Useful for debugging and testing, but O(exp(n)) in memory.
120    Naive,
121}
122
123/// Options for tensor train contraction.
124///
125/// # Examples
126///
127/// ```
128/// use tensor4all_core::SvdTruncationPolicy;
129/// use tensor4all_itensorlike::ContractOptions;
130///
131/// let opts = ContractOptions::fit()
132///     .with_svd_policy(SvdTruncationPolicy::new(1e-8))
133///     .with_max_rank(50)
134///     .with_nsweeps(3);
135///
136/// assert_eq!(opts.max_rank(), Some(50));
137/// assert_eq!(opts.svd_policy(), Some(SvdTruncationPolicy::new(1e-8)));
138/// assert_eq!(opts.nhalfsweeps(), 6);
139/// ```
140#[derive(Debug, Clone)]
141pub struct ContractOptions {
142    method: ContractMethod,
143    max_rank: Option<usize>,
144    svd_policy: Option<SvdTruncationPolicy>,
145    nhalfsweeps: usize,
146}
147
148impl Default for ContractOptions {
149    fn default() -> Self {
150        Self {
151            method: ContractMethod::default(),
152            max_rank: None,
153            svd_policy: None,
154            nhalfsweeps: 2,
155        }
156    }
157}
158
159impl ContractOptions {
160    /// Create options for zipup contraction.
161    pub fn zipup() -> Self {
162        Self {
163            method: ContractMethod::Zipup,
164            ..Default::default()
165        }
166    }
167
168    /// Create options for fit contraction.
169    pub fn fit() -> Self {
170        Self {
171            method: ContractMethod::Fit,
172            ..Default::default()
173        }
174    }
175
176    /// Create options for naive contraction.
177    pub fn naive() -> Self {
178        Self {
179            method: ContractMethod::Naive,
180            ..Default::default()
181        }
182    }
183
184    /// Set the maximum retained bond dimension.
185    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
186        self.max_rank = Some(max_rank);
187        self
188    }
189
190    /// Set the explicit SVD truncation policy.
191    pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
192        self.svd_policy = Some(policy);
193        self
194    }
195
196    /// Set number of half-sweeps for fit contraction.
197    pub fn with_nhalfsweeps(mut self, nhalfsweeps: usize) -> Self {
198        self.nhalfsweeps = nhalfsweeps;
199        self
200    }
201
202    /// Set number of full sweeps.
203    ///
204    /// A full sweep is two half-sweeps.
205    pub fn with_nsweeps(mut self, nsweeps: usize) -> Self {
206        self.nhalfsweeps = nsweeps * 2;
207        self
208    }
209
210    /// Get the contraction method.
211    #[inline]
212    pub fn method(&self) -> ContractMethod {
213        self.method
214    }
215
216    /// Get the maximum retained bond dimension.
217    #[inline]
218    pub fn max_rank(&self) -> Option<usize> {
219        self.max_rank
220    }
221
222    /// Get the SVD truncation policy override.
223    #[inline]
224    pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
225        self.svd_policy
226    }
227
228    /// Get number of half-sweeps.
229    #[inline]
230    pub fn nhalfsweeps(&self) -> usize {
231        self.nhalfsweeps
232    }
233}
234
235/// Options for the linear solver.
236///
237/// Solves `(a₀ + a₁ * A) * x = b` using DMRG-like sweeps with local GMRES.
238///
239/// # Examples
240///
241/// ```
242/// use tensor4all_core::SvdTruncationPolicy;
243/// use tensor4all_itensorlike::LinsolveOptions;
244///
245/// let opts = LinsolveOptions::new(5)
246///     .with_svd_policy(SvdTruncationPolicy::new(1e-10))
247///     .with_max_rank(64)
248///     .with_krylov_tol(1e-8)
249///     .with_coefficients(1.0, -1.0);
250///
251/// assert_eq!(opts.max_rank(), Some(64));
252/// assert_eq!(opts.svd_policy(), Some(SvdTruncationPolicy::new(1e-10)));
253/// assert_eq!(opts.nhalfsweeps(), 10);
254/// ```
255#[derive(Debug, Clone)]
256pub struct LinsolveOptions {
257    nhalfsweeps: usize,
258    max_rank: Option<usize>,
259    svd_policy: Option<SvdTruncationPolicy>,
260    krylov_tol: f64,
261    krylov_maxiter: usize,
262    krylov_dim: usize,
263    a0: f64,
264    a1: f64,
265    convergence_tol: Option<f64>,
266}
267
268impl Default for LinsolveOptions {
269    fn default() -> Self {
270        Self {
271            nhalfsweeps: 10,
272            max_rank: None,
273            svd_policy: None,
274            krylov_tol: 1e-10,
275            krylov_maxiter: 100,
276            krylov_dim: 30,
277            a0: 0.0,
278            a1: 1.0,
279            convergence_tol: None,
280        }
281    }
282}
283
284impl LinsolveOptions {
285    /// Create options with the specified number of full sweeps.
286    pub fn new(nsweeps: usize) -> Self {
287        Self {
288            nhalfsweeps: nsweeps * 2,
289            ..Default::default()
290        }
291    }
292
293    /// Set the explicit SVD truncation policy.
294    pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
295        self.svd_policy = Some(policy);
296        self
297    }
298
299    /// Set the maximum retained bond dimension.
300    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
301        self.max_rank = Some(max_rank);
302        self
303    }
304
305    /// Set number of half-sweeps.
306    pub fn with_nhalfsweeps(mut self, nhalfsweeps: usize) -> Self {
307        self.nhalfsweeps = nhalfsweeps;
308        self
309    }
310
311    /// Set number of full sweeps.
312    pub fn with_nsweeps(mut self, nsweeps: usize) -> Self {
313        self.nhalfsweeps = nsweeps * 2;
314        self
315    }
316
317    /// Set GMRES tolerance.
318    pub fn with_krylov_tol(mut self, tol: f64) -> Self {
319        self.krylov_tol = tol;
320        self
321    }
322
323    /// Set maximum GMRES iterations per local solve.
324    pub fn with_krylov_maxiter(mut self, maxiter: usize) -> Self {
325        self.krylov_maxiter = maxiter;
326        self
327    }
328
329    /// Set Krylov subspace dimension (restart parameter).
330    pub fn with_krylov_dim(mut self, dim: usize) -> Self {
331        self.krylov_dim = dim;
332        self
333    }
334
335    /// Set coefficients `a₀` and `a₁` in `(a₀ + a₁ * A) * x = b`.
336    pub fn with_coefficients(mut self, a0: f64, a1: f64) -> Self {
337        self.a0 = a0;
338        self.a1 = a1;
339        self
340    }
341
342    /// Set convergence tolerance for early termination.
343    pub fn with_convergence_tol(mut self, tol: f64) -> Self {
344        self.convergence_tol = Some(tol);
345        self
346    }
347
348    /// Get the maximum retained bond dimension.
349    #[inline]
350    pub fn max_rank(&self) -> Option<usize> {
351        self.max_rank
352    }
353
354    /// Get the SVD truncation policy override.
355    #[inline]
356    pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
357        self.svd_policy
358    }
359
360    /// Get number of half-sweeps.
361    #[inline]
362    pub fn nhalfsweeps(&self) -> usize {
363        self.nhalfsweeps
364    }
365
366    /// Get GMRES tolerance.
367    #[inline]
368    pub fn krylov_tol(&self) -> f64 {
369        self.krylov_tol
370    }
371
372    /// Get maximum GMRES iterations per local solve.
373    #[inline]
374    pub fn krylov_maxiter(&self) -> usize {
375        self.krylov_maxiter
376    }
377
378    /// Get Krylov subspace dimension.
379    #[inline]
380    pub fn krylov_dim(&self) -> usize {
381        self.krylov_dim
382    }
383
384    /// Get coefficients `(a0, a1)`.
385    #[inline]
386    pub fn coefficients(&self) -> (f64, f64) {
387        (self.a0, self.a1)
388    }
389
390    /// Get convergence tolerance.
391    #[inline]
392    pub fn convergence_tol(&self) -> Option<f64> {
393        self.convergence_tol
394    }
395}
396
397#[cfg(test)]
398mod tests;