tenferro_prims/families/scalar.rs
1use tenferro_algebra::{Algebra, Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext};
7use crate::{RocmBackend, RocmContext};
8
9/// Pointwise scalar unary operations.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_prims::ScalarUnaryOp;
15///
16/// let op = ScalarUnaryOp::Reciprocal;
17/// assert_eq!(op, ScalarUnaryOp::Reciprocal);
18/// ```
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum ScalarUnaryOp {
21 Neg,
22 Conj,
23 Abs,
24 Reciprocal,
25 Real,
26 Imag,
27 Square,
28}
29
30/// Pointwise scalar binary operations.
31///
32/// Ordered-real comparison operators return numeric masks in the same scalar
33/// dtype as their inputs: `1` where the predicate holds and `0` otherwise.
34///
35/// # Examples
36///
37/// ```
38/// use tenferro_prims::ScalarBinaryOp;
39///
40/// let op = ScalarBinaryOp::Mul;
41/// assert_eq!(op, ScalarBinaryOp::Mul);
42/// ```
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum ScalarBinaryOp {
45 Add,
46 Sub,
47 Mul,
48 Div,
49 Maximum,
50 Minimum,
51 Greater,
52 GreaterEqual,
53 ClampMin,
54 ClampMax,
55}
56
57/// Pointwise scalar ternary operations.
58///
59/// # Examples
60///
61/// ```
62/// use tenferro_prims::ScalarTernaryOp;
63///
64/// let op = ScalarTernaryOp::Where;
65/// assert_eq!(op, ScalarTernaryOp::Where);
66/// ```
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub enum ScalarTernaryOp {
69 Where,
70}
71
72/// Scalar reduction operations.
73///
74/// # Examples
75///
76/// ```
77/// use tenferro_prims::ScalarReductionOp;
78///
79/// let op = ScalarReductionOp::Sum;
80/// assert_eq!(op, ScalarReductionOp::Sum);
81/// ```
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
83pub enum ScalarReductionOp {
84 Sum,
85 Prod,
86 Mean,
87 Max,
88 Min,
89}
90
91/// Descriptor for scalar-pointwise and scalar-reduction planning.
92///
93/// # Examples
94///
95/// ```
96/// use tenferro_prims::{ScalarPrimsDescriptor, ScalarUnaryOp};
97///
98/// let desc = ScalarPrimsDescriptor::PointwiseUnary {
99/// op: ScalarUnaryOp::Reciprocal,
100/// };
101/// assert!(matches!(desc, ScalarPrimsDescriptor::PointwiseUnary { .. }));
102/// ```
103#[derive(Debug, Clone, PartialEq, Eq, Hash)]
104pub enum ScalarPrimsDescriptor {
105 /// Apply a unary pointwise operation to one input tensor.
106 PointwiseUnary {
107 /// The unary scalar operation to apply.
108 op: ScalarUnaryOp,
109 },
110 /// Apply a binary pointwise operation to two input tensors.
111 PointwiseBinary {
112 /// The binary scalar operation to apply.
113 op: ScalarBinaryOp,
114 },
115 /// Apply a ternary pointwise operation to three input tensors.
116 PointwiseTernary {
117 /// The ternary scalar operation to apply.
118 op: ScalarTernaryOp,
119 },
120 /// Reduce one tensor into an output tensor over the dropped modes.
121 Reduction {
122 /// Input modes associated with the source tensor.
123 modes_a: Vec<u32>,
124 /// Output modes that remain after reduction.
125 modes_c: Vec<u32>,
126 /// Reduction operator to use.
127 op: ScalarReductionOp,
128 },
129}
130
131/// Scalar pointwise and reduction protocol family.
132///
133/// # Examples
134///
135/// ```ignore
136/// use tenferro_algebra::Standard;
137/// use tenferro_prims::{CpuBackend, CpuContext, ScalarPrimsDescriptor, ScalarUnaryOp, TensorScalarPrims};
138///
139/// let mut ctx = CpuContext::new(1);
140/// let desc = ScalarPrimsDescriptor::PointwiseUnary {
141/// op: ScalarUnaryOp::Reciprocal,
142/// };
143/// let _plan = <CpuBackend as TensorScalarPrims<Standard<f64>>>::plan(
144/// &mut ctx,
145/// &desc,
146/// &[&[2, 2], &[2, 2]],
147/// )
148/// .unwrap();
149/// ```
150pub trait TensorScalarPrims<Alg: Algebra> {
151 type Plan;
152 type Context;
153
154 /// Plan a scalar-family operation for the given input/output shapes.
155 ///
156 /// Backends may reject descriptors that are reserved in the public
157 /// vocabulary but not yet wired to the current execution substrate.
158 fn plan(
159 ctx: &mut Self::Context,
160 desc: &ScalarPrimsDescriptor,
161 shapes: &[&[usize]],
162 ) -> Result<Self::Plan>;
163
164 /// Execute a previously planned scalar-family operation.
165 ///
166 /// The execution contract matches the rest of tenferro prims:
167 /// `output <- alpha * op(inputs) + beta * output`.
168 fn execute(
169 ctx: &mut Self::Context,
170 plan: &Self::Plan,
171 alpha: Alg::Scalar,
172 inputs: &[&Tensor<Alg::Scalar>],
173 beta: Alg::Scalar,
174 output: &mut Tensor<Alg::Scalar>,
175 ) -> Result<()>;
176
177 /// Report whether the backend advertises support for the given descriptor.
178 ///
179 /// This query is about the backend family surface, not whether a specific
180 /// shape instance is valid.
181 fn has_scalar_support(desc: ScalarPrimsDescriptor) -> bool;
182}
183
184#[cfg(not(feature = "cuda"))]
185impl<S: Scalar> TensorScalarPrims<Standard<S>> for CudaBackend {
186 type Plan = ();
187 type Context = CudaContext;
188
189 fn plan(
190 _ctx: &mut Self::Context,
191 desc: &ScalarPrimsDescriptor,
192 _shapes: &[&[usize]],
193 ) -> Result<Self::Plan> {
194 Err(Error::InvalidArgument(format!(
195 "scalar family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
196 )))
197 }
198
199 fn execute(
200 _ctx: &mut Self::Context,
201 _plan: &Self::Plan,
202 _alpha: S,
203 _inputs: &[&Tensor<S>],
204 _beta: S,
205 _output: &mut Tensor<S>,
206 ) -> Result<()> {
207 Err(Error::InvalidArgument(
208 "scalar family execution is not implemented on CudaBackend in phase 1".into(),
209 ))
210 }
211
212 fn has_scalar_support(_desc: ScalarPrimsDescriptor) -> bool {
213 false
214 }
215}
216
217impl<S: Scalar> TensorScalarPrims<Standard<S>> for RocmBackend {
218 type Plan = ();
219 type Context = RocmContext;
220
221 fn plan(
222 _ctx: &mut Self::Context,
223 desc: &ScalarPrimsDescriptor,
224 _shapes: &[&[usize]],
225 ) -> Result<Self::Plan> {
226 Err(Error::InvalidArgument(format!(
227 "scalar family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
228 )))
229 }
230
231 fn execute(
232 _ctx: &mut Self::Context,
233 _plan: &Self::Plan,
234 _alpha: S,
235 _inputs: &[&Tensor<S>],
236 _beta: S,
237 _output: &mut Tensor<S>,
238 ) -> Result<()> {
239 Err(Error::InvalidArgument(
240 "scalar family execution is not implemented on RocmBackend in phase 1".into(),
241 ))
242 }
243
244 fn has_scalar_support(_desc: ScalarPrimsDescriptor) -> bool {
245 false
246 }
247}