tenferro_einsum/subscripts.rs
1use crate::{Error, Result, Subscripts};
2
3/// Canonical N-ary einsum subscripts using integer labels.
4///
5/// String notation is a user-facing convenience. Runtime integration layers can
6/// carry this representation in extension payloads so execution, shape
7/// inference, and AD do not need to parse strings.
8#[derive(Clone, Debug, PartialEq, Eq, Hash)]
9pub struct EinsumSubscripts {
10 /// Index labels for each input tensor.
11 pub inputs: Vec<Vec<u32>>,
12 /// Index labels for the output tensor.
13 pub output: Vec<u32>,
14}
15
16impl EinsumSubscripts {
17 /// Create subscripts from integer label arrays.
18 ///
19 /// # Examples
20 ///
21 /// ```
22 /// use tenferro_einsum::EinsumSubscripts;
23 ///
24 /// let subscripts = EinsumSubscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
25 ///
26 /// assert_eq!(subscripts.inputs, vec![vec![0, 1], vec![1, 2]]);
27 /// assert_eq!(subscripts.output, vec![0, 2]);
28 /// ```
29 pub fn new(inputs: &[&[u32]], output: &[u32]) -> Self {
30 Self {
31 inputs: inputs.iter().map(|labels| labels.to_vec()).collect(),
32 output: output.to_vec(),
33 }
34 }
35
36 /// Number of input operands described by this specification.
37 ///
38 /// # Examples
39 ///
40 /// ```
41 /// use tenferro_einsum::EinsumSubscripts;
42 ///
43 /// let subscripts = EinsumSubscripts::new(&[&[0], &[0]], &[]);
44 ///
45 /// assert_eq!(subscripts.input_count(), 2);
46 /// ```
47 #[must_use]
48 pub fn input_count(&self) -> usize {
49 self.inputs.len()
50 }
51}
52
53impl From<Subscripts> for EinsumSubscripts {
54 fn from(subscripts: Subscripts) -> Self {
55 Self {
56 inputs: subscripts.inputs,
57 output: subscripts.output,
58 }
59 }
60}
61
62impl From<&Subscripts> for EinsumSubscripts {
63 fn from(subscripts: &Subscripts) -> Self {
64 Self {
65 inputs: subscripts.inputs.clone(),
66 output: subscripts.output.clone(),
67 }
68 }
69}
70
71impl From<EinsumSubscripts> for Subscripts {
72 fn from(subscripts: EinsumSubscripts) -> Self {
73 Self {
74 inputs: subscripts.inputs,
75 output: subscripts.output,
76 }
77 }
78}
79
80impl From<&EinsumSubscripts> for Subscripts {
81 fn from(subscripts: &EinsumSubscripts) -> Self {
82 Self {
83 inputs: subscripts.inputs.clone(),
84 output: subscripts.output.clone(),
85 }
86 }
87}
88
89/// Parse string einsum notation into canonical integer labels.
90///
91/// # Examples
92///
93/// ```
94/// use tenferro_einsum::parse_einsum_subscripts;
95///
96/// let subscripts = parse_einsum_subscripts("ij,jk->ik").unwrap();
97///
98/// assert_eq!(subscripts.inputs.len(), 2);
99/// assert_eq!(subscripts.output, vec![b'i' as u32, b'k' as u32]);
100/// ```
101///
102/// # Errors
103///
104/// Returns an error if the notation is malformed.
105pub fn parse_einsum_subscripts(notation: &str) -> Result<EinsumSubscripts> {
106 Subscripts::parse(notation)
107 .map(EinsumSubscripts::from)
108 .map_err(|err| Error::InvalidArgument(format!("invalid einsum subscripts: {err}")))
109}
110
111#[cfg(test)]
112mod tests;