Skip to main content

strided_perm/hptt/
plan.rs

1//! Plan construction for HPTT-faithful tensor permutation.
2//!
3//! Mirrors HPTT C++'s plan construction: bilateral fusion → identify stride-1
4//! dims → determine execution mode → compute loop order → build ComputeNode chain.
5
6use crate::fuse::fuse_dims_bilateral;
7use crate::hptt::micro_kernel::{MicroKernel, ScalarKernel};
8
9/// A node in the recursive loop structure.
10///
11/// Mirrors HPTT's ComputeNode linked list. Each node represents one
12/// loop level in the execution nest.
13#[derive(Debug, Clone)]
14pub struct ComputeNode {
15    /// End index for this loop (loop runs 0..end).
16    pub end: usize,
17    /// Source stride for this dimension.
18    pub lda: isize,
19    /// Destination stride for this dimension.
20    pub ldb: isize,
21    /// Next node in the chain (None = leaf → calls macro_kernel or memcpy).
22    pub next: Option<Box<ComputeNode>>,
23}
24
25/// Execution mode determined at plan time.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ExecMode {
28    /// dim_A != dim_B: 2D micro-kernel transpose path.
29    Transpose {
30        /// Dimension with smallest |src_stride| (stride-1 in source).
31        dim_a: usize,
32        /// Dimension with smallest |dst_stride| (stride-1 in dest).
33        dim_b: usize,
34    },
35    /// dim_A == dim_B (perm[0]==0 equivalent): memcpy/strided-copy path.
36    ConstStride1 {
37        /// The shared stride-1 dimension.
38        inner_dim: usize,
39    },
40    /// Rank 0: single element copy.
41    Scalar,
42}
43
44/// Complete permutation plan.
45#[derive(Debug)]
46pub struct PermutePlan {
47    /// Fused dimensions (after bilateral fusion).
48    pub fused_dims: Vec<usize>,
49    /// Fused source strides.
50    pub src_strides: Vec<isize>,
51    /// Fused destination strides.
52    pub dst_strides: Vec<isize>,
53    /// Root of the recursive loop structure (None for Scalar mode).
54    pub root: Option<ComputeNode>,
55    /// Execution mode.
56    pub mode: ExecMode,
57    /// Source stride along dim_B — the "lda" for macro_kernel.
58    /// (In the 2D view for the macro-kernel, this is the stride that
59    /// steps between columns of the source tile.)
60    pub lda_inner: isize,
61    /// Dest stride along dim_A — the "ldb" for macro_kernel.
62    pub ldb_inner: isize,
63    /// Macro-kernel tile size (= BLOCK, e.g. 16 for f64).
64    pub block: usize,
65}
66
67/// Build a permutation plan using bilateral fusion and HPTT-style blocking.
68///
69/// This is the main entry point. The returned plan is consumed by
70/// `execute_permute_blocked`.
71pub fn build_permute_plan(
72    dims: &[usize],
73    src_strides: &[isize],
74    dst_strides: &[isize],
75    elem_size: usize,
76) -> PermutePlan {
77    // Phase 1: Bilateral dimension fusion
78    let (fused_dims, fused_src, fused_dst) = fuse_dims_bilateral(dims, src_strides, dst_strides);
79
80    let rank = fused_dims.len();
81    if rank == 0 {
82        return PermutePlan {
83            fused_dims,
84            src_strides: fused_src,
85            dst_strides: fused_dst,
86            root: None,
87            mode: ExecMode::Scalar,
88            lda_inner: 0,
89            ldb_inner: 0,
90            block: 0,
91        };
92    }
93
94    // Phase 2: Identify stride-1 dimensions
95    let dim_a = find_stride1_dim(&fused_dims, &fused_src);
96    let dim_b = find_stride1_dim(&fused_dims, &fused_dst);
97
98    // Phase 3: Determine execution mode and blocking
99    let block = block_for_elem_size(elem_size);
100
101    if dim_a == dim_b {
102        // ConstStride1 path: both stride-1 dims are the same
103        let inner_dim = dim_a;
104        let mode = ExecMode::ConstStride1 { inner_dim };
105
106        let loop_order = compute_loop_order_const(&fused_dims, &fused_src, &fused_dst, inner_dim);
107        let root = build_compute_nodes(&fused_dims, &fused_src, &fused_dst, &loop_order);
108
109        PermutePlan {
110            fused_dims,
111            src_strides: fused_src.clone(),
112            dst_strides: fused_dst.clone(),
113            root,
114            mode,
115            lda_inner: fused_src[inner_dim],
116            ldb_inner: fused_dst[inner_dim],
117            block: 0,
118        }
119    } else {
120        // Transpose path: 2D micro-kernel
121        let mode = ExecMode::Transpose { dim_a, dim_b };
122
123        // lda_inner = src stride along dim_B (steps between rows in the 2D micro-kernel view)
124        // ldb_inner = dst stride along dim_A (steps between rows in the transposed view)
125        let lda_inner = fused_src[dim_b];
126        let ldb_inner = fused_dst[dim_a];
127
128        let loop_order =
129            compute_loop_order_transpose(&fused_dims, &fused_src, &fused_dst, dim_a, dim_b);
130        let root = build_compute_nodes(&fused_dims, &fused_src, &fused_dst, &loop_order);
131
132        PermutePlan {
133            fused_dims,
134            src_strides: fused_src,
135            dst_strides: fused_dst,
136            root,
137            mode,
138            lda_inner,
139            ldb_inner,
140            block,
141        }
142    }
143}
144
145/// Find the dimension with the smallest absolute stride among non-trivial dims.
146fn find_stride1_dim(dims: &[usize], strides: &[isize]) -> usize {
147    dims.iter()
148        .zip(strides.iter())
149        .enumerate()
150        .filter(|(_, (&d, _))| d > 1)
151        .min_by_key(|(_, (_, &s))| s.unsigned_abs())
152        .map(|(i, _)| i)
153        .unwrap_or(0)
154}
155
156/// BLOCK size for a given element size (matches HPTT's blocking_ = micro * 4).
157fn block_for_elem_size(elem_size: usize) -> usize {
158    match elem_size {
159        8 => <ScalarKernel as MicroKernel<f64>>::BLOCK, // 16
160        4 => <ScalarKernel as MicroKernel<f32>>::BLOCK, // 32
161        _ => 16,                                        // default
162    }
163}
164
165/// Compute loop order for Transpose mode.
166///
167/// Excludes dim_a and dim_b (consumed by macro_kernel).
168/// Remaining dims sorted by stride cost descending (largest strides outermost).
169fn compute_loop_order_transpose(
170    dims: &[usize],
171    src_strides: &[isize],
172    dst_strides: &[isize],
173    dim_a: usize,
174    dim_b: usize,
175) -> Vec<usize> {
176    let mut loop_dims: Vec<usize> = (0..dims.len())
177        .filter(|&d| d != dim_a && d != dim_b && dims[d] > 1)
178        .collect();
179    loop_dims.sort_by(|&a, &b| {
180        let cost_a = src_strides[a].unsigned_abs() + dst_strides[a].unsigned_abs();
181        let cost_b = src_strides[b].unsigned_abs() + dst_strides[b].unsigned_abs();
182        cost_b.cmp(&cost_a)
183    });
184    loop_dims
185}
186
187/// Compute loop order for ConstStride1 mode.
188///
189/// Excludes inner_dim (handled by memcpy at leaf).
190/// Remaining dims sorted by |dst_stride| descending: largest dst stride outermost,
191/// smallest innermost. This ensures the innermost loops advance by the smallest
192/// dst offsets, building up contiguous blocks that tile perfectly with the
193/// stride-1 inner copy. For a column-major dst (common case), this gives
194/// fully sequential write access.
195fn compute_loop_order_const(
196    dims: &[usize],
197    _src_strides: &[isize],
198    dst_strides: &[isize],
199    inner_dim: usize,
200) -> Vec<usize> {
201    let mut loop_dims: Vec<usize> = (0..dims.len())
202        .filter(|&d| d != inner_dim && dims[d] > 1)
203        .collect();
204    loop_dims.sort_by(|&a, &b| {
205        dst_strides[b]
206            .unsigned_abs()
207            .cmp(&dst_strides[a].unsigned_abs())
208    });
209    loop_dims
210}
211
212/// Build a linked-list ComputeNode chain from loop_order.
213///
214/// All nodes have inc=1 (the two stride-1 dims are not in the chain;
215/// they are handled by macro_kernel or memcpy at the leaf).
216/// Returns None if loop_order is empty (all work done at the leaf).
217fn build_compute_nodes(
218    dims: &[usize],
219    src_strides: &[isize],
220    dst_strides: &[isize],
221    loop_order: &[usize],
222) -> Option<ComputeNode> {
223    let mut current: Option<ComputeNode> = None;
224
225    // Build from innermost (last in loop_order) to outermost (first)
226    for &d in loop_order.iter().rev() {
227        let node = ComputeNode {
228            end: dims[d],
229            lda: src_strides[d],
230            ldb: dst_strides[d],
231            next: current.map(Box::new),
232        };
233        current = Some(node);
234    }
235
236    current
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_find_stride1_dim_basic() {
245        assert_eq!(find_stride1_dim(&[4, 5], &[1, 4]), 0);
246        assert_eq!(find_stride1_dim(&[4, 5], &[5, 1]), 1);
247    }
248
249    #[test]
250    fn test_find_stride1_dim_skips_size1() {
251        // dim 0 has stride 1 but size 1 — should pick dim 1
252        assert_eq!(find_stride1_dim(&[1, 5], &[1, 2]), 1);
253    }
254
255    #[test]
256    fn test_build_plan_identity() {
257        // Identity: src and dst both col-major → fuses to single dim → ConstStride1
258        let plan = build_permute_plan(&[2, 3, 4], &[1, 2, 6], &[1, 2, 6], 8);
259        assert_eq!(plan.fused_dims, vec![24]);
260        assert!(matches!(plan.mode, ExecMode::ConstStride1 { .. }));
261    }
262
263    #[test]
264    fn test_build_plan_transpose_2d() {
265        // 2D transpose: src [1, 4], dst [5, 1]
266        let plan = build_permute_plan(&[4, 5], &[1, 4], &[5, 1], 8);
267        assert_eq!(plan.fused_dims, vec![4, 5]);
268        match plan.mode {
269            ExecMode::Transpose { dim_a, dim_b } => {
270                assert_eq!(dim_a, 0); // src stride-1
271                assert_eq!(dim_b, 1); // dst stride-1
272            }
273            _ => panic!("expected Transpose mode"),
274        }
275        assert_eq!(plan.block, 16); // f64 BLOCK
276        assert_eq!(plan.lda_inner, 4); // src stride along dim_b
277        assert_eq!(plan.ldb_inner, 5); // dst stride along dim_a
278                                       // No loop nodes (only 2 dims, both consumed by macro_kernel)
279        assert!(plan.root.is_none());
280    }
281
282    #[test]
283    fn test_build_plan_3d_permute() {
284        // 3D: dims [4,2,3], src strides [6,1,2], dst [1,4,8]
285        // Bilateral fusion: dims 1-2 fuse (src: 2*1=2 == strides[2], dst: 2*4=8 == strides[2])
286        // After fusion: dims [4, 6], src [6, 1], dst [1, 4]
287        let plan = build_permute_plan(&[4, 2, 3], &[6, 1, 2], &[1, 4, 8], 8);
288        assert_eq!(plan.fused_dims, vec![4, 6]);
289        match plan.mode {
290            ExecMode::Transpose { dim_a, dim_b } => {
291                // dim_a: min |src_stride| → dim 1 (stride 1)
292                assert_eq!(dim_a, 1);
293                // dim_b: min |dst_stride| → dim 0 (stride 1)
294                assert_eq!(dim_b, 0);
295            }
296            _ => panic!("expected Transpose mode"),
297        }
298        // Only 2 fused dims, both consumed by macro_kernel → no outer loops
299        assert!(plan.root.is_none());
300    }
301
302    #[test]
303    fn test_build_plan_scattered_strides() {
304        // Simplified scattered case: 4 dims of size 2
305        let dims = vec![2, 2, 2, 2];
306        let src_strides = vec![1, 8, 2, 4]; // scattered
307        let dst_strides = vec![1, 2, 4, 8]; // col-major
308
309        let plan = build_permute_plan(&dims, &src_strides, &dst_strides, 8);
310
311        // Bilateral fusion: dims 2-3 fuse (src: 2→4 contiguous, dst: 4→8 contiguous)
312        // Result: 3 fused dims
313        assert_eq!(plan.fused_dims.len(), 3);
314
315        // dim_a and dim_b should be identified correctly
316        match plan.mode {
317            ExecMode::Transpose { .. } | ExecMode::ConstStride1 { .. } => {
318                // After bilateral fusion, the mode depends on which dims fuse
319            }
320            _ => panic!("unexpected mode"),
321        }
322    }
323
324    #[test]
325    fn test_build_plan_rank0() {
326        let plan = build_permute_plan(&[], &[], &[], 8);
327        assert!(matches!(plan.mode, ExecMode::Scalar));
328        assert!(plan.root.is_none());
329    }
330
331    #[test]
332    fn test_compute_loop_order_transpose() {
333        let dims = [4, 5, 3, 7];
334        let src_s = [1isize, 4, 100, 300];
335        let dst_s = [35isize, 1, 7, 21];
336        // dim_a=0 (min src stride), dim_b=1 (min dst stride)
337        let order = compute_loop_order_transpose(&dims, &src_s, &dst_s, 0, 1);
338        // Remaining: dims 2 and 3
339        // cost[2] = 100 + 7 = 107, cost[3] = 300 + 21 = 321
340        // Descending: [3, 2]
341        assert_eq!(order, vec![3, 2]);
342    }
343
344    #[test]
345    fn test_build_compute_nodes_chain() {
346        let dims = [10, 5, 3];
347        let src_s = [1isize, 10, 50];
348        let dst_s = [15isize, 1, 5];
349        let loop_order = vec![2]; // only dim 2 in the loop
350
351        let root = build_compute_nodes(&dims, &src_s, &dst_s, &loop_order);
352        assert!(root.is_some());
353        let root = root.unwrap();
354        assert_eq!(root.end, 3);
355        assert_eq!(root.lda, 50);
356        assert_eq!(root.ldb, 5);
357        assert!(root.next.is_none());
358    }
359}