Skip to main content

strided_perm/
order.rs

1//! Loop ordering algorithm ported from Strided.jl
2
3use crate::fuse::{compute_importance, sort_by_importance};
4use strided_view::auxiliary::index_order;
5
6/// Compute the optimal iteration order for dimensions.
7///
8/// This implementation follows Julia's `_mapreduce_order!` algorithm:
9/// 1. Compute `index_order` for each array's strides
10/// 2. Compute importance scores using bit-packing with output weighted 2x
11/// 3. Sort dimensions by importance (descending)
12pub fn compute_order(
13    dims: &[usize],
14    strides_list: &[&[isize]],
15    dest_index: Option<usize>,
16) -> Vec<usize> {
17    let rank = dims.len();
18    if rank == 0 {
19        return Vec::new();
20    }
21
22    if strides_list.is_empty() {
23        return (0..rank).collect();
24    }
25
26    // Compute index_order for each stride array
27    let mut index_orders: Vec<Vec<usize>> = Vec::with_capacity(strides_list.len());
28    for strides in strides_list {
29        index_orders.push(index_order(strides));
30    }
31
32    // Reorder so destination array is first (gets 2x weight)
33    let reordered_strides: Vec<&[isize]>;
34    let reordered_orders: Vec<Vec<usize>>;
35
36    if let Some(dest_idx) = dest_index {
37        if dest_idx < strides_list.len() && dest_idx != 0 {
38            let mut strides_vec: Vec<&[isize]> = strides_list.to_vec();
39            let mut orders_vec = index_orders;
40
41            let dest_strides = strides_vec.remove(dest_idx);
42            let dest_order = orders_vec.remove(dest_idx);
43
44            strides_vec.insert(0, dest_strides);
45            orders_vec.insert(0, dest_order);
46
47            reordered_strides = strides_vec;
48            reordered_orders = orders_vec;
49        } else {
50            reordered_strides = strides_list.to_vec();
51            reordered_orders = index_orders;
52        }
53    } else {
54        reordered_strides = strides_list.to_vec();
55        reordered_orders = index_orders;
56    }
57
58    // Compute importance using the Julia algorithm
59    let importance = compute_importance(dims, &reordered_strides, &reordered_orders);
60
61    // Sort by importance (descending)
62    sort_by_importance(&importance)
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn test_compute_order_column_major() {
71        let dims = [4usize, 5];
72        let strides = [1isize, 4];
73        let strides_list: Vec<&[isize]> = vec![&strides];
74        let order = compute_order(&dims, &strides_list, Some(0));
75        assert_eq!(order[0], 0);
76        assert_eq!(order[1], 1);
77    }
78
79    #[test]
80    fn test_compute_order_row_major() {
81        let dims = [4usize, 5];
82        let strides = [5isize, 1];
83        let strides_list: Vec<&[isize]> = vec![&strides];
84        let order = compute_order(&dims, &strides_list, Some(0));
85        assert_eq!(order[0], 1);
86        assert_eq!(order[1], 0);
87    }
88
89    #[test]
90    fn test_compute_order_mixed() {
91        let dims = [4usize, 5];
92        let out_strides = [1isize, 4];
93        let in_strides = [5isize, 1];
94        let strides_list: Vec<&[isize]> = vec![&out_strides, &in_strides];
95        let order = compute_order(&dims, &strides_list, Some(0));
96        assert_eq!(order[0], 0);
97        assert_eq!(order[1], 1);
98    }
99
100    #[test]
101    fn test_compute_order_empty() {
102        let dims: [usize; 0] = [];
103        let strides: [isize; 0] = [];
104        let strides_list: Vec<&[isize]> = vec![&strides];
105        let order = compute_order(&dims, &strides_list, Some(0));
106        assert!(order.is_empty());
107    }
108
109    #[test]
110    fn test_compute_order_dest_reorder() {
111        // dest_index=1: should swap so destination is first (gets 2x weight)
112        let dims = [4usize, 5];
113        let src_strides = [1isize, 4]; // col-major
114        let dst_strides = [5isize, 1]; // row-major
115        let strides_list: Vec<&[isize]> = vec![&src_strides, &dst_strides];
116        let order = compute_order(&dims, &strides_list, Some(1));
117        // With dst (row-major) weighted 2x, dim 1 (stride 1 in dst) should be innermost
118        assert_eq!(order.len(), 2);
119    }
120
121    #[test]
122    fn test_compute_order_no_dest() {
123        let dims = [4usize, 5];
124        let strides = [1isize, 4];
125        let strides_list: Vec<&[isize]> = vec![&strides];
126        let order = compute_order(&dims, &strides_list, None);
127        assert_eq!(order.len(), 2);
128    }
129
130    #[test]
131    fn test_compute_order_dest_out_of_bounds() {
132        // dest_index >= strides_list.len(): should skip reordering
133        let dims = [4usize, 5];
134        let strides = [1isize, 4];
135        let strides_list: Vec<&[isize]> = vec![&strides];
136        let order = compute_order(&dims, &strides_list, Some(5));
137        assert_eq!(order.len(), 2);
138    }
139
140    #[test]
141    fn test_compute_order_empty_strides_list() {
142        let dims = [4usize, 5, 6];
143        let strides_list: Vec<&[isize]> = vec![];
144        let order = compute_order(&dims, &strides_list, None);
145        assert_eq!(order, vec![0, 1, 2]);
146    }
147
148    #[test]
149    fn test_compute_order_3d() {
150        // 3D col-major
151        let dims = [4usize, 5, 6];
152        let strides = [1isize, 4, 20];
153        let strides_list: Vec<&[isize]> = vec![&strides];
154        let order = compute_order(&dims, &strides_list, Some(0));
155        assert_eq!(order.len(), 3);
156        // Col-major: innermost should be dim 0
157        assert_eq!(order[0], 0);
158    }
159}