strided_einsum2/
plan.rs

1//! Einsum2 plan: axis classification and permutation computation.
2
3use crate::util::invert_perm;
4use crate::AxisId;
5use crate::EinsumError;
6
7/// Pre-computed execution plan for a binary einsum contraction.
8///
9/// Classifies axes into groups and precomputes the permutations needed
10/// to arrange operands for batched matrix multiplication.
11#[derive(Debug, Clone)]
12pub struct Einsum2Plan<ID: AxisId> {
13    /// Batch axes: present in A, B, and C.
14    pub batch: Vec<ID>,
15    /// Left-output axes: present in A and C, not in B.
16    pub lo: Vec<ID>,
17    /// Right-output axes: present in B and C, not in A.
18    pub ro: Vec<ID>,
19    /// Contraction axes: present in A and B, not in C.
20    pub sum: Vec<ID>,
21
22    /// Permutation to reorder A to [lo, sum, batch] after trace reduction.
23    pub left_perm: Vec<usize>,
24    /// Permutation to reorder B to [sum, ro, batch] after trace reduction.
25    pub right_perm: Vec<usize>,
26    /// Permutation to reorder C from IC order to [lo, ro, batch] order.
27    pub c_to_internal_perm: Vec<usize>,
28}
29
30impl<ID: AxisId> Einsum2Plan<ID> {
31    /// Build a plan from axis labels.
32    ///
33    /// `ia`, `ib`, `ic` are the axis labels for A, B, C respectively.
34    ///
35    /// Uses linear scans instead of hash collections for axis classification,
36    /// which is faster for the small label sets typical in einsum contractions.
37    pub fn new(ia: &[ID], ib: &[ID], ic: &[ID]) -> Result<Self, EinsumError> {
38        // Validate: no duplicate axes within a single operand (linear scan)
39        for (i, id) in ia.iter().enumerate() {
40            if ia[..i].iter().any(|x| x == id) {
41                return Err(EinsumError::DuplicateAxis(
42                    "left operand has duplicate axis labels".into(),
43                ));
44            }
45        }
46        for (i, id) in ib.iter().enumerate() {
47            if ib[..i].iter().any(|x| x == id) {
48                return Err(EinsumError::DuplicateAxis(
49                    "right operand has duplicate axis labels".into(),
50                ));
51            }
52        }
53        for (i, id) in ic.iter().enumerate() {
54            if ic[..i].iter().any(|x| x == id) {
55                return Err(EinsumError::DuplicateAxis(
56                    "output has duplicate axis labels".into(),
57                ));
58            }
59        }
60
61        // Validate: every output axis must appear in at least one input
62        for id in ic {
63            if !ia.contains(id) && !ib.contains(id) {
64                return Err(EinsumError::OrphanOutputAxis(format!("{:?}", id)));
65            }
66        }
67
68        let mut batch = Vec::new();
69        let mut lo = Vec::new();
70        let mut sum = Vec::new();
71        let mut left_trace = Vec::new();
72
73        for id in ia {
74            if ib.contains(id) {
75                if ic.contains(id) {
76                    batch.push(id.clone());
77                } else {
78                    sum.push(id.clone());
79                }
80            } else if ic.contains(id) {
81                lo.push(id.clone());
82            } else {
83                left_trace.push(id.clone());
84            }
85        }
86
87        let mut ro = Vec::new();
88        let mut right_trace = Vec::new();
89
90        for id in ib {
91            if !ia.contains(id) {
92                if ic.contains(id) {
93                    ro.push(id.clone());
94                } else {
95                    right_trace.push(id.clone());
96                }
97            }
98        }
99
100        // Build left_perm: maps positions in ia (after trace removal) to [lo, sum, batch] order
101        // Use linear scan instead of HashMap — faster for small label sets.
102        let ia_after_trace: Vec<&ID> = ia.iter().filter(|id| !left_trace.contains(id)).collect();
103        let left_perm: Vec<usize> = lo
104            .iter()
105            .chain(sum.iter())
106            .chain(batch.iter())
107            .map(|id| {
108                ia_after_trace
109                    .iter()
110                    .position(|aid| *aid == id)
111                    .expect("left_perm: axis not found")
112            })
113            .collect();
114
115        // Build right_perm: maps positions in ib (after trace removal) to [sum, ro, batch] order
116        let ib_after_trace: Vec<&ID> = ib.iter().filter(|id| !right_trace.contains(id)).collect();
117        let right_perm: Vec<usize> = sum
118            .iter()
119            .chain(ro.iter())
120            .chain(batch.iter())
121            .map(|id| {
122                ib_after_trace
123                    .iter()
124                    .position(|bid| *bid == id)
125                    .expect("right_perm: axis not found")
126            })
127            .collect();
128
129        // Build c_to_internal_perm: maps IC order to [lo, ro, batch] order
130        let c_to_internal_perm: Vec<usize> = lo
131            .iter()
132            .chain(ro.iter())
133            .chain(batch.iter())
134            .map(|id| {
135                ic.iter()
136                    .position(|c_id| c_id == id)
137                    .expect("c_to_internal_perm: axis not found")
138            })
139            .collect();
140
141        Ok(Einsum2Plan {
142            batch,
143            lo,
144            ro,
145            sum,
146            left_perm,
147            right_perm,
148            c_to_internal_perm,
149        })
150    }
151
152    /// Get the inverse of c_to_internal_perm (maps `\[batch, lo, ro\]` back to IC order).
153    pub fn internal_to_c_perm(&self) -> Vec<usize> {
154        invert_perm(&self.c_to_internal_perm)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_classify_matmul() {
164        // ij,jk->ik
165        let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[0u32, 2]).unwrap();
166        assert_eq!(plan.batch, vec![] as Vec<u32>);
167        assert_eq!(plan.lo, vec![0]);
168        assert_eq!(plan.ro, vec![2]);
169        assert_eq!(plan.sum, vec![1]);
170    }
171
172    #[test]
173    fn test_classify_batched_matmul() {
174        // bij,bjk->bik
175        let plan = Einsum2Plan::new(&[0u32, 1, 2], &[0u32, 2, 3], &[0u32, 1, 3]).unwrap();
176        assert_eq!(plan.batch, vec![0]);
177        assert_eq!(plan.lo, vec![1]);
178        assert_eq!(plan.ro, vec![3]);
179        assert_eq!(plan.sum, vec![2]);
180    }
181
182    #[test]
183    fn test_classify_outer_product() {
184        // i,j->ij
185        let plan = Einsum2Plan::new(&[0u32], &[1u32], &[0u32, 1]).unwrap();
186        assert!(plan.batch.is_empty());
187        assert_eq!(plan.lo, vec![0]);
188        assert_eq!(plan.ro, vec![1]);
189        assert!(plan.sum.is_empty());
190    }
191
192    #[test]
193    fn test_classify_dot_product() {
194        // i,i->
195        let plan = Einsum2Plan::new(&[0u32], &[0u32], &[] as &[u32]).unwrap();
196        assert!(plan.batch.is_empty());
197        assert!(plan.lo.is_empty());
198        assert!(plan.ro.is_empty());
199        assert_eq!(plan.sum, vec![0]);
200    }
201
202    #[test]
203    fn test_classify_left_trace() {
204        // ij,jk->k: lo=[], ro=[k], sum=[j]
205        // Trace axis i is detected lazily at runtime
206        let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[2u32]).unwrap();
207        assert!(plan.batch.is_empty());
208        assert!(plan.lo.is_empty());
209        assert_eq!(plan.ro, vec![2]);
210        assert_eq!(plan.sum, vec![1]);
211    }
212
213    #[test]
214    fn test_perm_matmul() {
215        // ij,jk->ik
216        // A: [i, j] -> [lo=[i], sum=[j], batch=[]] = [i, j] => perm [0, 1]
217        // B: [j, k] -> [sum=[j], ro=[k], batch=[]] = [j, k] => perm [0, 1]
218        // C: [i, k] -> [lo=[i], ro=[k], batch=[]] = [i, k] => perm [0, 1]
219        let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[0u32, 2]).unwrap();
220        assert_eq!(plan.left_perm, vec![0, 1]);
221        assert_eq!(plan.right_perm, vec![0, 1]);
222        assert_eq!(plan.c_to_internal_perm, vec![0, 1]);
223    }
224
225    #[test]
226    fn test_perm_batched_transposed_output() {
227        // bij,bjk->bki (output has transposed lo/ro)
228        let plan = Einsum2Plan::new(&[0u32, 1, 2], &[0u32, 2, 3], &[0u32, 3, 1]).unwrap();
229        assert_eq!(plan.batch, vec![0]);
230        assert_eq!(plan.lo, vec![1]);
231        assert_eq!(plan.ro, vec![3]);
232        assert_eq!(plan.sum, vec![2]);
233        // A: ia=[b=0, i=1, j=2], after trace removal=[b, i, j]
234        // target [lo, sum, batch] = [i=1, j=2, b=0] -> positions in ia_after_trace: [1, 2, 0]
235        assert_eq!(plan.left_perm, vec![1, 2, 0]);
236        // B: ib=[b=0, j=2, k=3], after trace removal=[b, j, k]
237        // target [sum, ro, batch] = [j=2, k=3, b=0] -> positions: [1, 2, 0]
238        assert_eq!(plan.right_perm, vec![1, 2, 0]);
239        // C IC order: [b=0, k=3, i=1]
240        // target [lo, ro, batch] = [i=1, k=3, b=0] -> IC positions: [2, 1, 0]
241        assert_eq!(plan.c_to_internal_perm, vec![2, 1, 0]);
242    }
243
244    #[test]
245    fn test_error_orphan_output() {
246        let result = Einsum2Plan::new(&[0u32], &[1u32], &[0u32, 1, 2]);
247        assert!(result.is_err());
248    }
249
250    #[test]
251    fn test_error_duplicate() {
252        let result = Einsum2Plan::new(&[0u32, 0], &[1u32], &[0u32, 1]);
253        assert!(result.is_err());
254    }
255
256    #[test]
257    fn test_char_labels() {
258        let plan = Einsum2Plan::new(&['i', 'j'], &['j', 'k'], &['i', 'k']).unwrap();
259        assert_eq!(plan.lo, vec!['i']);
260        assert_eq!(plan.ro, vec!['k']);
261        assert_eq!(plan.sum, vec!['j']);
262    }
263}