strided_traits/
scalar.rs

1//! Scalar type bounds for strided operations and einsum.
2
3/// Shared trait bounds for all element types usable with einsum, independent
4/// of GEMM backend.
5///
6/// Unlike the previous design, `ScalarBase` does **not** require
7/// `ElementOpApply`. This allows custom scalar types (e.g., tropical
8/// semiring types) to satisfy einsum bounds without implementing
9/// conj/transpose/adjoint.
10pub trait ScalarBase:
11    Copy
12    + Send
13    + Sync
14    + std::ops::Mul<Output = Self>
15    + std::ops::Add<Output = Self>
16    + num_traits::Zero
17    + num_traits::One
18    + PartialEq
19{
20}
21
22impl<T> ScalarBase for T where
23    T: Copy
24        + Send
25        + Sync
26        + std::ops::Mul<Output = T>
27        + std::ops::Add<Output = T>
28        + num_traits::Zero
29        + num_traits::One
30        + PartialEq
31{
32}
33
34#[cfg(test)]
35mod tests {
36    use super::*;
37    use num_traits::{One, Zero};
38
39    fn assert_scalar_base<T: ScalarBase>() {}
40
41    #[test]
42    fn test_standard_types() {
43        assert_scalar_base::<f32>();
44        assert_scalar_base::<f64>();
45        assert_scalar_base::<i32>();
46        assert_scalar_base::<i64>();
47        assert_scalar_base::<num_complex::Complex64>();
48    }
49
50    #[test]
51    fn test_custom_type_without_element_op_apply() {
52        // A custom type that implements the arithmetic traits
53        // but NOT ElementOpApply — this should still satisfy ScalarBase
54        #[derive(Debug, Clone, Copy, PartialEq)]
55        struct TropicalLike(f64);
56
57        impl std::ops::Add for TropicalLike {
58            type Output = Self;
59            fn add(self, rhs: Self) -> Self {
60                // tropical add = max
61                TropicalLike(self.0.max(rhs.0))
62            }
63        }
64
65        impl std::ops::Mul for TropicalLike {
66            type Output = Self;
67            fn mul(self, rhs: Self) -> Self {
68                // tropical mul = add
69                TropicalLike(self.0 + rhs.0)
70            }
71        }
72
73        impl num_traits::Zero for TropicalLike {
74            fn zero() -> Self {
75                TropicalLike(f64::NEG_INFINITY)
76            }
77            fn is_zero(&self) -> bool {
78                self.0 == f64::NEG_INFINITY
79            }
80        }
81
82        impl num_traits::One for TropicalLike {
83            fn one() -> Self {
84                TropicalLike(0.0)
85            }
86        }
87
88        assert_scalar_base::<TropicalLike>();
89
90        // Exercise the actual operations so coverage sees them
91        let a = TropicalLike(3.0);
92        let b = TropicalLike(5.0);
93        assert_eq!((a + b).0, 5.0); // max(3, 5) = 5
94        assert_eq!((a * b).0, 8.0); // 3 + 5 = 8
95        assert_eq!(TropicalLike::zero().0, f64::NEG_INFINITY);
96        assert!(TropicalLike::zero().is_zero());
97        assert!(!a.is_zero());
98        assert_eq!(TropicalLike::one().0, 0.0);
99    }
100}