strided_traits/
scalar.rs

1//! Scalar type bounds for strided operations and einsum.
2
3/// Shared trait bounds for all element types usable with einsum, independent
4/// of GEMM backend.
5///
6/// Unlike the previous design, `ScalarBase` does **not** require
7/// `ElementOpApply`. This allows custom scalar types (e.g., tropical
8/// semiring types) to satisfy einsum bounds without implementing
9/// conj/transpose/adjoint.
10pub trait ScalarBase:
11    Copy
12    + Send
13    + Sync
14    + 'static
15    + std::ops::Mul<Output = Self>
16    + std::ops::Add<Output = Self>
17    + num_traits::Zero
18    + num_traits::One
19    + PartialEq
20{
21}
22
23impl<T> ScalarBase for T where
24    T: Copy
25        + Send
26        + Sync
27        + 'static
28        + std::ops::Mul<Output = T>
29        + std::ops::Add<Output = T>
30        + num_traits::Zero
31        + num_traits::One
32        + PartialEq
33{
34}
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39    use num_traits::{One, Zero};
40
41    fn assert_scalar_base<T: ScalarBase>() {}
42
43    #[test]
44    fn test_standard_types() {
45        assert_scalar_base::<f32>();
46        assert_scalar_base::<f64>();
47        assert_scalar_base::<i32>();
48        assert_scalar_base::<i64>();
49        assert_scalar_base::<num_complex::Complex64>();
50    }
51
52    #[test]
53    fn test_custom_type_without_element_op_apply() {
54        // A custom type that implements the arithmetic traits
55        // but NOT ElementOpApply — this should still satisfy ScalarBase
56        #[derive(Debug, Clone, Copy, PartialEq)]
57        struct TropicalLike(f64);
58
59        impl std::ops::Add for TropicalLike {
60            type Output = Self;
61            fn add(self, rhs: Self) -> Self {
62                // tropical add = max
63                TropicalLike(self.0.max(rhs.0))
64            }
65        }
66
67        impl std::ops::Mul for TropicalLike {
68            type Output = Self;
69            fn mul(self, rhs: Self) -> Self {
70                // tropical mul = add
71                TropicalLike(self.0 + rhs.0)
72            }
73        }
74
75        impl num_traits::Zero for TropicalLike {
76            fn zero() -> Self {
77                TropicalLike(f64::NEG_INFINITY)
78            }
79            fn is_zero(&self) -> bool {
80                self.0 == f64::NEG_INFINITY
81            }
82        }
83
84        impl num_traits::One for TropicalLike {
85            fn one() -> Self {
86                TropicalLike(0.0)
87            }
88        }
89
90        assert_scalar_base::<TropicalLike>();
91
92        // Exercise the actual operations so coverage sees them
93        let a = TropicalLike(3.0);
94        let b = TropicalLike(5.0);
95        assert_eq!((a + b).0, 5.0); // max(3, 5) = 5
96        assert_eq!((a * b).0, 8.0); // 3 + 5 = 8
97        assert_eq!(TropicalLike::zero().0, f64::NEG_INFINITY);
98        assert!(TropicalLike::zero().is_zero());
99        assert!(!a.is_zero());
100        assert_eq!(TropicalLike::one().0, 0.0);
101    }
102}