tenferro_einsum/syntax/nested.rs
1use std::collections::HashSet;
2
3use tenferro_device::{Error, Result};
4
5use crate::syntax::notation::{char_to_label, split_and_validate_notation};
6use crate::syntax::subscripts::Subscripts;
7
8/// Recursive einsum tree that preserves parenthesized grouping.
9///
10/// `NestedEinsum` mirrors OMEinsum.jl's `NestedEinsum`: each internal node
11/// holds [`Subscripts`] describing how its children are contracted, and leaf
12/// nodes reference an original input tensor by index.
13///
14/// # Construction
15///
16/// Use [`NestedEinsum::parse`] to build a tree from parenthesized string
17/// notation such as `"(ij,jk),kl->il"`. Without parentheses the result is
18/// a flat root node whose children are all leaves.
19///
20/// # Examples
21///
22/// ```
23/// use tenferro_einsum::NestedEinsum;
24///
25/// // Flat (no grouping): root with two leaves
26/// let flat = NestedEinsum::parse("ij,jk->ik").unwrap();
27/// assert!(matches!(flat, NestedEinsum::Node { .. }));
28///
29/// // Grouped: contract first two operands, then with third
30/// let grouped = NestedEinsum::parse("(ij,jk),kl->il").unwrap();
31/// assert!(matches!(grouped, NestedEinsum::Node { .. }));
32/// ```
33#[derive(Debug, Clone)]
34pub enum NestedEinsum {
35 /// A leaf referencing one of the original input tensors by index.
36 Leaf(usize),
37 /// An internal node that contracts its children according to `subscripts`.
38 Node {
39 /// The subscripts for this contraction: one input per child, plus output.
40 subscripts: Subscripts,
41 /// Child sub-expressions (leaves or further nodes).
42 children: Vec<NestedEinsum>,
43 },
44}
45
46impl NestedEinsum {
47 /// Count the total number of leaf operands in the tree.
48 pub fn count_leaves(&self) -> usize {
49 match self {
50 Self::Leaf(_) => 1,
51 Self::Node { children, .. } => children.iter().map(|c| c.count_leaves()).sum(),
52 }
53 }
54
55 /// Parse parenthesized einsum notation into a recursive tree.
56 ///
57 /// Notation follows the standard `"inputs->output"` format with optional
58 /// parentheses to specify contraction order. Each parenthesized group
59 /// becomes an internal [`NestedEinsum::Node`]; bare operands become
60 /// [`NestedEinsum::Leaf`] nodes.
61 ///
62 /// # Examples
63 ///
64 /// ```
65 /// use tenferro_einsum::NestedEinsum;
66 ///
67 /// let nested = NestedEinsum::parse("(ij,jk),kl->il").unwrap();
68 /// // Root has two children: a group node and a leaf
69 /// match &nested {
70 /// NestedEinsum::Node { children, .. } => assert_eq!(children.len(), 2),
71 /// _ => panic!("expected Node"),
72 /// }
73 /// ```
74 ///
75 /// # Errors
76 ///
77 /// Returns an error if parentheses are mismatched or the notation is
78 /// otherwise malformed.
79 pub fn parse(notation: &str) -> Result<Self> {
80 let (lhs, output_str) = split_and_validate_notation(notation)?;
81
82 let output: Vec<u32> = output_str
83 .chars()
84 .map(char_to_label)
85 .collect::<Result<_>>()?;
86
87 let mut leaf_counter: usize = 0;
88 let outer_needed: HashSet<u32> = output.iter().copied().collect();
89 Self::parse_group(lhs, &outer_needed, &output, &mut leaf_counter)
90 }
91
92 /// Recursively parse a group (possibly containing sub-groups) into a Node.
93 ///
94 /// `group_str` is a comma-separated list of items (at the top level),
95 /// where each item is either a bare operand (e.g. `"ij"`) or a
96 /// parenthesized sub-group (e.g. `"(ij,jk)"`).
97 ///
98 /// `outer_needed` contains labels that the parent or siblings need from
99 /// this group. `final_output` is the overall output of the entire
100 /// expression.
101 fn parse_group(
102 group_str: &str,
103 outer_needed: &HashSet<u32>,
104 final_output: &[u32],
105 leaf_counter: &mut usize,
106 ) -> Result<Self> {
107 let items = Self::split_top_level(group_str)?;
108
109 let mut children = Vec::with_capacity(items.len());
110 let mut child_subscript_inputs: Vec<Vec<u32>> = Vec::with_capacity(items.len());
111
112 for (idx, item) in items.iter().enumerate() {
113 if item.starts_with('(') && item.ends_with(')') {
114 // Sub-group: strip outer parens and recurse
115 let inner = &item[1..item.len() - 1];
116
117 // Compute what this sub-group needs to output:
118 // labels in this group that appear in outer_needed or in sibling items
119 let group_labels = Self::collect_labels(inner)?;
120 let sibling_labels = Self::collect_sibling_labels(&items, idx)?;
121 let mut needed: HashSet<u32> = HashSet::new();
122 for &label in &group_labels {
123 if outer_needed.contains(&label) || sibling_labels.contains(&label) {
124 needed.insert(label);
125 }
126 }
127 let mut sub_output: Vec<u32> = needed.iter().copied().collect();
128 sub_output.sort();
129
130 let child = Self::parse_group(inner, &needed, &sub_output, leaf_counter)?;
131 child_subscript_inputs.push(sub_output);
132 children.push(child);
133 } else {
134 // Bare operand -> Leaf
135 let labels: Vec<u32> = item.chars().map(char_to_label).collect::<Result<_>>()?;
136 child_subscript_inputs.push(labels);
137 children.push(NestedEinsum::Leaf(*leaf_counter));
138 *leaf_counter += 1;
139 }
140 }
141
142 // Build subscripts for this node
143 let node_output: Vec<u32> = final_output.to_vec();
144 let subscripts = Subscripts {
145 inputs: child_subscript_inputs,
146 output: node_output,
147 };
148
149 Ok(NestedEinsum::Node {
150 subscripts,
151 children,
152 })
153 }
154
155 /// Split a string on commas at the top level (depth 0), respecting parentheses.
156 fn split_top_level(s: &str) -> Result<Vec<&str>> {
157 let mut items = Vec::new();
158 let mut depth: usize = 0;
159 let mut start = 0;
160
161 for (pos, c) in s.char_indices() {
162 match c {
163 '(' => depth += 1,
164 ')' => {
165 if depth == 0 {
166 return Err(Error::InvalidArgument(format!(
167 "unmatched ')' in einsum group: {s}"
168 )));
169 }
170 depth -= 1;
171 }
172 ',' if depth == 0 => {
173 items.push(&s[start..pos]);
174 start = pos + 1; // skip the comma
175 }
176 _ => {}
177 }
178 }
179 // Push the last item
180 items.push(&s[start..]);
181 Ok(items)
182 }
183
184 /// Collect all unique labels from a (possibly nested) string, ignoring
185 /// parentheses and commas.
186 fn collect_labels(s: &str) -> Result<HashSet<u32>> {
187 let mut labels = HashSet::new();
188 for c in s.chars() {
189 match c {
190 '(' | ')' | ',' => continue,
191 _ => {
192 labels.insert(char_to_label(c)?);
193 }
194 }
195 }
196 Ok(labels)
197 }
198
199 /// Collect all labels from sibling items (all items except the one at `current_idx`).
200 fn collect_sibling_labels(items: &[&str], current_idx: usize) -> Result<HashSet<u32>> {
201 let mut labels = HashSet::new();
202 for (idx, item) in items.iter().enumerate() {
203 if idx == current_idx {
204 continue;
205 }
206 for label in Self::collect_labels(item)? {
207 labels.insert(label);
208 }
209 }
210 Ok(labels)
211 }
212}