tenferro_einsum/syntax/subscripts.rs
1use tenferro_device::Result;
2
3use crate::syntax::notation::{char_to_label, split_and_validate_notation};
4
5/// Einsum subscripts using integer labels (omeinsum-rs compatible).
6///
7/// Each dimension is represented by a `u32` label. Labels shared across
8/// multiple input tensors are contracted (summed over). Labels present
9/// only in the output are free indices.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_einsum::Subscripts;
15///
16/// // Matrix multiplication: C_{ik} = Σ_j A_{ij} * B_{jk}
17/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
18/// assert_eq!(subs.inputs.len(), 2);
19/// assert_eq!(subs.output, vec![0, 2]);
20/// ```
21///
22/// ```ignore
23/// use tenferro_einsum::Subscripts;
24///
25/// // Parse from string notation
26/// let subs = Subscripts::parse("ij,jk->ik").unwrap();
27/// assert_eq!(subs.inputs.len(), 2);
28/// ```
29#[derive(Debug, Clone)]
30pub struct Subscripts {
31 /// Index labels for each input tensor.
32 pub inputs: Vec<Vec<u32>>,
33 /// Index labels for the output tensor.
34 pub output: Vec<u32>,
35}
36
37impl Subscripts {
38 /// Create subscripts from integer label arrays.
39 ///
40 /// # Arguments
41 ///
42 /// * `inputs` — Index labels for each input tensor
43 /// * `output` — Index labels for the output tensor
44 pub fn new(inputs: &[&[u32]], output: &[u32]) -> Self {
45 Self {
46 inputs: inputs.iter().map(|s| s.to_vec()).collect(),
47 output: output.to_vec(),
48 }
49 }
50
51 /// Parse subscripts from NumPy/PyTorch-style string notation.
52 ///
53 /// Each Unicode alphanumeric character represents a dimension label.
54 /// Labels are mapped to integer IDs via Unicode scalar values (`char as u32`).
55 /// Input tensors are separated by commas, and `->` separates inputs
56 /// from the output.
57 ///
58 /// Parentheses in the notation are accepted but stripped during parsing.
59 /// To respect parenthesized contraction order, use [`crate::NestedEinsum::parse`]
60 /// or pass the parenthesized string directly to [`crate::einsum`].
61 ///
62 /// # Examples
63 ///
64 /// - `"ij,jk->ik"` — matrix multiplication
65 /// - `"ii->i"` — diagonal extraction
66 /// - `"ijk->"` — full contraction (scalar result)
67 /// - `"ij,(jk,kl)->il"` — contract B and C first, then with A
68 ///
69 /// # Errors
70 ///
71 /// Returns an error if the notation is malformed.
72 pub fn parse(notation: &str) -> Result<Self> {
73 let (inputs_str, output_str) = split_and_validate_notation(notation)?;
74
75 let output: Vec<u32> = output_str
76 .chars()
77 .map(char_to_label)
78 .collect::<Result<_>>()?;
79
80 // Parentheses already validated by split_and_validate_notation() above.
81 // Strip parentheses and parse input labels
82 let clean_inputs = inputs_str.replace(['(', ')'], "");
83 let inputs: Vec<Vec<u32>> = clean_inputs
84 .split(',')
85 .map(|s| s.chars().map(char_to_label).collect::<Result<_>>())
86 .collect::<Result<_>>()?;
87
88 Ok(Self { inputs, output })
89 }
90}