tenferro_prims/cpu/
complex_scale.rs1use num_complex::ComplexFloat;
2use tenferro_algebra::Scalar;
3use tenferro_device::Result;
4use tenferro_tensor::Tensor;
5
6use crate::cpu::{tensor_to_view, tensor_to_view_mut};
7use crate::{
8 validate_shape_count, validate_shape_eq, ComplexScalePrimsDescriptor, CpuBackend, CpuContext,
9 TensorComplexScalePrims,
10};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum CpuComplexScalePlan {
22 PointwiseMul,
23}
24
25fn plan_complex_scale(
26 desc: &ComplexScalePrimsDescriptor,
27 shapes: &[&[usize]],
28) -> Result<CpuComplexScalePlan> {
29 validate_shape_count(shapes, 3, "CpuComplexScalePointwiseMul")?;
30 validate_shape_eq(shapes[0], shapes[1], "CpuComplexScalePointwiseMul lhs/rhs")?;
31 validate_shape_eq(
32 shapes[0],
33 shapes[2],
34 "CpuComplexScalePointwiseMul lhs/output",
35 )?;
36 match desc {
37 ComplexScalePrimsDescriptor::PointwiseMul => Ok(CpuComplexScalePlan::PointwiseMul),
38 }
39}
40
41fn execute_complex_scale_typed<Input>(
42 alpha: Input,
43 lhs: &strided_view::StridedView<Input>,
44 rhs: &strided_view::StridedView<Input::Real>,
45 beta: Input,
46 output: &mut strided_view::StridedViewMut<Input>,
47) -> Result<()>
48where
49 Input: ComplexFloat
50 + Scalar
51 + std::ops::Add<Output = Input>
52 + std::ops::Mul<Input::Real, Output = Input>
53 + std::ops::Mul<Output = Input>,
54 Input::Real: Scalar + Send + Sync,
55{
56 let dims = output.dims().to_vec();
57 crate::for_each_index(&dims, |idx| {
58 output.set(
59 idx,
60 alpha * (lhs.get(idx) * rhs.get(idx)) + beta * output.get(idx),
61 );
62 });
63 Ok(())
64}
65
66fn execute_complex_scale<Input>(
67 plan: &CpuComplexScalePlan,
68 alpha: Input,
69 lhs: &Tensor<Input>,
70 rhs: &Tensor<Input::Real>,
71 beta: Input,
72 output: &mut Tensor<Input>,
73) -> Result<()>
74where
75 Input: ComplexFloat
76 + Scalar
77 + 'static
78 + std::ops::Add<Output = Input>
79 + std::ops::Mul<Input::Real, Output = Input>
80 + std::ops::Mul<Output = Input>,
81 Input::Real: Scalar + Send + Sync,
82{
83 validate_shape_eq(lhs.dims(), rhs.dims(), "CpuComplexScalePointwiseMul rhs")?;
84 validate_shape_eq(
85 lhs.dims(),
86 output.dims(),
87 "CpuComplexScalePointwiseMul output",
88 )?;
89
90 let lhs = tensor_to_view(lhs)?;
91 let rhs = tensor_to_view(rhs)?;
92 let mut output = tensor_to_view_mut(output)?;
93
94 match plan {
95 CpuComplexScalePlan::PointwiseMul => {
96 execute_complex_scale_typed::<Input>(alpha, &lhs, &rhs, beta, &mut output)
97 }
98 }
99}
100
101impl<Input> TensorComplexScalePrims<Input> for CpuBackend
102where
103 Input: ComplexFloat
104 + Scalar
105 + 'static
106 + std::ops::Add<Output = Input>
107 + std::ops::Mul<Input::Real, Output = Input>
108 + std::ops::Mul<Output = Input>,
109 Input::Real: Scalar + Send + Sync,
110{
111 type Plan = CpuComplexScalePlan;
112 type Context = CpuContext;
113
114 fn plan(
115 _ctx: &mut Self::Context,
116 desc: &ComplexScalePrimsDescriptor,
117 shapes: &[&[usize]],
118 ) -> Result<Self::Plan> {
119 plan_complex_scale(desc, shapes)
120 }
121
122 fn execute(
123 _ctx: &mut Self::Context,
124 plan: &Self::Plan,
125 alpha: Input,
126 lhs: &Tensor<Input>,
127 rhs: &Tensor<Input::Real>,
128 beta: Input,
129 output: &mut Tensor<Input>,
130 ) -> Result<()> {
131 execute_complex_scale::<Input>(plan, alpha, lhs, rhs, beta, output)
132 }
133
134 fn has_complex_scale_support(desc: ComplexScalePrimsDescriptor) -> bool {
135 matches!(desc, ComplexScalePrimsDescriptor::PointwiseMul)
136 }
137}