tensor4all_core/defaults/qr.rs
1//! QR decomposition for tensors.
2//!
3//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
4
5use crate::defaults::tensordynlen::unfold_split_inner;
6use crate::defaults::DynIndex;
7use crate::global_default::GlobalDefault;
8use crate::TensorDynLen;
9use num_complex::ComplexFloat;
10use tenferro::DType;
11use tenferro_linalg::eager_tensor::qr as eager_qr;
12use tensor4all_tensorbackend::{
13 native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
14};
15use thiserror::Error;
16
17/// Error type for QR operations in tensor4all-linalg.
18#[derive(Debug, Error)]
19pub enum QrError {
20 /// QR computation failed.
21 #[error("QR computation failed: {0}")]
22 ComputationError(#[from] anyhow::Error),
23 /// Invalid relative tolerance value (must be finite and non-negative).
24 #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
25 InvalidRtol(f64),
26}
27
28/// Options for QR decomposition with truncation control.
29///
30/// # Examples
31///
32/// ```
33/// use tensor4all_core::qr::{QrOptions, qr_with};
34/// use tensor4all_core::{DynIndex, TensorContractionLike, TensorDynLen};
35///
36/// let i = DynIndex::new_dyn(3);
37/// let j = DynIndex::new_dyn(3);
38/// let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
39/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
40///
41/// let opts = QrOptions::new().with_rtol(1e-10);
42/// let (q, r) = qr_with::<f64>(&tensor, &[i], &opts).unwrap();
43///
44/// // Q * R recovers the original tensor
45/// let recovered = q.contract_pair(&r).unwrap();
46/// assert!(tensor.distance(&recovered).unwrap() < 1e-12);
47/// ```
48#[derive(Debug, Clone, Copy)]
49pub struct QrOptions {
50 /// Relative tolerance for QR row-norm truncation.
51 /// If `None`, uses the global default.
52 pub rtol: Option<f64>,
53 truncate: bool,
54}
55
56impl Default for QrOptions {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl QrOptions {
63 /// Create new QR options with no overrides.
64 #[must_use]
65 pub fn new() -> Self {
66 Self {
67 rtol: None,
68 truncate: true,
69 }
70 }
71
72 /// Set the QR truncation tolerance.
73 #[must_use]
74 pub fn with_rtol(mut self, rtol: f64) -> Self {
75 self.rtol = Some(rtol);
76 self
77 }
78
79 pub(crate) fn full_rank() -> Self {
80 Self {
81 rtol: None,
82 truncate: false,
83 }
84 }
85}
86
87// Global default rtol using the unified GlobalDefault type
88// Default value: 1e-15 (very strict, near machine precision)
89static DEFAULT_QR_RTOL: GlobalDefault = GlobalDefault::new(1e-15);
90
91/// Get the global default rtol for QR truncation.
92///
93/// The default value is 1e-15 (very strict, near machine precision).
94pub fn default_qr_rtol() -> f64 {
95 DEFAULT_QR_RTOL.get()
96}
97
98/// Set the global default rtol for QR truncation.
99///
100/// # Arguments
101/// * `rtol` - Relative tolerance (must be finite and non-negative)
102///
103/// # Errors
104/// Returns `QrError::InvalidRtol` if rtol is not finite or is negative.
105pub fn set_default_qr_rtol(rtol: f64) -> Result<(), QrError> {
106 DEFAULT_QR_RTOL
107 .set(rtol)
108 .map_err(|e| QrError::InvalidRtol(e.0))
109}
110
111fn compute_retained_rank_qr_from_dense<T>(
112 r_full: &[T],
113 k: usize,
114 n: usize,
115 rtol: f64,
116) -> Result<usize, QrError>
117where
118 T: ComplexFloat,
119 <T as ComplexFloat>::Real: Into<f64>,
120{
121 if k == 0 || n == 0 {
122 return Ok(1);
123 }
124
125 let max_diag = k.min(n);
126
127 // Compute row norms of R (upper triangular: row i has entries from column i..n).
128 // Use relative comparison against the maximum row norm, matching
129 // compute_retained_rank_qr. The previous implementation compared diagonal
130 // elements absolutely and broke at the first small value, which is incorrect
131 // for non-pivoted QR where diagonal elements are not necessarily in
132 // decreasing order.
133 let row_norms: Vec<f64> = (0..max_diag)
134 .map(|i| {
135 let mut norm_sq: f64 = 0.0;
136 for j in i..n {
137 let val: f64 = r_full[i + j * k].abs().into();
138 norm_sq += val * val;
139 }
140 norm_sq.sqrt()
141 })
142 .collect();
143
144 let max_row_norm = row_norms.iter().cloned().fold(0.0_f64, f64::max);
145 if max_row_norm == 0.0 {
146 return Ok(1);
147 }
148
149 let threshold = rtol * max_row_norm;
150 let r = row_norms.iter().filter(|&&norm| norm >= threshold).count();
151 Ok(r.max(1))
152}
153
154/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
155///
156/// This function uses the global default rtol for truncation.
157/// See `qr_with` for per-call rtol control.
158///
159/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
160/// we return Q (m×k) and R (k×n) with k = min(m, n).
161///
162/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
163/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
164///
165/// Truncation is performed based on R's row norms: rows whose norm is below
166/// `rtol * max_row_norm` are discarded.
167///
168/// For the mathematical convention:
169/// \[ A = Q * R \]
170/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
171///
172/// # Arguments
173/// * `t` - Input tensor with DenseF64 or DenseC64 storage
174/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
175///
176/// # Returns
177/// A tuple `(Q, R)` where:
178/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
179/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
180/// where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
181///
182/// # Errors
183/// Returns `QrError` if:
184/// - The tensor rank is < 2
185/// - Storage is not DenseF64 or DenseC64
186/// - `left_inds` is empty or contains all indices
187/// - `left_inds` contains indices not in the tensor or duplicates
188/// - The QR computation fails
189///
190/// # Examples
191///
192/// ```
193/// use tensor4all_core::{TensorDynLen, DynIndex, qr};
194///
195/// // Create a 4x3 matrix
196/// let i = DynIndex::new_dyn(4);
197/// let j = DynIndex::new_dyn(3);
198/// // Identity-like data (4x3 column-major)
199/// let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
200/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
201///
202/// let (q, r) = qr::<f64>(&t, &[i.clone()]).unwrap();
203///
204/// // Q has shape (4, bond) and R has shape (bond, 3)
205/// assert_eq!(q.dims()[0], 4);
206/// assert_eq!(r.dims()[r.dims().len() - 1], 3);
207/// ```
208pub fn qr<T>(
209 t: &TensorDynLen,
210 left_inds: &[DynIndex],
211) -> Result<(TensorDynLen, TensorDynLen), QrError> {
212 qr_with::<T>(t, left_inds, &QrOptions::default())
213}
214
215/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
216///
217/// This function allows per-call control of the truncation tolerance via `QrOptions`.
218/// If `options.rtol` is `None`, uses the global default rtol.
219///
220/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
221/// we return Q (m×k) and R (k×n) with k = min(m, n).
222///
223/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
224/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
225///
226/// Truncation is performed based on R's row norms: rows whose norm is below
227/// `rtol * max_row_norm` are discarded.
228///
229/// For the mathematical convention:
230/// \[ A = Q * R \]
231/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
232///
233/// # Arguments
234/// * `t` - Input tensor with DenseF64 or DenseC64 storage
235/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
236/// * `options` - QR options including rtol for truncation control
237///
238/// # Returns
239/// A tuple `(Q, R)` where:
240/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
241/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
242/// where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
243///
244/// # Errors
245/// Returns `QrError` if:
246/// - The tensor rank is < 2
247/// - Storage is not DenseF64 or DenseC64
248/// - `left_inds` is empty or contains all indices
249/// - `left_inds` contains indices not in the tensor or duplicates
250/// - The QR computation fails
251/// - `options.rtol` is invalid (not finite or negative)
252pub fn qr_with<T>(
253 t: &TensorDynLen,
254 left_inds: &[DynIndex],
255 options: &QrOptions,
256) -> Result<(TensorDynLen, TensorDynLen), QrError> {
257 // Unfold tensor into an eager rank-2 tensor so linalg AD nodes stay connected.
258 let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
259 .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
260 .map_err(QrError::ComputationError)?;
261 let k = m.min(n);
262 let (mut q_inner, mut r_inner) =
263 eager_qr(&matrix_inner).map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
264
265 let r = if options.truncate {
266 // Determine rtol to use
267 let rtol = options.rtol.unwrap_or(default_qr_rtol());
268 if !rtol.is_finite() || rtol < 0.0 {
269 return Err(QrError::InvalidRtol(rtol));
270 }
271
272 match r_inner.data().dtype() {
273 DType::F64 => {
274 let values = native_tensor_primal_to_dense_f64_col_major(r_inner.data())
275 .map_err(QrError::ComputationError)?;
276 compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
277 }
278 DType::C64 => {
279 let values = native_tensor_primal_to_dense_c64_col_major(r_inner.data())
280 .map_err(QrError::ComputationError)?;
281 compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
282 }
283 other => {
284 return Err(QrError::ComputationError(anyhow::anyhow!(
285 "native QR returned unsupported scalar type {other:?}"
286 )));
287 }
288 }
289 } else {
290 k
291 };
292 if r < k {
293 let keep: Vec<usize> = (0..r).collect();
294 q_inner = q_inner
295 .take_axis(1, &keep)
296 .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
297 r_inner = r_inner
298 .take_axis(0, &keep)
299 .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
300 }
301
302 let bond_index = DynIndex::new_bond(r)
303 .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
304 .map_err(QrError::ComputationError)?;
305
306 let mut q_indices = left_indices.clone();
307 q_indices.push(bond_index.clone());
308 let q_dims: Vec<usize> = q_indices.iter().map(|idx| idx.dim).collect();
309 let q_reshaped = q_inner.reshape(&q_dims).map_err(|e| {
310 QrError::ComputationError(anyhow::anyhow!("eager QR Q reshape failed: {e}"))
311 })?;
312 let q = TensorDynLen::from_inner(q_indices, q_reshaped).map_err(QrError::ComputationError)?;
313
314 let mut r_indices = vec![bond_index.clone()];
315 r_indices.extend_from_slice(&right_indices);
316 let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
317 let r_reshaped = r_inner.reshape(&r_dims).map_err(|e| {
318 QrError::ComputationError(anyhow::anyhow!("eager QR R reshape failed: {e}"))
319 })?;
320 let r = TensorDynLen::from_inner(r_indices, r_reshaped).map_err(QrError::ComputationError)?;
321
322 Ok((q, r))
323}
324
325#[cfg(test)]
326mod tests;