1#[derive(Debug, Clone, PartialEq)]
3pub enum EinsumNode {
4 Leaf { ids: Vec<char>, tensor_index: usize },
6 Contract { args: Vec<EinsumNode> },
8}
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct EinsumCode {
13 pub root: EinsumNode,
14 pub output_ids: Vec<char>,
15}
16
17pub fn parse_einsum(s: &str) -> crate::Result<EinsumCode> {
19 let s: String = s.chars().filter(|c| !c.is_whitespace()).collect();
21
22 let arrow_pos = s
24 .find("->")
25 .ok_or_else(|| crate::EinsumError::ParseError("missing '->' in einsum string".into()))?;
26 let lhs = &s[..arrow_pos];
27 let rhs = &s[arrow_pos + 2..];
28
29 let output_ids: Vec<char> = rhs.chars().collect();
31 for &c in &output_ids {
32 if !c.is_alphabetic() {
33 return Err(crate::EinsumError::ParseError(format!(
34 "invalid character '{}' in output indices",
35 c
36 )));
37 }
38 }
39
40 let mut counter: usize = 0;
44 let root = parse_args_list(lhs, &mut counter)?;
45
46 Ok(EinsumCode { root, output_ids })
47}
48
49fn parse_args_list(s: &str, counter: &mut usize) -> crate::Result<EinsumNode> {
53 let parts = split_top_level(s)?;
54 if parts.is_empty() {
55 return parse_arg("", counter).map(|leaf| EinsumNode::Contract { args: vec![leaf] });
57 }
58 let mut args = Vec::with_capacity(parts.len());
59 for part in parts {
60 args.push(parse_arg(&part, counter)?);
61 }
62 if args.len() == 1 {
65 if let EinsumNode::Contract { .. } = &args[0] {
66 return Ok(args.into_iter().next().unwrap());
67 }
68 }
69 Ok(EinsumNode::Contract { args })
70}
71
72fn parse_arg(s: &str, counter: &mut usize) -> crate::Result<EinsumNode> {
75 if s.starts_with('(') && s.ends_with(')') {
76 let inner = &s[1..s.len() - 1];
78 parse_args_list(inner, counter)
79 } else {
80 for c in s.chars() {
83 if !c.is_alphabetic() {
84 return Err(crate::EinsumError::ParseError(format!(
85 "invalid character '{}' in index labels",
86 c
87 )));
88 }
89 }
90 let ids: Vec<char> = s.chars().collect();
91 let tensor_index = *counter;
92 *counter += 1;
93 Ok(EinsumNode::Leaf { ids, tensor_index })
94 }
95}
96
97fn split_top_level(s: &str) -> crate::Result<Vec<String>> {
99 let mut parts = Vec::new();
100 let mut depth = 0usize;
101 let mut current = String::new();
102 for c in s.chars() {
103 match c {
104 '(' => {
105 depth += 1;
106 current.push(c);
107 }
108 ')' => {
109 if depth == 0 {
110 return Err(crate::EinsumError::ParseError("unbalanced ')'".into()));
111 }
112 depth -= 1;
113 current.push(c);
114 }
115 ',' if depth == 0 => {
116 parts.push(std::mem::take(&mut current));
118 }
119 _ => {
120 current.push(c);
121 }
122 }
123 }
124 if depth != 0 {
125 return Err(crate::EinsumError::ParseError("unbalanced '('".into()));
126 }
127 if !current.is_empty() || !parts.is_empty() {
128 parts.push(current);
130 }
131 Ok(parts)
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_parse_flat() {
140 let code = parse_einsum("ij,jk->ik").unwrap();
141 assert_eq!(code.output_ids, vec!['i', 'k']);
142 match &code.root {
143 EinsumNode::Contract { args } => {
144 assert_eq!(args.len(), 2);
145 assert_eq!(
146 args[0],
147 EinsumNode::Leaf {
148 ids: vec!['i', 'j'],
149 tensor_index: 0
150 }
151 );
152 assert_eq!(
153 args[1],
154 EinsumNode::Leaf {
155 ids: vec!['j', 'k'],
156 tensor_index: 1
157 }
158 );
159 }
160 _ => panic!("expected Contract"),
161 }
162 }
163
164 #[test]
165 fn test_parse_nested() {
166 let code = parse_einsum("(ij,jk),kl->il").unwrap();
167 assert_eq!(code.output_ids, vec!['i', 'l']);
168 match &code.root {
169 EinsumNode::Contract { args } => {
170 assert_eq!(args.len(), 2);
171 match &args[0] {
172 EinsumNode::Contract { args: inner } => {
173 assert_eq!(inner.len(), 2);
174 assert_eq!(
175 inner[0],
176 EinsumNode::Leaf {
177 ids: vec!['i', 'j'],
178 tensor_index: 0
179 }
180 );
181 assert_eq!(
182 inner[1],
183 EinsumNode::Leaf {
184 ids: vec!['j', 'k'],
185 tensor_index: 1
186 }
187 );
188 }
189 _ => panic!("expected inner Contract"),
190 }
191 assert_eq!(
192 args[1],
193 EinsumNode::Leaf {
194 ids: vec!['k', 'l'],
195 tensor_index: 2
196 }
197 );
198 }
199 _ => panic!("expected Contract"),
200 }
201 }
202
203 #[test]
204 fn test_parse_deep_nested() {
205 let code = parse_einsum("((ij,jk),(kl,lm))->im").unwrap();
206 assert_eq!(code.output_ids, vec!['i', 'm']);
207 match &code.root {
208 EinsumNode::Contract { args } => {
209 assert_eq!(args.len(), 2);
210 match &args[0] {
211 EinsumNode::Contract { args: left } => {
212 assert_eq!(left.len(), 2);
213 assert_eq!(
214 left[0],
215 EinsumNode::Leaf {
216 ids: vec!['i', 'j'],
217 tensor_index: 0
218 }
219 );
220 assert_eq!(
221 left[1],
222 EinsumNode::Leaf {
223 ids: vec!['j', 'k'],
224 tensor_index: 1
225 }
226 );
227 }
228 _ => panic!("expected left Contract"),
229 }
230 match &args[1] {
231 EinsumNode::Contract { args: right } => {
232 assert_eq!(right.len(), 2);
233 assert_eq!(
234 right[0],
235 EinsumNode::Leaf {
236 ids: vec!['k', 'l'],
237 tensor_index: 2
238 }
239 );
240 assert_eq!(
241 right[1],
242 EinsumNode::Leaf {
243 ids: vec!['l', 'm'],
244 tensor_index: 3
245 }
246 );
247 }
248 _ => panic!("expected right Contract"),
249 }
250 }
251 _ => panic!("expected Contract"),
252 }
253 }
254
255 #[test]
256 fn test_parse_scalar_output() {
257 let code = parse_einsum("ij,ji->").unwrap();
258 assert_eq!(code.output_ids, vec![]);
259 }
260
261 #[test]
262 fn test_parse_single_tensor() {
263 let code = parse_einsum("ijk->kji").unwrap();
264 assert_eq!(code.output_ids, vec!['k', 'j', 'i']);
265 match &code.root {
266 EinsumNode::Contract { args } => {
267 assert_eq!(args.len(), 1);
268 assert_eq!(
269 args[0],
270 EinsumNode::Leaf {
271 ids: vec!['i', 'j', 'k'],
272 tensor_index: 0
273 }
274 );
275 }
276 _ => panic!("expected Contract"),
277 }
278 }
279
280 #[test]
281 fn test_parse_three_flat() {
282 let code = parse_einsum("ij,jk,kl->il").unwrap();
283 match &code.root {
284 EinsumNode::Contract { args } => {
285 assert_eq!(args.len(), 3);
286 }
287 _ => panic!("expected Contract"),
288 }
289 }
290
291 #[test]
292 fn test_parse_whitespace() {
293 let code = parse_einsum(" (ij, jk) , kl -> il ").unwrap();
294 assert_eq!(code.output_ids, vec!['i', 'l']);
295 }
296
297 #[test]
298 fn test_parse_error_no_arrow() {
299 assert!(parse_einsum("ij,jk").is_err());
300 }
301
302 #[test]
303 fn test_parse_scalar_operand_leading_comma() {
304 let code = parse_einsum(",k->k").unwrap();
306 assert_eq!(code.output_ids, vec!['k']);
307 match &code.root {
308 EinsumNode::Contract { args } => {
309 assert_eq!(args.len(), 2);
310 assert_eq!(
311 args[0],
312 EinsumNode::Leaf {
313 ids: vec![],
314 tensor_index: 0
315 }
316 );
317 assert_eq!(
318 args[1],
319 EinsumNode::Leaf {
320 ids: vec!['k'],
321 tensor_index: 1
322 }
323 );
324 }
325 _ => panic!("expected Contract"),
326 }
327 }
328
329 #[test]
330 fn test_parse_scalar_operand_trailing_comma() {
331 let code = parse_einsum("i,->i").unwrap();
333 assert_eq!(code.output_ids, vec!['i']);
334 match &code.root {
335 EinsumNode::Contract { args } => {
336 assert_eq!(args.len(), 2);
337 assert_eq!(
338 args[0],
339 EinsumNode::Leaf {
340 ids: vec!['i'],
341 tensor_index: 0
342 }
343 );
344 assert_eq!(
345 args[1],
346 EinsumNode::Leaf {
347 ids: vec![],
348 tensor_index: 1
349 }
350 );
351 }
352 _ => panic!("expected Contract"),
353 }
354 }
355
356 #[test]
357 fn test_parse_two_scalars() {
358 let code = parse_einsum(",->").unwrap();
360 assert_eq!(code.output_ids, vec![]);
361 match &code.root {
362 EinsumNode::Contract { args } => {
363 assert_eq!(args.len(), 2);
364 assert_eq!(
365 args[0],
366 EinsumNode::Leaf {
367 ids: vec![],
368 tensor_index: 0
369 }
370 );
371 assert_eq!(
372 args[1],
373 EinsumNode::Leaf {
374 ids: vec![],
375 tensor_index: 1
376 }
377 );
378 }
379 _ => panic!("expected Contract"),
380 }
381 }
382
383 #[test]
384 fn test_parse_scalar_between_tensors() {
385 let code = parse_einsum("ij,,jk->ik").unwrap();
387 assert_eq!(code.output_ids, vec!['i', 'k']);
388 match &code.root {
389 EinsumNode::Contract { args } => {
390 assert_eq!(args.len(), 3);
391 assert_eq!(
392 args[0],
393 EinsumNode::Leaf {
394 ids: vec!['i', 'j'],
395 tensor_index: 0
396 }
397 );
398 assert_eq!(
399 args[1],
400 EinsumNode::Leaf {
401 ids: vec![],
402 tensor_index: 1
403 }
404 );
405 assert_eq!(
406 args[2],
407 EinsumNode::Leaf {
408 ids: vec!['j', 'k'],
409 tensor_index: 2
410 }
411 );
412 }
413 _ => panic!("expected Contract"),
414 }
415 }
416
417 #[test]
418 fn test_parse_empty_lhs_scalar() {
419 let code = parse_einsum("->ii").unwrap();
421 assert_eq!(code.output_ids, vec!['i', 'i']);
422 match &code.root {
423 EinsumNode::Contract { args } => {
424 assert_eq!(args.len(), 1);
425 assert_eq!(
426 args[0],
427 EinsumNode::Leaf {
428 ids: vec![],
429 tensor_index: 0
430 }
431 );
432 }
433 _ => panic!("expected Contract"),
434 }
435 }
436
437 #[test]
438 fn test_parse_unicode_greek() {
439 let code = parse_einsum("αβ,βγ->αγ").unwrap();
440 assert_eq!(code.output_ids, vec!['α', 'γ']);
441 match &code.root {
442 EinsumNode::Contract { args } => {
443 assert_eq!(args.len(), 2);
444 assert_eq!(
445 args[0],
446 EinsumNode::Leaf {
447 ids: vec!['α', 'β'],
448 tensor_index: 0
449 }
450 );
451 assert_eq!(
452 args[1],
453 EinsumNode::Leaf {
454 ids: vec!['β', 'γ'],
455 tensor_index: 1
456 }
457 );
458 }
459 _ => panic!("expected Contract"),
460 }
461 }
462
463 #[test]
464 fn test_parse_unicode_mixed() {
465 let code = parse_einsum("αi,iβ->αβ").unwrap();
466 assert_eq!(code.output_ids, vec!['α', 'β']);
467 match &code.root {
468 EinsumNode::Contract { args } => {
469 assert_eq!(args.len(), 2);
470 assert_eq!(
471 args[0],
472 EinsumNode::Leaf {
473 ids: vec!['α', 'i'],
474 tensor_index: 0
475 }
476 );
477 assert_eq!(
478 args[1],
479 EinsumNode::Leaf {
480 ids: vec!['i', 'β'],
481 tensor_index: 1
482 }
483 );
484 }
485 _ => panic!("expected Contract"),
486 }
487 }
488
489 #[test]
490 fn test_parse_unicode_nested() {
491 let code = parse_einsum("(αβ,βγ),γδ->αδ").unwrap();
492 assert_eq!(code.output_ids, vec!['α', 'δ']);
493 }
494}