Skip to main content

tensor4all_core/
truncation.rs

1//! Truncation policy types for decomposition algorithms.
2//!
3//! This module keeps algorithm selection separate from algorithm-specific
4//! truncation semantics. SVD-based routines use [`SvdTruncationPolicy`],
5//! while QR and other decompositions keep their own option types.
6
7use thiserror::Error;
8
9/// Decomposition/factorization algorithm.
10///
11/// This enum unifies the algorithm choices across different crates
12/// (`tensor4all-core`, `tensor4all-treetn`, etc.).
13///
14/// # Examples
15///
16/// ```
17/// use tensor4all_core::DecompositionAlg;
18///
19/// assert_eq!(DecompositionAlg::default(), DecompositionAlg::SVD);
20/// assert!(DecompositionAlg::SVD.is_svd_based());
21/// assert!(DecompositionAlg::SVD.is_orthogonal());
22/// assert!(!DecompositionAlg::LU.is_orthogonal());
23/// ```
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum DecompositionAlg {
26    /// Singular Value Decomposition (optimal truncation).
27    #[default]
28    SVD,
29    /// Randomized SVD (faster for large matrices).
30    RSVD,
31    /// QR decomposition.
32    QR,
33    /// Rank-revealing LU decomposition.
34    LU,
35    /// Cross Interpolation.
36    CI,
37}
38
39impl DecompositionAlg {
40    /// Check if this algorithm is SVD-based (SVD or RSVD).
41    #[must_use]
42    pub fn is_svd_based(&self) -> bool {
43        matches!(self, Self::SVD | Self::RSVD)
44    }
45
46    /// Check if this algorithm provides orthogonal factors.
47    #[must_use]
48    pub fn is_orthogonal(&self) -> bool {
49        matches!(self, Self::SVD | Self::RSVD | Self::QR)
50    }
51}
52
53/// Threshold scaling for SVD truncation.
54///
55/// Relative thresholds compare against a scale derived from the singular values.
56/// Absolute thresholds compare directly against the configured cutoff.
57///
58/// # Examples
59///
60/// ```
61/// use tensor4all_core::ThresholdScale;
62///
63/// assert_eq!(ThresholdScale::default(), ThresholdScale::Relative);
64/// ```
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum ThresholdScale {
67    /// Compare against a singular-value-derived reference scale.
68    #[default]
69    Relative,
70    /// Compare directly against the configured threshold.
71    Absolute,
72}
73
74/// Singular-value-derived quantity used for truncation.
75///
76/// # Examples
77///
78/// ```
79/// use tensor4all_core::SingularValueMeasure;
80///
81/// assert_eq!(SingularValueMeasure::default(), SingularValueMeasure::Value);
82/// ```
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84pub enum SingularValueMeasure {
85    /// Compare using singular values `σ_i`.
86    #[default]
87    Value,
88    /// Compare using squared singular values `σ_i²`.
89    SquaredValue,
90}
91
92/// Rule used to map singular values to a retained rank.
93///
94/// # Examples
95///
96/// ```
97/// use tensor4all_core::TruncationRule;
98///
99/// assert_eq!(TruncationRule::default(), TruncationRule::PerValue);
100/// ```
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
102pub enum TruncationRule {
103    /// Keep values whose individual measure exceeds the threshold rule.
104    #[default]
105    PerValue,
106    /// Discard a suffix while the cumulative discarded measure stays below
107    /// the threshold rule.
108    DiscardedTailSum,
109}
110
111/// Explicit truncation policy for SVD-based decompositions.
112///
113/// Use this type when you need to describe how singular values are measured,
114/// scaled, and turned into a retained rank. [`SvdOptions`](crate::SvdOptions)
115/// carries this policy plus an independent `max_rank` cap.
116///
117/// # Examples
118///
119/// ```
120/// use tensor4all_core::{
121///     SingularValueMeasure, SvdTruncationPolicy, ThresholdScale, TruncationRule,
122/// };
123///
124/// let policy = SvdTruncationPolicy::new(1e-12);
125/// assert_eq!(policy.scale, ThresholdScale::Relative);
126/// assert_eq!(policy.measure, SingularValueMeasure::Value);
127/// assert_eq!(policy.rule, TruncationRule::PerValue);
128///
129/// let tail_policy = SvdTruncationPolicy::new(1e-8)
130///     .with_absolute()
131///     .with_squared_values()
132///     .with_discarded_tail_sum();
133/// assert_eq!(tail_policy.scale, ThresholdScale::Absolute);
134/// assert_eq!(tail_policy.measure, SingularValueMeasure::SquaredValue);
135/// assert_eq!(tail_policy.rule, TruncationRule::DiscardedTailSum);
136/// ```
137#[derive(Debug, Clone, Copy, PartialEq)]
138pub struct SvdTruncationPolicy {
139    /// Threshold value used by the selected scale/rule combination.
140    pub threshold: f64,
141    /// Whether the threshold is interpreted relatively or absolutely.
142    pub scale: ThresholdScale,
143    /// Whether the policy measures singular values or squared singular values.
144    pub measure: SingularValueMeasure,
145    /// Whether truncation is per value or based on a discarded tail sum.
146    pub rule: TruncationRule,
147}
148
149impl SvdTruncationPolicy {
150    /// Create a policy with the default semantics:
151    /// relative threshold, singular values, and per-value truncation.
152    #[must_use]
153    pub const fn new(threshold: f64) -> Self {
154        Self {
155            threshold,
156            scale: ThresholdScale::Relative,
157            measure: SingularValueMeasure::Value,
158            rule: TruncationRule::PerValue,
159        }
160    }
161
162    /// Use relative threshold scaling.
163    #[must_use]
164    pub const fn with_relative(mut self) -> Self {
165        self.scale = ThresholdScale::Relative;
166        self
167    }
168
169    /// Use absolute threshold scaling.
170    #[must_use]
171    pub const fn with_absolute(mut self) -> Self {
172        self.scale = ThresholdScale::Absolute;
173        self
174    }
175
176    /// Measure singular values directly.
177    #[must_use]
178    pub const fn with_values(mut self) -> Self {
179        self.measure = SingularValueMeasure::Value;
180        self
181    }
182
183    /// Measure squared singular values.
184    #[must_use]
185    pub const fn with_squared_values(mut self) -> Self {
186        self.measure = SingularValueMeasure::SquaredValue;
187        self
188    }
189
190    /// Apply the threshold independently to each singular value.
191    #[must_use]
192    pub const fn with_per_value(mut self) -> Self {
193        self.rule = TruncationRule::PerValue;
194        self
195    }
196
197    /// Apply the threshold to the cumulative discarded tail.
198    #[must_use]
199    pub const fn with_discarded_tail_sum(mut self) -> Self {
200        self.rule = TruncationRule::DiscardedTailSum;
201        self
202    }
203}
204
205/// Error for invalid SVD truncation thresholds.
206#[derive(Debug, Error, Clone, Copy, PartialEq)]
207#[error("Invalid SVD truncation threshold: {0}. Threshold must be finite and non-negative.")]
208pub struct InvalidThresholdError(pub f64);
209
210/// Validate one threshold value.
211pub(crate) fn validate_threshold_value(threshold: f64) -> Result<(), InvalidThresholdError> {
212    if !threshold.is_finite() || threshold < 0.0 {
213        return Err(InvalidThresholdError(threshold));
214    }
215    Ok(())
216}
217
218/// Validate one full SVD truncation policy.
219pub(crate) fn validate_svd_truncation_policy(
220    policy: SvdTruncationPolicy,
221) -> Result<(), InvalidThresholdError> {
222    validate_threshold_value(policy.threshold)
223}
224
225#[cfg(test)]
226mod tests;