tensor4all_core/
truncation.rs

1//! Common truncation options and traits.
2//!
3//! This module provides shared types and traits for truncation parameters
4//! used across tensor operations like SVD, QR, and tensor train compression.
5
6/// Decomposition/factorization algorithm.
7///
8/// This enum unifies the algorithm choices across different crates.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10pub enum DecompositionAlg {
11    /// Singular Value Decomposition (optimal truncation).
12    #[default]
13    SVD,
14    /// Randomized SVD (faster for large matrices).
15    RSVD,
16    /// QR decomposition.
17    QR,
18    /// Rank-revealing LU decomposition.
19    LU,
20    /// Cross Interpolation.
21    CI,
22}
23
24impl DecompositionAlg {
25    /// Check if this algorithm is SVD-based (SVD or RSVD).
26    #[must_use]
27    pub fn is_svd_based(&self) -> bool {
28        matches!(self, Self::SVD | Self::RSVD)
29    }
30
31    /// Check if this algorithm provides orthogonal factors.
32    #[must_use]
33    pub fn is_orthogonal(&self) -> bool {
34        matches!(self, Self::SVD | Self::RSVD | Self::QR)
35    }
36}
37
38/// Common truncation parameters.
39///
40/// This struct contains the core parameters used for rank truncation
41/// across various tensor decomposition and compression operations.
42///
43/// # Semantics
44///
45/// This crate uses **relative tolerance** (`rtol`) semantics:
46/// - Singular values are truncated when `σ_i / σ_max < rtol`
47///
48/// ITensorMPS.jl uses **cutoff** semantics:
49/// - Singular values are truncated when `σ_i² < cutoff`
50///
51/// **Conversion**: For normalized tensors (where `σ_max = 1`):
52/// - ITensorMPS.jl's `cutoff` = tensor4all-rs's `rtol²`
53/// - To match ITensorMPS.jl behavior: use `rtol = sqrt(cutoff)`
54/// - Example: ITensorMPS.jl `cutoff=1e-10` ↔ tensor4all-rs `rtol=1e-5`
55#[derive(Debug, Clone, Copy, Default)]
56pub struct TruncationParams {
57    /// Relative tolerance for truncation.
58    ///
59    /// Singular values satisfying `σ_i / σ_max < rtol` are truncated,
60    /// where `σ_max` is the largest singular value.
61    ///
62    /// If `None`, uses the algorithm's default tolerance.
63    pub rtol: Option<f64>,
64
65    /// Maximum rank (bond dimension).
66    ///
67    /// If `None`, no rank limit is applied.
68    pub max_rank: Option<usize>,
69
70    /// Cutoff value (ITensorMPS.jl convention).
71    ///
72    /// When set via [`with_cutoff`](TruncationParams::with_cutoff), `rtol` is
73    /// automatically set to `√cutoff`. This field tracks the original cutoff
74    /// value for inspection; `rtol` is always the authoritative tolerance.
75    ///
76    /// If `None`, cutoff was not used.
77    pub cutoff: Option<f64>,
78}
79
80impl TruncationParams {
81    /// Create new truncation parameters with default values.
82    #[must_use]
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Set the relative tolerance.
88    ///
89    /// Clears any previously set cutoff origin.
90    #[must_use]
91    pub fn with_rtol(mut self, rtol: f64) -> Self {
92        self.rtol = Some(rtol);
93        self.cutoff = None;
94        self
95    }
96
97    /// Set the maximum rank.
98    #[must_use]
99    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
100        self.max_rank = Some(max_rank);
101        self
102    }
103
104    /// Set cutoff (ITensorMPS.jl convention).
105    ///
106    /// Internally converted to `rtol = √cutoff`. Clears any previously set
107    /// `rtol` origin so `cutoff` becomes the authoritative tolerance source.
108    #[must_use]
109    pub fn with_cutoff(mut self, cutoff: f64) -> Self {
110        self.cutoff = Some(cutoff);
111        self.rtol = Some(cutoff.sqrt());
112        self
113    }
114
115    /// Set maxdim (alias for [`with_max_rank`](Self::with_max_rank)).
116    ///
117    /// This is provided for ITensorMPS.jl compatibility.
118    #[must_use]
119    pub fn with_maxdim(mut self, maxdim: usize) -> Self {
120        self.max_rank = Some(maxdim);
121        self
122    }
123
124    /// Get the effective rtol, using the provided default if not set.
125    #[must_use]
126    pub fn effective_rtol(&self, default: f64) -> f64 {
127        self.rtol.unwrap_or(default)
128    }
129
130    /// Get the effective max_rank, using usize::MAX if not set.
131    #[must_use]
132    pub fn effective_max_rank(&self) -> usize {
133        self.max_rank.unwrap_or(usize::MAX)
134    }
135
136    /// Merge with another set of parameters, preferring self's values.
137    #[must_use]
138    pub fn merge(&self, other: &Self) -> Self {
139        Self {
140            rtol: self.rtol.or(other.rtol),
141            max_rank: self.max_rank.or(other.max_rank),
142            cutoff: self.cutoff.or(other.cutoff),
143        }
144    }
145}
146
147/// Trait for types that contain truncation parameters.
148///
149/// This trait provides a common interface for accessing and modifying
150/// truncation parameters in various options structs.
151pub trait HasTruncationParams {
152    /// Get a reference to the truncation parameters.
153    fn truncation_params(&self) -> &TruncationParams;
154
155    /// Get a mutable reference to the truncation parameters.
156    fn truncation_params_mut(&mut self) -> &mut TruncationParams;
157
158    /// Get the rtol value.
159    fn rtol(&self) -> Option<f64> {
160        self.truncation_params().rtol
161    }
162
163    /// Get the max_rank value.
164    fn max_rank(&self) -> Option<usize> {
165        self.truncation_params().max_rank
166    }
167
168    /// Set the rtol value (builder pattern).
169    ///
170    /// Clears any previously set cutoff origin.
171    fn with_rtol(mut self, rtol: f64) -> Self
172    where
173        Self: Sized,
174    {
175        let p = self.truncation_params_mut();
176        p.rtol = Some(rtol);
177        p.cutoff = None;
178        self
179    }
180
181    /// Set the max_rank value (builder pattern).
182    fn with_max_rank(mut self, max_rank: usize) -> Self
183    where
184        Self: Sized,
185    {
186        self.truncation_params_mut().max_rank = Some(max_rank);
187        self
188    }
189
190    /// Set cutoff (ITensorMPS.jl convention, builder pattern).
191    ///
192    /// Internally converted to `rtol = √cutoff`.
193    fn with_cutoff(mut self, cutoff: f64) -> Self
194    where
195        Self: Sized,
196    {
197        let p = self.truncation_params_mut();
198        p.cutoff = Some(cutoff);
199        p.rtol = Some(cutoff.sqrt());
200        self
201    }
202
203    /// Set maxdim (alias for max_rank, builder pattern).
204    fn with_maxdim(mut self, maxdim: usize) -> Self
205    where
206        Self: Sized,
207    {
208        self.truncation_params_mut().max_rank = Some(maxdim);
209        self
210    }
211
212    /// Set cutoff via mutable reference.
213    fn set_cutoff(&mut self, cutoff: f64) {
214        let p = self.truncation_params_mut();
215        p.cutoff = Some(cutoff);
216        p.rtol = Some(cutoff.sqrt());
217    }
218
219    /// Set maxdim via mutable reference (alias for max_rank).
220    fn set_maxdim(&mut self, maxdim: usize) {
221        self.truncation_params_mut().max_rank = Some(maxdim);
222    }
223}
224
225// Implement HasTruncationParams for TruncationParams itself
226impl HasTruncationParams for TruncationParams {
227    fn truncation_params(&self) -> &TruncationParams {
228        self
229    }
230
231    fn truncation_params_mut(&mut self) -> &mut TruncationParams {
232        self
233    }
234}
235
236#[cfg(test)]
237mod tests;