strided_kernel/
simd.rs

1#[inline(always)]
2pub(crate) fn dispatch<R>(f: impl FnOnce() -> R) -> R {
3    #[cfg(feature = "simd")]
4    {
5        pulp::Arch::new().dispatch(f)
6    }
7    #[cfg(not(feature = "simd"))]
8    {
9        f()
10    }
11}
12
13#[inline(always)]
14pub(crate) fn dispatch_if_large<R>(len: usize, f: impl FnOnce() -> R) -> R {
15    // Avoid runtime-dispatch overhead for tiny loops (especially common for small-array cases).
16    // This is a heuristic; correctness does not depend on it.
17    if len >= 64 {
18        dispatch(f)
19    } else {
20        f()
21    }
22}
23
24/// Trait for types that may have SIMD-accelerated sum/dot operations.
25///
26/// Default implementations return `None` (no SIMD available).
27/// f32/f64 override these with SIMD kernels when the `simd` feature is enabled.
28pub trait MaybeSimdOps: Copy + Sized {
29    fn try_simd_sum(_src: &[Self]) -> Option<Self> {
30        None
31    }
32    fn try_simd_dot(_a: &[Self], _b: &[Self]) -> Option<Self> {
33        None
34    }
35}
36
37// Default (no-op) impls for integer types and Complex
38macro_rules! impl_no_simd {
39    ($($t:ty),*) => {
40        $(impl MaybeSimdOps for $t {})*
41    };
42}
43
44impl_no_simd!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
45
46impl<T: num_traits::Num + Copy + Clone + std::ops::Neg<Output = T>> MaybeSimdOps
47    for num_complex::Complex<T>
48{
49}
50
51// f32/f64: SIMD-accelerated when feature enabled, no-op otherwise
52#[cfg(not(feature = "simd"))]
53impl MaybeSimdOps for f32 {}
54
55#[cfg(not(feature = "simd"))]
56impl MaybeSimdOps for f64 {}
57
58#[cfg(feature = "simd")]
59mod simd_impls {
60    use super::MaybeSimdOps;
61    use pulp::{Simd, WithSimd};
62
63    impl MaybeSimdOps for f32 {
64        fn try_simd_sum(src: &[f32]) -> Option<f32> {
65            struct Sum<'a>(&'a [f32]);
66            impl<'a> WithSimd for Sum<'a> {
67                type Output = f32;
68
69                #[inline(always)]
70                fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
71                    let (head, tail) = S::as_simd_f32s(self.0);
72
73                    let mut acc0 = simd.splat_f32s(0.0);
74                    let mut acc1 = simd.splat_f32s(0.0);
75                    let mut acc2 = simd.splat_f32s(0.0);
76                    let mut acc3 = simd.splat_f32s(0.0);
77
78                    let mut i = 0usize;
79                    while i + 4 <= head.len() {
80                        acc0 = simd.add_f32s(acc0, head[i]);
81                        acc1 = simd.add_f32s(acc1, head[i + 1]);
82                        acc2 = simd.add_f32s(acc2, head[i + 2]);
83                        acc3 = simd.add_f32s(acc3, head[i + 3]);
84                        i += 4;
85                    }
86                    for &v in &head[i..] {
87                        acc0 = simd.add_f32s(acc0, v);
88                    }
89
90                    let acc = simd.add_f32s(simd.add_f32s(acc0, acc1), simd.add_f32s(acc2, acc3));
91                    let mut sum = simd.reduce_sum_f32s(acc);
92                    for &x in tail {
93                        sum += x;
94                    }
95                    sum
96                }
97            }
98
99            Some(pulp::Arch::new().dispatch(Sum(src)))
100        }
101
102        fn try_simd_dot(a: &[f32], b: &[f32]) -> Option<f32> {
103            struct Dot<'a> {
104                a: &'a [f32],
105                b: &'a [f32],
106            }
107            impl<'a> WithSimd for Dot<'a> {
108                type Output = f32;
109
110                #[inline(always)]
111                fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
112                    debug_assert_eq!(self.a.len(), self.b.len());
113                    let (a_head, a_tail) = S::as_simd_f32s(self.a);
114                    let (b_head, b_tail) = S::as_simd_f32s(self.b);
115                    debug_assert_eq!(a_head.len(), b_head.len());
116                    debug_assert_eq!(a_tail.len(), b_tail.len());
117
118                    let mut acc0 = simd.splat_f32s(0.0);
119                    let mut acc1 = simd.splat_f32s(0.0);
120                    let mut acc2 = simd.splat_f32s(0.0);
121                    let mut acc3 = simd.splat_f32s(0.0);
122
123                    let mut i = 0usize;
124                    while i + 4 <= a_head.len() {
125                        acc0 = simd.mul_add_f32s(a_head[i], b_head[i], acc0);
126                        acc1 = simd.mul_add_f32s(a_head[i + 1], b_head[i + 1], acc1);
127                        acc2 = simd.mul_add_f32s(a_head[i + 2], b_head[i + 2], acc2);
128                        acc3 = simd.mul_add_f32s(a_head[i + 3], b_head[i + 3], acc3);
129                        i += 4;
130                    }
131                    for j in i..a_head.len() {
132                        acc0 = simd.mul_add_f32s(a_head[j], b_head[j], acc0);
133                    }
134
135                    let acc = simd.add_f32s(simd.add_f32s(acc0, acc1), simd.add_f32s(acc2, acc3));
136                    let mut sum = simd.reduce_sum_f32s(acc);
137                    for (&x, &y) in a_tail.iter().zip(b_tail.iter()) {
138                        sum += x * y;
139                    }
140                    sum
141                }
142            }
143
144            Some(pulp::Arch::new().dispatch(Dot { a, b }))
145        }
146    }
147
148    impl MaybeSimdOps for f64 {
149        fn try_simd_sum(src: &[f64]) -> Option<f64> {
150            struct Sum<'a>(&'a [f64]);
151            impl<'a> WithSimd for Sum<'a> {
152                type Output = f64;
153
154                #[inline(always)]
155                fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
156                    let (head, tail) = S::as_simd_f64s(self.0);
157
158                    let mut acc0 = simd.splat_f64s(0.0);
159                    let mut acc1 = simd.splat_f64s(0.0);
160                    let mut acc2 = simd.splat_f64s(0.0);
161                    let mut acc3 = simd.splat_f64s(0.0);
162
163                    let mut i = 0usize;
164                    while i + 4 <= head.len() {
165                        acc0 = simd.add_f64s(acc0, head[i]);
166                        acc1 = simd.add_f64s(acc1, head[i + 1]);
167                        acc2 = simd.add_f64s(acc2, head[i + 2]);
168                        acc3 = simd.add_f64s(acc3, head[i + 3]);
169                        i += 4;
170                    }
171                    for &v in &head[i..] {
172                        acc0 = simd.add_f64s(acc0, v);
173                    }
174
175                    let acc = simd.add_f64s(simd.add_f64s(acc0, acc1), simd.add_f64s(acc2, acc3));
176                    let mut sum = simd.reduce_sum_f64s(acc);
177                    for &x in tail {
178                        sum += x;
179                    }
180                    sum
181                }
182            }
183
184            Some(pulp::Arch::new().dispatch(Sum(src)))
185        }
186
187        fn try_simd_dot(a: &[f64], b: &[f64]) -> Option<f64> {
188            struct Dot<'a> {
189                a: &'a [f64],
190                b: &'a [f64],
191            }
192            impl<'a> WithSimd for Dot<'a> {
193                type Output = f64;
194
195                #[inline(always)]
196                fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
197                    debug_assert_eq!(self.a.len(), self.b.len());
198                    let (a_head, a_tail) = S::as_simd_f64s(self.a);
199                    let (b_head, b_tail) = S::as_simd_f64s(self.b);
200                    debug_assert_eq!(a_head.len(), b_head.len());
201                    debug_assert_eq!(a_tail.len(), b_tail.len());
202
203                    let mut acc0 = simd.splat_f64s(0.0);
204                    let mut acc1 = simd.splat_f64s(0.0);
205                    let mut acc2 = simd.splat_f64s(0.0);
206                    let mut acc3 = simd.splat_f64s(0.0);
207
208                    let mut i = 0usize;
209                    while i + 4 <= a_head.len() {
210                        acc0 = simd.mul_add_f64s(a_head[i], b_head[i], acc0);
211                        acc1 = simd.mul_add_f64s(a_head[i + 1], b_head[i + 1], acc1);
212                        acc2 = simd.mul_add_f64s(a_head[i + 2], b_head[i + 2], acc2);
213                        acc3 = simd.mul_add_f64s(a_head[i + 3], b_head[i + 3], acc3);
214                        i += 4;
215                    }
216                    for j in i..a_head.len() {
217                        acc0 = simd.mul_add_f64s(a_head[j], b_head[j], acc0);
218                    }
219
220                    let acc = simd.add_f64s(simd.add_f64s(acc0, acc1), simd.add_f64s(acc2, acc3));
221                    let mut sum = simd.reduce_sum_f64s(acc);
222                    for (&x, &y) in a_tail.iter().zip(b_tail.iter()) {
223                        sum += x * y;
224                    }
225                    sum
226                }
227            }
228
229            Some(pulp::Arch::new().dispatch(Dot { a, b }))
230        }
231    }
232}