tenferro_einsum/syntax/subscripts.rs
1use crate::syntax::notation::{char_to_label, split_and_validate_notation};
2use crate::{Error, Result};
3
4/// Einsum subscripts using integer labels (omeinsum-rs compatible).
5///
6/// Each dimension is represented by a `u32` label. Labels shared across
7/// multiple input tensors are contracted (summed over). Repeated labels within
8/// one input select a diagonal before any reduction; if the repeated label is
9/// absent from the output, the diagonal is reduced. Repeated labels in the
10/// output embed the input on a diagonal.
11///
12/// # Examples
13///
14/// ```
15/// use tenferro_einsum::Subscripts;
16///
17/// // Matrix multiplication: C_{ik} = Σ_j A_{ij} * B_{jk}
18/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
19/// assert_eq!(subs.inputs.len(), 2);
20/// assert_eq!(subs.output, vec![0, 2]);
21/// ```
22///
23/// ```
24/// use tenferro_einsum::Subscripts;
25///
26/// // Parse from string notation
27/// let subs = Subscripts::parse("ij,jk->ik").unwrap();
28/// assert_eq!(subs.inputs.len(), 2);
29/// ```
30///
31/// ```
32/// use tenferro_einsum::Subscripts;
33///
34/// let trace = Subscripts::parse("ii->").unwrap();
35/// let diagonal = Subscripts::parse("ii->i").unwrap();
36/// let embed = Subscripts::parse("i->ii").unwrap();
37/// let higher_rank = Subscripts::parse("iij->ij").unwrap();
38///
39/// assert!(trace.output.is_empty());
40/// assert_eq!(diagonal.output, vec![b'i' as u32]);
41/// assert_eq!(embed.output, vec![b'i' as u32, b'i' as u32]);
42/// assert_eq!(higher_rank.inputs[0], vec![b'i' as u32, b'i' as u32, b'j' as u32]);
43/// ```
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct Subscripts {
46 /// Index labels for each input tensor.
47 pub inputs: Vec<Vec<u32>>,
48 /// Index labels for the output tensor.
49 pub output: Vec<u32>,
50}
51
52impl Subscripts {
53 /// Create subscripts from integer label arrays.
54 ///
55 /// # Arguments
56 ///
57 /// * `inputs` — Index labels for each input tensor
58 /// * `output` — Index labels for the output tensor
59 pub fn new(inputs: &[&[u32]], output: &[u32]) -> Self {
60 Self {
61 inputs: inputs.iter().map(|s| s.to_vec()).collect(),
62 output: output.to_vec(),
63 }
64 }
65
66 /// Parse subscripts from NumPy/PyTorch-style string notation.
67 ///
68 /// Each Unicode alphanumeric character represents a dimension label.
69 /// Labels are mapped to integer IDs via Unicode scalar values (`char as u32`).
70 /// Input tensors are separated by commas, and `->` separates inputs
71 /// from the output.
72 ///
73 /// Parentheses are rejected by this flat parser. Use
74 /// [`crate::NestedEinsum::parse`] when notation specifies a parenthesized
75 /// contraction order.
76 ///
77 /// # Examples
78 ///
79 /// - `"ij,jk->ik"` — matrix multiplication
80 /// - `"ii->"` — diagonal extraction followed by reduction (trace)
81 /// - `"ii->i"` — diagonal extraction
82 /// - `"i->ii"` — diagonal embedding
83 /// - `"iij->ij"` — higher-rank diagonal extraction
84 /// - `"ijk->"` — full contraction (scalar result)
85 /// # Errors
86 ///
87 /// Returns an error if the notation is malformed or contains
88 /// parenthesized contraction order.
89 pub fn parse(notation: &str) -> Result<Self> {
90 let (inputs_str, output_str) = split_and_validate_notation(notation)?;
91 if inputs_str.contains(['(', ')']) {
92 return Err(Error::InvalidArgument(
93 "Subscripts::parse does not accept parentheses; use NestedEinsum::parse to preserve parenthesized contraction order"
94 .into(),
95 ));
96 }
97
98 let output: Vec<u32> = output_str
99 .chars()
100 .map(char_to_label)
101 .collect::<Result<_>>()?;
102
103 let inputs: Vec<Vec<u32>> = inputs_str
104 .split(',')
105 .map(|s| s.chars().map(char_to_label).collect::<Result<_>>())
106 .collect::<Result<_>>()?;
107
108 Ok(Self { inputs, output })
109 }
110}
111
112#[cfg(test)]
113mod tests;