Skip to main content

strided_perm/hptt/
execute.rs

1//! Execution engine: recursive loop nest dispatching to macro_kernel.
2//!
3//! Mirrors HPTT C++'s `transpose_int` (lines 602-681) and
4//! `transpose_int_constStride1` (lines 683-720).
5
6use crate::hptt::macro_kernel::{
7    const_stride1_copy, macro_kernel_f32, macro_kernel_f64, macro_kernel_fallback,
8};
9use crate::hptt::plan::{ComputeNode, ExecMode, PermutePlan};
10
11#[cfg(feature = "parallel")]
12use rayon::iter::{IntoParallelIterator, ParallelIterator};
13
14/// Minimum elements to justify multi-threaded execution.
15#[cfg(feature = "parallel")]
16const MINTHREADLENGTH: usize = 1 << 15; // 32768
17
18/// Execute the permutation plan (single-threaded).
19///
20/// # Safety
21/// - `src` must be valid for reads at all offsets determined by dims/src_strides
22/// - `dst` must be valid for writes at all offsets determined by dims/dst_strides
23/// - src and dst must not overlap
24pub unsafe fn execute_permute_blocked<T: Copy>(src: *const T, dst: *mut T, plan: &PermutePlan) {
25    match plan.mode {
26        ExecMode::Scalar => {
27            *dst = *src;
28        }
29        ExecMode::ConstStride1 { inner_dim } => {
30            let count = plan.fused_dims[inner_dim];
31            let src_stride = plan.src_strides[inner_dim];
32            let dst_stride = plan.dst_strides[inner_dim];
33            match &plan.root {
34                Some(root) => {
35                    const_stride1_recursive(src, dst, root, count, src_stride, dst_stride);
36                }
37                None => {
38                    const_stride1_copy(src, dst, count, src_stride, dst_stride);
39                }
40            }
41        }
42        ExecMode::Transpose { dim_a, dim_b } => {
43            let size_a = plan.fused_dims[dim_a];
44            let size_b = plan.fused_dims[dim_b];
45            let lda = plan.lda_inner;
46            let ldb = plan.ldb_inner;
47            let block = plan.block;
48            let elem_size = std::mem::size_of::<T>();
49
50            match &plan.root {
51                Some(root) => {
52                    transpose_recursive(src, dst, root, size_a, size_b, lda, ldb, block, elem_size);
53                }
54                None => {
55                    // No outer loops — just the 2D blocked transpose
56                    dispatch_blocked_2d(src, dst, size_a, size_b, lda, ldb, block, elem_size);
57                }
58            }
59        }
60    }
61}
62
63/// Execute the permutation plan with Rayon parallelism.
64///
65/// Parallelizes over the outermost ComputeNode's dimension.
66/// Falls back to single-threaded for small tensors.
67///
68/// # Safety
69/// Same requirements as `execute_permute_blocked`.
70#[cfg(feature = "parallel")]
71pub unsafe fn execute_permute_blocked_par<T: Copy + Send + Sync>(
72    src: *const T,
73    dst: *mut T,
74    plan: &PermutePlan,
75) {
76    let total: usize = plan.fused_dims.iter().product();
77
78    if total < MINTHREADLENGTH {
79        execute_permute_blocked(src, dst, plan);
80        return;
81    }
82
83    let root = match &plan.root {
84        Some(r) => r,
85        None => {
86            execute_permute_blocked(src, dst, plan);
87            return;
88        }
89    };
90
91    let outer_dim = root.end;
92    if outer_dim <= 1 {
93        execute_permute_blocked(src, dst, plan);
94        return;
95    }
96
97    let src_addr = src as usize;
98    let dst_addr = dst as usize;
99    let lda_root = root.lda;
100    let ldb_root = root.ldb;
101    let elem_size = std::mem::size_of::<T>();
102    let inner = root.next.clone();
103
104    match plan.mode {
105        ExecMode::Transpose { dim_a, dim_b } => {
106            let size_a = plan.fused_dims[dim_a];
107            let size_b = plan.fused_dims[dim_b];
108            let lda = plan.lda_inner;
109            let ldb = plan.ldb_inner;
110            let block = plan.block;
111
112            (0..outer_dim).into_par_iter().for_each(|i| {
113                let s = (src_addr as isize + (i as isize) * lda_root * (elem_size as isize))
114                    as *const T;
115                let d =
116                    (dst_addr as isize + (i as isize) * ldb_root * (elem_size as isize)) as *mut T;
117
118                unsafe {
119                    match &inner {
120                        Some(next) => {
121                            transpose_recursive(
122                                s, d, next, size_a, size_b, lda, ldb, block, elem_size,
123                            );
124                        }
125                        None => {
126                            dispatch_blocked_2d(s, d, size_a, size_b, lda, ldb, block, elem_size);
127                        }
128                    }
129                }
130            });
131        }
132        ExecMode::ConstStride1 { inner_dim } => {
133            let count = plan.fused_dims[inner_dim];
134            let src_stride = plan.src_strides[inner_dim];
135            let dst_stride = plan.dst_strides[inner_dim];
136
137            (0..outer_dim).into_par_iter().for_each(|i| {
138                let s = (src_addr as isize + (i as isize) * lda_root * (elem_size as isize))
139                    as *const T;
140                let d =
141                    (dst_addr as isize + (i as isize) * ldb_root * (elem_size as isize)) as *mut T;
142
143                unsafe {
144                    match &inner {
145                        Some(next) => {
146                            const_stride1_recursive(s, d, next, count, src_stride, dst_stride);
147                        }
148                        None => {
149                            const_stride1_copy(s, d, count, src_stride, dst_stride);
150                        }
151                    }
152                }
153            });
154        }
155        ExecMode::Scalar => {
156            execute_permute_blocked(src, dst, plan);
157        }
158    }
159}
160
161// ---------------------------------------------------------------------------
162// Transpose mode: recursive execution
163// ---------------------------------------------------------------------------
164
165/// Recursive loop nest for Transpose mode.
166///
167/// Mirrors HPTT's `transpose_int`. Each ComputeNode iterates its dimension
168/// with inc=1. At the leaf, runs the 2D blocked transpose over dim_A × dim_B.
169unsafe fn transpose_recursive<T: Copy>(
170    src: *const T,
171    dst: *mut T,
172    node: &ComputeNode,
173    size_a: usize,
174    size_b: usize,
175    lda: isize,
176    ldb: isize,
177    block: usize,
178    elem_size: usize,
179) {
180    let end = node.end;
181    let node_lda = node.lda;
182    let node_ldb = node.ldb;
183
184    match &node.next {
185        Some(next) => {
186            let mut s = src;
187            let mut d = dst;
188            for _ in 0..end {
189                transpose_recursive(s, d, next, size_a, size_b, lda, ldb, block, elem_size);
190                s = s.offset(node_lda);
191                d = d.offset(node_ldb);
192            }
193        }
194        None => {
195            // Leaf: iterate this dim, calling blocked 2D transpose at each position
196            let mut s = src;
197            let mut d = dst;
198            for _ in 0..end {
199                dispatch_blocked_2d(s, d, size_a, size_b, lda, ldb, block, elem_size);
200                s = s.offset(node_lda);
201                d = d.offset(node_ldb);
202            }
203        }
204    }
205}
206
207/// 2D blocked transpose over dim_A × dim_B.
208///
209/// Tiles both dimensions by BLOCK and calls the appropriate macro_kernel.
210#[inline]
211unsafe fn dispatch_blocked_2d<T: Copy>(
212    src: *const T,
213    dst: *mut T,
214    size_a: usize,
215    size_b: usize,
216    lda: isize,
217    ldb: isize,
218    block: usize,
219    elem_size: usize,
220) {
221    match elem_size {
222        8 => blocked_transpose_2d_f64(
223            src as *const f64,
224            dst as *mut f64,
225            size_a,
226            size_b,
227            lda,
228            ldb,
229            block,
230        ),
231        4 => blocked_transpose_2d_f32(
232            src as *const f32,
233            dst as *mut f32,
234            size_a,
235            size_b,
236            lda,
237            ldb,
238            block,
239        ),
240        _ => blocked_transpose_2d_fallback(src, dst, size_a, size_b, lda, ldb, block),
241    }
242}
243
244#[inline]
245unsafe fn blocked_transpose_2d_f64(
246    src: *const f64,
247    dst: *mut f64,
248    size_a: usize,
249    size_b: usize,
250    lda: isize,
251    ldb: isize,
252    block: usize,
253) {
254    let mut ib = 0usize;
255    while ib < size_b {
256        let bb = block.min(size_b - ib);
257        let mut ia = 0usize;
258        while ia < size_a {
259            let ba = block.min(size_a - ia);
260            macro_kernel_f64(
261                src.offset(ia as isize + ib as isize * lda),
262                lda,
263                ba,
264                dst.offset(ib as isize + ia as isize * ldb),
265                ldb,
266                bb,
267            );
268            ia += block;
269        }
270        ib += block;
271    }
272}
273
274#[inline]
275unsafe fn blocked_transpose_2d_f32(
276    src: *const f32,
277    dst: *mut f32,
278    size_a: usize,
279    size_b: usize,
280    lda: isize,
281    ldb: isize,
282    block: usize,
283) {
284    let mut ib = 0usize;
285    while ib < size_b {
286        let bb = block.min(size_b - ib);
287        let mut ia = 0usize;
288        while ia < size_a {
289            let ba = block.min(size_a - ia);
290            macro_kernel_f32(
291                src.offset(ia as isize + ib as isize * lda),
292                lda,
293                ba,
294                dst.offset(ib as isize + ia as isize * ldb),
295                ldb,
296                bb,
297            );
298            ia += block;
299        }
300        ib += block;
301    }
302}
303
304#[inline]
305unsafe fn blocked_transpose_2d_fallback<T: Copy>(
306    src: *const T,
307    dst: *mut T,
308    size_a: usize,
309    size_b: usize,
310    lda: isize,
311    ldb: isize,
312    block: usize,
313) {
314    let mut ib = 0usize;
315    while ib < size_b {
316        let bb = block.min(size_b - ib);
317        let mut ia = 0usize;
318        while ia < size_a {
319            let ba = block.min(size_a - ia);
320            macro_kernel_fallback(
321                src.offset(ia as isize + ib as isize * lda),
322                lda,
323                ba,
324                dst.offset(ib as isize + ia as isize * ldb),
325                ldb,
326                bb,
327            );
328            ia += block;
329        }
330        ib += block;
331    }
332}
333
334// ---------------------------------------------------------------------------
335// ConstStride1 mode: recursive execution
336// ---------------------------------------------------------------------------
337
338/// Recursive loop nest for ConstStride1 mode.
339///
340/// Mirrors HPTT's `transpose_int_constStride1`. Each ComputeNode iterates
341/// its dimension. At the leaf, calls `const_stride1_copy` for the inner dim.
342unsafe fn const_stride1_recursive<T: Copy>(
343    src: *const T,
344    dst: *mut T,
345    node: &ComputeNode,
346    count: usize,
347    src_stride: isize,
348    dst_stride: isize,
349) {
350    let end = node.end;
351    let node_lda = node.lda;
352    let node_ldb = node.ldb;
353
354    match &node.next {
355        Some(next) => {
356            let mut s = src;
357            let mut d = dst;
358            for _ in 0..end {
359                const_stride1_recursive(s, d, next, count, src_stride, dst_stride);
360                s = s.offset(node_lda);
361                d = d.offset(node_ldb);
362            }
363        }
364        None => {
365            let mut s = src;
366            let mut d = dst;
367            for _ in 0..end {
368                const_stride1_copy(s, d, count, src_stride, dst_stride);
369                s = s.offset(node_lda);
370                d = d.offset(node_ldb);
371            }
372        }
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use crate::hptt::plan::build_permute_plan;
380
381    #[test]
382    fn test_execute_identity_copy() {
383        let src = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
384        let mut dst = vec![0.0f64; 6];
385        let plan = build_permute_plan(&[2, 3], &[1, 2], &[1, 2], 8);
386        unsafe {
387            execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
388        }
389        assert_eq!(dst, src);
390    }
391
392    #[test]
393    fn test_execute_transpose_2d() {
394        // src [3, 2] col-major: [1,2,3,4,5,6]
395        // Permuted view: dims [2, 3], strides [3, 1]
396        // dst col-major [2, 3]: strides [1, 2]
397        let src = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
398        let mut dst = vec![0.0f64; 6];
399        let plan = build_permute_plan(&[2, 3], &[3, 1], &[1, 2], 8);
400        unsafe {
401            execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
402        }
403        // Expected: dst = [1, 4, 2, 5, 3, 6]
404        assert_eq!(dst, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
405    }
406
407    #[test]
408    fn test_execute_3d_permute() {
409        // src [2,3,4] col-major, permute [2,0,1]
410        let dims = [2usize, 3, 4];
411        let total: usize = dims.iter().product();
412        let src: Vec<f64> = (0..total).map(|i| i as f64).collect();
413        let mut dst = vec![0.0f64; total];
414
415        // Permuted: dims [4,2,3], strides [6,1,2], dst col-major [1,4,8]
416        let plan = build_permute_plan(&[4, 2, 3], &[6, 1, 2], &[1, 4, 8], 8);
417        unsafe {
418            execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
419        }
420
421        for k in 0..4 {
422            for i in 0..2 {
423                for j in 0..3 {
424                    let dst_idx = k + i * 4 + j * 8;
425                    let src_idx = i + j * 2 + k * 6;
426                    assert_eq!(
427                        dst[dst_idx], src[src_idx],
428                        "mismatch at k={k}, i={i}, j={j}"
429                    );
430                }
431            }
432        }
433    }
434
435    #[test]
436    fn test_execute_4d_permute() {
437        let dims = [2usize, 3, 4, 5];
438        let total: usize = dims.iter().product();
439        let src: Vec<f64> = (0..total).map(|i| i as f64).collect();
440        let mut dst = vec![0.0f64; total];
441
442        // Permuted [3,1,0,2]: dims [5,3,2,4], strides [24,2,1,6], dst [1,5,15,30]
443        let plan = build_permute_plan(&[5, 3, 2, 4], &[24, 2, 1, 6], &[1, 5, 15, 30], 8);
444        unsafe {
445            execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
446        }
447
448        for i0 in 0..5 {
449            for i1 in 0..3 {
450                for i2 in 0..2 {
451                    for i3 in 0..4 {
452                        let src_idx = i0 * 24 + i1 * 2 + i2 + i3 * 6;
453                        let dst_idx = i0 + i1 * 5 + i2 * 15 + i3 * 30;
454                        assert_eq!(
455                            dst[dst_idx], src[src_idx],
456                            "4D mismatch at ({i0},{i1},{i2},{i3})"
457                        );
458                    }
459                }
460            }
461        }
462    }
463
464    #[test]
465    fn test_execute_5d_permute() {
466        let dims = [2usize, 2, 2, 2, 3];
467        let total: usize = dims.iter().product();
468        let src: Vec<f64> = (0..total).map(|i| i as f64).collect();
469        let mut dst = vec![0.0f64; total];
470
471        // Permuted [4,0,1,2,3]: dims [3,2,2,2,2], strides [16,1,2,4,8], dst [1,3,6,12,24]
472        let plan = build_permute_plan(&[3, 2, 2, 2, 2], &[16, 1, 2, 4, 8], &[1, 3, 6, 12, 24], 8);
473        unsafe {
474            execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
475        }
476
477        for i0 in 0..3 {
478            for i1 in 0..2 {
479                for i2 in 0..2 {
480                    for i3 in 0..2 {
481                        for i4 in 0..2 {
482                            let src_idx = i0 * 16 + i1 + i2 * 2 + i3 * 4 + i4 * 8;
483                            let dst_idx = i0 + i1 * 3 + i2 * 6 + i3 * 12 + i4 * 24;
484                            assert_eq!(
485                                dst[dst_idx], src[src_idx],
486                                "5D mismatch at ({i0},{i1},{i2},{i3},{i4})"
487                            );
488                        }
489                    }
490                }
491            }
492        }
493    }
494
495    #[test]
496    fn test_execute_rank0_scalar() {
497        let src = vec![42.0f64];
498        let mut dst = vec![0.0f64];
499        let plan = build_permute_plan(&[], &[], &[], 8);
500        unsafe {
501            execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
502        }
503        assert_eq!(dst[0], 42.0);
504    }
505
506    #[cfg(feature = "parallel")]
507    #[test]
508    fn test_execute_par_transpose_2d() {
509        let src = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
510        let mut dst = vec![0.0f64; 6];
511        let plan = build_permute_plan(&[2, 3], &[3, 1], &[1, 2], 8);
512        unsafe {
513            execute_permute_blocked_par(src.as_ptr(), dst.as_mut_ptr(), &plan);
514        }
515        assert_eq!(dst, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
516    }
517
518    #[cfg(feature = "parallel")]
519    #[test]
520    fn test_execute_par_large() {
521        let n = 256;
522        let total = n * n * n;
523        let src: Vec<f64> = (0..total).map(|i| i as f64).collect();
524        let mut dst = vec![0.0f64; total];
525
526        // [256, 256, 256] col-major, transpose [2, 0, 1]
527        let plan = build_permute_plan(&[n, n, n], &[65536, 1, 256], &[1, 256, 65536], 8);
528        unsafe {
529            execute_permute_blocked_par(src.as_ptr(), dst.as_mut_ptr(), &plan);
530        }
531
532        for i0 in [0, 1, 127, 255] {
533            for i1 in [0, 1, 127, 255] {
534                for i2 in [0, 1, 127, 255] {
535                    let dst_idx = i0 + i1 * n + i2 * n * n;
536                    let src_idx = i0 * 65536 + i1 + i2 * 256;
537                    assert_eq!(
538                        dst[dst_idx], src[src_idx],
539                        "mismatch at i0={i0}, i1={i1}, i2={i2}"
540                    );
541                }
542            }
543        }
544    }
545}