tenferro_tensor/validate/mod.rs
1//! Singular-matrix validation helpers shared across backends and exec layers.
2//!
3//! # Examples
4//!
5//! ```ignore
6//! use tenferro_tensor::validate::validate_nonsingular_u;
7//! use tenferro_tensor::{Tensor, TypedTensor};
8//!
9//! let t = Tensor::F64(TypedTensor::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]));
10//! assert!(validate_nonsingular_u(&t).is_ok());
11//! ```
12
13use num_complex::{Complex32, Complex64};
14
15use crate::{Error, Result, Tensor, TypedTensor};
16
17/// Trait for detecting singular or non-finite diagonal entries.
18///
19/// Implemented for `f32`, `f64`, `Complex32`, and `Complex64`.
20/// A value is considered singular if it is zero, NaN, infinite,
21/// or (for complex types) if either component is non-finite.
22pub trait DiagSingularity {
23 /// Returns `true` if the value is singular or non-finite.
24 fn is_singular_or_nonfinite(&self) -> bool;
25}
26
27macro_rules! impl_diag_singularity_float {
28 ($($t:ty),* $(,)?) => {
29 $(
30 impl DiagSingularity for $t {
31 fn is_singular_or_nonfinite(&self) -> bool {
32 !self.is_finite() || *self == 0.0
33 }
34 }
35 )*
36 };
37}
38
39impl_diag_singularity_float!(f64, f32);
40
41macro_rules! impl_diag_singularity_complex {
42 ($($t:ty),* $(,)?) => {
43 $(
44 impl DiagSingularity for $t {
45 fn is_singular_or_nonfinite(&self) -> bool {
46 !self.re.is_finite() || !self.im.is_finite() || self.norm_sqr() == 0.0
47 }
48 }
49 )*
50 };
51}
52
53impl_diag_singularity_complex!(Complex64, Complex32);
54
55/// Checks that every diagonal element of a (possibly batched) upper-triangular
56/// factor is non-singular and finite.
57///
58/// Iterates over all batch slices and inspects the diagonal entries
59/// `data[i + i * rows]` for `i` in `0..min(rows, cols)`. Returns
60/// [`Error::BackendFailure`] with `op: "solve"` on the first offending entry.
61///
62/// # Examples
63///
64/// ```ignore
65/// use tenferro_tensor::validate::check_singular_diagonal;
66/// use tenferro_tensor::TypedTensor;
67///
68/// let t = TypedTensor::from_vec(vec![2, 2], vec![1.0f32, 0.0, 0.0, 2.0]);
69/// assert!(check_singular_diagonal(&t).is_ok());
70/// ```
71pub fn check_singular_diagonal<T: DiagSingularity + Copy + std::fmt::Debug>(
72 t: &TypedTensor<T>,
73) -> Result<()> {
74 let rows = t.shape[0];
75 let cols = t.shape[1];
76 let n = rows.min(cols);
77 let batch_total: usize = t.shape[2..].iter().product();
78 let batch_total = batch_total.max(1);
79 let slice_size = rows * cols;
80 for batch_idx in 0..batch_total {
81 let batch = &t.host_data()[batch_idx * slice_size..(batch_idx + 1) * slice_size];
82 for i in 0..n {
83 let diag = batch[i + i * rows];
84 if diag.is_singular_or_nonfinite() {
85 return Err(Error::BackendFailure {
86 op: "solve",
87 message: if batch_total > 1 {
88 format!(
89 "singular matrix: non-finite or zero diagonal at batch {}, position [{},{}] = {:?}",
90 batch_idx, i, i, diag
91 )
92 } else {
93 format!(
94 "singular matrix: non-finite or zero diagonal at position [{},{}] = {:?}",
95 i, i, diag
96 )
97 },
98 });
99 }
100 }
101 }
102 Ok(())
103}
104
105/// Validates that the upper-triangular factor `u` of a matrix decomposition
106/// has no singular (zero) or non-finite diagonal entries.
107///
108/// Dispatches to [`check_singular_diagonal`] after unpacking the concrete
109/// tensor variant. Returns `Ok(())` when all diagonal entries are valid.
110///
111/// # Examples
112///
113/// ```ignore
114/// use tenferro_tensor::validate::validate_nonsingular_u;
115/// use tenferro_tensor::{Tensor, TypedTensor};
116///
117/// let t = Tensor::F64(TypedTensor::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]));
118/// assert!(validate_nonsingular_u(&t).is_ok());
119/// ```
120pub fn validate_nonsingular_u(u: &Tensor) -> Result<()> {
121 match u {
122 Tensor::F64(t) => check_singular_diagonal(t),
123 Tensor::F32(t) => check_singular_diagonal(t),
124 Tensor::C64(t) => check_singular_diagonal(t),
125 Tensor::C32(t) => check_singular_diagonal(t),
126 }
127}
128
129#[cfg(test)]
130mod tests;