tenferro_prims/families/
complex_scale.rs1use num_complex::ComplexFloat;
2use tenferro_algebra::Scalar;
3use tenferro_device::Result;
4use tenferro_tensor::Tensor;
5
6#[cfg(not(feature = "cuda"))]
7use crate::{CudaBackend, CudaContext};
8use crate::{RocmBackend, RocmContext};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum ComplexScalePrimsDescriptor {
22 PointwiseMul,
24}
25
26pub trait TensorComplexScalePrims<Input: ComplexFloat + Scalar>
50where
51 Input::Real: Scalar + Send + Sync,
52{
53 type Plan;
55 type Context;
57
58 fn plan(
60 ctx: &mut Self::Context,
61 desc: &ComplexScalePrimsDescriptor,
62 shapes: &[&[usize]],
63 ) -> Result<Self::Plan>;
64
65 fn execute(
67 ctx: &mut Self::Context,
68 plan: &Self::Plan,
69 alpha: Input,
70 lhs: &Tensor<Input>,
71 rhs: &Tensor<Input::Real>,
72 beta: Input,
73 output: &mut Tensor<Input>,
74 ) -> Result<()>;
75
76 fn has_complex_scale_support(desc: ComplexScalePrimsDescriptor) -> bool;
78}
79
80#[cfg(not(feature = "cuda"))]
81impl<Input> TensorComplexScalePrims<Input> for CudaBackend
82where
83 Input: ComplexFloat + Scalar,
84 Input::Real: Scalar + Send + Sync,
85{
86 type Plan = ();
87 type Context = CudaContext;
88
89 fn plan(
90 _ctx: &mut Self::Context,
91 desc: &ComplexScalePrimsDescriptor,
92 _shapes: &[&[usize]],
93 ) -> Result<Self::Plan> {
94 Err(tenferro_device::Error::InvalidArgument(format!(
95 "complex-scale family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
96 )))
97 }
98
99 fn execute(
100 _ctx: &mut Self::Context,
101 _plan: &Self::Plan,
102 _alpha: Input,
103 _lhs: &Tensor<Input>,
104 _rhs: &Tensor<Input::Real>,
105 _beta: Input,
106 _output: &mut Tensor<Input>,
107 ) -> Result<()> {
108 Err(tenferro_device::Error::InvalidArgument(
109 "complex-scale family execution is not implemented on CudaBackend in phase 1".into(),
110 ))
111 }
112
113 fn has_complex_scale_support(_desc: ComplexScalePrimsDescriptor) -> bool {
114 false
115 }
116}
117
118impl<Input> TensorComplexScalePrims<Input> for RocmBackend
119where
120 Input: ComplexFloat + Scalar,
121 Input::Real: Scalar + Send + Sync,
122{
123 type Plan = ();
124 type Context = RocmContext;
125
126 fn plan(
127 _ctx: &mut Self::Context,
128 desc: &ComplexScalePrimsDescriptor,
129 _shapes: &[&[usize]],
130 ) -> Result<Self::Plan> {
131 Err(tenferro_device::Error::InvalidArgument(format!(
132 "complex-scale family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
133 )))
134 }
135
136 fn execute(
137 _ctx: &mut Self::Context,
138 _plan: &Self::Plan,
139 _alpha: Input,
140 _lhs: &Tensor<Input>,
141 _rhs: &Tensor<Input::Real>,
142 _beta: Input,
143 _output: &mut Tensor<Input>,
144 ) -> Result<()> {
145 Err(tenferro_device::Error::InvalidArgument(
146 "complex-scale family execution is not implemented on RocmBackend in phase 1".into(),
147 ))
148 }
149
150 fn has_complex_scale_support(_desc: ComplexScalePrimsDescriptor) -> bool {
151 false
152 }
153}