1use num_complex::ComplexFloat;
2use num_traits::{Float, One, Zero};
3use tenferro_algebra::Scalar;
4use tenferro_device::{Error, Result};
5use tenferro_tensor::{MemoryOrder, Tensor};
6
7use crate::cpu::common::{plan_reduction, CpuScalarValue};
8use crate::cpu::family_reduction::{
9 execute_extrema_reduction, execute_mean_reduction, execute_prod_reduction,
10 execute_sum_reduction,
11};
12use crate::cpu::{tensor_to_view, tensor_to_view_mut};
13use crate::{
14 validate_execute_inputs, validate_shape_count, validate_shape_eq, ComplexRealPrimsDescriptor,
15 ComplexRealUnaryOp, CpuBackend, CpuContext, ScalarReductionOp, TensorComplexRealPrims,
16};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum CpuComplexRealPlan {
28 PointwiseUnary {
29 op: ComplexRealUnaryOp,
30 },
31 Reduction {
32 unary_op: ComplexRealUnaryOp,
33 reduction_op: ScalarReductionOp,
34 reduced_axes: Vec<usize>,
35 },
36}
37
38fn supports_complex_real_unary(op: ComplexRealUnaryOp) -> bool {
39 matches!(
40 op,
41 ComplexRealUnaryOp::Abs | ComplexRealUnaryOp::Real | ComplexRealUnaryOp::Imag
42 )
43}
44
45fn execute_complex_real_unary_typed<Input>(
46 alpha: Input::Real,
47 input: &strided_view::StridedView<Input>,
48 beta: Input::Real,
49 output: &mut strided_view::StridedViewMut<Input::Real>,
50 op: ComplexRealUnaryOp,
51) -> Result<()>
52where
53 Input: ComplexFloat + Scalar,
54 Input::Real: Scalar + Float,
55{
56 match op {
57 ComplexRealUnaryOp::Abs => {
58 let dims = output.dims().to_vec();
59 crate::for_each_index(&dims, |idx| {
60 let mapped = input.get(idx).abs();
61 let value = alpha * mapped;
62 if beta == Input::Real::zero() {
63 output.set(idx, value);
64 } else {
65 output.set(idx, value + beta * output.get(idx));
66 }
67 });
68 Ok(())
69 }
70 ComplexRealUnaryOp::Real => {
71 let dims = output.dims().to_vec();
72 crate::for_each_index(&dims, |idx| {
73 let mapped = input.get(idx).re();
74 let value = alpha * mapped;
75 if beta == Input::Real::zero() {
76 output.set(idx, value);
77 } else {
78 output.set(idx, value + beta * output.get(idx));
79 }
80 });
81 Ok(())
82 }
83 ComplexRealUnaryOp::Imag => {
84 let dims = output.dims().to_vec();
85 crate::for_each_index(&dims, |idx| {
86 let mapped = input.get(idx).im();
87 let value = alpha * mapped;
88 if beta == Input::Real::zero() {
89 output.set(idx, value);
90 } else {
91 output.set(idx, value + beta * output.get(idx));
92 }
93 });
94 Ok(())
95 }
96 }
97}
98
99fn plan_complex_real_unary<Input>(
100 desc: &ComplexRealPrimsDescriptor,
101 shapes: &[&[usize]],
102) -> Result<CpuComplexRealPlan>
103where
104 Input: ComplexFloat + Scalar,
105 Input::Real: Scalar + Float,
106{
107 validate_shape_count(shapes, 2, "CpuComplexRealPointwiseUnary")?;
108 validate_shape_eq(shapes[0], shapes[1], "CpuComplexRealPointwiseUnary")?;
109 match desc {
110 ComplexRealPrimsDescriptor::PointwiseUnary { op } => {
111 if !supports_complex_real_unary(*op) {
112 return Err(Error::InvalidArgument(format!(
113 "complex-real unary operation {op:?} is not supported on CpuBackend for {}",
114 std::any::type_name::<Input>()
115 )));
116 }
117 Ok(CpuComplexRealPlan::PointwiseUnary { op: *op })
118 }
119 ComplexRealPrimsDescriptor::Reduction { .. } => Err(Error::InvalidArgument(
120 "expected complex-real unary descriptor".into(),
121 )),
122 }
123}
124
125fn plan_complex_real_reduction<Input>(
126 desc: &ComplexRealPrimsDescriptor,
127 shapes: &[&[usize]],
128) -> Result<CpuComplexRealPlan>
129where
130 Input: ComplexFloat + Scalar,
131 Input::Real: Scalar + Float,
132{
133 match desc {
134 ComplexRealPrimsDescriptor::Reduction {
135 modes_a,
136 modes_c,
137 unary_op,
138 reduction_op,
139 } => {
140 if !supports_complex_real_unary(*unary_op) {
141 return Err(Error::InvalidArgument(format!(
142 "complex-real unary operation {unary_op:?} is not supported on CpuBackend for {}",
143 std::any::type_name::<Input>()
144 )));
145 }
146 let spec = plan_reduction(modes_a, modes_c, shapes, "CpuComplexRealReduction")?;
147 Ok(CpuComplexRealPlan::Reduction {
148 unary_op: *unary_op,
149 reduction_op: *reduction_op,
150 reduced_axes: spec.reduced_axes,
151 })
152 }
153 _ => Err(Error::InvalidArgument(
154 "expected complex-real reduction descriptor".into(),
155 )),
156 }
157}
158
159fn execute_complex_real_unary<Input>(
160 plan: &CpuComplexRealPlan,
161 alpha: Input::Real,
162 inputs: &[&Tensor<Input>],
163 beta: Input::Real,
164 output: &mut Tensor<Input::Real>,
165) -> Result<()>
166where
167 Input: ComplexFloat + Scalar + 'static,
168 Input::Real: CpuScalarValue + Float,
169{
170 validate_execute_inputs(inputs, 1, "CpuComplexRealPointwiseUnary")?;
171 let input = tensor_to_view(inputs[0])?;
172 let mut output = tensor_to_view_mut(output)?;
173
174 match plan {
175 CpuComplexRealPlan::PointwiseUnary { op } => {
176 execute_complex_real_unary_typed::<Input>(alpha, &input, beta, &mut output, *op)
177 }
178 CpuComplexRealPlan::Reduction {
179 unary_op,
180 reduction_op,
181 reduced_axes,
182 } => {
183 let input_space = inputs[0].logical_memory_space();
184 let mut temp = Tensor::<Input::Real>::zeros(
185 inputs[0].dims(),
186 input_space,
187 MemoryOrder::ColumnMajor,
188 )?;
189 {
190 let mut temp_view = tensor_to_view_mut(&mut temp)?;
191 execute_complex_real_unary_typed::<Input>(
192 Input::Real::one(),
193 &input,
194 Input::Real::zero(),
195 &mut temp_view,
196 *unary_op,
197 )?;
198 }
199
200 let temp_view = tensor_to_view(&temp)?;
201 match reduction_op {
202 ScalarReductionOp::Sum => {
203 execute_sum_reduction(alpha, &temp_view, beta, &mut output, reduced_axes)
204 }
205 ScalarReductionOp::Prod => {
206 execute_prod_reduction(alpha, &temp_view, beta, &mut output, reduced_axes)
207 }
208 ScalarReductionOp::Mean => {
209 execute_mean_reduction(alpha, &temp_view, beta, &mut output, reduced_axes)
210 }
211 ScalarReductionOp::Max => execute_extrema_reduction(
212 alpha,
213 &temp_view,
214 beta,
215 &mut output,
216 reduced_axes,
217 true,
218 ),
219 ScalarReductionOp::Min => execute_extrema_reduction(
220 alpha,
221 &temp_view,
222 beta,
223 &mut output,
224 reduced_axes,
225 false,
226 ),
227 }
228 }
229 }
230}
231
232impl<Input> TensorComplexRealPrims<Input> for CpuBackend
233where
234 Input: ComplexFloat + Scalar + 'static,
235 Input::Real: CpuScalarValue + Float,
236{
237 type Real = Input::Real;
238 type Plan = CpuComplexRealPlan;
239 type Context = CpuContext;
240
241 fn plan(
242 _ctx: &mut Self::Context,
243 desc: &ComplexRealPrimsDescriptor,
244 shapes: &[&[usize]],
245 ) -> Result<Self::Plan> {
246 match desc {
247 ComplexRealPrimsDescriptor::PointwiseUnary { .. } => {
248 plan_complex_real_unary::<Input>(desc, shapes)
249 }
250 ComplexRealPrimsDescriptor::Reduction { .. } => {
251 plan_complex_real_reduction::<Input>(desc, shapes)
252 }
253 }
254 }
255
256 fn execute(
257 _ctx: &mut Self::Context,
258 plan: &Self::Plan,
259 alpha: Input::Real,
260 inputs: &[&Tensor<Input>],
261 beta: Input::Real,
262 output: &mut Tensor<Self::Real>,
263 ) -> Result<()> {
264 execute_complex_real_unary::<Input>(plan, alpha, inputs, beta, output)
265 }
266
267 fn has_complex_real_support(desc: ComplexRealPrimsDescriptor) -> bool {
268 matches!(
269 desc,
270 ComplexRealPrimsDescriptor::PointwiseUnary {
271 op: ComplexRealUnaryOp::Abs | ComplexRealUnaryOp::Real | ComplexRealUnaryOp::Imag
272 } | ComplexRealPrimsDescriptor::Reduction {
273 unary_op: ComplexRealUnaryOp::Abs
274 | ComplexRealUnaryOp::Real
275 | ComplexRealUnaryOp::Imag,
276 reduction_op: ScalarReductionOp::Sum
277 | ScalarReductionOp::Prod
278 | ScalarReductionOp::Mean
279 | ScalarReductionOp::Max
280 | ScalarReductionOp::Min,
281 ..
282 }
283 )
284 }
285}