tensor4all_itensorlike/
options.rs1use std::ops::Range;
4
5use tensor4all_core::SvdTruncationPolicy;
6
7use crate::error::{Result, TensorTrainError};
8
9pub use tensor4all_treetn::algorithm::CanonicalForm;
11
12pub(crate) fn validate_svd_truncation_options(
13 max_rank: Option<usize>,
14 svd_policy: Option<SvdTruncationPolicy>,
15) -> Result<()> {
16 if let Some(policy) = svd_policy {
17 if !policy.threshold.is_finite() || policy.threshold < 0.0 {
18 return Err(TensorTrainError::OperationError {
19 message: format!(
20 "svd_policy.threshold must be finite and >= 0, got {}",
21 policy.threshold
22 ),
23 });
24 }
25 }
26
27 if let Some(max_rank) = max_rank {
28 if max_rank == 0 {
29 return Err(TensorTrainError::OperationError {
30 message: "max_rank/maxdim must be >= 1".to_string(),
31 });
32 }
33 }
34
35 Ok(())
36}
37
38#[derive(Debug, Clone, Default)]
59pub struct TruncateOptions {
60 max_rank: Option<usize>,
61 svd_policy: Option<SvdTruncationPolicy>,
62 site_range: Option<Range<usize>>,
63}
64
65impl TruncateOptions {
66 pub fn svd() -> Self {
68 Self::default()
69 }
70
71 pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
73 self.svd_policy = Some(policy);
74 self
75 }
76
77 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
79 self.max_rank = Some(max_rank);
80 self
81 }
82
83 pub fn with_site_range(mut self, range: Range<usize>) -> Self {
87 self.site_range = Some(range);
88 self
89 }
90
91 #[inline]
93 pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
94 self.svd_policy
95 }
96
97 #[inline]
99 pub fn max_rank(&self) -> Option<usize> {
100 self.max_rank
101 }
102
103 #[inline]
105 pub fn site_range(&self) -> Option<Range<usize>> {
106 self.site_range.clone()
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
112pub enum ContractMethod {
113 #[default]
115 Zipup,
116 Fit,
118 Naive,
121}
122
123#[derive(Debug, Clone)]
141pub struct ContractOptions {
142 method: ContractMethod,
143 max_rank: Option<usize>,
144 svd_policy: Option<SvdTruncationPolicy>,
145 nhalfsweeps: usize,
146}
147
148impl Default for ContractOptions {
149 fn default() -> Self {
150 Self {
151 method: ContractMethod::default(),
152 max_rank: None,
153 svd_policy: None,
154 nhalfsweeps: 2,
155 }
156 }
157}
158
159impl ContractOptions {
160 pub fn zipup() -> Self {
162 Self {
163 method: ContractMethod::Zipup,
164 ..Default::default()
165 }
166 }
167
168 pub fn fit() -> Self {
170 Self {
171 method: ContractMethod::Fit,
172 ..Default::default()
173 }
174 }
175
176 pub fn naive() -> Self {
178 Self {
179 method: ContractMethod::Naive,
180 ..Default::default()
181 }
182 }
183
184 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
186 self.max_rank = Some(max_rank);
187 self
188 }
189
190 pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
192 self.svd_policy = Some(policy);
193 self
194 }
195
196 pub fn with_nhalfsweeps(mut self, nhalfsweeps: usize) -> Self {
198 self.nhalfsweeps = nhalfsweeps;
199 self
200 }
201
202 pub fn with_nsweeps(mut self, nsweeps: usize) -> Self {
206 self.nhalfsweeps = nsweeps * 2;
207 self
208 }
209
210 #[inline]
212 pub fn method(&self) -> ContractMethod {
213 self.method
214 }
215
216 #[inline]
218 pub fn max_rank(&self) -> Option<usize> {
219 self.max_rank
220 }
221
222 #[inline]
224 pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
225 self.svd_policy
226 }
227
228 #[inline]
230 pub fn nhalfsweeps(&self) -> usize {
231 self.nhalfsweeps
232 }
233}
234
235#[derive(Debug, Clone)]
256pub struct LinsolveOptions {
257 nhalfsweeps: usize,
258 max_rank: Option<usize>,
259 svd_policy: Option<SvdTruncationPolicy>,
260 krylov_tol: f64,
261 krylov_maxiter: usize,
262 krylov_dim: usize,
263 a0: f64,
264 a1: f64,
265 convergence_tol: Option<f64>,
266}
267
268impl Default for LinsolveOptions {
269 fn default() -> Self {
270 Self {
271 nhalfsweeps: 10,
272 max_rank: None,
273 svd_policy: None,
274 krylov_tol: 1e-10,
275 krylov_maxiter: 100,
276 krylov_dim: 30,
277 a0: 0.0,
278 a1: 1.0,
279 convergence_tol: None,
280 }
281 }
282}
283
284impl LinsolveOptions {
285 pub fn new(nsweeps: usize) -> Self {
287 Self {
288 nhalfsweeps: nsweeps * 2,
289 ..Default::default()
290 }
291 }
292
293 pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
295 self.svd_policy = Some(policy);
296 self
297 }
298
299 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
301 self.max_rank = Some(max_rank);
302 self
303 }
304
305 pub fn with_nhalfsweeps(mut self, nhalfsweeps: usize) -> Self {
307 self.nhalfsweeps = nhalfsweeps;
308 self
309 }
310
311 pub fn with_nsweeps(mut self, nsweeps: usize) -> Self {
313 self.nhalfsweeps = nsweeps * 2;
314 self
315 }
316
317 pub fn with_krylov_tol(mut self, tol: f64) -> Self {
319 self.krylov_tol = tol;
320 self
321 }
322
323 pub fn with_krylov_maxiter(mut self, maxiter: usize) -> Self {
325 self.krylov_maxiter = maxiter;
326 self
327 }
328
329 pub fn with_krylov_dim(mut self, dim: usize) -> Self {
331 self.krylov_dim = dim;
332 self
333 }
334
335 pub fn with_coefficients(mut self, a0: f64, a1: f64) -> Self {
337 self.a0 = a0;
338 self.a1 = a1;
339 self
340 }
341
342 pub fn with_convergence_tol(mut self, tol: f64) -> Self {
344 self.convergence_tol = Some(tol);
345 self
346 }
347
348 #[inline]
350 pub fn max_rank(&self) -> Option<usize> {
351 self.max_rank
352 }
353
354 #[inline]
356 pub fn svd_policy(&self) -> Option<SvdTruncationPolicy> {
357 self.svd_policy
358 }
359
360 #[inline]
362 pub fn nhalfsweeps(&self) -> usize {
363 self.nhalfsweeps
364 }
365
366 #[inline]
368 pub fn krylov_tol(&self) -> f64 {
369 self.krylov_tol
370 }
371
372 #[inline]
374 pub fn krylov_maxiter(&self) -> usize {
375 self.krylov_maxiter
376 }
377
378 #[inline]
380 pub fn krylov_dim(&self) -> usize {
381 self.krylov_dim
382 }
383
384 #[inline]
386 pub fn coefficients(&self) -> (f64, f64) {
387 (self.a0, self.a1)
388 }
389
390 #[inline]
392 pub fn convergence_tol(&self) -> Option<f64> {
393 self.convergence_tol
394 }
395}
396
397#[cfg(test)]
398mod tests;