Skip to main content

tenferro_einsum/syntax/
subscripts.rs

1use crate::syntax::notation::{char_to_label, split_and_validate_notation};
2use crate::{Error, Result};
3
4/// Einsum subscripts using integer labels (omeinsum-rs compatible).
5///
6/// Each dimension is represented by a `u32` label. Labels shared across
7/// multiple input tensors are contracted (summed over). Repeated labels within
8/// one input select a diagonal before any reduction; if the repeated label is
9/// absent from the output, the diagonal is reduced. Repeated labels in the
10/// output embed the input on a diagonal.
11///
12/// # Examples
13///
14/// ```
15/// use tenferro_einsum::Subscripts;
16///
17/// // Matrix multiplication: C_{ik} = Σ_j A_{ij} * B_{jk}
18/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
19/// assert_eq!(subs.inputs.len(), 2);
20/// assert_eq!(subs.output, vec![0, 2]);
21/// ```
22///
23/// ```
24/// use tenferro_einsum::Subscripts;
25///
26/// // Parse from string notation
27/// let subs = Subscripts::parse("ij,jk->ik").unwrap();
28/// assert_eq!(subs.inputs.len(), 2);
29/// ```
30///
31/// ```
32/// use tenferro_einsum::Subscripts;
33///
34/// let trace = Subscripts::parse("ii->").unwrap();
35/// let diagonal = Subscripts::parse("ii->i").unwrap();
36/// let embed = Subscripts::parse("i->ii").unwrap();
37/// let higher_rank = Subscripts::parse("iij->ij").unwrap();
38///
39/// assert!(trace.output.is_empty());
40/// assert_eq!(diagonal.output, vec![b'i' as u32]);
41/// assert_eq!(embed.output, vec![b'i' as u32, b'i' as u32]);
42/// assert_eq!(higher_rank.inputs[0], vec![b'i' as u32, b'i' as u32, b'j' as u32]);
43/// ```
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45pub struct Subscripts {
46    /// Index labels for each input tensor.
47    pub inputs: Vec<Vec<u32>>,
48    /// Index labels for the output tensor.
49    pub output: Vec<u32>,
50}
51
52impl Subscripts {
53    /// Create subscripts from integer label arrays.
54    ///
55    /// # Arguments
56    ///
57    /// * `inputs` — Index labels for each input tensor
58    /// * `output` — Index labels for the output tensor
59    pub fn new(inputs: &[&[u32]], output: &[u32]) -> Self {
60        Self {
61            inputs: inputs.iter().map(|s| s.to_vec()).collect(),
62            output: output.to_vec(),
63        }
64    }
65
66    /// Parse subscripts from NumPy/PyTorch-style string notation.
67    ///
68    /// Each Unicode alphanumeric character represents a dimension label.
69    /// Labels are mapped to integer IDs via Unicode scalar values (`char as u32`).
70    /// Input tensors are separated by commas, and `->` separates inputs
71    /// from the output.
72    ///
73    /// Parentheses are rejected by this flat parser. Use
74    /// [`crate::NestedEinsum::parse`] when notation specifies a parenthesized
75    /// contraction order.
76    ///
77    /// # Examples
78    ///
79    /// - `"ij,jk->ik"` — matrix multiplication
80    /// - `"ii->"` — diagonal extraction followed by reduction (trace)
81    /// - `"ii->i"` — diagonal extraction
82    /// - `"i->ii"` — diagonal embedding
83    /// - `"iij->ij"` — higher-rank diagonal extraction
84    /// - `"ijk->"` — full contraction (scalar result)
85    /// # Errors
86    ///
87    /// Returns an error if the notation is malformed or contains
88    /// parenthesized contraction order.
89    pub fn parse(notation: &str) -> Result<Self> {
90        let (inputs_str, output_str) = split_and_validate_notation(notation)?;
91        if inputs_str.contains(['(', ')']) {
92            return Err(Error::InvalidArgument(
93                "Subscripts::parse does not accept parentheses; use NestedEinsum::parse to preserve parenthesized contraction order"
94                    .into(),
95            ));
96        }
97
98        let output: Vec<u32> = output_str
99            .chars()
100            .map(char_to_label)
101            .collect::<Result<_>>()?;
102
103        let inputs: Vec<Vec<u32>> = inputs_str
104            .split(',')
105            .map(|s| s.chars().map(char_to_label).collect::<Result<_>>())
106            .collect::<Result<_>>()?;
107
108        Ok(Self { inputs, output })
109    }
110}
111
112#[cfg(test)]
113mod tests;