strided_einsum2/
plan.rs

1//! Einsum2 plan: axis classification and permutation computation.
2
3use crate::util::invert_perm;
4use crate::AxisId;
5use crate::EinsumError;
6
7/// Pre-computed execution plan for a binary einsum contraction.
8///
9/// Classifies axes into groups and precomputes the permutations needed
10/// to arrange operands for batched matrix multiplication.
11#[derive(Debug, Clone)]
12pub struct Einsum2Plan<ID: AxisId> {
13    /// Batch axes: present in A, B, and C.
14    pub batch: Vec<ID>,
15    /// Left-output axes: present in A and C, not in B.
16    pub lo: Vec<ID>,
17    /// Right-output axes: present in B and C, not in A.
18    pub ro: Vec<ID>,
19    /// Contraction axes: present in A and B, not in C.
20    pub sum: Vec<ID>,
21    /// Left trace axes: present only in A.
22    pub left_trace: Vec<ID>,
23    /// Right trace axes: present only in B.
24    pub right_trace: Vec<ID>,
25
26    /// Permutation to reorder A to [lo, sum, batch] after trace reduction.
27    pub left_perm: Vec<usize>,
28    /// Permutation to reorder B to [sum, ro, batch] after trace reduction.
29    pub right_perm: Vec<usize>,
30    /// Permutation to reorder C from IC order to [lo, ro, batch] order.
31    pub c_to_internal_perm: Vec<usize>,
32}
33
34impl<ID: AxisId> Einsum2Plan<ID> {
35    /// Build a plan from axis labels.
36    ///
37    /// `ia`, `ib`, `ic` are the axis labels for A, B, C respectively.
38    ///
39    /// Uses linear scans instead of hash collections for axis classification,
40    /// which is faster for the small label sets typical in einsum contractions.
41    pub fn new(ia: &[ID], ib: &[ID], ic: &[ID]) -> Result<Self, EinsumError> {
42        // Validate: no duplicate axes within a single operand (linear scan)
43        for (i, id) in ia.iter().enumerate() {
44            if ia[..i].iter().any(|x| x == id) {
45                return Err(EinsumError::DuplicateAxis(
46                    "left operand has duplicate axis labels".into(),
47                ));
48            }
49        }
50        for (i, id) in ib.iter().enumerate() {
51            if ib[..i].iter().any(|x| x == id) {
52                return Err(EinsumError::DuplicateAxis(
53                    "right operand has duplicate axis labels".into(),
54                ));
55            }
56        }
57        for (i, id) in ic.iter().enumerate() {
58            if ic[..i].iter().any(|x| x == id) {
59                return Err(EinsumError::DuplicateAxis(
60                    "output has duplicate axis labels".into(),
61                ));
62            }
63        }
64
65        // Validate: every output axis must appear in at least one input
66        for id in ic {
67            if !ia.contains(id) && !ib.contains(id) {
68                return Err(EinsumError::OrphanOutputAxis(format!("{:?}", id)));
69            }
70        }
71
72        let mut batch = Vec::new();
73        let mut lo = Vec::new();
74        let mut sum = Vec::new();
75        let mut left_trace = Vec::new();
76
77        for id in ia {
78            if ib.contains(id) {
79                if ic.contains(id) {
80                    batch.push(id.clone());
81                } else {
82                    sum.push(id.clone());
83                }
84            } else if ic.contains(id) {
85                lo.push(id.clone());
86            } else {
87                left_trace.push(id.clone());
88            }
89        }
90
91        let mut ro = Vec::new();
92        let mut right_trace = Vec::new();
93
94        for id in ib {
95            if !ia.contains(id) {
96                if ic.contains(id) {
97                    ro.push(id.clone());
98                } else {
99                    right_trace.push(id.clone());
100                }
101            }
102        }
103
104        // Build left_perm: maps positions in ia (after trace removal) to [lo, sum, batch] order
105        // Use linear scan instead of HashMap — faster for small label sets.
106        let ia_after_trace: Vec<&ID> = ia.iter().filter(|id| !left_trace.contains(id)).collect();
107        let left_perm: Vec<usize> = lo
108            .iter()
109            .chain(sum.iter())
110            .chain(batch.iter())
111            .map(|id| {
112                ia_after_trace
113                    .iter()
114                    .position(|aid| *aid == id)
115                    .expect("left_perm: axis not found")
116            })
117            .collect();
118
119        // Build right_perm: maps positions in ib (after trace removal) to [sum, ro, batch] order
120        let ib_after_trace: Vec<&ID> = ib.iter().filter(|id| !right_trace.contains(id)).collect();
121        let right_perm: Vec<usize> = sum
122            .iter()
123            .chain(ro.iter())
124            .chain(batch.iter())
125            .map(|id| {
126                ib_after_trace
127                    .iter()
128                    .position(|bid| *bid == id)
129                    .expect("right_perm: axis not found")
130            })
131            .collect();
132
133        // Build c_to_internal_perm: maps IC order to [lo, ro, batch] order
134        let c_to_internal_perm: Vec<usize> = lo
135            .iter()
136            .chain(ro.iter())
137            .chain(batch.iter())
138            .map(|id| {
139                ic.iter()
140                    .position(|c_id| c_id == id)
141                    .expect("c_to_internal_perm: axis not found")
142            })
143            .collect();
144
145        Ok(Einsum2Plan {
146            batch,
147            lo,
148            ro,
149            sum,
150            left_trace,
151            right_trace,
152            left_perm,
153            right_perm,
154            c_to_internal_perm,
155        })
156    }
157
158    /// Get the indices of left_trace axes in the original `ia` array.
159    pub fn left_trace_indices(&self, ia: &[ID]) -> Vec<usize> {
160        self.left_trace
161            .iter()
162            .filter_map(|id| ia.iter().position(|x| x == id))
163            .collect()
164    }
165
166    /// Get the indices of right_trace axes in the original `ib` array.
167    pub fn right_trace_indices(&self, ib: &[ID]) -> Vec<usize> {
168        self.right_trace
169            .iter()
170            .filter_map(|id| ib.iter().position(|x| x == id))
171            .collect()
172    }
173
174    /// Get the inverse of c_to_internal_perm (maps `\[batch, lo, ro\]` back to IC order).
175    pub fn internal_to_c_perm(&self) -> Vec<usize> {
176        invert_perm(&self.c_to_internal_perm)
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_classify_matmul() {
186        // ij,jk->ik
187        let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[0u32, 2]).unwrap();
188        assert_eq!(plan.batch, vec![] as Vec<u32>);
189        assert_eq!(plan.lo, vec![0]);
190        assert_eq!(plan.ro, vec![2]);
191        assert_eq!(plan.sum, vec![1]);
192        assert!(plan.left_trace.is_empty());
193        assert!(plan.right_trace.is_empty());
194    }
195
196    #[test]
197    fn test_classify_batched_matmul() {
198        // bij,bjk->bik
199        let plan = Einsum2Plan::new(&[0u32, 1, 2], &[0u32, 2, 3], &[0u32, 1, 3]).unwrap();
200        assert_eq!(plan.batch, vec![0]);
201        assert_eq!(plan.lo, vec![1]);
202        assert_eq!(plan.ro, vec![3]);
203        assert_eq!(plan.sum, vec![2]);
204    }
205
206    #[test]
207    fn test_classify_outer_product() {
208        // i,j->ij
209        let plan = Einsum2Plan::new(&[0u32], &[1u32], &[0u32, 1]).unwrap();
210        assert!(plan.batch.is_empty());
211        assert_eq!(plan.lo, vec![0]);
212        assert_eq!(plan.ro, vec![1]);
213        assert!(plan.sum.is_empty());
214    }
215
216    #[test]
217    fn test_classify_dot_product() {
218        // i,i->
219        let plan = Einsum2Plan::new(&[0u32], &[0u32], &[] as &[u32]).unwrap();
220        assert!(plan.batch.is_empty());
221        assert!(plan.lo.is_empty());
222        assert!(plan.ro.is_empty());
223        assert_eq!(plan.sum, vec![0]);
224    }
225
226    #[test]
227    fn test_classify_left_trace() {
228        // ij,jk->k: lo=[], ro=[k], sum=[j], left_trace=[i]
229        let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[2u32]).unwrap();
230        assert!(plan.batch.is_empty());
231        assert!(plan.lo.is_empty());
232        assert_eq!(plan.ro, vec![2]);
233        assert_eq!(plan.sum, vec![1]);
234        assert_eq!(plan.left_trace, vec![0]);
235    }
236
237    #[test]
238    fn test_perm_matmul() {
239        // ij,jk->ik
240        // A: [i, j] -> [lo=[i], sum=[j], batch=[]] = [i, j] => perm [0, 1]
241        // B: [j, k] -> [sum=[j], ro=[k], batch=[]] = [j, k] => perm [0, 1]
242        // C: [i, k] -> [lo=[i], ro=[k], batch=[]] = [i, k] => perm [0, 1]
243        let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[0u32, 2]).unwrap();
244        assert_eq!(plan.left_perm, vec![0, 1]);
245        assert_eq!(plan.right_perm, vec![0, 1]);
246        assert_eq!(plan.c_to_internal_perm, vec![0, 1]);
247    }
248
249    #[test]
250    fn test_perm_batched_transposed_output() {
251        // bij,bjk->bki (output has transposed lo/ro)
252        let plan = Einsum2Plan::new(&[0u32, 1, 2], &[0u32, 2, 3], &[0u32, 3, 1]).unwrap();
253        assert_eq!(plan.batch, vec![0]);
254        assert_eq!(plan.lo, vec![1]);
255        assert_eq!(plan.ro, vec![3]);
256        assert_eq!(plan.sum, vec![2]);
257        // A: ia=[b=0, i=1, j=2], after trace removal=[b, i, j]
258        // target [lo, sum, batch] = [i=1, j=2, b=0] -> positions in ia_after_trace: [1, 2, 0]
259        assert_eq!(plan.left_perm, vec![1, 2, 0]);
260        // B: ib=[b=0, j=2, k=3], after trace removal=[b, j, k]
261        // target [sum, ro, batch] = [j=2, k=3, b=0] -> positions: [1, 2, 0]
262        assert_eq!(plan.right_perm, vec![1, 2, 0]);
263        // C IC order: [b=0, k=3, i=1]
264        // target [lo, ro, batch] = [i=1, k=3, b=0] -> IC positions: [2, 1, 0]
265        assert_eq!(plan.c_to_internal_perm, vec![2, 1, 0]);
266    }
267
268    #[test]
269    fn test_error_orphan_output() {
270        let result = Einsum2Plan::new(&[0u32], &[1u32], &[0u32, 1, 2]);
271        assert!(result.is_err());
272    }
273
274    #[test]
275    fn test_error_duplicate() {
276        let result = Einsum2Plan::new(&[0u32, 0], &[1u32], &[0u32, 1]);
277        assert!(result.is_err());
278    }
279
280    #[test]
281    fn test_char_labels() {
282        let plan = Einsum2Plan::new(&['i', 'j'], &['j', 'k'], &['i', 'k']).unwrap();
283        assert_eq!(plan.lo, vec!['i']);
284        assert_eq!(plan.ro, vec!['k']);
285        assert_eq!(plan.sum, vec!['j']);
286    }
287}