tensor4all_core/defaults/
svd.rs1use crate::defaults::DynIndex;
6use crate::global_default::GlobalDefault;
7use crate::index_like::IndexLike;
8use crate::truncation::{HasTruncationParams, TruncationParams};
9use crate::{unfold_split, TensorDynLen};
10use tensor4all_tensorbackend::{
11 dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
12 native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
13 reshape_col_major_native_tensor, svd_native_tensor, TensorElement,
14};
15use thiserror::Error;
16
17#[derive(Debug, Error)]
19pub enum SvdError {
20 #[error("SVD computation failed: {0}")]
22 ComputationError(#[from] anyhow::Error),
23 #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
25 InvalidRtol(f64),
26}
27
28#[derive(Debug, Clone, Copy, Default)]
30pub struct SvdOptions {
31 pub truncation: TruncationParams,
33}
34
35impl SvdOptions {
36 pub fn with_rtol(rtol: f64) -> Self {
38 Self {
39 truncation: TruncationParams::new().with_rtol(rtol),
40 }
41 }
42
43 pub fn with_max_rank(max_rank: usize) -> Self {
45 Self {
46 truncation: TruncationParams::new().with_max_rank(max_rank),
47 }
48 }
49
50 pub fn rtol(&self) -> Option<f64> {
52 self.truncation.rtol
53 }
54
55 pub fn max_rank(&self) -> Option<usize> {
57 self.truncation.max_rank
58 }
59}
60
61impl HasTruncationParams for SvdOptions {
62 fn truncation_params(&self) -> &TruncationParams {
63 &self.truncation
64 }
65
66 fn truncation_params_mut(&mut self) -> &mut TruncationParams {
67 &mut self.truncation
68 }
69}
70
71static DEFAULT_SVD_RTOL: GlobalDefault = GlobalDefault::new(1e-12);
74
75pub fn default_svd_rtol() -> f64 {
79 DEFAULT_SVD_RTOL.get()
80}
81
82pub fn set_default_svd_rtol(rtol: f64) -> Result<(), SvdError> {
90 DEFAULT_SVD_RTOL
91 .set(rtol)
92 .map_err(|e| SvdError::InvalidRtol(e.0))
93}
94
95fn compute_retained_rank(s_vec: &[f64], rtol: f64) -> usize {
100 if s_vec.is_empty() {
101 return 1;
102 }
103
104 let total_sq_norm: f64 = s_vec.iter().map(|&s| s * s).sum();
105 if total_sq_norm == 0.0 {
106 return 1;
107 }
108
109 let threshold = rtol * rtol * total_sq_norm;
110 let mut discarded_sq_norm = 0.0;
111 let mut r = s_vec.len();
112 for i in (0..s_vec.len()).rev() {
113 let s_sq = s_vec[i] * s_vec[i];
114 if discarded_sq_norm + s_sq <= threshold {
115 discarded_sq_norm += s_sq;
116 r = i;
117 } else {
118 break;
119 }
120 }
121 r.max(1)
122}
123
124fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, SvdError> {
125 match tensor.scalar_type() {
126 tenferro::ScalarType::F64 => {
127 native_tensor_primal_to_dense_f64_col_major(tensor).map_err(SvdError::ComputationError)
128 }
129 tenferro::ScalarType::C64 => native_tensor_primal_to_dense_c64_col_major(tensor)
130 .map(|values| values.into_iter().map(|value| value.re).collect())
131 .map_err(SvdError::ComputationError),
132 other => Err(SvdError::ComputationError(anyhow::anyhow!(
133 "native SVD returned unsupported singular-value scalar type {other:?}"
134 ))),
135 }
136}
137
138fn truncate_matrix_cols<T: TensorElement>(
139 data: &[T],
140 rows: usize,
141 keep_cols: usize,
142) -> anyhow::Result<tenferro::Tensor> {
143 dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
144}
145
146fn truncate_matrix_rows<T: TensorElement>(
147 data: &[T],
148 rows: usize,
149 cols: usize,
150 keep_rows: usize,
151) -> anyhow::Result<tenferro::Tensor> {
152 let mut truncated = Vec::with_capacity(keep_rows * cols);
153 for col in 0..cols {
154 let start = col * rows;
155 truncated.extend_from_slice(&data[start..start + keep_rows]);
156 }
157 dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
158}
159
160type SvdTruncatedNativeResult = (
161 tenferro::Tensor,
162 tenferro::Tensor,
163 tenferro::Tensor,
164 Vec<f64>,
165 DynIndex,
166 Vec<DynIndex>,
167 Vec<DynIndex>,
168);
169
170fn svd_truncated_native(
171 t: &TensorDynLen,
172 left_inds: &[DynIndex],
173 options: &SvdOptions,
174) -> Result<SvdTruncatedNativeResult, SvdError> {
175 let rtol = options.truncation.effective_rtol(default_svd_rtol());
176 if !rtol.is_finite() || rtol < 0.0 {
177 return Err(SvdError::InvalidRtol(rtol));
178 }
179
180 let (matrix_native, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
181 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
182 .map_err(SvdError::ComputationError)?;
183 let k = m.min(n);
184
185 let (mut u_native, mut s_native, mut vt_native) =
186 svd_native_tensor(&matrix_native).map_err(SvdError::ComputationError)?;
187 let s_full = singular_values_from_native(&s_native)?;
188 let mut r = compute_retained_rank(&s_full, rtol);
189 if let Some(max_rank) = options.truncation.max_rank {
190 r = r.min(max_rank);
191 }
192 if r < k {
193 match u_native.scalar_type() {
194 tenferro::ScalarType::F64 => {
195 let u_values = native_tensor_primal_to_dense_f64_col_major(&u_native)
196 .map_err(SvdError::ComputationError)?;
197 let vt_values = native_tensor_primal_to_dense_f64_col_major(&vt_native)
198 .map_err(SvdError::ComputationError)?;
199 u_native =
200 truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
201 vt_native = truncate_matrix_rows(&vt_values, k, n, r)
202 .map_err(SvdError::ComputationError)?;
203 }
204 tenferro::ScalarType::C64 => {
205 let u_values = native_tensor_primal_to_dense_c64_col_major(&u_native)
206 .map_err(SvdError::ComputationError)?;
207 let vt_values = native_tensor_primal_to_dense_c64_col_major(&vt_native)
208 .map_err(SvdError::ComputationError)?;
209 u_native =
210 truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
211 vt_native = truncate_matrix_rows(&vt_values, k, n, r)
212 .map_err(SvdError::ComputationError)?;
213 }
214 other => {
215 return Err(SvdError::ComputationError(anyhow::anyhow!(
216 "native SVD returned unsupported singular-vector scalar type {other:?}"
217 )));
218 }
219 }
220 s_native = dense_native_tensor_from_col_major(&s_full[..r], &[r])
221 .map_err(SvdError::ComputationError)?;
222 }
223
224 let bond_index = DynIndex::new_bond(r)
225 .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
226 .map_err(SvdError::ComputationError)?;
227 let singular_values = s_full[..r].to_vec();
228
229 Ok((
230 u_native,
231 s_native,
232 vt_native,
233 singular_values,
234 bond_index,
235 left_indices,
236 right_indices,
237 ))
238}
239
240pub fn svd<T>(
263 t: &TensorDynLen,
264 left_inds: &[DynIndex],
265) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
266 svd_with::<T>(t, left_inds, &SvdOptions::default())
267}
268
269pub fn svd_with<T>(
274 t: &TensorDynLen,
275 left_inds: &[DynIndex],
276 options: &SvdOptions,
277) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
278 let (u_native, s_native, vt_native, _singular_values, bond_index, left_indices, right_indices) =
279 svd_truncated_native(t, left_inds, options)?;
280
281 let mut u_indices = left_indices;
282 u_indices.push(bond_index.clone());
283 let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
284 let u_reshaped = reshape_col_major_native_tensor(&u_native, &u_dims).map_err(|e| {
285 SvdError::ComputationError(anyhow::anyhow!("native SVD U reshape failed: {e}"))
286 })?;
287 let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
288
289 let s_indices = vec![bond_index.clone(), bond_index.sim()];
290 let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
291 .map_err(SvdError::ComputationError)?;
292 let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
293
294 let mut vh_indices = vec![bond_index.clone()];
295 vh_indices.extend(right_indices);
296 let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
297 let vt_reshaped = reshape_col_major_native_tensor(&vt_native, &vh_dims).map_err(|e| {
298 SvdError::ComputationError(anyhow::anyhow!("native SVD V^T reshape failed: {e}"))
299 })?;
300 let vh =
301 TensorDynLen::from_native(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
302 let perm: Vec<usize> = (1..vh.indices.len()).chain(std::iter::once(0)).collect();
303 let v = vh.conj().permute(&perm);
304
305 Ok((u, s, v))
306}
307
308pub(crate) struct SvdFactorizeResult {
310 pub u: TensorDynLen,
311 pub s: TensorDynLen,
312 pub vh: TensorDynLen,
313 pub bond_index: DynIndex,
314 pub singular_values: Vec<f64>,
315 pub rank: usize,
316}
317
318pub(crate) fn svd_for_factorize(
320 t: &TensorDynLen,
321 left_inds: &[DynIndex],
322 options: &SvdOptions,
323) -> Result<SvdFactorizeResult, SvdError> {
324 let (u_native, s_native, vt_native, singular_values, bond_index, left_indices, right_indices) =
325 svd_truncated_native(t, left_inds, options)?;
326 let rank = singular_values.len();
327
328 let mut u_indices = left_indices;
329 u_indices.push(bond_index.clone());
330 let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
331 let u_reshaped = reshape_col_major_native_tensor(&u_native, &u_dims).map_err(|e| {
332 SvdError::ComputationError(anyhow::anyhow!("native SVD U reshape failed: {e}"))
333 })?;
334 let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
335
336 let s_indices = vec![bond_index.clone(), bond_index.sim()];
337 let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
338 .map_err(SvdError::ComputationError)?;
339 let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
340
341 let mut vh_indices = vec![bond_index.clone()];
342 vh_indices.extend(right_indices);
343 let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
344 let vt_reshaped = reshape_col_major_native_tensor(&vt_native, &vh_dims).map_err(|e| {
345 SvdError::ComputationError(anyhow::anyhow!("native SVD V^T reshape failed: {e}"))
346 })?;
347 let vh =
348 TensorDynLen::from_native(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
349
350 Ok(SvdFactorizeResult {
351 u,
352 s,
353 vh,
354 bond_index,
355 singular_values,
356 rank,
357 })
358}
359
360#[cfg(test)]
361mod tests;