Skip to main content

strided_kernel/
reduce_view.rs

1//! Reduce operations on dynamic-rank strided views.
2
3#[cfg(feature = "parallel")]
4use crate::kernel::same_contiguous_layout;
5use crate::kernel::{
6    build_plan_fused, for_each_inner_block_preordered, sequential_contiguous_layout, total_len,
7};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::simd;
10use crate::view::{col_major_strides, StridedArray, StridedView};
11use crate::{Result, StridedError};
12use strided_view::ElementOp;
13
14#[cfg(feature = "parallel")]
15use crate::fuse::compute_costs;
16#[cfg(feature = "parallel")]
17use crate::threading::{
18    for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
19};
20
21/// Full reduction with map function: `reduce(init, op, map.(src))`.
22pub fn reduce<T: Copy + MaybeSendSync, Op: ElementOp<T>, M, R, U>(
23    src: &StridedView<T, Op>,
24    map_fn: M,
25    reduce_fn: R,
26    init: U,
27) -> Result<U>
28where
29    M: Fn(T) -> U + MaybeSync,
30    R: Fn(U, U) -> U + MaybeSync,
31    U: Clone + MaybeSendSync,
32{
33    let src_ptr = src.ptr();
34    let src_dims = src.dims();
35    let src_strides = src.strides();
36
37    if sequential_contiguous_layout(src_dims, &[src_strides]).is_some() {
38        let len = total_len(src_dims);
39        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
40        return Ok(simd::dispatch_if_large(len, || {
41            let mut acc = init;
42            for &val in src.iter() {
43                acc = reduce_fn(acc, map_fn(Op::apply(val)));
44            }
45            acc
46        }));
47    }
48
49    // Parallel contiguous fast path: split into rayon chunks with slice-based iteration.
50    // This enables LLVM auto-vectorization on each chunk, unlike the general threaded path
51    // which uses scalar pointer-offset loops.
52    #[cfg(feature = "parallel")]
53    {
54        let total = total_len(src_dims);
55        if total > MINTHREADLENGTH
56            && rayon::current_num_threads() > 1
57            && same_contiguous_layout(src_dims, &[src_strides]).is_some()
58        {
59            let src_slice = unsafe { std::slice::from_raw_parts(src_ptr, total) };
60            use rayon::prelude::*;
61            let nthreads = rayon::current_num_threads();
62            let chunk_size = (total + nthreads - 1) / nthreads;
63            let result = src_slice
64                .par_chunks(chunk_size)
65                .map(|chunk| {
66                    simd::dispatch_if_large(chunk.len(), || {
67                        let mut acc = init.clone();
68                        for &val in chunk.iter() {
69                            acc = reduce_fn(acc, map_fn(Op::apply(val)));
70                        }
71                        acc
72                    })
73                })
74                .reduce(|| init.clone(), |a, b| reduce_fn(a, b));
75            return Ok(result);
76        }
77    }
78
79    let strides_list: [&[isize]; 1] = [src_strides];
80
81    let (fused_dims, ordered_strides, plan) =
82        build_plan_fused(src_dims, &strides_list, None, std::mem::size_of::<T>());
83
84    #[cfg(feature = "parallel")]
85    {
86        let total: usize = fused_dims.iter().product();
87        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
88            let nthreads = rayon::current_num_threads();
89            // False sharing avoidance: space output slots by cache line size
90            let spacing = (64 / std::mem::size_of::<U>()).max(1);
91            let mut threadedout = vec![init.clone(); spacing * nthreads];
92            let threadedout_ptr = SendPtr(threadedout.as_mut_ptr());
93            let src_send = SendPtr(src_ptr as *mut T);
94
95            let costs = compute_costs(&ordered_strides);
96
97            // For complete reduction, strides_list has 2 entries:
98            // [0] = threadedout (stride 0 everywhere — broadcasting), [1] = src
99            // The spacing/taskindex mechanism addresses output slots.
100            let ndim = fused_dims.len();
101            let mut threaded_strides = Vec::with_capacity(ordered_strides.len() + 1);
102            threaded_strides.push(vec![0isize; ndim]); // threadedout: stride 0 (broadcast)
103            for s in &ordered_strides {
104                threaded_strides.push(s.clone());
105            }
106            let initial_offsets = vec![0isize; threaded_strides.len()];
107
108            // Mask costs for threadedout stride=0 dims (all dims, since it's fully broadcast)
109            // This means: do NOT split on dims where output stride is 0 — but for complete
110            // reduction, ALL output strides are 0, so costs would all be masked to 0.
111            // Julia handles this with the spacing mechanism: each task writes to its own slot.
112            // We keep costs unmasked so splitting still works.
113
114            mapreduce_threaded(
115                &fused_dims,
116                &plan.block,
117                &threaded_strides,
118                &initial_offsets,
119                &costs,
120                nthreads,
121                spacing as isize,
122                1,
123                &|dims, blocks, strides_list, offsets| {
124                    // offsets[0] = spacing * (taskindex - 1) for threadedout
125                    // offsets[1] = offset into src
126                    let out_offset = offsets[0] as usize;
127                    let src_offsets = &offsets[1..];
128
129                    for_each_inner_block_with_offsets(
130                        dims,
131                        blocks,
132                        &strides_list[1..],
133                        src_offsets,
134                        |offsets, len, strides| {
135                            let mut ptr = unsafe { src_send.as_const().offset(offsets[0]) };
136                            let stride = strides[0];
137                            let slot = unsafe { &mut *threadedout_ptr.as_ptr().add(out_offset) };
138                            for _ in 0..len {
139                                let val = Op::apply(unsafe { *ptr });
140                                let mapped = map_fn(val);
141                                *slot = reduce_fn(slot.clone(), mapped);
142                                unsafe {
143                                    ptr = ptr.offset(stride);
144                                }
145                            }
146                            Ok(())
147                        },
148                    )
149                },
150            )?;
151
152            // Merge thread-local results
153            let mut result = init;
154            for i in 0..nthreads {
155                result = reduce_fn(result, threadedout[i * spacing].clone());
156            }
157            return Ok(result);
158        }
159    }
160
161    let mut acc = init;
162    let initial_offsets = vec![0isize; ordered_strides.len()];
163    for_each_inner_block_preordered(
164        &fused_dims,
165        &plan.block,
166        &ordered_strides,
167        &initial_offsets,
168        |offsets, len, strides| {
169            let mut ptr = unsafe { src_ptr.offset(offsets[0]) };
170            let stride = strides[0];
171            for _ in 0..len {
172                let val = Op::apply(unsafe { *ptr });
173                let mapped = map_fn(val);
174                acc = reduce_fn(acc.clone(), mapped);
175                unsafe {
176                    ptr = ptr.offset(stride);
177                }
178            }
179            Ok(())
180        },
181    )?;
182
183    Ok(acc)
184}
185
186/// Reduce along a single axis, returning a new StridedArray.
187pub fn reduce_axis<T: Copy + MaybeSendSync, Op: ElementOp<T>, M, R, U>(
188    src: &StridedView<T, Op>,
189    axis: usize,
190    map_fn: M,
191    reduce_fn: R,
192    init: U,
193) -> Result<StridedArray<U>>
194where
195    M: Fn(T) -> U + MaybeSync,
196    R: Fn(U, U) -> U + MaybeSync,
197    U: Clone + MaybeSendSync,
198{
199    let rank = src.ndim();
200    if axis >= rank {
201        return Err(StridedError::InvalidAxis { axis, rank });
202    }
203
204    let src_dims = src.dims();
205    let src_strides = src.strides();
206    let src_ptr = src.ptr();
207
208    let out_dims: Vec<usize> = src_dims
209        .iter()
210        .enumerate()
211        .filter(|(i, _)| *i != axis)
212        .map(|(_, &d)| d)
213        .collect();
214
215    let axis_len = src_dims[axis];
216    let axis_stride = src_strides[axis];
217
218    if out_dims.is_empty() {
219        // Reduce to scalar
220        let mut acc = init;
221        let mut offset = 0isize;
222        for _ in 0..axis_len {
223            let val = Op::apply(unsafe { *src_ptr.offset(offset) });
224            let mapped = map_fn(val);
225            acc = reduce_fn(acc, mapped);
226            offset += axis_stride;
227        }
228        let strides = col_major_strides(&[1]);
229        return StridedArray::from_parts(vec![acc], &[1], &strides, 0);
230    }
231
232    let total_out: usize = out_dims.iter().product();
233    let out_strides = col_major_strides(&out_dims);
234    let mut out =
235        StridedArray::from_parts(vec![init.clone(); total_out], &out_dims, &out_strides, 0)?;
236
237    // Build source strides for iteration over non-axis dimensions (same rank as out_dims)
238    let src_kept_strides: Vec<isize> = src_strides
239        .iter()
240        .enumerate()
241        .filter(|(i, _)| *i != axis)
242        .map(|(_, &s)| s)
243        .collect();
244
245    let elem_size = std::mem::size_of::<T>().max(std::mem::size_of::<U>());
246    let strides_list: [&[isize]; 2] = [&out_strides, &src_kept_strides];
247    let (fused_dims, ordered_strides, plan) =
248        build_plan_fused(&out_dims, &strides_list, Some(0), elem_size);
249
250    let out_ptr = out.view_mut().as_mut_ptr();
251
252    let initial_offsets = vec![0isize; ordered_strides.len()];
253    for_each_inner_block_preordered(
254        &fused_dims,
255        &plan.block,
256        &ordered_strides,
257        &initial_offsets,
258        |offsets, len, strides| {
259            let out_step = strides[0];
260            let src_step = strides[1];
261
262            // Fast path: when both output and source have stride 1, swap to
263            // reduction-outer / output-inner with slices so LLVM can
264            // auto-vectorize the contiguous inner loop.
265            if out_step == 1 && src_step == 1 && axis_len > 1 {
266                let n = len as usize;
267                let out_slice =
268                    unsafe { std::slice::from_raw_parts_mut(out_ptr.offset(offsets[0]), n) };
269                // First reduction element → initialize output
270                let src0 = unsafe { std::slice::from_raw_parts(src_ptr.offset(offsets[1]), n) };
271                for i in 0..n {
272                    out_slice[i] = map_fn(Op::apply(src0[i]));
273                }
274                // Remaining reduction elements → accumulate
275                for k in 1..axis_len {
276                    let src_k = unsafe {
277                        std::slice::from_raw_parts(
278                            src_ptr.offset(offsets[1] + k as isize * axis_stride),
279                            n,
280                        )
281                    };
282                    for i in 0..n {
283                        out_slice[i] = reduce_fn(out_slice[i].clone(), map_fn(Op::apply(src_k[i])));
284                    }
285                }
286                return Ok(());
287            }
288
289            // General path: output-outer, reduction-inner
290            let mut out_off = offsets[0];
291            let mut src_off = offsets[1];
292            for _ in 0..len {
293                let mut acc = init.clone();
294                let mut ptr = unsafe { src_ptr.offset(src_off) };
295                for _ in 0..axis_len {
296                    let val = Op::apply(unsafe { *ptr });
297                    let mapped = map_fn(val);
298                    acc = reduce_fn(acc, mapped);
299                    unsafe {
300                        ptr = ptr.offset(axis_stride);
301                    }
302                }
303                unsafe {
304                    *out_ptr.offset(out_off) = acc;
305                }
306                out_off += out_step;
307                src_off += src_step;
308            }
309            Ok(())
310        },
311    )?;
312
313    Ok(out)
314}