tensor4all_core/defaults/
svd.rs

1//! SVD decomposition for tensors.
2//!
3//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
4
5use crate::defaults::DynIndex;
6use crate::global_default::GlobalDefault;
7use crate::index_like::IndexLike;
8use crate::truncation::{HasTruncationParams, TruncationParams};
9use crate::{unfold_split, TensorDynLen};
10use tensor4all_tensorbackend::{
11    dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
12    native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
13    reshape_col_major_native_tensor, svd_native_tensor, TensorElement,
14};
15use thiserror::Error;
16
17/// Error type for SVD operations in tensor4all-linalg.
18#[derive(Debug, Error)]
19pub enum SvdError {
20    /// SVD computation failed.
21    #[error("SVD computation failed: {0}")]
22    ComputationError(#[from] anyhow::Error),
23    /// Invalid relative tolerance value (must be finite and non-negative).
24    #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
25    InvalidRtol(f64),
26}
27
28/// Options for SVD decomposition with truncation control.
29#[derive(Debug, Clone, Copy, Default)]
30pub struct SvdOptions {
31    /// Truncation parameters (rtol, max_rank).
32    pub truncation: TruncationParams,
33}
34
35impl SvdOptions {
36    /// Create new SVD options with the specified rtol.
37    pub fn with_rtol(rtol: f64) -> Self {
38        Self {
39            truncation: TruncationParams::new().with_rtol(rtol),
40        }
41    }
42
43    /// Create new SVD options with the specified max_rank.
44    pub fn with_max_rank(max_rank: usize) -> Self {
45        Self {
46            truncation: TruncationParams::new().with_max_rank(max_rank),
47        }
48    }
49
50    /// Get rtol from options (for backwards compatibility).
51    pub fn rtol(&self) -> Option<f64> {
52        self.truncation.rtol
53    }
54
55    /// Get max_rank from options (for backwards compatibility).
56    pub fn max_rank(&self) -> Option<usize> {
57        self.truncation.max_rank
58    }
59}
60
61impl HasTruncationParams for SvdOptions {
62    fn truncation_params(&self) -> &TruncationParams {
63        &self.truncation
64    }
65
66    fn truncation_params_mut(&mut self) -> &mut TruncationParams {
67        &mut self.truncation
68    }
69}
70
71// Global default rtol using the unified GlobalDefault type
72// Default value: 1e-12 (near machine precision)
73static DEFAULT_SVD_RTOL: GlobalDefault = GlobalDefault::new(1e-12);
74
75/// Get the global default rtol for SVD truncation.
76///
77/// The default value is 1e-12 (near machine precision).
78pub fn default_svd_rtol() -> f64 {
79    DEFAULT_SVD_RTOL.get()
80}
81
82/// Set the global default rtol for SVD truncation.
83///
84/// # Arguments
85/// * `rtol` - Relative Frobenius error tolerance (must be finite and non-negative)
86///
87/// # Errors
88/// Returns `SvdError::InvalidRtol` if rtol is not finite or is negative.
89pub fn set_default_svd_rtol(rtol: f64) -> Result<(), SvdError> {
90    DEFAULT_SVD_RTOL
91        .set(rtol)
92        .map_err(|e| SvdError::InvalidRtol(e.0))
93}
94
95/// Compute the retained rank based on rtol (TSVD truncation).
96///
97/// This implements the truncation criterion:
98///   sum_{i>r} σ_i² / sum_i σ_i² <= rtol²
99fn compute_retained_rank(s_vec: &[f64], rtol: f64) -> usize {
100    if s_vec.is_empty() {
101        return 1;
102    }
103
104    let total_sq_norm: f64 = s_vec.iter().map(|&s| s * s).sum();
105    if total_sq_norm == 0.0 {
106        return 1;
107    }
108
109    let threshold = rtol * rtol * total_sq_norm;
110    let mut discarded_sq_norm = 0.0;
111    let mut r = s_vec.len();
112    for i in (0..s_vec.len()).rev() {
113        let s_sq = s_vec[i] * s_vec[i];
114        if discarded_sq_norm + s_sq <= threshold {
115            discarded_sq_norm += s_sq;
116            r = i;
117        } else {
118            break;
119        }
120    }
121    r.max(1)
122}
123
124fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, SvdError> {
125    match tensor.scalar_type() {
126        tenferro::ScalarType::F64 => {
127            native_tensor_primal_to_dense_f64_col_major(tensor).map_err(SvdError::ComputationError)
128        }
129        tenferro::ScalarType::C64 => native_tensor_primal_to_dense_c64_col_major(tensor)
130            .map(|values| values.into_iter().map(|value| value.re).collect())
131            .map_err(SvdError::ComputationError),
132        other => Err(SvdError::ComputationError(anyhow::anyhow!(
133            "native SVD returned unsupported singular-value scalar type {other:?}"
134        ))),
135    }
136}
137
138fn truncate_matrix_cols<T: TensorElement>(
139    data: &[T],
140    rows: usize,
141    keep_cols: usize,
142) -> anyhow::Result<tenferro::Tensor> {
143    dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
144}
145
146fn truncate_matrix_rows<T: TensorElement>(
147    data: &[T],
148    rows: usize,
149    cols: usize,
150    keep_rows: usize,
151) -> anyhow::Result<tenferro::Tensor> {
152    let mut truncated = Vec::with_capacity(keep_rows * cols);
153    for col in 0..cols {
154        let start = col * rows;
155        truncated.extend_from_slice(&data[start..start + keep_rows]);
156    }
157    dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
158}
159
160type SvdTruncatedNativeResult = (
161    tenferro::Tensor,
162    tenferro::Tensor,
163    tenferro::Tensor,
164    Vec<f64>,
165    DynIndex,
166    Vec<DynIndex>,
167    Vec<DynIndex>,
168);
169
170fn svd_truncated_native(
171    t: &TensorDynLen,
172    left_inds: &[DynIndex],
173    options: &SvdOptions,
174) -> Result<SvdTruncatedNativeResult, SvdError> {
175    let rtol = options.truncation.effective_rtol(default_svd_rtol());
176    if !rtol.is_finite() || rtol < 0.0 {
177        return Err(SvdError::InvalidRtol(rtol));
178    }
179
180    let (matrix_native, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
181        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
182        .map_err(SvdError::ComputationError)?;
183    let k = m.min(n);
184
185    let (mut u_native, mut s_native, mut vt_native) =
186        svd_native_tensor(&matrix_native).map_err(SvdError::ComputationError)?;
187    let s_full = singular_values_from_native(&s_native)?;
188    let mut r = compute_retained_rank(&s_full, rtol);
189    if let Some(max_rank) = options.truncation.max_rank {
190        r = r.min(max_rank);
191    }
192    if r < k {
193        match u_native.scalar_type() {
194            tenferro::ScalarType::F64 => {
195                let u_values = native_tensor_primal_to_dense_f64_col_major(&u_native)
196                    .map_err(SvdError::ComputationError)?;
197                let vt_values = native_tensor_primal_to_dense_f64_col_major(&vt_native)
198                    .map_err(SvdError::ComputationError)?;
199                u_native =
200                    truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
201                vt_native = truncate_matrix_rows(&vt_values, k, n, r)
202                    .map_err(SvdError::ComputationError)?;
203            }
204            tenferro::ScalarType::C64 => {
205                let u_values = native_tensor_primal_to_dense_c64_col_major(&u_native)
206                    .map_err(SvdError::ComputationError)?;
207                let vt_values = native_tensor_primal_to_dense_c64_col_major(&vt_native)
208                    .map_err(SvdError::ComputationError)?;
209                u_native =
210                    truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
211                vt_native = truncate_matrix_rows(&vt_values, k, n, r)
212                    .map_err(SvdError::ComputationError)?;
213            }
214            other => {
215                return Err(SvdError::ComputationError(anyhow::anyhow!(
216                    "native SVD returned unsupported singular-vector scalar type {other:?}"
217                )));
218            }
219        }
220        s_native = dense_native_tensor_from_col_major(&s_full[..r], &[r])
221            .map_err(SvdError::ComputationError)?;
222    }
223
224    let bond_index = DynIndex::new_bond(r)
225        .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
226        .map_err(SvdError::ComputationError)?;
227    let singular_values = s_full[..r].to_vec();
228
229    Ok((
230        u_native,
231        s_native,
232        vt_native,
233        singular_values,
234        bond_index,
235        left_indices,
236        right_indices,
237    ))
238}
239
240/// Compute SVD decomposition of a tensor with arbitrary rank, returning (U, S, V).
241///
242/// # Examples
243///
244/// ```
245/// use tensor4all_core::{TensorDynLen, DynIndex, svd};
246///
247/// // Create a 2x3 matrix (rank-1 outer product: all-ones)
248/// let i = DynIndex::new_dyn(2);
249/// let j = DynIndex::new_dyn(3);
250/// let data = vec![1.0_f64; 6]; // all-ones 2x3 matrix
251/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
252///
253/// let (u, s, v) = svd::<f64>(&t, &[i.clone()]).unwrap();
254///
255/// // U: shape (left_dim, bond) = (2, bond)
256/// assert_eq!(u.dims()[0], 2);
257/// // V: shape (right_dim, bond) = (3, bond)
258/// assert_eq!(v.dims()[0], 3);
259/// // S is a diagonal matrix (bond × bond)
260/// assert_eq!(s.dims().len(), 2);
261/// ```
262pub fn svd<T>(
263    t: &TensorDynLen,
264    left_inds: &[DynIndex],
265) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
266    svd_with::<T>(t, left_inds, &SvdOptions::default())
267}
268
269/// Compute SVD decomposition of a tensor with arbitrary rank, returning (U, S, V).
270///
271/// This function allows per-call control of the truncation tolerance via `SvdOptions`.
272/// If `options.rtol` is `None`, uses the global default rtol.
273pub fn svd_with<T>(
274    t: &TensorDynLen,
275    left_inds: &[DynIndex],
276    options: &SvdOptions,
277) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
278    let (u_native, s_native, vt_native, _singular_values, bond_index, left_indices, right_indices) =
279        svd_truncated_native(t, left_inds, options)?;
280
281    let mut u_indices = left_indices;
282    u_indices.push(bond_index.clone());
283    let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
284    let u_reshaped = reshape_col_major_native_tensor(&u_native, &u_dims).map_err(|e| {
285        SvdError::ComputationError(anyhow::anyhow!("native SVD U reshape failed: {e}"))
286    })?;
287    let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
288
289    let s_indices = vec![bond_index.clone(), bond_index.sim()];
290    let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
291        .map_err(SvdError::ComputationError)?;
292    let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
293
294    let mut vh_indices = vec![bond_index.clone()];
295    vh_indices.extend(right_indices);
296    let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
297    let vt_reshaped = reshape_col_major_native_tensor(&vt_native, &vh_dims).map_err(|e| {
298        SvdError::ComputationError(anyhow::anyhow!("native SVD V^T reshape failed: {e}"))
299    })?;
300    let vh =
301        TensorDynLen::from_native(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
302    let perm: Vec<usize> = (1..vh.indices.len()).chain(std::iter::once(0)).collect();
303    let v = vh.conj().permute(&perm);
304
305    Ok((u, s, v))
306}
307
308/// SVD result for factorization, returning `V^H` directly.
309pub(crate) struct SvdFactorizeResult {
310    pub u: TensorDynLen,
311    pub s: TensorDynLen,
312    pub vh: TensorDynLen,
313    pub bond_index: DynIndex,
314    pub singular_values: Vec<f64>,
315    pub rank: usize,
316}
317
318/// Compute truncated SVD for factorization, returning `V^H` instead of `V`.
319pub(crate) fn svd_for_factorize(
320    t: &TensorDynLen,
321    left_inds: &[DynIndex],
322    options: &SvdOptions,
323) -> Result<SvdFactorizeResult, SvdError> {
324    let (u_native, s_native, vt_native, singular_values, bond_index, left_indices, right_indices) =
325        svd_truncated_native(t, left_inds, options)?;
326    let rank = singular_values.len();
327
328    let mut u_indices = left_indices;
329    u_indices.push(bond_index.clone());
330    let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
331    let u_reshaped = reshape_col_major_native_tensor(&u_native, &u_dims).map_err(|e| {
332        SvdError::ComputationError(anyhow::anyhow!("native SVD U reshape failed: {e}"))
333    })?;
334    let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
335
336    let s_indices = vec![bond_index.clone(), bond_index.sim()];
337    let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
338        .map_err(SvdError::ComputationError)?;
339    let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
340
341    let mut vh_indices = vec![bond_index.clone()];
342    vh_indices.extend(right_indices);
343    let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
344    let vt_reshaped = reshape_col_major_native_tensor(&vt_native, &vh_dims).map_err(|e| {
345        SvdError::ComputationError(anyhow::anyhow!("native SVD V^T reshape failed: {e}"))
346    })?;
347    let vh =
348        TensorDynLen::from_native(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
349
350    Ok(SvdFactorizeResult {
351        u,
352        s,
353        vh,
354        bond_index,
355        singular_values,
356        rank,
357    })
358}
359
360#[cfg(test)]
361mod tests;