Skip to main content

strided_perm/
fuse.rs

1//! Dimension fusion logic ported from Strided.jl/src/mapreduce.jl
2//!
3//! This module implements the core dimension fusion algorithm that merges
4//! contiguous dimensions to reduce iteration complexity.
5
6/// Fuse contiguous dimensions across multiple arrays.
7///
8/// This function fuses subsequent dimensions that are contiguous in memory
9/// for all arrays. If `strides[k][i] == dims[i-1] * strides[k][i-1]` for all k,
10/// dimensions i-1 and i can be merged.
11pub fn fuse_dims(dims: &[usize], all_strides: &[&[isize]]) -> Vec<usize> {
12    let n = dims.len();
13    if n <= 1 || all_strides.is_empty() {
14        return dims.to_vec();
15    }
16
17    let mut result = dims.to_vec();
18
19    // Work from the end towards the beginning (Julia: for i in length(dims):-1:2)
20    for i in (1..n).rev() {
21        let mut can_merge = true;
22
23        // Check all arrays for contiguity
24        for strides in all_strides {
25            // s[i] should equal dims[i-1] * s[i-1] for fusion
26            let expected = result[i - 1] as isize * strides[i - 1];
27            if strides[i] != expected {
28                can_merge = false;
29                break;
30            }
31        }
32
33        if can_merge {
34            // Fuse dimensions: merge dimension i into i-1
35            result[i - 1] *= result[i];
36            result[i] = 1;
37        }
38    }
39
40    result
41}
42
43/// Remove size-1 dimensions from fused dims and all corresponding strides.
44///
45/// After `fuse_dims()`, many dimensions may be 1 (either originally size-1
46/// or merged into a neighbor). These contribute nothing to iteration but
47/// increase loop depth. This function strips them out.
48///
49/// If ALL dimensions are 1 (scalar-like), a single dimension of size 1
50/// is preserved so the kernel has something to iterate over.
51pub fn compress_dims(dims: &[usize], all_strides: &[Vec<isize>]) -> (Vec<usize>, Vec<Vec<isize>>) {
52    let kept: Vec<usize> = (0..dims.len()).filter(|&i| dims[i] != 1).collect();
53
54    if kept.is_empty() {
55        // All dims are 1 (or empty). Preserve a single trivial dimension.
56        if dims.is_empty() {
57            return (vec![], all_strides.to_vec());
58        }
59        let new_strides = all_strides.iter().map(|s| vec![s[0]]).collect();
60        return (vec![1], new_strides);
61    }
62
63    let new_dims: Vec<usize> = kept.iter().map(|&i| dims[i]).collect();
64    let new_strides: Vec<Vec<isize>> = all_strides
65        .iter()
66        .map(|s| kept.iter().map(|&i| s[i]).collect())
67        .collect();
68
69    (new_dims, new_strides)
70}
71
72/// Compute the "importance" of each dimension for loop ordering.
73///
74/// This encodes stride order information into importance scores that determine
75/// the optimal iteration order. The output array's strides are weighted 2x.
76pub fn compute_importance(
77    dims: &[usize],
78    all_strides: &[&[isize]],
79    index_orders: &[Vec<usize>],
80) -> Vec<u64> {
81    let n = dims.len();
82    let m = all_strides.len();
83
84    if n == 0 || m == 0 {
85        return vec![];
86    }
87
88    // g = ceil(log2(M + 2)) = number of bits needed to encode array count
89    let g = (64 - (m as u64 + 1).leading_zeros()) as u64;
90
91    let mut importance = vec![0u64; n];
92
93    // First array (output) is weighted 2x
94    for i in 0..n {
95        let shift = g * (n - index_orders[0][i]) as u64;
96        importance[i] = 2 * (1u64 << shift);
97    }
98
99    // Add contributions from remaining arrays
100    #[allow(clippy::needless_range_loop)]
101    for k in 1..m {
102        for i in 0..n {
103            let shift = g * (n - index_orders[k][i]) as u64;
104            importance[i] += 1u64 << shift;
105        }
106    }
107
108    // Zero importance for size-1 dimensions (put them at the back)
109    for i in 0..n {
110        if dims[i] <= 1 {
111            importance[i] = 0;
112        }
113    }
114
115    importance
116}
117
118/// Get the permutation that sorts by importance (descending).
119pub fn sort_by_importance(importance: &[u64]) -> Vec<usize> {
120    let mut indices: Vec<usize> = (0..importance.len()).collect();
121    indices.sort_by(|&a, &b| importance[b].cmp(&importance[a]));
122    indices
123}
124
125/// Compute the minimum stride cost for each dimension.
126pub fn compute_costs<S: AsRef<[isize]>>(all_strides: &[S]) -> Vec<isize> {
127    if all_strides.is_empty() {
128        return vec![];
129    }
130
131    let n = all_strides[0].as_ref().len();
132    let mut costs = vec![isize::MAX; n];
133
134    for strides in all_strides {
135        let strides = strides.as_ref();
136        for i in 0..n {
137            costs[i] = costs[i].min(strides[i].abs());
138        }
139    }
140
141    // Transform: zero -> 1, nonzero -> 2*abs
142    for cost in &mut costs {
143        if *cost == 0 {
144            *cost = 1;
145        } else {
146            *cost *= 2;
147        }
148    }
149
150    costs
151}
152
153/// Bilateral dimension fusion for src + dst stride patterns.
154///
155/// Two dimensions `i` and `i+1` can be fused if BOTH src and dst strides
156/// are contiguous for those dimensions. Returns the fused (dims, src_strides, dst_strides).
157pub fn fuse_dims_bilateral(
158    dims: &[usize],
159    src_strides: &[isize],
160    dst_strides: &[isize],
161) -> (Vec<usize>, Vec<isize>, Vec<isize>) {
162    let n = dims.len();
163    if n <= 1 {
164        return (dims.to_vec(), src_strides.to_vec(), dst_strides.to_vec());
165    }
166
167    let mut fused_dims = Vec::with_capacity(n);
168    let mut fused_src = Vec::with_capacity(n);
169    let mut fused_dst = Vec::with_capacity(n);
170
171    fused_dims.push(dims[0]);
172    fused_src.push(src_strides[0]);
173    fused_dst.push(dst_strides[0]);
174
175    for i in 1..n {
176        let last = fused_dims.len() - 1;
177        let d_prev = fused_dims[last];
178        let ss_prev = fused_src[last];
179        let ds_prev = fused_dst[last];
180
181        // Check if dim i is contiguous with the previous fused dim in BOTH src and dst
182        let src_contiguous = src_strides[i] == ss_prev * d_prev as isize;
183        let dst_contiguous = dst_strides[i] == ds_prev * d_prev as isize;
184
185        if src_contiguous && dst_contiguous {
186            // Fuse: multiply the last fused dim
187            fused_dims[last] *= dims[i];
188        } else {
189            fused_dims.push(dims[i]);
190            fused_src.push(src_strides[i]);
191            fused_dst.push(dst_strides[i]);
192        }
193    }
194
195    (fused_dims, fused_src, fused_dst)
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_fuse_dims_contiguous() {
204        let dims = [3, 4];
205        let strides1 = [1isize, 3];
206        let strides2 = [1isize, 3];
207        let all_strides: Vec<&[isize]> = vec![&strides1, &strides2];
208        let fused = fuse_dims(&dims, &all_strides);
209        assert_eq!(fused, vec![12, 1]);
210    }
211
212    #[test]
213    fn test_fuse_dims_non_contiguous() {
214        let dims = [3, 4];
215        let strides1 = [1isize, 10];
216        let all_strides: Vec<&[isize]> = vec![&strides1];
217        let fused = fuse_dims(&dims, &all_strides);
218        assert_eq!(fused, vec![3, 4]);
219    }
220
221    #[test]
222    fn test_fuse_dims_bilateral_all_contiguous() {
223        // 24-dim all-size-2 col-major: all contiguous -> fuses to single dim
224        let dims = vec![2, 2, 2, 2];
225        let src_strides = vec![1, 2, 4, 8];
226        let dst_strides = vec![1, 2, 4, 8];
227        let (fd, fs, fds) = fuse_dims_bilateral(&dims, &src_strides, &dst_strides);
228        assert_eq!(fd, vec![16]);
229        assert_eq!(fs, vec![1]);
230        assert_eq!(fds, vec![1]);
231    }
232
233    #[test]
234    fn test_fuse_dims_bilateral_partial() {
235        // src: contiguous 0-1, not 1-2
236        // dst: contiguous 0-1-2
237        let dims = vec![2, 3, 4];
238        let src_strides = vec![1, 2, 100]; // 0-1 contiguous, 1-2 not
239        let dst_strides = vec![1, 2, 6]; // all contiguous
240        let (fd, fs, fds) = fuse_dims_bilateral(&dims, &src_strides, &dst_strides);
241        assert_eq!(fd, vec![6, 4]); // first two fuse
242        assert_eq!(fs, vec![1, 100]);
243        assert_eq!(fds, vec![1, 6]);
244    }
245
246    #[test]
247    fn test_fuse_dims_bilateral_scattered() {
248        // The benchmark case: scattered strides, nothing fuses
249        let dims = vec![2, 2, 2];
250        let src_strides = vec![1, 4194304, 2]; // scattered
251        let dst_strides = vec![1, 2, 4]; // contiguous
252        let (fd, fs, fds) = fuse_dims_bilateral(&dims, &src_strides, &dst_strides);
253        assert_eq!(fd, vec![2, 2, 2]); // nothing fuses
254        assert_eq!(fs, vec![1, 4194304, 2]);
255        assert_eq!(fds, vec![1, 2, 4]);
256    }
257
258    #[test]
259    fn test_compress_dims_removes_fused() {
260        let dims = vec![12usize, 1];
261        let strides = vec![vec![1isize, 3]];
262        let (cd, cs) = compress_dims(&dims, &strides);
263        assert_eq!(cd, vec![12]);
264        assert_eq!(cs, vec![vec![1]]);
265    }
266
267    #[test]
268    fn test_compute_costs() {
269        let strides1 = [1isize, 4, 0];
270        let strides2 = [2isize, 1, 0];
271        let all_strides: Vec<&[isize]> = vec![&strides1, &strides2];
272        let costs = compute_costs(&all_strides);
273        assert_eq!(costs, vec![2, 2, 1]);
274    }
275}