1use std::collections::{HashMap, HashSet};
2
3use tenferro_device::{Error, Result};
4
5use crate::syntax::subscripts::Subscripts;
6
7pub fn build_size_dict(
23 subscripts: &Subscripts,
24 shapes: &[&[usize]],
25 extra: Option<&HashMap<u32, usize>>,
26) -> Result<HashMap<u32, usize>> {
27 if subscripts.inputs.len() != shapes.len() {
28 return Err(Error::InvalidArgument(format!(
29 "expected {} input shapes, got {}",
30 subscripts.inputs.len(),
31 shapes.len()
32 )));
33 }
34 let mut size_dict: HashMap<u32, usize> = HashMap::new();
35 for (i, input_subs) in subscripts.inputs.iter().enumerate() {
36 if input_subs.len() != shapes[i].len() {
37 return Err(Error::InvalidArgument(format!(
38 "input {} has {} subscript labels but shape has {} dimensions",
39 i,
40 input_subs.len(),
41 shapes[i].len()
42 )));
43 }
44 for (j, &label) in input_subs.iter().enumerate() {
45 let size = shapes[i][j];
46 if let Some(&existing) = size_dict.get(&label) {
47 if existing != size {
48 return Err(Error::ShapeMismatch {
49 expected: vec![existing],
50 got: vec![size],
51 });
52 }
53 } else {
54 size_dict.insert(label, size);
55 }
56 }
57 }
58 if let Some(sd) = extra {
59 for (&label, &size) in sd {
60 size_dict.entry(label).or_insert(size);
61 }
62 }
63 Ok(size_dict)
64}
65
66pub fn compute_output_shape(
81 output_subs: &[u32],
82 size_dict: &HashMap<u32, usize>,
83) -> Result<Vec<usize>> {
84 output_subs
85 .iter()
86 .map(|&label| {
87 size_dict
88 .get(&label)
89 .copied()
90 .ok_or_else(|| Error::InvalidArgument(format!("unknown size for label {label}")))
91 })
92 .collect()
93}
94
95pub(crate) fn intermediate_subs(
98 subs_left: &[u32],
99 subs_right: &[u32],
100 needed: &HashSet<u32>,
101) -> Vec<u32> {
102 let mut seen = HashSet::new();
103 let mut output = Vec::new();
104 for &l in subs_left.iter().chain(subs_right.iter()) {
105 if needed.contains(&l) && seen.insert(l) {
106 output.push(l);
107 }
108 }
109 output
110}
111
112pub(crate) fn contraction_cost(
114 subs_a: &[u32],
115 subs_b: &[u32],
116 needed: &HashSet<u32>,
117 size_dict: &HashMap<u32, usize>,
118) -> usize {
119 let out_subs = intermediate_subs(subs_a, subs_b, needed);
120 out_subs
121 .iter()
122 .map(|l| {
123 debug_assert!(
124 size_dict.contains_key(l),
125 "contraction_cost: label {l} missing from size_dict"
126 );
127 size_dict.get(l).copied().unwrap_or(1)
128 })
129 .product::<usize>()
130 .max(1)
131}