1#[cfg(feature = "parallel")]
4use crate::kernel::same_contiguous_layout;
5use crate::kernel::{
6 build_plan_fused, for_each_inner_block_preordered, sequential_contiguous_layout, total_len,
7};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::simd;
10use crate::view::{col_major_strides, StridedArray, StridedView};
11use crate::{Result, StridedError};
12use strided_view::ElementOp;
13
14#[cfg(feature = "parallel")]
15use crate::fuse::compute_costs;
16#[cfg(feature = "parallel")]
17use crate::threading::{
18 for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
19};
20
21pub fn reduce<T: Copy + MaybeSendSync, Op: ElementOp<T>, M, R, U>(
23 src: &StridedView<T, Op>,
24 map_fn: M,
25 reduce_fn: R,
26 init: U,
27) -> Result<U>
28where
29 M: Fn(T) -> U + MaybeSync,
30 R: Fn(U, U) -> U + MaybeSync,
31 U: Clone + MaybeSendSync,
32{
33 let src_ptr = src.ptr();
34 let src_dims = src.dims();
35 let src_strides = src.strides();
36
37 if sequential_contiguous_layout(src_dims, &[src_strides]).is_some() {
38 let len = total_len(src_dims);
39 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
40 return Ok(simd::dispatch_if_large(len, || {
41 let mut acc = init;
42 for &val in src.iter() {
43 acc = reduce_fn(acc, map_fn(Op::apply(val)));
44 }
45 acc
46 }));
47 }
48
49 #[cfg(feature = "parallel")]
53 {
54 let total = total_len(src_dims);
55 if total > MINTHREADLENGTH
56 && rayon::current_num_threads() > 1
57 && same_contiguous_layout(src_dims, &[src_strides]).is_some()
58 {
59 let src_slice = unsafe { std::slice::from_raw_parts(src_ptr, total) };
60 use rayon::prelude::*;
61 let nthreads = rayon::current_num_threads();
62 let chunk_size = (total + nthreads - 1) / nthreads;
63 let result = src_slice
64 .par_chunks(chunk_size)
65 .map(|chunk| {
66 simd::dispatch_if_large(chunk.len(), || {
67 let mut acc = init.clone();
68 for &val in chunk.iter() {
69 acc = reduce_fn(acc, map_fn(Op::apply(val)));
70 }
71 acc
72 })
73 })
74 .reduce(|| init.clone(), |a, b| reduce_fn(a, b));
75 return Ok(result);
76 }
77 }
78
79 let strides_list: [&[isize]; 1] = [src_strides];
80
81 let (fused_dims, ordered_strides, plan) =
82 build_plan_fused(src_dims, &strides_list, None, std::mem::size_of::<T>());
83
84 #[cfg(feature = "parallel")]
85 {
86 let total: usize = fused_dims.iter().product();
87 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
88 let nthreads = rayon::current_num_threads();
89 let spacing = (64 / std::mem::size_of::<U>()).max(1);
91 let mut threadedout = vec![init.clone(); spacing * nthreads];
92 let threadedout_ptr = SendPtr(threadedout.as_mut_ptr());
93 let src_send = SendPtr(src_ptr as *mut T);
94
95 let costs = compute_costs(&ordered_strides);
96
97 let ndim = fused_dims.len();
101 let mut threaded_strides = Vec::with_capacity(ordered_strides.len() + 1);
102 threaded_strides.push(vec![0isize; ndim]); for s in &ordered_strides {
104 threaded_strides.push(s.clone());
105 }
106 let initial_offsets = vec![0isize; threaded_strides.len()];
107
108 mapreduce_threaded(
115 &fused_dims,
116 &plan.block,
117 &threaded_strides,
118 &initial_offsets,
119 &costs,
120 nthreads,
121 spacing as isize,
122 1,
123 &|dims, blocks, strides_list, offsets| {
124 let out_offset = offsets[0] as usize;
127 let src_offsets = &offsets[1..];
128
129 for_each_inner_block_with_offsets(
130 dims,
131 blocks,
132 &strides_list[1..],
133 src_offsets,
134 |offsets, len, strides| {
135 let mut ptr = unsafe { src_send.as_const().offset(offsets[0]) };
136 let stride = strides[0];
137 let slot = unsafe { &mut *threadedout_ptr.as_ptr().add(out_offset) };
138 for _ in 0..len {
139 let val = Op::apply(unsafe { *ptr });
140 let mapped = map_fn(val);
141 *slot = reduce_fn(slot.clone(), mapped);
142 unsafe {
143 ptr = ptr.offset(stride);
144 }
145 }
146 Ok(())
147 },
148 )
149 },
150 )?;
151
152 let mut result = init;
154 for i in 0..nthreads {
155 result = reduce_fn(result, threadedout[i * spacing].clone());
156 }
157 return Ok(result);
158 }
159 }
160
161 let mut acc = init;
162 let initial_offsets = vec![0isize; ordered_strides.len()];
163 for_each_inner_block_preordered(
164 &fused_dims,
165 &plan.block,
166 &ordered_strides,
167 &initial_offsets,
168 |offsets, len, strides| {
169 let mut ptr = unsafe { src_ptr.offset(offsets[0]) };
170 let stride = strides[0];
171 for _ in 0..len {
172 let val = Op::apply(unsafe { *ptr });
173 let mapped = map_fn(val);
174 acc = reduce_fn(acc.clone(), mapped);
175 unsafe {
176 ptr = ptr.offset(stride);
177 }
178 }
179 Ok(())
180 },
181 )?;
182
183 Ok(acc)
184}
185
186pub fn reduce_axis<T: Copy + MaybeSendSync, Op: ElementOp<T>, M, R, U>(
188 src: &StridedView<T, Op>,
189 axis: usize,
190 map_fn: M,
191 reduce_fn: R,
192 init: U,
193) -> Result<StridedArray<U>>
194where
195 M: Fn(T) -> U + MaybeSync,
196 R: Fn(U, U) -> U + MaybeSync,
197 U: Clone + MaybeSendSync,
198{
199 let rank = src.ndim();
200 if axis >= rank {
201 return Err(StridedError::InvalidAxis { axis, rank });
202 }
203
204 let src_dims = src.dims();
205 let src_strides = src.strides();
206 let src_ptr = src.ptr();
207
208 let out_dims: Vec<usize> = src_dims
209 .iter()
210 .enumerate()
211 .filter(|(i, _)| *i != axis)
212 .map(|(_, &d)| d)
213 .collect();
214
215 let axis_len = src_dims[axis];
216 let axis_stride = src_strides[axis];
217
218 if out_dims.is_empty() {
219 let mut acc = init;
221 let mut offset = 0isize;
222 for _ in 0..axis_len {
223 let val = Op::apply(unsafe { *src_ptr.offset(offset) });
224 let mapped = map_fn(val);
225 acc = reduce_fn(acc, mapped);
226 offset += axis_stride;
227 }
228 let strides = col_major_strides(&[1]);
229 return StridedArray::from_parts(vec![acc], &[1], &strides, 0);
230 }
231
232 let total_out: usize = out_dims.iter().product();
233 let out_strides = col_major_strides(&out_dims);
234 let mut out =
235 StridedArray::from_parts(vec![init.clone(); total_out], &out_dims, &out_strides, 0)?;
236
237 let src_kept_strides: Vec<isize> = src_strides
239 .iter()
240 .enumerate()
241 .filter(|(i, _)| *i != axis)
242 .map(|(_, &s)| s)
243 .collect();
244
245 let elem_size = std::mem::size_of::<T>().max(std::mem::size_of::<U>());
246 let strides_list: [&[isize]; 2] = [&out_strides, &src_kept_strides];
247 let (fused_dims, ordered_strides, plan) =
248 build_plan_fused(&out_dims, &strides_list, Some(0), elem_size);
249
250 let out_ptr = out.view_mut().as_mut_ptr();
251
252 let initial_offsets = vec![0isize; ordered_strides.len()];
253 for_each_inner_block_preordered(
254 &fused_dims,
255 &plan.block,
256 &ordered_strides,
257 &initial_offsets,
258 |offsets, len, strides| {
259 let out_step = strides[0];
260 let src_step = strides[1];
261
262 if out_step == 1 && src_step == 1 && axis_len > 1 {
266 let n = len as usize;
267 let out_slice =
268 unsafe { std::slice::from_raw_parts_mut(out_ptr.offset(offsets[0]), n) };
269 let src0 = unsafe { std::slice::from_raw_parts(src_ptr.offset(offsets[1]), n) };
271 for i in 0..n {
272 out_slice[i] = map_fn(Op::apply(src0[i]));
273 }
274 for k in 1..axis_len {
276 let src_k = unsafe {
277 std::slice::from_raw_parts(
278 src_ptr.offset(offsets[1] + k as isize * axis_stride),
279 n,
280 )
281 };
282 for i in 0..n {
283 out_slice[i] = reduce_fn(out_slice[i].clone(), map_fn(Op::apply(src_k[i])));
284 }
285 }
286 return Ok(());
287 }
288
289 let mut out_off = offsets[0];
291 let mut src_off = offsets[1];
292 for _ in 0..len {
293 let mut acc = init.clone();
294 let mut ptr = unsafe { src_ptr.offset(src_off) };
295 for _ in 0..axis_len {
296 let val = Op::apply(unsafe { *ptr });
297 let mapped = map_fn(val);
298 acc = reduce_fn(acc, mapped);
299 unsafe {
300 ptr = ptr.offset(axis_stride);
301 }
302 }
303 unsafe {
304 *out_ptr.offset(out_off) = acc;
305 }
306 out_off += out_step;
307 src_off += src_step;
308 }
309 Ok(())
310 },
311 )?;
312
313 Ok(out)
314}