1#[cfg(feature = "parallel")]
4use crate::kernel::same_contiguous_layout;
5use crate::kernel::{
6 build_plan_fused, for_each_inner_block_preordered, sequential_contiguous_layout, total_len,
7};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::simd;
10use crate::view::{col_major_strides, StridedArray, StridedView};
11use crate::{Result, StridedError};
12use strided_view::ElementOp;
13
14#[cfg(feature = "parallel")]
15use crate::fuse::compute_costs;
16#[cfg(feature = "parallel")]
17use crate::threading::{
18 for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
19};
20
21pub fn reduce<T: Copy + MaybeSendSync, Op: ElementOp<T>, M, R, U>(
23 src: &StridedView<T, Op>,
24 map_fn: M,
25 reduce_fn: R,
26 init: U,
27) -> Result<U>
28where
29 M: Fn(T) -> U + MaybeSync,
30 R: Fn(U, U) -> U + MaybeSync,
31 U: Clone + MaybeSendSync,
32{
33 let src_ptr = src.ptr();
34 let src_dims = src.dims();
35 let src_strides = src.strides();
36
37 if sequential_contiguous_layout(src_dims, &[src_strides]).is_some() {
38 let len = total_len(src_dims);
39 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
40 return Ok(simd::dispatch_if_large(len, || {
41 let mut acc = init;
42 for &val in src.iter() {
43 acc = reduce_fn(acc, map_fn(Op::apply(val)));
44 }
45 acc
46 }));
47 }
48
49 #[cfg(feature = "parallel")]
53 {
54 let total = total_len(src_dims);
55 if total > MINTHREADLENGTH && same_contiguous_layout(src_dims, &[src_strides]).is_some() {
56 let src_slice = unsafe { std::slice::from_raw_parts(src_ptr, total) };
57 use rayon::prelude::*;
58 let nthreads = rayon::current_num_threads();
59 let chunk_size = (total + nthreads - 1) / nthreads;
60 let result = src_slice
61 .par_chunks(chunk_size)
62 .map(|chunk| {
63 simd::dispatch_if_large(chunk.len(), || {
64 let mut acc = init.clone();
65 for &val in chunk.iter() {
66 acc = reduce_fn(acc, map_fn(Op::apply(val)));
67 }
68 acc
69 })
70 })
71 .reduce(|| init.clone(), |a, b| reduce_fn(a, b));
72 return Ok(result);
73 }
74 }
75
76 let strides_list: [&[isize]; 1] = [src_strides];
77
78 let (fused_dims, ordered_strides, plan) =
79 build_plan_fused(src_dims, &strides_list, None, std::mem::size_of::<T>());
80
81 #[cfg(feature = "parallel")]
82 {
83 let total: usize = fused_dims.iter().product();
84 if total > MINTHREADLENGTH {
85 let nthreads = rayon::current_num_threads();
86 let spacing = (64 / std::mem::size_of::<U>()).max(1);
88 let mut threadedout = vec![init.clone(); spacing * nthreads];
89 let threadedout_ptr = SendPtr(threadedout.as_mut_ptr());
90 let src_send = SendPtr(src_ptr as *mut T);
91
92 let costs = compute_costs(&ordered_strides);
93
94 let ndim = fused_dims.len();
98 let mut threaded_strides = Vec::with_capacity(ordered_strides.len() + 1);
99 threaded_strides.push(vec![0isize; ndim]); for s in &ordered_strides {
101 threaded_strides.push(s.clone());
102 }
103 let initial_offsets = vec![0isize; threaded_strides.len()];
104
105 mapreduce_threaded(
112 &fused_dims,
113 &plan.block,
114 &threaded_strides,
115 &initial_offsets,
116 &costs,
117 nthreads,
118 spacing as isize,
119 1,
120 &|dims, blocks, strides_list, offsets| {
121 let out_offset = offsets[0] as usize;
124 let src_offsets = &offsets[1..];
125
126 for_each_inner_block_with_offsets(
127 dims,
128 blocks,
129 &strides_list[1..],
130 src_offsets,
131 |offsets, len, strides| {
132 let mut ptr = unsafe { src_send.as_const().offset(offsets[0]) };
133 let stride = strides[0];
134 let slot = unsafe { &mut *threadedout_ptr.as_ptr().add(out_offset) };
135 for _ in 0..len {
136 let val = Op::apply(unsafe { *ptr });
137 let mapped = map_fn(val);
138 *slot = reduce_fn(slot.clone(), mapped);
139 unsafe {
140 ptr = ptr.offset(stride);
141 }
142 }
143 Ok(())
144 },
145 )
146 },
147 )?;
148
149 let mut result = init;
151 for i in 0..nthreads {
152 result = reduce_fn(result, threadedout[i * spacing].clone());
153 }
154 return Ok(result);
155 }
156 }
157
158 let mut acc = init;
159 let initial_offsets = vec![0isize; ordered_strides.len()];
160 for_each_inner_block_preordered(
161 &fused_dims,
162 &plan.block,
163 &ordered_strides,
164 &initial_offsets,
165 |offsets, len, strides| {
166 let mut ptr = unsafe { src_ptr.offset(offsets[0]) };
167 let stride = strides[0];
168 for _ in 0..len {
169 let val = Op::apply(unsafe { *ptr });
170 let mapped = map_fn(val);
171 acc = reduce_fn(acc.clone(), mapped);
172 unsafe {
173 ptr = ptr.offset(stride);
174 }
175 }
176 Ok(())
177 },
178 )?;
179
180 Ok(acc)
181}
182
183pub fn reduce_axis<T: Copy + MaybeSendSync, Op: ElementOp<T>, M, R, U>(
185 src: &StridedView<T, Op>,
186 axis: usize,
187 map_fn: M,
188 reduce_fn: R,
189 init: U,
190) -> Result<StridedArray<U>>
191where
192 M: Fn(T) -> U + MaybeSync,
193 R: Fn(U, U) -> U + MaybeSync,
194 U: Clone + MaybeSendSync,
195{
196 let rank = src.ndim();
197 if axis >= rank {
198 return Err(StridedError::InvalidAxis { axis, rank });
199 }
200
201 let src_dims = src.dims();
202 let src_strides = src.strides();
203 let src_ptr = src.ptr();
204
205 let out_dims: Vec<usize> = src_dims
206 .iter()
207 .enumerate()
208 .filter(|(i, _)| *i != axis)
209 .map(|(_, &d)| d)
210 .collect();
211
212 let axis_len = src_dims[axis];
213 let axis_stride = src_strides[axis];
214
215 if out_dims.is_empty() {
216 let mut acc = init;
218 let mut offset = 0isize;
219 for _ in 0..axis_len {
220 let val = Op::apply(unsafe { *src_ptr.offset(offset) });
221 let mapped = map_fn(val);
222 acc = reduce_fn(acc, mapped);
223 offset += axis_stride;
224 }
225 let strides = col_major_strides(&[1]);
226 return StridedArray::from_parts(vec![acc], &[1], &strides, 0);
227 }
228
229 let total_out: usize = out_dims.iter().product();
230 let out_strides = col_major_strides(&out_dims);
231 let mut out =
232 StridedArray::from_parts(vec![init.clone(); total_out], &out_dims, &out_strides, 0)?;
233
234 let src_kept_strides: Vec<isize> = src_strides
236 .iter()
237 .enumerate()
238 .filter(|(i, _)| *i != axis)
239 .map(|(_, &s)| s)
240 .collect();
241
242 let elem_size = std::mem::size_of::<T>().max(std::mem::size_of::<U>());
243 let strides_list: [&[isize]; 2] = [&out_strides, &src_kept_strides];
244 let (fused_dims, ordered_strides, plan) =
245 build_plan_fused(&out_dims, &strides_list, Some(0), elem_size);
246
247 let out_ptr = out.view_mut().as_mut_ptr();
248
249 let initial_offsets = vec![0isize; ordered_strides.len()];
250 for_each_inner_block_preordered(
251 &fused_dims,
252 &plan.block,
253 &ordered_strides,
254 &initial_offsets,
255 |offsets, len, strides| {
256 let out_step = strides[0];
257 let src_step = strides[1];
258
259 if out_step == 1 && src_step == 1 && axis_len > 1 {
263 let n = len as usize;
264 let out_slice =
265 unsafe { std::slice::from_raw_parts_mut(out_ptr.offset(offsets[0]), n) };
266 let src0 = unsafe { std::slice::from_raw_parts(src_ptr.offset(offsets[1]), n) };
268 for i in 0..n {
269 out_slice[i] = map_fn(Op::apply(src0[i]));
270 }
271 for k in 1..axis_len {
273 let src_k = unsafe {
274 std::slice::from_raw_parts(
275 src_ptr.offset(offsets[1] + k as isize * axis_stride),
276 n,
277 )
278 };
279 for i in 0..n {
280 out_slice[i] = reduce_fn(out_slice[i].clone(), map_fn(Op::apply(src_k[i])));
281 }
282 }
283 return Ok(());
284 }
285
286 let mut out_off = offsets[0];
288 let mut src_off = offsets[1];
289 for _ in 0..len {
290 let mut acc = init.clone();
291 let mut ptr = unsafe { src_ptr.offset(src_off) };
292 for _ in 0..axis_len {
293 let val = Op::apply(unsafe { *ptr });
294 let mapped = map_fn(val);
295 acc = reduce_fn(acc, mapped);
296 unsafe {
297 ptr = ptr.offset(axis_stride);
298 }
299 }
300 unsafe {
301 *out_ptr.offset(out_off) = acc;
302 }
303 out_off += out_step;
304 src_off += src_step;
305 }
306 Ok(())
307 },
308 )?;
309
310 Ok(out)
311}