Skip to main content

tensor4all_core/defaults/
svd.rs

1//! SVD decomposition for tensors.
2//!
3//! Provides [`svd`] and [`svd_with`] for computing truncated SVD of
4//! [`TensorDynLen`] values. The tensor is unfolded into a matrix by
5//! splitting its indices into left and right groups, then the standard
6//! matrix SVD is computed and truncated according to [`SvdOptions`].
7//!
8//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
9
10use crate::defaults::tensordynlen::unfold_split_inner;
11use crate::defaults::DynIndex;
12use crate::index_like::IndexLike;
13use crate::truncation::{
14    validate_svd_truncation_policy, SingularValueMeasure, SvdTruncationPolicy, ThresholdScale,
15    TruncationRule,
16};
17use crate::TensorDynLen;
18use std::sync::Mutex;
19use tenferro::{CpuBackend, DType, EagerTensor};
20use tensor4all_tensorbackend::{
21    native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
22};
23use thiserror::Error;
24
25/// Error type for SVD operations in tensor4all-linalg.
26#[derive(Debug, Error)]
27pub enum SvdError {
28    /// SVD computation failed.
29    #[error("SVD computation failed: {0}")]
30    ComputationError(#[from] anyhow::Error),
31    /// Invalid truncation threshold value (must be finite and non-negative).
32    #[error("Invalid SVD truncation threshold: {0}. Threshold must be finite and non-negative.")]
33    InvalidThreshold(f64),
34}
35
36/// Options for SVD decomposition with truncation control.
37///
38/// # Examples
39///
40/// ```
41/// use tensor4all_core::svd::{SvdOptions, svd_with};
42/// use tensor4all_core::{DynIndex, SvdTruncationPolicy, TensorDynLen};
43///
44/// let i = DynIndex::new_dyn(3);
45/// let j = DynIndex::new_dyn(3);
46/// let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
47/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
48///
49/// let opts = SvdOptions::new().with_policy(SvdTruncationPolicy::new(1e-10));
50/// let (u, s, v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
51///
52/// // U has left index + bond, S is diagonal bond x bond, V has right index + bond
53/// assert_eq!(u.dims()[0], 3);
54/// assert_eq!(s.dims().len(), 2);
55/// ```
56#[derive(Debug, Clone, Copy)]
57pub struct SvdOptions {
58    /// Maximum retained rank after policy-based truncation.
59    pub max_rank: Option<usize>,
60    /// Per-call SVD truncation policy.
61    /// If `None`, the global default policy is used.
62    pub policy: Option<SvdTruncationPolicy>,
63    truncate: bool,
64}
65
66impl Default for SvdOptions {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl SvdOptions {
73    /// Create new SVD options with no overrides.
74    #[must_use]
75    pub fn new() -> Self {
76        Self {
77            max_rank: None,
78            policy: None,
79            truncate: true,
80        }
81    }
82
83    /// Set the maximum retained rank.
84    #[must_use]
85    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
86        self.max_rank = Some(max_rank);
87        self
88    }
89
90    /// Set the SVD truncation policy override.
91    #[must_use]
92    pub fn with_policy(mut self, policy: SvdTruncationPolicy) -> Self {
93        self.policy = Some(policy);
94        self
95    }
96
97    pub(crate) fn full_rank() -> Self {
98        Self {
99            max_rank: None,
100            policy: None,
101            truncate: false,
102        }
103    }
104}
105
106fn default_policy_guard() -> std::sync::MutexGuard<'static, SvdTruncationPolicy> {
107    match DEFAULT_SVD_TRUNCATION_POLICY.lock() {
108        Ok(guard) => guard,
109        Err(poisoned) => poisoned.into_inner(),
110    }
111}
112
113// Default value: relative per-value threshold 1e-12.
114static DEFAULT_SVD_TRUNCATION_POLICY: Mutex<SvdTruncationPolicy> =
115    Mutex::new(SvdTruncationPolicy::new(1e-12));
116
117/// Get the global default truncation policy for SVD.
118///
119/// The default policy is `SvdTruncationPolicy::new(1e-12)`.
120#[must_use]
121pub fn default_svd_truncation_policy() -> SvdTruncationPolicy {
122    *default_policy_guard()
123}
124
125/// Set the global default truncation policy for SVD.
126///
127/// # Arguments
128/// * `policy` - SVD truncation policy to use when `SvdOptions::policy` is `None`
129///
130/// # Errors
131/// Returns `SvdError::InvalidThreshold` if `policy.threshold` is invalid.
132pub fn set_default_svd_truncation_policy(policy: SvdTruncationPolicy) -> Result<(), SvdError> {
133    validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
134    *default_policy_guard() = policy;
135    Ok(())
136}
137
138fn singular_value_measure(value: f64, measure: SingularValueMeasure) -> f64 {
139    match measure {
140        SingularValueMeasure::Value => value,
141        SingularValueMeasure::SquaredValue => value * value,
142    }
143}
144
145/// Compute the retained rank based on an explicit SVD truncation policy.
146fn compute_retained_rank(s_vec: &[f64], policy: &SvdTruncationPolicy) -> usize {
147    if s_vec.is_empty() {
148        return 1;
149    }
150
151    let measured: Vec<f64> = s_vec
152        .iter()
153        .map(|&value| singular_value_measure(value, policy.measure))
154        .collect();
155    if measured.iter().all(|&value| value == 0.0) {
156        return 1;
157    }
158
159    let retained = match (policy.scale, policy.rule) {
160        (ThresholdScale::Relative, TruncationRule::PerValue) => {
161            let reference = measured.iter().copied().fold(0.0_f64, f64::max);
162            measured
163                .iter()
164                .take_while(|&&value| reference > 0.0 && value / reference > policy.threshold)
165                .count()
166        }
167        (ThresholdScale::Absolute, TruncationRule::PerValue) => measured
168            .iter()
169            .take_while(|&&value| value > policy.threshold)
170            .count(),
171        (ThresholdScale::Relative, TruncationRule::DiscardedTailSum) => {
172            let total: f64 = measured.iter().sum();
173            if total == 0.0 {
174                1
175            } else {
176                let mut discarded = 0.0;
177                let mut keep = measured.len();
178                for (i, value) in measured.iter().enumerate().rev() {
179                    if (discarded + value) / total <= policy.threshold {
180                        discarded += value;
181                        keep = i;
182                    } else {
183                        break;
184                    }
185                }
186                keep
187            }
188        }
189        (ThresholdScale::Absolute, TruncationRule::DiscardedTailSum) => {
190            let mut discarded = 0.0;
191            let mut keep = measured.len();
192            for (i, value) in measured.iter().enumerate().rev() {
193                if discarded + value <= policy.threshold {
194                    discarded += value;
195                    keep = i;
196                } else {
197                    break;
198                }
199            }
200            keep
201        }
202    };
203
204    retained.max(1)
205}
206
207fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, SvdError> {
208    match tensor.dtype() {
209        DType::F64 => {
210            native_tensor_primal_to_dense_f64_col_major(tensor).map_err(SvdError::ComputationError)
211        }
212        DType::C64 => native_tensor_primal_to_dense_c64_col_major(tensor)
213            .map(|values| values.into_iter().map(|value| value.re).collect())
214            .map_err(SvdError::ComputationError),
215        other => Err(SvdError::ComputationError(anyhow::anyhow!(
216            "native SVD returned unsupported singular-value scalar type {other:?}"
217        ))),
218    }
219}
220
221type SvdTruncatedEagerResult = (
222    EagerTensor<CpuBackend>,
223    EagerTensor<CpuBackend>,
224    EagerTensor<CpuBackend>,
225    Vec<f64>,
226    DynIndex,
227    Vec<DynIndex>,
228    Vec<DynIndex>,
229);
230
231fn svd_truncated_inner(
232    t: &TensorDynLen,
233    left_inds: &[DynIndex],
234    options: &SvdOptions,
235) -> Result<SvdTruncatedEagerResult, SvdError> {
236    let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
237        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
238        .map_err(SvdError::ComputationError)?;
239    let k = m.min(n);
240
241    let (mut u_inner, mut s_inner, mut vt_inner) = matrix_inner
242        .svd()
243        .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
244    let s_full = singular_values_from_native(s_inner.data())?;
245    let mut r = if options.truncate {
246        let policy = options.policy.unwrap_or_else(default_svd_truncation_policy);
247        validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
248
249        let mut retained = compute_retained_rank(&s_full, &policy);
250        if let Some(max_rank) = options.max_rank {
251            retained = retained.min(max_rank);
252        }
253        retained.max(1)
254    } else {
255        k.max(1)
256    };
257    r = r.min(s_full.len());
258    if r < k {
259        let keep: Vec<usize> = (0..r).collect();
260        u_inner = u_inner
261            .take_axis(1, &keep)
262            .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
263        s_inner = s_inner
264            .take_axis(0, &keep)
265            .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
266        vt_inner = vt_inner
267            .take_axis(0, &keep)
268            .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
269    }
270
271    let bond_index = DynIndex::new_bond(r)
272        .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
273        .map_err(SvdError::ComputationError)?;
274    let singular_values = s_full[..r].to_vec();
275
276    Ok((
277        u_inner,
278        s_inner,
279        vt_inner,
280        singular_values,
281        bond_index,
282        left_indices,
283        right_indices,
284    ))
285}
286
287/// Compute SVD decomposition of a tensor with arbitrary rank, returning (U, S, V).
288///
289/// # Examples
290///
291/// ```
292/// use tensor4all_core::{TensorDynLen, DynIndex, svd};
293///
294/// // Create a 2x3 matrix (rank-1 outer product: all-ones)
295/// let i = DynIndex::new_dyn(2);
296/// let j = DynIndex::new_dyn(3);
297/// let data = vec![1.0_f64; 6]; // all-ones 2x3 matrix
298/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
299///
300/// let (u, s, v) = svd::<f64>(&t, &[i.clone()]).unwrap();
301///
302/// // U: shape (left_dim, bond) = (2, bond)
303/// assert_eq!(u.dims()[0], 2);
304/// // V: shape (right_dim, bond) = (3, bond)
305/// assert_eq!(v.dims()[0], 3);
306/// // S is a diagonal matrix (bond × bond)
307/// assert_eq!(s.dims().len(), 2);
308/// ```
309pub fn svd<T>(
310    t: &TensorDynLen,
311    left_inds: &[DynIndex],
312) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
313    svd_with::<T>(t, left_inds, &SvdOptions::default())
314}
315
316/// Compute SVD decomposition of a tensor with arbitrary rank, returning (U, S, V).
317///
318/// This function allows per-call control of the truncation policy via `SvdOptions`.
319/// If `options.policy` is `None`, it uses the global default policy.
320///
321/// # Examples
322///
323/// ```
324/// use tensor4all_core::{DynIndex, TensorDynLen};
325/// use tensor4all_core::svd::{SvdOptions, svd_with};
326///
327/// let i = DynIndex::new_dyn(4);
328/// let j = DynIndex::new_dyn(4);
329/// // Rank-1 matrix
330/// let mut data = vec![0.0_f64; 16];
331/// data[0] = 1.0;
332/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
333///
334/// use tensor4all_core::SvdTruncationPolicy;
335///
336/// // Truncate with a relative per-value threshold => rank 1
337/// let opts = SvdOptions::new().with_policy(SvdTruncationPolicy::new(1e-10));
338/// let (u, s, _v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
339/// assert_eq!(s.dims()[0], 1);  // rank-1
340///
341/// // Truncate with max_rank => capped
342/// let opts = SvdOptions::new().with_max_rank(2);
343/// let (_u, s, _v) = svd_with::<f64>(&tensor, &[i.clone()], &opts).unwrap();
344/// assert!(s.dims()[0] <= 2);
345/// ```
346pub fn svd_with<T>(
347    t: &TensorDynLen,
348    left_inds: &[DynIndex],
349    options: &SvdOptions,
350) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
351    let (u_inner, s_inner, vt_inner, _singular_values, bond_index, left_indices, right_indices) =
352        svd_truncated_inner(t, left_inds, options)?;
353
354    let mut u_indices = left_indices;
355    u_indices.push(bond_index.clone());
356    let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
357    let u_reshaped = u_inner.reshape(&u_dims).map_err(|e| {
358        SvdError::ComputationError(anyhow::anyhow!("eager SVD U reshape failed: {e}"))
359    })?;
360    let u = TensorDynLen::from_inner(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
361
362    let s_indices = vec![bond_index.clone(), bond_index.sim()];
363    let s =
364        TensorDynLen::from_diag_inner(s_indices, s_inner).map_err(SvdError::ComputationError)?;
365
366    let mut vh_indices = vec![bond_index.clone()];
367    vh_indices.extend(right_indices);
368    let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
369    let vt_reshaped = vt_inner.reshape(&vh_dims).map_err(|e| {
370        SvdError::ComputationError(anyhow::anyhow!("eager SVD V^T reshape failed: {e}"))
371    })?;
372    let vh =
373        TensorDynLen::from_inner(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
374    let perm: Vec<usize> = (1..vh.indices.len()).chain(std::iter::once(0)).collect();
375    let v = vh
376        .conj()
377        .permute(&perm)
378        .map_err(SvdError::ComputationError)?;
379
380    Ok((u, s, v))
381}
382
383/// SVD result for factorization, returning `V^H` directly.
384pub(crate) struct SvdFactorizeResult {
385    pub u: TensorDynLen,
386    pub s: TensorDynLen,
387    pub vh: TensorDynLen,
388    pub bond_index: DynIndex,
389    pub singular_values: Vec<f64>,
390    pub rank: usize,
391}
392
393/// Compute truncated SVD for factorization, returning `V^H` instead of `V`.
394pub(crate) fn svd_for_factorize(
395    t: &TensorDynLen,
396    left_inds: &[DynIndex],
397    options: &SvdOptions,
398) -> Result<SvdFactorizeResult, SvdError> {
399    let (u_inner, s_inner, vt_inner, singular_values, bond_index, left_indices, right_indices) =
400        svd_truncated_inner(t, left_inds, options)?;
401    let rank = singular_values.len();
402
403    let mut u_indices = left_indices;
404    u_indices.push(bond_index.clone());
405    let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
406    let u_reshaped = u_inner.reshape(&u_dims).map_err(|e| {
407        SvdError::ComputationError(anyhow::anyhow!("eager SVD U reshape failed: {e}"))
408    })?;
409    let u = TensorDynLen::from_inner(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
410
411    let s_indices = vec![bond_index.clone(), bond_index.sim()];
412    let s =
413        TensorDynLen::from_diag_inner(s_indices, s_inner).map_err(SvdError::ComputationError)?;
414
415    let mut vh_indices = vec![bond_index.clone()];
416    vh_indices.extend(right_indices);
417    let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
418    let vt_reshaped = vt_inner.reshape(&vh_dims).map_err(|e| {
419        SvdError::ComputationError(anyhow::anyhow!("eager SVD V^T reshape failed: {e}"))
420    })?;
421    let vh =
422        TensorDynLen::from_inner(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
423
424    Ok(SvdFactorizeResult {
425        u,
426        s,
427        vh,
428        bond_index,
429        singular_values,
430        rank,
431    })
432}
433
434#[cfg(test)]
435mod tests;