Skip to main content

strided_perm/
block.rs

1//! Block size computation ported from Strided.jl
2
3use crate::fuse::compute_costs;
4use crate::{BLOCK_MEMORY_SIZE, CACHE_LINE_SIZE};
5use strided_view::auxiliary::index_order;
6
7/// Compute block sizes for tiled iteration.
8pub fn compute_block_sizes(
9    dims: &[usize],
10    order: &[usize],
11    strides_list: &[&[isize]],
12    elem_size: usize,
13) -> Vec<usize> {
14    if order.is_empty() {
15        return Vec::new();
16    }
17
18    // Reorder dims to iteration order
19    let ordered_dims: Vec<usize> = order.iter().map(|&i| dims[i]).collect();
20
21    // Compute byte strides in iteration order
22    let byte_strides: Vec<Vec<isize>> = strides_list
23        .iter()
24        .map(|strides| {
25            order
26                .iter()
27                .map(|&i| strides[i] * elem_size as isize)
28                .collect()
29        })
30        .collect();
31
32    // Compute stride orders (in iteration order)
33    let stride_orders: Vec<Vec<usize>> = byte_strides.iter().map(|bs| index_order(bs)).collect();
34
35    // Reorder strides for cost computation
36    let reordered_strides: Vec<Vec<isize>> = strides_list
37        .iter()
38        .map(|strides| order.iter().map(|&i| strides[i]).collect())
39        .collect();
40
41    let reordered_refs: Vec<&[isize]> = reordered_strides.iter().map(|s| s.as_slice()).collect();
42    let costs = compute_costs(&reordered_refs);
43
44    // Convert byte_strides to slices
45    let byte_stride_refs: Vec<&[isize]> = byte_strides.iter().map(|s| s.as_slice()).collect();
46    let stride_order_refs: Vec<&[usize]> = stride_orders.iter().map(|s| s.as_slice()).collect();
47
48    compute_blocks(
49        &ordered_dims,
50        &costs,
51        &byte_stride_refs,
52        &stride_order_refs,
53        BLOCK_MEMORY_SIZE,
54    )
55}
56
57fn compute_blocks(
58    dims: &[usize],
59    costs: &[isize],
60    byte_strides: &[&[isize]],
61    stride_orders: &[&[usize]],
62    block_size: usize,
63) -> Vec<usize> {
64    let n = dims.len();
65    if n == 0 {
66        return vec![];
67    }
68
69    if total_memory_region(dims, byte_strides) <= block_size {
70        return dims.to_vec();
71    }
72
73    let min_order = stride_orders
74        .iter()
75        .filter_map(|orders| orders.iter().min().copied())
76        .min()
77        .unwrap_or(1);
78
79    if stride_orders
80        .iter()
81        .all(|orders| !orders.is_empty() && orders[0] == min_order)
82    {
83        let tail_dims: Vec<usize> = dims[1..].to_vec();
84        let tail_costs: Vec<isize> = costs[1..].to_vec();
85        let tail_byte_strides: Vec<&[isize]> = byte_strides.iter().map(|s| &s[1..]).collect();
86        let tail_stride_orders: Vec<&[usize]> = stride_orders.iter().map(|s| &s[1..]).collect();
87
88        let tail_blocks = compute_blocks(
89            &tail_dims,
90            &tail_costs,
91            &tail_byte_strides,
92            &tail_stride_orders,
93            block_size,
94        );
95
96        let mut result = vec![dims[0]];
97        result.extend(tail_blocks);
98        return result;
99    }
100
101    let min_stride = byte_strides
102        .iter()
103        .filter_map(|s| s.iter().map(|x| x.unsigned_abs()).min())
104        .min()
105        .unwrap_or(0);
106
107    if min_stride > block_size {
108        return vec![1; n];
109    }
110
111    let mut blocks = dims.to_vec();
112
113    // Phase 1: Halve until within 2x of target
114    while total_memory_region(&blocks, byte_strides) >= 2 * block_size {
115        let i = last_argmax_weighted(&blocks, costs);
116        if i.is_none() || blocks[i.unwrap()] <= 1 {
117            break;
118        }
119        let i = i.unwrap();
120        blocks[i] = blocks[i].div_ceil(2);
121    }
122
123    // Phase 2: Decrement until within target
124    while total_memory_region(&blocks, byte_strides) > block_size {
125        let i = last_argmax_weighted(&blocks, costs);
126        if i.is_none() || blocks[i.unwrap()] <= 1 {
127            break;
128        }
129        let i = i.unwrap();
130        blocks[i] -= 1;
131    }
132
133    blocks
134}
135
136fn total_memory_region(dims: &[usize], byte_strides: &[&[isize]]) -> usize {
137    let cache_line = CACHE_LINE_SIZE;
138    let mut memory_region = 0usize;
139
140    for strides in byte_strides {
141        let mut num_contiguous_cache_lines = 0isize;
142        let mut num_cache_line_blocks = 1usize;
143
144        for (&d, &s) in dims.iter().zip(strides.iter()) {
145            let s_abs = s.unsigned_abs();
146            if s_abs < cache_line {
147                num_contiguous_cache_lines += (d.saturating_sub(1) as isize) * (s_abs as isize);
148            } else {
149                num_cache_line_blocks *= d;
150            }
151        }
152
153        let contiguous_lines = (num_contiguous_cache_lines as usize / cache_line) + 1;
154        memory_region += cache_line * contiguous_lines * num_cache_line_blocks;
155    }
156
157    memory_region
158}
159
160fn last_argmax_weighted(blocks: &[usize], costs: &[isize]) -> Option<usize> {
161    if blocks.is_empty() {
162        return None;
163    }
164
165    let mut max_score = 0isize;
166    let mut max_idx = None;
167
168    for (i, (&b, &c)) in blocks.iter().zip(costs.iter()).enumerate() {
169        if b <= 1 {
170            continue;
171        }
172        let score = (b as isize - 1) * c;
173        if score >= max_score {
174            max_score = score;
175            max_idx = Some(i);
176        }
177    }
178
179    max_idx
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_total_memory_region_contiguous() {
188        let dims = [100usize];
189        let strides = [8isize];
190        let byte_strides: Vec<&[isize]> = vec![&strides];
191        let region = total_memory_region(&dims, &byte_strides);
192        assert_eq!(region, 832);
193    }
194
195    #[test]
196    fn test_compute_blocks_small() {
197        let dims = [10usize, 10];
198        let costs = [2isize, 2];
199        let strides = [8isize, 80];
200        let orders = [1usize, 2];
201        let byte_strides: Vec<&[isize]> = vec![&strides];
202        let stride_orders: Vec<&[usize]> = vec![&orders];
203        let blocks = compute_blocks(
204            &dims,
205            &costs,
206            &byte_strides,
207            &stride_orders,
208            BLOCK_MEMORY_SIZE,
209        );
210        assert_eq!(blocks, vec![10, 10]);
211    }
212
213    #[test]
214    fn test_compute_blocks_large() {
215        let dims = [1000usize, 1000];
216        let costs = [2isize, 2];
217        let strides = [8isize, 8000];
218        let orders = [1usize, 2];
219        let byte_strides: Vec<&[isize]> = vec![&strides];
220        let stride_orders: Vec<&[usize]> = vec![&orders];
221        let blocks = compute_blocks(
222            &dims,
223            &costs,
224            &byte_strides,
225            &stride_orders,
226            BLOCK_MEMORY_SIZE,
227        );
228        assert!(blocks[0] <= dims[0]);
229        assert!(blocks[1] <= dims[1]);
230        assert!(blocks[0] >= 1);
231        assert!(blocks[1] >= 1);
232    }
233
234    #[test]
235    fn test_compute_blocks_empty() {
236        let blocks = compute_blocks(&[], &[], &[], &[], BLOCK_MEMORY_SIZE);
237        assert!(blocks.is_empty());
238    }
239
240    #[test]
241    fn test_last_argmax_weighted_basic() {
242        // blocks [10, 5], costs [2, 4]
243        // scores: (10-1)*2=18, (5-1)*4=16 → first wins
244        assert_eq!(last_argmax_weighted(&[10, 5], &[2, 4]), Some(0));
245    }
246
247    #[test]
248    fn test_last_argmax_weighted_ties() {
249        // Equal scores: last wins (>= semantics)
250        assert_eq!(last_argmax_weighted(&[5, 5], &[2, 2]), Some(1));
251    }
252
253    #[test]
254    fn test_last_argmax_weighted_all_one() {
255        // All blocks are 1: no valid candidate
256        assert_eq!(last_argmax_weighted(&[1, 1, 1], &[1, 1, 1]), None);
257    }
258
259    #[test]
260    fn test_last_argmax_weighted_empty() {
261        assert_eq!(last_argmax_weighted(&[], &[]), None);
262    }
263
264    #[test]
265    fn test_total_memory_region_multi_array() {
266        // Two arrays with different strides
267        let dims = [100usize, 100];
268        let s1 = [8isize, 800]; // col-major f64
269        let s2 = [800isize, 8]; // row-major f64
270        let byte_strides: Vec<&[isize]> = vec![&s1, &s2];
271        let region = total_memory_region(&dims, &byte_strides);
272        // Should sum contributions from both arrays
273        assert!(region > 0);
274    }
275
276    #[test]
277    fn test_total_memory_region_large_stride() {
278        // Stride >= cache line triggers block multiplication
279        let dims = [10usize, 10];
280        let strides = [8isize, 800]; // second dim stride 800 >= 64
281        let byte_strides: Vec<&[isize]> = vec![&strides];
282        let region = total_memory_region(&dims, &byte_strides);
283        assert!(region > 0);
284    }
285
286    #[test]
287    fn test_compute_blocks_min_stride_exceeds_block_size() {
288        // Both strides very large (> block_size=64): min_stride > block_size → all blocks = 1
289        // Two arrays with conflicting stride orders to prevent recursive tail path
290        let dims = [10usize, 10];
291        let costs = [1isize, 1];
292        let s1 = [100000isize, 1000000]; // array 1: order [0, 1]
293        let s2 = [1000000isize, 100000]; // array 2: order [1, 0]
294        let o1 = [0usize, 1];
295        let o2 = [1usize, 0]; // conflicting: o1[0]=0, o2[0]=1, min_order=0 but o2[0]!=0
296        let byte_strides: Vec<&[isize]> = vec![&s1, &s2];
297        let stride_orders: Vec<&[usize]> = vec![&o1, &o2];
298        let blocks = compute_blocks(&dims, &costs, &byte_strides, &stride_orders, 64);
299        assert_eq!(blocks, vec![1, 1]);
300    }
301
302    #[test]
303    fn test_compute_block_sizes_3d() {
304        // 3D col-major
305        let dims = [10usize, 20, 30];
306        let order = [0usize, 1, 2];
307        let strides = [8isize, 80, 1600];
308        let strides_list: Vec<&[isize]> = vec![&strides];
309        let blocks = compute_block_sizes(&dims, &order, &strides_list, 8);
310        assert_eq!(blocks.len(), 3);
311        for i in 0..3 {
312            assert!(blocks[i] >= 1 && blocks[i] <= dims[i]);
313        }
314    }
315
316    #[test]
317    fn test_compute_blocks_first_stride_order_matches() {
318        // stride_orders[0][0] == min_order → triggers recursive tail path
319        let dims = [4usize, 100, 100];
320        let costs = [1isize, 10, 10];
321        // Stride order: [0, 1, 2] and min_order = 0
322        let strides = [8isize, 32, 3200];
323        let orders = [0usize, 1, 2];
324        let byte_strides: Vec<&[isize]> = vec![&strides];
325        let stride_orders: Vec<&[usize]> = vec![&orders];
326        let blocks = compute_blocks(
327            &dims,
328            &costs,
329            &byte_strides,
330            &stride_orders,
331            BLOCK_MEMORY_SIZE,
332        );
333        // First dim should be kept at full extent
334        assert_eq!(blocks[0], 4);
335    }
336}