tenferro_prims/cpu/
complex_scale.rs

1use num_complex::ComplexFloat;
2use tenferro_algebra::Scalar;
3use tenferro_device::Result;
4use tenferro_tensor::Tensor;
5
6use crate::cpu::{tensor_to_view, tensor_to_view_mut};
7use crate::{
8    validate_shape_count, validate_shape_eq, ComplexScalePrimsDescriptor, CpuBackend, CpuContext,
9    TensorComplexScalePrims,
10};
11
12/// CPU execution plan for the complex-by-real pointwise protocol family.
13///
14/// # Examples
15///
16/// ```ignore
17/// use tenferro_prims::CpuComplexScalePlan;
18/// let _ = std::mem::size_of::<CpuComplexScalePlan>();
19/// ```
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum CpuComplexScalePlan {
22    PointwiseMul,
23}
24
25fn plan_complex_scale(
26    desc: &ComplexScalePrimsDescriptor,
27    shapes: &[&[usize]],
28) -> Result<CpuComplexScalePlan> {
29    validate_shape_count(shapes, 3, "CpuComplexScalePointwiseMul")?;
30    validate_shape_eq(shapes[0], shapes[1], "CpuComplexScalePointwiseMul lhs/rhs")?;
31    validate_shape_eq(
32        shapes[0],
33        shapes[2],
34        "CpuComplexScalePointwiseMul lhs/output",
35    )?;
36    match desc {
37        ComplexScalePrimsDescriptor::PointwiseMul => Ok(CpuComplexScalePlan::PointwiseMul),
38    }
39}
40
41fn execute_complex_scale_typed<Input>(
42    alpha: Input,
43    lhs: &strided_view::StridedView<Input>,
44    rhs: &strided_view::StridedView<Input::Real>,
45    beta: Input,
46    output: &mut strided_view::StridedViewMut<Input>,
47) -> Result<()>
48where
49    Input: ComplexFloat
50        + Scalar
51        + std::ops::Add<Output = Input>
52        + std::ops::Mul<Input::Real, Output = Input>
53        + std::ops::Mul<Output = Input>,
54    Input::Real: Scalar + Send + Sync,
55{
56    let dims = output.dims().to_vec();
57    crate::for_each_index(&dims, |idx| {
58        output.set(
59            idx,
60            alpha * (lhs.get(idx) * rhs.get(idx)) + beta * output.get(idx),
61        );
62    });
63    Ok(())
64}
65
66fn execute_complex_scale<Input>(
67    plan: &CpuComplexScalePlan,
68    alpha: Input,
69    lhs: &Tensor<Input>,
70    rhs: &Tensor<Input::Real>,
71    beta: Input,
72    output: &mut Tensor<Input>,
73) -> Result<()>
74where
75    Input: ComplexFloat
76        + Scalar
77        + 'static
78        + std::ops::Add<Output = Input>
79        + std::ops::Mul<Input::Real, Output = Input>
80        + std::ops::Mul<Output = Input>,
81    Input::Real: Scalar + Send + Sync,
82{
83    validate_shape_eq(lhs.dims(), rhs.dims(), "CpuComplexScalePointwiseMul rhs")?;
84    validate_shape_eq(
85        lhs.dims(),
86        output.dims(),
87        "CpuComplexScalePointwiseMul output",
88    )?;
89
90    let lhs = tensor_to_view(lhs)?;
91    let rhs = tensor_to_view(rhs)?;
92    let mut output = tensor_to_view_mut(output)?;
93
94    match plan {
95        CpuComplexScalePlan::PointwiseMul => {
96            execute_complex_scale_typed::<Input>(alpha, &lhs, &rhs, beta, &mut output)
97        }
98    }
99}
100
101impl<Input> TensorComplexScalePrims<Input> for CpuBackend
102where
103    Input: ComplexFloat
104        + Scalar
105        + 'static
106        + std::ops::Add<Output = Input>
107        + std::ops::Mul<Input::Real, Output = Input>
108        + std::ops::Mul<Output = Input>,
109    Input::Real: Scalar + Send + Sync,
110{
111    type Plan = CpuComplexScalePlan;
112    type Context = CpuContext;
113
114    fn plan(
115        _ctx: &mut Self::Context,
116        desc: &ComplexScalePrimsDescriptor,
117        shapes: &[&[usize]],
118    ) -> Result<Self::Plan> {
119        plan_complex_scale(desc, shapes)
120    }
121
122    fn execute(
123        _ctx: &mut Self::Context,
124        plan: &Self::Plan,
125        alpha: Input,
126        lhs: &Tensor<Input>,
127        rhs: &Tensor<Input::Real>,
128        beta: Input,
129        output: &mut Tensor<Input>,
130    ) -> Result<()> {
131        execute_complex_scale::<Input>(plan, alpha, lhs, rhs, beta, output)
132    }
133
134    fn has_complex_scale_support(desc: ComplexScalePrimsDescriptor) -> bool {
135        matches!(desc, ComplexScalePrimsDescriptor::PointwiseMul)
136    }
137}