1use crate::defaults::DynIndex;
11use crate::index_like::IndexLike;
12use crate::truncation::{
13 validate_svd_truncation_policy, SingularValueMeasure, SvdTruncationPolicy, ThresholdScale,
14 TruncationRule,
15};
16use crate::{unfold_split, TensorDynLen};
17use std::sync::Mutex;
18use tenferro::DType;
19use tensor4all_tensorbackend::{
20 dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
21 native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
22 reshape_col_major_native_tensor, svd_native_tensor, TensorElement,
23};
24use thiserror::Error;
25
26#[derive(Debug, Error)]
28pub enum SvdError {
29 #[error("SVD computation failed: {0}")]
31 ComputationError(#[from] anyhow::Error),
32 #[error("Invalid SVD truncation threshold: {0}. Threshold must be finite and non-negative.")]
34 InvalidThreshold(f64),
35}
36
37#[derive(Debug, Clone, Copy)]
58pub struct SvdOptions {
59 pub max_rank: Option<usize>,
61 pub policy: Option<SvdTruncationPolicy>,
64 truncate: bool,
65}
66
67impl Default for SvdOptions {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl SvdOptions {
74 #[must_use]
76 pub fn new() -> Self {
77 Self {
78 max_rank: None,
79 policy: None,
80 truncate: true,
81 }
82 }
83
84 #[must_use]
86 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
87 self.max_rank = Some(max_rank);
88 self
89 }
90
91 #[must_use]
93 pub fn with_policy(mut self, policy: SvdTruncationPolicy) -> Self {
94 self.policy = Some(policy);
95 self
96 }
97
98 pub(crate) fn full_rank() -> Self {
99 Self {
100 max_rank: None,
101 policy: None,
102 truncate: false,
103 }
104 }
105}
106
107fn default_policy_guard() -> std::sync::MutexGuard<'static, SvdTruncationPolicy> {
108 match DEFAULT_SVD_TRUNCATION_POLICY.lock() {
109 Ok(guard) => guard,
110 Err(poisoned) => poisoned.into_inner(),
111 }
112}
113
114static DEFAULT_SVD_TRUNCATION_POLICY: Mutex<SvdTruncationPolicy> =
116 Mutex::new(SvdTruncationPolicy::new(1e-12));
117
118#[must_use]
122pub fn default_svd_truncation_policy() -> SvdTruncationPolicy {
123 *default_policy_guard()
124}
125
126pub fn set_default_svd_truncation_policy(policy: SvdTruncationPolicy) -> Result<(), SvdError> {
134 validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
135 *default_policy_guard() = policy;
136 Ok(())
137}
138
139fn singular_value_measure(value: f64, measure: SingularValueMeasure) -> f64 {
140 match measure {
141 SingularValueMeasure::Value => value,
142 SingularValueMeasure::SquaredValue => value * value,
143 }
144}
145
146fn compute_retained_rank(s_vec: &[f64], policy: &SvdTruncationPolicy) -> usize {
148 if s_vec.is_empty() {
149 return 1;
150 }
151
152 let measured: Vec<f64> = s_vec
153 .iter()
154 .map(|&value| singular_value_measure(value, policy.measure))
155 .collect();
156 if measured.iter().all(|&value| value == 0.0) {
157 return 1;
158 }
159
160 let retained = match (policy.scale, policy.rule) {
161 (ThresholdScale::Relative, TruncationRule::PerValue) => {
162 let reference = measured.iter().copied().fold(0.0_f64, f64::max);
163 measured
164 .iter()
165 .take_while(|&&value| reference > 0.0 && value / reference > policy.threshold)
166 .count()
167 }
168 (ThresholdScale::Absolute, TruncationRule::PerValue) => measured
169 .iter()
170 .take_while(|&&value| value > policy.threshold)
171 .count(),
172 (ThresholdScale::Relative, TruncationRule::DiscardedTailSum) => {
173 let total: f64 = measured.iter().sum();
174 if total == 0.0 {
175 1
176 } else {
177 let mut discarded = 0.0;
178 let mut keep = measured.len();
179 for (i, value) in measured.iter().enumerate().rev() {
180 if (discarded + value) / total <= policy.threshold {
181 discarded += value;
182 keep = i;
183 } else {
184 break;
185 }
186 }
187 keep
188 }
189 }
190 (ThresholdScale::Absolute, TruncationRule::DiscardedTailSum) => {
191 let mut discarded = 0.0;
192 let mut keep = measured.len();
193 for (i, value) in measured.iter().enumerate().rev() {
194 if discarded + value <= policy.threshold {
195 discarded += value;
196 keep = i;
197 } else {
198 break;
199 }
200 }
201 keep
202 }
203 };
204
205 retained.max(1)
206}
207
208fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, SvdError> {
209 match tensor.dtype() {
210 DType::F64 => {
211 native_tensor_primal_to_dense_f64_col_major(tensor).map_err(SvdError::ComputationError)
212 }
213 DType::C64 => native_tensor_primal_to_dense_c64_col_major(tensor)
214 .map(|values| values.into_iter().map(|value| value.re).collect())
215 .map_err(SvdError::ComputationError),
216 other => Err(SvdError::ComputationError(anyhow::anyhow!(
217 "native SVD returned unsupported singular-value scalar type {other:?}"
218 ))),
219 }
220}
221
222fn truncate_matrix_cols<T: TensorElement>(
223 data: &[T],
224 rows: usize,
225 keep_cols: usize,
226) -> anyhow::Result<tenferro::Tensor> {
227 dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
228}
229
230fn truncate_matrix_rows<T: TensorElement>(
231 data: &[T],
232 rows: usize,
233 cols: usize,
234 keep_rows: usize,
235) -> anyhow::Result<tenferro::Tensor> {
236 let mut truncated = Vec::with_capacity(keep_rows * cols);
237 for col in 0..cols {
238 let start = col * rows;
239 truncated.extend_from_slice(&data[start..start + keep_rows]);
240 }
241 dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
242}
243
244type SvdTruncatedNativeResult = (
245 tenferro::Tensor,
246 tenferro::Tensor,
247 tenferro::Tensor,
248 Vec<f64>,
249 DynIndex,
250 Vec<DynIndex>,
251 Vec<DynIndex>,
252);
253
254fn svd_truncated_native(
255 t: &TensorDynLen,
256 left_inds: &[DynIndex],
257 options: &SvdOptions,
258) -> Result<SvdTruncatedNativeResult, SvdError> {
259 let (matrix_native, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
260 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
261 .map_err(SvdError::ComputationError)?;
262 let k = m.min(n);
263
264 let (mut u_native, mut s_native, mut vt_native) =
265 svd_native_tensor(&matrix_native).map_err(SvdError::ComputationError)?;
266 let s_full = singular_values_from_native(&s_native)?;
267 let mut r = if options.truncate {
268 let policy = options.policy.unwrap_or_else(default_svd_truncation_policy);
269 validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
270
271 let mut retained = compute_retained_rank(&s_full, &policy);
272 if let Some(max_rank) = options.max_rank {
273 retained = retained.min(max_rank);
274 }
275 retained.max(1)
276 } else {
277 k.max(1)
278 };
279 r = r.min(s_full.len());
280 if r < k {
281 match u_native.dtype() {
282 DType::F64 => {
283 let u_values = native_tensor_primal_to_dense_f64_col_major(&u_native)
284 .map_err(SvdError::ComputationError)?;
285 let vt_values = native_tensor_primal_to_dense_f64_col_major(&vt_native)
286 .map_err(SvdError::ComputationError)?;
287 u_native =
288 truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
289 vt_native = truncate_matrix_rows(&vt_values, k, n, r)
290 .map_err(SvdError::ComputationError)?;
291 }
292 DType::C64 => {
293 let u_values = native_tensor_primal_to_dense_c64_col_major(&u_native)
294 .map_err(SvdError::ComputationError)?;
295 let vt_values = native_tensor_primal_to_dense_c64_col_major(&vt_native)
296 .map_err(SvdError::ComputationError)?;
297 u_native =
298 truncate_matrix_cols(&u_values, m, r).map_err(SvdError::ComputationError)?;
299 vt_native = truncate_matrix_rows(&vt_values, k, n, r)
300 .map_err(SvdError::ComputationError)?;
301 }
302 other => {
303 return Err(SvdError::ComputationError(anyhow::anyhow!(
304 "native SVD returned unsupported singular-vector scalar type {other:?}"
305 )));
306 }
307 }
308 s_native = dense_native_tensor_from_col_major(&s_full[..r], &[r])
309 .map_err(SvdError::ComputationError)?;
310 }
311
312 let bond_index = DynIndex::new_bond(r)
313 .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
314 .map_err(SvdError::ComputationError)?;
315 let singular_values = s_full[..r].to_vec();
316
317 Ok((
318 u_native,
319 s_native,
320 vt_native,
321 singular_values,
322 bond_index,
323 left_indices,
324 right_indices,
325 ))
326}
327
328pub fn svd<T>(
351 t: &TensorDynLen,
352 left_inds: &[DynIndex],
353) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
354 svd_with::<T>(t, left_inds, &SvdOptions::default())
355}
356
357pub fn svd_with<T>(
388 t: &TensorDynLen,
389 left_inds: &[DynIndex],
390 options: &SvdOptions,
391) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
392 let (u_native, s_native, vt_native, _singular_values, bond_index, left_indices, right_indices) =
393 svd_truncated_native(t, left_inds, options)?;
394
395 let mut u_indices = left_indices;
396 u_indices.push(bond_index.clone());
397 let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
398 let u_reshaped = reshape_col_major_native_tensor(&u_native, &u_dims).map_err(|e| {
399 SvdError::ComputationError(anyhow::anyhow!("native SVD U reshape failed: {e}"))
400 })?;
401 let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
402
403 let s_indices = vec![bond_index.clone(), bond_index.sim()];
404 let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
405 .map_err(SvdError::ComputationError)?;
406 let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
407
408 let mut vh_indices = vec![bond_index.clone()];
409 vh_indices.extend(right_indices);
410 let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
411 let vt_reshaped = reshape_col_major_native_tensor(&vt_native, &vh_dims).map_err(|e| {
412 SvdError::ComputationError(anyhow::anyhow!("native SVD V^T reshape failed: {e}"))
413 })?;
414 let vh =
415 TensorDynLen::from_native(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
416 let perm: Vec<usize> = (1..vh.indices.len()).chain(std::iter::once(0)).collect();
417 let v = vh.conj().permute(&perm);
418
419 Ok((u, s, v))
420}
421
422pub(crate) struct SvdFactorizeResult {
424 pub u: TensorDynLen,
425 pub s: TensorDynLen,
426 pub vh: TensorDynLen,
427 pub bond_index: DynIndex,
428 pub singular_values: Vec<f64>,
429 pub rank: usize,
430}
431
432pub(crate) fn svd_for_factorize(
434 t: &TensorDynLen,
435 left_inds: &[DynIndex],
436 options: &SvdOptions,
437) -> Result<SvdFactorizeResult, SvdError> {
438 let (u_native, s_native, vt_native, singular_values, bond_index, left_indices, right_indices) =
439 svd_truncated_native(t, left_inds, options)?;
440 let rank = singular_values.len();
441
442 let mut u_indices = left_indices;
443 u_indices.push(bond_index.clone());
444 let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
445 let u_reshaped = reshape_col_major_native_tensor(&u_native, &u_dims).map_err(|e| {
446 SvdError::ComputationError(anyhow::anyhow!("native SVD U reshape failed: {e}"))
447 })?;
448 let u = TensorDynLen::from_native(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
449
450 let s_indices = vec![bond_index.clone(), bond_index.sim()];
451 let s_diag = diag_native_tensor_from_col_major(&singular_values_from_native(&s_native)?, 2)
452 .map_err(SvdError::ComputationError)?;
453 let s = TensorDynLen::from_native(s_indices, s_diag).map_err(SvdError::ComputationError)?;
454
455 let mut vh_indices = vec![bond_index.clone()];
456 vh_indices.extend(right_indices);
457 let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
458 let vt_reshaped = reshape_col_major_native_tensor(&vt_native, &vh_dims).map_err(|e| {
459 SvdError::ComputationError(anyhow::anyhow!("native SVD V^T reshape failed: {e}"))
460 })?;
461 let vh =
462 TensorDynLen::from_native(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
463
464 Ok(SvdFactorizeResult {
465 u,
466 s,
467 vh,
468 bond_index,
469 singular_values,
470 rank,
471 })
472}
473
474#[cfg(test)]
475mod tests;