tenferro_prims/families/
scalar.rs

1use tenferro_algebra::{Algebra, Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext};
7use crate::{RocmBackend, RocmContext};
8
9/// Pointwise scalar unary operations.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_prims::ScalarUnaryOp;
15///
16/// let op = ScalarUnaryOp::Reciprocal;
17/// assert_eq!(op, ScalarUnaryOp::Reciprocal);
18/// ```
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum ScalarUnaryOp {
21    Neg,
22    Conj,
23    Abs,
24    Reciprocal,
25    Real,
26    Imag,
27    Square,
28}
29
30/// Pointwise scalar binary operations.
31///
32/// Ordered-real comparison operators return numeric masks in the same scalar
33/// dtype as their inputs: `1` where the predicate holds and `0` otherwise.
34///
35/// # Examples
36///
37/// ```
38/// use tenferro_prims::ScalarBinaryOp;
39///
40/// let op = ScalarBinaryOp::Mul;
41/// assert_eq!(op, ScalarBinaryOp::Mul);
42/// ```
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum ScalarBinaryOp {
45    Add,
46    Sub,
47    Mul,
48    Div,
49    Maximum,
50    Minimum,
51    Greater,
52    GreaterEqual,
53    ClampMin,
54    ClampMax,
55}
56
57/// Pointwise scalar ternary operations.
58///
59/// # Examples
60///
61/// ```
62/// use tenferro_prims::ScalarTernaryOp;
63///
64/// let op = ScalarTernaryOp::Where;
65/// assert_eq!(op, ScalarTernaryOp::Where);
66/// ```
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub enum ScalarTernaryOp {
69    Where,
70}
71
72/// Scalar reduction operations.
73///
74/// # Examples
75///
76/// ```
77/// use tenferro_prims::ScalarReductionOp;
78///
79/// let op = ScalarReductionOp::Sum;
80/// assert_eq!(op, ScalarReductionOp::Sum);
81/// ```
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
83pub enum ScalarReductionOp {
84    Sum,
85    Prod,
86    Mean,
87    Max,
88    Min,
89}
90
91/// Descriptor for scalar-pointwise and scalar-reduction planning.
92///
93/// # Examples
94///
95/// ```
96/// use tenferro_prims::{ScalarPrimsDescriptor, ScalarUnaryOp};
97///
98/// let desc = ScalarPrimsDescriptor::PointwiseUnary {
99///     op: ScalarUnaryOp::Reciprocal,
100/// };
101/// assert!(matches!(desc, ScalarPrimsDescriptor::PointwiseUnary { .. }));
102/// ```
103#[derive(Debug, Clone, PartialEq, Eq, Hash)]
104pub enum ScalarPrimsDescriptor {
105    /// Apply a unary pointwise operation to one input tensor.
106    PointwiseUnary {
107        /// The unary scalar operation to apply.
108        op: ScalarUnaryOp,
109    },
110    /// Apply a binary pointwise operation to two input tensors.
111    PointwiseBinary {
112        /// The binary scalar operation to apply.
113        op: ScalarBinaryOp,
114    },
115    /// Apply a ternary pointwise operation to three input tensors.
116    PointwiseTernary {
117        /// The ternary scalar operation to apply.
118        op: ScalarTernaryOp,
119    },
120    /// Reduce one tensor into an output tensor over the dropped modes.
121    Reduction {
122        /// Input modes associated with the source tensor.
123        modes_a: Vec<u32>,
124        /// Output modes that remain after reduction.
125        modes_c: Vec<u32>,
126        /// Reduction operator to use.
127        op: ScalarReductionOp,
128    },
129}
130
131/// Scalar pointwise and reduction protocol family.
132///
133/// # Examples
134///
135/// ```ignore
136/// use tenferro_algebra::Standard;
137/// use tenferro_prims::{CpuBackend, CpuContext, ScalarPrimsDescriptor, ScalarUnaryOp, TensorScalarPrims};
138///
139/// let mut ctx = CpuContext::new(1);
140/// let desc = ScalarPrimsDescriptor::PointwiseUnary {
141///     op: ScalarUnaryOp::Reciprocal,
142/// };
143/// let _plan = <CpuBackend as TensorScalarPrims<Standard<f64>>>::plan(
144///     &mut ctx,
145///     &desc,
146///     &[&[2, 2], &[2, 2]],
147/// )
148/// .unwrap();
149/// ```
150pub trait TensorScalarPrims<Alg: Algebra> {
151    type Plan;
152    type Context;
153
154    /// Plan a scalar-family operation for the given input/output shapes.
155    ///
156    /// Backends may reject descriptors that are reserved in the public
157    /// vocabulary but not yet wired to the current execution substrate.
158    fn plan(
159        ctx: &mut Self::Context,
160        desc: &ScalarPrimsDescriptor,
161        shapes: &[&[usize]],
162    ) -> Result<Self::Plan>;
163
164    /// Execute a previously planned scalar-family operation.
165    ///
166    /// The execution contract matches the rest of tenferro prims:
167    /// `output <- alpha * op(inputs) + beta * output`.
168    fn execute(
169        ctx: &mut Self::Context,
170        plan: &Self::Plan,
171        alpha: Alg::Scalar,
172        inputs: &[&Tensor<Alg::Scalar>],
173        beta: Alg::Scalar,
174        output: &mut Tensor<Alg::Scalar>,
175    ) -> Result<()>;
176
177    /// Report whether the backend advertises support for the given descriptor.
178    ///
179    /// This query is about the backend family surface, not whether a specific
180    /// shape instance is valid.
181    fn has_scalar_support(desc: ScalarPrimsDescriptor) -> bool;
182}
183
184#[cfg(not(feature = "cuda"))]
185impl<S: Scalar> TensorScalarPrims<Standard<S>> for CudaBackend {
186    type Plan = ();
187    type Context = CudaContext;
188
189    fn plan(
190        _ctx: &mut Self::Context,
191        desc: &ScalarPrimsDescriptor,
192        _shapes: &[&[usize]],
193    ) -> Result<Self::Plan> {
194        Err(Error::InvalidArgument(format!(
195            "scalar family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
196        )))
197    }
198
199    fn execute(
200        _ctx: &mut Self::Context,
201        _plan: &Self::Plan,
202        _alpha: S,
203        _inputs: &[&Tensor<S>],
204        _beta: S,
205        _output: &mut Tensor<S>,
206    ) -> Result<()> {
207        Err(Error::InvalidArgument(
208            "scalar family execution is not implemented on CudaBackend in phase 1".into(),
209        ))
210    }
211
212    fn has_scalar_support(_desc: ScalarPrimsDescriptor) -> bool {
213        false
214    }
215}
216
217impl<S: Scalar> TensorScalarPrims<Standard<S>> for RocmBackend {
218    type Plan = ();
219    type Context = RocmContext;
220
221    fn plan(
222        _ctx: &mut Self::Context,
223        desc: &ScalarPrimsDescriptor,
224        _shapes: &[&[usize]],
225    ) -> Result<Self::Plan> {
226        Err(Error::InvalidArgument(format!(
227            "scalar family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
228        )))
229    }
230
231    fn execute(
232        _ctx: &mut Self::Context,
233        _plan: &Self::Plan,
234        _alpha: S,
235        _inputs: &[&Tensor<S>],
236        _beta: S,
237        _output: &mut Tensor<S>,
238    ) -> Result<()> {
239        Err(Error::InvalidArgument(
240            "scalar family execution is not implemented on RocmBackend in phase 1".into(),
241        ))
242    }
243
244    fn has_scalar_support(_desc: ScalarPrimsDescriptor) -> bool {
245        false
246    }
247}