tenferro_einsum/syntax/nested.rs
1use std::collections::HashSet;
2
3use crate::syntax::notation::{char_to_label, split_and_validate_notation};
4use crate::syntax::subscripts::Subscripts;
5use crate::{Error, Result};
6
7/// Recursive einsum tree that preserves parenthesized grouping.
8///
9/// `NestedEinsum` mirrors OMEinsum.jl's `NestedEinsum`: each internal node
10/// holds [`Subscripts`] describing how its children are contracted, and leaf
11/// nodes reference an original input tensor by index.
12///
13/// # Construction
14///
15/// Use [`NestedEinsum::parse`] to build a tree from parenthesized string
16/// notation such as `"(ij,jk),kl->il"`. Without parentheses the result is
17/// a flat root node whose children are all leaves.
18///
19/// # Examples
20///
21/// ```
22/// use tenferro_einsum::NestedEinsum;
23///
24/// // Flat (no grouping): root with two leaves
25/// let flat = NestedEinsum::parse("ij,jk->ik").unwrap();
26/// assert!(matches!(flat, NestedEinsum::Node { .. }));
27///
28/// // Grouped: contract first two operands, then with third
29/// let grouped = NestedEinsum::parse("(ij,jk),kl->il").unwrap();
30/// assert!(matches!(grouped, NestedEinsum::Node { .. }));
31/// ```
32#[derive(Debug, Clone)]
33pub enum NestedEinsum {
34 /// A leaf referencing one of the original input tensors by index.
35 Leaf(usize),
36 /// An internal node that contracts its children according to `subscripts`.
37 Node {
38 /// The subscripts for this contraction: one input per child, plus output.
39 subscripts: Subscripts,
40 /// Child sub-expressions (leaves or further nodes).
41 children: Vec<NestedEinsum>,
42 },
43}
44
45impl NestedEinsum {
46 /// Count the total number of leaf operands in the tree.
47 pub fn count_leaves(&self) -> usize {
48 match self {
49 Self::Leaf(_) => 1,
50 Self::Node { children, .. } => children.iter().map(|c| c.count_leaves()).sum(),
51 }
52 }
53
54 /// Parse parenthesized einsum notation into a recursive tree.
55 ///
56 /// Notation follows the standard `"inputs->output"` format with optional
57 /// parentheses to specify contraction order. Each parenthesized group
58 /// becomes an internal [`NestedEinsum::Node`]; bare operands become
59 /// [`NestedEinsum::Leaf`] nodes.
60 ///
61 /// # Examples
62 ///
63 /// ```
64 /// use tenferro_einsum::NestedEinsum;
65 ///
66 /// let nested = NestedEinsum::parse("(ij,jk),kl->il").unwrap();
67 /// // Root has two children: a group node and a leaf
68 /// match &nested {
69 /// NestedEinsum::Node { children, .. } => assert_eq!(children.len(), 2),
70 /// _ => panic!("expected Node"),
71 /// }
72 /// ```
73 ///
74 /// # Errors
75 ///
76 /// Returns an error if parentheses are mismatched or the notation is
77 /// otherwise malformed.
78 pub fn parse(notation: &str) -> Result<Self> {
79 let (lhs, output_str) = split_and_validate_notation(notation)?;
80
81 let output: Vec<u32> = output_str
82 .chars()
83 .map(char_to_label)
84 .collect::<Result<_>>()?;
85
86 let mut leaf_counter: usize = 0;
87 let outer_needed: HashSet<u32> = output.iter().copied().collect();
88 Self::parse_group(lhs, &outer_needed, &output, &mut leaf_counter)
89 }
90
91 /// Recursively parse a group (possibly containing sub-groups) into a Node.
92 ///
93 /// `group_str` is a comma-separated list of items (at the top level),
94 /// where each item is either a bare operand (e.g. `"ij"`) or a
95 /// parenthesized sub-group (e.g. `"(ij,jk)"`).
96 ///
97 /// `outer_needed` contains labels that the parent or siblings need from
98 /// this group. `final_output` is the overall output of the entire
99 /// expression.
100 fn parse_group(
101 group_str: &str,
102 outer_needed: &HashSet<u32>,
103 final_output: &[u32],
104 leaf_counter: &mut usize,
105 ) -> Result<Self> {
106 let items = Self::split_top_level(group_str)?;
107
108 let mut children = Vec::with_capacity(items.len());
109 let mut child_subscript_inputs: Vec<Vec<u32>> = Vec::with_capacity(items.len());
110
111 for (idx, item) in items.iter().enumerate() {
112 if item.starts_with('(') && item.ends_with(')') {
113 // Sub-group: strip outer parens and recurse
114 let inner = &item[1..item.len() - 1];
115
116 // Compute what this sub-group needs to output:
117 // labels in this group that appear in outer_needed or in sibling items
118 let group_labels = Self::collect_labels_in_order(inner)?;
119 let sibling_labels = Self::collect_sibling_labels(&items, idx)?;
120 let mut needed: HashSet<u32> = HashSet::new();
121 let mut sub_output = Vec::new();
122 for label in group_labels {
123 if outer_needed.contains(&label) || sibling_labels.contains(&label) {
124 needed.insert(label);
125 sub_output.push(label);
126 }
127 }
128
129 let child = Self::parse_group(inner, &needed, &sub_output, leaf_counter)?;
130 child_subscript_inputs.push(sub_output);
131 children.push(child);
132 } else {
133 // Bare operand -> Leaf
134 let labels: Vec<u32> = item.chars().map(char_to_label).collect::<Result<_>>()?;
135 child_subscript_inputs.push(labels);
136 children.push(NestedEinsum::Leaf(*leaf_counter));
137 *leaf_counter += 1;
138 }
139 }
140
141 // Build subscripts for this node
142 let node_output: Vec<u32> = final_output.to_vec();
143 let subscripts = Subscripts {
144 inputs: child_subscript_inputs,
145 output: node_output,
146 };
147
148 Ok(NestedEinsum::Node {
149 subscripts,
150 children,
151 })
152 }
153
154 /// Split a string on commas at the top level (depth 0), respecting parentheses.
155 fn split_top_level(s: &str) -> Result<Vec<&str>> {
156 let mut items = Vec::new();
157 let mut depth: usize = 0;
158 let mut start = 0;
159
160 for (pos, c) in s.char_indices() {
161 match c {
162 '(' => depth += 1,
163 ')' => {
164 if depth == 0 {
165 return Err(Error::InvalidArgument(format!(
166 "unmatched ')' in einsum group: {s}"
167 )));
168 }
169 depth -= 1;
170 }
171 ',' if depth == 0 => {
172 items.push(&s[start..pos]);
173 start = pos + 1; // skip the comma
174 }
175 _ => {}
176 }
177 }
178 // Push the last item
179 items.push(&s[start..]);
180 Ok(items)
181 }
182
183 /// Collect all unique labels from a (possibly nested) string, ignoring
184 /// parentheses and commas.
185 fn collect_labels(s: &str) -> Result<HashSet<u32>> {
186 let mut labels = HashSet::new();
187 for c in s.chars() {
188 match c {
189 '(' | ')' | ',' => continue,
190 _ => {
191 labels.insert(char_to_label(c)?);
192 }
193 }
194 }
195 Ok(labels)
196 }
197
198 /// Collect unique labels in first-appearance order, ignoring
199 /// parentheses and commas.
200 fn collect_labels_in_order(s: &str) -> Result<Vec<u32>> {
201 let mut seen = HashSet::new();
202 let mut labels = Vec::new();
203 for c in s.chars() {
204 match c {
205 '(' | ')' | ',' => continue,
206 _ => {
207 let label = char_to_label(c)?;
208 if seen.insert(label) {
209 labels.push(label);
210 }
211 }
212 }
213 }
214 Ok(labels)
215 }
216
217 /// Collect all labels from sibling items (all items except the one at `current_idx`).
218 fn collect_sibling_labels(items: &[&str], current_idx: usize) -> Result<HashSet<u32>> {
219 let mut labels = HashSet::new();
220 for (idx, item) in items.iter().enumerate() {
221 if idx == current_idx {
222 continue;
223 }
224 for label in Self::collect_labels(item)? {
225 labels.insert(label);
226 }
227 }
228 Ok(labels)
229 }
230}
231
232#[cfg(test)]
233mod tests;