Skip to main content

tenferro_einsum/syntax/
nested.rs

1use std::collections::HashSet;
2
3use crate::syntax::notation::{char_to_label, split_and_validate_notation};
4use crate::syntax::subscripts::Subscripts;
5use crate::{Error, Result};
6
7/// Recursive einsum tree that preserves parenthesized grouping.
8///
9/// `NestedEinsum` mirrors OMEinsum.jl's `NestedEinsum`: each internal node
10/// holds [`Subscripts`] describing how its children are contracted, and leaf
11/// nodes reference an original input tensor by index.
12///
13/// # Construction
14///
15/// Use [`NestedEinsum::parse`] to build a tree from parenthesized string
16/// notation such as `"(ij,jk),kl->il"`.  Without parentheses the result is
17/// a flat root node whose children are all leaves.
18///
19/// # Examples
20///
21/// ```
22/// use tenferro_einsum::NestedEinsum;
23///
24/// // Flat (no grouping): root with two leaves
25/// let flat = NestedEinsum::parse("ij,jk->ik").unwrap();
26/// assert!(matches!(flat, NestedEinsum::Node { .. }));
27///
28/// // Grouped: contract first two operands, then with third
29/// let grouped = NestedEinsum::parse("(ij,jk),kl->il").unwrap();
30/// assert!(matches!(grouped, NestedEinsum::Node { .. }));
31/// ```
32#[derive(Debug, Clone)]
33pub enum NestedEinsum {
34    /// A leaf referencing one of the original input tensors by index.
35    Leaf(usize),
36    /// An internal node that contracts its children according to `subscripts`.
37    Node {
38        /// The subscripts for this contraction: one input per child, plus output.
39        subscripts: Subscripts,
40        /// Child sub-expressions (leaves or further nodes).
41        children: Vec<NestedEinsum>,
42    },
43}
44
45impl NestedEinsum {
46    /// Count the total number of leaf operands in the tree.
47    pub fn count_leaves(&self) -> usize {
48        match self {
49            Self::Leaf(_) => 1,
50            Self::Node { children, .. } => children.iter().map(|c| c.count_leaves()).sum(),
51        }
52    }
53
54    /// Parse parenthesized einsum notation into a recursive tree.
55    ///
56    /// Notation follows the standard `"inputs->output"` format with optional
57    /// parentheses to specify contraction order. Each parenthesized group
58    /// becomes an internal [`NestedEinsum::Node`]; bare operands become
59    /// [`NestedEinsum::Leaf`] nodes.
60    ///
61    /// # Examples
62    ///
63    /// ```
64    /// use tenferro_einsum::NestedEinsum;
65    ///
66    /// let nested = NestedEinsum::parse("(ij,jk),kl->il").unwrap();
67    /// // Root has two children: a group node and a leaf
68    /// match &nested {
69    ///     NestedEinsum::Node { children, .. } => assert_eq!(children.len(), 2),
70    ///     _ => panic!("expected Node"),
71    /// }
72    /// ```
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if parentheses are mismatched or the notation is
77    /// otherwise malformed.
78    pub fn parse(notation: &str) -> Result<Self> {
79        let (lhs, output_str) = split_and_validate_notation(notation)?;
80
81        let output: Vec<u32> = output_str
82            .chars()
83            .map(char_to_label)
84            .collect::<Result<_>>()?;
85
86        let mut leaf_counter: usize = 0;
87        let outer_needed: HashSet<u32> = output.iter().copied().collect();
88        Self::parse_group(lhs, &outer_needed, &output, &mut leaf_counter)
89    }
90
91    /// Recursively parse a group (possibly containing sub-groups) into a Node.
92    ///
93    /// `group_str` is a comma-separated list of items (at the top level),
94    /// where each item is either a bare operand (e.g. `"ij"`) or a
95    /// parenthesized sub-group (e.g. `"(ij,jk)"`).
96    ///
97    /// `outer_needed` contains labels that the parent or siblings need from
98    /// this group.  `final_output` is the overall output of the entire
99    /// expression.
100    fn parse_group(
101        group_str: &str,
102        outer_needed: &HashSet<u32>,
103        final_output: &[u32],
104        leaf_counter: &mut usize,
105    ) -> Result<Self> {
106        let items = Self::split_top_level(group_str)?;
107
108        let mut children = Vec::with_capacity(items.len());
109        let mut child_subscript_inputs: Vec<Vec<u32>> = Vec::with_capacity(items.len());
110
111        for (idx, item) in items.iter().enumerate() {
112            if item.starts_with('(') && item.ends_with(')') {
113                // Sub-group: strip outer parens and recurse
114                let inner = &item[1..item.len() - 1];
115
116                // Compute what this sub-group needs to output:
117                // labels in this group that appear in outer_needed or in sibling items
118                let group_labels = Self::collect_labels_in_order(inner)?;
119                let sibling_labels = Self::collect_sibling_labels(&items, idx)?;
120                let mut needed: HashSet<u32> = HashSet::new();
121                let mut sub_output = Vec::new();
122                for label in group_labels {
123                    if outer_needed.contains(&label) || sibling_labels.contains(&label) {
124                        needed.insert(label);
125                        sub_output.push(label);
126                    }
127                }
128
129                let child = Self::parse_group(inner, &needed, &sub_output, leaf_counter)?;
130                child_subscript_inputs.push(sub_output);
131                children.push(child);
132            } else {
133                // Bare operand -> Leaf
134                let labels: Vec<u32> = item.chars().map(char_to_label).collect::<Result<_>>()?;
135                child_subscript_inputs.push(labels);
136                children.push(NestedEinsum::Leaf(*leaf_counter));
137                *leaf_counter += 1;
138            }
139        }
140
141        // Build subscripts for this node
142        let node_output: Vec<u32> = final_output.to_vec();
143        let subscripts = Subscripts {
144            inputs: child_subscript_inputs,
145            output: node_output,
146        };
147
148        Ok(NestedEinsum::Node {
149            subscripts,
150            children,
151        })
152    }
153
154    /// Split a string on commas at the top level (depth 0), respecting parentheses.
155    fn split_top_level(s: &str) -> Result<Vec<&str>> {
156        let mut items = Vec::new();
157        let mut depth: usize = 0;
158        let mut start = 0;
159
160        for (pos, c) in s.char_indices() {
161            match c {
162                '(' => depth += 1,
163                ')' => {
164                    if depth == 0 {
165                        return Err(Error::InvalidArgument(format!(
166                            "unmatched ')' in einsum group: {s}"
167                        )));
168                    }
169                    depth -= 1;
170                }
171                ',' if depth == 0 => {
172                    items.push(&s[start..pos]);
173                    start = pos + 1; // skip the comma
174                }
175                _ => {}
176            }
177        }
178        // Push the last item
179        items.push(&s[start..]);
180        Ok(items)
181    }
182
183    /// Collect all unique labels from a (possibly nested) string, ignoring
184    /// parentheses and commas.
185    fn collect_labels(s: &str) -> Result<HashSet<u32>> {
186        let mut labels = HashSet::new();
187        for c in s.chars() {
188            match c {
189                '(' | ')' | ',' => continue,
190                _ => {
191                    labels.insert(char_to_label(c)?);
192                }
193            }
194        }
195        Ok(labels)
196    }
197
198    /// Collect unique labels in first-appearance order, ignoring
199    /// parentheses and commas.
200    fn collect_labels_in_order(s: &str) -> Result<Vec<u32>> {
201        let mut seen = HashSet::new();
202        let mut labels = Vec::new();
203        for c in s.chars() {
204            match c {
205                '(' | ')' | ',' => continue,
206                _ => {
207                    let label = char_to_label(c)?;
208                    if seen.insert(label) {
209                        labels.push(label);
210                    }
211                }
212            }
213        }
214        Ok(labels)
215    }
216
217    /// Collect all labels from sibling items (all items except the one at `current_idx`).
218    fn collect_sibling_labels(items: &[&str], current_idx: usize) -> Result<HashSet<u32>> {
219        let mut labels = HashSet::new();
220        for (idx, item) in items.iter().enumerate() {
221            if idx == current_idx {
222                continue;
223            }
224            for label in Self::collect_labels(item)? {
225                labels.insert(label);
226            }
227        }
228        Ok(labels)
229    }
230}
231
232#[cfg(test)]
233mod tests;