1use num_complex::ComplexFloat;
2use num_traits::Float;
3use tenferro_algebra::{Scalar, Standard};
4use tenferro_device::{Error, Result};
5use tenferro_tensor::Tensor;
6
7use crate::cpu::common::{
8 execute_binary_map, execute_unary_map, is_supported_ordered_real_type,
9 is_supported_scalar_type, plan_reduction, validate_pointwise_shapes, ComplexCpuScalarValue,
10 CpuScalarValue, ReductionPlanSpec,
11};
12use crate::cpu::family_reduction::{execute_std_reduction, execute_variance_reduction};
13use crate::cpu::{tensor_to_view, tensor_to_view_mut};
14use crate::infra::typed_dispatch::{
15 cast_scalar_value, cast_strided_view, cast_strided_view_mut, dispatch_complex_scalar_type,
16 dispatch_real_scalar_type, dispatch_standard_scalar_type,
17};
18use crate::{
19 validate_execute_inputs, AnalyticBinaryOp, AnalyticPrimsDescriptor, AnalyticReductionOp,
20 AnalyticUnaryOp, CpuBackend, CpuContext, TensorAnalyticPrims,
21};
22
23#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum CpuAnalyticPlan {
33 PointwiseUnary {
34 op: AnalyticUnaryOp,
35 },
36 PointwiseBinary {
37 op: AnalyticBinaryOp,
38 },
39 Reduction {
40 reduced_axes: Vec<usize>,
41 op: AnalyticReductionOp,
42 },
43}
44
45fn supports_analytic_unary<S: Scalar + 'static>(op: AnalyticUnaryOp) -> bool {
46 match op {
47 AnalyticUnaryOp::Ceil => is_supported_ordered_real_type::<S>(),
48 _ => {
49 is_supported_scalar_type::<S>()
50 && matches!(
51 op,
52 AnalyticUnaryOp::Sqrt
53 | AnalyticUnaryOp::Rsqrt
54 | AnalyticUnaryOp::Exp
55 | AnalyticUnaryOp::Expm1
56 | AnalyticUnaryOp::Log
57 | AnalyticUnaryOp::Log1p
58 | AnalyticUnaryOp::Sin
59 | AnalyticUnaryOp::Cos
60 | AnalyticUnaryOp::Tan
61 | AnalyticUnaryOp::Tanh
62 | AnalyticUnaryOp::Asin
63 | AnalyticUnaryOp::Acos
64 | AnalyticUnaryOp::Atan
65 | AnalyticUnaryOp::Sinh
66 | AnalyticUnaryOp::Cosh
67 | AnalyticUnaryOp::Asinh
68 | AnalyticUnaryOp::Acosh
69 | AnalyticUnaryOp::Atanh
70 )
71 }
72 }
73}
74
75fn supports_analytic_binary<S: Scalar + 'static>(op: AnalyticBinaryOp) -> bool {
76 match op {
77 AnalyticBinaryOp::Pow | AnalyticBinaryOp::Xlogy => is_supported_scalar_type::<S>(),
78 AnalyticBinaryOp::Atan2 | AnalyticBinaryOp::Hypot => is_supported_ordered_real_type::<S>(),
79 }
80}
81
82fn supports_analytic_reduction<S: Scalar + 'static>(op: AnalyticReductionOp) -> bool {
83 match op {
84 AnalyticReductionOp::Var | AnalyticReductionOp::Std => {
85 is_supported_ordered_real_type::<S>()
86 }
87 }
88}
89
90fn execute_analytic_unary_typed<S: CpuScalarValue>(
91 alpha: S,
92 input: &strided_view::StridedView<S>,
93 beta: S,
94 output: &mut strided_view::StridedViewMut<S>,
95 op: AnalyticUnaryOp,
96) -> Result<()> {
97 match op {
98 AnalyticUnaryOp::Sqrt => execute_unary_map(alpha, input, beta, output, |x| x.sqrt()),
99 AnalyticUnaryOp::Rsqrt => {
100 execute_unary_map(alpha, input, beta, output, |x| S::one() / x.sqrt())
101 }
102 AnalyticUnaryOp::Exp => execute_unary_map(alpha, input, beta, output, |x| x.exp()),
103 AnalyticUnaryOp::Expm1 => {
104 execute_unary_map(alpha, input, beta, output, |x| x.exp() - S::one())
105 }
106 AnalyticUnaryOp::Ceil => {
107 dispatch_real_scalar_type!(S, Concrete, {
108 let input = cast_strided_view!(input, S, Concrete);
109 let output = cast_strided_view_mut!(output, S, Concrete);
110 let alpha = cast_scalar_value!(alpha, S, Concrete);
111 let beta = cast_scalar_value!(beta, S, Concrete);
112 return execute_unary_map(alpha, input, beta, output, |x| x.ceil());
113 });
114
115 Err(Error::InvalidArgument(format!(
116 "analytic unary operation {op:?} is not supported for {}",
117 std::any::type_name::<S>()
118 )))
119 }
120 AnalyticUnaryOp::Log => execute_unary_map(alpha, input, beta, output, |x| x.ln()),
121 AnalyticUnaryOp::Log1p => {
122 execute_unary_map(alpha, input, beta, output, |x| (x + S::one()).ln())
123 }
124 AnalyticUnaryOp::Sin => execute_unary_map(alpha, input, beta, output, |x| x.sin()),
125 AnalyticUnaryOp::Cos => execute_unary_map(alpha, input, beta, output, |x| x.cos()),
126 AnalyticUnaryOp::Tan => execute_unary_map(alpha, input, beta, output, |x| x.tan()),
127 AnalyticUnaryOp::Tanh => execute_unary_map(alpha, input, beta, output, |x| x.tanh()),
128 AnalyticUnaryOp::Asin => execute_unary_map(alpha, input, beta, output, |x| x.asin()),
129 AnalyticUnaryOp::Acos => execute_unary_map(alpha, input, beta, output, |x| x.acos()),
130 AnalyticUnaryOp::Atan => execute_unary_map(alpha, input, beta, output, |x| x.atan()),
131 AnalyticUnaryOp::Sinh => execute_unary_map(alpha, input, beta, output, |x| x.sinh()),
132 AnalyticUnaryOp::Cosh => execute_unary_map(alpha, input, beta, output, |x| x.cosh()),
133 AnalyticUnaryOp::Asinh => execute_unary_map(alpha, input, beta, output, |x| x.asinh()),
134 AnalyticUnaryOp::Acosh => execute_unary_map(alpha, input, beta, output, |x| x.acosh()),
135 AnalyticUnaryOp::Atanh => execute_unary_map(alpha, input, beta, output, |x| x.atanh()),
136 }
137}
138
139fn execute_analytic_binary_real<S: Float + CpuScalarValue>(
140 alpha: S,
141 lhs: &strided_view::StridedView<S>,
142 rhs: &strided_view::StridedView<S>,
143 beta: S,
144 output: &mut strided_view::StridedViewMut<S>,
145 op: AnalyticBinaryOp,
146) -> Result<()> {
147 match op {
148 AnalyticBinaryOp::Pow => {
149 execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| Float::powf(x, y))
150 }
151 AnalyticBinaryOp::Atan2 => {
152 execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x.atan2(y))
153 }
154 AnalyticBinaryOp::Hypot => {
155 execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x.hypot(y))
156 }
157 AnalyticBinaryOp::Xlogy => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| {
158 if x == S::zero() {
159 S::zero()
160 } else {
161 x * Float::ln(y)
162 }
163 }),
164 }
165}
166
167fn execute_analytic_binary_complex<S: ComplexCpuScalarValue>(
168 alpha: S,
169 lhs: &strided_view::StridedView<S>,
170 rhs: &strided_view::StridedView<S>,
171 beta: S,
172 output: &mut strided_view::StridedViewMut<S>,
173 op: AnalyticBinaryOp,
174) -> Result<()> {
175 match op {
176 AnalyticBinaryOp::Pow => {
177 execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| x.pow_complex(y))
178 }
179 AnalyticBinaryOp::Xlogy => execute_binary_map(alpha, lhs, rhs, beta, output, |x, y| {
180 if x == S::zero() {
181 S::zero()
182 } else {
183 x * ComplexFloat::ln(y)
184 }
185 }),
186 _ => Err(Error::InvalidArgument(format!(
187 "analytic binary operation {op:?} requires ordered real scalars"
188 ))),
189 }
190}
191
192fn execute_analytic_unary<T: Scalar + 'static>(
193 alpha: T,
194 input: &strided_view::StridedView<T>,
195 beta: T,
196 output: &mut strided_view::StridedViewMut<T>,
197 op: AnalyticUnaryOp,
198) -> Result<()> {
199 dispatch_standard_scalar_type!(T, Concrete, {
200 let input = cast_strided_view!(input, T, Concrete);
201 let output = cast_strided_view_mut!(output, T, Concrete);
202 let alpha = cast_scalar_value!(alpha, T, Concrete);
203 let beta = cast_scalar_value!(beta, T, Concrete);
204 return execute_analytic_unary_typed(alpha, input, beta, output, op);
205 });
206
207 Err(Error::InvalidArgument(format!(
208 "analytic unary operation {op:?} is not supported for {}",
209 std::any::type_name::<T>()
210 )))
211}
212
213fn execute_analytic_binary<T: Scalar + 'static>(
214 alpha: T,
215 lhs: &strided_view::StridedView<T>,
216 rhs: &strided_view::StridedView<T>,
217 beta: T,
218 output: &mut strided_view::StridedViewMut<T>,
219 op: AnalyticBinaryOp,
220) -> Result<()> {
221 dispatch_real_scalar_type!(T, Concrete, {
222 let lhs = cast_strided_view!(lhs, T, Concrete);
223 let rhs = cast_strided_view!(rhs, T, Concrete);
224 let output = cast_strided_view_mut!(output, T, Concrete);
225 let alpha = cast_scalar_value!(alpha, T, Concrete);
226 let beta = cast_scalar_value!(beta, T, Concrete);
227 return execute_analytic_binary_real(alpha, lhs, rhs, beta, output, op);
228 });
229 dispatch_complex_scalar_type!(T, Concrete, {
230 let lhs = cast_strided_view!(lhs, T, Concrete);
231 let rhs = cast_strided_view!(rhs, T, Concrete);
232 let output = cast_strided_view_mut!(output, T, Concrete);
233 let alpha = cast_scalar_value!(alpha, T, Concrete);
234 let beta = cast_scalar_value!(beta, T, Concrete);
235 return execute_analytic_binary_complex(alpha, lhs, rhs, beta, output, op);
236 });
237
238 Err(Error::InvalidArgument(format!(
239 "analytic binary operation {op:?} is not supported for {}",
240 std::any::type_name::<T>()
241 )))
242}
243
244fn execute_analytic_reduction_real<S: Float + CpuScalarValue>(
245 alpha: S,
246 input: &strided_view::StridedView<S>,
247 beta: S,
248 output: &mut strided_view::StridedViewMut<S>,
249 reduced_axes: &[usize],
250 op: AnalyticReductionOp,
251) -> Result<()> {
252 match op {
253 AnalyticReductionOp::Var => {
254 execute_variance_reduction(alpha, input, beta, output, reduced_axes)
255 }
256 AnalyticReductionOp::Std => execute_std_reduction(alpha, input, beta, output, reduced_axes),
257 }
258}
259
260fn execute_analytic_reduction<T: Scalar + 'static>(
261 alpha: T,
262 input: &strided_view::StridedView<T>,
263 beta: T,
264 output: &mut strided_view::StridedViewMut<T>,
265 reduced_axes: &[usize],
266 op: AnalyticReductionOp,
267) -> Result<()> {
268 dispatch_real_scalar_type!(T, Concrete, {
269 let input = cast_strided_view!(input, T, Concrete);
270 let output = cast_strided_view_mut!(output, T, Concrete);
271 let alpha = cast_scalar_value!(alpha, T, Concrete);
272 let beta = cast_scalar_value!(beta, T, Concrete);
273 return execute_analytic_reduction_real(alpha, input, beta, output, reduced_axes, op);
274 });
275
276 Err(Error::InvalidArgument(format!(
277 "analytic reduction {op:?} is not supported for {}",
278 std::any::type_name::<T>()
279 )))
280}
281
282impl<S: Scalar + 'static> TensorAnalyticPrims<Standard<S>> for CpuBackend {
283 type Plan = CpuAnalyticPlan;
284 type Context = CpuContext;
285
286 fn plan(
287 _ctx: &mut Self::Context,
288 desc: &AnalyticPrimsDescriptor,
289 shapes: &[&[usize]],
290 ) -> Result<Self::Plan> {
291 match desc {
292 AnalyticPrimsDescriptor::PointwiseUnary { op } => {
293 validate_pointwise_shapes(shapes, 1, "AnalyticPointwiseUnary")?;
294 if !supports_analytic_unary::<S>(*op) {
295 return Err(Error::InvalidArgument(format!(
296 "analytic unary operation {op:?} is not supported on CpuBackend for {}",
297 std::any::type_name::<S>()
298 )));
299 }
300 Ok(CpuAnalyticPlan::PointwiseUnary { op: *op })
301 }
302 AnalyticPrimsDescriptor::PointwiseBinary { op } => {
303 validate_pointwise_shapes(shapes, 2, "AnalyticPointwiseBinary")?;
304 if !supports_analytic_binary::<S>(*op) {
305 return Err(Error::InvalidArgument(format!(
306 "analytic binary operation {op:?} is not supported on CpuBackend for {}",
307 std::any::type_name::<S>()
308 )));
309 }
310 Ok(CpuAnalyticPlan::PointwiseBinary { op: *op })
311 }
312 AnalyticPrimsDescriptor::Reduction {
313 modes_a,
314 modes_c,
315 op,
316 } => {
317 let ReductionPlanSpec { reduced_axes, .. } =
318 plan_reduction(modes_a, modes_c, shapes, "AnalyticReduction")?;
319 if !supports_analytic_reduction::<S>(*op) {
320 return Err(Error::InvalidArgument(format!(
321 "analytic reduction {op:?} is not supported on CpuBackend for {}",
322 std::any::type_name::<S>()
323 )));
324 }
325 Ok(CpuAnalyticPlan::Reduction {
326 reduced_axes,
327 op: *op,
328 })
329 }
330 }
331 }
332
333 fn execute(
334 _ctx: &mut Self::Context,
335 plan: &Self::Plan,
336 alpha: S,
337 inputs: &[&Tensor<S>],
338 beta: S,
339 output: &mut Tensor<S>,
340 ) -> Result<()> {
341 let views: Vec<_> = inputs
342 .iter()
343 .map(|tensor| tensor_to_view(tensor))
344 .collect::<Result<_>>()?;
345 let view_refs: Vec<_> = views.iter().collect();
346 let mut out_view = tensor_to_view_mut(output)?;
347
348 match plan {
349 CpuAnalyticPlan::PointwiseUnary { op } => {
350 validate_execute_inputs(inputs, 1, "AnalyticPointwiseUnary")?;
351 execute_analytic_unary(alpha, view_refs[0], beta, &mut out_view, *op)
352 }
353 CpuAnalyticPlan::PointwiseBinary { op } => {
354 validate_execute_inputs(inputs, 2, "AnalyticPointwiseBinary")?;
355 execute_analytic_binary(alpha, view_refs[0], view_refs[1], beta, &mut out_view, *op)
356 }
357 CpuAnalyticPlan::Reduction { reduced_axes, op } => {
358 validate_execute_inputs(inputs, 1, "AnalyticReduction")?;
359 execute_analytic_reduction(
360 alpha,
361 view_refs[0],
362 beta,
363 &mut out_view,
364 reduced_axes,
365 *op,
366 )
367 }
368 }
369 }
370
371 fn has_analytic_support(desc: AnalyticPrimsDescriptor) -> bool {
372 match desc {
373 AnalyticPrimsDescriptor::PointwiseUnary { op } => supports_analytic_unary::<S>(op),
374 AnalyticPrimsDescriptor::PointwiseBinary { op } => supports_analytic_binary::<S>(op),
375 AnalyticPrimsDescriptor::Reduction { op, .. } => supports_analytic_reduction::<S>(op),
376 }
377 }
378}