tensor4all_core/truncation.rs
1//! Common truncation options and traits.
2//!
3//! This module provides shared types and traits for truncation parameters
4//! used across tensor operations like SVD, QR, and tensor train compression.
5
6/// Decomposition/factorization algorithm.
7///
8/// This enum unifies the algorithm choices across different crates.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10pub enum DecompositionAlg {
11 /// Singular Value Decomposition (optimal truncation).
12 #[default]
13 SVD,
14 /// Randomized SVD (faster for large matrices).
15 RSVD,
16 /// QR decomposition.
17 QR,
18 /// Rank-revealing LU decomposition.
19 LU,
20 /// Cross Interpolation.
21 CI,
22}
23
24impl DecompositionAlg {
25 /// Check if this algorithm is SVD-based (SVD or RSVD).
26 #[must_use]
27 pub fn is_svd_based(&self) -> bool {
28 matches!(self, Self::SVD | Self::RSVD)
29 }
30
31 /// Check if this algorithm provides orthogonal factors.
32 #[must_use]
33 pub fn is_orthogonal(&self) -> bool {
34 matches!(self, Self::SVD | Self::RSVD | Self::QR)
35 }
36}
37
38/// Common truncation parameters.
39///
40/// This struct contains the core parameters used for rank truncation
41/// across various tensor decomposition and compression operations.
42///
43/// # Semantics
44///
45/// This crate uses **relative tolerance** (`rtol`) semantics:
46/// - Singular values are truncated when `σ_i / σ_max < rtol`
47///
48/// ITensorMPS.jl uses **cutoff** semantics:
49/// - Singular values are truncated when `σ_i² < cutoff`
50///
51/// **Conversion**: For normalized tensors (where `σ_max = 1`):
52/// - ITensorMPS.jl's `cutoff` = tensor4all-rs's `rtol²`
53/// - To match ITensorMPS.jl behavior: use `rtol = sqrt(cutoff)`
54/// - Example: ITensorMPS.jl `cutoff=1e-10` ↔ tensor4all-rs `rtol=1e-5`
55#[derive(Debug, Clone, Copy, Default)]
56pub struct TruncationParams {
57 /// Relative tolerance for truncation.
58 ///
59 /// Singular values satisfying `σ_i / σ_max < rtol` are truncated,
60 /// where `σ_max` is the largest singular value.
61 ///
62 /// If `None`, uses the algorithm's default tolerance.
63 pub rtol: Option<f64>,
64
65 /// Maximum rank (bond dimension).
66 ///
67 /// If `None`, no rank limit is applied.
68 pub max_rank: Option<usize>,
69
70 /// Cutoff value (ITensorMPS.jl convention).
71 ///
72 /// When set via [`with_cutoff`](TruncationParams::with_cutoff), `rtol` is
73 /// automatically set to `√cutoff`. This field tracks the original cutoff
74 /// value for inspection; `rtol` is always the authoritative tolerance.
75 ///
76 /// If `None`, cutoff was not used.
77 pub cutoff: Option<f64>,
78}
79
80impl TruncationParams {
81 /// Create new truncation parameters with default values.
82 #[must_use]
83 pub fn new() -> Self {
84 Self::default()
85 }
86
87 /// Set the relative tolerance.
88 ///
89 /// Clears any previously set cutoff origin.
90 #[must_use]
91 pub fn with_rtol(mut self, rtol: f64) -> Self {
92 self.rtol = Some(rtol);
93 self.cutoff = None;
94 self
95 }
96
97 /// Set the maximum rank.
98 #[must_use]
99 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
100 self.max_rank = Some(max_rank);
101 self
102 }
103
104 /// Set cutoff (ITensorMPS.jl convention).
105 ///
106 /// Internally converted to `rtol = √cutoff`. Clears any previously set
107 /// `rtol` origin so `cutoff` becomes the authoritative tolerance source.
108 #[must_use]
109 pub fn with_cutoff(mut self, cutoff: f64) -> Self {
110 self.cutoff = Some(cutoff);
111 self.rtol = Some(cutoff.sqrt());
112 self
113 }
114
115 /// Set maxdim (alias for [`with_max_rank`](Self::with_max_rank)).
116 ///
117 /// This is provided for ITensorMPS.jl compatibility.
118 #[must_use]
119 pub fn with_maxdim(mut self, maxdim: usize) -> Self {
120 self.max_rank = Some(maxdim);
121 self
122 }
123
124 /// Get the effective rtol, using the provided default if not set.
125 #[must_use]
126 pub fn effective_rtol(&self, default: f64) -> f64 {
127 self.rtol.unwrap_or(default)
128 }
129
130 /// Get the effective max_rank, using usize::MAX if not set.
131 #[must_use]
132 pub fn effective_max_rank(&self) -> usize {
133 self.max_rank.unwrap_or(usize::MAX)
134 }
135
136 /// Merge with another set of parameters, preferring self's values.
137 #[must_use]
138 pub fn merge(&self, other: &Self) -> Self {
139 Self {
140 rtol: self.rtol.or(other.rtol),
141 max_rank: self.max_rank.or(other.max_rank),
142 cutoff: self.cutoff.or(other.cutoff),
143 }
144 }
145}
146
147/// Trait for types that contain truncation parameters.
148///
149/// This trait provides a common interface for accessing and modifying
150/// truncation parameters in various options structs.
151pub trait HasTruncationParams {
152 /// Get a reference to the truncation parameters.
153 fn truncation_params(&self) -> &TruncationParams;
154
155 /// Get a mutable reference to the truncation parameters.
156 fn truncation_params_mut(&mut self) -> &mut TruncationParams;
157
158 /// Get the rtol value.
159 fn rtol(&self) -> Option<f64> {
160 self.truncation_params().rtol
161 }
162
163 /// Get the max_rank value.
164 fn max_rank(&self) -> Option<usize> {
165 self.truncation_params().max_rank
166 }
167
168 /// Set the rtol value (builder pattern).
169 ///
170 /// Clears any previously set cutoff origin.
171 fn with_rtol(mut self, rtol: f64) -> Self
172 where
173 Self: Sized,
174 {
175 let p = self.truncation_params_mut();
176 p.rtol = Some(rtol);
177 p.cutoff = None;
178 self
179 }
180
181 /// Set the max_rank value (builder pattern).
182 fn with_max_rank(mut self, max_rank: usize) -> Self
183 where
184 Self: Sized,
185 {
186 self.truncation_params_mut().max_rank = Some(max_rank);
187 self
188 }
189
190 /// Set cutoff (ITensorMPS.jl convention, builder pattern).
191 ///
192 /// Internally converted to `rtol = √cutoff`.
193 fn with_cutoff(mut self, cutoff: f64) -> Self
194 where
195 Self: Sized,
196 {
197 let p = self.truncation_params_mut();
198 p.cutoff = Some(cutoff);
199 p.rtol = Some(cutoff.sqrt());
200 self
201 }
202
203 /// Set maxdim (alias for max_rank, builder pattern).
204 fn with_maxdim(mut self, maxdim: usize) -> Self
205 where
206 Self: Sized,
207 {
208 self.truncation_params_mut().max_rank = Some(maxdim);
209 self
210 }
211
212 /// Set cutoff via mutable reference.
213 fn set_cutoff(&mut self, cutoff: f64) {
214 let p = self.truncation_params_mut();
215 p.cutoff = Some(cutoff);
216 p.rtol = Some(cutoff.sqrt());
217 }
218
219 /// Set maxdim via mutable reference (alias for max_rank).
220 fn set_maxdim(&mut self, maxdim: usize) {
221 self.truncation_params_mut().max_rank = Some(maxdim);
222 }
223}
224
225// Implement HasTruncationParams for TruncationParams itself
226impl HasTruncationParams for TruncationParams {
227 fn truncation_params(&self) -> &TruncationParams {
228 self
229 }
230
231 fn truncation_params_mut(&mut self) -> &mut TruncationParams {
232 self
233 }
234}
235
236#[cfg(test)]
237mod tests;