1#[inline(always)]
2pub(crate) fn dispatch<R>(f: impl FnOnce() -> R) -> R {
3 #[cfg(feature = "simd")]
4 {
5 pulp::Arch::new().dispatch(f)
6 }
7 #[cfg(not(feature = "simd"))]
8 {
9 f()
10 }
11}
12
13#[inline(always)]
14pub(crate) fn dispatch_if_large<R>(len: usize, f: impl FnOnce() -> R) -> R {
15 if len >= 64 {
18 dispatch(f)
19 } else {
20 f()
21 }
22}
23
24pub trait MaybeSimdOps: Copy + Sized {
29 fn try_simd_sum(_src: &[Self]) -> Option<Self> {
30 None
31 }
32 fn try_simd_dot(_a: &[Self], _b: &[Self]) -> Option<Self> {
33 None
34 }
35}
36
37macro_rules! impl_no_simd {
39 ($($t:ty),*) => {
40 $(impl MaybeSimdOps for $t {})*
41 };
42}
43
44impl_no_simd!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
45
46impl<T: num_traits::Num + Copy + Clone + std::ops::Neg<Output = T>> MaybeSimdOps
47 for num_complex::Complex<T>
48{
49}
50
51#[cfg(not(feature = "simd"))]
53impl MaybeSimdOps for f32 {}
54
55#[cfg(not(feature = "simd"))]
56impl MaybeSimdOps for f64 {}
57
58#[cfg(feature = "simd")]
59mod simd_impls {
60 use super::MaybeSimdOps;
61 use pulp::{Simd, WithSimd};
62
63 impl MaybeSimdOps for f32 {
64 fn try_simd_sum(src: &[f32]) -> Option<f32> {
65 struct Sum<'a>(&'a [f32]);
66 impl<'a> WithSimd for Sum<'a> {
67 type Output = f32;
68
69 #[inline(always)]
70 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
71 let (head, tail) = S::as_simd_f32s(self.0);
72
73 let mut acc0 = simd.splat_f32s(0.0);
74 let mut acc1 = simd.splat_f32s(0.0);
75 let mut acc2 = simd.splat_f32s(0.0);
76 let mut acc3 = simd.splat_f32s(0.0);
77
78 let mut i = 0usize;
79 while i + 4 <= head.len() {
80 acc0 = simd.add_f32s(acc0, head[i]);
81 acc1 = simd.add_f32s(acc1, head[i + 1]);
82 acc2 = simd.add_f32s(acc2, head[i + 2]);
83 acc3 = simd.add_f32s(acc3, head[i + 3]);
84 i += 4;
85 }
86 for &v in &head[i..] {
87 acc0 = simd.add_f32s(acc0, v);
88 }
89
90 let acc = simd.add_f32s(simd.add_f32s(acc0, acc1), simd.add_f32s(acc2, acc3));
91 let mut sum = simd.reduce_sum_f32s(acc);
92 for &x in tail {
93 sum += x;
94 }
95 sum
96 }
97 }
98
99 Some(pulp::Arch::new().dispatch(Sum(src)))
100 }
101
102 fn try_simd_dot(a: &[f32], b: &[f32]) -> Option<f32> {
103 struct Dot<'a> {
104 a: &'a [f32],
105 b: &'a [f32],
106 }
107 impl<'a> WithSimd for Dot<'a> {
108 type Output = f32;
109
110 #[inline(always)]
111 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
112 debug_assert_eq!(self.a.len(), self.b.len());
113 let (a_head, a_tail) = S::as_simd_f32s(self.a);
114 let (b_head, b_tail) = S::as_simd_f32s(self.b);
115 debug_assert_eq!(a_head.len(), b_head.len());
116 debug_assert_eq!(a_tail.len(), b_tail.len());
117
118 let mut acc0 = simd.splat_f32s(0.0);
119 let mut acc1 = simd.splat_f32s(0.0);
120 let mut acc2 = simd.splat_f32s(0.0);
121 let mut acc3 = simd.splat_f32s(0.0);
122
123 let mut i = 0usize;
124 while i + 4 <= a_head.len() {
125 acc0 = simd.mul_add_f32s(a_head[i], b_head[i], acc0);
126 acc1 = simd.mul_add_f32s(a_head[i + 1], b_head[i + 1], acc1);
127 acc2 = simd.mul_add_f32s(a_head[i + 2], b_head[i + 2], acc2);
128 acc3 = simd.mul_add_f32s(a_head[i + 3], b_head[i + 3], acc3);
129 i += 4;
130 }
131 for j in i..a_head.len() {
132 acc0 = simd.mul_add_f32s(a_head[j], b_head[j], acc0);
133 }
134
135 let acc = simd.add_f32s(simd.add_f32s(acc0, acc1), simd.add_f32s(acc2, acc3));
136 let mut sum = simd.reduce_sum_f32s(acc);
137 for (&x, &y) in a_tail.iter().zip(b_tail.iter()) {
138 sum += x * y;
139 }
140 sum
141 }
142 }
143
144 Some(pulp::Arch::new().dispatch(Dot { a, b }))
145 }
146 }
147
148 impl MaybeSimdOps for f64 {
149 fn try_simd_sum(src: &[f64]) -> Option<f64> {
150 struct Sum<'a>(&'a [f64]);
151 impl<'a> WithSimd for Sum<'a> {
152 type Output = f64;
153
154 #[inline(always)]
155 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
156 let (head, tail) = S::as_simd_f64s(self.0);
157
158 let mut acc0 = simd.splat_f64s(0.0);
159 let mut acc1 = simd.splat_f64s(0.0);
160 let mut acc2 = simd.splat_f64s(0.0);
161 let mut acc3 = simd.splat_f64s(0.0);
162
163 let mut i = 0usize;
164 while i + 4 <= head.len() {
165 acc0 = simd.add_f64s(acc0, head[i]);
166 acc1 = simd.add_f64s(acc1, head[i + 1]);
167 acc2 = simd.add_f64s(acc2, head[i + 2]);
168 acc3 = simd.add_f64s(acc3, head[i + 3]);
169 i += 4;
170 }
171 for &v in &head[i..] {
172 acc0 = simd.add_f64s(acc0, v);
173 }
174
175 let acc = simd.add_f64s(simd.add_f64s(acc0, acc1), simd.add_f64s(acc2, acc3));
176 let mut sum = simd.reduce_sum_f64s(acc);
177 for &x in tail {
178 sum += x;
179 }
180 sum
181 }
182 }
183
184 Some(pulp::Arch::new().dispatch(Sum(src)))
185 }
186
187 fn try_simd_dot(a: &[f64], b: &[f64]) -> Option<f64> {
188 struct Dot<'a> {
189 a: &'a [f64],
190 b: &'a [f64],
191 }
192 impl<'a> WithSimd for Dot<'a> {
193 type Output = f64;
194
195 #[inline(always)]
196 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
197 debug_assert_eq!(self.a.len(), self.b.len());
198 let (a_head, a_tail) = S::as_simd_f64s(self.a);
199 let (b_head, b_tail) = S::as_simd_f64s(self.b);
200 debug_assert_eq!(a_head.len(), b_head.len());
201 debug_assert_eq!(a_tail.len(), b_tail.len());
202
203 let mut acc0 = simd.splat_f64s(0.0);
204 let mut acc1 = simd.splat_f64s(0.0);
205 let mut acc2 = simd.splat_f64s(0.0);
206 let mut acc3 = simd.splat_f64s(0.0);
207
208 let mut i = 0usize;
209 while i + 4 <= a_head.len() {
210 acc0 = simd.mul_add_f64s(a_head[i], b_head[i], acc0);
211 acc1 = simd.mul_add_f64s(a_head[i + 1], b_head[i + 1], acc1);
212 acc2 = simd.mul_add_f64s(a_head[i + 2], b_head[i + 2], acc2);
213 acc3 = simd.mul_add_f64s(a_head[i + 3], b_head[i + 3], acc3);
214 i += 4;
215 }
216 for j in i..a_head.len() {
217 acc0 = simd.mul_add_f64s(a_head[j], b_head[j], acc0);
218 }
219
220 let acc = simd.add_f64s(simd.add_f64s(acc0, acc1), simd.add_f64s(acc2, acc3));
221 let mut sum = simd.reduce_sum_f64s(acc);
222 for (&x, &y) in a_tail.iter().zip(b_tail.iter()) {
223 sum += x * y;
224 }
225 sum
226 }
227 }
228
229 Some(pulp::Arch::new().dispatch(Dot { a, b }))
230 }
231 }
232}