1pub fn invert_perm(perm: &[usize]) -> Vec<usize> {
5 let mut inv = vec![0usize; perm.len()];
6 for (i, &p) in perm.iter().enumerate() {
7 inv[p] = i;
8 }
9 inv
10}
11
12pub struct MultiIndex {
16 dims: Vec<usize>,
17 current: Vec<usize>,
18 total: usize,
19 count: usize,
20}
21
22impl MultiIndex {
23 pub fn new(dims: &[usize]) -> Self {
25 let total: usize = dims.iter().product();
26 Self {
27 dims: dims.to_vec(),
28 current: vec![0; dims.len()],
29 total,
30 count: 0,
31 }
32 }
33
34 pub fn offset(&self, strides: &[isize]) -> isize {
36 self.current
37 .iter()
38 .zip(strides.iter())
39 .map(|(&i, &s)| i as isize * s)
40 .sum()
41 }
42
43 pub fn reset(&mut self) {
45 self.current.fill(0);
46 self.count = 0;
47 }
48}
49
50impl Iterator for MultiIndex {
51 type Item = ();
52
53 fn next(&mut self) -> Option<()> {
54 if self.count >= self.total {
55 return None;
56 }
57 if self.count > 0 {
58 let mut carry = true;
60 for i in (0..self.dims.len()).rev() {
61 if carry {
62 self.current[i] += 1;
63 if self.current[i] >= self.dims[i] {
64 self.current[i] = 0;
65 } else {
66 carry = false;
67 }
68 }
69 }
70 }
71 self.count += 1;
72 Some(())
73 }
74}
75
76pub fn try_fuse_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
83 match dims.len() {
84 0 => Some((1, 0)),
85 1 => Some((dims[0], strides[0])),
86 _ => {
87 if dims.len() != strides.len() {
88 return None;
89 }
90 for (&d, &s) in dims.iter().zip(strides.iter()) {
91 if d > 1 && s == 0 {
92 return None;
93 }
94 }
95 let mut base_idx: Option<usize> = None;
98 let mut base_abs = usize::MAX;
99 for (i, (&d, &s)) in dims.iter().zip(strides.iter()).enumerate() {
100 if d <= 1 {
101 continue;
102 }
103 let abs = s.unsigned_abs();
104 if abs < base_abs {
105 base_abs = abs;
106 base_idx = Some(i);
107 }
108 }
109
110 let Some(base) = base_idx else {
112 let stride = *strides
113 .iter()
114 .min_by_key(|s| s.unsigned_abs())
115 .unwrap_or(&0);
116 return Some((dims.iter().product(), stride));
117 };
118
119 let mut used = vec![false; dims.len()];
120 used[base] = true;
121 let mut expected_abs = base_abs.checked_mul(dims[base])?;
122
123 let non_singleton = dims.iter().filter(|&&d| d > 1).count();
125 for _ in 1..non_singleton {
126 let mut next = None;
127 for i in 0..dims.len() {
128 if used[i] || dims[i] <= 1 {
129 continue;
130 }
131 if strides[i].unsigned_abs() == expected_abs {
132 next = Some(i);
133 break;
134 }
135 }
136 let i = next?;
137 used[i] = true;
138 expected_abs = expected_abs.checked_mul(dims[i])?;
139 }
140
141 Some((dims.iter().product(), strides[base]))
142 }
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_invert_perm() {
152 assert_eq!(invert_perm(&[2, 0, 1]), vec![1, 2, 0]);
153 assert_eq!(invert_perm(&[0, 1, 2]), vec![0, 1, 2]);
154 }
155
156 #[test]
157 fn test_multi_index_2d() {
158 let mut iter = MultiIndex::new(&[2, 3]);
159 let mut indices = vec![];
160 while iter.next().is_some() {
161 indices.push(iter.current.clone());
162 }
163 assert_eq!(
164 indices,
165 vec![
166 vec![0, 0],
167 vec![0, 1],
168 vec![0, 2],
169 vec![1, 0],
170 vec![1, 1],
171 vec![1, 2],
172 ]
173 );
174 }
175
176 #[test]
177 fn test_multi_index_offset() {
178 let mut iter = MultiIndex::new(&[2, 3]);
179 let strides = [3, 1]; let mut offsets = vec![];
181 while iter.next().is_some() {
182 offsets.push(iter.offset(&strides));
183 }
184 assert_eq!(offsets, vec![0, 1, 2, 3, 4, 5]);
185 }
186
187 #[test]
188 fn test_multi_index_empty() {
189 let mut iter = MultiIndex::new(&[]);
190 assert!(iter.next().is_some()); assert!(iter.next().is_none());
192 }
193
194 #[test]
195 fn test_try_fuse_group_empty() {
196 assert_eq!(try_fuse_group(&[], &[]), Some((1, 0)));
197 }
198
199 #[test]
200 fn test_try_fuse_group_single() {
201 assert_eq!(try_fuse_group(&[5], &[2]), Some((5, 2)));
202 }
203
204 #[test]
205 fn test_try_fuse_group_contiguous_row_major() {
206 assert_eq!(try_fuse_group(&[3, 4], &[4, 1]), Some((12, 1)));
208 }
209
210 #[test]
211 fn test_try_fuse_group_contiguous_col_major() {
212 assert_eq!(try_fuse_group(&[3, 4], &[1, 3]), Some((12, 1)));
214 }
215
216 #[test]
217 fn test_try_fuse_group_contiguous_3d() {
218 assert_eq!(try_fuse_group(&[2, 3, 4], &[12, 4, 1]), Some((24, 1)));
220 }
221
222 #[test]
223 fn test_try_fuse_group_non_contiguous() {
224 assert_eq!(try_fuse_group(&[3, 4], &[8, 1]), None);
226 }
227}