tensor4all_core/defaults/
qr.rs

1//! QR decomposition for tensors.
2//!
3//! This module works with concrete types (`DynIndex`, `TensorDynLen`) only.
4
5use crate::defaults::DynIndex;
6use crate::global_default::GlobalDefault;
7use crate::truncation::TruncationParams;
8use crate::{unfold_split, TensorDynLen};
9use num_complex::ComplexFloat;
10use tensor4all_tensorbackend::{
11    dense_native_tensor_from_col_major, native_tensor_primal_to_dense_c64_col_major,
12    native_tensor_primal_to_dense_f64_col_major, qr_native_tensor, reshape_col_major_native_tensor,
13    TensorElement,
14};
15use thiserror::Error;
16
17/// Error type for QR operations in tensor4all-linalg.
18#[derive(Debug, Error)]
19pub enum QrError {
20    /// QR computation failed.
21    #[error("QR computation failed: {0}")]
22    ComputationError(#[from] anyhow::Error),
23    /// Invalid relative tolerance value (must be finite and non-negative).
24    #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
25    InvalidRtol(f64),
26}
27
28/// Options for QR decomposition with truncation control.
29#[derive(Debug, Clone, Copy, Default)]
30pub struct QrOptions {
31    /// Truncation parameters (rtol only for QR).
32    pub truncation: TruncationParams,
33}
34
35impl QrOptions {
36    /// Create new QR options with the specified rtol.
37    pub fn with_rtol(rtol: f64) -> Self {
38        Self {
39            truncation: TruncationParams::new().with_rtol(rtol),
40        }
41    }
42
43    /// Get rtol from options (for backwards compatibility).
44    pub fn rtol(&self) -> Option<f64> {
45        self.truncation.rtol
46    }
47}
48
49// Global default rtol using the unified GlobalDefault type
50// Default value: 1e-15 (very strict, near machine precision)
51static DEFAULT_QR_RTOL: GlobalDefault = GlobalDefault::new(1e-15);
52
53/// Get the global default rtol for QR truncation.
54///
55/// The default value is 1e-15 (very strict, near machine precision).
56pub fn default_qr_rtol() -> f64 {
57    DEFAULT_QR_RTOL.get()
58}
59
60/// Set the global default rtol for QR truncation.
61///
62/// # Arguments
63/// * `rtol` - Relative tolerance (must be finite and non-negative)
64///
65/// # Errors
66/// Returns `QrError::InvalidRtol` if rtol is not finite or is negative.
67pub fn set_default_qr_rtol(rtol: f64) -> Result<(), QrError> {
68    DEFAULT_QR_RTOL
69        .set(rtol)
70        .map_err(|e| QrError::InvalidRtol(e.0))
71}
72
73fn compute_retained_rank_qr_from_dense<T>(
74    r_full: &[T],
75    k: usize,
76    n: usize,
77    rtol: f64,
78) -> Result<usize, QrError>
79where
80    T: ComplexFloat,
81    <T as ComplexFloat>::Real: Into<f64>,
82{
83    if k == 0 || n == 0 {
84        return Ok(1);
85    }
86
87    let max_diag = k.min(n);
88
89    // Compute row norms of R (upper triangular: row i has entries from column i..n).
90    // Use relative comparison against the maximum row norm, matching
91    // compute_retained_rank_qr. The previous implementation compared diagonal
92    // elements absolutely and broke at the first small value, which is incorrect
93    // for non-pivoted QR where diagonal elements are not necessarily in
94    // decreasing order.
95    let row_norms: Vec<f64> = (0..max_diag)
96        .map(|i| {
97            let mut norm_sq: f64 = 0.0;
98            for j in i..n {
99                let val: f64 = r_full[i + j * k].abs().into();
100                norm_sq += val * val;
101            }
102            norm_sq.sqrt()
103        })
104        .collect();
105
106    let max_row_norm = row_norms.iter().cloned().fold(0.0_f64, f64::max);
107    if max_row_norm == 0.0 {
108        return Ok(1);
109    }
110
111    let threshold = rtol * max_row_norm;
112    let r = row_norms.iter().filter(|&&norm| norm >= threshold).count();
113    Ok(r.max(1))
114}
115
116fn truncate_matrix_cols<T: TensorElement>(
117    data: &[T],
118    rows: usize,
119    keep_cols: usize,
120) -> anyhow::Result<tenferro::Tensor> {
121    dense_native_tensor_from_col_major(&data[..rows * keep_cols], &[rows, keep_cols])
122}
123
124fn truncate_matrix_rows<T: TensorElement>(
125    data: &[T],
126    rows: usize,
127    cols: usize,
128    keep_rows: usize,
129) -> anyhow::Result<tenferro::Tensor> {
130    let mut truncated = Vec::with_capacity(keep_rows * cols);
131    for col in 0..cols {
132        let start = col * rows;
133        truncated.extend_from_slice(&data[start..start + keep_rows]);
134    }
135    dense_native_tensor_from_col_major(&truncated, &[keep_rows, cols])
136}
137
138/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
139///
140/// This function uses the global default rtol for truncation.
141/// See `qr_with` for per-call rtol control.
142///
143/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
144/// we return Q (m×k) and R (k×n) with k = min(m, n).
145///
146/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
147/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
148///
149/// Truncation is performed based on R's row norms: rows whose norm is below
150/// `rtol * max_row_norm` are discarded.
151///
152/// For the mathematical convention:
153/// \[ A = Q * R \]
154/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
155///
156/// # Arguments
157/// * `t` - Input tensor with DenseF64 or DenseC64 storage
158/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
159///
160/// # Returns
161/// A tuple `(Q, R)` where:
162/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
163/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
164///   where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
165///
166/// # Errors
167/// Returns `QrError` if:
168/// - The tensor rank is < 2
169/// - Storage is not DenseF64 or DenseC64
170/// - `left_inds` is empty or contains all indices
171/// - `left_inds` contains indices not in the tensor or duplicates
172/// - The QR computation fails
173///
174/// # Examples
175///
176/// ```
177/// use tensor4all_core::{TensorDynLen, DynIndex, qr};
178///
179/// // Create a 4x3 matrix
180/// let i = DynIndex::new_dyn(4);
181/// let j = DynIndex::new_dyn(3);
182/// // Identity-like data (4x3 column-major)
183/// let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
184/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
185///
186/// let (q, r) = qr::<f64>(&t, &[i.clone()]).unwrap();
187///
188/// // Q has shape (4, bond) and R has shape (bond, 3)
189/// assert_eq!(q.dims()[0], 4);
190/// assert_eq!(r.dims()[r.dims().len() - 1], 3);
191/// ```
192pub fn qr<T>(
193    t: &TensorDynLen,
194    left_inds: &[DynIndex],
195) -> Result<(TensorDynLen, TensorDynLen), QrError> {
196    qr_with::<T>(t, left_inds, &QrOptions::default())
197}
198
199/// Compute QR decomposition of a tensor with arbitrary rank, returning (Q, R).
200///
201/// This function allows per-call control of the truncation tolerance via `QrOptions`.
202/// If `options.rtol` is `None`, uses the global default rtol.
203///
204/// This function computes the thin QR decomposition, where for an unfolded matrix A (m×n),
205/// we return Q (m×k) and R (k×n) with k = min(m, n).
206///
207/// The input tensor can have any rank >= 2, and indices are split into left and right groups.
208/// The tensor is unfolded into a matrix by grouping left indices as rows and right indices as columns.
209///
210/// Truncation is performed based on R's row norms: rows whose norm is below
211/// `rtol * max_row_norm` are discarded.
212///
213/// For the mathematical convention:
214/// \[ A = Q * R \]
215/// where Q is orthogonal (or unitary for complex) and R is upper triangular.
216///
217/// # Arguments
218/// * `t` - Input tensor with DenseF64 or DenseC64 storage
219/// * `left_inds` - Indices to place on the left (row) side of the unfolded matrix
220/// * `options` - QR options including rtol for truncation control
221///
222/// # Returns
223/// A tuple `(Q, R)` where:
224/// - `Q` is a tensor with indices `[left_inds..., bond_index]` and dimensions `[left_dims..., r]`
225/// - `R` is a tensor with indices `[bond_index, right_inds...]` and dimensions `[r, right_dims...]`
226///   where `r` is the retained rank (≤ min(m, n)) determined by rtol truncation.
227///
228/// # Errors
229/// Returns `QrError` if:
230/// - The tensor rank is < 2
231/// - Storage is not DenseF64 or DenseC64
232/// - `left_inds` is empty or contains all indices
233/// - `left_inds` contains indices not in the tensor or duplicates
234/// - The QR computation fails
235/// - `options.rtol` is invalid (not finite or negative)
236pub fn qr_with<T>(
237    t: &TensorDynLen,
238    left_inds: &[DynIndex],
239    options: &QrOptions,
240) -> Result<(TensorDynLen, TensorDynLen), QrError> {
241    // Determine rtol to use
242    let rtol = options.truncation.effective_rtol(default_qr_rtol());
243    if !rtol.is_finite() || rtol < 0.0 {
244        return Err(QrError::InvalidRtol(rtol));
245    }
246
247    // Unfold tensor into a native rank-2 tensor.
248    let (matrix_native, _, m, n, left_indices, right_indices) = unfold_split(t, left_inds)
249        .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e))
250        .map_err(QrError::ComputationError)?;
251    let k = m.min(n);
252    let (mut q_native, mut r_native) =
253        qr_native_tensor(&matrix_native).map_err(QrError::ComputationError)?;
254    let r = match r_native.scalar_type() {
255        tenferro::ScalarType::F64 => {
256            let values = native_tensor_primal_to_dense_f64_col_major(&r_native)
257                .map_err(QrError::ComputationError)?;
258            compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
259        }
260        tenferro::ScalarType::C64 => {
261            let values = native_tensor_primal_to_dense_c64_col_major(&r_native)
262                .map_err(QrError::ComputationError)?;
263            compute_retained_rank_qr_from_dense(&values, k, n, rtol)?
264        }
265        other => {
266            return Err(QrError::ComputationError(anyhow::anyhow!(
267                "native QR returned unsupported scalar type {other:?}"
268            )));
269        }
270    };
271    if r < k {
272        match q_native.scalar_type() {
273            tenferro::ScalarType::F64 => {
274                let q_values = native_tensor_primal_to_dense_f64_col_major(&q_native)
275                    .map_err(QrError::ComputationError)?;
276                let r_values = native_tensor_primal_to_dense_f64_col_major(&r_native)
277                    .map_err(QrError::ComputationError)?;
278                q_native =
279                    truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
280                r_native =
281                    truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
282            }
283            tenferro::ScalarType::C64 => {
284                let q_values = native_tensor_primal_to_dense_c64_col_major(&q_native)
285                    .map_err(QrError::ComputationError)?;
286                let r_values = native_tensor_primal_to_dense_c64_col_major(&r_native)
287                    .map_err(QrError::ComputationError)?;
288                q_native =
289                    truncate_matrix_cols(&q_values, m, r).map_err(QrError::ComputationError)?;
290                r_native =
291                    truncate_matrix_rows(&r_values, k, n, r).map_err(QrError::ComputationError)?;
292            }
293            other => {
294                return Err(QrError::ComputationError(anyhow::anyhow!(
295                    "native QR returned unsupported scalar type {other:?}"
296                )));
297            }
298        }
299    }
300
301    let bond_index = DynIndex::new_bond(r)
302        .map_err(|e| anyhow::anyhow!("Failed to create Link index: {:?}", e))
303        .map_err(QrError::ComputationError)?;
304
305    let mut q_indices = left_indices.clone();
306    q_indices.push(bond_index.clone());
307    let q_dims: Vec<usize> = q_indices.iter().map(|idx| idx.dim).collect();
308    let q_reshaped = reshape_col_major_native_tensor(&q_native, &q_dims).map_err(|e| {
309        QrError::ComputationError(anyhow::anyhow!("native QR Q reshape failed: {e}"))
310    })?;
311    let q = TensorDynLen::from_native(q_indices, q_reshaped).map_err(QrError::ComputationError)?;
312
313    let mut r_indices = vec![bond_index.clone()];
314    r_indices.extend_from_slice(&right_indices);
315    let r_dims: Vec<usize> = r_indices.iter().map(|idx| idx.dim).collect();
316    let r_reshaped = reshape_col_major_native_tensor(&r_native, &r_dims).map_err(|e| {
317        QrError::ComputationError(anyhow::anyhow!("native QR R reshape failed: {e}"))
318    })?;
319    let r = TensorDynLen::from_native(r_indices, r_reshaped).map_err(QrError::ComputationError)?;
320
321    Ok((q, r))
322}
323
324#[cfg(test)]
325mod tests;