strided_einsum2/
util.rs

1//! Shared helpers for strided-einsum2.
2
3/// Invert a permutation: if `perm[i] = j`, then `result[j] = i`.
4pub fn invert_perm(perm: &[usize]) -> Vec<usize> {
5    let mut inv = vec![0usize; perm.len()];
6    for (i, &p) in perm.iter().enumerate() {
7        inv[p] = i;
8    }
9    inv
10}
11
12/// Iterator over multi-dimensional index tuples within given dimensions.
13///
14/// Iterates in row-major order (last index varies fastest).
15pub struct MultiIndex {
16    dims: Vec<usize>,
17    current: Vec<usize>,
18    total: usize,
19    count: usize,
20}
21
22impl MultiIndex {
23    /// Create a new iterator over all index tuples for the given dimensions.
24    pub fn new(dims: &[usize]) -> Self {
25        let total: usize = dims.iter().product();
26        Self {
27            dims: dims.to_vec(),
28            current: vec![0; dims.len()],
29            total,
30            count: 0,
31        }
32    }
33
34    /// Compute byte offset from current indices and given strides.
35    pub fn offset(&self, strides: &[isize]) -> isize {
36        self.current
37            .iter()
38            .zip(strides.iter())
39            .map(|(&i, &s)| i as isize * s)
40            .sum()
41    }
42
43    /// Reset the iterator to the beginning.
44    pub fn reset(&mut self) {
45        self.current.fill(0);
46        self.count = 0;
47    }
48}
49
50impl Iterator for MultiIndex {
51    type Item = ();
52
53    fn next(&mut self) -> Option<()> {
54        if self.count >= self.total {
55            return None;
56        }
57        if self.count > 0 {
58            // Increment: last index varies fastest (row-major)
59            let mut carry = true;
60            for i in (0..self.dims.len()).rev() {
61                if carry {
62                    self.current[i] += 1;
63                    if self.current[i] >= self.dims[i] {
64                        self.current[i] = 0;
65                    } else {
66                        carry = false;
67                    }
68                }
69            }
70        }
71        self.count += 1;
72        Some(())
73    }
74}
75
76/// Try to fuse a contiguous dimension group into a single (total_size, innermost_stride).
77///
78/// For the group to be fusable, consecutive dimensions must have contiguous strides:
79/// `stride[i] == stride[i+1] * dim[i+1]` for all i.
80///
81/// Returns `None` if strides are not contiguous within the group.
82pub fn try_fuse_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
83    match dims.len() {
84        0 => Some((1, 0)),
85        1 => Some((dims[0], strides[0])),
86        _ => {
87            if dims.len() != strides.len() {
88                return None;
89            }
90            for (&d, &s) in dims.iter().zip(strides.iter()) {
91                if d > 1 && s == 0 {
92                    return None;
93                }
94            }
95            // Ignore size-1 axes for contiguity checks; they do not constrain layout.
96            // Find the first non-singleton axis with minimum absolute stride.
97            let mut base_idx: Option<usize> = None;
98            let mut base_abs = usize::MAX;
99            for (i, (&d, &s)) in dims.iter().zip(strides.iter()).enumerate() {
100                if d <= 1 {
101                    continue;
102                }
103                let abs = s.unsigned_abs();
104                if abs < base_abs {
105                    base_abs = abs;
106                    base_idx = Some(i);
107                }
108            }
109
110            // All singleton axes: trivially fusable.
111            let Some(base) = base_idx else {
112                let stride = *strides
113                    .iter()
114                    .min_by_key(|s| s.unsigned_abs())
115                    .unwrap_or(&0);
116                return Some((dims.iter().product(), stride));
117            };
118
119            let mut used = vec![false; dims.len()];
120            used[base] = true;
121            let mut expected_abs = base_abs.checked_mul(dims[base])?;
122
123            // Reconstruct increasing-stride contiguous chain without sorting.
124            let non_singleton = dims.iter().filter(|&&d| d > 1).count();
125            for _ in 1..non_singleton {
126                let mut next = None;
127                for i in 0..dims.len() {
128                    if used[i] || dims[i] <= 1 {
129                        continue;
130                    }
131                    if strides[i].unsigned_abs() == expected_abs {
132                        next = Some(i);
133                        break;
134                    }
135                }
136                let i = next?;
137                used[i] = true;
138                expected_abs = expected_abs.checked_mul(dims[i])?;
139            }
140
141            Some((dims.iter().product(), strides[base]))
142        }
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_invert_perm() {
152        assert_eq!(invert_perm(&[2, 0, 1]), vec![1, 2, 0]);
153        assert_eq!(invert_perm(&[0, 1, 2]), vec![0, 1, 2]);
154    }
155
156    #[test]
157    fn test_multi_index_2d() {
158        let mut iter = MultiIndex::new(&[2, 3]);
159        let mut indices = vec![];
160        while iter.next().is_some() {
161            indices.push(iter.current.clone());
162        }
163        assert_eq!(
164            indices,
165            vec![
166                vec![0, 0],
167                vec![0, 1],
168                vec![0, 2],
169                vec![1, 0],
170                vec![1, 1],
171                vec![1, 2],
172            ]
173        );
174    }
175
176    #[test]
177    fn test_multi_index_offset() {
178        let mut iter = MultiIndex::new(&[2, 3]);
179        let strides = [3, 1]; // row-major strides
180        let mut offsets = vec![];
181        while iter.next().is_some() {
182            offsets.push(iter.offset(&strides));
183        }
184        assert_eq!(offsets, vec![0, 1, 2, 3, 4, 5]);
185    }
186
187    #[test]
188    fn test_multi_index_empty() {
189        let mut iter = MultiIndex::new(&[]);
190        assert!(iter.next().is_some()); // single scalar iteration
191        assert!(iter.next().is_none());
192    }
193
194    #[test]
195    fn test_try_fuse_group_empty() {
196        assert_eq!(try_fuse_group(&[], &[]), Some((1, 0)));
197    }
198
199    #[test]
200    fn test_try_fuse_group_single() {
201        assert_eq!(try_fuse_group(&[5], &[2]), Some((5, 2)));
202    }
203
204    #[test]
205    fn test_try_fuse_group_contiguous_row_major() {
206        // 3x4 row-major: strides [4, 1]
207        assert_eq!(try_fuse_group(&[3, 4], &[4, 1]), Some((12, 1)));
208    }
209
210    #[test]
211    fn test_try_fuse_group_contiguous_col_major() {
212        // 3x4 col-major: strides [1, 3]
213        assert_eq!(try_fuse_group(&[3, 4], &[1, 3]), Some((12, 1)));
214    }
215
216    #[test]
217    fn test_try_fuse_group_contiguous_3d() {
218        // 2x3x4 row-major: strides [12, 4, 1]
219        assert_eq!(try_fuse_group(&[2, 3, 4], &[12, 4, 1]), Some((24, 1)));
220    }
221
222    #[test]
223    fn test_try_fuse_group_non_contiguous() {
224        // strides don't follow contiguity rule
225        assert_eq!(try_fuse_group(&[3, 4], &[8, 1]), None);
226    }
227}