strided_kernel/
reduce_view.rs

1//! Reduce operations on dynamic-rank strided views.
2
3#[cfg(feature = "parallel")]
4use crate::kernel::same_contiguous_layout;
5use crate::kernel::{
6    build_plan_fused, for_each_inner_block_preordered, sequential_contiguous_layout, total_len,
7};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::simd;
10use crate::view::{col_major_strides, StridedArray, StridedView};
11use crate::{Result, StridedError};
12use strided_view::ElementOp;
13
14#[cfg(feature = "parallel")]
15use crate::fuse::compute_costs;
16#[cfg(feature = "parallel")]
17use crate::threading::{
18    for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
19};
20
21/// Full reduction with map function: `reduce(init, op, map.(src))`.
22pub fn reduce<T: Copy + MaybeSendSync, Op: ElementOp<T>, M, R, U>(
23    src: &StridedView<T, Op>,
24    map_fn: M,
25    reduce_fn: R,
26    init: U,
27) -> Result<U>
28where
29    M: Fn(T) -> U + MaybeSync,
30    R: Fn(U, U) -> U + MaybeSync,
31    U: Clone + MaybeSendSync,
32{
33    let src_ptr = src.ptr();
34    let src_dims = src.dims();
35    let src_strides = src.strides();
36
37    if sequential_contiguous_layout(src_dims, &[src_strides]).is_some() {
38        let len = total_len(src_dims);
39        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
40        return Ok(simd::dispatch_if_large(len, || {
41            let mut acc = init;
42            for &val in src.iter() {
43                acc = reduce_fn(acc, map_fn(Op::apply(val)));
44            }
45            acc
46        }));
47    }
48
49    // Parallel contiguous fast path: split into rayon chunks with slice-based iteration.
50    // This enables LLVM auto-vectorization on each chunk, unlike the general threaded path
51    // which uses scalar pointer-offset loops.
52    #[cfg(feature = "parallel")]
53    {
54        let total = total_len(src_dims);
55        if total > MINTHREADLENGTH && same_contiguous_layout(src_dims, &[src_strides]).is_some() {
56            let src_slice = unsafe { std::slice::from_raw_parts(src_ptr, total) };
57            use rayon::prelude::*;
58            let nthreads = rayon::current_num_threads();
59            let chunk_size = (total + nthreads - 1) / nthreads;
60            let result = src_slice
61                .par_chunks(chunk_size)
62                .map(|chunk| {
63                    simd::dispatch_if_large(chunk.len(), || {
64                        let mut acc = init.clone();
65                        for &val in chunk.iter() {
66                            acc = reduce_fn(acc, map_fn(Op::apply(val)));
67                        }
68                        acc
69                    })
70                })
71                .reduce(|| init.clone(), |a, b| reduce_fn(a, b));
72            return Ok(result);
73        }
74    }
75
76    let strides_list: [&[isize]; 1] = [src_strides];
77
78    let (fused_dims, ordered_strides, plan) =
79        build_plan_fused(src_dims, &strides_list, None, std::mem::size_of::<T>());
80
81    #[cfg(feature = "parallel")]
82    {
83        let total: usize = fused_dims.iter().product();
84        if total > MINTHREADLENGTH {
85            let nthreads = rayon::current_num_threads();
86            // False sharing avoidance: space output slots by cache line size
87            let spacing = (64 / std::mem::size_of::<U>()).max(1);
88            let mut threadedout = vec![init.clone(); spacing * nthreads];
89            let threadedout_ptr = SendPtr(threadedout.as_mut_ptr());
90            let src_send = SendPtr(src_ptr as *mut T);
91
92            let costs = compute_costs(&ordered_strides);
93
94            // For complete reduction, strides_list has 2 entries:
95            // [0] = threadedout (stride 0 everywhere — broadcasting), [1] = src
96            // The spacing/taskindex mechanism addresses output slots.
97            let ndim = fused_dims.len();
98            let mut threaded_strides = Vec::with_capacity(ordered_strides.len() + 1);
99            threaded_strides.push(vec![0isize; ndim]); // threadedout: stride 0 (broadcast)
100            for s in &ordered_strides {
101                threaded_strides.push(s.clone());
102            }
103            let initial_offsets = vec![0isize; threaded_strides.len()];
104
105            // Mask costs for threadedout stride=0 dims (all dims, since it's fully broadcast)
106            // This means: do NOT split on dims where output stride is 0 — but for complete
107            // reduction, ALL output strides are 0, so costs would all be masked to 0.
108            // Julia handles this with the spacing mechanism: each task writes to its own slot.
109            // We keep costs unmasked so splitting still works.
110
111            mapreduce_threaded(
112                &fused_dims,
113                &plan.block,
114                &threaded_strides,
115                &initial_offsets,
116                &costs,
117                nthreads,
118                spacing as isize,
119                1,
120                &|dims, blocks, strides_list, offsets| {
121                    // offsets[0] = spacing * (taskindex - 1) for threadedout
122                    // offsets[1] = offset into src
123                    let out_offset = offsets[0] as usize;
124                    let src_offsets = &offsets[1..];
125
126                    for_each_inner_block_with_offsets(
127                        dims,
128                        blocks,
129                        &strides_list[1..],
130                        src_offsets,
131                        |offsets, len, strides| {
132                            let mut ptr = unsafe { src_send.as_const().offset(offsets[0]) };
133                            let stride = strides[0];
134                            let slot = unsafe { &mut *threadedout_ptr.as_ptr().add(out_offset) };
135                            for _ in 0..len {
136                                let val = Op::apply(unsafe { *ptr });
137                                let mapped = map_fn(val);
138                                *slot = reduce_fn(slot.clone(), mapped);
139                                unsafe {
140                                    ptr = ptr.offset(stride);
141                                }
142                            }
143                            Ok(())
144                        },
145                    )
146                },
147            )?;
148
149            // Merge thread-local results
150            let mut result = init;
151            for i in 0..nthreads {
152                result = reduce_fn(result, threadedout[i * spacing].clone());
153            }
154            return Ok(result);
155        }
156    }
157
158    let mut acc = init;
159    let initial_offsets = vec![0isize; ordered_strides.len()];
160    for_each_inner_block_preordered(
161        &fused_dims,
162        &plan.block,
163        &ordered_strides,
164        &initial_offsets,
165        |offsets, len, strides| {
166            let mut ptr = unsafe { src_ptr.offset(offsets[0]) };
167            let stride = strides[0];
168            for _ in 0..len {
169                let val = Op::apply(unsafe { *ptr });
170                let mapped = map_fn(val);
171                acc = reduce_fn(acc.clone(), mapped);
172                unsafe {
173                    ptr = ptr.offset(stride);
174                }
175            }
176            Ok(())
177        },
178    )?;
179
180    Ok(acc)
181}
182
183/// Reduce along a single axis, returning a new StridedArray.
184pub fn reduce_axis<T: Copy + MaybeSendSync, Op: ElementOp<T>, M, R, U>(
185    src: &StridedView<T, Op>,
186    axis: usize,
187    map_fn: M,
188    reduce_fn: R,
189    init: U,
190) -> Result<StridedArray<U>>
191where
192    M: Fn(T) -> U + MaybeSync,
193    R: Fn(U, U) -> U + MaybeSync,
194    U: Clone + MaybeSendSync,
195{
196    let rank = src.ndim();
197    if axis >= rank {
198        return Err(StridedError::InvalidAxis { axis, rank });
199    }
200
201    let src_dims = src.dims();
202    let src_strides = src.strides();
203    let src_ptr = src.ptr();
204
205    let out_dims: Vec<usize> = src_dims
206        .iter()
207        .enumerate()
208        .filter(|(i, _)| *i != axis)
209        .map(|(_, &d)| d)
210        .collect();
211
212    let axis_len = src_dims[axis];
213    let axis_stride = src_strides[axis];
214
215    if out_dims.is_empty() {
216        // Reduce to scalar
217        let mut acc = init;
218        let mut offset = 0isize;
219        for _ in 0..axis_len {
220            let val = Op::apply(unsafe { *src_ptr.offset(offset) });
221            let mapped = map_fn(val);
222            acc = reduce_fn(acc, mapped);
223            offset += axis_stride;
224        }
225        let strides = col_major_strides(&[1]);
226        return StridedArray::from_parts(vec![acc], &[1], &strides, 0);
227    }
228
229    let total_out: usize = out_dims.iter().product();
230    let out_strides = col_major_strides(&out_dims);
231    let mut out =
232        StridedArray::from_parts(vec![init.clone(); total_out], &out_dims, &out_strides, 0)?;
233
234    // Build source strides for iteration over non-axis dimensions (same rank as out_dims)
235    let src_kept_strides: Vec<isize> = src_strides
236        .iter()
237        .enumerate()
238        .filter(|(i, _)| *i != axis)
239        .map(|(_, &s)| s)
240        .collect();
241
242    let elem_size = std::mem::size_of::<T>().max(std::mem::size_of::<U>());
243    let strides_list: [&[isize]; 2] = [&out_strides, &src_kept_strides];
244    let (fused_dims, ordered_strides, plan) =
245        build_plan_fused(&out_dims, &strides_list, Some(0), elem_size);
246
247    let out_ptr = out.view_mut().as_mut_ptr();
248
249    let initial_offsets = vec![0isize; ordered_strides.len()];
250    for_each_inner_block_preordered(
251        &fused_dims,
252        &plan.block,
253        &ordered_strides,
254        &initial_offsets,
255        |offsets, len, strides| {
256            let out_step = strides[0];
257            let src_step = strides[1];
258
259            // Fast path: when both output and source have stride 1, swap to
260            // reduction-outer / output-inner with slices so LLVM can
261            // auto-vectorize the contiguous inner loop.
262            if out_step == 1 && src_step == 1 && axis_len > 1 {
263                let n = len as usize;
264                let out_slice =
265                    unsafe { std::slice::from_raw_parts_mut(out_ptr.offset(offsets[0]), n) };
266                // First reduction element → initialize output
267                let src0 = unsafe { std::slice::from_raw_parts(src_ptr.offset(offsets[1]), n) };
268                for i in 0..n {
269                    out_slice[i] = map_fn(Op::apply(src0[i]));
270                }
271                // Remaining reduction elements → accumulate
272                for k in 1..axis_len {
273                    let src_k = unsafe {
274                        std::slice::from_raw_parts(
275                            src_ptr.offset(offsets[1] + k as isize * axis_stride),
276                            n,
277                        )
278                    };
279                    for i in 0..n {
280                        out_slice[i] = reduce_fn(out_slice[i].clone(), map_fn(Op::apply(src_k[i])));
281                    }
282                }
283                return Ok(());
284            }
285
286            // General path: output-outer, reduction-inner
287            let mut out_off = offsets[0];
288            let mut src_off = offsets[1];
289            for _ in 0..len {
290                let mut acc = init.clone();
291                let mut ptr = unsafe { src_ptr.offset(src_off) };
292                for _ in 0..axis_len {
293                    let val = Op::apply(unsafe { *ptr });
294                    let mapped = map_fn(val);
295                    acc = reduce_fn(acc, mapped);
296                    unsafe {
297                        ptr = ptr.offset(axis_stride);
298                    }
299                }
300                unsafe {
301                    *out_ptr.offset(out_off) = acc;
302                }
303                out_off += out_step;
304                src_off += src_step;
305            }
306            Ok(())
307        },
308    )?;
309
310    Ok(out)
311}