1use std::collections::HashSet;
2use std::marker::PhantomData;
3
4use tenferro_algebra::{Scalar, Standard};
5use tenferro_device::{Error, Result};
6
7use crate::{
8 mode_position, validate_rank, validate_shape_count, validate_shape_eq, SemiringBinaryOp,
9 SemiringCoreDescriptor, SemiringFastPathDescriptor, TensorSemiringCore, TensorSemiringFastPath,
10};
11
12use super::context::CpuBackend;
13use super::context::CpuContext;
14use super::execution::execute_semiring_plan;
15use super::plan::{build_contract_gemm_spec, compute_paired_components, CpuPlan};
16
17impl CpuBackend {
18 pub(super) fn build_semiring_core_plan<T: Scalar>(
19 desc: &SemiringCoreDescriptor,
20 shapes: &[&[usize]],
21 ) -> Result<CpuPlan<T>> {
22 match desc {
23 SemiringCoreDescriptor::BatchedGemm {
24 batch_dims,
25 m,
26 n,
27 k,
28 } => {
29 validate_shape_count(shapes, 3, "BatchedGemm")?;
30 if !Self::supports_batched_gemm_type::<T>() {
31 return Err(Error::InvalidArgument(format!(
32 "BatchedGemm supports only f32, f64, Complex32, and Complex64 (got {})",
33 std::any::type_name::<T>()
34 )));
35 }
36 let expected_a: Vec<usize> = [*m, *k]
37 .iter()
38 .copied()
39 .chain(batch_dims.iter().copied())
40 .collect();
41 let expected_b: Vec<usize> = [*k, *n]
42 .iter()
43 .copied()
44 .chain(batch_dims.iter().copied())
45 .collect();
46 let expected_c: Vec<usize> = [*m, *n]
47 .iter()
48 .copied()
49 .chain(batch_dims.iter().copied())
50 .collect();
51 validate_shape_eq(shapes[0], &expected_a, "BatchedGemm input A")?;
52 validate_shape_eq(shapes[1], &expected_b, "BatchedGemm input B")?;
53 validate_shape_eq(shapes[2], &expected_c, "BatchedGemm output C")?;
54 Ok(CpuPlan::BatchedGemm {
55 batch_dims: batch_dims.clone(),
56 m: *m,
57 n: *n,
58 k: *k,
59 _marker: PhantomData,
60 })
61 }
62 SemiringCoreDescriptor::ReduceAdd { modes_a, modes_c } => {
63 validate_shape_count(shapes, 2, "ReduceAdd")?;
64 validate_rank(shapes[0], modes_a.len(), "ReduceAdd input A")?;
65 validate_rank(shapes[1], modes_c.len(), "ReduceAdd output C")?;
66 let reduced_axes: Vec<usize> = modes_a
67 .iter()
68 .enumerate()
69 .filter(|(_, mode)| !modes_c.contains(mode))
70 .map(|(idx, _)| idx)
71 .collect();
72 for window in reduced_axes.windows(2) {
73 if window[0] >= window[1] {
74 return Err(Error::InvalidArgument(format!(
75 "ReduceAdd: reduced_axes must be sorted and unique, got {reduced_axes:?}"
76 )));
77 }
78 }
79 if let Some(&last) = reduced_axes.last() {
80 if last >= modes_a.len() {
81 return Err(Error::InvalidArgument(format!(
82 "ReduceAdd: reduced axis {last} out of range for rank {}",
83 modes_a.len()
84 )));
85 }
86 }
87 Ok(CpuPlan::ReduceAdd {
88 reduced_axes,
89 _marker: PhantomData,
90 })
91 }
92 SemiringCoreDescriptor::Trace {
93 modes_a,
94 modes_c,
95 paired,
96 } => {
97 validate_shape_count(shapes, 2, "Trace")?;
98 validate_rank(shapes[0], modes_a.len(), "Trace input A")?;
99 validate_rank(shapes[1], modes_c.len(), "Trace output C")?;
100 let paired_axes: Vec<(usize, usize)> = paired
101 .iter()
102 .map(|(m1, m2)| {
103 Ok((mode_position(modes_a, *m1)?, mode_position(modes_a, *m2)?))
104 })
105 .collect::<Result<_>>()?;
106 for &(ax1, ax2) in &paired_axes {
107 if shapes[0][ax1] != shapes[0][ax2] {
108 return Err(Error::InvalidArgument(format!(
109 "Trace paired axes ({ax1}, {ax2}) have mismatched dimensions: {} vs {}",
110 shapes[0][ax1], shapes[0][ax2]
111 )));
112 }
113 }
114 let free_axes: Vec<usize> = modes_c
115 .iter()
116 .map(|mode| mode_position(modes_a, *mode))
117 .collect::<Result<_>>()?;
118 let (components, comp_dims) = compute_paired_components(&paired_axes, shapes[0]);
119 Ok(CpuPlan::Trace {
120 free_axes,
121 components,
122 comp_dims,
123 _marker: PhantomData,
124 })
125 }
126 SemiringCoreDescriptor::AntiTrace {
127 modes_a,
128 modes_c,
129 paired,
130 } => {
131 validate_shape_count(shapes, 2, "AntiTrace")?;
132 validate_rank(shapes[0], modes_a.len(), "AntiTrace input A")?;
133 validate_rank(shapes[1], modes_c.len(), "AntiTrace output C")?;
134 let paired_axes: Vec<(usize, usize)> = paired
135 .iter()
136 .map(|(m1, m2)| {
137 Ok((mode_position(modes_c, *m1)?, mode_position(modes_c, *m2)?))
138 })
139 .collect::<Result<_>>()?;
140 for &(ax1, ax2) in &paired_axes {
141 if shapes[1][ax1] != shapes[1][ax2] {
142 return Err(Error::InvalidArgument(format!(
143 "AntiTrace paired axes ({ax1}, {ax2}) have mismatched dimensions: {} vs {}",
144 shapes[1][ax1], shapes[1][ax2]
145 )));
146 }
147 }
148 let free_axes: Vec<usize> = modes_a
149 .iter()
150 .map(|mode| mode_position(modes_c, *mode))
151 .collect::<Result<_>>()?;
152 let (components, comp_dims) = compute_paired_components(&paired_axes, shapes[1]);
153 Ok(CpuPlan::AntiTrace {
154 paired_axes,
155 free_axes,
156 components,
157 comp_dims,
158 _marker: PhantomData,
159 })
160 }
161 SemiringCoreDescriptor::AntiDiag {
162 modes_a,
163 modes_c,
164 paired,
165 } => {
166 validate_shape_count(shapes, 2, "AntiDiag")?;
167 validate_rank(shapes[0], modes_a.len(), "AntiDiag input A")?;
168 validate_rank(shapes[1], modes_c.len(), "AntiDiag output C")?;
169 let paired_axes: Vec<(usize, usize)> = paired
170 .iter()
171 .map(|(m1, m2)| {
172 Ok((mode_position(modes_c, *m1)?, mode_position(modes_c, *m2)?))
173 })
174 .collect::<Result<_>>()?;
175 let free_axes: Vec<usize> = modes_a
176 .iter()
177 .map(|mode| mode_position(modes_c, *mode))
178 .collect::<Result<_>>()?;
179 let (components, comp_dims) = compute_paired_components(&paired_axes, shapes[1]);
180 let free_ax_set: HashSet<usize> = free_axes.iter().copied().collect();
181 let generative_comps: Vec<usize> = components
182 .iter()
183 .enumerate()
184 .filter(|(_, comp)| comp.iter().all(|ax| !free_ax_set.contains(ax)))
185 .map(|(idx, _)| idx)
186 .collect();
187 Ok(CpuPlan::AntiDiag {
188 paired_axes,
189 free_axes,
190 components,
191 comp_dims,
192 generative_comps,
193 _marker: PhantomData,
194 })
195 }
196 SemiringCoreDescriptor::MakeContiguous => {
197 validate_shape_count(shapes, 2, "MakeContiguous")?;
198 validate_shape_eq(shapes[1], shapes[0], "MakeContiguous output")?;
199 Ok(CpuPlan::MakeContiguous {
200 _marker: PhantomData,
201 })
202 }
203 }
204 }
205
206 pub(super) fn build_semiring_fast_path_plan<T: Scalar>(
207 desc: &SemiringFastPathDescriptor,
208 shapes: &[&[usize]],
209 ) -> Result<CpuPlan<T>> {
210 match desc {
211 SemiringFastPathDescriptor::Contract {
212 modes_a,
213 modes_b,
214 modes_c,
215 } => {
216 validate_shape_count(shapes, 3, "Contract")?;
217 validate_rank(shapes[0], modes_a.len(), "Contract input A")?;
218 validate_rank(shapes[1], modes_b.len(), "Contract input B")?;
219 validate_rank(shapes[2], modes_c.len(), "Contract output C")?;
220 for (a_pos, &mode) in modes_a.iter().enumerate() {
221 if let Some(b_pos) = modes_b.iter().position(|&m| m == mode) {
222 if shapes[0][a_pos] != shapes[1][b_pos] {
223 return Err(Error::InvalidArgument(format!(
224 "Contract mode {mode} has mismatched dimensions: A={} vs B={}",
225 shapes[0][a_pos], shapes[1][b_pos]
226 )));
227 }
228 }
229 }
230 let gemm_spec = build_contract_gemm_spec(modes_a, modes_b, modes_c);
231 Ok(CpuPlan::Contract {
232 modes_a: modes_a.clone(),
233 modes_b: modes_b.clone(),
234 modes_c: modes_c.clone(),
235 gemm_spec,
236 _marker: PhantomData,
237 })
238 }
239 SemiringFastPathDescriptor::ElementwiseBinary { op } => {
240 validate_shape_count(shapes, 3, "ElementwiseBinary")?;
241 validate_shape_eq(shapes[1], shapes[0], "ElementwiseBinary input B")?;
242 validate_shape_eq(shapes[2], shapes[0], "ElementwiseBinary output C")?;
243 Ok(CpuPlan::ElementwiseBinary {
244 op: *op,
245 _marker: PhantomData,
246 })
247 }
248 }
249 }
250}
251
252impl<S: Scalar> TensorSemiringCore<Standard<S>> for CpuBackend {
253 type Plan = CpuPlan<S>;
254 type Context = CpuContext;
255
256 fn plan(
257 ctx: &mut CpuContext,
258 desc: &SemiringCoreDescriptor,
259 shapes: &[&[usize]],
260 ) -> Result<CpuPlan<S>> {
261 if let Some(cached) = ctx
262 .plan_cache
263 .get::<CpuPlan<S>, SemiringCoreDescriptor>(desc, shapes)
264 {
265 return Ok(cached);
266 }
267
268 let plan = Self::build_semiring_core_plan::<S>(desc, shapes)?;
269 ctx.plan_cache.insert(desc, shapes, plan.clone());
270 Ok(plan)
271 }
272
273 fn execute(
274 ctx: &mut CpuContext,
275 plan: &CpuPlan<S>,
276 alpha: S,
277 inputs: &[&tenferro_tensor::Tensor<S>],
278 beta: S,
279 output: &mut tenferro_tensor::Tensor<S>,
280 ) -> Result<()> {
281 execute_semiring_plan(ctx, plan, alpha, inputs, beta, output)
282 }
283}
284
285impl<S: Scalar> TensorSemiringFastPath<Standard<S>> for CpuBackend {
286 type Plan = CpuPlan<S>;
287 type Context = CpuContext;
288
289 fn plan(
290 ctx: &mut CpuContext,
291 desc: &SemiringFastPathDescriptor,
292 shapes: &[&[usize]],
293 ) -> Result<CpuPlan<S>> {
294 if let Some(cached) = ctx
295 .plan_cache
296 .get::<CpuPlan<S>, SemiringFastPathDescriptor>(desc, shapes)
297 {
298 return Ok(cached);
299 }
300
301 let plan = Self::build_semiring_fast_path_plan::<S>(desc, shapes)?;
302 ctx.plan_cache.insert(desc, shapes, plan.clone());
303 Ok(plan)
304 }
305
306 fn execute(
307 ctx: &mut CpuContext,
308 plan: &CpuPlan<S>,
309 alpha: S,
310 inputs: &[&tenferro_tensor::Tensor<S>],
311 beta: S,
312 output: &mut tenferro_tensor::Tensor<S>,
313 ) -> Result<()> {
314 execute_semiring_plan(ctx, plan, alpha, inputs, beta, output)
315 }
316
317 fn has_fast_path(desc: SemiringFastPathDescriptor) -> bool {
318 matches!(
319 desc,
320 SemiringFastPathDescriptor::Contract { .. }
321 | SemiringFastPathDescriptor::ElementwiseBinary {
322 op: SemiringBinaryOp::Add | SemiringBinaryOp::Mul,
323 }
324 )
325 }
326}