strided_opteinsum/
parse.rs

1/// Parsed contraction tree node (no operands, just structure).
2#[derive(Debug, Clone, PartialEq)]
3pub enum EinsumNode {
4    /// Leaf: indices for a single tensor, with 0-based tensor index.
5    Leaf { ids: Vec<char>, tensor_index: usize },
6    /// Contraction of children.
7    Contract { args: Vec<EinsumNode> },
8}
9
10/// Parsed einsum code: contraction tree + final output indices.
11#[derive(Debug, Clone, PartialEq)]
12pub struct EinsumCode {
13    pub root: EinsumNode,
14    pub output_ids: Vec<char>,
15}
16
17/// Parse an einsum string like "(ij,jk),kl->il" into an EinsumCode.
18pub fn parse_einsum(s: &str) -> crate::Result<EinsumCode> {
19    // Strip all whitespace
20    let s: String = s.chars().filter(|c| !c.is_whitespace()).collect();
21
22    // Split on "->"
23    let arrow_pos = s
24        .find("->")
25        .ok_or_else(|| crate::EinsumError::ParseError("missing '->' in einsum string".into()))?;
26    let lhs = &s[..arrow_pos];
27    let rhs = &s[arrow_pos + 2..];
28
29    // Parse output indices
30    let output_ids: Vec<char> = rhs.chars().collect();
31    for &c in &output_ids {
32        if !c.is_alphabetic() {
33            return Err(crate::EinsumError::ParseError(format!(
34                "invalid character '{}' in output indices",
35                c
36            )));
37        }
38    }
39
40    // Parse LHS as args_list.
41    // Empty LHS (e.g. "->ii") is a single scalar operand with no indices.
42    // Leading/trailing commas (e.g. ",k->k") produce scalar operands too.
43    let mut counter: usize = 0;
44    let root = parse_args_list(lhs, &mut counter)?;
45
46    Ok(EinsumCode { root, output_ids })
47}
48
49/// Parse a comma-separated args list at the current level, returning a `Contract` node.
50///
51/// An empty string produces a single scalar Leaf (0-index operand).
52fn parse_args_list(s: &str, counter: &mut usize) -> crate::Result<EinsumNode> {
53    let parts = split_top_level(s)?;
54    if parts.is_empty() {
55        // Empty string (no commas, no chars) → single scalar operand
56        return parse_arg("", counter).map(|leaf| EinsumNode::Contract { args: vec![leaf] });
57    }
58    let mut args = Vec::with_capacity(parts.len());
59    for part in parts {
60        args.push(parse_arg(&part, counter)?);
61    }
62    // If there is exactly one arg and it is a Contract, unwrap it to avoid
63    // redundant nesting from outer parentheses like "((ij,jk),(kl,lm))".
64    if args.len() == 1 {
65        if let EinsumNode::Contract { .. } = &args[0] {
66            return Ok(args.into_iter().next().unwrap());
67        }
68    }
69    Ok(EinsumNode::Contract { args })
70}
71
72/// Parse a single arg: if wrapped in `(...)`, recursively parse inner as args_list;
73/// otherwise it's a leaf with index characters.
74fn parse_arg(s: &str, counter: &mut usize) -> crate::Result<EinsumNode> {
75    if s.starts_with('(') && s.ends_with(')') {
76        // Strip outer parens and recursively parse
77        let inner = &s[1..s.len() - 1];
78        parse_args_list(inner, counter)
79    } else {
80        // Empty string is a valid scalar operand (0-index tensor).
81        // Leaf: validate all chars are alphabetic (ASCII or Unicode letters)
82        for c in s.chars() {
83            if !c.is_alphabetic() {
84                return Err(crate::EinsumError::ParseError(format!(
85                    "invalid character '{}' in index labels",
86                    c
87                )));
88            }
89        }
90        let ids: Vec<char> = s.chars().collect();
91        let tensor_index = *counter;
92        *counter += 1;
93        Ok(EinsumNode::Leaf { ids, tensor_index })
94    }
95}
96
97/// Split a string by commas, respecting parenthesis nesting depth.
98fn split_top_level(s: &str) -> crate::Result<Vec<String>> {
99    let mut parts = Vec::new();
100    let mut depth = 0usize;
101    let mut current = String::new();
102    for c in s.chars() {
103        match c {
104            '(' => {
105                depth += 1;
106                current.push(c);
107            }
108            ')' => {
109                if depth == 0 {
110                    return Err(crate::EinsumError::ParseError("unbalanced ')'".into()));
111                }
112                depth -= 1;
113                current.push(c);
114            }
115            ',' if depth == 0 => {
116                // Empty `current` is valid: it represents a scalar (0-index) operand.
117                parts.push(std::mem::take(&mut current));
118            }
119            _ => {
120                current.push(c);
121            }
122        }
123    }
124    if depth != 0 {
125        return Err(crate::EinsumError::ParseError("unbalanced '('".into()));
126    }
127    if !current.is_empty() || !parts.is_empty() {
128        // Push final segment. Empty `current` after a comma is a valid scalar operand.
129        parts.push(current);
130    }
131    Ok(parts)
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_parse_flat() {
140        let code = parse_einsum("ij,jk->ik").unwrap();
141        assert_eq!(code.output_ids, vec!['i', 'k']);
142        match &code.root {
143            EinsumNode::Contract { args } => {
144                assert_eq!(args.len(), 2);
145                assert_eq!(
146                    args[0],
147                    EinsumNode::Leaf {
148                        ids: vec!['i', 'j'],
149                        tensor_index: 0
150                    }
151                );
152                assert_eq!(
153                    args[1],
154                    EinsumNode::Leaf {
155                        ids: vec!['j', 'k'],
156                        tensor_index: 1
157                    }
158                );
159            }
160            _ => panic!("expected Contract"),
161        }
162    }
163
164    #[test]
165    fn test_parse_nested() {
166        let code = parse_einsum("(ij,jk),kl->il").unwrap();
167        assert_eq!(code.output_ids, vec!['i', 'l']);
168        match &code.root {
169            EinsumNode::Contract { args } => {
170                assert_eq!(args.len(), 2);
171                match &args[0] {
172                    EinsumNode::Contract { args: inner } => {
173                        assert_eq!(inner.len(), 2);
174                        assert_eq!(
175                            inner[0],
176                            EinsumNode::Leaf {
177                                ids: vec!['i', 'j'],
178                                tensor_index: 0
179                            }
180                        );
181                        assert_eq!(
182                            inner[1],
183                            EinsumNode::Leaf {
184                                ids: vec!['j', 'k'],
185                                tensor_index: 1
186                            }
187                        );
188                    }
189                    _ => panic!("expected inner Contract"),
190                }
191                assert_eq!(
192                    args[1],
193                    EinsumNode::Leaf {
194                        ids: vec!['k', 'l'],
195                        tensor_index: 2
196                    }
197                );
198            }
199            _ => panic!("expected Contract"),
200        }
201    }
202
203    #[test]
204    fn test_parse_deep_nested() {
205        let code = parse_einsum("((ij,jk),(kl,lm))->im").unwrap();
206        assert_eq!(code.output_ids, vec!['i', 'm']);
207        match &code.root {
208            EinsumNode::Contract { args } => {
209                assert_eq!(args.len(), 2);
210                match &args[0] {
211                    EinsumNode::Contract { args: left } => {
212                        assert_eq!(left.len(), 2);
213                        assert_eq!(
214                            left[0],
215                            EinsumNode::Leaf {
216                                ids: vec!['i', 'j'],
217                                tensor_index: 0
218                            }
219                        );
220                        assert_eq!(
221                            left[1],
222                            EinsumNode::Leaf {
223                                ids: vec!['j', 'k'],
224                                tensor_index: 1
225                            }
226                        );
227                    }
228                    _ => panic!("expected left Contract"),
229                }
230                match &args[1] {
231                    EinsumNode::Contract { args: right } => {
232                        assert_eq!(right.len(), 2);
233                        assert_eq!(
234                            right[0],
235                            EinsumNode::Leaf {
236                                ids: vec!['k', 'l'],
237                                tensor_index: 2
238                            }
239                        );
240                        assert_eq!(
241                            right[1],
242                            EinsumNode::Leaf {
243                                ids: vec!['l', 'm'],
244                                tensor_index: 3
245                            }
246                        );
247                    }
248                    _ => panic!("expected right Contract"),
249                }
250            }
251            _ => panic!("expected Contract"),
252        }
253    }
254
255    #[test]
256    fn test_parse_scalar_output() {
257        let code = parse_einsum("ij,ji->").unwrap();
258        assert_eq!(code.output_ids, vec![]);
259    }
260
261    #[test]
262    fn test_parse_single_tensor() {
263        let code = parse_einsum("ijk->kji").unwrap();
264        assert_eq!(code.output_ids, vec!['k', 'j', 'i']);
265        match &code.root {
266            EinsumNode::Contract { args } => {
267                assert_eq!(args.len(), 1);
268                assert_eq!(
269                    args[0],
270                    EinsumNode::Leaf {
271                        ids: vec!['i', 'j', 'k'],
272                        tensor_index: 0
273                    }
274                );
275            }
276            _ => panic!("expected Contract"),
277        }
278    }
279
280    #[test]
281    fn test_parse_three_flat() {
282        let code = parse_einsum("ij,jk,kl->il").unwrap();
283        match &code.root {
284            EinsumNode::Contract { args } => {
285                assert_eq!(args.len(), 3);
286            }
287            _ => panic!("expected Contract"),
288        }
289    }
290
291    #[test]
292    fn test_parse_whitespace() {
293        let code = parse_einsum(" (ij, jk) , kl -> il ").unwrap();
294        assert_eq!(code.output_ids, vec!['i', 'l']);
295    }
296
297    #[test]
298    fn test_parse_error_no_arrow() {
299        assert!(parse_einsum("ij,jk").is_err());
300    }
301
302    #[test]
303    fn test_parse_scalar_operand_leading_comma() {
304        // ",k->k" = scalar (tensor 0) + vector k (tensor 1)
305        let code = parse_einsum(",k->k").unwrap();
306        assert_eq!(code.output_ids, vec!['k']);
307        match &code.root {
308            EinsumNode::Contract { args } => {
309                assert_eq!(args.len(), 2);
310                assert_eq!(
311                    args[0],
312                    EinsumNode::Leaf {
313                        ids: vec![],
314                        tensor_index: 0
315                    }
316                );
317                assert_eq!(
318                    args[1],
319                    EinsumNode::Leaf {
320                        ids: vec!['k'],
321                        tensor_index: 1
322                    }
323                );
324            }
325            _ => panic!("expected Contract"),
326        }
327    }
328
329    #[test]
330    fn test_parse_scalar_operand_trailing_comma() {
331        // "i,->i" = vector i (tensor 0) + scalar (tensor 1)
332        let code = parse_einsum("i,->i").unwrap();
333        assert_eq!(code.output_ids, vec!['i']);
334        match &code.root {
335            EinsumNode::Contract { args } => {
336                assert_eq!(args.len(), 2);
337                assert_eq!(
338                    args[0],
339                    EinsumNode::Leaf {
340                        ids: vec!['i'],
341                        tensor_index: 0
342                    }
343                );
344                assert_eq!(
345                    args[1],
346                    EinsumNode::Leaf {
347                        ids: vec![],
348                        tensor_index: 1
349                    }
350                );
351            }
352            _ => panic!("expected Contract"),
353        }
354    }
355
356    #[test]
357    fn test_parse_two_scalars() {
358        // ",->": two scalar operands
359        let code = parse_einsum(",->").unwrap();
360        assert_eq!(code.output_ids, vec![]);
361        match &code.root {
362            EinsumNode::Contract { args } => {
363                assert_eq!(args.len(), 2);
364                assert_eq!(
365                    args[0],
366                    EinsumNode::Leaf {
367                        ids: vec![],
368                        tensor_index: 0
369                    }
370                );
371                assert_eq!(
372                    args[1],
373                    EinsumNode::Leaf {
374                        ids: vec![],
375                        tensor_index: 1
376                    }
377                );
378            }
379            _ => panic!("expected Contract"),
380        }
381    }
382
383    #[test]
384    fn test_parse_scalar_between_tensors() {
385        // "ij,,jk->ik" = tensor 0 (ij) + scalar (tensor 1) + tensor 2 (jk)
386        let code = parse_einsum("ij,,jk->ik").unwrap();
387        assert_eq!(code.output_ids, vec!['i', 'k']);
388        match &code.root {
389            EinsumNode::Contract { args } => {
390                assert_eq!(args.len(), 3);
391                assert_eq!(
392                    args[0],
393                    EinsumNode::Leaf {
394                        ids: vec!['i', 'j'],
395                        tensor_index: 0
396                    }
397                );
398                assert_eq!(
399                    args[1],
400                    EinsumNode::Leaf {
401                        ids: vec![],
402                        tensor_index: 1
403                    }
404                );
405                assert_eq!(
406                    args[2],
407                    EinsumNode::Leaf {
408                        ids: vec!['j', 'k'],
409                        tensor_index: 2
410                    }
411                );
412            }
413            _ => panic!("expected Contract"),
414        }
415    }
416
417    #[test]
418    fn test_parse_empty_lhs_scalar() {
419        // "->ii" = single scalar operand with generative output
420        let code = parse_einsum("->ii").unwrap();
421        assert_eq!(code.output_ids, vec!['i', 'i']);
422        match &code.root {
423            EinsumNode::Contract { args } => {
424                assert_eq!(args.len(), 1);
425                assert_eq!(
426                    args[0],
427                    EinsumNode::Leaf {
428                        ids: vec![],
429                        tensor_index: 0
430                    }
431                );
432            }
433            _ => panic!("expected Contract"),
434        }
435    }
436
437    #[test]
438    fn test_parse_unicode_greek() {
439        let code = parse_einsum("αβ,βγ->αγ").unwrap();
440        assert_eq!(code.output_ids, vec!['α', 'γ']);
441        match &code.root {
442            EinsumNode::Contract { args } => {
443                assert_eq!(args.len(), 2);
444                assert_eq!(
445                    args[0],
446                    EinsumNode::Leaf {
447                        ids: vec!['α', 'β'],
448                        tensor_index: 0
449                    }
450                );
451                assert_eq!(
452                    args[1],
453                    EinsumNode::Leaf {
454                        ids: vec!['β', 'γ'],
455                        tensor_index: 1
456                    }
457                );
458            }
459            _ => panic!("expected Contract"),
460        }
461    }
462
463    #[test]
464    fn test_parse_unicode_mixed() {
465        let code = parse_einsum("αi,iβ->αβ").unwrap();
466        assert_eq!(code.output_ids, vec!['α', 'β']);
467        match &code.root {
468            EinsumNode::Contract { args } => {
469                assert_eq!(args.len(), 2);
470                assert_eq!(
471                    args[0],
472                    EinsumNode::Leaf {
473                        ids: vec!['α', 'i'],
474                        tensor_index: 0
475                    }
476                );
477                assert_eq!(
478                    args[1],
479                    EinsumNode::Leaf {
480                        ids: vec!['i', 'β'],
481                        tensor_index: 1
482                    }
483                );
484            }
485            _ => panic!("expected Contract"),
486        }
487    }
488
489    #[test]
490    fn test_parse_unicode_nested() {
491        let code = parse_einsum("(αβ,βγ),γδ->αδ").unwrap();
492        assert_eq!(code.output_ids, vec!['α', 'δ']);
493    }
494}