1pub fn invert_perm(perm: &[usize]) -> Vec<usize> {
5 let mut inv = vec![0usize; perm.len()];
6 for (i, &p) in perm.iter().enumerate() {
7 inv[p] = i;
8 }
9 inv
10}
11
12pub struct MultiIndex {
16 dims: Vec<usize>,
17 current: Vec<usize>,
18 total: usize,
19 count: usize,
20}
21
22impl MultiIndex {
23 pub fn new(dims: &[usize]) -> Self {
25 let total: usize = dims.iter().product();
26 Self {
27 dims: dims.to_vec(),
28 current: vec![0; dims.len()],
29 total,
30 count: 0,
31 }
32 }
33
34 pub fn offset(&self, strides: &[isize]) -> isize {
36 self.current
37 .iter()
38 .zip(strides.iter())
39 .map(|(&i, &s)| i as isize * s)
40 .sum()
41 }
42
43 pub fn reset(&mut self) {
45 self.current.fill(0);
46 self.count = 0;
47 }
48}
49
50impl Iterator for MultiIndex {
51 type Item = ();
52
53 fn next(&mut self) -> Option<()> {
54 if self.count >= self.total {
55 return None;
56 }
57 if self.count > 0 {
58 let mut carry = true;
60 for i in (0..self.dims.len()).rev() {
61 if carry {
62 self.current[i] += 1;
63 if self.current[i] >= self.dims[i] {
64 self.current[i] = 0;
65 } else {
66 carry = false;
67 }
68 }
69 }
70 }
71 self.count += 1;
72 Some(())
73 }
74}
75
76pub fn try_fuse_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
83 match dims.len() {
84 0 => Some((1, 0)),
85 1 => Some((dims[0], strides[0])),
86 n => {
87 if dims.len() != strides.len() {
88 return None;
89 }
90 for (&d, &s) in dims.iter().zip(strides.iter()) {
91 if d > 1 && s == 0 {
92 return None;
93 }
94 }
95
96 let mut pairs: Vec<(usize, isize)> =
97 dims.iter().copied().zip(strides.iter().copied()).collect();
98 pairs.sort_by_key(|&(d, s)| (s.unsigned_abs(), d));
99
100 for i in 0..n - 1 {
101 let (dim_i, stride_i) = pairs[i];
102 let (dim_next, stride_next) = pairs[i + 1];
103 if dim_i <= 1 || dim_next <= 1 {
104 continue;
105 }
106 let expected = stride_i.unsigned_abs().checked_mul(dim_i)?;
107 if stride_next.unsigned_abs() != expected {
108 return None;
109 }
110 }
111
112 Some((dims.iter().product(), pairs[0].1))
113 }
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_invert_perm() {
123 assert_eq!(invert_perm(&[2, 0, 1]), vec![1, 2, 0]);
124 assert_eq!(invert_perm(&[0, 1, 2]), vec![0, 1, 2]);
125 }
126
127 #[test]
128 fn test_multi_index_2d() {
129 let mut iter = MultiIndex::new(&[2, 3]);
130 let mut indices = vec![];
131 while iter.next().is_some() {
132 indices.push(iter.current.clone());
133 }
134 assert_eq!(
135 indices,
136 vec![
137 vec![0, 0],
138 vec![0, 1],
139 vec![0, 2],
140 vec![1, 0],
141 vec![1, 1],
142 vec![1, 2],
143 ]
144 );
145 }
146
147 #[test]
148 fn test_multi_index_offset() {
149 let mut iter = MultiIndex::new(&[2, 3]);
150 let strides = [3, 1]; let mut offsets = vec![];
152 while iter.next().is_some() {
153 offsets.push(iter.offset(&strides));
154 }
155 assert_eq!(offsets, vec![0, 1, 2, 3, 4, 5]);
156 }
157
158 #[test]
159 fn test_multi_index_empty() {
160 let mut iter = MultiIndex::new(&[]);
161 assert!(iter.next().is_some()); assert!(iter.next().is_none());
163 }
164
165 #[test]
166 fn test_try_fuse_group_empty() {
167 assert_eq!(try_fuse_group(&[], &[]), Some((1, 0)));
168 }
169
170 #[test]
171 fn test_try_fuse_group_single() {
172 assert_eq!(try_fuse_group(&[5], &[2]), Some((5, 2)));
173 }
174
175 #[test]
176 fn test_try_fuse_group_contiguous_row_major() {
177 assert_eq!(try_fuse_group(&[3, 4], &[4, 1]), Some((12, 1)));
179 }
180
181 #[test]
182 fn test_try_fuse_group_contiguous_col_major() {
183 assert_eq!(try_fuse_group(&[3, 4], &[1, 3]), Some((12, 1)));
185 }
186
187 #[test]
188 fn test_try_fuse_group_contiguous_3d() {
189 assert_eq!(try_fuse_group(&[2, 3, 4], &[12, 4, 1]), Some((24, 1)));
191 }
192
193 #[test]
194 fn test_try_fuse_group_non_contiguous() {
195 assert_eq!(try_fuse_group(&[3, 4], &[8, 1]), None);
197 }
198}