1use crate::util::invert_perm;
4use crate::AxisId;
5use crate::EinsumError;
6
7#[derive(Debug, Clone)]
12pub struct Einsum2Plan<ID: AxisId> {
13 pub batch: Vec<ID>,
15 pub lo: Vec<ID>,
17 pub ro: Vec<ID>,
19 pub sum: Vec<ID>,
21
22 pub left_perm: Vec<usize>,
24 pub right_perm: Vec<usize>,
26 pub c_to_internal_perm: Vec<usize>,
28}
29
30impl<ID: AxisId> Einsum2Plan<ID> {
31 pub fn new(ia: &[ID], ib: &[ID], ic: &[ID]) -> Result<Self, EinsumError> {
38 for (i, id) in ia.iter().enumerate() {
40 if ia[..i].iter().any(|x| x == id) {
41 return Err(EinsumError::DuplicateAxis(
42 "left operand has duplicate axis labels".into(),
43 ));
44 }
45 }
46 for (i, id) in ib.iter().enumerate() {
47 if ib[..i].iter().any(|x| x == id) {
48 return Err(EinsumError::DuplicateAxis(
49 "right operand has duplicate axis labels".into(),
50 ));
51 }
52 }
53 for (i, id) in ic.iter().enumerate() {
54 if ic[..i].iter().any(|x| x == id) {
55 return Err(EinsumError::DuplicateAxis(
56 "output has duplicate axis labels".into(),
57 ));
58 }
59 }
60
61 for id in ic {
63 if !ia.contains(id) && !ib.contains(id) {
64 return Err(EinsumError::OrphanOutputAxis(format!("{:?}", id)));
65 }
66 }
67
68 let mut batch = Vec::new();
69 let mut lo = Vec::new();
70 let mut sum = Vec::new();
71 let mut left_trace = Vec::new();
72
73 for id in ia {
74 if ib.contains(id) {
75 if ic.contains(id) {
76 batch.push(id.clone());
77 } else {
78 sum.push(id.clone());
79 }
80 } else if ic.contains(id) {
81 lo.push(id.clone());
82 } else {
83 left_trace.push(id.clone());
84 }
85 }
86
87 let mut ro = Vec::new();
88 let mut right_trace = Vec::new();
89
90 for id in ib {
91 if !ia.contains(id) {
92 if ic.contains(id) {
93 ro.push(id.clone());
94 } else {
95 right_trace.push(id.clone());
96 }
97 }
98 }
99
100 let ia_after_trace: Vec<&ID> = ia.iter().filter(|id| !left_trace.contains(id)).collect();
103 let left_perm: Vec<usize> = lo
104 .iter()
105 .chain(sum.iter())
106 .chain(batch.iter())
107 .map(|id| {
108 ia_after_trace
109 .iter()
110 .position(|aid| *aid == id)
111 .expect("left_perm: axis not found")
112 })
113 .collect();
114
115 let ib_after_trace: Vec<&ID> = ib.iter().filter(|id| !right_trace.contains(id)).collect();
117 let right_perm: Vec<usize> = sum
118 .iter()
119 .chain(ro.iter())
120 .chain(batch.iter())
121 .map(|id| {
122 ib_after_trace
123 .iter()
124 .position(|bid| *bid == id)
125 .expect("right_perm: axis not found")
126 })
127 .collect();
128
129 let c_to_internal_perm: Vec<usize> = lo
131 .iter()
132 .chain(ro.iter())
133 .chain(batch.iter())
134 .map(|id| {
135 ic.iter()
136 .position(|c_id| c_id == id)
137 .expect("c_to_internal_perm: axis not found")
138 })
139 .collect();
140
141 Ok(Einsum2Plan {
142 batch,
143 lo,
144 ro,
145 sum,
146 left_perm,
147 right_perm,
148 c_to_internal_perm,
149 })
150 }
151
152 pub fn internal_to_c_perm(&self) -> Vec<usize> {
154 invert_perm(&self.c_to_internal_perm)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_classify_matmul() {
164 let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[0u32, 2]).unwrap();
166 assert_eq!(plan.batch, vec![] as Vec<u32>);
167 assert_eq!(plan.lo, vec![0]);
168 assert_eq!(plan.ro, vec![2]);
169 assert_eq!(plan.sum, vec![1]);
170 }
171
172 #[test]
173 fn test_classify_batched_matmul() {
174 let plan = Einsum2Plan::new(&[0u32, 1, 2], &[0u32, 2, 3], &[0u32, 1, 3]).unwrap();
176 assert_eq!(plan.batch, vec![0]);
177 assert_eq!(plan.lo, vec![1]);
178 assert_eq!(plan.ro, vec![3]);
179 assert_eq!(plan.sum, vec![2]);
180 }
181
182 #[test]
183 fn test_classify_outer_product() {
184 let plan = Einsum2Plan::new(&[0u32], &[1u32], &[0u32, 1]).unwrap();
186 assert!(plan.batch.is_empty());
187 assert_eq!(plan.lo, vec![0]);
188 assert_eq!(plan.ro, vec![1]);
189 assert!(plan.sum.is_empty());
190 }
191
192 #[test]
193 fn test_classify_dot_product() {
194 let plan = Einsum2Plan::new(&[0u32], &[0u32], &[] as &[u32]).unwrap();
196 assert!(plan.batch.is_empty());
197 assert!(plan.lo.is_empty());
198 assert!(plan.ro.is_empty());
199 assert_eq!(plan.sum, vec![0]);
200 }
201
202 #[test]
203 fn test_classify_left_trace() {
204 let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[2u32]).unwrap();
207 assert!(plan.batch.is_empty());
208 assert!(plan.lo.is_empty());
209 assert_eq!(plan.ro, vec![2]);
210 assert_eq!(plan.sum, vec![1]);
211 }
212
213 #[test]
214 fn test_perm_matmul() {
215 let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[0u32, 2]).unwrap();
220 assert_eq!(plan.left_perm, vec![0, 1]);
221 assert_eq!(plan.right_perm, vec![0, 1]);
222 assert_eq!(plan.c_to_internal_perm, vec![0, 1]);
223 }
224
225 #[test]
226 fn test_perm_batched_transposed_output() {
227 let plan = Einsum2Plan::new(&[0u32, 1, 2], &[0u32, 2, 3], &[0u32, 3, 1]).unwrap();
229 assert_eq!(plan.batch, vec![0]);
230 assert_eq!(plan.lo, vec![1]);
231 assert_eq!(plan.ro, vec![3]);
232 assert_eq!(plan.sum, vec![2]);
233 assert_eq!(plan.left_perm, vec![1, 2, 0]);
236 assert_eq!(plan.right_perm, vec![1, 2, 0]);
239 assert_eq!(plan.c_to_internal_perm, vec![2, 1, 0]);
242 }
243
244 #[test]
245 fn test_error_orphan_output() {
246 let result = Einsum2Plan::new(&[0u32], &[1u32], &[0u32, 1, 2]);
247 assert!(result.is_err());
248 }
249
250 #[test]
251 fn test_error_duplicate() {
252 let result = Einsum2Plan::new(&[0u32, 0], &[1u32], &[0u32, 1]);
253 assert!(result.is_err());
254 }
255
256 #[test]
257 fn test_char_labels() {
258 let plan = Einsum2Plan::new(&['i', 'j'], &['j', 'k'], &['i', 'k']).unwrap();
259 assert_eq!(plan.lo, vec!['i']);
260 assert_eq!(plan.ro, vec!['k']);
261 assert_eq!(plan.sum, vec!['j']);
262 }
263}