tenferro_einsum/syntax/
nested.rs

1use std::collections::HashSet;
2
3use tenferro_device::{Error, Result};
4
5use crate::syntax::notation::{char_to_label, split_and_validate_notation};
6use crate::syntax::subscripts::Subscripts;
7
8/// Recursive einsum tree that preserves parenthesized grouping.
9///
10/// `NestedEinsum` mirrors OMEinsum.jl's `NestedEinsum`: each internal node
11/// holds [`Subscripts`] describing how its children are contracted, and leaf
12/// nodes reference an original input tensor by index.
13///
14/// # Construction
15///
16/// Use [`NestedEinsum::parse`] to build a tree from parenthesized string
17/// notation such as `"(ij,jk),kl->il"`.  Without parentheses the result is
18/// a flat root node whose children are all leaves.
19///
20/// # Examples
21///
22/// ```
23/// use tenferro_einsum::NestedEinsum;
24///
25/// // Flat (no grouping): root with two leaves
26/// let flat = NestedEinsum::parse("ij,jk->ik").unwrap();
27/// assert!(matches!(flat, NestedEinsum::Node { .. }));
28///
29/// // Grouped: contract first two operands, then with third
30/// let grouped = NestedEinsum::parse("(ij,jk),kl->il").unwrap();
31/// assert!(matches!(grouped, NestedEinsum::Node { .. }));
32/// ```
33#[derive(Debug, Clone)]
34pub enum NestedEinsum {
35    /// A leaf referencing one of the original input tensors by index.
36    Leaf(usize),
37    /// An internal node that contracts its children according to `subscripts`.
38    Node {
39        /// The subscripts for this contraction: one input per child, plus output.
40        subscripts: Subscripts,
41        /// Child sub-expressions (leaves or further nodes).
42        children: Vec<NestedEinsum>,
43    },
44}
45
46impl NestedEinsum {
47    /// Count the total number of leaf operands in the tree.
48    pub fn count_leaves(&self) -> usize {
49        match self {
50            Self::Leaf(_) => 1,
51            Self::Node { children, .. } => children.iter().map(|c| c.count_leaves()).sum(),
52        }
53    }
54
55    /// Parse parenthesized einsum notation into a recursive tree.
56    ///
57    /// Notation follows the standard `"inputs->output"` format with optional
58    /// parentheses to specify contraction order. Each parenthesized group
59    /// becomes an internal [`NestedEinsum::Node`]; bare operands become
60    /// [`NestedEinsum::Leaf`] nodes.
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use tenferro_einsum::NestedEinsum;
66    ///
67    /// let nested = NestedEinsum::parse("(ij,jk),kl->il").unwrap();
68    /// // Root has two children: a group node and a leaf
69    /// match &nested {
70    ///     NestedEinsum::Node { children, .. } => assert_eq!(children.len(), 2),
71    ///     _ => panic!("expected Node"),
72    /// }
73    /// ```
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if parentheses are mismatched or the notation is
78    /// otherwise malformed.
79    pub fn parse(notation: &str) -> Result<Self> {
80        let (lhs, output_str) = split_and_validate_notation(notation)?;
81
82        let output: Vec<u32> = output_str
83            .chars()
84            .map(char_to_label)
85            .collect::<Result<_>>()?;
86
87        let mut leaf_counter: usize = 0;
88        let outer_needed: HashSet<u32> = output.iter().copied().collect();
89        Self::parse_group(lhs, &outer_needed, &output, &mut leaf_counter)
90    }
91
92    /// Recursively parse a group (possibly containing sub-groups) into a Node.
93    ///
94    /// `group_str` is a comma-separated list of items (at the top level),
95    /// where each item is either a bare operand (e.g. `"ij"`) or a
96    /// parenthesized sub-group (e.g. `"(ij,jk)"`).
97    ///
98    /// `outer_needed` contains labels that the parent or siblings need from
99    /// this group.  `final_output` is the overall output of the entire
100    /// expression.
101    fn parse_group(
102        group_str: &str,
103        outer_needed: &HashSet<u32>,
104        final_output: &[u32],
105        leaf_counter: &mut usize,
106    ) -> Result<Self> {
107        let items = Self::split_top_level(group_str)?;
108
109        let mut children = Vec::with_capacity(items.len());
110        let mut child_subscript_inputs: Vec<Vec<u32>> = Vec::with_capacity(items.len());
111
112        for (idx, item) in items.iter().enumerate() {
113            if item.starts_with('(') && item.ends_with(')') {
114                // Sub-group: strip outer parens and recurse
115                let inner = &item[1..item.len() - 1];
116
117                // Compute what this sub-group needs to output:
118                // labels in this group that appear in outer_needed or in sibling items
119                let group_labels = Self::collect_labels(inner)?;
120                let sibling_labels = Self::collect_sibling_labels(&items, idx)?;
121                let mut needed: HashSet<u32> = HashSet::new();
122                for &label in &group_labels {
123                    if outer_needed.contains(&label) || sibling_labels.contains(&label) {
124                        needed.insert(label);
125                    }
126                }
127                let mut sub_output: Vec<u32> = needed.iter().copied().collect();
128                sub_output.sort();
129
130                let child = Self::parse_group(inner, &needed, &sub_output, leaf_counter)?;
131                child_subscript_inputs.push(sub_output);
132                children.push(child);
133            } else {
134                // Bare operand -> Leaf
135                let labels: Vec<u32> = item.chars().map(char_to_label).collect::<Result<_>>()?;
136                child_subscript_inputs.push(labels);
137                children.push(NestedEinsum::Leaf(*leaf_counter));
138                *leaf_counter += 1;
139            }
140        }
141
142        // Build subscripts for this node
143        let node_output: Vec<u32> = final_output.to_vec();
144        let subscripts = Subscripts {
145            inputs: child_subscript_inputs,
146            output: node_output,
147        };
148
149        Ok(NestedEinsum::Node {
150            subscripts,
151            children,
152        })
153    }
154
155    /// Split a string on commas at the top level (depth 0), respecting parentheses.
156    fn split_top_level(s: &str) -> Result<Vec<&str>> {
157        let mut items = Vec::new();
158        let mut depth: usize = 0;
159        let mut start = 0;
160
161        for (pos, c) in s.char_indices() {
162            match c {
163                '(' => depth += 1,
164                ')' => {
165                    if depth == 0 {
166                        return Err(Error::InvalidArgument(format!(
167                            "unmatched ')' in einsum group: {s}"
168                        )));
169                    }
170                    depth -= 1;
171                }
172                ',' if depth == 0 => {
173                    items.push(&s[start..pos]);
174                    start = pos + 1; // skip the comma
175                }
176                _ => {}
177            }
178        }
179        // Push the last item
180        items.push(&s[start..]);
181        Ok(items)
182    }
183
184    /// Collect all unique labels from a (possibly nested) string, ignoring
185    /// parentheses and commas.
186    fn collect_labels(s: &str) -> Result<HashSet<u32>> {
187        let mut labels = HashSet::new();
188        for c in s.chars() {
189            match c {
190                '(' | ')' | ',' => continue,
191                _ => {
192                    labels.insert(char_to_label(c)?);
193                }
194            }
195        }
196        Ok(labels)
197    }
198
199    /// Collect all labels from sibling items (all items except the one at `current_idx`).
200    fn collect_sibling_labels(items: &[&str], current_idx: usize) -> Result<HashSet<u32>> {
201        let mut labels = HashSet::new();
202        for (idx, item) in items.iter().enumerate() {
203            if idx == current_idx {
204                continue;
205            }
206            for label in Self::collect_labels(item)? {
207                labels.insert(label);
208            }
209        }
210        Ok(labels)
211    }
212}