1pub fn fuse_dims(dims: &[usize], all_strides: &[&[isize]]) -> Vec<usize> {
12 let n = dims.len();
13 if n <= 1 || all_strides.is_empty() {
14 return dims.to_vec();
15 }
16
17 let mut result = dims.to_vec();
18
19 for i in (1..n).rev() {
21 let mut can_merge = true;
22
23 for strides in all_strides {
25 let expected = result[i - 1] as isize * strides[i - 1];
27 if strides[i] != expected {
28 can_merge = false;
29 break;
30 }
31 }
32
33 if can_merge {
34 result[i - 1] *= result[i];
36 result[i] = 1;
37 }
38 }
39
40 result
41}
42
43pub fn compress_dims(dims: &[usize], all_strides: &[Vec<isize>]) -> (Vec<usize>, Vec<Vec<isize>>) {
52 let kept: Vec<usize> = (0..dims.len()).filter(|&i| dims[i] != 1).collect();
53
54 if kept.is_empty() {
55 if dims.is_empty() {
57 return (vec![], all_strides.to_vec());
58 }
59 let new_strides = all_strides.iter().map(|s| vec![s[0]]).collect();
60 return (vec![1], new_strides);
61 }
62
63 let new_dims: Vec<usize> = kept.iter().map(|&i| dims[i]).collect();
64 let new_strides: Vec<Vec<isize>> = all_strides
65 .iter()
66 .map(|s| kept.iter().map(|&i| s[i]).collect())
67 .collect();
68
69 (new_dims, new_strides)
70}
71
72pub fn compute_importance(
77 dims: &[usize],
78 all_strides: &[&[isize]],
79 index_orders: &[Vec<usize>],
80) -> Vec<u64> {
81 let n = dims.len();
82 let m = all_strides.len();
83
84 if n == 0 || m == 0 {
85 return vec![];
86 }
87
88 let g = (64 - (m as u64 + 1).leading_zeros()) as u64;
90
91 let mut importance = vec![0u64; n];
92
93 for i in 0..n {
95 let shift = g * (n - index_orders[0][i]) as u64;
96 importance[i] = 2 * (1u64 << shift);
97 }
98
99 #[allow(clippy::needless_range_loop)]
101 for k in 1..m {
102 for i in 0..n {
103 let shift = g * (n - index_orders[k][i]) as u64;
104 importance[i] += 1u64 << shift;
105 }
106 }
107
108 for i in 0..n {
110 if dims[i] <= 1 {
111 importance[i] = 0;
112 }
113 }
114
115 importance
116}
117
118pub fn sort_by_importance(importance: &[u64]) -> Vec<usize> {
120 let mut indices: Vec<usize> = (0..importance.len()).collect();
121 indices.sort_by(|&a, &b| importance[b].cmp(&importance[a]));
122 indices
123}
124
125pub fn compute_costs<S: AsRef<[isize]>>(all_strides: &[S]) -> Vec<isize> {
127 if all_strides.is_empty() {
128 return vec![];
129 }
130
131 let n = all_strides[0].as_ref().len();
132 let mut costs = vec![isize::MAX; n];
133
134 for strides in all_strides {
135 let strides = strides.as_ref();
136 for i in 0..n {
137 costs[i] = costs[i].min(strides[i].abs());
138 }
139 }
140
141 for cost in &mut costs {
143 if *cost == 0 {
144 *cost = 1;
145 } else {
146 *cost *= 2;
147 }
148 }
149
150 costs
151}
152
153pub fn fuse_dims_bilateral(
158 dims: &[usize],
159 src_strides: &[isize],
160 dst_strides: &[isize],
161) -> (Vec<usize>, Vec<isize>, Vec<isize>) {
162 let n = dims.len();
163 if n <= 1 {
164 return (dims.to_vec(), src_strides.to_vec(), dst_strides.to_vec());
165 }
166
167 let mut fused_dims = Vec::with_capacity(n);
168 let mut fused_src = Vec::with_capacity(n);
169 let mut fused_dst = Vec::with_capacity(n);
170
171 fused_dims.push(dims[0]);
172 fused_src.push(src_strides[0]);
173 fused_dst.push(dst_strides[0]);
174
175 for i in 1..n {
176 let last = fused_dims.len() - 1;
177 let d_prev = fused_dims[last];
178 let ss_prev = fused_src[last];
179 let ds_prev = fused_dst[last];
180
181 let src_contiguous = src_strides[i] == ss_prev * d_prev as isize;
183 let dst_contiguous = dst_strides[i] == ds_prev * d_prev as isize;
184
185 if src_contiguous && dst_contiguous {
186 fused_dims[last] *= dims[i];
188 } else {
189 fused_dims.push(dims[i]);
190 fused_src.push(src_strides[i]);
191 fused_dst.push(dst_strides[i]);
192 }
193 }
194
195 (fused_dims, fused_src, fused_dst)
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_fuse_dims_contiguous() {
204 let dims = [3, 4];
205 let strides1 = [1isize, 3];
206 let strides2 = [1isize, 3];
207 let all_strides: Vec<&[isize]> = vec![&strides1, &strides2];
208 let fused = fuse_dims(&dims, &all_strides);
209 assert_eq!(fused, vec![12, 1]);
210 }
211
212 #[test]
213 fn test_fuse_dims_non_contiguous() {
214 let dims = [3, 4];
215 let strides1 = [1isize, 10];
216 let all_strides: Vec<&[isize]> = vec![&strides1];
217 let fused = fuse_dims(&dims, &all_strides);
218 assert_eq!(fused, vec![3, 4]);
219 }
220
221 #[test]
222 fn test_fuse_dims_bilateral_all_contiguous() {
223 let dims = vec![2, 2, 2, 2];
225 let src_strides = vec![1, 2, 4, 8];
226 let dst_strides = vec![1, 2, 4, 8];
227 let (fd, fs, fds) = fuse_dims_bilateral(&dims, &src_strides, &dst_strides);
228 assert_eq!(fd, vec![16]);
229 assert_eq!(fs, vec![1]);
230 assert_eq!(fds, vec![1]);
231 }
232
233 #[test]
234 fn test_fuse_dims_bilateral_partial() {
235 let dims = vec![2, 3, 4];
238 let src_strides = vec![1, 2, 100]; let dst_strides = vec![1, 2, 6]; let (fd, fs, fds) = fuse_dims_bilateral(&dims, &src_strides, &dst_strides);
241 assert_eq!(fd, vec![6, 4]); assert_eq!(fs, vec![1, 100]);
243 assert_eq!(fds, vec![1, 6]);
244 }
245
246 #[test]
247 fn test_fuse_dims_bilateral_scattered() {
248 let dims = vec![2, 2, 2];
250 let src_strides = vec![1, 4194304, 2]; let dst_strides = vec![1, 2, 4]; let (fd, fs, fds) = fuse_dims_bilateral(&dims, &src_strides, &dst_strides);
253 assert_eq!(fd, vec![2, 2, 2]); assert_eq!(fs, vec![1, 4194304, 2]);
255 assert_eq!(fds, vec![1, 2, 4]);
256 }
257
258 #[test]
259 fn test_compress_dims_removes_fused() {
260 let dims = vec![12usize, 1];
261 let strides = vec![vec![1isize, 3]];
262 let (cd, cs) = compress_dims(&dims, &strides);
263 assert_eq!(cd, vec![12]);
264 assert_eq!(cs, vec![vec![1]]);
265 }
266
267 #[test]
268 fn test_compute_costs() {
269 let strides1 = [1isize, 4, 0];
270 let strides2 = [2isize, 1, 0];
271 let all_strides: Vec<&[isize]> = vec![&strides1, &strides2];
272 let costs = compute_costs(&all_strides);
273 assert_eq!(costs, vec![2, 2, 1]);
274 }
275}