tensor4all_core/truncation.rs
1//! Truncation policy types for decomposition algorithms.
2//!
3//! This module keeps algorithm selection separate from algorithm-specific
4//! truncation semantics. SVD-based routines use [`SvdTruncationPolicy`],
5//! while QR and other decompositions keep their own option types.
6
7use thiserror::Error;
8
9/// Decomposition/factorization algorithm.
10///
11/// This enum unifies the algorithm choices across different crates
12/// (`tensor4all-core`, `tensor4all-treetn`, etc.).
13///
14/// # Examples
15///
16/// ```
17/// use tensor4all_core::DecompositionAlg;
18///
19/// assert_eq!(DecompositionAlg::default(), DecompositionAlg::SVD);
20/// assert!(DecompositionAlg::SVD.is_svd_based());
21/// assert!(DecompositionAlg::SVD.is_orthogonal());
22/// assert!(!DecompositionAlg::LU.is_orthogonal());
23/// ```
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum DecompositionAlg {
26 /// Singular Value Decomposition (optimal truncation).
27 #[default]
28 SVD,
29 /// Randomized SVD (faster for large matrices).
30 RSVD,
31 /// QR decomposition.
32 QR,
33 /// Rank-revealing LU decomposition.
34 LU,
35 /// Cross Interpolation.
36 CI,
37}
38
39impl DecompositionAlg {
40 /// Check if this algorithm is SVD-based (SVD or RSVD).
41 #[must_use]
42 pub fn is_svd_based(&self) -> bool {
43 matches!(self, Self::SVD | Self::RSVD)
44 }
45
46 /// Check if this algorithm provides orthogonal factors.
47 #[must_use]
48 pub fn is_orthogonal(&self) -> bool {
49 matches!(self, Self::SVD | Self::RSVD | Self::QR)
50 }
51}
52
53/// Threshold scaling for SVD truncation.
54///
55/// Relative thresholds compare against a scale derived from the singular values.
56/// Absolute thresholds compare directly against the configured cutoff.
57///
58/// # Examples
59///
60/// ```
61/// use tensor4all_core::ThresholdScale;
62///
63/// assert_eq!(ThresholdScale::default(), ThresholdScale::Relative);
64/// ```
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum ThresholdScale {
67 /// Compare against a singular-value-derived reference scale.
68 #[default]
69 Relative,
70 /// Compare directly against the configured threshold.
71 Absolute,
72}
73
74/// Singular-value-derived quantity used for truncation.
75///
76/// # Examples
77///
78/// ```
79/// use tensor4all_core::SingularValueMeasure;
80///
81/// assert_eq!(SingularValueMeasure::default(), SingularValueMeasure::Value);
82/// ```
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84pub enum SingularValueMeasure {
85 /// Compare using singular values `σ_i`.
86 #[default]
87 Value,
88 /// Compare using squared singular values `σ_i²`.
89 SquaredValue,
90}
91
92/// Rule used to map singular values to a retained rank.
93///
94/// # Examples
95///
96/// ```
97/// use tensor4all_core::TruncationRule;
98///
99/// assert_eq!(TruncationRule::default(), TruncationRule::PerValue);
100/// ```
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
102pub enum TruncationRule {
103 /// Keep values whose individual measure exceeds the threshold rule.
104 #[default]
105 PerValue,
106 /// Discard a suffix while the cumulative discarded measure stays below
107 /// the threshold rule.
108 DiscardedTailSum,
109}
110
111/// Explicit truncation policy for SVD-based decompositions.
112///
113/// Use this type when you need to describe how singular values are measured,
114/// scaled, and turned into a retained rank. [`SvdOptions`](crate::SvdOptions)
115/// carries this policy plus an independent `max_rank` cap.
116///
117/// # Examples
118///
119/// ```
120/// use tensor4all_core::{
121/// SingularValueMeasure, SvdTruncationPolicy, ThresholdScale, TruncationRule,
122/// };
123///
124/// let policy = SvdTruncationPolicy::new(1e-12);
125/// assert_eq!(policy.scale, ThresholdScale::Relative);
126/// assert_eq!(policy.measure, SingularValueMeasure::Value);
127/// assert_eq!(policy.rule, TruncationRule::PerValue);
128///
129/// let tail_policy = SvdTruncationPolicy::new(1e-8)
130/// .with_absolute()
131/// .with_squared_values()
132/// .with_discarded_tail_sum();
133/// assert_eq!(tail_policy.scale, ThresholdScale::Absolute);
134/// assert_eq!(tail_policy.measure, SingularValueMeasure::SquaredValue);
135/// assert_eq!(tail_policy.rule, TruncationRule::DiscardedTailSum);
136/// ```
137#[derive(Debug, Clone, Copy, PartialEq)]
138pub struct SvdTruncationPolicy {
139 /// Threshold value used by the selected scale/rule combination.
140 pub threshold: f64,
141 /// Whether the threshold is interpreted relatively or absolutely.
142 pub scale: ThresholdScale,
143 /// Whether the policy measures singular values or squared singular values.
144 pub measure: SingularValueMeasure,
145 /// Whether truncation is per value or based on a discarded tail sum.
146 pub rule: TruncationRule,
147}
148
149impl SvdTruncationPolicy {
150 /// Create a policy with the default semantics:
151 /// relative threshold, singular values, and per-value truncation.
152 #[must_use]
153 pub const fn new(threshold: f64) -> Self {
154 Self {
155 threshold,
156 scale: ThresholdScale::Relative,
157 measure: SingularValueMeasure::Value,
158 rule: TruncationRule::PerValue,
159 }
160 }
161
162 /// Use relative threshold scaling.
163 #[must_use]
164 pub const fn with_relative(mut self) -> Self {
165 self.scale = ThresholdScale::Relative;
166 self
167 }
168
169 /// Use absolute threshold scaling.
170 #[must_use]
171 pub const fn with_absolute(mut self) -> Self {
172 self.scale = ThresholdScale::Absolute;
173 self
174 }
175
176 /// Measure singular values directly.
177 #[must_use]
178 pub const fn with_values(mut self) -> Self {
179 self.measure = SingularValueMeasure::Value;
180 self
181 }
182
183 /// Measure squared singular values.
184 #[must_use]
185 pub const fn with_squared_values(mut self) -> Self {
186 self.measure = SingularValueMeasure::SquaredValue;
187 self
188 }
189
190 /// Apply the threshold independently to each singular value.
191 #[must_use]
192 pub const fn with_per_value(mut self) -> Self {
193 self.rule = TruncationRule::PerValue;
194 self
195 }
196
197 /// Apply the threshold to the cumulative discarded tail.
198 #[must_use]
199 pub const fn with_discarded_tail_sum(mut self) -> Self {
200 self.rule = TruncationRule::DiscardedTailSum;
201 self
202 }
203}
204
205/// Error for invalid SVD truncation thresholds.
206#[derive(Debug, Error, Clone, Copy, PartialEq)]
207#[error("Invalid SVD truncation threshold: {0}. Threshold must be finite and non-negative.")]
208pub struct InvalidThresholdError(pub f64);
209
210/// Validate one threshold value.
211pub(crate) fn validate_threshold_value(threshold: f64) -> Result<(), InvalidThresholdError> {
212 if !threshold.is_finite() || threshold < 0.0 {
213 return Err(InvalidThresholdError(threshold));
214 }
215 Ok(())
216}
217
218/// Validate one full SVD truncation policy.
219pub(crate) fn validate_svd_truncation_policy(
220 policy: SvdTruncationPolicy,
221) -> Result<(), InvalidThresholdError> {
222 validate_threshold_value(policy.threshold)
223}
224
225#[cfg(test)]
226mod tests;