Skip to main content

tenferro_tensor/validate/
mod.rs

1//! Singular-matrix validation helpers shared across backends and exec layers.
2//!
3//! # Examples
4//!
5//! ```ignore
6//! use tenferro_tensor::validate::validate_nonsingular_u;
7//! use tenferro_tensor::{Tensor, TypedTensor};
8//!
9//! let t = Tensor::F64(TypedTensor::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]));
10//! assert!(validate_nonsingular_u(&t).is_ok());
11//! ```
12
13use num_complex::{Complex32, Complex64};
14
15use crate::{Error, Result, Tensor, TypedTensor};
16
17/// Trait for detecting singular or non-finite diagonal entries.
18///
19/// Implemented for `f32`, `f64`, `Complex32`, and `Complex64`.
20/// A value is considered singular if it is zero, NaN, infinite,
21/// or (for complex types) if either component is non-finite.
22pub trait DiagSingularity {
23    /// Returns `true` if the value is singular or non-finite.
24    fn is_singular_or_nonfinite(&self) -> bool;
25}
26
27macro_rules! impl_diag_singularity_float {
28    ($($t:ty),* $(,)?) => {
29        $(
30            impl DiagSingularity for $t {
31                fn is_singular_or_nonfinite(&self) -> bool {
32                    !self.is_finite() || *self == 0.0
33                }
34            }
35        )*
36    };
37}
38
39impl_diag_singularity_float!(f64, f32);
40
41macro_rules! impl_diag_singularity_complex {
42    ($($t:ty),* $(,)?) => {
43        $(
44            impl DiagSingularity for $t {
45                fn is_singular_or_nonfinite(&self) -> bool {
46                    !self.re.is_finite() || !self.im.is_finite() || self.norm_sqr() == 0.0
47                }
48            }
49        )*
50    };
51}
52
53impl_diag_singularity_complex!(Complex64, Complex32);
54
55/// Checks that every diagonal element of a (possibly batched) upper-triangular
56/// factor is non-singular and finite.
57///
58/// Iterates over all batch slices and inspects the diagonal entries
59/// `data[i + i * rows]` for `i` in `0..min(rows, cols)`. Returns
60/// [`Error::BackendFailure`] with `op: "solve"` on the first offending entry.
61///
62/// # Examples
63///
64/// ```ignore
65/// use tenferro_tensor::validate::check_singular_diagonal;
66/// use tenferro_tensor::TypedTensor;
67///
68/// let t = TypedTensor::from_vec(vec![2, 2], vec![1.0f32, 0.0, 0.0, 2.0]);
69/// assert!(check_singular_diagonal(&t).is_ok());
70/// ```
71pub fn check_singular_diagonal<T: DiagSingularity + Copy + std::fmt::Debug>(
72    t: &TypedTensor<T>,
73) -> Result<()> {
74    let rows = t.shape[0];
75    let cols = t.shape[1];
76    let n = rows.min(cols);
77    let batch_total: usize = t.shape[2..].iter().product();
78    let batch_total = batch_total.max(1);
79    let slice_size = rows * cols;
80    for batch_idx in 0..batch_total {
81        let batch = &t.host_data()[batch_idx * slice_size..(batch_idx + 1) * slice_size];
82        for i in 0..n {
83            let diag = batch[i + i * rows];
84            if diag.is_singular_or_nonfinite() {
85                return Err(Error::BackendFailure {
86                    op: "solve",
87                    message: if batch_total > 1 {
88                        format!(
89                            "singular matrix: non-finite or zero diagonal at batch {}, position [{},{}] = {:?}",
90                            batch_idx, i, i, diag
91                        )
92                    } else {
93                        format!(
94                            "singular matrix: non-finite or zero diagonal at position [{},{}] = {:?}",
95                            i, i, diag
96                        )
97                    },
98                });
99            }
100        }
101    }
102    Ok(())
103}
104
105/// Validates that the upper-triangular factor `u` of a matrix decomposition
106/// has no singular (zero) or non-finite diagonal entries.
107///
108/// Dispatches to [`check_singular_diagonal`] after unpacking the concrete
109/// tensor variant. Returns `Ok(())` when all diagonal entries are valid.
110///
111/// # Examples
112///
113/// ```ignore
114/// use tenferro_tensor::validate::validate_nonsingular_u;
115/// use tenferro_tensor::{Tensor, TypedTensor};
116///
117/// let t = Tensor::F64(TypedTensor::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]));
118/// assert!(validate_nonsingular_u(&t).is_ok());
119/// ```
120pub fn validate_nonsingular_u(u: &Tensor) -> Result<()> {
121    match u {
122        Tensor::F64(t) => check_singular_diagonal(t),
123        Tensor::F32(t) => check_singular_diagonal(t),
124        Tensor::C64(t) => check_singular_diagonal(t),
125        Tensor::C32(t) => check_singular_diagonal(t),
126    }
127}
128
129#[cfg(test)]
130mod tests;