Skip to main content

tensor4all_tcicore/
scalar.rs

1//! Common scalar trait for matrix and tensor operations.
2//!
3//! The [`Scalar`] trait provides a unified interface for numeric types used
4//! in matrix cross interpolation and tensor train operations. It is
5//! implemented for [`f64`], [`f32`], [`Complex64`], and [`Complex32`].
6//!
7//! The `scalar_tests!` macro generates dual `f64`/`Complex64` test
8//! variants from a single generic test function.
9
10use crate::matrix::BlasMul;
11use num_complex::{Complex32, Complex64};
12use num_traits::{Float, One, Zero};
13
14/// Common scalar trait for matrix and tensor operations.
15///
16/// Defines the minimal requirements for scalar types used in matrix cross
17/// interpolation and tensor train operations. Implemented for `f64`, `f32`,
18/// `Complex64`, and `Complex32`.
19///
20/// # Examples
21///
22/// ```
23/// use tensor4all_tcicore::Scalar;
24///
25/// // f64
26/// let x = 3.0_f64;
27/// assert_eq!(x.abs_sq(), 9.0);
28/// assert_eq!(x.conj(), 3.0);
29///
30/// // Complex64
31/// use num_complex::Complex64;
32/// let z = Complex64::new(3.0, 4.0);
33/// assert!((z.abs_sq() - 25.0).abs() < 1e-10);
34/// assert_eq!(z.conj(), Complex64::new(3.0, -4.0));
35///
36/// // Construction from f64
37/// let val = f64::from_f64(2.5);
38/// assert_eq!(val, 2.5);
39/// ```
40pub trait Scalar:
41    Clone
42    + Copy
43    + Zero
44    + One
45    + std::ops::Add<Output = Self>
46    + std::ops::Sub<Output = Self>
47    + std::ops::Mul<Output = Self>
48    + std::ops::Div<Output = Self>
49    + std::ops::Neg<Output = Self>
50    + Default
51    + Send
52    + Sync
53    + BlasMul
54    + 'static
55{
56    /// Complex conjugate of the value.
57    fn conj(self) -> Self;
58
59    /// Square of the absolute value (for complex numbers, |z|^2).
60    fn abs_sq(self) -> f64;
61
62    /// Absolute value as Self type.
63    ///
64    /// For real types, this returns the absolute value.
65    /// For complex types, this returns a real-valued complex (re=|z|, im=0).
66    fn abs(self) -> Self;
67
68    /// Absolute value as f64.
69    fn abs_val(self) -> f64 {
70        self.abs_sq().sqrt()
71    }
72
73    /// Create from f64 value.
74    fn from_f64(val: f64) -> Self;
75
76    /// Check if value is NaN.
77    fn is_nan(self) -> bool;
78
79    /// Machine epsilon for numerical comparisons.
80    ///
81    /// Returns `f64::EPSILON` (~2.2e-16) by default. This is the smallest value
82    /// such that `1.0 + epsilon != 1.0`.
83    fn epsilon() -> f64 {
84        f64::EPSILON
85    }
86}
87
88impl Scalar for f64 {
89    #[inline]
90    fn conj(self) -> Self {
91        self
92    }
93
94    #[inline]
95    fn abs_sq(self) -> f64 {
96        self * self
97    }
98
99    #[inline]
100    fn abs(self) -> Self {
101        Float::abs(self)
102    }
103
104    #[inline]
105    fn abs_val(self) -> f64 {
106        Float::abs(self)
107    }
108
109    #[inline]
110    fn from_f64(val: f64) -> Self {
111        val
112    }
113
114    #[inline]
115    fn is_nan(self) -> bool {
116        Float::is_nan(self)
117    }
118}
119
120impl Scalar for f32 {
121    #[inline]
122    fn conj(self) -> Self {
123        self
124    }
125
126    #[inline]
127    fn abs_sq(self) -> f64 {
128        (self * self) as f64
129    }
130
131    #[inline]
132    fn abs(self) -> Self {
133        Float::abs(self)
134    }
135
136    #[inline]
137    fn abs_val(self) -> f64 {
138        Float::abs(self) as f64
139    }
140
141    #[inline]
142    fn from_f64(val: f64) -> Self {
143        val as f32
144    }
145
146    #[inline]
147    fn is_nan(self) -> bool {
148        Float::is_nan(self)
149    }
150}
151
152impl Scalar for Complex64 {
153    #[inline]
154    fn conj(self) -> Self {
155        Complex64::conj(&self)
156    }
157
158    #[inline]
159    fn abs_sq(self) -> f64 {
160        self.norm_sqr()
161    }
162
163    #[inline]
164    fn abs(self) -> Self {
165        Complex64::new(self.norm(), 0.0)
166    }
167
168    #[inline]
169    fn abs_val(self) -> f64 {
170        self.norm()
171    }
172
173    #[inline]
174    fn from_f64(val: f64) -> Self {
175        Complex64::new(val, 0.0)
176    }
177
178    #[inline]
179    fn is_nan(self) -> bool {
180        self.re.is_nan() || self.im.is_nan()
181    }
182}
183
184impl Scalar for Complex32 {
185    #[inline]
186    fn conj(self) -> Self {
187        Complex32::conj(&self)
188    }
189
190    #[inline]
191    fn abs_sq(self) -> f64 {
192        self.norm_sqr() as f64
193    }
194
195    #[inline]
196    fn abs(self) -> Self {
197        Complex32::new(self.norm(), 0.0)
198    }
199
200    #[inline]
201    fn abs_val(self) -> f64 {
202        self.norm() as f64
203    }
204
205    #[inline]
206    fn from_f64(val: f64) -> Self {
207        Complex32::new(val as f32, 0.0)
208    }
209
210    #[inline]
211    fn is_nan(self) -> bool {
212        self.re.is_nan() || self.im.is_nan()
213    }
214}
215
216/// Macro to generate f64 and Complex64 test variants from a generic test function.
217///
218/// # Example
219///
220/// ```
221/// fn test_operation_generic<T: tensor4all_tcicore::Scalar>() {
222///     let value = T::from_f64(2.0);
223///     assert_eq!(value.abs_sq(), 4.0);
224/// }
225///
226/// # fn main() {}
227/// tensor4all_tcicore::scalar_tests!(test_operation, test_operation_generic);
228/// ```
229#[macro_export]
230macro_rules! scalar_tests {
231    ($name:ident, $test_fn:ident) => {
232        paste::paste! {
233            #[test]
234            fn [<$name _f64>]() {
235                $test_fn::<f64>();
236            }
237
238            #[test]
239            fn [<$name _c64>]() {
240                $test_fn::<num_complex::Complex64>();
241            }
242        }
243    };
244}
245
246#[cfg(test)]
247mod tests;