1use crate::util::invert_perm;
4use crate::AxisId;
5use crate::EinsumError;
6
7#[derive(Debug, Clone)]
12pub struct Einsum2Plan<ID: AxisId> {
13 pub batch: Vec<ID>,
15 pub lo: Vec<ID>,
17 pub ro: Vec<ID>,
19 pub sum: Vec<ID>,
21 pub left_trace: Vec<ID>,
23 pub right_trace: Vec<ID>,
25
26 pub left_perm: Vec<usize>,
28 pub right_perm: Vec<usize>,
30 pub c_to_internal_perm: Vec<usize>,
32}
33
34impl<ID: AxisId> Einsum2Plan<ID> {
35 pub fn new(ia: &[ID], ib: &[ID], ic: &[ID]) -> Result<Self, EinsumError> {
42 for (i, id) in ia.iter().enumerate() {
44 if ia[..i].iter().any(|x| x == id) {
45 return Err(EinsumError::DuplicateAxis(
46 "left operand has duplicate axis labels".into(),
47 ));
48 }
49 }
50 for (i, id) in ib.iter().enumerate() {
51 if ib[..i].iter().any(|x| x == id) {
52 return Err(EinsumError::DuplicateAxis(
53 "right operand has duplicate axis labels".into(),
54 ));
55 }
56 }
57 for (i, id) in ic.iter().enumerate() {
58 if ic[..i].iter().any(|x| x == id) {
59 return Err(EinsumError::DuplicateAxis(
60 "output has duplicate axis labels".into(),
61 ));
62 }
63 }
64
65 for id in ic {
67 if !ia.contains(id) && !ib.contains(id) {
68 return Err(EinsumError::OrphanOutputAxis(format!("{:?}", id)));
69 }
70 }
71
72 let mut batch = Vec::new();
73 let mut lo = Vec::new();
74 let mut sum = Vec::new();
75 let mut left_trace = Vec::new();
76
77 for id in ia {
78 if ib.contains(id) {
79 if ic.contains(id) {
80 batch.push(id.clone());
81 } else {
82 sum.push(id.clone());
83 }
84 } else if ic.contains(id) {
85 lo.push(id.clone());
86 } else {
87 left_trace.push(id.clone());
88 }
89 }
90
91 let mut ro = Vec::new();
92 let mut right_trace = Vec::new();
93
94 for id in ib {
95 if !ia.contains(id) {
96 if ic.contains(id) {
97 ro.push(id.clone());
98 } else {
99 right_trace.push(id.clone());
100 }
101 }
102 }
103
104 let ia_after_trace: Vec<&ID> = ia.iter().filter(|id| !left_trace.contains(id)).collect();
107 let left_perm: Vec<usize> = lo
108 .iter()
109 .chain(sum.iter())
110 .chain(batch.iter())
111 .map(|id| {
112 ia_after_trace
113 .iter()
114 .position(|aid| *aid == id)
115 .expect("left_perm: axis not found")
116 })
117 .collect();
118
119 let ib_after_trace: Vec<&ID> = ib.iter().filter(|id| !right_trace.contains(id)).collect();
121 let right_perm: Vec<usize> = sum
122 .iter()
123 .chain(ro.iter())
124 .chain(batch.iter())
125 .map(|id| {
126 ib_after_trace
127 .iter()
128 .position(|bid| *bid == id)
129 .expect("right_perm: axis not found")
130 })
131 .collect();
132
133 let c_to_internal_perm: Vec<usize> = lo
135 .iter()
136 .chain(ro.iter())
137 .chain(batch.iter())
138 .map(|id| {
139 ic.iter()
140 .position(|c_id| c_id == id)
141 .expect("c_to_internal_perm: axis not found")
142 })
143 .collect();
144
145 Ok(Einsum2Plan {
146 batch,
147 lo,
148 ro,
149 sum,
150 left_trace,
151 right_trace,
152 left_perm,
153 right_perm,
154 c_to_internal_perm,
155 })
156 }
157
158 pub fn left_trace_indices(&self, ia: &[ID]) -> Vec<usize> {
160 self.left_trace
161 .iter()
162 .filter_map(|id| ia.iter().position(|x| x == id))
163 .collect()
164 }
165
166 pub fn right_trace_indices(&self, ib: &[ID]) -> Vec<usize> {
168 self.right_trace
169 .iter()
170 .filter_map(|id| ib.iter().position(|x| x == id))
171 .collect()
172 }
173
174 pub fn internal_to_c_perm(&self) -> Vec<usize> {
176 invert_perm(&self.c_to_internal_perm)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_classify_matmul() {
186 let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[0u32, 2]).unwrap();
188 assert_eq!(plan.batch, vec![] as Vec<u32>);
189 assert_eq!(plan.lo, vec![0]);
190 assert_eq!(plan.ro, vec![2]);
191 assert_eq!(plan.sum, vec![1]);
192 assert!(plan.left_trace.is_empty());
193 assert!(plan.right_trace.is_empty());
194 }
195
196 #[test]
197 fn test_classify_batched_matmul() {
198 let plan = Einsum2Plan::new(&[0u32, 1, 2], &[0u32, 2, 3], &[0u32, 1, 3]).unwrap();
200 assert_eq!(plan.batch, vec![0]);
201 assert_eq!(plan.lo, vec![1]);
202 assert_eq!(plan.ro, vec![3]);
203 assert_eq!(plan.sum, vec![2]);
204 }
205
206 #[test]
207 fn test_classify_outer_product() {
208 let plan = Einsum2Plan::new(&[0u32], &[1u32], &[0u32, 1]).unwrap();
210 assert!(plan.batch.is_empty());
211 assert_eq!(plan.lo, vec![0]);
212 assert_eq!(plan.ro, vec![1]);
213 assert!(plan.sum.is_empty());
214 }
215
216 #[test]
217 fn test_classify_dot_product() {
218 let plan = Einsum2Plan::new(&[0u32], &[0u32], &[] as &[u32]).unwrap();
220 assert!(plan.batch.is_empty());
221 assert!(plan.lo.is_empty());
222 assert!(plan.ro.is_empty());
223 assert_eq!(plan.sum, vec![0]);
224 }
225
226 #[test]
227 fn test_classify_left_trace() {
228 let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[2u32]).unwrap();
230 assert!(plan.batch.is_empty());
231 assert!(plan.lo.is_empty());
232 assert_eq!(plan.ro, vec![2]);
233 assert_eq!(plan.sum, vec![1]);
234 assert_eq!(plan.left_trace, vec![0]);
235 }
236
237 #[test]
238 fn test_perm_matmul() {
239 let plan = Einsum2Plan::new(&[0u32, 1], &[1u32, 2], &[0u32, 2]).unwrap();
244 assert_eq!(plan.left_perm, vec![0, 1]);
245 assert_eq!(plan.right_perm, vec![0, 1]);
246 assert_eq!(plan.c_to_internal_perm, vec![0, 1]);
247 }
248
249 #[test]
250 fn test_perm_batched_transposed_output() {
251 let plan = Einsum2Plan::new(&[0u32, 1, 2], &[0u32, 2, 3], &[0u32, 3, 1]).unwrap();
253 assert_eq!(plan.batch, vec![0]);
254 assert_eq!(plan.lo, vec![1]);
255 assert_eq!(plan.ro, vec![3]);
256 assert_eq!(plan.sum, vec![2]);
257 assert_eq!(plan.left_perm, vec![1, 2, 0]);
260 assert_eq!(plan.right_perm, vec![1, 2, 0]);
263 assert_eq!(plan.c_to_internal_perm, vec![2, 1, 0]);
266 }
267
268 #[test]
269 fn test_error_orphan_output() {
270 let result = Einsum2Plan::new(&[0u32], &[1u32], &[0u32, 1, 2]);
271 assert!(result.is_err());
272 }
273
274 #[test]
275 fn test_error_duplicate() {
276 let result = Einsum2Plan::new(&[0u32, 0], &[1u32], &[0u32, 1]);
277 assert!(result.is_err());
278 }
279
280 #[test]
281 fn test_char_labels() {
282 let plan = Einsum2Plan::new(&['i', 'j'], &['j', 'k'], &['i', 'k']).unwrap();
283 assert_eq!(plan.lo, vec!['i']);
284 assert_eq!(plan.ro, vec!['k']);
285 assert_eq!(plan.sum, vec!['j']);
286 }
287}