tenferro_prims/families/
complex_scale.rs

1use num_complex::ComplexFloat;
2use tenferro_algebra::Scalar;
3use tenferro_device::Result;
4use tenferro_tensor::Tensor;
5
6#[cfg(not(feature = "cuda"))]
7use crate::{CudaBackend, CudaContext};
8use crate::{RocmBackend, RocmContext};
9
10/// Cross-dtype complex-by-real pointwise operations.
11///
12/// # Examples
13///
14/// ```rust
15/// use tenferro_prims::ComplexScalePrimsDescriptor;
16///
17/// let desc = ComplexScalePrimsDescriptor::PointwiseMul;
18/// assert!(matches!(desc, ComplexScalePrimsDescriptor::PointwiseMul));
19/// ```
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum ComplexScalePrimsDescriptor {
22    /// Multiply a complex tensor by a real tensor elementwise.
23    PointwiseMul,
24}
25
26/// Cross-dtype complex-by-real pointwise family.
27///
28/// The left-hand side and output are complex-valued, while the right-hand side
29/// is real-valued. The execution contract matches the rest of tenferro prims:
30/// `output <- alpha * (lhs * rhs) + beta * output`.
31///
32/// # Examples
33///
34/// ```ignore
35/// use num_complex::Complex64;
36/// use tenferro_prims::{
37///     ComplexScalePrimsDescriptor, CpuBackend, CpuContext, TensorComplexScalePrims,
38/// };
39///
40/// let mut ctx = CpuContext::new(1);
41/// let desc = ComplexScalePrimsDescriptor::PointwiseMul;
42/// let _plan = <CpuBackend as TensorComplexScalePrims<Complex64>>::plan(
43///     &mut ctx,
44///     &desc,
45///     &[&[2, 2], &[2, 2], &[2, 2]],
46/// )
47/// .unwrap();
48/// ```
49pub trait TensorComplexScalePrims<Input: ComplexFloat + Scalar>
50where
51    Input::Real: Scalar + Send + Sync,
52{
53    /// Backend plan type.
54    type Plan;
55    /// Backend execution context.
56    type Context;
57
58    /// Plan a complex-by-real pointwise operation for the given shapes.
59    fn plan(
60        ctx: &mut Self::Context,
61        desc: &ComplexScalePrimsDescriptor,
62        shapes: &[&[usize]],
63    ) -> Result<Self::Plan>;
64
65    /// Execute a previously planned complex-by-real pointwise operation.
66    fn execute(
67        ctx: &mut Self::Context,
68        plan: &Self::Plan,
69        alpha: Input,
70        lhs: &Tensor<Input>,
71        rhs: &Tensor<Input::Real>,
72        beta: Input,
73        output: &mut Tensor<Input>,
74    ) -> Result<()>;
75
76    /// Report whether the backend advertises support for the given descriptor.
77    fn has_complex_scale_support(desc: ComplexScalePrimsDescriptor) -> bool;
78}
79
80#[cfg(not(feature = "cuda"))]
81impl<Input> TensorComplexScalePrims<Input> for CudaBackend
82where
83    Input: ComplexFloat + Scalar,
84    Input::Real: Scalar + Send + Sync,
85{
86    type Plan = ();
87    type Context = CudaContext;
88
89    fn plan(
90        _ctx: &mut Self::Context,
91        desc: &ComplexScalePrimsDescriptor,
92        _shapes: &[&[usize]],
93    ) -> Result<Self::Plan> {
94        Err(tenferro_device::Error::InvalidArgument(format!(
95            "complex-scale family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
96        )))
97    }
98
99    fn execute(
100        _ctx: &mut Self::Context,
101        _plan: &Self::Plan,
102        _alpha: Input,
103        _lhs: &Tensor<Input>,
104        _rhs: &Tensor<Input::Real>,
105        _beta: Input,
106        _output: &mut Tensor<Input>,
107    ) -> Result<()> {
108        Err(tenferro_device::Error::InvalidArgument(
109            "complex-scale family execution is not implemented on CudaBackend in phase 1".into(),
110        ))
111    }
112
113    fn has_complex_scale_support(_desc: ComplexScalePrimsDescriptor) -> bool {
114        false
115    }
116}
117
118impl<Input> TensorComplexScalePrims<Input> for RocmBackend
119where
120    Input: ComplexFloat + Scalar,
121    Input::Real: Scalar + Send + Sync,
122{
123    type Plan = ();
124    type Context = RocmContext;
125
126    fn plan(
127        _ctx: &mut Self::Context,
128        desc: &ComplexScalePrimsDescriptor,
129        _shapes: &[&[usize]],
130    ) -> Result<Self::Plan> {
131        Err(tenferro_device::Error::InvalidArgument(format!(
132            "complex-scale family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
133        )))
134    }
135
136    fn execute(
137        _ctx: &mut Self::Context,
138        _plan: &Self::Plan,
139        _alpha: Input,
140        _lhs: &Tensor<Input>,
141        _rhs: &Tensor<Input::Real>,
142        _beta: Input,
143        _output: &mut Tensor<Input>,
144    ) -> Result<()> {
145        Err(tenferro_device::Error::InvalidArgument(
146            "complex-scale family execution is not implemented on RocmBackend in phase 1".into(),
147        ))
148    }
149
150    fn has_complex_scale_support(_desc: ComplexScalePrimsDescriptor) -> bool {
151        false
152    }
153}