tenferro_prims/families/complex_real.rs
1use num_complex::ComplexFloat;
2use tenferro_algebra::Scalar;
3use tenferro_device::{Error, Result};
4use tenferro_tensor::Tensor;
5
6#[cfg(not(feature = "cuda"))]
7use crate::{CudaBackend, CudaContext};
8use crate::{RocmBackend, RocmContext, ScalarReductionOp};
9
10/// Cross-dtype complex-to-real unary operations.
11///
12/// # Examples
13///
14/// ```rust
15/// use tenferro_prims::ComplexRealUnaryOp;
16///
17/// let op = ComplexRealUnaryOp::Abs;
18/// assert_eq!(op, ComplexRealUnaryOp::Abs);
19/// let op = ComplexRealUnaryOp::Real;
20/// assert_eq!(op, ComplexRealUnaryOp::Real);
21/// let op = ComplexRealUnaryOp::Imag;
22/// assert_eq!(op, ComplexRealUnaryOp::Imag);
23/// ```
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum ComplexRealUnaryOp {
26 Abs,
27 Real,
28 Imag,
29}
30
31/// Descriptor for complex-to-real planning.
32///
33/// # Examples
34///
35/// ```rust
36/// use tenferro_prims::{ComplexRealPrimsDescriptor, ComplexRealUnaryOp};
37///
38/// let desc = ComplexRealPrimsDescriptor::PointwiseUnary {
39/// op: ComplexRealUnaryOp::Abs,
40/// };
41/// assert!(matches!(desc, ComplexRealPrimsDescriptor::PointwiseUnary { .. }));
42///
43/// let desc = ComplexRealPrimsDescriptor::Reduction {
44/// modes_a: vec![0, 1],
45/// modes_c: vec![1],
46/// unary_op: ComplexRealUnaryOp::Abs,
47/// reduction_op: tenferro_prims::ScalarReductionOp::Sum,
48/// };
49/// assert!(matches!(desc, ComplexRealPrimsDescriptor::Reduction { .. }));
50/// ```
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum ComplexRealPrimsDescriptor {
53 /// Apply a complex-to-real unary operation to one input tensor.
54 PointwiseUnary {
55 /// The unary operation to apply.
56 op: ComplexRealUnaryOp,
57 },
58 /// Apply a complex-to-real unary map and then reduce over selected axes.
59 Reduction {
60 /// Input modes associated with the source tensor.
61 modes_a: Vec<u32>,
62 /// Output modes that remain after reduction.
63 modes_c: Vec<u32>,
64 /// The unary operation to apply before reduction.
65 unary_op: ComplexRealUnaryOp,
66 /// The reduction operator to apply to the real-valued result.
67 reduction_op: ScalarReductionOp,
68 },
69}
70
71/// Cross-dtype complex-to-real unary protocol family.
72///
73/// The input tensor is complex-valued and the output tensor is real-valued.
74///
75/// # Examples
76///
77/// ```ignore
78/// use num_complex::Complex64;
79/// use tenferro_prims::{
80/// ComplexRealPrimsDescriptor, ComplexRealUnaryOp, CpuBackend, CpuContext,
81/// TensorComplexRealPrims,
82/// };
83///
84/// let mut ctx = CpuContext::new(1);
85/// let desc = ComplexRealPrimsDescriptor::PointwiseUnary {
86/// op: ComplexRealUnaryOp::Abs,
87/// };
88/// let _plan = <CpuBackend as TensorComplexRealPrims<Complex64>>::plan(
89/// &mut ctx,
90/// &desc,
91/// &[&[2, 2], &[2, 2]],
92/// )
93/// .unwrap();
94/// ```
95pub trait TensorComplexRealPrims<Input: ComplexFloat + Scalar> {
96 /// The real-valued output scalar type.
97 type Real: Scalar + Send + Sync;
98 /// Backend plan type.
99 type Plan;
100 /// Backend execution context.
101 type Context;
102
103 /// Plan a complex-to-real unary operation for the given input/output shapes.
104 fn plan(
105 ctx: &mut Self::Context,
106 desc: &ComplexRealPrimsDescriptor,
107 shapes: &[&[usize]],
108 ) -> Result<Self::Plan>;
109
110 /// Execute a previously planned complex-to-real unary operation.
111 ///
112 /// The execution contract matches the rest of tenferro prims:
113 /// `output <- alpha * op(inputs) + beta * output`.
114 fn execute(
115 ctx: &mut Self::Context,
116 plan: &Self::Plan,
117 alpha: Input::Real,
118 inputs: &[&Tensor<Input>],
119 beta: Input::Real,
120 output: &mut Tensor<Self::Real>,
121 ) -> Result<()>;
122
123 /// Report whether the backend advertises support for the given descriptor.
124 fn has_complex_real_support(desc: ComplexRealPrimsDescriptor) -> bool;
125}
126
127#[cfg(not(feature = "cuda"))]
128impl<Input> TensorComplexRealPrims<Input> for CudaBackend
129where
130 Input: ComplexFloat + Scalar,
131 Input::Real: Scalar,
132{
133 type Real = Input::Real;
134 type Plan = ();
135 type Context = CudaContext;
136
137 fn plan(
138 _ctx: &mut Self::Context,
139 desc: &ComplexRealPrimsDescriptor,
140 _shapes: &[&[usize]],
141 ) -> Result<Self::Plan> {
142 Err(Error::InvalidArgument(format!(
143 "complex-real family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
144 )))
145 }
146
147 fn execute(
148 _ctx: &mut Self::Context,
149 _plan: &Self::Plan,
150 _alpha: Input::Real,
151 _inputs: &[&Tensor<Input>],
152 _beta: Input::Real,
153 _output: &mut Tensor<Self::Real>,
154 ) -> Result<()> {
155 Err(Error::InvalidArgument(
156 "complex-real family execution is not implemented on CudaBackend in phase 1".into(),
157 ))
158 }
159
160 fn has_complex_real_support(_desc: ComplexRealPrimsDescriptor) -> bool {
161 false
162 }
163}
164
165impl<Input> TensorComplexRealPrims<Input> for RocmBackend
166where
167 Input: ComplexFloat + Scalar,
168 Input::Real: Scalar,
169{
170 type Real = Input::Real;
171 type Plan = ();
172 type Context = RocmContext;
173
174 fn plan(
175 _ctx: &mut Self::Context,
176 desc: &ComplexRealPrimsDescriptor,
177 _shapes: &[&[usize]],
178 ) -> Result<Self::Plan> {
179 Err(Error::InvalidArgument(format!(
180 "complex-real family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
181 )))
182 }
183
184 fn execute(
185 _ctx: &mut Self::Context,
186 _plan: &Self::Plan,
187 _alpha: Input::Real,
188 _inputs: &[&Tensor<Input>],
189 _beta: Input::Real,
190 _output: &mut Tensor<Self::Real>,
191 ) -> Result<()> {
192 Err(Error::InvalidArgument(
193 "complex-real family execution is not implemented on RocmBackend in phase 1".into(),
194 ))
195 }
196
197 fn has_complex_real_support(_desc: ComplexRealPrimsDescriptor) -> bool {
198 false
199 }
200}