Skip to main content

tensor4all_core/defaults/
qr.rs

1//! QR decomposition for tensors.
2//!
3//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
4
5use crate::defaults::tensordynlen::unfold_split_inner;
6use crate::defaults::DynIndex;
7use crate::global_default::GlobalDefault;
8use crate::TensorDynLen;
9use num_complex::ComplexFloat;
10use tenferro::DType;
11use tensor4all_tensorbackend::{
12    native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major,
13};
14use thiserror::Error;
15
16/// Error type for QR operations in tensor4all-linalg.
17#[derive(Debug, Error)]
18pub enum QrError {
19    /// QR computation failed.
20    #[error("QR computation failed: {0}")]
21    ComputationError(#[from] anyhow::Error),
22    /// Invalid relative tolerance value (must be finite and non-negative).
23    #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
24    InvalidRtol(f64),
25}
26
27/// Options for QR decomposition with truncation control.
28///
29/// # Examples
30///
31/// ```
32/// use tensor4all_core::qr::{QrOptions, qr_with};
33/// use tensor4all_core::{DynIndex, TensorContractionLike, TensorDynLen};
34///
35/// let i = DynIndex::new_dyn(3);
36/// let j = DynIndex::new_dyn(3);
37/// let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
38/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
39///
40/// let opts = QrOptions::new().with_rtol(1e-10);
41/// let (q, r) = qr_with::<f64>(&tensor, &[i], &opts).unwrap();
42///
43/// // Q * R recovers the original tensor
44/// let recovered = q.contract_pair(&r).unwrap();
45/// assert!(tensor.distance(&recovered).unwrap() < 1e-12);
46/// ```
47#[derive(Debug, Clone, Copy)]
48pub struct QrOptions {
49    /// Relative tolerance for QR row-norm truncation.
50    /// If `None`, uses the global default.
51    pub rtol: Option<f64>,
52    truncate: bool,
53}
54
55impl Default for QrOptions {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl QrOptions {
62    /// Create new QR options with no overrides.
63    #[must_use]
64    pub fn new() -> Self {
65        Self {
66            rtol: None,
67            truncate: true,
68        }
69    }
70
71    /// Set the QR truncation tolerance.
72    #[must_use]
73    pub fn with_rtol(mut self, rtol: f64) -> Self {
74        self.rtol = Some(rtol);
75        self
76    }
77
78    pub(crate) fn full_rank() -> Self {
79        Self {
80            rtol: None,
81            truncate: false,
82        }
83    }
84}
85
86// Global default rtol using the unified GlobalDefault type
87// Default value: 1e-15 (very strict, near machine precision)
88static DEFAULT_QR_RTOL: GlobalDefault = GlobalDefault::new(1e-15);
89
90/// Get the global default rtol for QR truncation.
91///
92/// The default value is 1e-15 (very strict, near machine precision).
93pub fn default_qr_rtol() -> f64 {
94    DEFAULT_QR_RTOL.get()
95}
96
97/// Set the global default rtol for QR truncation.
98///
99/// # Arguments
100/// * `rtol` - Relative tolerance (must be finite and non-negative)
101///
102/// # Errors
103/// Returns `QrError::InvalidRtol` if rtol is not finite or is negative.
104pub fn set_default_qr_rtol(rtol: f64) -> Result<(), QrError> {
105    DEFAULT_QR_RTOL
106        .set(rtol)
107        .map_err(|e| QrError::InvalidRtol(e.0))
108}
109
110fn compute_retained_rank_qr_from_dense<T>(
111    r_full: &[T],
112    k: usize,
113    n: usize,
114    rtol: f64,
115) -> Result<usize, QrError>
116where
117    T: ComplexFloat,
118    <T as ComplexFloat>::Real: Into<f64>,
119{
120    if k == 0 || n == 0 {
121        return Ok(1);
122    }
123
124    let max_diag = k.min(n);
125
126    // Compute row norms of R (upper triangular: row i has entries from column i..n).
127    // Use relative comparison against the maximum row norm, matching
128    // compute_retained_rank_qr. The previous implementation compared diagonal
129    // elements absolutely and broke at the first small value, which is incorrect
130    // for non-pivoted QR where diagonal elements are not necessarily in
131    // decreasing order.
132    let row_norms: Vec<f64> = (0..max_diag)
133        .map(|i| {
134            let mut norm_sq: f64 = 0.0;
135            for j in i..n {
136                let val: f64 = r_full[i + j * k].abs().into();
137                norm_sq += val * val;
138            }
139            norm_sq.sqrt()
140        })
141        .collect();
142
143    let max_row_norm = row_norms.iter().cloned().fold(0.0_f64, f64::max);
144    if max_row_norm == 0.0 {
145        return Ok(1);
146    }
147
148    let threshold = rtol * max_row_norm;
149    let r = row_norms.iter().filter(|&&norm| norm >= threshold).count();
150    Ok(r.max(1))
151}
152
153/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
154///
155/// This function uses the global default rtol for truncation.
156/// See `qr_with` for per-call rtol control.
157///
158/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
159/// we return Q (m×k) and R (k×n) with k = min(m, n).
160///
161/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
162/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
163///
164/// Truncation is performed based on R's row norms: rows whose norm is below
165/// `rtol * max_row_norm` are discarded.
166///
167/// For the mathematical convention:
168/// \[ A = Q * R \]
169/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
170///
171/// # Arguments
172/// * `t` - Input tensor with DenseF64 or DenseC64 storage
173/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
174///
175/// # Returns
176/// A tuple `(Q, R)` where:
177/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
178/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
179///   where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
180///
181/// # Errors
182/// Returns `QrError` if:
183/// - The tensor rank is < 2
184/// - Storage is not DenseF64 or DenseC64
185/// - `left_inds` is empty or contains all indices
186/// - `left_inds` contains indices not in the tensor or duplicates
187/// - The QR computation fails
188///
189/// # Examples
190///
191/// ```
192/// use tensor4all_core::{TensorDynLen, DynIndex, qr};
193///
194/// // Create a 4x3 matrix
195/// let i = DynIndex::new_dyn(4);
196/// let j = DynIndex::new_dyn(3);
197/// // Identity-like data (4x3 column-major)
198/// let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
199/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
200///
201/// let (q, r) = qr::<f64>(&t, &[i.clone()]).unwrap();
202///
203/// // Q has shape (4, bond) and R has shape (bond, 3)
204/// assert_eq!(q.dims()[0], 4);
205/// assert_eq!(r.dims()[r.dims().len() - 1], 3);
206/// ```
207pub fn qr<T>(
208    t: &TensorDynLen,
209    left_inds: &[DynIndex],
210) -> Result<(TensorDynLen, TensorDynLen), QrError> {
211    qr_with::<T>(t, left_inds, &QrOptions::default())
212}
213
214/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
215///
216/// This function allows per-call control of the truncation tolerance via `QrOptions`.
217/// If `options.rtol` is `None`, uses the global default rtol.
218///
219/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
220/// we return Q (m×k) and R (k×n) with k = min(m, n).
221///
222/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
223/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
224///
225/// Truncation is performed based on R's row norms: rows whose norm is below
226/// `rtol * max_row_norm` are discarded.
227///
228/// For the mathematical convention:
229/// \[ A = Q * R \]
230/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
231///
232/// # Arguments
233/// * `t` - Input tensor with DenseF64 or DenseC64 storage
234/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
235/// * `options` - QR options including rtol for truncation control
236///
237/// # Returns
238/// A tuple `(Q, R)` where:
239/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
240/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
241///   where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
242///
243/// # Errors
244/// Returns `QrError` if:
245/// - The tensor rank is < 2
246/// - Storage is not DenseF64 or DenseC64
247/// - `left_inds` is empty or contains all indices
248/// - `left_inds` contains indices not in the tensor or duplicates
249/// - The QR computation fails
250/// - `options.rtol` is invalid (not finite or negative)
251pub fn qr_with<T>(
252    t: &TensorDynLen,
253    left_inds: &[DynIndex],
254    options: &QrOptions,
255) -> Result<(TensorDynLen, TensorDynLen), QrError> {
256    // Unfold tensor into an eager rank-2 tensor so linalg AD nodes stay connected.
257    let (matrix_inner, _, m, n, left_indices, right_indices) = unfold_split_inner(t, left_inds)
258        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
259        .map_err(QrError::ComputationError)?;
260    let k = m.min(n);
261    let (mut q_inner, mut r_inner) = matrix_inner
262        .qr()
263        .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
264
265    let r = if options.truncate {
266        // Determine rtol to use
267        let rtol = options.rtol.unwrap_or(default_qr_rtol());
268        if !rtol.is_finite() || rtol < 0.0 {
269            return Err(QrError::InvalidRtol(rtol));
270        }
271
272        match r_inner.data().dtype() {
273            DType::F64 => {
274                let values = native_tensor_primal_to_dense_f64_col_major(r_inner.data())
275                    .map_err(QrError::ComputationError)?;
276                compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
277            }
278            DType::C64 => {
279                let values = native_tensor_primal_to_dense_c64_col_major(r_inner.data())
280                    .map_err(QrError::ComputationError)?;
281                compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
282            }
283            other => {
284                return Err(QrError::ComputationError(anyhow::anyhow!(
285                    "native QR returned unsupported scalar type {other:?}"
286                )));
287            }
288        }
289    } else {
290        k
291    };
292    if r < k {
293        let keep: Vec<usize> = (0..r).collect();
294        q_inner = q_inner
295            .take_axis(1, &keep)
296            .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
297        r_inner = r_inner
298            .take_axis(0, &keep)
299            .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?;
300    }
301
302    let bond_index = DynIndex::new_bond(r)
303        .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
304        .map_err(QrError::ComputationError)?;
305
306    let mut q_indices = left_indices.clone();
307    q_indices.push(bond_index.clone());
308    let q_dims: Vec<usize> = q_indices.iter().map(|idx| idx.dim).collect();
309    let q_reshaped = q_inner.reshape(&q_dims).map_err(|e| {
310        QrError::ComputationError(anyhow::anyhow!("eager QR Q reshape failed: {e}"))
311    })?;
312    let q = TensorDynLen::from_inner(q_indices, q_reshaped).map_err(QrError::ComputationError)?;
313
314    let mut r_indices = vec![bond_index.clone()];
315    r_indices.extend_from_slice(&right_indices);
316    let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
317    let r_reshaped = r_inner.reshape(&r_dims).map_err(|e| {
318        QrError::ComputationError(anyhow::anyhow!("eager QR R reshape failed: {e}"))
319    })?;
320    let r = TensorDynLen::from_inner(r_indices, r_reshaped).map_err(QrError::ComputationError)?;
321
322    Ok((q, r))
323}
324
325#[cfg(test)]
326mod tests;