tenferro_prims/cpu/
plan.rs1use std::marker::PhantomData;
2
3use tenferro_algebra::Scalar;
4
5use crate::SemiringBinaryOp;
6
7pub(super) fn compute_paired_components(
13 paired_axes: &[(usize, usize)],
14 shape: &[usize],
15) -> (Vec<Vec<usize>>, Vec<usize>) {
16 use std::collections::HashMap;
17
18 if paired_axes.is_empty() {
19 return (vec![], vec![]);
20 }
21
22 let mut all_axes: Vec<usize> = Vec::new();
23 for &(ax1, ax2) in paired_axes {
24 all_axes.push(ax1);
25 all_axes.push(ax2);
26 }
27 all_axes.sort();
28 all_axes.dedup();
29
30 let mut parent: HashMap<usize, usize> = all_axes.iter().map(|&ax| (ax, ax)).collect();
31
32 fn find(parent: &mut HashMap<usize, usize>, x: usize) -> usize {
33 let p = parent[&x];
34 if p != x {
35 let root = find(parent, p);
36 parent.insert(x, root);
37 root
38 } else {
39 x
40 }
41 }
42
43 for &(ax1, ax2) in paired_axes {
44 let r1 = find(&mut parent, ax1);
45 let r2 = find(&mut parent, ax2);
46 if r1 != r2 {
47 let (lo, hi) = if r1 < r2 { (r1, r2) } else { (r2, r1) };
48 parent.insert(hi, lo);
49 }
50 }
51
52 let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
53 for &ax in &all_axes {
54 let root = find(&mut parent, ax);
55 groups.entry(root).or_default().push(ax);
56 }
57
58 let mut components: Vec<Vec<usize>> = groups.into_values().collect();
59 components.sort_by_key(|c| c[0]);
60
61 let comp_dims: Vec<usize> = components.iter().map(|c| shape[c[0]]).collect();
62
63 (components, comp_dims)
64}
65
66#[derive(Debug, Clone)]
68pub(super) struct ContractGemmSpec {
69 pub(super) a_target: Vec<u32>,
71 pub(super) b_target: Vec<u32>,
73 pub(super) c_target: Vec<u32>,
75 pub(super) batch_modes: Vec<u32>,
76 pub(super) m_modes: Vec<u32>,
77 pub(super) n_modes: Vec<u32>,
78 pub(super) k_modes: Vec<u32>,
79}
80
81pub(super) fn build_contract_gemm_spec(
84 modes_a: &[u32],
85 modes_b: &[u32],
86 modes_c: &[u32],
87) -> Option<ContractGemmSpec> {
88 let batch_modes: Vec<u32> = modes_c
89 .iter()
90 .copied()
91 .filter(|m| modes_a.contains(m) && modes_b.contains(m))
92 .collect();
93 let m_modes: Vec<u32> = modes_c
94 .iter()
95 .copied()
96 .filter(|m| modes_a.contains(m) && !modes_b.contains(m))
97 .collect();
98 let n_modes: Vec<u32> = modes_c
99 .iter()
100 .copied()
101 .filter(|m| modes_b.contains(m) && !modes_a.contains(m))
102 .collect();
103 let k_modes: Vec<u32> = modes_a
104 .iter()
105 .copied()
106 .filter(|m| modes_b.contains(m) && !modes_c.contains(m))
107 .collect();
108
109 let expected_a = batch_modes.len() + m_modes.len() + k_modes.len();
110 let expected_b = batch_modes.len() + k_modes.len() + n_modes.len();
111 if expected_a != modes_a.len() || expected_b != modes_b.len() {
112 return None;
113 }
114 if batch_modes.len() + m_modes.len() + n_modes.len() != modes_c.len() {
115 return None;
116 }
117
118 let a_target: Vec<u32> = batch_modes
119 .iter()
120 .chain(m_modes.iter())
121 .chain(k_modes.iter())
122 .copied()
123 .collect();
124 let b_target: Vec<u32> = batch_modes
125 .iter()
126 .chain(k_modes.iter())
127 .chain(n_modes.iter())
128 .copied()
129 .collect();
130 let c_target: Vec<u32> = batch_modes
131 .iter()
132 .chain(m_modes.iter())
133 .chain(n_modes.iter())
134 .copied()
135 .collect();
136
137 Some(ContractGemmSpec {
138 a_target,
139 b_target,
140 c_target,
141 batch_modes,
142 m_modes,
143 n_modes,
144 k_modes,
145 })
146}
147
148#[derive(Debug, Clone)]
153pub enum CpuPlan<T: Scalar> {
154 BatchedGemm {
156 batch_dims: Vec<usize>,
158 m: usize,
160 n: usize,
162 k: usize,
164 _marker: PhantomData<T>,
165 },
166 ReduceAdd {
168 reduced_axes: Vec<usize>,
170 _marker: PhantomData<T>,
171 },
172 Trace {
174 free_axes: Vec<usize>,
176 components: Vec<Vec<usize>>,
178 comp_dims: Vec<usize>,
180 _marker: PhantomData<T>,
181 },
182 AntiTrace {
184 paired_axes: Vec<(usize, usize)>,
186 free_axes: Vec<usize>,
188 components: Vec<Vec<usize>>,
190 comp_dims: Vec<usize>,
192 _marker: PhantomData<T>,
193 },
194 AntiDiag {
196 paired_axes: Vec<(usize, usize)>,
198 free_axes: Vec<usize>,
200 components: Vec<Vec<usize>>,
202 comp_dims: Vec<usize>,
204 generative_comps: Vec<usize>,
206 _marker: PhantomData<T>,
207 },
208 Contract {
210 modes_a: Vec<u32>,
212 modes_b: Vec<u32>,
214 modes_c: Vec<u32>,
216 #[allow(private_interfaces)]
218 gemm_spec: Option<ContractGemmSpec>,
219 _marker: PhantomData<T>,
220 },
221 ElementwiseBinary {
223 op: SemiringBinaryOp,
225 _marker: PhantomData<T>,
226 },
227 MakeContiguous { _marker: PhantomData<T> },
229}