Skip to main content

tensor4all_core/defaults/
qr.rs

1//! QR decomposition for tensors.
2//!
3//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
4
5use crate::defaults::DynIndex;
6use crate::global_default::GlobalDefault;
7use crate::{unfold_split, TensorDynLen};
8use num_complex::ComplexFloat;
9use tenferro::DType;
10use tensor4all_tensorbackend::{
11    dense_native_tensor_from_col_major, native_tensor_primal_to_dense_c64_col_major,
12    native_tensor_primal_to_dense_f64_col_major, qr_native_tensor, reshape_col_major_native_tensor,
13    TensorElement,
14};
15use thiserror::Error;
16
17/// Error type for QR operations in tensor4all-linalg.
18#[derive(Debug, Error)]
19pub enum QrError {
20    /// QR computation failed.
21    #[error("QR computation failed: {0}")]
22    ComputationError(#[from] anyhow::Error),
23    /// Invalid relative tolerance value (must be finite and non-negative).
24    #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
25    InvalidRtol(f64),
26}
27
28/// Options for QR decomposition with truncation control.
29///
30/// # Examples
31///
32/// ```
33/// use tensor4all_core::qr::{QrOptions, qr_with};
34/// use tensor4all_core::{DynIndex, TensorDynLen};
35///
36/// let i = DynIndex::new_dyn(3);
37/// let j = DynIndex::new_dyn(3);
38/// let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
39/// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
40///
41/// let opts = QrOptions::new().with_rtol(1e-10);
42/// let (q, r) = qr_with::<f64>(&tensor, &[i], &opts).unwrap();
43///
44/// // Q * R recovers the original tensor
45/// let recovered = q.contract(&r);
46/// assert!(tensor.distance(&recovered) < 1e-12);
47/// ```
48#[derive(Debug, Clone, Copy)]
49pub struct QrOptions {
50    /// Relative tolerance for QR row-norm truncation.
51    /// If `None`, uses the global default.
52    pub rtol: Option<f64>,
53    truncate: bool,
54}
55
56impl Default for QrOptions {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl QrOptions {
63    /// Create new QR options with no overrides.
64    #[must_use]
65    pub fn new() -> Self {
66        Self {
67            rtol: None,
68            truncate: true,
69        }
70    }
71
72    /// Set the QR truncation tolerance.
73    #[must_use]
74    pub fn with_rtol(mut self, rtol: f64) -> Self {
75        self.rtol = Some(rtol);
76        self
77    }
78
79    pub(crate) fn full_rank() -> Self {
80        Self {
81            rtol: None,
82            truncate: false,
83        }
84    }
85}
86
87// Global default rtol using the unified GlobalDefault type
88// Default value: 1e-15 (very strict, near machine precision)
89static DEFAULT_QR_RTOL: GlobalDefault = GlobalDefault::new(1e-15);
90
91/// Get the global default rtol for QR truncation.
92///
93/// The default value is 1e-15 (very strict, near machine precision).
94pub fn default_qr_rtol() -> f64 {
95    DEFAULT_QR_RTOL.get()
96}
97
98/// Set the global default rtol for QR truncation.
99///
100/// # Arguments
101/// * `rtol` - Relative tolerance (must be finite and non-negative)
102///
103/// # Errors
104/// Returns `QrError::InvalidRtol` if rtol is not finite or is negative.
105pub fn set_default_qr_rtol(rtol: f64) -> Result<(), QrError> {
106    DEFAULT_QR_RTOL
107        .set(rtol)
108        .map_err(|e| QrError::InvalidRtol(e.0))
109}
110
111fn compute_retained_rank_qr_from_dense<T>(
112    r_full: &[T],
113    k: usize,
114    n: usize,
115    rtol: f64,
116) -> Result<usize, QrError>
117where
118    T: ComplexFloat,
119    <T as ComplexFloat>::Real: Into<f64>,
120{
121    if k == 0 || n == 0 {
122        return Ok(1);
123    }
124
125    let max_diag = k.min(n);
126
127    // Compute row norms of R (upper triangular: row i has entries from column i..n).
128    // Use relative comparison against the maximum row norm, matching
129    // compute_retained_rank_qr. The previous implementation compared diagonal
130    // elements absolutely and broke at the first small value, which is incorrect
131    // for non-pivoted QR where diagonal elements are not necessarily in
132    // decreasing order.
133    let row_norms: Vec<f64> = (0..max_diag)
134        .map(|i| {
135            let mut norm_sq: f64 = 0.0;
136            for j in i..n {
137                let val: f64 = r_full[i + j * k].abs().into();
138                norm_sq += val * val;
139            }
140            norm_sq.sqrt()
141        })
142        .collect();
143
144    let max_row_norm = row_norms.iter().cloned().fold(0.0_f64, f64::max);
145    if max_row_norm == 0.0 {
146        return Ok(1);
147    }
148
149    let threshold = rtol * max_row_norm;
150    let r = row_norms.iter().filter(|&&norm| norm >= threshold).count();
151    Ok(r.max(1))
152}
153
154fn truncate_matrix_cols<T: TensorElement>(
155    data: &[T],
156    rows: usize,
157    keep_cols: usize,
158) -> anyhow::Result<tenferro::Tensor> {
159    dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
160}
161
162fn truncate_matrix_rows<T: TensorElement>(
163    data: &[T],
164    rows: usize,
165    cols: usize,
166    keep_rows: usize,
167) -> anyhow::Result<tenferro::Tensor> {
168    let mut truncated = Vec::with_capacity(keep_rows * cols);
169    for col in 0..cols {
170        let start = col * rows;
171        truncated.extend_from_slice(&data[start..start + keep_rows]);
172    }
173    dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
174}
175
176/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
177///
178/// This function uses the global default rtol for truncation.
179/// See `qr_with` for per-call rtol control.
180///
181/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
182/// we return Q (m×k) and R (k×n) with k = min(m, n).
183///
184/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
185/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
186///
187/// Truncation is performed based on R's row norms: rows whose norm is below
188/// `rtol * max_row_norm` are discarded.
189///
190/// For the mathematical convention:
191/// \[ A = Q * R \]
192/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
193///
194/// # Arguments
195/// * `t` - Input tensor with DenseF64 or DenseC64 storage
196/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
197///
198/// # Returns
199/// A tuple `(Q, R)` where:
200/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
201/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
202///   where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
203///
204/// # Errors
205/// Returns `QrError` if:
206/// - The tensor rank is < 2
207/// - Storage is not DenseF64 or DenseC64
208/// - `left_inds` is empty or contains all indices
209/// - `left_inds` contains indices not in the tensor or duplicates
210/// - The QR computation fails
211///
212/// # Examples
213///
214/// ```
215/// use tensor4all_core::{TensorDynLen, DynIndex, qr};
216///
217/// // Create a 4x3 matrix
218/// let i = DynIndex::new_dyn(4);
219/// let j = DynIndex::new_dyn(3);
220/// // Identity-like data (4x3 column-major)
221/// let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
222/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
223///
224/// let (q, r) = qr::<f64>(&t, &[i.clone()]).unwrap();
225///
226/// // Q has shape (4, bond) and R has shape (bond, 3)
227/// assert_eq!(q.dims()[0], 4);
228/// assert_eq!(r.dims()[r.dims().len() - 1], 3);
229/// ```
230pub fn qr<T>(
231    t: &TensorDynLen,
232    left_inds: &[DynIndex],
233) -> Result<(TensorDynLen, TensorDynLen), QrError> {
234    qr_with::<T>(t, left_inds, &QrOptions::default())
235}
236
237/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
238///
239/// This function allows per-call control of the truncation tolerance via `QrOptions`.
240/// If `options.rtol` is `None`, uses the global default rtol.
241///
242/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
243/// we return Q (m×k) and R (k×n) with k = min(m, n).
244///
245/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
246/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
247///
248/// Truncation is performed based on R's row norms: rows whose norm is below
249/// `rtol * max_row_norm` are discarded.
250///
251/// For the mathematical convention:
252/// \[ A = Q * R \]
253/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
254///
255/// # Arguments
256/// * `t` - Input tensor with DenseF64 or DenseC64 storage
257/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
258/// * `options` - QR options including rtol for truncation control
259///
260/// # Returns
261/// A tuple `(Q, R)` where:
262/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
263/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
264///   where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
265///
266/// # Errors
267/// Returns `QrError` if:
268/// - The tensor rank is < 2
269/// - Storage is not DenseF64 or DenseC64
270/// - `left_inds` is empty or contains all indices
271/// - `left_inds` contains indices not in the tensor or duplicates
272/// - The QR computation fails
273/// - `options.rtol` is invalid (not finite or negative)
274pub fn qr_with<T>(
275    t: &TensorDynLen,
276    left_inds: &[DynIndex],
277    options: &QrOptions,
278) -> Result<(TensorDynLen, TensorDynLen), QrError> {
279    // Unfold tensor into a native rank-2 tensor.
280    let (matrix_native, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
281        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
282        .map_err(QrError::ComputationError)?;
283    let k = m.min(n);
284    let (mut q_native, mut r_native) =
285        qr_native_tensor(&matrix_native).map_err(QrError::ComputationError)?;
286
287    let r = if options.truncate {
288        // Determine rtol to use
289        let rtol = options.rtol.unwrap_or(default_qr_rtol());
290        if !rtol.is_finite() || rtol < 0.0 {
291            return Err(QrError::InvalidRtol(rtol));
292        }
293
294        match r_native.dtype() {
295            DType::F64 => {
296                let values = native_tensor_primal_to_dense_f64_col_major(&r_native)
297                    .map_err(QrError::ComputationError)?;
298                compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
299            }
300            DType::C64 => {
301                let values = native_tensor_primal_to_dense_c64_col_major(&r_native)
302                    .map_err(QrError::ComputationError)?;
303                compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
304            }
305            other => {
306                return Err(QrError::ComputationError(anyhow::anyhow!(
307                    "native QR returned unsupported scalar type {other:?}"
308                )));
309            }
310        }
311    } else {
312        k
313    };
314    if r < k {
315        match q_native.dtype() {
316            DType::F64 => {
317                let q_values = native_tensor_primal_to_dense_f64_col_major(&q_native)
318                    .map_err(QrError::ComputationError)?;
319                let r_values = native_tensor_primal_to_dense_f64_col_major(&r_native)
320                    .map_err(QrError::ComputationError)?;
321                q_native =
322                    truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
323                r_native =
324                    truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
325            }
326            DType::C64 => {
327                let q_values = native_tensor_primal_to_dense_c64_col_major(&q_native)
328                    .map_err(QrError::ComputationError)?;
329                let r_values = native_tensor_primal_to_dense_c64_col_major(&r_native)
330                    .map_err(QrError::ComputationError)?;
331                q_native =
332                    truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
333                r_native =
334                    truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
335            }
336            other => {
337                return Err(QrError::ComputationError(anyhow::anyhow!(
338                    "native QR returned unsupported scalar type {other:?}"
339                )));
340            }
341        }
342    }
343
344    let bond_index = DynIndex::new_bond(r)
345        .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
346        .map_err(QrError::ComputationError)?;
347
348    let mut q_indices = left_indices.clone();
349    q_indices.push(bond_index.clone());
350    let q_dims: Vec<usize> = q_indices.iter().map(|idx| idx.dim).collect();
351    let q_reshaped = reshape_col_major_native_tensor(&q_native, &q_dims).map_err(|e| {
352        QrError::ComputationError(anyhow::anyhow!("native QR Q reshape failed: {e}"))
353    })?;
354    let q = TensorDynLen::from_native(q_indices, q_reshaped).map_err(QrError::ComputationError)?;
355
356    let mut r_indices = vec![bond_index.clone()];
357    r_indices.extend_from_slice(&right_indices);
358    let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
359    let r_reshaped = reshape_col_major_native_tensor(&r_native, &r_dims).map_err(|e| {
360        QrError::ComputationError(anyhow::anyhow!("native QR R reshape failed: {e}"))
361    })?;
362    let r = TensorDynLen::from_native(r_indices, r_reshaped).map_err(QrError::ComputationError)?;
363
364    Ok((q, r))
365}
366
367#[cfg(test)]
368mod tests;