Skip to main content

tenferro_tensor/validate/
mod.rs

1//! Validation helpers shared across backends and exec layers.
2//!
3//! # Examples
4//!
5//! ```rust
6//! use tenferro_tensor::validate::validate_nonsingular_u;
7//! use tenferro_tensor::{Tensor, TypedTensor};
8//!
9//! let t = Tensor::F64(TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]).unwrap());
10//! assert!(validate_nonsingular_u(&t).is_ok());
11//! ```
12
13use num_complex::{Complex32, Complex64};
14
15use crate::{DType, Error, Result, Tensor, TypedTensor};
16
17/// Promote two dtypes according to tenferro's public dtype-promotion lattice.
18///
19/// # Examples
20///
21/// ```rust
22/// use tenferro_tensor::validate::promote_dtype;
23/// use tenferro_tensor::DType;
24///
25/// assert_eq!(promote_dtype(DType::I32, DType::F32), DType::F64);
26/// ```
27pub fn promote_dtype(lhs: DType, rhs: DType) -> DType {
28    use DType::*;
29    match (lhs, rhs) {
30        (Bool, Bool) => Bool,
31        (Bool, other) | (other, Bool) => other,
32        (I32, I32) => I32,
33        (I32, I64) | (I64, I32) | (I64, I64) => I64,
34        (I32 | I64, F32 | F64) | (F32 | F64, I32 | I64) => F64,
35        (I32 | I64, C32 | C64) | (C32 | C64, I32 | I64) => C64,
36        (F32, F32) => F32,
37        (F32, F64) | (F64, F32) | (F64, F64) => F64,
38        (F32, C32) | (C32, F32) | (C32, C32) => C32,
39        (F32, C64) | (C64, F32) => C64,
40        (F64, C32 | C64) | (C32 | C64, F64) => C64,
41        (C32, C64) | (C64, C32) | (C64, C64) => C64,
42    }
43}
44
45/// Return whether public `convert` may change `from` into `to`.
46///
47/// Checked conversion follows the same dtype lattice as implicit promotion.
48/// Use explicit `cast` for value-changing projections outside this lattice.
49///
50/// # Examples
51///
52/// ```rust
53/// use tenferro_tensor::validate::can_convert_dtype;
54/// use tenferro_tensor::DType;
55///
56/// assert!(can_convert_dtype(DType::F32, DType::F64));
57/// assert!(!can_convert_dtype(DType::F64, DType::I32));
58/// ```
59pub fn can_convert_dtype(from: DType, to: DType) -> bool {
60    promote_dtype(from, to) == to
61}
62
63/// Validate a public checked dtype conversion.
64///
65/// # Examples
66///
67/// ```rust
68/// use tenferro_tensor::validate::validate_convert_dtype;
69/// use tenferro_tensor::DType;
70///
71/// assert!(validate_convert_dtype("convert", DType::F32, DType::F64).is_ok());
72/// assert!(validate_convert_dtype("convert", DType::C64, DType::F64).is_err());
73/// ```
74pub fn validate_convert_dtype(op: &'static str, from: DType, to: DType) -> Result<()> {
75    if can_convert_dtype(from, to) {
76        return Ok(());
77    }
78
79    Err(Error::UnsupportedDTypeConversion {
80        op,
81        from,
82        to,
83        message: "checked convert only accepts conversions allowed by dtype promotion; use explicit cast for lossy dtype projection".to_string(),
84    })
85}
86
87/// Trait for detecting singular or non-finite diagonal entries.
88///
89/// Implemented for `f32`, `f64`, `Complex32`, and `Complex64`.
90/// A value is considered singular if it is zero, NaN, infinite,
91/// or (for complex types) if either component is non-finite.
92pub trait DiagSingularity {
93    /// Returns `true` if the value is singular or non-finite.
94    fn is_singular_or_nonfinite(&self) -> bool;
95}
96
97macro_rules! impl_diag_singularity_float {
98    ($($t:ty),* $(,)?) => {
99        $(
100            impl DiagSingularity for $t {
101                fn is_singular_or_nonfinite(&self) -> bool {
102                    !self.is_finite() || *self == 0.0
103                }
104            }
105        )*
106    };
107}
108
109impl_diag_singularity_float!(f64, f32);
110
111macro_rules! impl_diag_singularity_complex {
112    ($($t:ty),* $(,)?) => {
113        $(
114            impl DiagSingularity for $t {
115                fn is_singular_or_nonfinite(&self) -> bool {
116                    !self.re.is_finite() || !self.im.is_finite() || self.norm_sqr() == 0.0
117                }
118            }
119        )*
120    };
121}
122
123impl_diag_singularity_complex!(Complex64, Complex32);
124
125/// Checks that every diagonal element of a (possibly batched) upper-triangular
126/// factor is non-singular and finite.
127///
128/// Iterates over all batch slices and inspects the diagonal entries
129/// `data[i + i * rows]` for `i` in `0..min(rows, cols)`. Returns
130/// [`Error::BackendFailure`] with `op: "solve"` on the first offending entry,
131/// or [`Error::RankMismatch`] when `t` has rank less than two.
132///
133/// # Examples
134///
135/// ```rust
136/// use tenferro_tensor::validate::check_singular_diagonal;
137/// use tenferro_tensor::TypedTensor;
138///
139/// let t = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0f32, 0.0, 0.0, 2.0]).unwrap();
140/// assert!(check_singular_diagonal(&t).is_ok());
141/// ```
142pub fn check_singular_diagonal<T: DiagSingularity + Copy + std::fmt::Debug>(
143    t: &TypedTensor<T>,
144) -> Result<()> {
145    if t.shape().len() < 2 {
146        return Err(Error::RankMismatch {
147            op: "solve",
148            expected: 2,
149            actual: t.shape().len(),
150        });
151    }
152    let rows = t.shape()[0];
153    let cols = t.shape()[1];
154    let n = rows.min(cols);
155    let batch_total: usize = t.shape()[2..].iter().product();
156    let batch_total = batch_total.max(1);
157    let slice_size = rows * cols;
158    let data = t.host_data()?;
159    for batch_idx in 0..batch_total {
160        let batch = &data[batch_idx * slice_size..(batch_idx + 1) * slice_size];
161        for i in 0..n {
162            let diag = batch[i + i * rows];
163            if diag.is_singular_or_nonfinite() {
164                return Err(Error::backend_failure(
165                    "solve",
166                    if batch_total > 1 {
167                        format!(
168                            "singular matrix: non-finite or zero diagonal at batch {}, position [{},{}] = {:?}",
169                            batch_idx, i, i, diag
170                        )
171                    } else {
172                        format!(
173                            "singular matrix: non-finite or zero diagonal at position [{},{}] = {:?}",
174                            i, i, diag
175                        )
176                    },
177                ));
178            }
179        }
180    }
181    Ok(())
182}
183
184/// Validates that the upper-triangular factor `u` of a matrix decomposition
185/// has no singular (zero) or non-finite diagonal entries.
186///
187/// Dispatches to [`check_singular_diagonal`] after unpacking the concrete
188/// tensor variant. Returns `Ok(())` when all diagonal entries are valid.
189///
190/// # Examples
191///
192/// ```rust
193/// use tenferro_tensor::validate::validate_nonsingular_u;
194/// use tenferro_tensor::{Tensor, TypedTensor};
195///
196/// let t = Tensor::F64(TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]).unwrap());
197/// assert!(validate_nonsingular_u(&t).is_ok());
198/// ```
199pub fn validate_nonsingular_u(u: &Tensor) -> Result<()> {
200    match u {
201        Tensor::F64(t) => check_singular_diagonal(t),
202        Tensor::F32(t) => check_singular_diagonal(t),
203        Tensor::C64(t) => check_singular_diagonal(t),
204        Tensor::C32(t) => check_singular_diagonal(t),
205        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(Error::backend_failure(
206            "solve",
207            format!("unsupported dtype {:?}", u.dtype()),
208        )),
209    }
210}
211
212#[cfg(test)]
213mod tests;