tenferro_prims/cpu/
scalar.rs

1use tenferro_algebra::{Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5use crate::cpu::common::{
6    execute_binary_map, execute_ternary_map, execute_unary_map, is_supported_ordered_real_type,
7    is_supported_scalar_type, plan_reduction, validate_pointwise_shapes, CpuScalarValue,
8};
9use crate::cpu::family_reduction::{
10    execute_extrema_reduction, execute_mean_reduction, execute_prod_reduction,
11    execute_sum_reduction,
12};
13use crate::cpu::{tensor_to_view, tensor_to_view_mut};
14use crate::infra::typed_dispatch::{
15    cast_scalar_value, cast_strided_view, cast_strided_view_mut, dispatch_complex_scalar_type,
16    dispatch_real_scalar_type, dispatch_standard_scalar_type,
17};
18use crate::{
19    validate_execute_inputs, CpuBackend, CpuContext, ScalarBinaryOp, ScalarPrimsDescriptor,
20    ScalarReductionOp, ScalarTernaryOp, ScalarUnaryOp, TensorScalarPrims,
21};
22
23/// CPU execution plan for the scalar protocol family.
24///
25/// # Examples
26///
27/// ```ignore
28/// use tenferro_prims::CpuScalarPlan;
29/// let _ = std::mem::size_of::<CpuScalarPlan>();
30/// ```
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum CpuScalarPlan {
33    PointwiseUnary {
34        op: ScalarUnaryOp,
35    },
36    PointwiseBinary {
37        op: ScalarBinaryOp,
38    },
39    PointwiseTernary {
40        op: ScalarTernaryOp,
41    },
42    Reduction {
43        reduced_axes: Vec<usize>,
44        op: ScalarReductionOp,
45    },
46}
47
48fn supports_scalar_unary<S: Scalar + 'static>(op: ScalarUnaryOp) -> bool {
49    is_supported_scalar_type::<S>()
50        && matches!(
51            op,
52            ScalarUnaryOp::Neg
53                | ScalarUnaryOp::Conj
54                | ScalarUnaryOp::Abs
55                | ScalarUnaryOp::Reciprocal
56                | ScalarUnaryOp::Real
57                | ScalarUnaryOp::Imag
58                | ScalarUnaryOp::Square
59        )
60}
61
62fn supports_scalar_binary<S: Scalar + 'static>(op: ScalarBinaryOp) -> bool {
63    match op {
64        ScalarBinaryOp::Add | ScalarBinaryOp::Sub | ScalarBinaryOp::Mul | ScalarBinaryOp::Div => {
65            is_supported_scalar_type::<S>()
66        }
67        ScalarBinaryOp::Maximum
68        | ScalarBinaryOp::Minimum
69        | ScalarBinaryOp::Greater
70        | ScalarBinaryOp::GreaterEqual
71        | ScalarBinaryOp::ClampMin
72        | ScalarBinaryOp::ClampMax => is_supported_ordered_real_type::<S>(),
73    }
74}
75
76fn supports_scalar_ternary<S: Scalar + 'static>(op: ScalarTernaryOp) -> bool {
77    matches!(op, ScalarTernaryOp::Where) && is_supported_ordered_real_type::<S>()
78}
79
80fn supports_scalar_reduction<S: Scalar + 'static>(op: ScalarReductionOp) -> bool {
81    match op {
82        ScalarReductionOp::Sum | ScalarReductionOp::Prod | ScalarReductionOp::Mean => {
83            is_supported_scalar_type::<S>()
84        }
85        ScalarReductionOp::Max | ScalarReductionOp::Min => is_supported_ordered_real_type::<S>(),
86    }
87}
88
89fn execute_scalar_unary_typed<S: CpuScalarValue>(
90    alpha: S,
91    input: &strided_view::StridedView<S>,
92    beta: S,
93    output: &mut strided_view::StridedViewMut<S>,
94    op: ScalarUnaryOp,
95) -> Result<()> {
96    match op {
97        ScalarUnaryOp::Neg => execute_unary_map(alpha, input, beta, output, |x| -x),
98        ScalarUnaryOp::Conj => execute_unary_map(alpha, input, beta, output, |x| x.conj()),
99        ScalarUnaryOp::Abs => {
100            execute_unary_map(alpha, input, beta, output, |x| S::from_real(x.abs()))
101        }
102        ScalarUnaryOp::Reciprocal => execute_unary_map(alpha, input, beta, output, |x| x.recip()),
103        ScalarUnaryOp::Real => {
104            execute_unary_map(alpha, input, beta, output, |x| S::from_real(x.re()))
105        }
106        ScalarUnaryOp::Imag => {
107            execute_unary_map(alpha, input, beta, output, |x| S::from_real(x.im()))
108        }
109        ScalarUnaryOp::Square => execute_unary_map(alpha, input, beta, output, |x| x * x),
110    }
111}
112
113fn execute_scalar_binary_real<S: num_traits::Float + CpuScalarValue>(
114    alpha: S,
115    lhs: &strided_view::StridedView<S>,
116    rhs: &strided_view::StridedView<S>,
117    beta: S,
118    output: &mut strided_view::StridedViewMut<S>,
119    op: ScalarBinaryOp,
120) -> Result<()> {
121    match op {
122        ScalarBinaryOp::Add => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x + y),
123        ScalarBinaryOp::Sub => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x - y),
124        ScalarBinaryOp::Mul => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x * y),
125        ScalarBinaryOp::Div => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x / y),
126        ScalarBinaryOp::Maximum => {
127            execute_binary_map(
128                alpha,
129                lhs,
130                rhs,
131                beta,
132                output,
133                |x, y| if x >= y { x } else { y },
134            )
135        }
136        ScalarBinaryOp::Minimum => {
137            execute_binary_map(
138                alpha,
139                lhs,
140                rhs,
141                beta,
142                output,
143                |x, y| if x <= y { x } else { y },
144            )
145        }
146        ScalarBinaryOp::Greater => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| {
147            if x > y {
148                S::one()
149            } else {
150                S::zero()
151            }
152        }),
153        ScalarBinaryOp::GreaterEqual => {
154            execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| {
155                if x >= y {
156                    S::one()
157                } else {
158                    S::zero()
159                }
160            })
161        }
162        ScalarBinaryOp::ClampMin => {
163            execute_binary_map(
164                alpha,
165                lhs,
166                rhs,
167                beta,
168                output,
169                |x, y| if x >= y { x } else { y },
170            )
171        }
172        ScalarBinaryOp::ClampMax => {
173            execute_binary_map(
174                alpha,
175                lhs,
176                rhs,
177                beta,
178                output,
179                |x, y| if x <= y { x } else { y },
180            )
181        }
182    }
183}
184
185fn execute_scalar_binary_complex<S: CpuScalarValue>(
186    alpha: S,
187    lhs: &strided_view::StridedView<S>,
188    rhs: &strided_view::StridedView<S>,
189    beta: S,
190    output: &mut strided_view::StridedViewMut<S>,
191    op: ScalarBinaryOp,
192) -> Result<()> {
193    match op {
194        ScalarBinaryOp::Add => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x + y),
195        ScalarBinaryOp::Sub => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x - y),
196        ScalarBinaryOp::Mul => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x * y),
197        ScalarBinaryOp::Div => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x / y),
198        _ => Err(Error::InvalidArgument(format!(
199            "scalar binary operation {op:?} requires ordered real scalars"
200        ))),
201    }
202}
203
204fn execute_scalar_ternary_real<S: num_traits::Float + CpuScalarValue>(
205    alpha: S,
206    cond: &strided_view::StridedView<S>,
207    on_true: &strided_view::StridedView<S>,
208    on_false: &strided_view::StridedView<S>,
209    beta: S,
210    output: &mut strided_view::StridedViewMut<S>,
211    op: ScalarTernaryOp,
212) -> Result<()> {
213    match op {
214        ScalarTernaryOp::Where => {
215            execute_ternary_map(alpha, cond, on_true, on_false, beta, output, |c, t, f| {
216                if c != S::zero() {
217                    t
218                } else {
219                    f
220                }
221            })
222        }
223    }
224}
225
226fn execute_scalar_unary<T: Scalar + 'static>(
227    alpha: T,
228    input: &strided_view::StridedView<T>,
229    beta: T,
230    output: &mut strided_view::StridedViewMut<T>,
231    op: ScalarUnaryOp,
232) -> Result<()> {
233    dispatch_standard_scalar_type!(T, Concrete, {
234        let input = cast_strided_view!(input, T, Concrete);
235        let output = cast_strided_view_mut!(output, T, Concrete);
236        let alpha = cast_scalar_value!(alpha, T, Concrete);
237        let beta = cast_scalar_value!(beta, T, Concrete);
238        return execute_scalar_unary_typed::<Concrete>(alpha, input, beta, output, op);
239    });
240
241    Err(Error::InvalidArgument(format!(
242        "scalar unary operation {op:?} is not supported for {}",
243        std::any::type_name::<T>()
244    )))
245}
246
247fn execute_scalar_binary<T: Scalar + 'static>(
248    alpha: T,
249    lhs: &strided_view::StridedView<T>,
250    rhs: &strided_view::StridedView<T>,
251    beta: T,
252    output: &mut strided_view::StridedViewMut<T>,
253    op: ScalarBinaryOp,
254) -> Result<()> {
255    dispatch_real_scalar_type!(T, Concrete, {
256        let lhs = cast_strided_view!(lhs, T, Concrete);
257        let rhs = cast_strided_view!(rhs, T, Concrete);
258        let output = cast_strided_view_mut!(output, T, Concrete);
259        let alpha = cast_scalar_value!(alpha, T, Concrete);
260        let beta = cast_scalar_value!(beta, T, Concrete);
261        return execute_scalar_binary_real::<Concrete>(alpha, lhs, rhs, beta, output, op);
262    });
263    dispatch_complex_scalar_type!(T, Concrete, {
264        let lhs = cast_strided_view!(lhs, T, Concrete);
265        let rhs = cast_strided_view!(rhs, T, Concrete);
266        let output = cast_strided_view_mut!(output, T, Concrete);
267        let alpha = cast_scalar_value!(alpha, T, Concrete);
268        let beta = cast_scalar_value!(beta, T, Concrete);
269        return execute_scalar_binary_complex::<Concrete>(alpha, lhs, rhs, beta, output, op);
270    });
271
272    Err(Error::InvalidArgument(format!(
273        "scalar binary operation {op:?} is not supported for {}",
274        std::any::type_name::<T>()
275    )))
276}
277
278fn execute_scalar_ternary<T: Scalar + 'static>(
279    alpha: T,
280    cond: &strided_view::StridedView<T>,
281    on_true: &strided_view::StridedView<T>,
282    on_false: &strided_view::StridedView<T>,
283    beta: T,
284    output: &mut strided_view::StridedViewMut<T>,
285    op: ScalarTernaryOp,
286) -> Result<()> {
287    dispatch_real_scalar_type!(T, Concrete, {
288        let cond = cast_strided_view!(cond, T, Concrete);
289        let on_true = cast_strided_view!(on_true, T, Concrete);
290        let on_false = cast_strided_view!(on_false, T, Concrete);
291        let output = cast_strided_view_mut!(output, T, Concrete);
292        let alpha = cast_scalar_value!(alpha, T, Concrete);
293        let beta = cast_scalar_value!(beta, T, Concrete);
294        return execute_scalar_ternary_real::<Concrete>(
295            alpha, cond, on_true, on_false, beta, output, op,
296        );
297    });
298
299    Err(Error::InvalidArgument(format!(
300        "scalar ternary operation {op:?} is not supported for {}",
301        std::any::type_name::<T>()
302    )))
303}
304
305fn execute_scalar_reduction<T: Scalar + 'static>(
306    alpha: T,
307    input: &strided_view::StridedView<T>,
308    beta: T,
309    output: &mut strided_view::StridedViewMut<T>,
310    reduced_axes: &[usize],
311    op: ScalarReductionOp,
312) -> Result<()> {
313    match op {
314        ScalarReductionOp::Sum => {
315            dispatch_standard_scalar_type!(T, Concrete, {
316                let input = cast_strided_view!(input, T, Concrete);
317                let output = cast_strided_view_mut!(output, T, Concrete);
318                let alpha = cast_scalar_value!(alpha, T, Concrete);
319                let beta = cast_scalar_value!(beta, T, Concrete);
320                return execute_sum_reduction::<Concrete>(alpha, input, beta, output, reduced_axes);
321            });
322        }
323        ScalarReductionOp::Prod => {
324            dispatch_standard_scalar_type!(T, Concrete, {
325                let input = cast_strided_view!(input, T, Concrete);
326                let output = cast_strided_view_mut!(output, T, Concrete);
327                let alpha = cast_scalar_value!(alpha, T, Concrete);
328                let beta = cast_scalar_value!(beta, T, Concrete);
329                return execute_prod_reduction::<Concrete>(
330                    alpha,
331                    input,
332                    beta,
333                    output,
334                    reduced_axes,
335                );
336            });
337        }
338        ScalarReductionOp::Mean => {
339            dispatch_standard_scalar_type!(T, Concrete, {
340                let input = cast_strided_view!(input, T, Concrete);
341                let output = cast_strided_view_mut!(output, T, Concrete);
342                let alpha = cast_scalar_value!(alpha, T, Concrete);
343                let beta = cast_scalar_value!(beta, T, Concrete);
344                return execute_mean_reduction::<Concrete>(
345                    alpha,
346                    input,
347                    beta,
348                    output,
349                    reduced_axes,
350                );
351            });
352        }
353        ScalarReductionOp::Max => {
354            dispatch_real_scalar_type!(T, Concrete, {
355                let input = cast_strided_view!(input, T, Concrete);
356                let output = cast_strided_view_mut!(output, T, Concrete);
357                let alpha = cast_scalar_value!(alpha, T, Concrete);
358                let beta = cast_scalar_value!(beta, T, Concrete);
359                return execute_extrema_reduction(alpha, input, beta, output, reduced_axes, true);
360            });
361        }
362        ScalarReductionOp::Min => {
363            dispatch_real_scalar_type!(T, Concrete, {
364                let input = cast_strided_view!(input, T, Concrete);
365                let output = cast_strided_view_mut!(output, T, Concrete);
366                let alpha = cast_scalar_value!(alpha, T, Concrete);
367                let beta = cast_scalar_value!(beta, T, Concrete);
368                return execute_extrema_reduction(alpha, input, beta, output, reduced_axes, false);
369            });
370        }
371    }
372
373    Err(Error::InvalidArgument(format!(
374        "scalar reduction {op:?} is not supported for {}",
375        std::any::type_name::<T>()
376    )))
377}
378
379impl<S: Scalar + 'static> TensorScalarPrims<Standard<S>> for CpuBackend {
380    type Plan = CpuScalarPlan;
381    type Context = CpuContext;
382
383    fn plan(
384        _ctx: &mut Self::Context,
385        desc: &ScalarPrimsDescriptor,
386        shapes: &[&[usize]],
387    ) -> Result<Self::Plan> {
388        match desc {
389            ScalarPrimsDescriptor::PointwiseUnary { op } => {
390                validate_pointwise_shapes(shapes, 1, "ScalarPointwiseUnary")?;
391                if !supports_scalar_unary::<S>(*op) {
392                    return Err(Error::InvalidArgument(format!(
393                        "scalar unary operation {op:?} is not supported on CpuBackend for {}",
394                        std::any::type_name::<S>()
395                    )));
396                }
397                Ok(CpuScalarPlan::PointwiseUnary { op: *op })
398            }
399            ScalarPrimsDescriptor::PointwiseBinary { op } => {
400                validate_pointwise_shapes(shapes, 2, "ScalarPointwiseBinary")?;
401                if !supports_scalar_binary::<S>(*op) {
402                    return Err(Error::InvalidArgument(format!(
403                        "scalar binary operation {op:?} is not supported on CpuBackend for {}",
404                        std::any::type_name::<S>()
405                    )));
406                }
407                Ok(CpuScalarPlan::PointwiseBinary { op: *op })
408            }
409            ScalarPrimsDescriptor::PointwiseTernary { op } => {
410                validate_pointwise_shapes(shapes, 3, "ScalarPointwiseTernary")?;
411                if !supports_scalar_ternary::<S>(*op) {
412                    return Err(Error::InvalidArgument(format!(
413                        "scalar ternary operation {op:?} is not supported on CpuBackend for {}",
414                        std::any::type_name::<S>()
415                    )));
416                }
417                Ok(CpuScalarPlan::PointwiseTernary { op: *op })
418            }
419            ScalarPrimsDescriptor::Reduction {
420                modes_a,
421                modes_c,
422                op,
423            } => {
424                if !supports_scalar_reduction::<S>(*op) {
425                    return Err(Error::InvalidArgument(format!(
426                        "scalar reduction {op:?} is not supported on CpuBackend for {}",
427                        std::any::type_name::<S>()
428                    )));
429                }
430                let spec = plan_reduction(modes_a, modes_c, shapes, "ScalarReduction")?;
431                let _ = spec.reduced_total;
432                Ok(CpuScalarPlan::Reduction {
433                    reduced_axes: spec.reduced_axes,
434                    op: *op,
435                })
436            }
437        }
438    }
439
440    fn execute(
441        _ctx: &mut Self::Context,
442        plan: &Self::Plan,
443        alpha: S,
444        inputs: &[&Tensor<S>],
445        beta: S,
446        output: &mut Tensor<S>,
447    ) -> Result<()> {
448        let views: Vec<_> = inputs
449            .iter()
450            .map(|tensor| tensor_to_view(tensor))
451            .collect::<Result<_>>()?;
452        let view_refs: Vec<_> = views.iter().collect();
453        let mut out_view = tensor_to_view_mut(output)?;
454
455        match plan {
456            CpuScalarPlan::PointwiseUnary { op } => {
457                validate_execute_inputs(inputs, 1, "ScalarPointwiseUnary")?;
458                execute_scalar_unary(alpha, view_refs[0], beta, &mut out_view, *op)
459            }
460            CpuScalarPlan::PointwiseBinary { op } => {
461                validate_execute_inputs(inputs, 2, "ScalarPointwiseBinary")?;
462                execute_scalar_binary(alpha, view_refs[0], view_refs[1], beta, &mut out_view, *op)
463            }
464            CpuScalarPlan::PointwiseTernary { op } => {
465                validate_execute_inputs(inputs, 3, "ScalarPointwiseTernary")?;
466                execute_scalar_ternary(
467                    alpha,
468                    view_refs[0],
469                    view_refs[1],
470                    view_refs[2],
471                    beta,
472                    &mut out_view,
473                    *op,
474                )
475            }
476            CpuScalarPlan::Reduction { reduced_axes, op } => {
477                validate_execute_inputs(inputs, 1, "ScalarReduction")?;
478                execute_scalar_reduction(
479                    alpha,
480                    view_refs[0],
481                    beta,
482                    &mut out_view,
483                    reduced_axes,
484                    *op,
485                )
486            }
487        }
488    }
489
490    fn has_scalar_support(desc: ScalarPrimsDescriptor) -> bool {
491        match desc {
492            ScalarPrimsDescriptor::PointwiseUnary { op } => supports_scalar_unary::<S>(op),
493            ScalarPrimsDescriptor::PointwiseBinary { op } => supports_scalar_binary::<S>(op),
494            ScalarPrimsDescriptor::PointwiseTernary { op } => supports_scalar_ternary::<S>(op),
495            ScalarPrimsDescriptor::Reduction { op, .. } => supports_scalar_reduction::<S>(op),
496        }
497    }
498}