strided_einsum2/
util.rs

1//! Shared helpers for strided-einsum2.
2
3/// Invert a permutation: if `perm[i] = j`, then `result[j] = i`.
4pub fn invert_perm(perm: &[usize]) -> Vec<usize> {
5    let mut inv = vec![0usize; perm.len()];
6    for (i, &p) in perm.iter().enumerate() {
7        inv[p] = i;
8    }
9    inv
10}
11
12/// Iterator over multi-dimensional index tuples within given dimensions.
13///
14/// Iterates in row-major order (last index varies fastest).
15pub struct MultiIndex {
16    dims: Vec<usize>,
17    current: Vec<usize>,
18    total: usize,
19    count: usize,
20}
21
22impl MultiIndex {
23    /// Create a new iterator over all index tuples for the given dimensions.
24    pub fn new(dims: &[usize]) -> Self {
25        let total: usize = dims.iter().product();
26        Self {
27            dims: dims.to_vec(),
28            current: vec![0; dims.len()],
29            total,
30            count: 0,
31        }
32    }
33
34    /// Compute byte offset from current indices and given strides.
35    pub fn offset(&self, strides: &[isize]) -> isize {
36        self.current
37            .iter()
38            .zip(strides.iter())
39            .map(|(&i, &s)| i as isize * s)
40            .sum()
41    }
42
43    /// Reset the iterator to the beginning.
44    pub fn reset(&mut self) {
45        self.current.fill(0);
46        self.count = 0;
47    }
48}
49
50impl Iterator for MultiIndex {
51    type Item = ();
52
53    fn next(&mut self) -> Option<()> {
54        if self.count >= self.total {
55            return None;
56        }
57        if self.count > 0 {
58            // Increment: last index varies fastest (row-major)
59            let mut carry = true;
60            for i in (0..self.dims.len()).rev() {
61                if carry {
62                    self.current[i] += 1;
63                    if self.current[i] >= self.dims[i] {
64                        self.current[i] = 0;
65                    } else {
66                        carry = false;
67                    }
68                }
69            }
70        }
71        self.count += 1;
72        Some(())
73    }
74}
75
76/// Try to fuse a contiguous dimension group into a single (total_size, innermost_stride).
77///
78/// For the group to be fusable, consecutive dimensions must have contiguous strides:
79/// `stride[i] == stride[i+1] * dim[i+1]` for all i.
80///
81/// Returns `None` if strides are not contiguous within the group.
82pub fn try_fuse_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
83    match dims.len() {
84        0 => Some((1, 0)),
85        1 => Some((dims[0], strides[0])),
86        n => {
87            if dims.len() != strides.len() {
88                return None;
89            }
90            for (&d, &s) in dims.iter().zip(strides.iter()) {
91                if d > 1 && s == 0 {
92                    return None;
93                }
94            }
95
96            let mut pairs: Vec<(usize, isize)> =
97                dims.iter().copied().zip(strides.iter().copied()).collect();
98            pairs.sort_by_key(|&(d, s)| (s.unsigned_abs(), d));
99
100            for i in 0..n - 1 {
101                let (dim_i, stride_i) = pairs[i];
102                let (dim_next, stride_next) = pairs[i + 1];
103                if dim_i <= 1 || dim_next <= 1 {
104                    continue;
105                }
106                let expected = stride_i.unsigned_abs().checked_mul(dim_i)?;
107                if stride_next.unsigned_abs() != expected {
108                    return None;
109                }
110            }
111
112            Some((dims.iter().product(), pairs[0].1))
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_invert_perm() {
123        assert_eq!(invert_perm(&[2, 0, 1]), vec![1, 2, 0]);
124        assert_eq!(invert_perm(&[0, 1, 2]), vec![0, 1, 2]);
125    }
126
127    #[test]
128    fn test_multi_index_2d() {
129        let mut iter = MultiIndex::new(&[2, 3]);
130        let mut indices = vec![];
131        while iter.next().is_some() {
132            indices.push(iter.current.clone());
133        }
134        assert_eq!(
135            indices,
136            vec![
137                vec![0, 0],
138                vec![0, 1],
139                vec![0, 2],
140                vec![1, 0],
141                vec![1, 1],
142                vec![1, 2],
143            ]
144        );
145    }
146
147    #[test]
148    fn test_multi_index_offset() {
149        let mut iter = MultiIndex::new(&[2, 3]);
150        let strides = [3, 1]; // row-major strides
151        let mut offsets = vec![];
152        while iter.next().is_some() {
153            offsets.push(iter.offset(&strides));
154        }
155        assert_eq!(offsets, vec![0, 1, 2, 3, 4, 5]);
156    }
157
158    #[test]
159    fn test_multi_index_empty() {
160        let mut iter = MultiIndex::new(&[]);
161        assert!(iter.next().is_some()); // single scalar iteration
162        assert!(iter.next().is_none());
163    }
164
165    #[test]
166    fn test_try_fuse_group_empty() {
167        assert_eq!(try_fuse_group(&[], &[]), Some((1, 0)));
168    }
169
170    #[test]
171    fn test_try_fuse_group_single() {
172        assert_eq!(try_fuse_group(&[5], &[2]), Some((5, 2)));
173    }
174
175    #[test]
176    fn test_try_fuse_group_contiguous_row_major() {
177        // 3x4 row-major: strides [4, 1]
178        assert_eq!(try_fuse_group(&[3, 4], &[4, 1]), Some((12, 1)));
179    }
180
181    #[test]
182    fn test_try_fuse_group_contiguous_col_major() {
183        // 3x4 col-major: strides [1, 3]
184        assert_eq!(try_fuse_group(&[3, 4], &[1, 3]), Some((12, 1)));
185    }
186
187    #[test]
188    fn test_try_fuse_group_contiguous_3d() {
189        // 2x3x4 row-major: strides [12, 4, 1]
190        assert_eq!(try_fuse_group(&[2, 3, 4], &[12, 4, 1]), Some((24, 1)));
191    }
192
193    #[test]
194    fn test_try_fuse_group_non_contiguous() {
195        // strides don't follow contiguity rule
196        assert_eq!(try_fuse_group(&[3, 4], &[8, 1]), None);
197    }
198}