tensor4all_core/defaults/
svd.rs1use crate::defaults::tensordynlen::unfold_split_inner;
11use crate::defaults::DynIndex;
12use crate::index_like::IndexLike;
13use crate::truncation::{
14 validate_svd_truncation_policy, SingularValueMeasure, SvdTruncationPolicy, ThresholdScale,
15 TruncationRule,
16};
17use crate::TensorDynLen;
18use std::sync::Mutex;
19use tenferro::{CpuBackend, DType, EagerTensor};
20use tensor4all_tensorbackend::{
21 native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
22};
23use thiserror::Error;
24
25#[derive(Debug, Error)]
27pub enum SvdError {
28 #[error("SVD computation failed: {0}")]
30 ComputationError(#[from] anyhow::Error),
31 #[error("Invalid SVD truncation threshold: {0}. Threshold must be finite and non-negative.")]
33 InvalidThreshold(f64),
34}
35
36#[derive(Debug, Clone, Copy)]
57pub struct SvdOptions {
58 pub max_rank: Option<usize>,
60 pub policy: Option<SvdTruncationPolicy>,
63 truncate: bool,
64}
65
66impl Default for SvdOptions {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl SvdOptions {
73 #[must_use]
75 pub fn new() -> Self {
76 Self {
77 max_rank: None,
78 policy: None,
79 truncate: true,
80 }
81 }
82
83 #[must_use]
85 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
86 self.max_rank = Some(max_rank);
87 self
88 }
89
90 #[must_use]
92 pub fn with_policy(mut self, policy: SvdTruncationPolicy) -> Self {
93 self.policy = Some(policy);
94 self
95 }
96
97 pub(crate) fn full_rank() -> Self {
98 Self {
99 max_rank: None,
100 policy: None,
101 truncate: false,
102 }
103 }
104}
105
106fn default_policy_guard() -> std::sync::MutexGuard<'static, SvdTruncationPolicy> {
107 match DEFAULT_SVD_TRUNCATION_POLICY.lock() {
108 Ok(guard) => guard,
109 Err(poisoned) => poisoned.into_inner(),
110 }
111}
112
113static DEFAULT_SVD_TRUNCATION_POLICY: Mutex<SvdTruncationPolicy> =
115 Mutex::new(SvdTruncationPolicy::new(1e-12));
116
117#[must_use]
121pub fn default_svd_truncation_policy() -> SvdTruncationPolicy {
122 *default_policy_guard()
123}
124
125pub fn set_default_svd_truncation_policy(policy: SvdTruncationPolicy) -> Result<(), SvdError> {
133 validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
134 *default_policy_guard() = policy;
135 Ok(())
136}
137
138fn singular_value_measure(value: f64, measure: SingularValueMeasure) -> f64 {
139 match measure {
140 SingularValueMeasure::Value => value,
141 SingularValueMeasure::SquaredValue => value * value,
142 }
143}
144
145fn compute_retained_rank(s_vec: &[f64], policy: &SvdTruncationPolicy) -> usize {
147 if s_vec.is_empty() {
148 return 1;
149 }
150
151 let measured: Vec<f64> = s_vec
152 .iter()
153 .map(|&value| singular_value_measure(value, policy.measure))
154 .collect();
155 if measured.iter().all(|&value| value == 0.0) {
156 return 1;
157 }
158
159 let retained = match (policy.scale, policy.rule) {
160 (ThresholdScale::Relative, TruncationRule::PerValue) => {
161 let reference = measured.iter().copied().fold(0.0_f64, f64::max);
162 measured
163 .iter()
164 .take_while(|&&value| reference > 0.0 && value / reference > policy.threshold)
165 .count()
166 }
167 (ThresholdScale::Absolute, TruncationRule::PerValue) => measured
168 .iter()
169 .take_while(|&&value| value > policy.threshold)
170 .count(),
171 (ThresholdScale::Relative, TruncationRule::DiscardedTailSum) => {
172 let total: f64 = measured.iter().sum();
173 if total == 0.0 {
174 1
175 } else {
176 let mut discarded = 0.0;
177 let mut keep = measured.len();
178 for (i, value) in measured.iter().enumerate().rev() {
179 if (discarded + value) / total <= policy.threshold {
180 discarded += value;
181 keep = i;
182 } else {
183 break;
184 }
185 }
186 keep
187 }
188 }
189 (ThresholdScale::Absolute, TruncationRule::DiscardedTailSum) => {
190 let mut discarded = 0.0;
191 let mut keep = measured.len();
192 for (i, value) in measured.iter().enumerate().rev() {
193 if discarded + value <= policy.threshold {
194 discarded += value;
195 keep = i;
196 } else {
197 break;
198 }
199 }
200 keep
201 }
202 };
203
204 retained.max(1)
205}
206
207fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, SvdError> {
208 match tensor.dtype() {
209 DType::F64 => {
210 native_tensor_primal_to_dense_f64_col_major(tensor).map_err(SvdError::ComputationError)
211 }
212 DType::C64 => native_tensor_primal_to_dense_c64_col_major(tensor)
213 .map(|values| values.into_iter().map(|value| value.re).collect())
214 .map_err(SvdError::ComputationError),
215 other => Err(SvdError::ComputationError(anyhow::anyhow!(
216 "native SVD returned unsupported singular-value scalar type {other:?}"
217 ))),
218 }
219}
220
221type SvdTruncatedEagerResult = (
222 EagerTensor<CpuBackend>,
223 EagerTensor<CpuBackend>,
224 EagerTensor<CpuBackend>,
225 Vec<f64>,
226 DynIndex,
227 Vec<DynIndex>,
228 Vec<DynIndex>,
229);
230
231fn svd_truncated_inner(
232 t: &TensorDynLen,
233 left_inds: &[DynIndex],
234 options: &SvdOptions,
235) -> Result<SvdTruncatedEagerResult, SvdError> {
236 let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
237 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
238 .map_err(SvdError::ComputationError)?;
239 let k = m.min(n);
240
241 let (mut u_inner, mut s_inner, mut vt_inner) = matrix_inner
242 .svd()
243 .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
244 let s_full = singular_values_from_native(s_inner.data())?;
245 let mut r = if options.truncate {
246 let policy = options.policy.unwrap_or_else(default_svd_truncation_policy);
247 validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
248
249 let mut retained = compute_retained_rank(&s_full, &policy);
250 if let Some(max_rank) = options.max_rank {
251 retained = retained.min(max_rank);
252 }
253 retained.max(1)
254 } else {
255 k.max(1)
256 };
257 r = r.min(s_full.len());
258 if r < k {
259 let keep: Vec<usize> = (0..r).collect();
260 u_inner = u_inner
261 .take_axis(1, &keep)
262 .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
263 s_inner = s_inner
264 .take_axis(0, &keep)
265 .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
266 vt_inner = vt_inner
267 .take_axis(0, &keep)
268 .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
269 }
270
271 let bond_index = DynIndex::new_bond(r)
272 .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
273 .map_err(SvdError::ComputationError)?;
274 let singular_values = s_full[..r].to_vec();
275
276 Ok((
277 u_inner,
278 s_inner,
279 vt_inner,
280 singular_values,
281 bond_index,
282 left_indices,
283 right_indices,
284 ))
285}
286
287pub fn svd<T>(
310 t: &TensorDynLen,
311 left_inds: &[DynIndex],
312) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
313 svd_with::<T>(t, left_inds, &SvdOptions::default())
314}
315
316pub fn svd_with<T>(
347 t: &TensorDynLen,
348 left_inds: &[DynIndex],
349 options: &SvdOptions,
350) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
351 let (u_inner, s_inner, vt_inner, _singular_values, bond_index, left_indices, right_indices) =
352 svd_truncated_inner(t, left_inds, options)?;
353
354 let mut u_indices = left_indices;
355 u_indices.push(bond_index.clone());
356 let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
357 let u_reshaped = u_inner.reshape(&u_dims).map_err(|e| {
358 SvdError::ComputationError(anyhow::anyhow!("eager SVD U reshape failed: {e}"))
359 })?;
360 let u = TensorDynLen::from_inner(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
361
362 let s_indices = vec![bond_index.clone(), bond_index.sim()];
363 let s =
364 TensorDynLen::from_diag_inner(s_indices, s_inner).map_err(SvdError::ComputationError)?;
365
366 let mut vh_indices = vec![bond_index.clone()];
367 vh_indices.extend(right_indices);
368 let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
369 let vt_reshaped = vt_inner.reshape(&vh_dims).map_err(|e| {
370 SvdError::ComputationError(anyhow::anyhow!("eager SVD V^T reshape failed: {e}"))
371 })?;
372 let vh =
373 TensorDynLen::from_inner(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
374 let perm: Vec<usize> = (1..vh.indices.len()).chain(std::iter::once(0)).collect();
375 let v = vh
376 .conj()
377 .permute(&perm)
378 .map_err(SvdError::ComputationError)?;
379
380 Ok((u, s, v))
381}
382
383pub(crate) struct SvdFactorizeResult {
385 pub u: TensorDynLen,
386 pub s: TensorDynLen,
387 pub vh: TensorDynLen,
388 pub bond_index: DynIndex,
389 pub singular_values: Vec<f64>,
390 pub rank: usize,
391}
392
393pub(crate) fn svd_for_factorize(
395 t: &TensorDynLen,
396 left_inds: &[DynIndex],
397 options: &SvdOptions,
398) -> Result<SvdFactorizeResult, SvdError> {
399 let (u_inner, s_inner, vt_inner, singular_values, bond_index, left_indices, right_indices) =
400 svd_truncated_inner(t, left_inds, options)?;
401 let rank = singular_values.len();
402
403 let mut u_indices = left_indices;
404 u_indices.push(bond_index.clone());
405 let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
406 let u_reshaped = u_inner.reshape(&u_dims).map_err(|e| {
407 SvdError::ComputationError(anyhow::anyhow!("eager SVD U reshape failed: {e}"))
408 })?;
409 let u = TensorDynLen::from_inner(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
410
411 let s_indices = vec![bond_index.clone(), bond_index.sim()];
412 let s =
413 TensorDynLen::from_diag_inner(s_indices, s_inner).map_err(SvdError::ComputationError)?;
414
415 let mut vh_indices = vec![bond_index.clone()];
416 vh_indices.extend(right_indices);
417 let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
418 let vt_reshaped = vt_inner.reshape(&vh_dims).map_err(|e| {
419 SvdError::ComputationError(anyhow::anyhow!("eager SVD V^T reshape failed: {e}"))
420 })?;
421 let vh =
422 TensorDynLen::from_inner(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
423
424 Ok(SvdFactorizeResult {
425 u,
426 s,
427 vh,
428 bond_index,
429 singular_values,
430 rank,
431 })
432}
433
434#[cfg(test)]
435mod tests;