1use crate::fuse::{compress_dims, fuse_dims};
4use crate::{block, order};
5use strided_view::Result;
6
7pub const SMALL_TENSOR_THRESHOLD: usize = 1024;
9
10pub struct KernelPlan {
11 pub order: Vec<usize>,
12 pub block: Vec<usize>,
13}
14
15pub fn build_plan_fused(
19 dims: &[usize],
20 strides_list: &[&[isize]],
21 dest_index: Option<usize>,
22 elem_size: usize,
23) -> (Vec<usize>, Vec<Vec<isize>>, KernelPlan) {
24 let order = order::compute_order(dims, strides_list, dest_index);
25
26 let ordered_dims: Vec<usize> = order.iter().map(|&d| dims[d]).collect();
27 let ordered_strides: Vec<Vec<isize>> = strides_list
28 .iter()
29 .map(|strides| order.iter().map(|&d| strides[d]).collect())
30 .collect();
31 let ordered_strides_refs: Vec<&[isize]> =
32 ordered_strides.iter().map(|s| s.as_slice()).collect();
33
34 let fused_dims = fuse_dims(&ordered_dims, &ordered_strides_refs);
35 let (compressed_dims, compressed_strides) = compress_dims(&fused_dims, &ordered_strides);
36 let compressed_strides_refs: Vec<&[isize]> =
37 compressed_strides.iter().map(|s| s.as_slice()).collect();
38
39 let identity: Vec<usize> = (0..compressed_dims.len()).collect();
40 let block = block::compute_block_sizes(
41 &compressed_dims,
42 &identity,
43 &compressed_strides_refs,
44 elem_size,
45 );
46
47 (
48 compressed_dims,
49 compressed_strides,
50 KernelPlan {
51 order: identity,
52 block,
53 },
54 )
55}
56
57pub fn build_plan_fused_small(
59 dims: &[usize],
60 strides_list: &[&[isize]],
61) -> (Vec<usize>, Vec<Vec<isize>>, KernelPlan) {
62 let strides_owned: Vec<Vec<isize>> = strides_list.iter().map(|s| s.to_vec()).collect();
63
64 let fused = fuse_dims(dims, strides_list);
65 let (fused_dims, fused_strides) = compress_dims(&fused, &strides_owned);
66
67 let block = fused_dims.clone();
68 let identity: Vec<usize> = (0..fused_dims.len()).collect();
69
70 (
71 fused_dims,
72 fused_strides,
73 KernelPlan {
74 order: identity,
75 block,
76 },
77 )
78}
79
80macro_rules! elem_loops {
85 ($offsets:ident, $strides:ident, $f:ident, $blens:ident, $is:ident; $lv:literal) => {
86 for _ in 0..$blens[$lv] {
87 $f($offsets, $blens[0], &$is)?;
88 for (o, s) in $offsets.iter_mut().zip($strides.iter()) {
89 *o += s[$lv];
90 }
91 }
92 };
93 ($offsets:ident, $strides:ident, $f:ident, $blens:ident, $is:ident;
94 $lv:literal, $next:literal $(, $rest:literal)*) => {
95 for _ in 0..$blens[$lv] {
96 elem_loops!($offsets, $strides, $f, $blens, $is; $next $(, $rest)*);
97 for (o, s) in $offsets.iter_mut().zip($strides.iter()) {
98 *o -= ($blens[$next] as isize) * s[$next];
99 *o += s[$lv];
100 }
101 }
102 };
103}
104
105macro_rules! block_loop {
106 ($dims:ident, $blocks:ident, $strides:ident, $offsets:ident, $f:ident,
107 $blens:ident, $is:ident; elem=[$($el:literal),+]; $lv0:literal; top=$top:literal) => {{
108 let mut _j = 0usize;
109 while _j < $dims[$lv0] {
110 $blens[$lv0] = $blocks[$lv0].max(1).min($dims[$lv0]).min($dims[$lv0] - _j);
111 elem_loops!($offsets, $strides, $f, $blens, $is; $($el),+);
112 for (o, s) in $offsets.iter_mut().zip($strides.iter()) {
113 *o -= ($blens[$top] as isize) * s[$top];
114 *o += ($blens[$lv0] as isize) * s[$lv0];
115 }
116 _j += $blens[$lv0];
117 }
118 }};
119 ($dims:ident, $blocks:ident, $strides:ident, $offsets:ident, $f:ident,
120 $blens:ident, $is:ident; elem=[$($el:literal),+];
121 $lv:literal, $next:literal $(, $rest:literal)*; top=$top:literal) => {{
122 let mut _j = 0usize;
123 while _j < $dims[$lv] {
124 $blens[$lv] = $blocks[$lv].max(1).min($dims[$lv]).min($dims[$lv] - _j);
125 block_loop!($dims, $blocks, $strides, $offsets, $f, $blens, $is;
126 elem=[$($el),+]; $next $(, $rest)*; top=$top);
127 for (o, s) in $offsets.iter_mut().zip($strides.iter()) {
128 *o -= ($dims[$next] as isize) * s[$next];
129 *o += ($blens[$lv] as isize) * s[$lv];
130 }
131 _j += $blens[$lv];
132 }
133 }};
134}
135
136macro_rules! make_kernel {
137 ($name:ident, rank=1) => {
138 #[inline]
139 fn $name<F>(
140 dims: &[usize],
141 blocks: &[usize],
142 strides: &[Vec<isize>],
143 offsets: &mut [isize],
144 f: &mut F,
145 ) -> Result<()>
146 where
147 F: FnMut(&[isize], usize, &[isize]) -> Result<()>,
148 {
149 let d0 = dims[0];
150 let b0 = blocks[0].max(1).min(d0);
151 let inner_strides: Vec<isize> = strides.iter().map(|s| s[0]).collect();
152 let mut j0 = 0usize;
153 while j0 < d0 {
154 let blen0 = b0.min(d0 - j0);
155 f(offsets, blen0, &inner_strides)?;
156 for (o, s) in offsets.iter_mut().zip(strides.iter()) {
157 *o += (blen0 as isize) * s[0];
158 }
159 j0 += blen0;
160 }
161 for (o, s) in offsets.iter_mut().zip(strides.iter()) {
162 *o -= (d0 as isize) * s[0];
163 }
164 Ok(())
165 }
166 };
167 ($name:ident, rank=$rank:literal,
168 block=[$($blk:literal),+], elem=[$($el:literal),+], top=$top:literal) => {
169 #[inline]
170 fn $name<F>(
171 dims: &[usize],
172 blocks: &[usize],
173 strides: &[Vec<isize>],
174 offsets: &mut [isize],
175 f: &mut F,
176 ) -> Result<()>
177 where
178 F: FnMut(&[isize], usize, &[isize]) -> Result<()>,
179 {
180 let inner_strides: Vec<isize> = strides.iter().map(|s| s[0]).collect();
181 let mut blens = [0usize; $rank];
182 block_loop!(dims, blocks, strides, offsets, f, blens, inner_strides;
183 elem=[$($el),+]; $($blk),+; top=$top);
184 for (o, s) in offsets.iter_mut().zip(strides.iter()) {
185 *o -= (dims[$top] as isize) * s[$top];
186 }
187 Ok(())
188 }
189 };
190}
191
192make_kernel!(kernel_1d_inner, rank = 1);
193make_kernel!(
194 kernel_2d_inner,
195 rank = 2,
196 block = [1, 0],
197 elem = [1],
198 top = 1
199);
200make_kernel!(
201 kernel_3d_inner,
202 rank = 3,
203 block = [2, 1, 0],
204 elem = [2, 1],
205 top = 2
206);
207make_kernel!(
208 kernel_4d_inner,
209 rank = 4,
210 block = [3, 2, 1, 0],
211 elem = [3, 2, 1],
212 top = 3
213);
214make_kernel!(
215 kernel_5d_inner,
216 rank = 5,
217 block = [4, 3, 2, 1, 0],
218 elem = [4, 3, 2, 1],
219 top = 4
220);
221make_kernel!(
222 kernel_6d_inner,
223 rank = 6,
224 block = [5, 4, 3, 2, 1, 0],
225 elem = [5, 4, 3, 2, 1],
226 top = 5
227);
228make_kernel!(
229 kernel_7d_inner,
230 rank = 7,
231 block = [6, 5, 4, 3, 2, 1, 0],
232 elem = [6, 5, 4, 3, 2, 1],
233 top = 6
234);
235make_kernel!(
236 kernel_8d_inner,
237 rank = 8,
238 block = [7, 6, 5, 4, 3, 2, 1, 0],
239 elem = [7, 6, 5, 4, 3, 2, 1],
240 top = 7
241);
242
243#[inline]
245fn kernel_nd_inner_iterative<F>(
246 dims: &[usize],
247 blocks: &[usize],
248 strides: &[Vec<isize>],
249 offsets: &mut [isize],
250 f: &mut F,
251) -> Result<()>
252where
253 F: FnMut(&[isize], usize, &[isize]) -> Result<()>,
254{
255 let rank = dims.len();
256 debug_assert!(rank >= 9);
257
258 let d0 = dims[0];
259 let b0 = blocks[0].max(1).min(d0);
260 let inner_strides: Vec<isize> = strides.iter().map(|s| s[0]).collect();
261
262 let mut idx = vec![0usize; rank];
263
264 loop {
265 let mut j0 = 0usize;
266 while j0 < d0 {
267 let blen0 = b0.min(d0 - j0);
268 f(offsets, blen0, &inner_strides)?;
269 for (offset, s) in offsets.iter_mut().zip(strides.iter()) {
270 *offset += (blen0 as isize) * s[0];
271 }
272 j0 += blen0;
273 }
274 for (offset, s) in offsets.iter_mut().zip(strides.iter()) {
275 *offset -= (d0 as isize) * s[0];
276 }
277
278 let mut level = 1usize;
279 loop {
280 for (offset, s) in offsets.iter_mut().zip(strides.iter()) {
281 *offset += s[level];
282 }
283 idx[level] += 1;
284 if idx[level] < dims[level] {
285 break;
286 }
287
288 idx[level] = 0;
289 for (offset, s) in offsets.iter_mut().zip(strides.iter()) {
290 *offset -= (dims[level] as isize) * s[level];
291 }
292 level += 1;
293 if level == rank {
294 return Ok(());
295 }
296 }
297 }
298}
299
300#[inline]
302pub fn for_each_inner_block_preordered<F>(
303 dims: &[usize],
304 blocks: &[usize],
305 strides: &[Vec<isize>],
306 initial_offsets: &[isize],
307 mut f: F,
308) -> Result<()>
309where
310 F: FnMut(&[isize], usize, &[isize]) -> Result<()>,
311{
312 let rank = dims.len();
313 if rank == 0 {
314 return f(initial_offsets, 1, &[]);
315 }
316
317 let mut offsets = initial_offsets.to_vec();
318
319 match rank {
320 1 => kernel_1d_inner(dims, blocks, strides, &mut offsets, &mut f),
321 2 => kernel_2d_inner(dims, blocks, strides, &mut offsets, &mut f),
322 3 => kernel_3d_inner(dims, blocks, strides, &mut offsets, &mut f),
323 4 => kernel_4d_inner(dims, blocks, strides, &mut offsets, &mut f),
324 5 => kernel_5d_inner(dims, blocks, strides, &mut offsets, &mut f),
325 6 => kernel_6d_inner(dims, blocks, strides, &mut offsets, &mut f),
326 7 => kernel_7d_inner(dims, blocks, strides, &mut offsets, &mut f),
327 8 => kernel_8d_inner(dims, blocks, strides, &mut offsets, &mut f),
328 _ => kernel_nd_inner_iterative(dims, blocks, strides, &mut offsets, &mut f),
329 }
330}
331
332#[inline]
334pub fn total_len(dims: &[usize]) -> usize {
335 if dims.is_empty() {
336 return 1;
337 }
338 dims.iter().product()
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_build_plan_fused_compresses() {
347 let dims = [2usize, 3];
348 let strides = [1isize, 2];
349 let strides_list: Vec<&[isize]> = vec![&strides];
350 let (fused_dims, fused_strides, plan) = build_plan_fused(&dims, &strides_list, Some(0), 8);
351 assert_eq!(fused_dims, vec![6]);
352 assert_eq!(fused_strides.len(), 1);
353 assert_eq!(fused_strides[0], vec![1]);
354 assert_eq!(plan.block.len(), 1);
355 }
356
357 #[test]
358 fn test_for_each_inner_block_preordered_total() {
359 let dims = vec![3, 4, 2];
360 let blocks = vec![3, 4, 2];
361 let strides = vec![vec![1, 3, 12], vec![1, 3, 12]];
362 let offsets = vec![0, 0];
363 let mut total = 0usize;
364 for_each_inner_block_preordered(&dims, &blocks, &strides, &offsets, |_, len, _| {
365 total += len;
366 Ok(())
367 })
368 .unwrap();
369 assert_eq!(total, 24);
370 }
371
372 #[test]
373 fn test_build_plan_fused_non_contiguous() {
374 let dims = [4usize, 5];
376 let strides = [5isize, 1];
377 let strides_list: Vec<&[isize]> = vec![&strides];
378 let (fused_dims, fused_strides, plan) = build_plan_fused(&dims, &strides_list, Some(0), 8);
379 assert_eq!(fused_dims, vec![20]);
381 assert_eq!(fused_strides[0], vec![1]);
382 assert_eq!(plan.block.len(), 1);
383 }
384
385 #[test]
386 fn test_build_plan_fused_multi_array() {
387 let dims = [4usize, 5];
389 let dst_strides = [1isize, 4];
390 let src_strides = [5isize, 1];
391 let strides_list: Vec<&[isize]> = vec![&dst_strides, &src_strides];
392 let (fused_dims, fused_strides, plan) = build_plan_fused(&dims, &strides_list, Some(0), 8);
393 assert_eq!(fused_dims.len(), 2);
395 assert_eq!(fused_strides.len(), 2);
396 assert!(plan.block.len() >= 1);
397 }
398
399 #[test]
400 fn test_build_plan_fused_small_basic() {
401 let dims = [2usize, 3];
402 let strides = [1isize, 2];
403 let strides_list: Vec<&[isize]> = vec![&strides];
404 let (fused_dims, fused_strides, plan) = build_plan_fused_small(&dims, &strides_list);
405 assert_eq!(fused_dims, vec![6]);
406 assert_eq!(fused_strides[0], vec![1]);
407 assert_eq!(plan.block, fused_dims);
409 }
410
411 #[test]
412 fn test_build_plan_fused_small_non_contiguous() {
413 let dims = [4usize, 5];
414 let strides = [5isize, 1];
415 let strides_list: Vec<&[isize]> = vec![&strides];
416 let (fused_dims, fused_strides, plan) = build_plan_fused_small(&dims, &strides_list);
417 assert!(!fused_dims.is_empty());
419 assert_eq!(fused_strides.len(), 1);
420 assert_eq!(plan.block, fused_dims);
421 }
422
423 #[test]
424 fn test_for_each_rank0() {
425 let dims = vec![];
426 let blocks = vec![];
427 let strides: Vec<Vec<isize>> = vec![vec![]];
428 let offsets = vec![0isize];
429 let mut called = false;
430 for_each_inner_block_preordered(&dims, &blocks, &strides, &offsets, |_, len, _| {
431 called = true;
432 assert_eq!(len, 1);
433 Ok(())
434 })
435 .unwrap();
436 assert!(called);
437 }
438
439 fn count_elements(dims: &[usize], blocks: &[usize]) -> usize {
441 let n_arrays = 1;
442 let strides: Vec<Vec<isize>> = {
443 let mut s = vec![vec![0isize; dims.len()]; n_arrays];
444 let mut stride = 1isize;
445 for d in 0..dims.len() {
446 for a in 0..n_arrays {
447 s[a][d] = stride;
448 }
449 stride *= dims[d] as isize;
450 }
451 s
452 };
453 let offsets = vec![0isize; n_arrays];
454 let mut total = 0usize;
455 for_each_inner_block_preordered(dims, blocks, &strides, &offsets, |_, len, _| {
456 total += len;
457 Ok(())
458 })
459 .unwrap();
460 total
461 }
462
463 #[test]
464 fn test_for_each_rank4() {
465 assert_eq!(count_elements(&[2, 3, 4, 5], &[2, 3, 4, 5]), 120);
466 }
467
468 #[test]
469 fn test_for_each_rank5() {
470 assert_eq!(count_elements(&[2, 2, 2, 2, 2], &[2, 2, 2, 2, 2]), 32);
471 }
472
473 #[test]
474 fn test_for_each_rank6() {
475 assert_eq!(count_elements(&[2, 2, 2, 2, 2, 3], &[2, 2, 2, 2, 2, 3]), 96);
476 }
477
478 #[test]
479 fn test_for_each_rank7() {
480 assert_eq!(
481 count_elements(&[2, 2, 2, 2, 2, 2, 3], &[2, 2, 2, 2, 2, 2, 3]),
482 192
483 );
484 }
485
486 #[test]
487 fn test_for_each_rank8() {
488 assert_eq!(
489 count_elements(&[2, 2, 2, 2, 2, 2, 2, 3], &[2, 2, 2, 2, 2, 2, 2, 3]),
490 384
491 );
492 }
493
494 #[test]
495 fn test_for_each_rank9_iterative() {
496 assert_eq!(
498 count_elements(&[2, 2, 2, 2, 2, 2, 2, 2, 3], &[2, 2, 2, 2, 2, 2, 2, 2, 3]),
499 768
500 );
501 }
502
503 #[test]
504 fn test_for_each_rank10_iterative() {
505 assert_eq!(
506 count_elements(
507 &[2, 2, 2, 2, 2, 2, 2, 2, 2, 3],
508 &[2, 2, 2, 2, 2, 2, 2, 2, 2, 3]
509 ),
510 1536
511 );
512 }
513
514 #[test]
515 fn test_total_len_empty() {
516 assert_eq!(total_len(&[]), 1);
517 }
518
519 #[test]
520 fn test_total_len_basic() {
521 assert_eq!(total_len(&[2, 3, 4]), 24);
522 }
523}