tenferro_prims/families/
complex_real.rs

1use num_complex::ComplexFloat;
2use tenferro_algebra::Scalar;
3use tenferro_device::{Error, Result};
4use tenferro_tensor::Tensor;
5
6#[cfg(not(feature = "cuda"))]
7use crate::{CudaBackend, CudaContext};
8use crate::{RocmBackend, RocmContext, ScalarReductionOp};
9
10/// Cross-dtype complex-to-real unary operations.
11///
12/// # Examples
13///
14/// ```rust
15/// use tenferro_prims::ComplexRealUnaryOp;
16///
17/// let op = ComplexRealUnaryOp::Abs;
18/// assert_eq!(op, ComplexRealUnaryOp::Abs);
19/// let op = ComplexRealUnaryOp::Real;
20/// assert_eq!(op, ComplexRealUnaryOp::Real);
21/// let op = ComplexRealUnaryOp::Imag;
22/// assert_eq!(op, ComplexRealUnaryOp::Imag);
23/// ```
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum ComplexRealUnaryOp {
26    Abs,
27    Real,
28    Imag,
29}
30
31/// Descriptor for complex-to-real planning.
32///
33/// # Examples
34///
35/// ```rust
36/// use tenferro_prims::{ComplexRealPrimsDescriptor, ComplexRealUnaryOp};
37///
38/// let desc = ComplexRealPrimsDescriptor::PointwiseUnary {
39///     op: ComplexRealUnaryOp::Abs,
40/// };
41/// assert!(matches!(desc, ComplexRealPrimsDescriptor::PointwiseUnary { .. }));
42///
43/// let desc = ComplexRealPrimsDescriptor::Reduction {
44///     modes_a: vec![0, 1],
45///     modes_c: vec![1],
46///     unary_op: ComplexRealUnaryOp::Abs,
47///     reduction_op: tenferro_prims::ScalarReductionOp::Sum,
48/// };
49/// assert!(matches!(desc, ComplexRealPrimsDescriptor::Reduction { .. }));
50/// ```
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum ComplexRealPrimsDescriptor {
53    /// Apply a complex-to-real unary operation to one input tensor.
54    PointwiseUnary {
55        /// The unary operation to apply.
56        op: ComplexRealUnaryOp,
57    },
58    /// Apply a complex-to-real unary map and then reduce over selected axes.
59    Reduction {
60        /// Input modes associated with the source tensor.
61        modes_a: Vec<u32>,
62        /// Output modes that remain after reduction.
63        modes_c: Vec<u32>,
64        /// The unary operation to apply before reduction.
65        unary_op: ComplexRealUnaryOp,
66        /// The reduction operator to apply to the real-valued result.
67        reduction_op: ScalarReductionOp,
68    },
69}
70
71/// Cross-dtype complex-to-real unary protocol family.
72///
73/// The input tensor is complex-valued and the output tensor is real-valued.
74///
75/// # Examples
76///
77/// ```ignore
78/// use num_complex::Complex64;
79/// use tenferro_prims::{
80///     ComplexRealPrimsDescriptor, ComplexRealUnaryOp, CpuBackend, CpuContext,
81///     TensorComplexRealPrims,
82/// };
83///
84/// let mut ctx = CpuContext::new(1);
85/// let desc = ComplexRealPrimsDescriptor::PointwiseUnary {
86///     op: ComplexRealUnaryOp::Abs,
87/// };
88/// let _plan = <CpuBackend as TensorComplexRealPrims<Complex64>>::plan(
89///     &mut ctx,
90///     &desc,
91///     &[&[2, 2], &[2, 2]],
92/// )
93/// .unwrap();
94/// ```
95pub trait TensorComplexRealPrims<Input: ComplexFloat + Scalar> {
96    /// The real-valued output scalar type.
97    type Real: Scalar + Send + Sync;
98    /// Backend plan type.
99    type Plan;
100    /// Backend execution context.
101    type Context;
102
103    /// Plan a complex-to-real unary operation for the given input/output shapes.
104    fn plan(
105        ctx: &mut Self::Context,
106        desc: &ComplexRealPrimsDescriptor,
107        shapes: &[&[usize]],
108    ) -> Result<Self::Plan>;
109
110    /// Execute a previously planned complex-to-real unary operation.
111    ///
112    /// The execution contract matches the rest of tenferro prims:
113    /// `output <- alpha * op(inputs) + beta * output`.
114    fn execute(
115        ctx: &mut Self::Context,
116        plan: &Self::Plan,
117        alpha: Input::Real,
118        inputs: &[&Tensor<Input>],
119        beta: Input::Real,
120        output: &mut Tensor<Self::Real>,
121    ) -> Result<()>;
122
123    /// Report whether the backend advertises support for the given descriptor.
124    fn has_complex_real_support(desc: ComplexRealPrimsDescriptor) -> bool;
125}
126
127#[cfg(not(feature = "cuda"))]
128impl<Input> TensorComplexRealPrims<Input> for CudaBackend
129where
130    Input: ComplexFloat + Scalar,
131    Input::Real: Scalar,
132{
133    type Real = Input::Real;
134    type Plan = ();
135    type Context = CudaContext;
136
137    fn plan(
138        _ctx: &mut Self::Context,
139        desc: &ComplexRealPrimsDescriptor,
140        _shapes: &[&[usize]],
141    ) -> Result<Self::Plan> {
142        Err(Error::InvalidArgument(format!(
143            "complex-real family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
144        )))
145    }
146
147    fn execute(
148        _ctx: &mut Self::Context,
149        _plan: &Self::Plan,
150        _alpha: Input::Real,
151        _inputs: &[&Tensor<Input>],
152        _beta: Input::Real,
153        _output: &mut Tensor<Self::Real>,
154    ) -> Result<()> {
155        Err(Error::InvalidArgument(
156            "complex-real family execution is not implemented on CudaBackend in phase 1".into(),
157        ))
158    }
159
160    fn has_complex_real_support(_desc: ComplexRealPrimsDescriptor) -> bool {
161        false
162    }
163}
164
165impl<Input> TensorComplexRealPrims<Input> for RocmBackend
166where
167    Input: ComplexFloat + Scalar,
168    Input::Real: Scalar,
169{
170    type Real = Input::Real;
171    type Plan = ();
172    type Context = RocmContext;
173
174    fn plan(
175        _ctx: &mut Self::Context,
176        desc: &ComplexRealPrimsDescriptor,
177        _shapes: &[&[usize]],
178    ) -> Result<Self::Plan> {
179        Err(Error::InvalidArgument(format!(
180            "complex-real family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
181        )))
182    }
183
184    fn execute(
185        _ctx: &mut Self::Context,
186        _plan: &Self::Plan,
187        _alpha: Input::Real,
188        _inputs: &[&Tensor<Input>],
189        _beta: Input::Real,
190        _output: &mut Tensor<Self::Real>,
191    ) -> Result<()> {
192        Err(Error::InvalidArgument(
193            "complex-real family execution is not implemented on RocmBackend in phase 1".into(),
194        ))
195    }
196
197    fn has_complex_real_support(_desc: ComplexRealPrimsDescriptor) -> bool {
198        false
199    }
200}