1use crate::fuse::compute_costs;
4use crate::{BLOCK_MEMORY_SIZE, CACHE_LINE_SIZE};
5use strided_view::auxiliary::index_order;
6
7pub fn compute_block_sizes(
9 dims: &[usize],
10 order: &[usize],
11 strides_list: &[&[isize]],
12 elem_size: usize,
13) -> Vec<usize> {
14 if order.is_empty() {
15 return Vec::new();
16 }
17
18 let ordered_dims: Vec<usize> = order.iter().map(|&i| dims[i]).collect();
20
21 let byte_strides: Vec<Vec<isize>> = strides_list
23 .iter()
24 .map(|strides| {
25 order
26 .iter()
27 .map(|&i| strides[i] * elem_size as isize)
28 .collect()
29 })
30 .collect();
31
32 let stride_orders: Vec<Vec<usize>> = byte_strides.iter().map(|bs| index_order(bs)).collect();
34
35 let reordered_strides: Vec<Vec<isize>> = strides_list
37 .iter()
38 .map(|strides| order.iter().map(|&i| strides[i]).collect())
39 .collect();
40
41 let reordered_refs: Vec<&[isize]> = reordered_strides.iter().map(|s| s.as_slice()).collect();
42 let costs = compute_costs(&reordered_refs);
43
44 let byte_stride_refs: Vec<&[isize]> = byte_strides.iter().map(|s| s.as_slice()).collect();
46 let stride_order_refs: Vec<&[usize]> = stride_orders.iter().map(|s| s.as_slice()).collect();
47
48 compute_blocks(
49 &ordered_dims,
50 &costs,
51 &byte_stride_refs,
52 &stride_order_refs,
53 BLOCK_MEMORY_SIZE,
54 )
55}
56
57fn compute_blocks(
58 dims: &[usize],
59 costs: &[isize],
60 byte_strides: &[&[isize]],
61 stride_orders: &[&[usize]],
62 block_size: usize,
63) -> Vec<usize> {
64 let n = dims.len();
65 if n == 0 {
66 return vec![];
67 }
68
69 if total_memory_region(dims, byte_strides) <= block_size {
70 return dims.to_vec();
71 }
72
73 let min_order = stride_orders
74 .iter()
75 .filter_map(|orders| orders.iter().min().copied())
76 .min()
77 .unwrap_or(1);
78
79 if stride_orders
80 .iter()
81 .all(|orders| !orders.is_empty() && orders[0] == min_order)
82 {
83 let tail_dims: Vec<usize> = dims[1..].to_vec();
84 let tail_costs: Vec<isize> = costs[1..].to_vec();
85 let tail_byte_strides: Vec<&[isize]> = byte_strides.iter().map(|s| &s[1..]).collect();
86 let tail_stride_orders: Vec<&[usize]> = stride_orders.iter().map(|s| &s[1..]).collect();
87
88 let tail_blocks = compute_blocks(
89 &tail_dims,
90 &tail_costs,
91 &tail_byte_strides,
92 &tail_stride_orders,
93 block_size,
94 );
95
96 let mut result = vec![dims[0]];
97 result.extend(tail_blocks);
98 return result;
99 }
100
101 let min_stride = byte_strides
102 .iter()
103 .filter_map(|s| s.iter().map(|x| x.unsigned_abs()).min())
104 .min()
105 .unwrap_or(0);
106
107 if min_stride > block_size {
108 return vec![1; n];
109 }
110
111 let mut blocks = dims.to_vec();
112
113 while total_memory_region(&blocks, byte_strides) >= 2 * block_size {
115 let i = last_argmax_weighted(&blocks, costs);
116 if i.is_none() || blocks[i.unwrap()] <= 1 {
117 break;
118 }
119 let i = i.unwrap();
120 blocks[i] = blocks[i].div_ceil(2);
121 }
122
123 while total_memory_region(&blocks, byte_strides) > block_size {
125 let i = last_argmax_weighted(&blocks, costs);
126 if i.is_none() || blocks[i.unwrap()] <= 1 {
127 break;
128 }
129 let i = i.unwrap();
130 blocks[i] -= 1;
131 }
132
133 blocks
134}
135
136fn total_memory_region(dims: &[usize], byte_strides: &[&[isize]]) -> usize {
137 let cache_line = CACHE_LINE_SIZE;
138 let mut memory_region = 0usize;
139
140 for strides in byte_strides {
141 let mut num_contiguous_cache_lines = 0isize;
142 let mut num_cache_line_blocks = 1usize;
143
144 for (&d, &s) in dims.iter().zip(strides.iter()) {
145 let s_abs = s.unsigned_abs();
146 if s_abs < cache_line {
147 num_contiguous_cache_lines += (d.saturating_sub(1) as isize) * (s_abs as isize);
148 } else {
149 num_cache_line_blocks *= d;
150 }
151 }
152
153 let contiguous_lines = (num_contiguous_cache_lines as usize / cache_line) + 1;
154 memory_region += cache_line * contiguous_lines * num_cache_line_blocks;
155 }
156
157 memory_region
158}
159
160fn last_argmax_weighted(blocks: &[usize], costs: &[isize]) -> Option<usize> {
161 if blocks.is_empty() {
162 return None;
163 }
164
165 let mut max_score = 0isize;
166 let mut max_idx = None;
167
168 for (i, (&b, &c)) in blocks.iter().zip(costs.iter()).enumerate() {
169 if b <= 1 {
170 continue;
171 }
172 let score = (b as isize - 1) * c;
173 if score >= max_score {
174 max_score = score;
175 max_idx = Some(i);
176 }
177 }
178
179 max_idx
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_total_memory_region_contiguous() {
188 let dims = [100usize];
189 let strides = [8isize];
190 let byte_strides: Vec<&[isize]> = vec![&strides];
191 let region = total_memory_region(&dims, &byte_strides);
192 assert_eq!(region, 832);
193 }
194
195 #[test]
196 fn test_compute_blocks_small() {
197 let dims = [10usize, 10];
198 let costs = [2isize, 2];
199 let strides = [8isize, 80];
200 let orders = [1usize, 2];
201 let byte_strides: Vec<&[isize]> = vec![&strides];
202 let stride_orders: Vec<&[usize]> = vec![&orders];
203 let blocks = compute_blocks(
204 &dims,
205 &costs,
206 &byte_strides,
207 &stride_orders,
208 BLOCK_MEMORY_SIZE,
209 );
210 assert_eq!(blocks, vec![10, 10]);
211 }
212
213 #[test]
214 fn test_compute_blocks_large() {
215 let dims = [1000usize, 1000];
216 let costs = [2isize, 2];
217 let strides = [8isize, 8000];
218 let orders = [1usize, 2];
219 let byte_strides: Vec<&[isize]> = vec![&strides];
220 let stride_orders: Vec<&[usize]> = vec![&orders];
221 let blocks = compute_blocks(
222 &dims,
223 &costs,
224 &byte_strides,
225 &stride_orders,
226 BLOCK_MEMORY_SIZE,
227 );
228 assert!(blocks[0] <= dims[0]);
229 assert!(blocks[1] <= dims[1]);
230 assert!(blocks[0] >= 1);
231 assert!(blocks[1] >= 1);
232 }
233
234 #[test]
235 fn test_compute_blocks_empty() {
236 let blocks = compute_blocks(&[], &[], &[], &[], BLOCK_MEMORY_SIZE);
237 assert!(blocks.is_empty());
238 }
239
240 #[test]
241 fn test_last_argmax_weighted_basic() {
242 assert_eq!(last_argmax_weighted(&[10, 5], &[2, 4]), Some(0));
245 }
246
247 #[test]
248 fn test_last_argmax_weighted_ties() {
249 assert_eq!(last_argmax_weighted(&[5, 5], &[2, 2]), Some(1));
251 }
252
253 #[test]
254 fn test_last_argmax_weighted_all_one() {
255 assert_eq!(last_argmax_weighted(&[1, 1, 1], &[1, 1, 1]), None);
257 }
258
259 #[test]
260 fn test_last_argmax_weighted_empty() {
261 assert_eq!(last_argmax_weighted(&[], &[]), None);
262 }
263
264 #[test]
265 fn test_total_memory_region_multi_array() {
266 let dims = [100usize, 100];
268 let s1 = [8isize, 800]; let s2 = [800isize, 8]; let byte_strides: Vec<&[isize]> = vec![&s1, &s2];
271 let region = total_memory_region(&dims, &byte_strides);
272 assert!(region > 0);
274 }
275
276 #[test]
277 fn test_total_memory_region_large_stride() {
278 let dims = [10usize, 10];
280 let strides = [8isize, 800]; let byte_strides: Vec<&[isize]> = vec![&strides];
282 let region = total_memory_region(&dims, &byte_strides);
283 assert!(region > 0);
284 }
285
286 #[test]
287 fn test_compute_blocks_min_stride_exceeds_block_size() {
288 let dims = [10usize, 10];
291 let costs = [1isize, 1];
292 let s1 = [100000isize, 1000000]; let s2 = [1000000isize, 100000]; let o1 = [0usize, 1];
295 let o2 = [1usize, 0]; let byte_strides: Vec<&[isize]> = vec![&s1, &s2];
297 let stride_orders: Vec<&[usize]> = vec![&o1, &o2];
298 let blocks = compute_blocks(&dims, &costs, &byte_strides, &stride_orders, 64);
299 assert_eq!(blocks, vec![1, 1]);
300 }
301
302 #[test]
303 fn test_compute_block_sizes_3d() {
304 let dims = [10usize, 20, 30];
306 let order = [0usize, 1, 2];
307 let strides = [8isize, 80, 1600];
308 let strides_list: Vec<&[isize]> = vec![&strides];
309 let blocks = compute_block_sizes(&dims, &order, &strides_list, 8);
310 assert_eq!(blocks.len(), 3);
311 for i in 0..3 {
312 assert!(blocks[i] >= 1 && blocks[i] <= dims[i]);
313 }
314 }
315
316 #[test]
317 fn test_compute_blocks_first_stride_order_matches() {
318 let dims = [4usize, 100, 100];
320 let costs = [1isize, 10, 10];
321 let strides = [8isize, 32, 3200];
323 let orders = [0usize, 1, 2];
324 let byte_strides: Vec<&[isize]> = vec![&strides];
325 let stride_orders: Vec<&[usize]> = vec![&orders];
326 let blocks = compute_blocks(
327 &dims,
328 &costs,
329 &byte_strides,
330 &stride_orders,
331 BLOCK_MEMORY_SIZE,
332 );
333 assert_eq!(blocks[0], 4);
335 }
336}