tensor4all_core/defaults/
svd.rs1use crate::defaults::tensordynlen::unfold_split_inner;
11use crate::defaults::DynIndex;
12use crate::index_like::IndexLike;
13use crate::truncation::{
14 validate_svd_truncation_policy, SingularValueMeasure, SvdTruncationPolicy, ThresholdScale,
15 TruncationRule,
16};
17use crate::TensorDynLen;
18use std::sync::Mutex;
19use tenferro::DType;
20use tenferro_ad::EagerTensor;
21use tenferro_linalg::eager_tensor::svd as eager_svd;
22use tensor4all_tensorbackend::{
23 native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
24};
25use thiserror::Error;
26
27#[derive(Debug, Error)]
29pub enum SvdError {
30 #[error("SVD computation failed: {0}")]
32 ComputationError(#[from] anyhow::Error),
33 #[error("Invalid SVD truncation threshold: {0}. Threshold must be finite and non-negative.")]
35 InvalidThreshold(f64),
36}
37
38#[derive(Debug, Clone, Copy)]
59pub struct SvdOptions {
60 pub max_rank: Option<usize>,
62 pub policy: Option<SvdTruncationPolicy>,
65 truncate: bool,
66}
67
68impl Default for SvdOptions {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl SvdOptions {
75 #[must_use]
77 pub fn new() -> Self {
78 Self {
79 max_rank: None,
80 policy: None,
81 truncate: true,
82 }
83 }
84
85 #[must_use]
87 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
88 self.max_rank = Some(max_rank);
89 self
90 }
91
92 #[must_use]
94 pub fn with_policy(mut self, policy: SvdTruncationPolicy) -> Self {
95 self.policy = Some(policy);
96 self
97 }
98
99 pub(crate) fn full_rank() -> Self {
100 Self {
101 max_rank: None,
102 policy: None,
103 truncate: false,
104 }
105 }
106}
107
108fn default_policy_guard() -> std::sync::MutexGuard<'static, SvdTruncationPolicy> {
109 match DEFAULT_SVD_TRUNCATION_POLICY.lock() {
110 Ok(guard) => guard,
111 Err(poisoned) => poisoned.into_inner(),
112 }
113}
114
115static DEFAULT_SVD_TRUNCATION_POLICY: Mutex<SvdTruncationPolicy> =
117 Mutex::new(SvdTruncationPolicy::new(1e-12));
118
119#[must_use]
123pub fn default_svd_truncation_policy() -> SvdTruncationPolicy {
124 *default_policy_guard()
125}
126
127pub fn set_default_svd_truncation_policy(policy: SvdTruncationPolicy) -> Result<(), SvdError> {
135 validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
136 *default_policy_guard() = policy;
137 Ok(())
138}
139
140fn singular_value_measure(value: f64, measure: SingularValueMeasure) -> f64 {
141 match measure {
142 SingularValueMeasure::Value => value,
143 SingularValueMeasure::SquaredValue => value * value,
144 }
145}
146
147fn compute_retained_rank(s_vec: &[f64], policy: &SvdTruncationPolicy) -> usize {
149 if s_vec.is_empty() {
150 return 1;
151 }
152
153 let measured: Vec<f64> = s_vec
154 .iter()
155 .map(|&value| singular_value_measure(value, policy.measure))
156 .collect();
157 if measured.iter().all(|&value| value == 0.0) {
158 return 1;
159 }
160
161 let retained = match (policy.scale, policy.rule) {
162 (ThresholdScale::Relative, TruncationRule::PerValue) => {
163 let reference = measured.iter().copied().fold(0.0_f64, f64::max);
164 measured
165 .iter()
166 .take_while(|&&value| reference > 0.0 && value / reference > policy.threshold)
167 .count()
168 }
169 (ThresholdScale::Absolute, TruncationRule::PerValue) => measured
170 .iter()
171 .take_while(|&&value| value > policy.threshold)
172 .count(),
173 (ThresholdScale::Relative, TruncationRule::DiscardedTailSum) => {
174 let total: f64 = measured.iter().sum();
175 if total == 0.0 {
176 1
177 } else {
178 let mut discarded = 0.0;
179 let mut keep = measured.len();
180 for (i, value) in measured.iter().enumerate().rev() {
181 if (discarded + value) / total <= policy.threshold {
182 discarded += value;
183 keep = i;
184 } else {
185 break;
186 }
187 }
188 keep
189 }
190 }
191 (ThresholdScale::Absolute, TruncationRule::DiscardedTailSum) => {
192 let mut discarded = 0.0;
193 let mut keep = measured.len();
194 for (i, value) in measured.iter().enumerate().rev() {
195 if discarded + value <= policy.threshold {
196 discarded += value;
197 keep = i;
198 } else {
199 break;
200 }
201 }
202 keep
203 }
204 };
205
206 retained.max(1)
207}
208
209fn singular_values_from_native(tensor: &tenferro::Tensor) -> Result<Vec<f64>, SvdError> {
210 match tensor.dtype() {
211 DType::F64 => {
212 native_tensor_primal_to_dense_f64_col_major(tensor).map_err(SvdError::ComputationError)
213 }
214 DType::C64 => native_tensor_primal_to_dense_c64_col_major(tensor)
215 .map(|values| values.into_iter().map(|value| value.re).collect())
216 .map_err(SvdError::ComputationError),
217 other => Err(SvdError::ComputationError(anyhow::anyhow!(
218 "native SVD returned unsupported singular-value scalar type {other:?}"
219 ))),
220 }
221}
222
223type SvdTruncatedEagerResult = (
224 EagerTensor,
225 EagerTensor,
226 EagerTensor,
227 Vec<f64>,
228 DynIndex,
229 Vec<DynIndex>,
230 Vec<DynIndex>,
231);
232
233fn svd_truncated_inner(
234 t: &TensorDynLen,
235 left_inds: &[DynIndex],
236 options: &SvdOptions,
237) -> Result<SvdTruncatedEagerResult, SvdError> {
238 let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
239 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
240 .map_err(SvdError::ComputationError)?;
241 let k = m.min(n);
242
243 let (mut u_inner, mut s_inner, mut vt_inner) =
244 eager_svd(&matrix_inner).map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
245 let s_full = singular_values_from_native(s_inner.data())?;
246 let mut r = if options.truncate {
247 let policy = options.policy.unwrap_or_else(default_svd_truncation_policy);
248 validate_svd_truncation_policy(policy).map_err(|e| SvdError::InvalidThreshold(e.0))?;
249
250 let mut retained = compute_retained_rank(&s_full, &policy);
251 if let Some(max_rank) = options.max_rank {
252 retained = retained.min(max_rank);
253 }
254 retained.max(1)
255 } else {
256 k.max(1)
257 };
258 r = r.min(s_full.len());
259 if r < k {
260 let keep: Vec<usize> = (0..r).collect();
261 u_inner = u_inner
262 .take_axis(1, &keep)
263 .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
264 s_inner = s_inner
265 .take_axis(0, &keep)
266 .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
267 vt_inner = vt_inner
268 .take_axis(0, &keep)
269 .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?;
270 }
271
272 let bond_index = DynIndex::new_bond(r)
273 .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
274 .map_err(SvdError::ComputationError)?;
275 let singular_values = s_full[..r].to_vec();
276
277 Ok((
278 u_inner,
279 s_inner,
280 vt_inner,
281 singular_values,
282 bond_index,
283 left_indices,
284 right_indices,
285 ))
286}
287
288pub fn svd<T>(
311 t: &TensorDynLen,
312 left_inds: &[DynIndex],
313) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
314 svd_with::<T>(t, left_inds, &SvdOptions::default())
315}
316
317pub fn svd_with<T>(
348 t: &TensorDynLen,
349 left_inds: &[DynIndex],
350 options: &SvdOptions,
351) -> Result<(TensorDynLen, TensorDynLen, TensorDynLen), SvdError> {
352 let (u_inner, s_inner, vt_inner, _singular_values, bond_index, left_indices, right_indices) =
353 svd_truncated_inner(t, left_inds, options)?;
354
355 let mut u_indices = left_indices;
356 u_indices.push(bond_index.clone());
357 let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
358 let u_reshaped = u_inner.reshape(&u_dims).map_err(|e| {
359 SvdError::ComputationError(anyhow::anyhow!("eager SVD U reshape failed: {e}"))
360 })?;
361 let u = TensorDynLen::from_inner(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
362
363 let s_indices = vec![bond_index.clone(), bond_index.sim()];
364 let s =
365 TensorDynLen::from_diag_inner(s_indices, s_inner).map_err(SvdError::ComputationError)?;
366
367 let mut vh_indices = vec![bond_index.clone()];
368 vh_indices.extend(right_indices);
369 let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
370 let vt_reshaped = vt_inner.reshape(&vh_dims).map_err(|e| {
371 SvdError::ComputationError(anyhow::anyhow!("eager SVD V^T reshape failed: {e}"))
372 })?;
373 let vh =
374 TensorDynLen::from_inner(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
375 let perm: Vec<usize> = (1..vh.indices.len()).chain(std::iter::once(0)).collect();
376 let v = vh
377 .conj()
378 .permute(&perm)
379 .map_err(SvdError::ComputationError)?;
380
381 Ok((u, s, v))
382}
383
384pub(crate) struct SvdFactorizeResult {
386 pub u: TensorDynLen,
387 pub s: TensorDynLen,
388 pub vh: TensorDynLen,
389 pub bond_index: DynIndex,
390 pub singular_values: Vec<f64>,
391 pub rank: usize,
392}
393
394pub(crate) fn svd_for_factorize(
396 t: &TensorDynLen,
397 left_inds: &[DynIndex],
398 options: &SvdOptions,
399) -> Result<SvdFactorizeResult, SvdError> {
400 let (u_inner, s_inner, vt_inner, singular_values, bond_index, left_indices, right_indices) =
401 svd_truncated_inner(t, left_inds, options)?;
402 let rank = singular_values.len();
403
404 let mut u_indices = left_indices;
405 u_indices.push(bond_index.clone());
406 let u_dims: Vec<usize> = u_indices.iter().map(|idx| idx.dim).collect();
407 let u_reshaped = u_inner.reshape(&u_dims).map_err(|e| {
408 SvdError::ComputationError(anyhow::anyhow!("eager SVD U reshape failed: {e}"))
409 })?;
410 let u = TensorDynLen::from_inner(u_indices, u_reshaped).map_err(SvdError::ComputationError)?;
411
412 let s_indices = vec![bond_index.clone(), bond_index.sim()];
413 let s =
414 TensorDynLen::from_diag_inner(s_indices, s_inner).map_err(SvdError::ComputationError)?;
415
416 let mut vh_indices = vec![bond_index.clone()];
417 vh_indices.extend(right_indices);
418 let vh_dims: Vec<usize> = vh_indices.iter().map(|idx| idx.dim).collect();
419 let vt_reshaped = vt_inner.reshape(&vh_dims).map_err(|e| {
420 SvdError::ComputationError(anyhow::anyhow!("eager SVD V^T reshape failed: {e}"))
421 })?;
422 let vh =
423 TensorDynLen::from_inner(vh_indices, vt_reshaped).map_err(SvdError::ComputationError)?;
424
425 Ok(SvdFactorizeResult {
426 u,
427 s,
428 vh,
429 bond_index,
430 singular_values,
431 rank,
432 })
433}
434
435#[cfg(test)]
436mod tests;