1use tenferro_algebra::{Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5use crate::cpu::common::{
6 execute_binary_map, execute_ternary_map, execute_unary_map, is_supported_ordered_real_type,
7 is_supported_scalar_type, plan_reduction, validate_pointwise_shapes, CpuScalarValue,
8};
9use crate::cpu::family_reduction::{
10 execute_extrema_reduction, execute_mean_reduction, execute_prod_reduction,
11 execute_sum_reduction,
12};
13use crate::cpu::{tensor_to_view, tensor_to_view_mut};
14use crate::infra::typed_dispatch::{
15 cast_scalar_value, cast_strided_view, cast_strided_view_mut, dispatch_complex_scalar_type,
16 dispatch_real_scalar_type, dispatch_standard_scalar_type,
17};
18use crate::{
19 validate_execute_inputs, CpuBackend, CpuContext, ScalarBinaryOp, ScalarPrimsDescriptor,
20 ScalarReductionOp, ScalarTernaryOp, ScalarUnaryOp, TensorScalarPrims,
21};
22
23#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum CpuScalarPlan {
33 PointwiseUnary {
34 op: ScalarUnaryOp,
35 },
36 PointwiseBinary {
37 op: ScalarBinaryOp,
38 },
39 PointwiseTernary {
40 op: ScalarTernaryOp,
41 },
42 Reduction {
43 reduced_axes: Vec<usize>,
44 op: ScalarReductionOp,
45 },
46}
47
48fn supports_scalar_unary<S: Scalar + 'static>(op: ScalarUnaryOp) -> bool {
49 is_supported_scalar_type::<S>()
50 && matches!(
51 op,
52 ScalarUnaryOp::Neg
53 | ScalarUnaryOp::Conj
54 | ScalarUnaryOp::Abs
55 | ScalarUnaryOp::Reciprocal
56 | ScalarUnaryOp::Real
57 | ScalarUnaryOp::Imag
58 | ScalarUnaryOp::Square
59 )
60}
61
62fn supports_scalar_binary<S: Scalar + 'static>(op: ScalarBinaryOp) -> bool {
63 match op {
64 ScalarBinaryOp::Add | ScalarBinaryOp::Sub | ScalarBinaryOp::Mul | ScalarBinaryOp::Div => {
65 is_supported_scalar_type::<S>()
66 }
67 ScalarBinaryOp::Maximum
68 | ScalarBinaryOp::Minimum
69 | ScalarBinaryOp::Greater
70 | ScalarBinaryOp::GreaterEqual
71 | ScalarBinaryOp::ClampMin
72 | ScalarBinaryOp::ClampMax => is_supported_ordered_real_type::<S>(),
73 }
74}
75
76fn supports_scalar_ternary<S: Scalar + 'static>(op: ScalarTernaryOp) -> bool {
77 matches!(op, ScalarTernaryOp::Where) && is_supported_ordered_real_type::<S>()
78}
79
80fn supports_scalar_reduction<S: Scalar + 'static>(op: ScalarReductionOp) -> bool {
81 match op {
82 ScalarReductionOp::Sum | ScalarReductionOp::Prod | ScalarReductionOp::Mean => {
83 is_supported_scalar_type::<S>()
84 }
85 ScalarReductionOp::Max | ScalarReductionOp::Min => is_supported_ordered_real_type::<S>(),
86 }
87}
88
89fn execute_scalar_unary_typed<S: CpuScalarValue>(
90 alpha: S,
91 input: &strided_view::StridedView<S>,
92 beta: S,
93 output: &mut strided_view::StridedViewMut<S>,
94 op: ScalarUnaryOp,
95) -> Result<()> {
96 match op {
97 ScalarUnaryOp::Neg => execute_unary_map(alpha, input, beta, output, |x| -x),
98 ScalarUnaryOp::Conj => execute_unary_map(alpha, input, beta, output, |x| x.conj()),
99 ScalarUnaryOp::Abs => {
100 execute_unary_map(alpha, input, beta, output, |x| S::from_real(x.abs()))
101 }
102 ScalarUnaryOp::Reciprocal => execute_unary_map(alpha, input, beta, output, |x| x.recip()),
103 ScalarUnaryOp::Real => {
104 execute_unary_map(alpha, input, beta, output, |x| S::from_real(x.re()))
105 }
106 ScalarUnaryOp::Imag => {
107 execute_unary_map(alpha, input, beta, output, |x| S::from_real(x.im()))
108 }
109 ScalarUnaryOp::Square => execute_unary_map(alpha, input, beta, output, |x| x * x),
110 }
111}
112
113fn execute_scalar_binary_real<S: num_traits::Float + CpuScalarValue>(
114 alpha: S,
115 lhs: &strided_view::StridedView<S>,
116 rhs: &strided_view::StridedView<S>,
117 beta: S,
118 output: &mut strided_view::StridedViewMut<S>,
119 op: ScalarBinaryOp,
120) -> Result<()> {
121 match op {
122 ScalarBinaryOp::Add => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x + y),
123 ScalarBinaryOp::Sub => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x - y),
124 ScalarBinaryOp::Mul => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x * y),
125 ScalarBinaryOp::Div => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x / y),
126 ScalarBinaryOp::Maximum => {
127 execute_binary_map(
128 alpha,
129 lhs,
130 rhs,
131 beta,
132 output,
133 |x, y| if x >= y { x } else { y },
134 )
135 }
136 ScalarBinaryOp::Minimum => {
137 execute_binary_map(
138 alpha,
139 lhs,
140 rhs,
141 beta,
142 output,
143 |x, y| if x <= y { x } else { y },
144 )
145 }
146 ScalarBinaryOp::Greater => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| {
147 if x > y {
148 S::one()
149 } else {
150 S::zero()
151 }
152 }),
153 ScalarBinaryOp::GreaterEqual => {
154 execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| {
155 if x >= y {
156 S::one()
157 } else {
158 S::zero()
159 }
160 })
161 }
162 ScalarBinaryOp::ClampMin => {
163 execute_binary_map(
164 alpha,
165 lhs,
166 rhs,
167 beta,
168 output,
169 |x, y| if x >= y { x } else { y },
170 )
171 }
172 ScalarBinaryOp::ClampMax => {
173 execute_binary_map(
174 alpha,
175 lhs,
176 rhs,
177 beta,
178 output,
179 |x, y| if x <= y { x } else { y },
180 )
181 }
182 }
183}
184
185fn execute_scalar_binary_complex<S: CpuScalarValue>(
186 alpha: S,
187 lhs: &strided_view::StridedView<S>,
188 rhs: &strided_view::StridedView<S>,
189 beta: S,
190 output: &mut strided_view::StridedViewMut<S>,
191 op: ScalarBinaryOp,
192) -> Result<()> {
193 match op {
194 ScalarBinaryOp::Add => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x + y),
195 ScalarBinaryOp::Sub => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x - y),
196 ScalarBinaryOp::Mul => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x * y),
197 ScalarBinaryOp::Div => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x / y),
198 _ => Err(Error::InvalidArgument(format!(
199 "scalar binary operation {op:?} requires ordered real scalars"
200 ))),
201 }
202}
203
204fn execute_scalar_ternary_real<S: num_traits::Float + CpuScalarValue>(
205 alpha: S,
206 cond: &strided_view::StridedView<S>,
207 on_true: &strided_view::StridedView<S>,
208 on_false: &strided_view::StridedView<S>,
209 beta: S,
210 output: &mut strided_view::StridedViewMut<S>,
211 op: ScalarTernaryOp,
212) -> Result<()> {
213 match op {
214 ScalarTernaryOp::Where => {
215 execute_ternary_map(alpha, cond, on_true, on_false, beta, output, |c, t, f| {
216 if c != S::zero() {
217 t
218 } else {
219 f
220 }
221 })
222 }
223 }
224}
225
226fn execute_scalar_unary<T: Scalar + 'static>(
227 alpha: T,
228 input: &strided_view::StridedView<T>,
229 beta: T,
230 output: &mut strided_view::StridedViewMut<T>,
231 op: ScalarUnaryOp,
232) -> Result<()> {
233 dispatch_standard_scalar_type!(T, Concrete, {
234 let input = cast_strided_view!(input, T, Concrete);
235 let output = cast_strided_view_mut!(output, T, Concrete);
236 let alpha = cast_scalar_value!(alpha, T, Concrete);
237 let beta = cast_scalar_value!(beta, T, Concrete);
238 return execute_scalar_unary_typed::<Concrete>(alpha, input, beta, output, op);
239 });
240
241 Err(Error::InvalidArgument(format!(
242 "scalar unary operation {op:?} is not supported for {}",
243 std::any::type_name::<T>()
244 )))
245}
246
247fn execute_scalar_binary<T: Scalar + 'static>(
248 alpha: T,
249 lhs: &strided_view::StridedView<T>,
250 rhs: &strided_view::StridedView<T>,
251 beta: T,
252 output: &mut strided_view::StridedViewMut<T>,
253 op: ScalarBinaryOp,
254) -> Result<()> {
255 dispatch_real_scalar_type!(T, Concrete, {
256 let lhs = cast_strided_view!(lhs, T, Concrete);
257 let rhs = cast_strided_view!(rhs, T, Concrete);
258 let output = cast_strided_view_mut!(output, T, Concrete);
259 let alpha = cast_scalar_value!(alpha, T, Concrete);
260 let beta = cast_scalar_value!(beta, T, Concrete);
261 return execute_scalar_binary_real::<Concrete>(alpha, lhs, rhs, beta, output, op);
262 });
263 dispatch_complex_scalar_type!(T, Concrete, {
264 let lhs = cast_strided_view!(lhs, T, Concrete);
265 let rhs = cast_strided_view!(rhs, T, Concrete);
266 let output = cast_strided_view_mut!(output, T, Concrete);
267 let alpha = cast_scalar_value!(alpha, T, Concrete);
268 let beta = cast_scalar_value!(beta, T, Concrete);
269 return execute_scalar_binary_complex::<Concrete>(alpha, lhs, rhs, beta, output, op);
270 });
271
272 Err(Error::InvalidArgument(format!(
273 "scalar binary operation {op:?} is not supported for {}",
274 std::any::type_name::<T>()
275 )))
276}
277
278fn execute_scalar_ternary<T: Scalar + 'static>(
279 alpha: T,
280 cond: &strided_view::StridedView<T>,
281 on_true: &strided_view::StridedView<T>,
282 on_false: &strided_view::StridedView<T>,
283 beta: T,
284 output: &mut strided_view::StridedViewMut<T>,
285 op: ScalarTernaryOp,
286) -> Result<()> {
287 dispatch_real_scalar_type!(T, Concrete, {
288 let cond = cast_strided_view!(cond, T, Concrete);
289 let on_true = cast_strided_view!(on_true, T, Concrete);
290 let on_false = cast_strided_view!(on_false, T, Concrete);
291 let output = cast_strided_view_mut!(output, T, Concrete);
292 let alpha = cast_scalar_value!(alpha, T, Concrete);
293 let beta = cast_scalar_value!(beta, T, Concrete);
294 return execute_scalar_ternary_real::<Concrete>(
295 alpha, cond, on_true, on_false, beta, output, op,
296 );
297 });
298
299 Err(Error::InvalidArgument(format!(
300 "scalar ternary operation {op:?} is not supported for {}",
301 std::any::type_name::<T>()
302 )))
303}
304
305fn execute_scalar_reduction<T: Scalar + 'static>(
306 alpha: T,
307 input: &strided_view::StridedView<T>,
308 beta: T,
309 output: &mut strided_view::StridedViewMut<T>,
310 reduced_axes: &[usize],
311 op: ScalarReductionOp,
312) -> Result<()> {
313 match op {
314 ScalarReductionOp::Sum => {
315 dispatch_standard_scalar_type!(T, Concrete, {
316 let input = cast_strided_view!(input, T, Concrete);
317 let output = cast_strided_view_mut!(output, T, Concrete);
318 let alpha = cast_scalar_value!(alpha, T, Concrete);
319 let beta = cast_scalar_value!(beta, T, Concrete);
320 return execute_sum_reduction::<Concrete>(alpha, input, beta, output, reduced_axes);
321 });
322 }
323 ScalarReductionOp::Prod => {
324 dispatch_standard_scalar_type!(T, Concrete, {
325 let input = cast_strided_view!(input, T, Concrete);
326 let output = cast_strided_view_mut!(output, T, Concrete);
327 let alpha = cast_scalar_value!(alpha, T, Concrete);
328 let beta = cast_scalar_value!(beta, T, Concrete);
329 return execute_prod_reduction::<Concrete>(
330 alpha,
331 input,
332 beta,
333 output,
334 reduced_axes,
335 );
336 });
337 }
338 ScalarReductionOp::Mean => {
339 dispatch_standard_scalar_type!(T, Concrete, {
340 let input = cast_strided_view!(input, T, Concrete);
341 let output = cast_strided_view_mut!(output, T, Concrete);
342 let alpha = cast_scalar_value!(alpha, T, Concrete);
343 let beta = cast_scalar_value!(beta, T, Concrete);
344 return execute_mean_reduction::<Concrete>(
345 alpha,
346 input,
347 beta,
348 output,
349 reduced_axes,
350 );
351 });
352 }
353 ScalarReductionOp::Max => {
354 dispatch_real_scalar_type!(T, Concrete, {
355 let input = cast_strided_view!(input, T, Concrete);
356 let output = cast_strided_view_mut!(output, T, Concrete);
357 let alpha = cast_scalar_value!(alpha, T, Concrete);
358 let beta = cast_scalar_value!(beta, T, Concrete);
359 return execute_extrema_reduction(alpha, input, beta, output, reduced_axes, true);
360 });
361 }
362 ScalarReductionOp::Min => {
363 dispatch_real_scalar_type!(T, Concrete, {
364 let input = cast_strided_view!(input, T, Concrete);
365 let output = cast_strided_view_mut!(output, T, Concrete);
366 let alpha = cast_scalar_value!(alpha, T, Concrete);
367 let beta = cast_scalar_value!(beta, T, Concrete);
368 return execute_extrema_reduction(alpha, input, beta, output, reduced_axes, false);
369 });
370 }
371 }
372
373 Err(Error::InvalidArgument(format!(
374 "scalar reduction {op:?} is not supported for {}",
375 std::any::type_name::<T>()
376 )))
377}
378
379impl<S: Scalar + 'static> TensorScalarPrims<Standard<S>> for CpuBackend {
380 type Plan = CpuScalarPlan;
381 type Context = CpuContext;
382
383 fn plan(
384 _ctx: &mut Self::Context,
385 desc: &ScalarPrimsDescriptor,
386 shapes: &[&[usize]],
387 ) -> Result<Self::Plan> {
388 match desc {
389 ScalarPrimsDescriptor::PointwiseUnary { op } => {
390 validate_pointwise_shapes(shapes, 1, "ScalarPointwiseUnary")?;
391 if !supports_scalar_unary::<S>(*op) {
392 return Err(Error::InvalidArgument(format!(
393 "scalar unary operation {op:?} is not supported on CpuBackend for {}",
394 std::any::type_name::<S>()
395 )));
396 }
397 Ok(CpuScalarPlan::PointwiseUnary { op: *op })
398 }
399 ScalarPrimsDescriptor::PointwiseBinary { op } => {
400 validate_pointwise_shapes(shapes, 2, "ScalarPointwiseBinary")?;
401 if !supports_scalar_binary::<S>(*op) {
402 return Err(Error::InvalidArgument(format!(
403 "scalar binary operation {op:?} is not supported on CpuBackend for {}",
404 std::any::type_name::<S>()
405 )));
406 }
407 Ok(CpuScalarPlan::PointwiseBinary { op: *op })
408 }
409 ScalarPrimsDescriptor::PointwiseTernary { op } => {
410 validate_pointwise_shapes(shapes, 3, "ScalarPointwiseTernary")?;
411 if !supports_scalar_ternary::<S>(*op) {
412 return Err(Error::InvalidArgument(format!(
413 "scalar ternary operation {op:?} is not supported on CpuBackend for {}",
414 std::any::type_name::<S>()
415 )));
416 }
417 Ok(CpuScalarPlan::PointwiseTernary { op: *op })
418 }
419 ScalarPrimsDescriptor::Reduction {
420 modes_a,
421 modes_c,
422 op,
423 } => {
424 if !supports_scalar_reduction::<S>(*op) {
425 return Err(Error::InvalidArgument(format!(
426 "scalar reduction {op:?} is not supported on CpuBackend for {}",
427 std::any::type_name::<S>()
428 )));
429 }
430 let spec = plan_reduction(modes_a, modes_c, shapes, "ScalarReduction")?;
431 let _ = spec.reduced_total;
432 Ok(CpuScalarPlan::Reduction {
433 reduced_axes: spec.reduced_axes,
434 op: *op,
435 })
436 }
437 }
438 }
439
440 fn execute(
441 _ctx: &mut Self::Context,
442 plan: &Self::Plan,
443 alpha: S,
444 inputs: &[&Tensor<S>],
445 beta: S,
446 output: &mut Tensor<S>,
447 ) -> Result<()> {
448 let views: Vec<_> = inputs
449 .iter()
450 .map(|tensor| tensor_to_view(tensor))
451 .collect::<Result<_>>()?;
452 let view_refs: Vec<_> = views.iter().collect();
453 let mut out_view = tensor_to_view_mut(output)?;
454
455 match plan {
456 CpuScalarPlan::PointwiseUnary { op } => {
457 validate_execute_inputs(inputs, 1, "ScalarPointwiseUnary")?;
458 execute_scalar_unary(alpha, view_refs[0], beta, &mut out_view, *op)
459 }
460 CpuScalarPlan::PointwiseBinary { op } => {
461 validate_execute_inputs(inputs, 2, "ScalarPointwiseBinary")?;
462 execute_scalar_binary(alpha, view_refs[0], view_refs[1], beta, &mut out_view, *op)
463 }
464 CpuScalarPlan::PointwiseTernary { op } => {
465 validate_execute_inputs(inputs, 3, "ScalarPointwiseTernary")?;
466 execute_scalar_ternary(
467 alpha,
468 view_refs[0],
469 view_refs[1],
470 view_refs[2],
471 beta,
472 &mut out_view,
473 *op,
474 )
475 }
476 CpuScalarPlan::Reduction { reduced_axes, op } => {
477 validate_execute_inputs(inputs, 1, "ScalarReduction")?;
478 execute_scalar_reduction(
479 alpha,
480 view_refs[0],
481 beta,
482 &mut out_view,
483 reduced_axes,
484 *op,
485 )
486 }
487 }
488 }
489
490 fn has_scalar_support(desc: ScalarPrimsDescriptor) -> bool {
491 match desc {
492 ScalarPrimsDescriptor::PointwiseUnary { op } => supports_scalar_unary::<S>(op),
493 ScalarPrimsDescriptor::PointwiseBinary { op } => supports_scalar_binary::<S>(op),
494 ScalarPrimsDescriptor::PointwiseTernary { op } => supports_scalar_ternary::<S>(op),
495 ScalarPrimsDescriptor::Reduction { op, .. } => supports_scalar_reduction::<S>(op),
496 }
497 }
498}