tenferro_prims/cpu/
complex_real.rs

1use num_complex::ComplexFloat;
2use num_traits::{Float, One, Zero};
3use tenferro_algebra::Scalar;
4use tenferro_device::{Error, Result};
5use tenferro_tensor::{MemoryOrder, Tensor};
6
7use crate::cpu::common::{plan_reduction, CpuScalarValue};
8use crate::cpu::family_reduction::{
9    execute_extrema_reduction, execute_mean_reduction, execute_prod_reduction,
10    execute_sum_reduction,
11};
12use crate::cpu::{tensor_to_view, tensor_to_view_mut};
13use crate::{
14    validate_execute_inputs, validate_shape_count, validate_shape_eq, ComplexRealPrimsDescriptor,
15    ComplexRealUnaryOp, CpuBackend, CpuContext, ScalarReductionOp, TensorComplexRealPrims,
16};
17
18/// CPU execution plan for the complex-to-real unary protocol family.
19///
20/// # Examples
21///
22/// ```ignore
23/// use tenferro_prims::CpuComplexRealPlan;
24/// let _ = std::mem::size_of::<CpuComplexRealPlan>();
25/// ```
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum CpuComplexRealPlan {
28    PointwiseUnary {
29        op: ComplexRealUnaryOp,
30    },
31    Reduction {
32        unary_op: ComplexRealUnaryOp,
33        reduction_op: ScalarReductionOp,
34        reduced_axes: Vec<usize>,
35    },
36}
37
38fn supports_complex_real_unary(op: ComplexRealUnaryOp) -> bool {
39    matches!(
40        op,
41        ComplexRealUnaryOp::Abs | ComplexRealUnaryOp::Real | ComplexRealUnaryOp::Imag
42    )
43}
44
45fn execute_complex_real_unary_typed<Input>(
46    alpha: Input::Real,
47    input: &strided_view::StridedView<Input>,
48    beta: Input::Real,
49    output: &mut strided_view::StridedViewMut<Input::Real>,
50    op: ComplexRealUnaryOp,
51) -> Result<()>
52where
53    Input: ComplexFloat + Scalar,
54    Input::Real: Scalar + Float,
55{
56    match op {
57        ComplexRealUnaryOp::Abs => {
58            let dims = output.dims().to_vec();
59            crate::for_each_index(&dims, |idx| {
60                let mapped = input.get(idx).abs();
61                let value = alpha * mapped;
62                if beta == Input::Real::zero() {
63                    output.set(idx, value);
64                } else {
65                    output.set(idx, value + beta * output.get(idx));
66                }
67            });
68            Ok(())
69        }
70        ComplexRealUnaryOp::Real => {
71            let dims = output.dims().to_vec();
72            crate::for_each_index(&dims, |idx| {
73                let mapped = input.get(idx).re();
74                let value = alpha * mapped;
75                if beta == Input::Real::zero() {
76                    output.set(idx, value);
77                } else {
78                    output.set(idx, value + beta * output.get(idx));
79                }
80            });
81            Ok(())
82        }
83        ComplexRealUnaryOp::Imag => {
84            let dims = output.dims().to_vec();
85            crate::for_each_index(&dims, |idx| {
86                let mapped = input.get(idx).im();
87                let value = alpha * mapped;
88                if beta == Input::Real::zero() {
89                    output.set(idx, value);
90                } else {
91                    output.set(idx, value + beta * output.get(idx));
92                }
93            });
94            Ok(())
95        }
96    }
97}
98
99fn plan_complex_real_unary<Input>(
100    desc: &ComplexRealPrimsDescriptor,
101    shapes: &[&[usize]],
102) -> Result<CpuComplexRealPlan>
103where
104    Input: ComplexFloat + Scalar,
105    Input::Real: Scalar + Float,
106{
107    validate_shape_count(shapes, 2, "CpuComplexRealPointwiseUnary")?;
108    validate_shape_eq(shapes[0], shapes[1], "CpuComplexRealPointwiseUnary")?;
109    match desc {
110        ComplexRealPrimsDescriptor::PointwiseUnary { op } => {
111            if !supports_complex_real_unary(*op) {
112                return Err(Error::InvalidArgument(format!(
113                    "complex-real unary operation {op:?} is not supported on CpuBackend for {}",
114                    std::any::type_name::<Input>()
115                )));
116            }
117            Ok(CpuComplexRealPlan::PointwiseUnary { op: *op })
118        }
119        ComplexRealPrimsDescriptor::Reduction { .. } => Err(Error::InvalidArgument(
120            "expected complex-real unary descriptor".into(),
121        )),
122    }
123}
124
125fn plan_complex_real_reduction<Input>(
126    desc: &ComplexRealPrimsDescriptor,
127    shapes: &[&[usize]],
128) -> Result<CpuComplexRealPlan>
129where
130    Input: ComplexFloat + Scalar,
131    Input::Real: Scalar + Float,
132{
133    match desc {
134        ComplexRealPrimsDescriptor::Reduction {
135            modes_a,
136            modes_c,
137            unary_op,
138            reduction_op,
139        } => {
140            if !supports_complex_real_unary(*unary_op) {
141                return Err(Error::InvalidArgument(format!(
142                    "complex-real unary operation {unary_op:?} is not supported on CpuBackend for {}",
143                    std::any::type_name::<Input>()
144                )));
145            }
146            let spec = plan_reduction(modes_a, modes_c, shapes, "CpuComplexRealReduction")?;
147            Ok(CpuComplexRealPlan::Reduction {
148                unary_op: *unary_op,
149                reduction_op: *reduction_op,
150                reduced_axes: spec.reduced_axes,
151            })
152        }
153        _ => Err(Error::InvalidArgument(
154            "expected complex-real reduction descriptor".into(),
155        )),
156    }
157}
158
159fn execute_complex_real_unary<Input>(
160    plan: &CpuComplexRealPlan,
161    alpha: Input::Real,
162    inputs: &[&Tensor<Input>],
163    beta: Input::Real,
164    output: &mut Tensor<Input::Real>,
165) -> Result<()>
166where
167    Input: ComplexFloat + Scalar + 'static,
168    Input::Real: CpuScalarValue + Float,
169{
170    validate_execute_inputs(inputs, 1, "CpuComplexRealPointwiseUnary")?;
171    let input = tensor_to_view(inputs[0])?;
172    let mut output = tensor_to_view_mut(output)?;
173
174    match plan {
175        CpuComplexRealPlan::PointwiseUnary { op } => {
176            execute_complex_real_unary_typed::<Input>(alpha, &input, beta, &mut output, *op)
177        }
178        CpuComplexRealPlan::Reduction {
179            unary_op,
180            reduction_op,
181            reduced_axes,
182        } => {
183            let input_space = inputs[0].logical_memory_space();
184            let mut temp = Tensor::<Input::Real>::zeros(
185                inputs[0].dims(),
186                input_space,
187                MemoryOrder::ColumnMajor,
188            )?;
189            {
190                let mut temp_view = tensor_to_view_mut(&mut temp)?;
191                execute_complex_real_unary_typed::<Input>(
192                    Input::Real::one(),
193                    &input,
194                    Input::Real::zero(),
195                    &mut temp_view,
196                    *unary_op,
197                )?;
198            }
199
200            let temp_view = tensor_to_view(&temp)?;
201            match reduction_op {
202                ScalarReductionOp::Sum => {
203                    execute_sum_reduction(alpha, &temp_view, beta, &mut output, reduced_axes)
204                }
205                ScalarReductionOp::Prod => {
206                    execute_prod_reduction(alpha, &temp_view, beta, &mut output, reduced_axes)
207                }
208                ScalarReductionOp::Mean => {
209                    execute_mean_reduction(alpha, &temp_view, beta, &mut output, reduced_axes)
210                }
211                ScalarReductionOp::Max => execute_extrema_reduction(
212                    alpha,
213                    &temp_view,
214                    beta,
215                    &mut output,
216                    reduced_axes,
217                    true,
218                ),
219                ScalarReductionOp::Min => execute_extrema_reduction(
220                    alpha,
221                    &temp_view,
222                    beta,
223                    &mut output,
224                    reduced_axes,
225                    false,
226                ),
227            }
228        }
229    }
230}
231
232impl<Input> TensorComplexRealPrims<Input> for CpuBackend
233where
234    Input: ComplexFloat + Scalar + 'static,
235    Input::Real: CpuScalarValue + Float,
236{
237    type Real = Input::Real;
238    type Plan = CpuComplexRealPlan;
239    type Context = CpuContext;
240
241    fn plan(
242        _ctx: &mut Self::Context,
243        desc: &ComplexRealPrimsDescriptor,
244        shapes: &[&[usize]],
245    ) -> Result<Self::Plan> {
246        match desc {
247            ComplexRealPrimsDescriptor::PointwiseUnary { .. } => {
248                plan_complex_real_unary::<Input>(desc, shapes)
249            }
250            ComplexRealPrimsDescriptor::Reduction { .. } => {
251                plan_complex_real_reduction::<Input>(desc, shapes)
252            }
253        }
254    }
255
256    fn execute(
257        _ctx: &mut Self::Context,
258        plan: &Self::Plan,
259        alpha: Input::Real,
260        inputs: &[&Tensor<Input>],
261        beta: Input::Real,
262        output: &mut Tensor<Self::Real>,
263    ) -> Result<()> {
264        execute_complex_real_unary::<Input>(plan, alpha, inputs, beta, output)
265    }
266
267    fn has_complex_real_support(desc: ComplexRealPrimsDescriptor) -> bool {
268        matches!(
269            desc,
270            ComplexRealPrimsDescriptor::PointwiseUnary {
271                op: ComplexRealUnaryOp::Abs | ComplexRealUnaryOp::Real | ComplexRealUnaryOp::Imag
272            } | ComplexRealPrimsDescriptor::Reduction {
273                unary_op: ComplexRealUnaryOp::Abs
274                    | ComplexRealUnaryOp::Real
275                    | ComplexRealUnaryOp::Imag,
276                reduction_op: ScalarReductionOp::Sum
277                    | ScalarReductionOp::Prod
278                    | ScalarReductionOp::Mean
279                    | ScalarReductionOp::Max
280                    | ScalarReductionOp::Min,
281                ..
282            }
283        )
284    }
285}