tensor4all_core/defaults/qr.rs
1//! QR decomposition for tensors.
2//!
3//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
4
5use crate::defaults::DynIndex;
6use crate::global_default::GlobalDefault;
7use crate::{unfold_split, TensorDynLen};
8use num_complex::ComplexFloat;
9use tenferro::DType;
10use tensor4all_tensorbackend::{
11 dense_native_tensor_from_col_major, native_tensor_primal_to_dense_c64_col_major,
12 native_tensor_primal_to_dense_f64_col_major, qr_native_tensor, reshape_col_major_native_tensor,
13 TensorElement,
14};
15use thiserror::Error;
16
17/// Error type for QR operations in tensor4all-linalg.
18#[derive(Debug, Error)]
19pub enum QrError {
20 /// QR computation failed.
21 #[error("QR computation failed: {0}")]
22 ComputationError(#[from] anyhow::Error),
23 /// Invalid relative tolerance value (must be finite and non-negative).
24 #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
25 InvalidRtol(f64),
26}
27
28/// Options for QR decomposition with truncation control.
29///
30/// # Examples
31///
32/// ```
33/// use tensor4all_core::qr::{QrOptions, qr_with};
34/// use tensor4all_core::{DynIndex, TensorDynLen};
35///
36/// let i = DynIndex::new_dyn(3);
37/// let j = DynIndex::new_dyn(3);
38/// let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
39/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
40///
41/// let opts = QrOptions::new().with_rtol(1e-10);
42/// let (q, r) = qr_with::<f64>(&tensor, &[i], &opts).unwrap();
43///
44/// // Q * R recovers the original tensor
45/// let recovered = q.contract(&r);
46/// assert!(tensor.distance(&recovered) < 1e-12);
47/// ```
48#[derive(Debug, Clone, Copy)]
49pub struct QrOptions {
50 /// Relative tolerance for QR row-norm truncation.
51 /// If `None`, uses the global default.
52 pub rtol: Option<f64>,
53 truncate: bool,
54}
55
56impl Default for QrOptions {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl QrOptions {
63 /// Create new QR options with no overrides.
64 #[must_use]
65 pub fn new() -> Self {
66 Self {
67 rtol: None,
68 truncate: true,
69 }
70 }
71
72 /// Set the QR truncation tolerance.
73 #[must_use]
74 pub fn with_rtol(mut self, rtol: f64) -> Self {
75 self.rtol = Some(rtol);
76 self
77 }
78
79 pub(crate) fn full_rank() -> Self {
80 Self {
81 rtol: None,
82 truncate: false,
83 }
84 }
85}
86
87// Global default rtol using the unified GlobalDefault type
88// Default value: 1e-15 (very strict, near machine precision)
89static DEFAULT_QR_RTOL: GlobalDefault = GlobalDefault::new(1e-15);
90
91/// Get the global default rtol for QR truncation.
92///
93/// The default value is 1e-15 (very strict, near machine precision).
94pub fn default_qr_rtol() -> f64 {
95 DEFAULT_QR_RTOL.get()
96}
97
98/// Set the global default rtol for QR truncation.
99///
100/// # Arguments
101/// * `rtol` - Relative tolerance (must be finite and non-negative)
102///
103/// # Errors
104/// Returns `QrError::InvalidRtol` if rtol is not finite or is negative.
105pub fn set_default_qr_rtol(rtol: f64) -> Result<(), QrError> {
106 DEFAULT_QR_RTOL
107 .set(rtol)
108 .map_err(|e| QrError::InvalidRtol(e.0))
109}
110
111fn compute_retained_rank_qr_from_dense<T>(
112 r_full: &[T],
113 k: usize,
114 n: usize,
115 rtol: f64,
116) -> Result<usize, QrError>
117where
118 T: ComplexFloat,
119 <T as ComplexFloat>::Real: Into<f64>,
120{
121 if k == 0 || n == 0 {
122 return Ok(1);
123 }
124
125 let max_diag = k.min(n);
126
127 // Compute row norms of R (upper triangular: row i has entries from column i..n).
128 // Use relative comparison against the maximum row norm, matching
129 // compute_retained_rank_qr. The previous implementation compared diagonal
130 // elements absolutely and broke at the first small value, which is incorrect
131 // for non-pivoted QR where diagonal elements are not necessarily in
132 // decreasing order.
133 let row_norms: Vec<f64> = (0..max_diag)
134 .map(|i| {
135 let mut norm_sq: f64 = 0.0;
136 for j in i..n {
137 let val: f64 = r_full[i + j * k].abs().into();
138 norm_sq += val * val;
139 }
140 norm_sq.sqrt()
141 })
142 .collect();
143
144 let max_row_norm = row_norms.iter().cloned().fold(0.0_f64, f64::max);
145 if max_row_norm == 0.0 {
146 return Ok(1);
147 }
148
149 let threshold = rtol * max_row_norm;
150 let r = row_norms.iter().filter(|&&norm| norm >= threshold).count();
151 Ok(r.max(1))
152}
153
154fn truncate_matrix_cols<T: TensorElement>(
155 data: &[T],
156 rows: usize,
157 keep_cols: usize,
158) -> anyhow::Result<tenferro::Tensor> {
159 dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
160}
161
162fn truncate_matrix_rows<T: TensorElement>(
163 data: &[T],
164 rows: usize,
165 cols: usize,
166 keep_rows: usize,
167) -> anyhow::Result<tenferro::Tensor> {
168 let mut truncated = Vec::with_capacity(keep_rows * cols);
169 for col in 0..cols {
170 let start = col * rows;
171 truncated.extend_from_slice(&data[start..start + keep_rows]);
172 }
173 dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
174}
175
176/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
177///
178/// This function uses the global default rtol for truncation.
179/// See `qr_with` for per-call rtol control.
180///
181/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
182/// we return Q (m×k) and R (k×n) with k = min(m, n).
183///
184/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
185/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
186///
187/// Truncation is performed based on R's row norms: rows whose norm is below
188/// `rtol * max_row_norm` are discarded.
189///
190/// For the mathematical convention:
191/// \[ A = Q * R \]
192/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
193///
194/// # Arguments
195/// * `t` - Input tensor with DenseF64 or DenseC64 storage
196/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
197///
198/// # Returns
199/// A tuple `(Q, R)` where:
200/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
201/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
202/// where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
203///
204/// # Errors
205/// Returns `QrError` if:
206/// - The tensor rank is < 2
207/// - Storage is not DenseF64 or DenseC64
208/// - `left_inds` is empty or contains all indices
209/// - `left_inds` contains indices not in the tensor or duplicates
210/// - The QR computation fails
211///
212/// # Examples
213///
214/// ```
215/// use tensor4all_core::{TensorDynLen, DynIndex, qr};
216///
217/// // Create a 4x3 matrix
218/// let i = DynIndex::new_dyn(4);
219/// let j = DynIndex::new_dyn(3);
220/// // Identity-like data (4x3 column-major)
221/// let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
222/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
223///
224/// let (q, r) = qr::<f64>(&t, &[i.clone()]).unwrap();
225///
226/// // Q has shape (4, bond) and R has shape (bond, 3)
227/// assert_eq!(q.dims()[0], 4);
228/// assert_eq!(r.dims()[r.dims().len() - 1], 3);
229/// ```
230pub fn qr<T>(
231 t: &TensorDynLen,
232 left_inds: &[DynIndex],
233) -> Result<(TensorDynLen, TensorDynLen), QrError> {
234 qr_with::<T>(t, left_inds, &QrOptions::default())
235}
236
237/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
238///
239/// This function allows per-call control of the truncation tolerance via `QrOptions`.
240/// If `options.rtol` is `None`, uses the global default rtol.
241///
242/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
243/// we return Q (m×k) and R (k×n) with k = min(m, n).
244///
245/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
246/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
247///
248/// Truncation is performed based on R's row norms: rows whose norm is below
249/// `rtol * max_row_norm` are discarded.
250///
251/// For the mathematical convention:
252/// \[ A = Q * R \]
253/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
254///
255/// # Arguments
256/// * `t` - Input tensor with DenseF64 or DenseC64 storage
257/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
258/// * `options` - QR options including rtol for truncation control
259///
260/// # Returns
261/// A tuple `(Q, R)` where:
262/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
263/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
264/// where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
265///
266/// # Errors
267/// Returns `QrError` if:
268/// - The tensor rank is < 2
269/// - Storage is not DenseF64 or DenseC64
270/// - `left_inds` is empty or contains all indices
271/// - `left_inds` contains indices not in the tensor or duplicates
272/// - The QR computation fails
273/// - `options.rtol` is invalid (not finite or negative)
274pub fn qr_with<T>(
275 t: &TensorDynLen,
276 left_inds: &[DynIndex],
277 options: &QrOptions,
278) -> Result<(TensorDynLen, TensorDynLen), QrError> {
279 // Unfold tensor into a native rank-2 tensor.
280 let (matrix_native, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
281 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
282 .map_err(QrError::ComputationError)?;
283 let k = m.min(n);
284 let (mut q_native, mut r_native) =
285 qr_native_tensor(&matrix_native).map_err(QrError::ComputationError)?;
286
287 let r = if options.truncate {
288 // Determine rtol to use
289 let rtol = options.rtol.unwrap_or(default_qr_rtol());
290 if !rtol.is_finite() || rtol < 0.0 {
291 return Err(QrError::InvalidRtol(rtol));
292 }
293
294 match r_native.dtype() {
295 DType::F64 => {
296 let values = native_tensor_primal_to_dense_f64_col_major(&r_native)
297 .map_err(QrError::ComputationError)?;
298 compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
299 }
300 DType::C64 => {
301 let values = native_tensor_primal_to_dense_c64_col_major(&r_native)
302 .map_err(QrError::ComputationError)?;
303 compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
304 }
305 other => {
306 return Err(QrError::ComputationError(anyhow::anyhow!(
307 "native QR returned unsupported scalar type {other:?}"
308 )));
309 }
310 }
311 } else {
312 k
313 };
314 if r < k {
315 match q_native.dtype() {
316 DType::F64 => {
317 let q_values = native_tensor_primal_to_dense_f64_col_major(&q_native)
318 .map_err(QrError::ComputationError)?;
319 let r_values = native_tensor_primal_to_dense_f64_col_major(&r_native)
320 .map_err(QrError::ComputationError)?;
321 q_native =
322 truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
323 r_native =
324 truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
325 }
326 DType::C64 => {
327 let q_values = native_tensor_primal_to_dense_c64_col_major(&q_native)
328 .map_err(QrError::ComputationError)?;
329 let r_values = native_tensor_primal_to_dense_c64_col_major(&r_native)
330 .map_err(QrError::ComputationError)?;
331 q_native =
332 truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
333 r_native =
334 truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
335 }
336 other => {
337 return Err(QrError::ComputationError(anyhow::anyhow!(
338 "native QR returned unsupported scalar type {other:?}"
339 )));
340 }
341 }
342 }
343
344 let bond_index = DynIndex::new_bond(r)
345 .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
346 .map_err(QrError::ComputationError)?;
347
348 let mut q_indices = left_indices.clone();
349 q_indices.push(bond_index.clone());
350 let q_dims: Vec<usize> = q_indices.iter().map(|idx| idx.dim).collect();
351 let q_reshaped = reshape_col_major_native_tensor(&q_native, &q_dims).map_err(|e| {
352 QrError::ComputationError(anyhow::anyhow!("native QR Q reshape failed: {e}"))
353 })?;
354 let q = TensorDynLen::from_native(q_indices, q_reshaped).map_err(QrError::ComputationError)?;
355
356 let mut r_indices = vec![bond_index.clone()];
357 r_indices.extend_from_slice(&right_indices);
358 let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
359 let r_reshaped = reshape_col_major_native_tensor(&r_native, &r_dims).map_err(|e| {
360 QrError::ComputationError(anyhow::anyhow!("native QR R reshape failed: {e}"))
361 })?;
362 let r = TensorDynLen::from_native(r_indices, r_reshaped).map_err(QrError::ComputationError)?;
363
364 Ok((q, r))
365}
366
367#[cfg(test)]
368mod tests;