Skip to main content

strided_perm/
kernel.rs

1//! Block-based iteration engine for strided permutation operations.
2
3use crate::fuse::{compress_dims, fuse_dims};
4use crate::{block, order};
5use strided_view::Result;
6
7/// Maximum total elements for the small tensor fast path.
8pub const SMALL_TENSOR_THRESHOLD: usize = 1024;
9
10pub struct KernelPlan {
11    pub order: Vec<usize>,
12    pub block: Vec<usize>,
13}
14
15/// Build an execution plan with dimension fusion.
16///
17/// Pipeline: order -> reorder -> fuse -> block.
18pub fn build_plan_fused(
19    dims: &[usize],
20    strides_list: &[&[isize]],
21    dest_index: Option<usize>,
22    elem_size: usize,
23) -> (Vec<usize>, Vec<Vec<isize>>, KernelPlan) {
24    let order = order::compute_order(dims, strides_list, dest_index);
25
26    let ordered_dims: Vec<usize> = order.iter().map(|&d| dims[d]).collect();
27    let ordered_strides: Vec<Vec<isize>> = strides_list
28        .iter()
29        .map(|strides| order.iter().map(|&d| strides[d]).collect())
30        .collect();
31    let ordered_strides_refs: Vec<&[isize]> =
32        ordered_strides.iter().map(|s| s.as_slice()).collect();
33
34    let fused_dims = fuse_dims(&ordered_dims, &ordered_strides_refs);
35    let (compressed_dims, compressed_strides) = compress_dims(&fused_dims, &ordered_strides);
36    let compressed_strides_refs: Vec<&[isize]> =
37        compressed_strides.iter().map(|s| s.as_slice()).collect();
38
39    let identity: Vec<usize> = (0..compressed_dims.len()).collect();
40    let block = block::compute_block_sizes(
41        &compressed_dims,
42        &identity,
43        &compressed_strides_refs,
44        elem_size,
45    );
46
47    (
48        compressed_dims,
49        compressed_strides,
50        KernelPlan {
51            order: identity,
52            block,
53        },
54    )
55}
56
57/// Simplified plan for small tensors that fit in L1 cache.
58pub fn build_plan_fused_small(
59    dims: &[usize],
60    strides_list: &[&[isize]],
61) -> (Vec<usize>, Vec<Vec<isize>>, KernelPlan) {
62    let strides_owned: Vec<Vec<isize>> = strides_list.iter().map(|s| s.to_vec()).collect();
63
64    let fused = fuse_dims(dims, strides_list);
65    let (fused_dims, fused_strides) = compress_dims(&fused, &strides_owned);
66
67    let block = fused_dims.clone();
68    let identity: Vec<usize> = (0..fused_dims.len()).collect();
69
70    (
71        fused_dims,
72        fused_strides,
73        KernelPlan {
74            order: identity,
75            block,
76        },
77    )
78}
79
80// ============================================================================
81// Macro-generated rank-specialized kernels
82// ============================================================================
83
84macro_rules! elem_loops {
85    ($offsets:ident, $strides:ident, $f:ident, $blens:ident, $is:ident; $lv:literal) => {
86        for _ in 0..$blens[$lv] {
87            $f($offsets, $blens[0], &$is)?;
88            for (o, s) in $offsets.iter_mut().zip($strides.iter()) {
89                *o += s[$lv];
90            }
91        }
92    };
93    ($offsets:ident, $strides:ident, $f:ident, $blens:ident, $is:ident;
94     $lv:literal, $next:literal $(, $rest:literal)*) => {
95        for _ in 0..$blens[$lv] {
96            elem_loops!($offsets, $strides, $f, $blens, $is; $next $(, $rest)*);
97            for (o, s) in $offsets.iter_mut().zip($strides.iter()) {
98                *o -= ($blens[$next] as isize) * s[$next];
99                *o += s[$lv];
100            }
101        }
102    };
103}
104
105macro_rules! block_loop {
106    ($dims:ident, $blocks:ident, $strides:ident, $offsets:ident, $f:ident,
107     $blens:ident, $is:ident; elem=[$($el:literal),+]; $lv0:literal; top=$top:literal) => {{
108        let mut _j = 0usize;
109        while _j < $dims[$lv0] {
110            $blens[$lv0] = $blocks[$lv0].max(1).min($dims[$lv0]).min($dims[$lv0] - _j);
111            elem_loops!($offsets, $strides, $f, $blens, $is; $($el),+);
112            for (o, s) in $offsets.iter_mut().zip($strides.iter()) {
113                *o -= ($blens[$top] as isize) * s[$top];
114                *o += ($blens[$lv0] as isize) * s[$lv0];
115            }
116            _j += $blens[$lv0];
117        }
118    }};
119    ($dims:ident, $blocks:ident, $strides:ident, $offsets:ident, $f:ident,
120     $blens:ident, $is:ident; elem=[$($el:literal),+];
121     $lv:literal, $next:literal $(, $rest:literal)*; top=$top:literal) => {{
122        let mut _j = 0usize;
123        while _j < $dims[$lv] {
124            $blens[$lv] = $blocks[$lv].max(1).min($dims[$lv]).min($dims[$lv] - _j);
125            block_loop!($dims, $blocks, $strides, $offsets, $f, $blens, $is;
126                elem=[$($el),+]; $next $(, $rest)*; top=$top);
127            for (o, s) in $offsets.iter_mut().zip($strides.iter()) {
128                *o -= ($dims[$next] as isize) * s[$next];
129                *o += ($blens[$lv] as isize) * s[$lv];
130            }
131            _j += $blens[$lv];
132        }
133    }};
134}
135
136macro_rules! make_kernel {
137    ($name:ident, rank=1) => {
138        #[inline]
139        fn $name<F>(
140            dims: &[usize],
141            blocks: &[usize],
142            strides: &[Vec<isize>],
143            offsets: &mut [isize],
144            f: &mut F,
145        ) -> Result<()>
146        where
147            F: FnMut(&[isize], usize, &[isize]) -> Result<()>,
148        {
149            let d0 = dims[0];
150            let b0 = blocks[0].max(1).min(d0);
151            let inner_strides: Vec<isize> = strides.iter().map(|s| s[0]).collect();
152            let mut j0 = 0usize;
153            while j0 < d0 {
154                let blen0 = b0.min(d0 - j0);
155                f(offsets, blen0, &inner_strides)?;
156                for (o, s) in offsets.iter_mut().zip(strides.iter()) {
157                    *o += (blen0 as isize) * s[0];
158                }
159                j0 += blen0;
160            }
161            for (o, s) in offsets.iter_mut().zip(strides.iter()) {
162                *o -= (d0 as isize) * s[0];
163            }
164            Ok(())
165        }
166    };
167    ($name:ident, rank=$rank:literal,
168     block=[$($blk:literal),+], elem=[$($el:literal),+], top=$top:literal) => {
169        #[inline]
170        fn $name<F>(
171            dims: &[usize],
172            blocks: &[usize],
173            strides: &[Vec<isize>],
174            offsets: &mut [isize],
175            f: &mut F,
176        ) -> Result<()>
177        where
178            F: FnMut(&[isize], usize, &[isize]) -> Result<()>,
179        {
180            let inner_strides: Vec<isize> = strides.iter().map(|s| s[0]).collect();
181            let mut blens = [0usize; $rank];
182            block_loop!(dims, blocks, strides, offsets, f, blens, inner_strides;
183                elem=[$($el),+]; $($blk),+; top=$top);
184            for (o, s) in offsets.iter_mut().zip(strides.iter()) {
185                *o -= (dims[$top] as isize) * s[$top];
186            }
187            Ok(())
188        }
189    };
190}
191
192make_kernel!(kernel_1d_inner, rank = 1);
193make_kernel!(
194    kernel_2d_inner,
195    rank = 2,
196    block = [1, 0],
197    elem = [1],
198    top = 1
199);
200make_kernel!(
201    kernel_3d_inner,
202    rank = 3,
203    block = [2, 1, 0],
204    elem = [2, 1],
205    top = 2
206);
207make_kernel!(
208    kernel_4d_inner,
209    rank = 4,
210    block = [3, 2, 1, 0],
211    elem = [3, 2, 1],
212    top = 3
213);
214make_kernel!(
215    kernel_5d_inner,
216    rank = 5,
217    block = [4, 3, 2, 1, 0],
218    elem = [4, 3, 2, 1],
219    top = 4
220);
221make_kernel!(
222    kernel_6d_inner,
223    rank = 6,
224    block = [5, 4, 3, 2, 1, 0],
225    elem = [5, 4, 3, 2, 1],
226    top = 5
227);
228make_kernel!(
229    kernel_7d_inner,
230    rank = 7,
231    block = [6, 5, 4, 3, 2, 1, 0],
232    elem = [6, 5, 4, 3, 2, 1],
233    top = 6
234);
235make_kernel!(
236    kernel_8d_inner,
237    rank = 8,
238    block = [7, 6, 5, 4, 3, 2, 1, 0],
239    elem = [7, 6, 5, 4, 3, 2, 1],
240    top = 7
241);
242
243/// N-dimensional kernel (iterative form, fallback for rank >= 9).
244#[inline]
245fn kernel_nd_inner_iterative<F>(
246    dims: &[usize],
247    blocks: &[usize],
248    strides: &[Vec<isize>],
249    offsets: &mut [isize],
250    f: &mut F,
251) -> Result<()>
252where
253    F: FnMut(&[isize], usize, &[isize]) -> Result<()>,
254{
255    let rank = dims.len();
256    debug_assert!(rank >= 9);
257
258    let d0 = dims[0];
259    let b0 = blocks[0].max(1).min(d0);
260    let inner_strides: Vec<isize> = strides.iter().map(|s| s[0]).collect();
261
262    let mut idx = vec![0usize; rank];
263
264    loop {
265        let mut j0 = 0usize;
266        while j0 < d0 {
267            let blen0 = b0.min(d0 - j0);
268            f(offsets, blen0, &inner_strides)?;
269            for (offset, s) in offsets.iter_mut().zip(strides.iter()) {
270                *offset += (blen0 as isize) * s[0];
271            }
272            j0 += blen0;
273        }
274        for (offset, s) in offsets.iter_mut().zip(strides.iter()) {
275            *offset -= (d0 as isize) * s[0];
276        }
277
278        let mut level = 1usize;
279        loop {
280            for (offset, s) in offsets.iter_mut().zip(strides.iter()) {
281                *offset += s[level];
282            }
283            idx[level] += 1;
284            if idx[level] < dims[level] {
285                break;
286            }
287
288            idx[level] = 0;
289            for (offset, s) in offsets.iter_mut().zip(strides.iter()) {
290                *offset -= (dims[level] as isize) * s[level];
291            }
292            level += 1;
293            if level == rank {
294                return Ok(());
295            }
296        }
297    }
298}
299
300/// Iterate over blocks with pre-ordered dimensions and initial offsets.
301#[inline]
302pub fn for_each_inner_block_preordered<F>(
303    dims: &[usize],
304    blocks: &[usize],
305    strides: &[Vec<isize>],
306    initial_offsets: &[isize],
307    mut f: F,
308) -> Result<()>
309where
310    F: FnMut(&[isize], usize, &[isize]) -> Result<()>,
311{
312    let rank = dims.len();
313    if rank == 0 {
314        return f(initial_offsets, 1, &[]);
315    }
316
317    let mut offsets = initial_offsets.to_vec();
318
319    match rank {
320        1 => kernel_1d_inner(dims, blocks, strides, &mut offsets, &mut f),
321        2 => kernel_2d_inner(dims, blocks, strides, &mut offsets, &mut f),
322        3 => kernel_3d_inner(dims, blocks, strides, &mut offsets, &mut f),
323        4 => kernel_4d_inner(dims, blocks, strides, &mut offsets, &mut f),
324        5 => kernel_5d_inner(dims, blocks, strides, &mut offsets, &mut f),
325        6 => kernel_6d_inner(dims, blocks, strides, &mut offsets, &mut f),
326        7 => kernel_7d_inner(dims, blocks, strides, &mut offsets, &mut f),
327        8 => kernel_8d_inner(dims, blocks, strides, &mut offsets, &mut f),
328        _ => kernel_nd_inner_iterative(dims, blocks, strides, &mut offsets, &mut f),
329    }
330}
331
332/// Utility: total number of elements.
333#[inline]
334pub fn total_len(dims: &[usize]) -> usize {
335    if dims.is_empty() {
336        return 1;
337    }
338    dims.iter().product()
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_build_plan_fused_compresses() {
347        let dims = [2usize, 3];
348        let strides = [1isize, 2];
349        let strides_list: Vec<&[isize]> = vec![&strides];
350        let (fused_dims, fused_strides, plan) = build_plan_fused(&dims, &strides_list, Some(0), 8);
351        assert_eq!(fused_dims, vec![6]);
352        assert_eq!(fused_strides.len(), 1);
353        assert_eq!(fused_strides[0], vec![1]);
354        assert_eq!(plan.block.len(), 1);
355    }
356
357    #[test]
358    fn test_for_each_inner_block_preordered_total() {
359        let dims = vec![3, 4, 2];
360        let blocks = vec![3, 4, 2];
361        let strides = vec![vec![1, 3, 12], vec![1, 3, 12]];
362        let offsets = vec![0, 0];
363        let mut total = 0usize;
364        for_each_inner_block_preordered(&dims, &blocks, &strides, &offsets, |_, len, _| {
365            total += len;
366            Ok(())
367        })
368        .unwrap();
369        assert_eq!(total, 24);
370    }
371
372    #[test]
373    fn test_build_plan_fused_non_contiguous() {
374        // Row-major [4,5]: strides [5, 1]
375        let dims = [4usize, 5];
376        let strides = [5isize, 1];
377        let strides_list: Vec<&[isize]> = vec![&strides];
378        let (fused_dims, fused_strides, plan) = build_plan_fused(&dims, &strides_list, Some(0), 8);
379        // Row-major: should reorder so stride-1 is first, then fuse to 20
380        assert_eq!(fused_dims, vec![20]);
381        assert_eq!(fused_strides[0], vec![1]);
382        assert_eq!(plan.block.len(), 1);
383    }
384
385    #[test]
386    fn test_build_plan_fused_multi_array() {
387        // Two arrays: one col-major, one row-major
388        let dims = [4usize, 5];
389        let dst_strides = [1isize, 4];
390        let src_strides = [5isize, 1];
391        let strides_list: Vec<&[isize]> = vec![&dst_strides, &src_strides];
392        let (fused_dims, fused_strides, plan) = build_plan_fused(&dims, &strides_list, Some(0), 8);
393        // Conflicting strides means no fusion possible
394        assert_eq!(fused_dims.len(), 2);
395        assert_eq!(fused_strides.len(), 2);
396        assert!(plan.block.len() >= 1);
397    }
398
399    #[test]
400    fn test_build_plan_fused_small_basic() {
401        let dims = [2usize, 3];
402        let strides = [1isize, 2];
403        let strides_list: Vec<&[isize]> = vec![&strides];
404        let (fused_dims, fused_strides, plan) = build_plan_fused_small(&dims, &strides_list);
405        assert_eq!(fused_dims, vec![6]);
406        assert_eq!(fused_strides[0], vec![1]);
407        // Small plan: block = dims (no blocking)
408        assert_eq!(plan.block, fused_dims);
409    }
410
411    #[test]
412    fn test_build_plan_fused_small_non_contiguous() {
413        let dims = [4usize, 5];
414        let strides = [5isize, 1];
415        let strides_list: Vec<&[isize]> = vec![&strides];
416        let (fused_dims, fused_strides, plan) = build_plan_fused_small(&dims, &strides_list);
417        // No ordering in small path, but fusion should still work on contiguous groups
418        assert!(!fused_dims.is_empty());
419        assert_eq!(fused_strides.len(), 1);
420        assert_eq!(plan.block, fused_dims);
421    }
422
423    #[test]
424    fn test_for_each_rank0() {
425        let dims = vec![];
426        let blocks = vec![];
427        let strides: Vec<Vec<isize>> = vec![vec![]];
428        let offsets = vec![0isize];
429        let mut called = false;
430        for_each_inner_block_preordered(&dims, &blocks, &strides, &offsets, |_, len, _| {
431            called = true;
432            assert_eq!(len, 1);
433            Ok(())
434        })
435        .unwrap();
436        assert!(called);
437    }
438
439    /// Helper: count total elements iterated for given rank
440    fn count_elements(dims: &[usize], blocks: &[usize]) -> usize {
441        let n_arrays = 1;
442        let strides: Vec<Vec<isize>> = {
443            let mut s = vec![vec![0isize; dims.len()]; n_arrays];
444            let mut stride = 1isize;
445            for d in 0..dims.len() {
446                for a in 0..n_arrays {
447                    s[a][d] = stride;
448                }
449                stride *= dims[d] as isize;
450            }
451            s
452        };
453        let offsets = vec![0isize; n_arrays];
454        let mut total = 0usize;
455        for_each_inner_block_preordered(dims, blocks, &strides, &offsets, |_, len, _| {
456            total += len;
457            Ok(())
458        })
459        .unwrap();
460        total
461    }
462
463    #[test]
464    fn test_for_each_rank4() {
465        assert_eq!(count_elements(&[2, 3, 4, 5], &[2, 3, 4, 5]), 120);
466    }
467
468    #[test]
469    fn test_for_each_rank5() {
470        assert_eq!(count_elements(&[2, 2, 2, 2, 2], &[2, 2, 2, 2, 2]), 32);
471    }
472
473    #[test]
474    fn test_for_each_rank6() {
475        assert_eq!(count_elements(&[2, 2, 2, 2, 2, 3], &[2, 2, 2, 2, 2, 3]), 96);
476    }
477
478    #[test]
479    fn test_for_each_rank7() {
480        assert_eq!(
481            count_elements(&[2, 2, 2, 2, 2, 2, 3], &[2, 2, 2, 2, 2, 2, 3]),
482            192
483        );
484    }
485
486    #[test]
487    fn test_for_each_rank8() {
488        assert_eq!(
489            count_elements(&[2, 2, 2, 2, 2, 2, 2, 3], &[2, 2, 2, 2, 2, 2, 2, 3]),
490            384
491        );
492    }
493
494    #[test]
495    fn test_for_each_rank9_iterative() {
496        // Rank 9 triggers kernel_nd_inner_iterative
497        assert_eq!(
498            count_elements(&[2, 2, 2, 2, 2, 2, 2, 2, 3], &[2, 2, 2, 2, 2, 2, 2, 2, 3]),
499            768
500        );
501    }
502
503    #[test]
504    fn test_for_each_rank10_iterative() {
505        assert_eq!(
506            count_elements(
507                &[2, 2, 2, 2, 2, 2, 2, 2, 2, 3],
508                &[2, 2, 2, 2, 2, 2, 2, 2, 2, 3]
509            ),
510            1536
511        );
512    }
513
514    #[test]
515    fn test_total_len_empty() {
516        assert_eq!(total_len(&[]), 1);
517    }
518
519    #[test]
520    fn test_total_len_basic() {
521        assert_eq!(total_len(&[2, 3, 4]), 24);
522    }
523}