tenferro_prims/cpu/
analytic.rs

1use num_complex::ComplexFloat;
2use num_traits::Float;
3use tenferro_algebra::{Scalar, Standard};
4use tenferro_device::{Error, Result};
5use tenferro_tensor::Tensor;
6
7use crate::cpu::common::{
8    execute_binary_map, execute_unary_map, is_supported_ordered_real_type,
9    is_supported_scalar_type, plan_reduction, validate_pointwise_shapes, ComplexCpuScalarValue,
10    CpuScalarValue, ReductionPlanSpec,
11};
12use crate::cpu::family_reduction::{execute_std_reduction, execute_variance_reduction};
13use crate::cpu::{tensor_to_view, tensor_to_view_mut};
14use crate::infra::typed_dispatch::{
15    cast_scalar_value, cast_strided_view, cast_strided_view_mut, dispatch_complex_scalar_type,
16    dispatch_real_scalar_type, dispatch_standard_scalar_type,
17};
18use crate::{
19    validate_execute_inputs, AnalyticBinaryOp, AnalyticPrimsDescriptor, AnalyticReductionOp,
20    AnalyticUnaryOp, CpuBackend, CpuContext, TensorAnalyticPrims,
21};
22
23/// CPU execution plan for the analytic protocol family.
24///
25/// # Examples
26///
27/// ```ignore
28/// use tenferro_prims::CpuAnalyticPlan;
29/// let _ = std::mem::size_of::<CpuAnalyticPlan>();
30/// ```
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum CpuAnalyticPlan {
33    PointwiseUnary {
34        op: AnalyticUnaryOp,
35    },
36    PointwiseBinary {
37        op: AnalyticBinaryOp,
38    },
39    Reduction {
40        reduced_axes: Vec<usize>,
41        op: AnalyticReductionOp,
42    },
43}
44
45fn supports_analytic_unary<S: Scalar + 'static>(op: AnalyticUnaryOp) -> bool {
46    match op {
47        AnalyticUnaryOp::Ceil => is_supported_ordered_real_type::<S>(),
48        _ => {
49            is_supported_scalar_type::<S>()
50                && matches!(
51                    op,
52                    AnalyticUnaryOp::Sqrt
53                        | AnalyticUnaryOp::Rsqrt
54                        | AnalyticUnaryOp::Exp
55                        | AnalyticUnaryOp::Expm1
56                        | AnalyticUnaryOp::Log
57                        | AnalyticUnaryOp::Log1p
58                        | AnalyticUnaryOp::Sin
59                        | AnalyticUnaryOp::Cos
60                        | AnalyticUnaryOp::Tan
61                        | AnalyticUnaryOp::Tanh
62                        | AnalyticUnaryOp::Asin
63                        | AnalyticUnaryOp::Acos
64                        | AnalyticUnaryOp::Atan
65                        | AnalyticUnaryOp::Sinh
66                        | AnalyticUnaryOp::Cosh
67                        | AnalyticUnaryOp::Asinh
68                        | AnalyticUnaryOp::Acosh
69                        | AnalyticUnaryOp::Atanh
70                )
71        }
72    }
73}
74
75fn supports_analytic_binary<S: Scalar + 'static>(op: AnalyticBinaryOp) -> bool {
76    match op {
77        AnalyticBinaryOp::Pow | AnalyticBinaryOp::Xlogy => is_supported_scalar_type::<S>(),
78        AnalyticBinaryOp::Atan2 | AnalyticBinaryOp::Hypot => is_supported_ordered_real_type::<S>(),
79    }
80}
81
82fn supports_analytic_reduction<S: Scalar + 'static>(op: AnalyticReductionOp) -> bool {
83    match op {
84        AnalyticReductionOp::Var | AnalyticReductionOp::Std => {
85            is_supported_ordered_real_type::<S>()
86        }
87    }
88}
89
90fn execute_analytic_unary_typed<S: CpuScalarValue>(
91    alpha: S,
92    input: &strided_view::StridedView<S>,
93    beta: S,
94    output: &mut strided_view::StridedViewMut<S>,
95    op: AnalyticUnaryOp,
96) -> Result<()> {
97    match op {
98        AnalyticUnaryOp::Sqrt => execute_unary_map(alpha, input, beta, output, |x| x.sqrt()),
99        AnalyticUnaryOp::Rsqrt => {
100            execute_unary_map(alpha, input, beta, output, |x| S::one() / x.sqrt())
101        }
102        AnalyticUnaryOp::Exp => execute_unary_map(alpha, input, beta, output, |x| x.exp()),
103        AnalyticUnaryOp::Expm1 => {
104            execute_unary_map(alpha, input, beta, output, |x| x.exp() - S::one())
105        }
106        AnalyticUnaryOp::Ceil => {
107            dispatch_real_scalar_type!(S, Concrete, {
108                let input = cast_strided_view!(input, S, Concrete);
109                let output = cast_strided_view_mut!(output, S, Concrete);
110                let alpha = cast_scalar_value!(alpha, S, Concrete);
111                let beta = cast_scalar_value!(beta, S, Concrete);
112                return execute_unary_map(alpha, input, beta, output, |x| x.ceil());
113            });
114
115            Err(Error::InvalidArgument(format!(
116                "analytic unary operation {op:?} is not supported for {}",
117                std::any::type_name::<S>()
118            )))
119        }
120        AnalyticUnaryOp::Log => execute_unary_map(alpha, input, beta, output, |x| x.ln()),
121        AnalyticUnaryOp::Log1p => {
122            execute_unary_map(alpha, input, beta, output, |x| (x + S::one()).ln())
123        }
124        AnalyticUnaryOp::Sin => execute_unary_map(alpha, input, beta, output, |x| x.sin()),
125        AnalyticUnaryOp::Cos => execute_unary_map(alpha, input, beta, output, |x| x.cos()),
126        AnalyticUnaryOp::Tan => execute_unary_map(alpha, input, beta, output, |x| x.tan()),
127        AnalyticUnaryOp::Tanh => execute_unary_map(alpha, input, beta, output, |x| x.tanh()),
128        AnalyticUnaryOp::Asin => execute_unary_map(alpha, input, beta, output, |x| x.asin()),
129        AnalyticUnaryOp::Acos => execute_unary_map(alpha, input, beta, output, |x| x.acos()),
130        AnalyticUnaryOp::Atan => execute_unary_map(alpha, input, beta, output, |x| x.atan()),
131        AnalyticUnaryOp::Sinh => execute_unary_map(alpha, input, beta, output, |x| x.sinh()),
132        AnalyticUnaryOp::Cosh => execute_unary_map(alpha, input, beta, output, |x| x.cosh()),
133        AnalyticUnaryOp::Asinh => execute_unary_map(alpha, input, beta, output, |x| x.asinh()),
134        AnalyticUnaryOp::Acosh => execute_unary_map(alpha, input, beta, output, |x| x.acosh()),
135        AnalyticUnaryOp::Atanh => execute_unary_map(alpha, input, beta, output, |x| x.atanh()),
136    }
137}
138
139fn execute_analytic_binary_real<S: Float + CpuScalarValue>(
140    alpha: S,
141    lhs: &strided_view::StridedView<S>,
142    rhs: &strided_view::StridedView<S>,
143    beta: S,
144    output: &mut strided_view::StridedViewMut<S>,
145    op: AnalyticBinaryOp,
146) -> Result<()> {
147    match op {
148        AnalyticBinaryOp::Pow => {
149            execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| Float::powf(x, y))
150        }
151        AnalyticBinaryOp::Atan2 => {
152            execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x.atan2(y))
153        }
154        AnalyticBinaryOp::Hypot => {
155            execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x.hypot(y))
156        }
157        AnalyticBinaryOp::Xlogy => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| {
158            if x == S::zero() {
159                S::zero()
160            } else {
161                x * Float::ln(y)
162            }
163        }),
164    }
165}
166
167fn execute_analytic_binary_complex<S: ComplexCpuScalarValue>(
168    alpha: S,
169    lhs: &strided_view::StridedView<S>,
170    rhs: &strided_view::StridedView<S>,
171    beta: S,
172    output: &mut strided_view::StridedViewMut<S>,
173    op: AnalyticBinaryOp,
174) -> Result<()> {
175    match op {
176        AnalyticBinaryOp::Pow => {
177            execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x.pow_complex(y))
178        }
179        AnalyticBinaryOp::Xlogy => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| {
180            if x == S::zero() {
181                S::zero()
182            } else {
183                x * ComplexFloat::ln(y)
184            }
185        }),
186        _ => Err(Error::InvalidArgument(format!(
187            "analytic binary operation {op:?} requires ordered real scalars"
188        ))),
189    }
190}
191
192fn execute_analytic_unary<T: Scalar + 'static>(
193    alpha: T,
194    input: &strided_view::StridedView<T>,
195    beta: T,
196    output: &mut strided_view::StridedViewMut<T>,
197    op: AnalyticUnaryOp,
198) -> Result<()> {
199    dispatch_standard_scalar_type!(T, Concrete, {
200        let input = cast_strided_view!(input, T, Concrete);
201        let output = cast_strided_view_mut!(output, T, Concrete);
202        let alpha = cast_scalar_value!(alpha, T, Concrete);
203        let beta = cast_scalar_value!(beta, T, Concrete);
204        return execute_analytic_unary_typed(alpha, input, beta, output, op);
205    });
206
207    Err(Error::InvalidArgument(format!(
208        "analytic unary operation {op:?} is not supported for {}",
209        std::any::type_name::<T>()
210    )))
211}
212
213fn execute_analytic_binary<T: Scalar + 'static>(
214    alpha: T,
215    lhs: &strided_view::StridedView<T>,
216    rhs: &strided_view::StridedView<T>,
217    beta: T,
218    output: &mut strided_view::StridedViewMut<T>,
219    op: AnalyticBinaryOp,
220) -> Result<()> {
221    dispatch_real_scalar_type!(T, Concrete, {
222        let lhs = cast_strided_view!(lhs, T, Concrete);
223        let rhs = cast_strided_view!(rhs, T, Concrete);
224        let output = cast_strided_view_mut!(output, T, Concrete);
225        let alpha = cast_scalar_value!(alpha, T, Concrete);
226        let beta = cast_scalar_value!(beta, T, Concrete);
227        return execute_analytic_binary_real(alpha, lhs, rhs, beta, output, op);
228    });
229    dispatch_complex_scalar_type!(T, Concrete, {
230        let lhs = cast_strided_view!(lhs, T, Concrete);
231        let rhs = cast_strided_view!(rhs, T, Concrete);
232        let output = cast_strided_view_mut!(output, T, Concrete);
233        let alpha = cast_scalar_value!(alpha, T, Concrete);
234        let beta = cast_scalar_value!(beta, T, Concrete);
235        return execute_analytic_binary_complex(alpha, lhs, rhs, beta, output, op);
236    });
237
238    Err(Error::InvalidArgument(format!(
239        "analytic binary operation {op:?} is not supported for {}",
240        std::any::type_name::<T>()
241    )))
242}
243
244fn execute_analytic_reduction_real<S: Float + CpuScalarValue>(
245    alpha: S,
246    input: &strided_view::StridedView<S>,
247    beta: S,
248    output: &mut strided_view::StridedViewMut<S>,
249    reduced_axes: &[usize],
250    op: AnalyticReductionOp,
251) -> Result<()> {
252    match op {
253        AnalyticReductionOp::Var => {
254            execute_variance_reduction(alpha, input, beta, output, reduced_axes)
255        }
256        AnalyticReductionOp::Std => execute_std_reduction(alpha, input, beta, output, reduced_axes),
257    }
258}
259
260fn execute_analytic_reduction<T: Scalar + 'static>(
261    alpha: T,
262    input: &strided_view::StridedView<T>,
263    beta: T,
264    output: &mut strided_view::StridedViewMut<T>,
265    reduced_axes: &[usize],
266    op: AnalyticReductionOp,
267) -> Result<()> {
268    dispatch_real_scalar_type!(T, Concrete, {
269        let input = cast_strided_view!(input, T, Concrete);
270        let output = cast_strided_view_mut!(output, T, Concrete);
271        let alpha = cast_scalar_value!(alpha, T, Concrete);
272        let beta = cast_scalar_value!(beta, T, Concrete);
273        return execute_analytic_reduction_real(alpha, input, beta, output, reduced_axes, op);
274    });
275
276    Err(Error::InvalidArgument(format!(
277        "analytic reduction {op:?} is not supported for {}",
278        std::any::type_name::<T>()
279    )))
280}
281
282impl<S: Scalar + 'static> TensorAnalyticPrims<Standard<S>> for CpuBackend {
283    type Plan = CpuAnalyticPlan;
284    type Context = CpuContext;
285
286    fn plan(
287        _ctx: &mut Self::Context,
288        desc: &AnalyticPrimsDescriptor,
289        shapes: &[&[usize]],
290    ) -> Result<Self::Plan> {
291        match desc {
292            AnalyticPrimsDescriptor::PointwiseUnary { op } => {
293                validate_pointwise_shapes(shapes, 1, "AnalyticPointwiseUnary")?;
294                if !supports_analytic_unary::<S>(*op) {
295                    return Err(Error::InvalidArgument(format!(
296                        "analytic unary operation {op:?} is not supported on CpuBackend for {}",
297                        std::any::type_name::<S>()
298                    )));
299                }
300                Ok(CpuAnalyticPlan::PointwiseUnary { op: *op })
301            }
302            AnalyticPrimsDescriptor::PointwiseBinary { op } => {
303                validate_pointwise_shapes(shapes, 2, "AnalyticPointwiseBinary")?;
304                if !supports_analytic_binary::<S>(*op) {
305                    return Err(Error::InvalidArgument(format!(
306                        "analytic binary operation {op:?} is not supported on CpuBackend for {}",
307                        std::any::type_name::<S>()
308                    )));
309                }
310                Ok(CpuAnalyticPlan::PointwiseBinary { op: *op })
311            }
312            AnalyticPrimsDescriptor::Reduction {
313                modes_a,
314                modes_c,
315                op,
316            } => {
317                let ReductionPlanSpec { reduced_axes, .. } =
318                    plan_reduction(modes_a, modes_c, shapes, "AnalyticReduction")?;
319                if !supports_analytic_reduction::<S>(*op) {
320                    return Err(Error::InvalidArgument(format!(
321                        "analytic reduction {op:?} is not supported on CpuBackend for {}",
322                        std::any::type_name::<S>()
323                    )));
324                }
325                Ok(CpuAnalyticPlan::Reduction {
326                    reduced_axes,
327                    op: *op,
328                })
329            }
330        }
331    }
332
333    fn execute(
334        _ctx: &mut Self::Context,
335        plan: &Self::Plan,
336        alpha: S,
337        inputs: &[&Tensor<S>],
338        beta: S,
339        output: &mut Tensor<S>,
340    ) -> Result<()> {
341        let views: Vec<_> = inputs
342            .iter()
343            .map(|tensor| tensor_to_view(tensor))
344            .collect::<Result<_>>()?;
345        let view_refs: Vec<_> = views.iter().collect();
346        let mut out_view = tensor_to_view_mut(output)?;
347
348        match plan {
349            CpuAnalyticPlan::PointwiseUnary { op } => {
350                validate_execute_inputs(inputs, 1, "AnalyticPointwiseUnary")?;
351                execute_analytic_unary(alpha, view_refs[0], beta, &mut out_view, *op)
352            }
353            CpuAnalyticPlan::PointwiseBinary { op } => {
354                validate_execute_inputs(inputs, 2, "AnalyticPointwiseBinary")?;
355                execute_analytic_binary(alpha, view_refs[0], view_refs[1], beta, &mut out_view, *op)
356            }
357            CpuAnalyticPlan::Reduction { reduced_axes, op } => {
358                validate_execute_inputs(inputs, 1, "AnalyticReduction")?;
359                execute_analytic_reduction(
360                    alpha,
361                    view_refs[0],
362                    beta,
363                    &mut out_view,
364                    reduced_axes,
365                    *op,
366                )
367            }
368        }
369    }
370
371    fn has_analytic_support(desc: AnalyticPrimsDescriptor) -> bool {
372        match desc {
373            AnalyticPrimsDescriptor::PointwiseUnary { op } => supports_analytic_unary::<S>(op),
374            AnalyticPrimsDescriptor::PointwiseBinary { op } => supports_analytic_binary::<S>(op),
375            AnalyticPrimsDescriptor::Reduction { op, .. } => supports_analytic_reduction::<S>(op),
376        }
377    }
378}