Skip to main content

tenferro_einsum/
subscripts.rs

1use crate::{Error, Result, Subscripts};
2
3/// Canonical N-ary einsum subscripts using integer labels.
4///
5/// String notation is a user-facing convenience. Runtime integration layers can
6/// carry this representation in extension payloads so execution, shape
7/// inference, and AD do not need to parse strings.
8#[derive(Clone, Debug, PartialEq, Eq, Hash)]
9pub struct EinsumSubscripts {
10    /// Index labels for each input tensor.
11    pub inputs: Vec<Vec<u32>>,
12    /// Index labels for the output tensor.
13    pub output: Vec<u32>,
14}
15
16impl EinsumSubscripts {
17    /// Create subscripts from integer label arrays.
18    ///
19    /// # Examples
20    ///
21    /// ```
22    /// use tenferro_einsum::EinsumSubscripts;
23    ///
24    /// let subscripts = EinsumSubscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
25    ///
26    /// assert_eq!(subscripts.inputs, vec![vec![0, 1], vec![1, 2]]);
27    /// assert_eq!(subscripts.output, vec![0, 2]);
28    /// ```
29    pub fn new(inputs: &[&[u32]], output: &[u32]) -> Self {
30        Self {
31            inputs: inputs.iter().map(|labels| labels.to_vec()).collect(),
32            output: output.to_vec(),
33        }
34    }
35
36    /// Number of input operands described by this specification.
37    ///
38    /// # Examples
39    ///
40    /// ```
41    /// use tenferro_einsum::EinsumSubscripts;
42    ///
43    /// let subscripts = EinsumSubscripts::new(&[&[0], &[0]], &[]);
44    ///
45    /// assert_eq!(subscripts.input_count(), 2);
46    /// ```
47    #[must_use]
48    pub fn input_count(&self) -> usize {
49        self.inputs.len()
50    }
51}
52
53impl From<Subscripts> for EinsumSubscripts {
54    fn from(subscripts: Subscripts) -> Self {
55        Self {
56            inputs: subscripts.inputs,
57            output: subscripts.output,
58        }
59    }
60}
61
62impl From<&Subscripts> for EinsumSubscripts {
63    fn from(subscripts: &Subscripts) -> Self {
64        Self {
65            inputs: subscripts.inputs.clone(),
66            output: subscripts.output.clone(),
67        }
68    }
69}
70
71impl From<EinsumSubscripts> for Subscripts {
72    fn from(subscripts: EinsumSubscripts) -> Self {
73        Self {
74            inputs: subscripts.inputs,
75            output: subscripts.output,
76        }
77    }
78}
79
80impl From<&EinsumSubscripts> for Subscripts {
81    fn from(subscripts: &EinsumSubscripts) -> Self {
82        Self {
83            inputs: subscripts.inputs.clone(),
84            output: subscripts.output.clone(),
85        }
86    }
87}
88
89/// Parse string einsum notation into canonical integer labels.
90///
91/// # Examples
92///
93/// ```
94/// use tenferro_einsum::parse_einsum_subscripts;
95///
96/// let subscripts = parse_einsum_subscripts("ij,jk->ik").unwrap();
97///
98/// assert_eq!(subscripts.inputs.len(), 2);
99/// assert_eq!(subscripts.output, vec![b'i' as u32, b'k' as u32]);
100/// ```
101///
102/// # Errors
103///
104/// Returns an error if the notation is malformed.
105pub fn parse_einsum_subscripts(notation: &str) -> Result<EinsumSubscripts> {
106    Subscripts::parse(notation)
107        .map(EinsumSubscripts::from)
108        .map_err(|err| Error::InvalidArgument(format!("invalid einsum subscripts: {err}")))
109}
110
111#[cfg(test)]
112mod tests;