Skip to main content

tenferro_einsum/
util.rs

1use std::collections::{HashMap, HashSet};
2
3use tenferro_device::{Error, Result};
4
5use crate::syntax::subscripts::Subscripts;
6
7/// Build a label -> size mapping from subscripts and input shapes.
8///
9/// # Examples
10///
11/// ```
12/// use tenferro_einsum::{build_size_dict, Subscripts};
13///
14/// let subs = Subscripts::parse("ij,jk->ik").unwrap();
15/// let shapes = [&[2, 3][..], &[3, 4][..]];
16/// let sizes = build_size_dict(&subs, &shapes, None).unwrap();
17///
18/// assert_eq!(sizes.get(&(b'i' as u32)), Some(&2));
19/// assert_eq!(sizes.get(&(b'j' as u32)), Some(&3));
20/// assert_eq!(sizes.get(&(b'k' as u32)), Some(&4));
21/// ```
22pub fn build_size_dict(
23    subscripts: &Subscripts,
24    shapes: &[&[usize]],
25    extra: Option<&HashMap<u32, usize>>,
26) -> Result<HashMap<u32, usize>> {
27    if subscripts.inputs.len() != shapes.len() {
28        return Err(Error::InvalidArgument(format!(
29            "expected {} input shapes, got {}",
30            subscripts.inputs.len(),
31            shapes.len()
32        )));
33    }
34    let mut size_dict: HashMap<u32, usize> = HashMap::new();
35    for (i, input_subs) in subscripts.inputs.iter().enumerate() {
36        if input_subs.len() != shapes[i].len() {
37            return Err(Error::InvalidArgument(format!(
38                "input {} has {} subscript labels but shape has {} dimensions",
39                i,
40                input_subs.len(),
41                shapes[i].len()
42            )));
43        }
44        for (j, &label) in input_subs.iter().enumerate() {
45            let size = shapes[i][j];
46            if let Some(&existing) = size_dict.get(&label) {
47                if existing != size {
48                    return Err(Error::ShapeMismatch {
49                        expected: vec![existing],
50                        got: vec![size],
51                    });
52                }
53            } else {
54                size_dict.insert(label, size);
55            }
56        }
57    }
58    if let Some(sd) = extra {
59        for (&label, &size) in sd {
60            size_dict.entry(label).or_insert(size);
61        }
62    }
63    Ok(size_dict)
64}
65
66/// Compute output shape from output subscripts and size dictionary.
67///
68/// # Examples
69///
70/// ```
71/// use tenferro_einsum::{build_size_dict, compute_output_shape, Subscripts};
72///
73/// let subs = Subscripts::parse("ij,jk->ik").unwrap();
74/// let shapes = [&[2, 3][..], &[3, 4][..]];
75/// let sizes = build_size_dict(&subs, &shapes, None).unwrap();
76/// let output_shape = compute_output_shape(&subs.output, &sizes).unwrap();
77///
78/// assert_eq!(output_shape, vec![2, 4]);
79/// ```
80pub fn compute_output_shape(
81    output_subs: &[u32],
82    size_dict: &HashMap<u32, usize>,
83) -> Result<Vec<usize>> {
84    output_subs
85        .iter()
86        .map(|&label| {
87            size_dict
88                .get(&label)
89                .copied()
90                .ok_or_else(|| Error::InvalidArgument(format!("unknown size for label {label}")))
91        })
92        .collect()
93}
94
95/// Compute intermediate subscripts when contracting two operands.
96/// Keeps labels from left/right that appear in the `needed` set.
97pub(crate) fn intermediate_subs(
98    subs_left: &[u32],
99    subs_right: &[u32],
100    needed: &HashSet<u32>,
101) -> Vec<u32> {
102    let mut seen = HashSet::new();
103    let mut output = Vec::new();
104    for &l in subs_left.iter().chain(subs_right.iter()) {
105        if needed.contains(&l) && seen.insert(l) {
106            output.push(l);
107        }
108    }
109    output
110}
111
112/// Compute the cost (output size) of contracting two operands.
113pub(crate) fn contraction_cost(
114    subs_a: &[u32],
115    subs_b: &[u32],
116    needed: &HashSet<u32>,
117    size_dict: &HashMap<u32, usize>,
118) -> usize {
119    let out_subs = intermediate_subs(subs_a, subs_b, needed);
120    out_subs
121        .iter()
122        .map(|l| {
123            debug_assert!(
124                size_dict.contains_key(l),
125                "contraction_cost: label {l} missing from size_dict"
126            );
127            size_dict.get(l).copied().unwrap_or(1)
128        })
129        .product::<usize>()
130        .max(1)
131}