1use crate::hptt::macro_kernel::{
7 const_stride1_copy, macro_kernel_f32, macro_kernel_f64, macro_kernel_fallback,
8};
9use crate::hptt::plan::{ComputeNode, ExecMode, PermutePlan};
10
11#[cfg(feature = "parallel")]
12use rayon::iter::{IntoParallelIterator, ParallelIterator};
13
14#[cfg(feature = "parallel")]
16const MINTHREADLENGTH: usize = 1 << 15; pub unsafe fn execute_permute_blocked<T: Copy>(src: *const T, dst: *mut T, plan: &PermutePlan) {
25 match plan.mode {
26 ExecMode::Scalar => {
27 *dst = *src;
28 }
29 ExecMode::ConstStride1 { inner_dim } => {
30 let count = plan.fused_dims[inner_dim];
31 let src_stride = plan.src_strides[inner_dim];
32 let dst_stride = plan.dst_strides[inner_dim];
33 match &plan.root {
34 Some(root) => {
35 const_stride1_recursive(src, dst, root, count, src_stride, dst_stride);
36 }
37 None => {
38 const_stride1_copy(src, dst, count, src_stride, dst_stride);
39 }
40 }
41 }
42 ExecMode::Transpose { dim_a, dim_b } => {
43 let size_a = plan.fused_dims[dim_a];
44 let size_b = plan.fused_dims[dim_b];
45 let lda = plan.lda_inner;
46 let ldb = plan.ldb_inner;
47 let block = plan.block;
48 let elem_size = std::mem::size_of::<T>();
49
50 match &plan.root {
51 Some(root) => {
52 transpose_recursive(src, dst, root, size_a, size_b, lda, ldb, block, elem_size);
53 }
54 None => {
55 dispatch_blocked_2d(src, dst, size_a, size_b, lda, ldb, block, elem_size);
57 }
58 }
59 }
60 }
61}
62
63#[cfg(feature = "parallel")]
71pub unsafe fn execute_permute_blocked_par<T: Copy + Send + Sync>(
72 src: *const T,
73 dst: *mut T,
74 plan: &PermutePlan,
75) {
76 let total: usize = plan.fused_dims.iter().product();
77
78 if total < MINTHREADLENGTH {
79 execute_permute_blocked(src, dst, plan);
80 return;
81 }
82
83 let root = match &plan.root {
84 Some(r) => r,
85 None => {
86 execute_permute_blocked(src, dst, plan);
87 return;
88 }
89 };
90
91 let outer_dim = root.end;
92 if outer_dim <= 1 {
93 execute_permute_blocked(src, dst, plan);
94 return;
95 }
96
97 let src_addr = src as usize;
98 let dst_addr = dst as usize;
99 let lda_root = root.lda;
100 let ldb_root = root.ldb;
101 let elem_size = std::mem::size_of::<T>();
102 let inner = root.next.clone();
103
104 match plan.mode {
105 ExecMode::Transpose { dim_a, dim_b } => {
106 let size_a = plan.fused_dims[dim_a];
107 let size_b = plan.fused_dims[dim_b];
108 let lda = plan.lda_inner;
109 let ldb = plan.ldb_inner;
110 let block = plan.block;
111
112 (0..outer_dim).into_par_iter().for_each(|i| {
113 let s = (src_addr as isize + (i as isize) * lda_root * (elem_size as isize))
114 as *const T;
115 let d =
116 (dst_addr as isize + (i as isize) * ldb_root * (elem_size as isize)) as *mut T;
117
118 unsafe {
119 match &inner {
120 Some(next) => {
121 transpose_recursive(
122 s, d, next, size_a, size_b, lda, ldb, block, elem_size,
123 );
124 }
125 None => {
126 dispatch_blocked_2d(s, d, size_a, size_b, lda, ldb, block, elem_size);
127 }
128 }
129 }
130 });
131 }
132 ExecMode::ConstStride1 { inner_dim } => {
133 let count = plan.fused_dims[inner_dim];
134 let src_stride = plan.src_strides[inner_dim];
135 let dst_stride = plan.dst_strides[inner_dim];
136
137 (0..outer_dim).into_par_iter().for_each(|i| {
138 let s = (src_addr as isize + (i as isize) * lda_root * (elem_size as isize))
139 as *const T;
140 let d =
141 (dst_addr as isize + (i as isize) * ldb_root * (elem_size as isize)) as *mut T;
142
143 unsafe {
144 match &inner {
145 Some(next) => {
146 const_stride1_recursive(s, d, next, count, src_stride, dst_stride);
147 }
148 None => {
149 const_stride1_copy(s, d, count, src_stride, dst_stride);
150 }
151 }
152 }
153 });
154 }
155 ExecMode::Scalar => {
156 execute_permute_blocked(src, dst, plan);
157 }
158 }
159}
160
161unsafe fn transpose_recursive<T: Copy>(
170 src: *const T,
171 dst: *mut T,
172 node: &ComputeNode,
173 size_a: usize,
174 size_b: usize,
175 lda: isize,
176 ldb: isize,
177 block: usize,
178 elem_size: usize,
179) {
180 let end = node.end;
181 let node_lda = node.lda;
182 let node_ldb = node.ldb;
183
184 match &node.next {
185 Some(next) => {
186 let mut s = src;
187 let mut d = dst;
188 for _ in 0..end {
189 transpose_recursive(s, d, next, size_a, size_b, lda, ldb, block, elem_size);
190 s = s.offset(node_lda);
191 d = d.offset(node_ldb);
192 }
193 }
194 None => {
195 let mut s = src;
197 let mut d = dst;
198 for _ in 0..end {
199 dispatch_blocked_2d(s, d, size_a, size_b, lda, ldb, block, elem_size);
200 s = s.offset(node_lda);
201 d = d.offset(node_ldb);
202 }
203 }
204 }
205}
206
207#[inline]
211unsafe fn dispatch_blocked_2d<T: Copy>(
212 src: *const T,
213 dst: *mut T,
214 size_a: usize,
215 size_b: usize,
216 lda: isize,
217 ldb: isize,
218 block: usize,
219 elem_size: usize,
220) {
221 match elem_size {
222 8 => blocked_transpose_2d_f64(
223 src as *const f64,
224 dst as *mut f64,
225 size_a,
226 size_b,
227 lda,
228 ldb,
229 block,
230 ),
231 4 => blocked_transpose_2d_f32(
232 src as *const f32,
233 dst as *mut f32,
234 size_a,
235 size_b,
236 lda,
237 ldb,
238 block,
239 ),
240 _ => blocked_transpose_2d_fallback(src, dst, size_a, size_b, lda, ldb, block),
241 }
242}
243
244#[inline]
245unsafe fn blocked_transpose_2d_f64(
246 src: *const f64,
247 dst: *mut f64,
248 size_a: usize,
249 size_b: usize,
250 lda: isize,
251 ldb: isize,
252 block: usize,
253) {
254 let mut ib = 0usize;
255 while ib < size_b {
256 let bb = block.min(size_b - ib);
257 let mut ia = 0usize;
258 while ia < size_a {
259 let ba = block.min(size_a - ia);
260 macro_kernel_f64(
261 src.offset(ia as isize + ib as isize * lda),
262 lda,
263 ba,
264 dst.offset(ib as isize + ia as isize * ldb),
265 ldb,
266 bb,
267 );
268 ia += block;
269 }
270 ib += block;
271 }
272}
273
274#[inline]
275unsafe fn blocked_transpose_2d_f32(
276 src: *const f32,
277 dst: *mut f32,
278 size_a: usize,
279 size_b: usize,
280 lda: isize,
281 ldb: isize,
282 block: usize,
283) {
284 let mut ib = 0usize;
285 while ib < size_b {
286 let bb = block.min(size_b - ib);
287 let mut ia = 0usize;
288 while ia < size_a {
289 let ba = block.min(size_a - ia);
290 macro_kernel_f32(
291 src.offset(ia as isize + ib as isize * lda),
292 lda,
293 ba,
294 dst.offset(ib as isize + ia as isize * ldb),
295 ldb,
296 bb,
297 );
298 ia += block;
299 }
300 ib += block;
301 }
302}
303
304#[inline]
305unsafe fn blocked_transpose_2d_fallback<T: Copy>(
306 src: *const T,
307 dst: *mut T,
308 size_a: usize,
309 size_b: usize,
310 lda: isize,
311 ldb: isize,
312 block: usize,
313) {
314 let mut ib = 0usize;
315 while ib < size_b {
316 let bb = block.min(size_b - ib);
317 let mut ia = 0usize;
318 while ia < size_a {
319 let ba = block.min(size_a - ia);
320 macro_kernel_fallback(
321 src.offset(ia as isize + ib as isize * lda),
322 lda,
323 ba,
324 dst.offset(ib as isize + ia as isize * ldb),
325 ldb,
326 bb,
327 );
328 ia += block;
329 }
330 ib += block;
331 }
332}
333
334unsafe fn const_stride1_recursive<T: Copy>(
343 src: *const T,
344 dst: *mut T,
345 node: &ComputeNode,
346 count: usize,
347 src_stride: isize,
348 dst_stride: isize,
349) {
350 let end = node.end;
351 let node_lda = node.lda;
352 let node_ldb = node.ldb;
353
354 match &node.next {
355 Some(next) => {
356 let mut s = src;
357 let mut d = dst;
358 for _ in 0..end {
359 const_stride1_recursive(s, d, next, count, src_stride, dst_stride);
360 s = s.offset(node_lda);
361 d = d.offset(node_ldb);
362 }
363 }
364 None => {
365 let mut s = src;
366 let mut d = dst;
367 for _ in 0..end {
368 const_stride1_copy(s, d, count, src_stride, dst_stride);
369 s = s.offset(node_lda);
370 d = d.offset(node_ldb);
371 }
372 }
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::hptt::plan::build_permute_plan;
380
381 #[test]
382 fn test_execute_identity_copy() {
383 let src = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
384 let mut dst = vec![0.0f64; 6];
385 let plan = build_permute_plan(&[2, 3], &[1, 2], &[1, 2], 8);
386 unsafe {
387 execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
388 }
389 assert_eq!(dst, src);
390 }
391
392 #[test]
393 fn test_execute_transpose_2d() {
394 let src = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
398 let mut dst = vec![0.0f64; 6];
399 let plan = build_permute_plan(&[2, 3], &[3, 1], &[1, 2], 8);
400 unsafe {
401 execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
402 }
403 assert_eq!(dst, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
405 }
406
407 #[test]
408 fn test_execute_3d_permute() {
409 let dims = [2usize, 3, 4];
411 let total: usize = dims.iter().product();
412 let src: Vec<f64> = (0..total).map(|i| i as f64).collect();
413 let mut dst = vec![0.0f64; total];
414
415 let plan = build_permute_plan(&[4, 2, 3], &[6, 1, 2], &[1, 4, 8], 8);
417 unsafe {
418 execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
419 }
420
421 for k in 0..4 {
422 for i in 0..2 {
423 for j in 0..3 {
424 let dst_idx = k + i * 4 + j * 8;
425 let src_idx = i + j * 2 + k * 6;
426 assert_eq!(
427 dst[dst_idx], src[src_idx],
428 "mismatch at k={k}, i={i}, j={j}"
429 );
430 }
431 }
432 }
433 }
434
435 #[test]
436 fn test_execute_4d_permute() {
437 let dims = [2usize, 3, 4, 5];
438 let total: usize = dims.iter().product();
439 let src: Vec<f64> = (0..total).map(|i| i as f64).collect();
440 let mut dst = vec![0.0f64; total];
441
442 let plan = build_permute_plan(&[5, 3, 2, 4], &[24, 2, 1, 6], &[1, 5, 15, 30], 8);
444 unsafe {
445 execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
446 }
447
448 for i0 in 0..5 {
449 for i1 in 0..3 {
450 for i2 in 0..2 {
451 for i3 in 0..4 {
452 let src_idx = i0 * 24 + i1 * 2 + i2 + i3 * 6;
453 let dst_idx = i0 + i1 * 5 + i2 * 15 + i3 * 30;
454 assert_eq!(
455 dst[dst_idx], src[src_idx],
456 "4D mismatch at ({i0},{i1},{i2},{i3})"
457 );
458 }
459 }
460 }
461 }
462 }
463
464 #[test]
465 fn test_execute_5d_permute() {
466 let dims = [2usize, 2, 2, 2, 3];
467 let total: usize = dims.iter().product();
468 let src: Vec<f64> = (0..total).map(|i| i as f64).collect();
469 let mut dst = vec![0.0f64; total];
470
471 let plan = build_permute_plan(&[3, 2, 2, 2, 2], &[16, 1, 2, 4, 8], &[1, 3, 6, 12, 24], 8);
473 unsafe {
474 execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
475 }
476
477 for i0 in 0..3 {
478 for i1 in 0..2 {
479 for i2 in 0..2 {
480 for i3 in 0..2 {
481 for i4 in 0..2 {
482 let src_idx = i0 * 16 + i1 + i2 * 2 + i3 * 4 + i4 * 8;
483 let dst_idx = i0 + i1 * 3 + i2 * 6 + i3 * 12 + i4 * 24;
484 assert_eq!(
485 dst[dst_idx], src[src_idx],
486 "5D mismatch at ({i0},{i1},{i2},{i3},{i4})"
487 );
488 }
489 }
490 }
491 }
492 }
493 }
494
495 #[test]
496 fn test_execute_rank0_scalar() {
497 let src = vec![42.0f64];
498 let mut dst = vec![0.0f64];
499 let plan = build_permute_plan(&[], &[], &[], 8);
500 unsafe {
501 execute_permute_blocked(src.as_ptr(), dst.as_mut_ptr(), &plan);
502 }
503 assert_eq!(dst[0], 42.0);
504 }
505
506 #[cfg(feature = "parallel")]
507 #[test]
508 fn test_execute_par_transpose_2d() {
509 let src = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
510 let mut dst = vec![0.0f64; 6];
511 let plan = build_permute_plan(&[2, 3], &[3, 1], &[1, 2], 8);
512 unsafe {
513 execute_permute_blocked_par(src.as_ptr(), dst.as_mut_ptr(), &plan);
514 }
515 assert_eq!(dst, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
516 }
517
518 #[cfg(feature = "parallel")]
519 #[test]
520 fn test_execute_par_large() {
521 let n = 256;
522 let total = n * n * n;
523 let src: Vec<f64> = (0..total).map(|i| i as f64).collect();
524 let mut dst = vec![0.0f64; total];
525
526 let plan = build_permute_plan(&[n, n, n], &[65536, 1, 256], &[1, 256, 65536], 8);
528 unsafe {
529 execute_permute_blocked_par(src.as_ptr(), dst.as_mut_ptr(), &plan);
530 }
531
532 for i0 in [0, 1, 127, 255] {
533 for i1 in [0, 1, 127, 255] {
534 for i2 in [0, 1, 127, 255] {
535 let dst_idx = i0 + i1 * n + i2 * n * n;
536 let src_idx = i0 * 65536 + i1 + i2 * 256;
537 assert_eq!(
538 dst[dst_idx], src[src_idx],
539 "mismatch at i0={i0}, i1={i1}, i2={i2}"
540 );
541 }
542 }
543 }
544 }
545}