tenferro_tensor/validate/mod.rs
1//! Validation helpers shared across backends and exec layers.
2//!
3//! # Examples
4//!
5//! ```rust
6//! use tenferro_tensor::validate::validate_nonsingular_u;
7//! use tenferro_tensor::{Tensor, TypedTensor};
8//!
9//! let t = Tensor::F64(TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]).unwrap());
10//! assert!(validate_nonsingular_u(&t).is_ok());
11//! ```
12
13use num_complex::{Complex32, Complex64};
14
15use crate::{DType, Error, Result, Tensor, TypedTensor};
16
17/// Promote two dtypes according to tenferro's public dtype-promotion lattice.
18///
19/// # Examples
20///
21/// ```rust
22/// use tenferro_tensor::validate::promote_dtype;
23/// use tenferro_tensor::DType;
24///
25/// assert_eq!(promote_dtype(DType::I32, DType::F32), DType::F64);
26/// ```
27pub fn promote_dtype(lhs: DType, rhs: DType) -> DType {
28 use DType::*;
29 match (lhs, rhs) {
30 (Bool, Bool) => Bool,
31 (Bool, other) | (other, Bool) => other,
32 (I32, I32) => I32,
33 (I32, I64) | (I64, I32) | (I64, I64) => I64,
34 (I32 | I64, F32 | F64) | (F32 | F64, I32 | I64) => F64,
35 (I32 | I64, C32 | C64) | (C32 | C64, I32 | I64) => C64,
36 (F32, F32) => F32,
37 (F32, F64) | (F64, F32) | (F64, F64) => F64,
38 (F32, C32) | (C32, F32) | (C32, C32) => C32,
39 (F32, C64) | (C64, F32) => C64,
40 (F64, C32 | C64) | (C32 | C64, F64) => C64,
41 (C32, C64) | (C64, C32) | (C64, C64) => C64,
42 }
43}
44
45/// Return whether public `convert` may change `from` into `to`.
46///
47/// Checked conversion follows the same dtype lattice as implicit promotion.
48/// Use explicit `cast` for value-changing projections outside this lattice.
49///
50/// # Examples
51///
52/// ```rust
53/// use tenferro_tensor::validate::can_convert_dtype;
54/// use tenferro_tensor::DType;
55///
56/// assert!(can_convert_dtype(DType::F32, DType::F64));
57/// assert!(!can_convert_dtype(DType::F64, DType::I32));
58/// ```
59pub fn can_convert_dtype(from: DType, to: DType) -> bool {
60 promote_dtype(from, to) == to
61}
62
63/// Validate a public checked dtype conversion.
64///
65/// # Examples
66///
67/// ```rust
68/// use tenferro_tensor::validate::validate_convert_dtype;
69/// use tenferro_tensor::DType;
70///
71/// assert!(validate_convert_dtype("convert", DType::F32, DType::F64).is_ok());
72/// assert!(validate_convert_dtype("convert", DType::C64, DType::F64).is_err());
73/// ```
74pub fn validate_convert_dtype(op: &'static str, from: DType, to: DType) -> Result<()> {
75 if can_convert_dtype(from, to) {
76 return Ok(());
77 }
78
79 Err(Error::UnsupportedDTypeConversion {
80 op,
81 from,
82 to,
83 message: "checked convert only accepts conversions allowed by dtype promotion; use explicit cast for lossy dtype projection".to_string(),
84 })
85}
86
87/// Trait for detecting singular or non-finite diagonal entries.
88///
89/// Implemented for `f32`, `f64`, `Complex32`, and `Complex64`.
90/// A value is considered singular if it is zero, NaN, infinite,
91/// or (for complex types) if either component is non-finite.
92pub trait DiagSingularity {
93 /// Returns `true` if the value is singular or non-finite.
94 fn is_singular_or_nonfinite(&self) -> bool;
95}
96
97macro_rules! impl_diag_singularity_float {
98 ($($t:ty),* $(,)?) => {
99 $(
100 impl DiagSingularity for $t {
101 fn is_singular_or_nonfinite(&self) -> bool {
102 !self.is_finite() || *self == 0.0
103 }
104 }
105 )*
106 };
107}
108
109impl_diag_singularity_float!(f64, f32);
110
111macro_rules! impl_diag_singularity_complex {
112 ($($t:ty),* $(,)?) => {
113 $(
114 impl DiagSingularity for $t {
115 fn is_singular_or_nonfinite(&self) -> bool {
116 !self.re.is_finite() || !self.im.is_finite() || self.norm_sqr() == 0.0
117 }
118 }
119 )*
120 };
121}
122
123impl_diag_singularity_complex!(Complex64, Complex32);
124
125/// Checks that every diagonal element of a (possibly batched) upper-triangular
126/// factor is non-singular and finite.
127///
128/// Iterates over all batch slices and inspects the diagonal entries
129/// `data[i + i * rows]` for `i` in `0..min(rows, cols)`. Returns
130/// [`Error::BackendFailure`] with `op: "solve"` on the first offending entry,
131/// or [`Error::RankMismatch`] when `t` has rank less than two.
132///
133/// # Examples
134///
135/// ```rust
136/// use tenferro_tensor::validate::check_singular_diagonal;
137/// use tenferro_tensor::TypedTensor;
138///
139/// let t = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0f32, 0.0, 0.0, 2.0]).unwrap();
140/// assert!(check_singular_diagonal(&t).is_ok());
141/// ```
142pub fn check_singular_diagonal<T: DiagSingularity + Copy + std::fmt::Debug>(
143 t: &TypedTensor<T>,
144) -> Result<()> {
145 if t.shape().len() < 2 {
146 return Err(Error::RankMismatch {
147 op: "solve",
148 expected: 2,
149 actual: t.shape().len(),
150 });
151 }
152 let rows = t.shape()[0];
153 let cols = t.shape()[1];
154 let n = rows.min(cols);
155 let batch_total: usize = t.shape()[2..].iter().product();
156 let batch_total = batch_total.max(1);
157 let slice_size = rows * cols;
158 let data = t.host_data()?;
159 for batch_idx in 0..batch_total {
160 let batch = &data[batch_idx * slice_size..(batch_idx + 1) * slice_size];
161 for i in 0..n {
162 let diag = batch[i + i * rows];
163 if diag.is_singular_or_nonfinite() {
164 return Err(Error::backend_failure(
165 "solve",
166 if batch_total > 1 {
167 format!(
168 "singular matrix: non-finite or zero diagonal at batch {}, position [{},{}] = {:?}",
169 batch_idx, i, i, diag
170 )
171 } else {
172 format!(
173 "singular matrix: non-finite or zero diagonal at position [{},{}] = {:?}",
174 i, i, diag
175 )
176 },
177 ));
178 }
179 }
180 }
181 Ok(())
182}
183
184/// Validates that the upper-triangular factor `u` of a matrix decomposition
185/// has no singular (zero) or non-finite diagonal entries.
186///
187/// Dispatches to [`check_singular_diagonal`] after unpacking the concrete
188/// tensor variant. Returns `Ok(())` when all diagonal entries are valid.
189///
190/// # Examples
191///
192/// ```rust
193/// use tenferro_tensor::validate::validate_nonsingular_u;
194/// use tenferro_tensor::{Tensor, TypedTensor};
195///
196/// let t = Tensor::F64(TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]).unwrap());
197/// assert!(validate_nonsingular_u(&t).is_ok());
198/// ```
199pub fn validate_nonsingular_u(u: &Tensor) -> Result<()> {
200 match u {
201 Tensor::F64(t) => check_singular_diagonal(t),
202 Tensor::F32(t) => check_singular_diagonal(t),
203 Tensor::C64(t) => check_singular_diagonal(t),
204 Tensor::C32(t) => check_singular_diagonal(t),
205 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(Error::backend_failure(
206 "solve",
207 format!("unsupported dtype {:?}", u.dtype()),
208 )),
209 }
210}
211
212#[cfg(test)]
213mod tests;