tensor4all_core/defaults/qr.rs
1//! QR decomposition for tensors.
2//!
3//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
4
5use crate::defaults::tensordynlen::unfold_split_inner;
6use crate::defaults::DynIndex;
7use crate::global_default::GlobalDefault;
8use crate::TensorDynLen;
9use num_complex::ComplexFloat;
10use tenferro::DType;
11use tensor4all_tensorbackend::{
12 native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
13};
14use thiserror::Error;
15
16/// Error type for QR operations in tensor4all-linalg.
17#[derive(Debug, Error)]
18pub enum QrError {
19 /// QR computation failed.
20 #[error("QR computation failed: {0}")]
21 ComputationError(#[from] anyhow::Error),
22 /// Invalid relative tolerance value (must be finite and non-negative).
23 #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
24 InvalidRtol(f64),
25}
26
27/// Options for QR decomposition with truncation control.
28///
29/// # Examples
30///
31/// ```
32/// use tensor4all_core::qr::{QrOptions, qr_with};
33/// use tensor4all_core::{DynIndex, TensorContractionLike, TensorDynLen};
34///
35/// let i = DynIndex::new_dyn(3);
36/// let j = DynIndex::new_dyn(3);
37/// let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
38/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
39///
40/// let opts = QrOptions::new().with_rtol(1e-10);
41/// let (q, r) = qr_with::<f64>(&tensor, &[i], &opts).unwrap();
42///
43/// // Q * R recovers the original tensor
44/// let recovered = q.contract_pair(&r).unwrap();
45/// assert!(tensor.distance(&recovered).unwrap() < 1e-12);
46/// ```
47#[derive(Debug, Clone, Copy)]
48pub struct QrOptions {
49 /// Relative tolerance for QR row-norm truncation.
50 /// If `None`, uses the global default.
51 pub rtol: Option<f64>,
52 truncate: bool,
53}
54
55impl Default for QrOptions {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl QrOptions {
62 /// Create new QR options with no overrides.
63 #[must_use]
64 pub fn new() -> Self {
65 Self {
66 rtol: None,
67 truncate: true,
68 }
69 }
70
71 /// Set the QR truncation tolerance.
72 #[must_use]
73 pub fn with_rtol(mut self, rtol: f64) -> Self {
74 self.rtol = Some(rtol);
75 self
76 }
77
78 pub(crate) fn full_rank() -> Self {
79 Self {
80 rtol: None,
81 truncate: false,
82 }
83 }
84}
85
86// Global default rtol using the unified GlobalDefault type
87// Default value: 1e-15 (very strict, near machine precision)
88static DEFAULT_QR_RTOL: GlobalDefault = GlobalDefault::new(1e-15);
89
90/// Get the global default rtol for QR truncation.
91///
92/// The default value is 1e-15 (very strict, near machine precision).
93pub fn default_qr_rtol() -> f64 {
94 DEFAULT_QR_RTOL.get()
95}
96
97/// Set the global default rtol for QR truncation.
98///
99/// # Arguments
100/// * `rtol` - Relative tolerance (must be finite and non-negative)
101///
102/// # Errors
103/// Returns `QrError::InvalidRtol` if rtol is not finite or is negative.
104pub fn set_default_qr_rtol(rtol: f64) -> Result<(), QrError> {
105 DEFAULT_QR_RTOL
106 .set(rtol)
107 .map_err(|e| QrError::InvalidRtol(e.0))
108}
109
110fn compute_retained_rank_qr_from_dense<T>(
111 r_full: &[T],
112 k: usize,
113 n: usize,
114 rtol: f64,
115) -> Result<usize, QrError>
116where
117 T: ComplexFloat,
118 <T as ComplexFloat>::Real: Into<f64>,
119{
120 if k == 0 || n == 0 {
121 return Ok(1);
122 }
123
124 let max_diag = k.min(n);
125
126 // Compute row norms of R (upper triangular: row i has entries from column i..n).
127 // Use relative comparison against the maximum row norm, matching
128 // compute_retained_rank_qr. The previous implementation compared diagonal
129 // elements absolutely and broke at the first small value, which is incorrect
130 // for non-pivoted QR where diagonal elements are not necessarily in
131 // decreasing order.
132 let row_norms: Vec<f64> = (0..max_diag)
133 .map(|i| {
134 let mut norm_sq: f64 = 0.0;
135 for j in i..n {
136 let val: f64 = r_full[i + j * k].abs().into();
137 norm_sq += val * val;
138 }
139 norm_sq.sqrt()
140 })
141 .collect();
142
143 let max_row_norm = row_norms.iter().cloned().fold(0.0_f64, f64::max);
144 if max_row_norm == 0.0 {
145 return Ok(1);
146 }
147
148 let threshold = rtol * max_row_norm;
149 let r = row_norms.iter().filter(|&&norm| norm >= threshold).count();
150 Ok(r.max(1))
151}
152
153/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
154///
155/// This function uses the global default rtol for truncation.
156/// See `qr_with` for per-call rtol control.
157///
158/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
159/// we return Q (m×k) and R (k×n) with k = min(m, n).
160///
161/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
162/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
163///
164/// Truncation is performed based on R's row norms: rows whose norm is below
165/// `rtol * max_row_norm` are discarded.
166///
167/// For the mathematical convention:
168/// \[ A = Q * R \]
169/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
170///
171/// # Arguments
172/// * `t` - Input tensor with DenseF64 or DenseC64 storage
173/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
174///
175/// # Returns
176/// A tuple `(Q, R)` where:
177/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
178/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
179/// where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
180///
181/// # Errors
182/// Returns `QrError` if:
183/// - The tensor rank is < 2
184/// - Storage is not DenseF64 or DenseC64
185/// - `left_inds` is empty or contains all indices
186/// - `left_inds` contains indices not in the tensor or duplicates
187/// - The QR computation fails
188///
189/// # Examples
190///
191/// ```
192/// use tensor4all_core::{TensorDynLen, DynIndex, qr};
193///
194/// // Create a 4x3 matrix
195/// let i = DynIndex::new_dyn(4);
196/// let j = DynIndex::new_dyn(3);
197/// // Identity-like data (4x3 column-major)
198/// let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
199/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
200///
201/// let (q, r) = qr::<f64>(&t, &[i.clone()]).unwrap();
202///
203/// // Q has shape (4, bond) and R has shape (bond, 3)
204/// assert_eq!(q.dims()[0], 4);
205/// assert_eq!(r.dims()[r.dims().len() - 1], 3);
206/// ```
207pub fn qr<T>(
208 t: &TensorDynLen,
209 left_inds: &[DynIndex],
210) -> Result<(TensorDynLen, TensorDynLen), QrError> {
211 qr_with::<T>(t, left_inds, &QrOptions::default())
212}
213
214/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
215///
216/// This function allows per-call control of the truncation tolerance via `QrOptions`.
217/// If `options.rtol` is `None`, uses the global default rtol.
218///
219/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
220/// we return Q (m×k) and R (k×n) with k = min(m, n).
221///
222/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
223/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
224///
225/// Truncation is performed based on R's row norms: rows whose norm is below
226/// `rtol * max_row_norm` are discarded.
227///
228/// For the mathematical convention:
229/// \[ A = Q * R \]
230/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
231///
232/// # Arguments
233/// * `t` - Input tensor with DenseF64 or DenseC64 storage
234/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
235/// * `options` - QR options including rtol for truncation control
236///
237/// # Returns
238/// A tuple `(Q, R)` where:
239/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
240/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
241/// where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
242///
243/// # Errors
244/// Returns `QrError` if:
245/// - The tensor rank is < 2
246/// - Storage is not DenseF64 or DenseC64
247/// - `left_inds` is empty or contains all indices
248/// - `left_inds` contains indices not in the tensor or duplicates
249/// - The QR computation fails
250/// - `options.rtol` is invalid (not finite or negative)
251pub fn qr_with<T>(
252 t: &TensorDynLen,
253 left_inds: &[DynIndex],
254 options: &QrOptions,
255) -> Result<(TensorDynLen, TensorDynLen), QrError> {
256 // Unfold tensor into an eager rank-2 tensor so linalg AD nodes stay connected.
257 let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
258 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
259 .map_err(QrError::ComputationError)?;
260 let k = m.min(n);
261 let (mut q_inner, mut r_inner) = matrix_inner
262 .qr()
263 .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
264
265 let r = if options.truncate {
266 // Determine rtol to use
267 let rtol = options.rtol.unwrap_or(default_qr_rtol());
268 if !rtol.is_finite() || rtol < 0.0 {
269 return Err(QrError::InvalidRtol(rtol));
270 }
271
272 match r_inner.data().dtype() {
273 DType::F64 => {
274 let values = native_tensor_primal_to_dense_f64_col_major(r_inner.data())
275 .map_err(QrError::ComputationError)?;
276 compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
277 }
278 DType::C64 => {
279 let values = native_tensor_primal_to_dense_c64_col_major(r_inner.data())
280 .map_err(QrError::ComputationError)?;
281 compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
282 }
283 other => {
284 return Err(QrError::ComputationError(anyhow::anyhow!(
285 "native QR returned unsupported scalar type {other:?}"
286 )));
287 }
288 }
289 } else {
290 k
291 };
292 if r < k {
293 let keep: Vec<usize> = (0..r).collect();
294 q_inner = q_inner
295 .take_axis(1, &keep)
296 .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
297 r_inner = r_inner
298 .take_axis(0, &keep)
299 .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
300 }
301
302 let bond_index = DynIndex::new_bond(r)
303 .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
304 .map_err(QrError::ComputationError)?;
305
306 let mut q_indices = left_indices.clone();
307 q_indices.push(bond_index.clone());
308 let q_dims: Vec<usize> = q_indices.iter().map(|idx| idx.dim).collect();
309 let q_reshaped = q_inner.reshape(&q_dims).map_err(|e| {
310 QrError::ComputationError(anyhow::anyhow!("eager QR Q reshape failed: {e}"))
311 })?;
312 let q = TensorDynLen::from_inner(q_indices, q_reshaped).map_err(QrError::ComputationError)?;
313
314 let mut r_indices = vec![bond_index.clone()];
315 r_indices.extend_from_slice(&right_indices);
316 let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
317 let r_reshaped = r_inner.reshape(&r_dims).map_err(|e| {
318 QrError::ComputationError(anyhow::anyhow!("eager QR R reshape failed: {e}"))
319 })?;
320 let r = TensorDynLen::from_inner(r_indices, r_reshaped).map_err(QrError::ComputationError)?;
321
322 Ok((q, r))
323}
324
325#[cfg(test)]
326mod tests;