Skip to main content

tensor4all_core/defaults/
svd.rs

1//! SVD decomposition for tensors.
2//!
3//! Provides [`svd`] and [`svd_with`] for computing truncated SVD of
4//! [`TensorDynLen`] values. The tensor is unfolded into a matrix by
5//! splitting its indices into left and right groups, then the standard
6//! matrix SVD is computed and truncated according to [`SvdOptions`].
7//!
8//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
9
10use crate::defaults::DynIndex;
11use crate::index_like::IndexLike;
12use crate::truncation::{
13    validate_svd_truncation_policy, SingularValueMeasure, SvdTruncationPolicy, ThresholdScale,
14    TruncationRule,
15};
16use crate::{unfold_split, TensorDynLen};
17use std::sync::Mutex;
18use tenferro::DType;
19use tensor4all_tensorbackend::{
20    dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
21    native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
22    reshape_col_major_native_tensor, svd_native_tensor, TensorElement,
23};
24use thiserror::Error;
25
26/// Error type for SVD operations in tensor4all-linalg.
27#[derive(Debug, Error)]
28pub enum SvdError {
29    /// SVD computation failed.
30    #[error("SVD computation failed: {0}")]
31    ComputationError(#[from] anyhow::Error),
32    /// Invalid truncation threshold value (must be finite and non-negative).
33    #[error("Invalid SVD truncation threshold: {0}. Threshold must be finite and non-negative.")]
34    InvalidThreshold(f64),
35}
36
37/// Options for SVD decomposition with truncation control.
38///
39/// # Examples
40///
41/// ```
42/// use tensor4all_core::svd::{SvdOptions, svd_with};
43/// use tensor4all_core::{DynIndex, SvdTruncationPolicy, TensorDynLen};
44///
45/// let i = DynIndex::new_dyn(3);
46/// let j = DynIndex::new_dyn(3);
47/// let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
48/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
49///
50/// let opts = SvdOptions::new().with_policy(SvdTruncationPolicy::new(1e-10));
51/// let (u, s, v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
52///
53/// // U has left index + bond, S is diagonal bond x bond, V has right index + bond
54/// assert_eq!(u.dims()[0], 3);
55/// assert_eq!(s.dims().len(), 2);
56/// ```
57#[derive(Debug, Clone, Copy)]
58pub struct SvdOptions {
59    /// Maximum retained rank after policy-based truncation.
60    pub max_rank: Option<usize>,
61    /// Per-call SVD truncation policy.
62    /// If `None`, the global default policy is used.
63    pub policy: Option<SvdTruncationPolicy>,
64    truncate: bool,
65}
66
67impl Default for SvdOptions {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl SvdOptions {
74    /// Create new SVD options with no overrides.
75    #[must_use]
76    pub fn new() -> Self {
77        Self {
78            max_rank: None,
79            policy: None,
80            truncate: true,
81        }
82    }
83
84    /// Set the maximum retained rank.
85    #[must_use]
86    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
87        self.max_rank = Some(max_rank);
88        self
89    }
90
91    /// Set the SVD truncation policy override.
92    #[must_use]
93    pub fn with_policy(mut self, policy: SvdTruncationPolicy) -> Self {
94        self.policy = Some(policy);
95        self
96    }
97
98    pub(crate) fn full_rank() -> Self {
99        Self {
100            max_rank: None,
101            policy: None,
102            truncate: false,
103        }
104    }
105}
106
107fn default_policy_guard() -> std::sync::MutexGuard<'static, SvdTruncationPolicy> {
108    match DEFAULT_SVD_TRUNCATION_POLICY.lock() {
109        Ok(guard) => guard,
110        Err(poisoned) => poisoned.into_inner(),
111    }
112}
113
114// Default value: relative per-value threshold 1e-12.
115static DEFAULT_SVD_TRUNCATION_POLICY: Mutex<SvdTruncationPolicy> =
116    Mutex::new(SvdTruncationPolicy::new(1e-12));
117
118/// Get the global default truncation policy for SVD.
119///
120/// The default policy is `SvdTruncationPolicy::new(1e-12)`.
121#[must_use]
122pub fn default_svd_truncation_policy() -> SvdTruncationPolicy {
123    *default_policy_guard()
124}
125
126/// Set the global default truncation policy for SVD.
127///
128/// # Arguments
129/// * `policy` - SVD truncation policy to use when `SvdOptions::policy` is `None`
130///
131/// # Errors
132/// Returns `SvdError::InvalidThreshold` if `policy.threshold` is invalid.
133pub fn set_default_svd_truncation_policy(policy: SvdTruncationPolicy) -> Result<(), SvdError> {
134    validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
135    *default_policy_guard() = policy;
136    Ok(())
137}
138
139fn singular_value_measure(value: f64, measure: SingularValueMeasure) -> f64 {
140    match measure {
141        SingularValueMeasure::Value => value,
142        SingularValueMeasure::SquaredValue => value * value,
143    }
144}
145
146/// Compute the retained rank based on an explicit SVD truncation policy.
147fn compute_retained_rank(s_vec: &[f64], policy: &SvdTruncationPolicy) -> usize {
148    if s_vec.is_empty() {
149        return 1;
150    }
151
152    let measured: Vec<f64> = s_vec
153        .iter()
154        .map(|&value| singular_value_measure(value, policy.measure))
155        .collect();
156    if measured.iter().all(|&value| value == 0.0) {
157        return 1;
158    }
159
160    let retained = match (policy.scale, policy.rule) {
161        (ThresholdScale::Relative, TruncationRule::PerValue) => {
162            let reference = measured.iter().copied().fold(0.0_f64, f64::max);
163            measured
164                .iter()
165                .take_while(|&&value| reference > 0.0 && value / reference > policy.threshold)
166                .count()
167        }
168        (ThresholdScale::Absolute, TruncationRule::PerValue) => measured
169            .iter()
170            .take_while(|&&value| value > policy.threshold)
171            .count(),
172        (ThresholdScale::Relative, TruncationRule::DiscardedTailSum) => {
173            let total: f64 = measured.iter().sum();
174            if total == 0.0 {
175                1
176            } else {
177                let mut discarded = 0.0;
178                let mut keep = measured.len();
179                for (i, value) in measured.iter().enumerate().rev() {
180                    if (discarded + value) / total <= policy.threshold {
181                        discarded += value;
182                        keep = i;
183                    } else {
184                        break;
185                    }
186                }
187                keep
188            }
189        }
190        (ThresholdScale::Absolute, TruncationRule::DiscardedTailSum) => {
191            let mut discarded = 0.0;
192            let mut keep = measured.len();
193            for (i, value) in measured.iter().enumerate().rev() {
194                if discarded + value <= policy.threshold {
195                    discarded += value;
196                    keep = i;
197                } else {
198                    break;
199                }
200            }
201            keep
202        }
203    };
204
205    retained.max(1)
206}
207
208fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, SvdError> {
209    match tensor.dtype() {
210        DType::F64 => {
211            native_tensor_primal_to_dense_f64_col_major(tensor).map_err(SvdError::ComputationError)
212        }
213        DType::C64 => native_tensor_primal_to_dense_c64_col_major(tensor)
214            .map(|values| values.into_iter().map(|value| value.re).collect())
215            .map_err(SvdError::ComputationError),
216        other => Err(SvdError::ComputationError(anyhow::anyhow!(
217            "native SVD returned unsupported singular-value scalar type {other:?}"
218        ))),
219    }
220}
221
222fn truncate_matrix_cols<T: TensorElement>(
223    data: &[T],
224    rows: usize,
225    keep_cols: usize,
226) -> anyhow::Result<tenferro::Tensor> {
227    dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
228}
229
230fn truncate_matrix_rows<T: TensorElement>(
231    data: &[T],
232    rows: usize,
233    cols: usize,
234    keep_rows: usize,
235) -> anyhow::Result<tenferro::Tensor> {
236    let mut truncated = Vec::with_capacity(keep_rows * cols);
237    for col in 0..cols {
238        let start = col * rows;
239        truncated.extend_from_slice(&data[start..start + keep_rows]);
240    }
241    dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
242}
243
244type SvdTruncatedNativeResult = (
245    tenferro::Tensor,
246    tenferro::Tensor,
247    tenferro::Tensor,
248    Vec<f64>,
249    DynIndex,
250    Vec<DynIndex>,
251    Vec<DynIndex>,
252);
253
254fn svd_truncated_native(
255    t: &TensorDynLen,
256    left_inds: &[DynIndex],
257    options: &SvdOptions,
258) -> Result<SvdTruncatedNativeResult, SvdError> {
259    let (matrix_native, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
260        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
261        .map_err(SvdError::ComputationError)?;
262    let k = m.min(n);
263
264    let (mut u_native, mut s_native, mut vt_native) =
265        svd_native_tensor(&matrix_native).map_err(SvdError::ComputationError)?;
266    let s_full = singular_values_from_native(&s_native)?;
267    let mut r = if options.truncate {
268        let policy = options.policy.unwrap_or_else(default_svd_truncation_policy);
269        validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
270
271        let mut retained = compute_retained_rank(&s_full, &policy);
272        if let Some(max_rank) = options.max_rank {
273            retained = retained.min(max_rank);
274        }
275        retained.max(1)
276    } else {
277        k.max(1)
278    };
279    r = r.min(s_full.len());
280    if r < k {
281        match u_native.dtype() {
282            DType::F64 => {
283                let u_values = native_tensor_primal_to_dense_f64_col_major(&u_native)
284                    .map_err(SvdError::ComputationError)?;
285                let vt_values = native_tensor_primal_to_dense_f64_col_major(&vt_native)
286                    .map_err(SvdError::ComputationError)?;
287                u_native =
288                    truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
289                vt_native = truncate_matrix_rows(&vt_values, k, n, r)
290                    .map_err(SvdError::ComputationError)?;
291            }
292            DType::C64 => {
293                let u_values = native_tensor_primal_to_dense_c64_col_major(&u_native)
294                    .map_err(SvdError::ComputationError)?;
295                let vt_values = native_tensor_primal_to_dense_c64_col_major(&vt_native)
296                    .map_err(SvdError::ComputationError)?;
297                u_native =
298                    truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
299                vt_native = truncate_matrix_rows(&vt_values, k, n, r)
300                    .map_err(SvdError::ComputationError)?;
301            }
302            other => {
303                return Err(SvdError::ComputationError(anyhow::anyhow!(
304                    "native SVD returned unsupported singular-vector scalar type {other:?}"
305                )));
306            }
307        }
308        s_native = dense_native_tensor_from_col_major(&s_full[..r], &[r])
309            .map_err(SvdError::ComputationError)?;
310    }
311
312    let bond_index = DynIndex::new_bond(r)
313        .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
314        .map_err(SvdError::ComputationError)?;
315    let singular_values = s_full[..r].to_vec();
316
317    Ok((
318        u_native,
319        s_native,
320        vt_native,
321        singular_values,
322        bond_index,
323        left_indices,
324        right_indices,
325    ))
326}
327
328/// Compute SVD decomposition of a tensor with arbitrary rank, returning (U, S, V).
329///
330/// # Examples
331///
332/// ```
333/// use tensor4all_core::{TensorDynLen, DynIndex, svd};
334///
335/// // Create a 2x3 matrix (rank-1 outer product: all-ones)
336/// let i = DynIndex::new_dyn(2);
337/// let j = DynIndex::new_dyn(3);
338/// let data = vec![1.0_f64; 6]; // all-ones 2x3 matrix
339/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
340///
341/// let (u, s, v) = svd::<f64>(&t, &[i.clone()]).unwrap();
342///
343/// // U: shape (left_dim, bond) = (2, bond)
344/// assert_eq!(u.dims()[0], 2);
345/// // V: shape (right_dim, bond) = (3, bond)
346/// assert_eq!(v.dims()[0], 3);
347/// // S is a diagonal matrix (bond × bond)
348/// assert_eq!(s.dims().len(), 2);
349/// ```
350pub fn svd<T>(
351    t: &TensorDynLen,
352    left_inds: &[DynIndex],
353) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
354    svd_with::<T>(t, left_inds, &SvdOptions::default())
355}
356
357/// Compute SVD decomposition of a tensor with arbitrary rank, returning (U, S, V).
358///
359/// This function allows per-call control of the truncation policy via `SvdOptions`.
360/// If `options.policy` is `None`, it uses the global default policy.
361///
362/// # Examples
363///
364/// ```
365/// use tensor4all_core::{DynIndex, TensorDynLen};
366/// use tensor4all_core::svd::{SvdOptions, svd_with};
367///
368/// let i = DynIndex::new_dyn(4);
369/// let j = DynIndex::new_dyn(4);
370/// // Rank-1 matrix
371/// let mut data = vec![0.0_f64; 16];
372/// data[0] = 1.0;
373/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
374///
375/// use tensor4all_core::SvdTruncationPolicy;
376///
377/// // Truncate with a relative per-value threshold => rank 1
378/// let opts = SvdOptions::new().with_policy(SvdTruncationPolicy::new(1e-10));
379/// let (u, s, _v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
380/// assert_eq!(s.dims()[0], 1);  // rank-1
381///
382/// // Truncate with max_rank => capped
383/// let opts = SvdOptions::new().with_max_rank(2);
384/// let (_u, s, _v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
385/// assert!(s.dims()[0] <= 2);
386/// ```
387pub fn svd_with<T>(
388    t: &TensorDynLen,
389    left_inds: &[DynIndex],
390    options: &SvdOptions,
391) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
392    let (u_native, s_native, vt_native, _singular_values, bond_index, left_indices, right_indices) =
393        svd_truncated_native(t, left_inds, options)?;
394
395    let mut u_indices = left_indices;
396    u_indices.push(bond_index.clone());
397    let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
398    let u_reshaped = reshape_col_major_native_tensor(&u_native, &u_dims).map_err(|e| {
399        SvdError::ComputationError(anyhow::anyhow!("native SVD U reshape failed: {e}"))
400    })?;
401    let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
402
403    let s_indices = vec![bond_index.clone(), bond_index.sim()];
404    let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
405        .map_err(SvdError::ComputationError)?;
406    let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
407
408    let mut vh_indices = vec![bond_index.clone()];
409    vh_indices.extend(right_indices);
410    let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
411    let vt_reshaped = reshape_col_major_native_tensor(&vt_native, &vh_dims).map_err(|e| {
412        SvdError::ComputationError(anyhow::anyhow!("native SVD V^T reshape failed: {e}"))
413    })?;
414    let vh =
415        TensorDynLen::from_native(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
416    let perm: Vec<usize> = (1..vh.indices.len()).chain(std::iter::once(0)).collect();
417    let v = vh.conj().permute(&perm);
418
419    Ok((u, s, v))
420}
421
422/// SVD result for factorization, returning `V^H` directly.
423pub(crate) struct SvdFactorizeResult {
424    pub u: TensorDynLen,
425    pub s: TensorDynLen,
426    pub vh: TensorDynLen,
427    pub bond_index: DynIndex,
428    pub singular_values: Vec<f64>,
429    pub rank: usize,
430}
431
432/// Compute truncated SVD for factorization, returning `V^H` instead of `V`.
433pub(crate) fn svd_for_factorize(
434    t: &TensorDynLen,
435    left_inds: &[DynIndex],
436    options: &SvdOptions,
437) -> Result<SvdFactorizeResult, SvdError> {
438    let (u_native, s_native, vt_native, singular_values, bond_index, left_indices, right_indices) =
439        svd_truncated_native(t, left_inds, options)?;
440    let rank = singular_values.len();
441
442    let mut u_indices = left_indices;
443    u_indices.push(bond_index.clone());
444    let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
445    let u_reshaped = reshape_col_major_native_tensor(&u_native, &u_dims).map_err(|e| {
446        SvdError::ComputationError(anyhow::anyhow!("native SVD U reshape failed: {e}"))
447    })?;
448    let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
449
450    let s_indices = vec![bond_index.clone(), bond_index.sim()];
451    let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
452        .map_err(SvdError::ComputationError)?;
453    let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
454
455    let mut vh_indices = vec![bond_index.clone()];
456    vh_indices.extend(right_indices);
457    let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
458    let vt_reshaped = reshape_col_major_native_tensor(&vt_native, &vh_dims).map_err(|e| {
459        SvdError::ComputationError(anyhow::anyhow!("native SVD V^T reshape failed: {e}"))
460    })?;
461    let vh =
462        TensorDynLen::from_native(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
463
464    Ok(SvdFactorizeResult {
465        u,
466        s,
467        vh,
468        bond_index,
469        singular_values,
470        rank,
471    })
472}
473
474#[cfg(test)]
475mod tests;