1use crate::fuse::fuse_dims_bilateral;
7use crate::hptt::micro_kernel::{MicroKernel, ScalarKernel};
8
9#[derive(Debug, Clone)]
14pub struct ComputeNode {
15 pub end: usize,
17 pub lda: isize,
19 pub ldb: isize,
21 pub next: Option<Box<ComputeNode>>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ExecMode {
28 Transpose {
30 dim_a: usize,
32 dim_b: usize,
34 },
35 ConstStride1 {
37 inner_dim: usize,
39 },
40 Scalar,
42}
43
44#[derive(Debug)]
46pub struct PermutePlan {
47 pub fused_dims: Vec<usize>,
49 pub src_strides: Vec<isize>,
51 pub dst_strides: Vec<isize>,
53 pub root: Option<ComputeNode>,
55 pub mode: ExecMode,
57 pub lda_inner: isize,
61 pub ldb_inner: isize,
63 pub block: usize,
65}
66
67pub fn build_permute_plan(
72 dims: &[usize],
73 src_strides: &[isize],
74 dst_strides: &[isize],
75 elem_size: usize,
76) -> PermutePlan {
77 let (fused_dims, fused_src, fused_dst) = fuse_dims_bilateral(dims, src_strides, dst_strides);
79
80 let rank = fused_dims.len();
81 if rank == 0 {
82 return PermutePlan {
83 fused_dims,
84 src_strides: fused_src,
85 dst_strides: fused_dst,
86 root: None,
87 mode: ExecMode::Scalar,
88 lda_inner: 0,
89 ldb_inner: 0,
90 block: 0,
91 };
92 }
93
94 let dim_a = find_stride1_dim(&fused_dims, &fused_src);
96 let dim_b = find_stride1_dim(&fused_dims, &fused_dst);
97
98 let block = block_for_elem_size(elem_size);
100
101 if dim_a == dim_b {
102 let inner_dim = dim_a;
104 let mode = ExecMode::ConstStride1 { inner_dim };
105
106 let loop_order = compute_loop_order_const(&fused_dims, &fused_src, &fused_dst, inner_dim);
107 let root = build_compute_nodes(&fused_dims, &fused_src, &fused_dst, &loop_order);
108
109 PermutePlan {
110 fused_dims,
111 src_strides: fused_src.clone(),
112 dst_strides: fused_dst.clone(),
113 root,
114 mode,
115 lda_inner: fused_src[inner_dim],
116 ldb_inner: fused_dst[inner_dim],
117 block: 0,
118 }
119 } else {
120 let mode = ExecMode::Transpose { dim_a, dim_b };
122
123 let lda_inner = fused_src[dim_b];
126 let ldb_inner = fused_dst[dim_a];
127
128 let loop_order =
129 compute_loop_order_transpose(&fused_dims, &fused_src, &fused_dst, dim_a, dim_b);
130 let root = build_compute_nodes(&fused_dims, &fused_src, &fused_dst, &loop_order);
131
132 PermutePlan {
133 fused_dims,
134 src_strides: fused_src,
135 dst_strides: fused_dst,
136 root,
137 mode,
138 lda_inner,
139 ldb_inner,
140 block,
141 }
142 }
143}
144
145fn find_stride1_dim(dims: &[usize], strides: &[isize]) -> usize {
147 dims.iter()
148 .zip(strides.iter())
149 .enumerate()
150 .filter(|(_, (&d, _))| d > 1)
151 .min_by_key(|(_, (_, &s))| s.unsigned_abs())
152 .map(|(i, _)| i)
153 .unwrap_or(0)
154}
155
156fn block_for_elem_size(elem_size: usize) -> usize {
158 match elem_size {
159 8 => <ScalarKernel as MicroKernel<f64>>::BLOCK, 4 => <ScalarKernel as MicroKernel<f32>>::BLOCK, _ => 16, }
163}
164
165fn compute_loop_order_transpose(
170 dims: &[usize],
171 src_strides: &[isize],
172 dst_strides: &[isize],
173 dim_a: usize,
174 dim_b: usize,
175) -> Vec<usize> {
176 let mut loop_dims: Vec<usize> = (0..dims.len())
177 .filter(|&d| d != dim_a && d != dim_b && dims[d] > 1)
178 .collect();
179 loop_dims.sort_by(|&a, &b| {
180 let cost_a = src_strides[a].unsigned_abs() + dst_strides[a].unsigned_abs();
181 let cost_b = src_strides[b].unsigned_abs() + dst_strides[b].unsigned_abs();
182 cost_b.cmp(&cost_a)
183 });
184 loop_dims
185}
186
187fn compute_loop_order_const(
196 dims: &[usize],
197 _src_strides: &[isize],
198 dst_strides: &[isize],
199 inner_dim: usize,
200) -> Vec<usize> {
201 let mut loop_dims: Vec<usize> = (0..dims.len())
202 .filter(|&d| d != inner_dim && dims[d] > 1)
203 .collect();
204 loop_dims.sort_by(|&a, &b| {
205 dst_strides[b]
206 .unsigned_abs()
207 .cmp(&dst_strides[a].unsigned_abs())
208 });
209 loop_dims
210}
211
212fn build_compute_nodes(
218 dims: &[usize],
219 src_strides: &[isize],
220 dst_strides: &[isize],
221 loop_order: &[usize],
222) -> Option<ComputeNode> {
223 let mut current: Option<ComputeNode> = None;
224
225 for &d in loop_order.iter().rev() {
227 let node = ComputeNode {
228 end: dims[d],
229 lda: src_strides[d],
230 ldb: dst_strides[d],
231 next: current.map(Box::new),
232 };
233 current = Some(node);
234 }
235
236 current
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_find_stride1_dim_basic() {
245 assert_eq!(find_stride1_dim(&[4, 5], &[1, 4]), 0);
246 assert_eq!(find_stride1_dim(&[4, 5], &[5, 1]), 1);
247 }
248
249 #[test]
250 fn test_find_stride1_dim_skips_size1() {
251 assert_eq!(find_stride1_dim(&[1, 5], &[1, 2]), 1);
253 }
254
255 #[test]
256 fn test_build_plan_identity() {
257 let plan = build_permute_plan(&[2, 3, 4], &[1, 2, 6], &[1, 2, 6], 8);
259 assert_eq!(plan.fused_dims, vec![24]);
260 assert!(matches!(plan.mode, ExecMode::ConstStride1 { .. }));
261 }
262
263 #[test]
264 fn test_build_plan_transpose_2d() {
265 let plan = build_permute_plan(&[4, 5], &[1, 4], &[5, 1], 8);
267 assert_eq!(plan.fused_dims, vec![4, 5]);
268 match plan.mode {
269 ExecMode::Transpose { dim_a, dim_b } => {
270 assert_eq!(dim_a, 0); assert_eq!(dim_b, 1); }
273 _ => panic!("expected Transpose mode"),
274 }
275 assert_eq!(plan.block, 16); assert_eq!(plan.lda_inner, 4); assert_eq!(plan.ldb_inner, 5); assert!(plan.root.is_none());
280 }
281
282 #[test]
283 fn test_build_plan_3d_permute() {
284 let plan = build_permute_plan(&[4, 2, 3], &[6, 1, 2], &[1, 4, 8], 8);
288 assert_eq!(plan.fused_dims, vec![4, 6]);
289 match plan.mode {
290 ExecMode::Transpose { dim_a, dim_b } => {
291 assert_eq!(dim_a, 1);
293 assert_eq!(dim_b, 0);
295 }
296 _ => panic!("expected Transpose mode"),
297 }
298 assert!(plan.root.is_none());
300 }
301
302 #[test]
303 fn test_build_plan_scattered_strides() {
304 let dims = vec![2, 2, 2, 2];
306 let src_strides = vec![1, 8, 2, 4]; let dst_strides = vec![1, 2, 4, 8]; let plan = build_permute_plan(&dims, &src_strides, &dst_strides, 8);
310
311 assert_eq!(plan.fused_dims.len(), 3);
314
315 match plan.mode {
317 ExecMode::Transpose { .. } | ExecMode::ConstStride1 { .. } => {
318 }
320 _ => panic!("unexpected mode"),
321 }
322 }
323
324 #[test]
325 fn test_build_plan_rank0() {
326 let plan = build_permute_plan(&[], &[], &[], 8);
327 assert!(matches!(plan.mode, ExecMode::Scalar));
328 assert!(plan.root.is_none());
329 }
330
331 #[test]
332 fn test_compute_loop_order_transpose() {
333 let dims = [4, 5, 3, 7];
334 let src_s = [1isize, 4, 100, 300];
335 let dst_s = [35isize, 1, 7, 21];
336 let order = compute_loop_order_transpose(&dims, &src_s, &dst_s, 0, 1);
338 assert_eq!(order, vec![3, 2]);
342 }
343
344 #[test]
345 fn test_build_compute_nodes_chain() {
346 let dims = [10, 5, 3];
347 let src_s = [1isize, 10, 50];
348 let dst_s = [15isize, 1, 5];
349 let loop_order = vec![2]; let root = build_compute_nodes(&dims, &src_s, &dst_s, &loop_order);
352 assert!(root.is_some());
353 let root = root.unwrap();
354 assert_eq!(root.end, 3);
355 assert_eq!(root.lda, 50);
356 assert_eq!(root.ldb, 5);
357 assert!(root.next.is_none());
358 }
359}