Skip to main content

tensor4all_core/defaults/
svd.rs

1//! SVD decomposition for tensors.
2//!
3//! Provides [`svd`] and [`svd_with`] for computing truncated SVD of
4//! [`TensorDynLen`] values. The tensor is unfolded into a matrix by
5//! splitting its indices into left and right groups, then the standard
6//! matrix SVD is computed and truncated according to [`SvdOptions`].
7//!
8//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
9
10use crate::defaults::tensordynlen::unfold_split_inner;
11use crate::defaults::DynIndex;
12use crate::index_like::IndexLike;
13use crate::truncation::{
14    validate_svd_truncation_policy, SingularValueMeasure, SvdTruncationPolicy, ThresholdScale,
15    TruncationRule,
16};
17use crate::TensorDynLen;
18use std::sync::Mutex;
19use tenferro::DType;
20use tenferro_ad::EagerTensor;
21use tenferro_linalg::eager_tensor::svd as eager_svd;
22use tensor4all_tensorbackend::{
23    native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
24};
25use thiserror::Error;
26
27/// Error type for SVD operations in tensor4all-linalg.
28#[derive(Debug, Error)]
29pub enum SvdError {
30    /// SVD computation failed.
31    #[error("SVD computation failed: {0}")]
32    ComputationError(#[from] anyhow::Error),
33    /// Invalid truncation threshold value (must be finite and non-negative).
34    #[error("Invalid SVD truncation threshold: {0}. Threshold must be finite and non-negative.")]
35    InvalidThreshold(f64),
36}
37
38/// Options for SVD decomposition with truncation control.
39///
40/// # Examples
41///
42/// ```
43/// use tensor4all_core::svd::{SvdOptions, svd_with};
44/// use tensor4all_core::{DynIndex, SvdTruncationPolicy, TensorDynLen};
45///
46/// let i = DynIndex::new_dyn(3);
47/// let j = DynIndex::new_dyn(3);
48/// let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
49/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
50///
51/// let opts = SvdOptions::new().with_policy(SvdTruncationPolicy::new(1e-10));
52/// let (u, s, v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
53///
54/// // U has left index + bond, S is diagonal bond x bond, V has right index + bond
55/// assert_eq!(u.dims()[0], 3);
56/// assert_eq!(s.dims().len(), 2);
57/// ```
58#[derive(Debug, Clone, Copy)]
59pub struct SvdOptions {
60    /// Maximum retained rank after policy-based truncation.
61    pub max_rank: Option<usize>,
62    /// Per-call SVD truncation policy.
63    /// If `None`, the global default policy is used.
64    pub policy: Option<SvdTruncationPolicy>,
65    truncate: bool,
66}
67
68impl Default for SvdOptions {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl SvdOptions {
75    /// Create new SVD options with no overrides.
76    #[must_use]
77    pub fn new() -> Self {
78        Self {
79            max_rank: None,
80            policy: None,
81            truncate: true,
82        }
83    }
84
85    /// Set the maximum retained rank.
86    #[must_use]
87    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
88        self.max_rank = Some(max_rank);
89        self
90    }
91
92    /// Set the SVD truncation policy override.
93    #[must_use]
94    pub fn with_policy(mut self, policy: SvdTruncationPolicy) -> Self {
95        self.policy = Some(policy);
96        self
97    }
98
99    pub(crate) fn full_rank() -> Self {
100        Self {
101            max_rank: None,
102            policy: None,
103            truncate: false,
104        }
105    }
106}
107
108fn default_policy_guard() -> std::sync::MutexGuard<'static, SvdTruncationPolicy> {
109    match DEFAULT_SVD_TRUNCATION_POLICY.lock() {
110        Ok(guard) => guard,
111        Err(poisoned) => poisoned.into_inner(),
112    }
113}
114
115// Default value: relative per-value threshold 1e-12.
116static DEFAULT_SVD_TRUNCATION_POLICY: Mutex<SvdTruncationPolicy> =
117    Mutex::new(SvdTruncationPolicy::new(1e-12));
118
119/// Get the global default truncation policy for SVD.
120///
121/// The default policy is `SvdTruncationPolicy::new(1e-12)`.
122#[must_use]
123pub fn default_svd_truncation_policy() -> SvdTruncationPolicy {
124    *default_policy_guard()
125}
126
127/// Set the global default truncation policy for SVD.
128///
129/// # Arguments
130/// * `policy` - SVD truncation policy to use when `SvdOptions::policy` is `None`
131///
132/// # Errors
133/// Returns `SvdError::InvalidThreshold` if `policy.threshold` is invalid.
134pub fn set_default_svd_truncation_policy(policy: SvdTruncationPolicy) -> Result<(), SvdError> {
135    validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
136    *default_policy_guard() = policy;
137    Ok(())
138}
139
140fn singular_value_measure(value: f64, measure: SingularValueMeasure) -> f64 {
141    match measure {
142        SingularValueMeasure::Value => value,
143        SingularValueMeasure::SquaredValue => value * value,
144    }
145}
146
147/// Compute the retained rank based on an explicit SVD truncation policy.
148fn compute_retained_rank(s_vec: &[f64], policy: &SvdTruncationPolicy) -> usize {
149    if s_vec.is_empty() {
150        return 1;
151    }
152
153    let measured: Vec<f64> = s_vec
154        .iter()
155        .map(|&value| singular_value_measure(value, policy.measure))
156        .collect();
157    if measured.iter().all(|&value| value == 0.0) {
158        return 1;
159    }
160
161    let retained = match (policy.scale, policy.rule) {
162        (ThresholdScale::Relative, TruncationRule::PerValue) => {
163            let reference = measured.iter().copied().fold(0.0_f64, f64::max);
164            measured
165                .iter()
166                .take_while(|&&value| reference > 0.0 && value / reference > policy.threshold)
167                .count()
168        }
169        (ThresholdScale::Absolute, TruncationRule::PerValue) => measured
170            .iter()
171            .take_while(|&&value| value > policy.threshold)
172            .count(),
173        (ThresholdScale::Relative, TruncationRule::DiscardedTailSum) => {
174            let total: f64 = measured.iter().sum();
175            if total == 0.0 {
176                1
177            } else {
178                let mut discarded = 0.0;
179                let mut keep = measured.len();
180                for (i, value) in measured.iter().enumerate().rev() {
181                    if (discarded + value) / total <= policy.threshold {
182                        discarded += value;
183                        keep = i;
184                    } else {
185                        break;
186                    }
187                }
188                keep
189            }
190        }
191        (ThresholdScale::Absolute, TruncationRule::DiscardedTailSum) => {
192            let mut discarded = 0.0;
193            let mut keep = measured.len();
194            for (i, value) in measured.iter().enumerate().rev() {
195                if discarded + value <= policy.threshold {
196                    discarded += value;
197                    keep = i;
198                } else {
199                    break;
200                }
201            }
202            keep
203        }
204    };
205
206    retained.max(1)
207}
208
209fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, SvdError> {
210    match tensor.dtype() {
211        DType::F64 => {
212            native_tensor_primal_to_dense_f64_col_major(tensor).map_err(SvdError::ComputationError)
213        }
214        DType::C64 => native_tensor_primal_to_dense_c64_col_major(tensor)
215            .map(|values| values.into_iter().map(|value| value.re).collect())
216            .map_err(SvdError::ComputationError),
217        other => Err(SvdError::ComputationError(anyhow::anyhow!(
218            "native SVD returned unsupported singular-value scalar type {other:?}"
219        ))),
220    }
221}
222
223type SvdTruncatedEagerResult = (
224    EagerTensor,
225    EagerTensor,
226    EagerTensor,
227    Vec<f64>,
228    DynIndex,
229    Vec<DynIndex>,
230    Vec<DynIndex>,
231);
232
233fn svd_truncated_inner(
234    t: &TensorDynLen,
235    left_inds: &[DynIndex],
236    options: &SvdOptions,
237) -> Result<SvdTruncatedEagerResult, SvdError> {
238    let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
239        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
240        .map_err(SvdError::ComputationError)?;
241    let k = m.min(n);
242
243    let (mut u_inner, mut s_inner, mut vt_inner) =
244        eager_svd(&matrix_inner).map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
245    let s_full = singular_values_from_native(s_inner.data())?;
246    let mut r = if options.truncate {
247        let policy = options.policy.unwrap_or_else(default_svd_truncation_policy);
248        validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
249
250        let mut retained = compute_retained_rank(&s_full, &policy);
251        if let Some(max_rank) = options.max_rank {
252            retained = retained.min(max_rank);
253        }
254        retained.max(1)
255    } else {
256        k.max(1)
257    };
258    r = r.min(s_full.len());
259    if r < k {
260        let keep: Vec<usize> = (0..r).collect();
261        u_inner = u_inner
262            .take_axis(1, &keep)
263            .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
264        s_inner = s_inner
265            .take_axis(0, &keep)
266            .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
267        vt_inner = vt_inner
268            .take_axis(0, &keep)
269            .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
270    }
271
272    let bond_index = DynIndex::new_bond(r)
273        .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
274        .map_err(SvdError::ComputationError)?;
275    let singular_values = s_full[..r].to_vec();
276
277    Ok((
278        u_inner,
279        s_inner,
280        vt_inner,
281        singular_values,
282        bond_index,
283        left_indices,
284        right_indices,
285    ))
286}
287
288/// Compute SVD decomposition of a tensor with arbitrary rank, returning (U, S, V).
289///
290/// # Examples
291///
292/// ```
293/// use tensor4all_core::{TensorDynLen, DynIndex, svd};
294///
295/// // Create a 2x3 matrix (rank-1 outer product: all-ones)
296/// let i = DynIndex::new_dyn(2);
297/// let j = DynIndex::new_dyn(3);
298/// let data = vec![1.0_f64; 6]; // all-ones 2x3 matrix
299/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
300///
301/// let (u, s, v) = svd::<f64>(&t, &[i.clone()]).unwrap();
302///
303/// // U: shape (left_dim, bond) = (2, bond)
304/// assert_eq!(u.dims()[0], 2);
305/// // V: shape (right_dim, bond) = (3, bond)
306/// assert_eq!(v.dims()[0], 3);
307/// // S is a diagonal matrix (bond × bond)
308/// assert_eq!(s.dims().len(), 2);
309/// ```
310pub fn svd<T>(
311    t: &TensorDynLen,
312    left_inds: &[DynIndex],
313) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
314    svd_with::<T>(t, left_inds, &SvdOptions::default())
315}
316
317/// Compute SVD decomposition of a tensor with arbitrary rank, returning (U, S, V).
318///
319/// This function allows per-call control of the truncation policy via `SvdOptions`.
320/// If `options.policy` is `None`, it uses the global default policy.
321///
322/// # Examples
323///
324/// ```
325/// use tensor4all_core::{DynIndex, TensorDynLen};
326/// use tensor4all_core::svd::{SvdOptions, svd_with};
327///
328/// let i = DynIndex::new_dyn(4);
329/// let j = DynIndex::new_dyn(4);
330/// // Rank-1 matrix
331/// let mut data = vec![0.0_f64; 16];
332/// data[0] = 1.0;
333/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
334///
335/// use tensor4all_core::SvdTruncationPolicy;
336///
337/// // Truncate with a relative per-value threshold => rank 1
338/// let opts = SvdOptions::new().with_policy(SvdTruncationPolicy::new(1e-10));
339/// let (u, s, _v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
340/// assert_eq!(s.dims()[0], 1);  // rank-1
341///
342/// // Truncate with max_rank => capped
343/// let opts = SvdOptions::new().with_max_rank(2);
344/// let (_u, s, _v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
345/// assert!(s.dims()[0] <= 2);
346/// ```
347pub fn svd_with<T>(
348    t: &TensorDynLen,
349    left_inds: &[DynIndex],
350    options: &SvdOptions,
351) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
352    let (u_inner, s_inner, vt_inner, _singular_values, bond_index, left_indices, right_indices) =
353        svd_truncated_inner(t, left_inds, options)?;
354
355    let mut u_indices = left_indices;
356    u_indices.push(bond_index.clone());
357    let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
358    let u_reshaped = u_inner.reshape(&u_dims).map_err(|e| {
359        SvdError::ComputationError(anyhow::anyhow!("eager SVD U reshape failed: {e}"))
360    })?;
361    let u = TensorDynLen::from_inner(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
362
363    let s_indices = vec![bond_index.clone(), bond_index.sim()];
364    let s =
365        TensorDynLen::from_diag_inner(s_indices, s_inner).map_err(SvdError::ComputationError)?;
366
367    let mut vh_indices = vec![bond_index.clone()];
368    vh_indices.extend(right_indices);
369    let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
370    let vt_reshaped = vt_inner.reshape(&vh_dims).map_err(|e| {
371        SvdError::ComputationError(anyhow::anyhow!("eager SVD V^T reshape failed: {e}"))
372    })?;
373    let vh =
374        TensorDynLen::from_inner(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
375    let perm: Vec<usize> = (1..vh.indices.len()).chain(std::iter::once(0)).collect();
376    let v = vh
377        .conj()
378        .permute(&perm)
379        .map_err(SvdError::ComputationError)?;
380
381    Ok((u, s, v))
382}
383
384/// SVD result for factorization, returning `V^H` directly.
385pub(crate) struct SvdFactorizeResult {
386    pub u: TensorDynLen,
387    pub s: TensorDynLen,
388    pub vh: TensorDynLen,
389    pub bond_index: DynIndex,
390    pub singular_values: Vec<f64>,
391    pub rank: usize,
392}
393
394/// Compute truncated SVD for factorization, returning `V^H` instead of `V`.
395pub(crate) fn svd_for_factorize(
396    t: &TensorDynLen,
397    left_inds: &[DynIndex],
398    options: &SvdOptions,
399) -> Result<SvdFactorizeResult, SvdError> {
400    let (u_inner, s_inner, vt_inner, singular_values, bond_index, left_indices, right_indices) =
401        svd_truncated_inner(t, left_inds, options)?;
402    let rank = singular_values.len();
403
404    let mut u_indices = left_indices;
405    u_indices.push(bond_index.clone());
406    let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
407    let u_reshaped = u_inner.reshape(&u_dims).map_err(|e| {
408        SvdError::ComputationError(anyhow::anyhow!("eager SVD U reshape failed: {e}"))
409    })?;
410    let u = TensorDynLen::from_inner(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
411
412    let s_indices = vec![bond_index.clone(), bond_index.sim()];
413    let s =
414        TensorDynLen::from_diag_inner(s_indices, s_inner).map_err(SvdError::ComputationError)?;
415
416    let mut vh_indices = vec![bond_index.clone()];
417    vh_indices.extend(right_indices);
418    let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
419    let vt_reshaped = vt_inner.reshape(&vh_dims).map_err(|e| {
420        SvdError::ComputationError(anyhow::anyhow!("eager SVD V^T reshape failed: {e}"))
421    })?;
422    let vh =
423        TensorDynLen::from_inner(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
424
425    Ok(SvdFactorizeResult {
426        u,
427        s,
428        vh,
429        bond_index,
430        singular_values,
431        rank,
432    })
433}
434
435#[cfg(test)]
436mod tests;