tenferro_prims/families/
analytic.rs

1use tenferro_algebra::{Algebra, Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext};
7use crate::{RocmBackend, RocmContext};
8
9/// Analytic unary operations.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_prims::AnalyticUnaryOp;
15///
16/// let op = AnalyticUnaryOp::Sqrt;
17/// assert_eq!(op, AnalyticUnaryOp::Sqrt);
18/// ```
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum AnalyticUnaryOp {
21    Sqrt,
22    Rsqrt,
23    Exp,
24    Expm1,
25    Ceil,
26    Log,
27    Log1p,
28    Sin,
29    Cos,
30    Tan,
31    Tanh,
32    Asin,
33    Acos,
34    Atan,
35    Sinh,
36    Cosh,
37    Asinh,
38    Acosh,
39    Atanh,
40}
41
42/// Analytic binary operations.
43///
44/// # Examples
45///
46/// ```
47/// use tenferro_prims::AnalyticBinaryOp;
48///
49/// let op = AnalyticBinaryOp::Pow;
50/// assert_eq!(op, AnalyticBinaryOp::Pow);
51/// ```
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum AnalyticBinaryOp {
54    Pow,
55    Atan2,
56    Hypot,
57    Xlogy,
58}
59
60/// Analytic reduction operations.
61///
62/// # Examples
63///
64/// ```
65/// use tenferro_prims::AnalyticReductionOp;
66///
67/// let op = AnalyticReductionOp::Var;
68/// assert_eq!(op, AnalyticReductionOp::Var);
69/// ```
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum AnalyticReductionOp {
72    Var,
73    Std,
74}
75
76/// Descriptor for analytic-pointwise and analytic-reduction planning.
77///
78/// # Examples
79///
80/// ```
81/// use tenferro_prims::{AnalyticPrimsDescriptor, AnalyticUnaryOp};
82///
83/// let desc = AnalyticPrimsDescriptor::PointwiseUnary {
84///     op: AnalyticUnaryOp::Sqrt,
85/// };
86/// assert!(matches!(desc, AnalyticPrimsDescriptor::PointwiseUnary { .. }));
87/// ```
88#[derive(Debug, Clone, PartialEq, Eq, Hash)]
89pub enum AnalyticPrimsDescriptor {
90    /// Apply an analytic unary pointwise operation to one input tensor.
91    PointwiseUnary {
92        /// The unary analytic operation to apply.
93        op: AnalyticUnaryOp,
94    },
95    /// Apply an analytic binary pointwise operation to two input tensors.
96    PointwiseBinary {
97        /// The binary analytic operation to apply.
98        op: AnalyticBinaryOp,
99    },
100    /// Reduce one tensor into an output tensor over the dropped modes.
101    Reduction {
102        /// Input modes associated with the source tensor.
103        modes_a: Vec<u32>,
104        /// Output modes that remain after reduction.
105        modes_c: Vec<u32>,
106        /// Reduction operator to use.
107        op: AnalyticReductionOp,
108    },
109}
110
111/// Analytic pointwise and reduction protocol family.
112///
113/// # Examples
114///
115/// ```ignore
116/// use tenferro_algebra::Standard;
117/// use tenferro_prims::{AnalyticPrimsDescriptor, AnalyticUnaryOp, CpuBackend, CpuContext, TensorAnalyticPrims};
118///
119/// let mut ctx = CpuContext::new(1);
120/// let desc = AnalyticPrimsDescriptor::PointwiseUnary {
121///     op: AnalyticUnaryOp::Sqrt,
122/// };
123/// let _plan = <CpuBackend as TensorAnalyticPrims<Standard<f64>>>::plan(
124///     &mut ctx,
125///     &desc,
126///     &[&[2, 2], &[2, 2]],
127/// )
128/// .unwrap();
129/// ```
130pub trait TensorAnalyticPrims<Alg: Algebra> {
131    type Plan;
132    type Context;
133
134    /// Plan an analytic-family operation for the given input/output shapes.
135    ///
136    /// Public vocabulary may be broader than the currently wired execution
137    /// surface so later backend work can land without descriptor churn.
138    fn plan(
139        ctx: &mut Self::Context,
140        desc: &AnalyticPrimsDescriptor,
141        shapes: &[&[usize]],
142    ) -> Result<Self::Plan>;
143
144    /// Execute a previously planned analytic-family operation.
145    ///
146    /// The execution contract matches the rest of tenferro prims:
147    /// `output <- alpha * op(inputs) + beta * output`.
148    fn execute(
149        ctx: &mut Self::Context,
150        plan: &Self::Plan,
151        alpha: Alg::Scalar,
152        inputs: &[&Tensor<Alg::Scalar>],
153        beta: Alg::Scalar,
154        output: &mut Tensor<Alg::Scalar>,
155    ) -> Result<()>;
156
157    /// Report whether the backend advertises support for the given descriptor.
158    ///
159    /// This is a family-level capability check and does not validate every
160    /// shape-specific precondition.
161    fn has_analytic_support(desc: AnalyticPrimsDescriptor) -> bool;
162}
163
164#[cfg(not(feature = "cuda"))]
165impl<S: Scalar> TensorAnalyticPrims<Standard<S>> for CudaBackend {
166    type Plan = ();
167    type Context = CudaContext;
168
169    fn plan(
170        _ctx: &mut Self::Context,
171        desc: &AnalyticPrimsDescriptor,
172        _shapes: &[&[usize]],
173    ) -> Result<Self::Plan> {
174        Err(Error::InvalidArgument(format!(
175            "analytic family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
176        )))
177    }
178
179    fn execute(
180        _ctx: &mut Self::Context,
181        _plan: &Self::Plan,
182        _alpha: S,
183        _inputs: &[&Tensor<S>],
184        _beta: S,
185        _output: &mut Tensor<S>,
186    ) -> Result<()> {
187        Err(Error::InvalidArgument(
188            "analytic family execution is not implemented on CudaBackend in phase 1".into(),
189        ))
190    }
191
192    fn has_analytic_support(_desc: AnalyticPrimsDescriptor) -> bool {
193        false
194    }
195}
196
197impl<S: Scalar> TensorAnalyticPrims<Standard<S>> for RocmBackend {
198    type Plan = ();
199    type Context = RocmContext;
200
201    fn plan(
202        _ctx: &mut Self::Context,
203        desc: &AnalyticPrimsDescriptor,
204        _shapes: &[&[usize]],
205    ) -> Result<Self::Plan> {
206        Err(Error::InvalidArgument(format!(
207            "analytic family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
208        )))
209    }
210
211    fn execute(
212        _ctx: &mut Self::Context,
213        _plan: &Self::Plan,
214        _alpha: S,
215        _inputs: &[&Tensor<S>],
216        _beta: S,
217        _output: &mut Tensor<S>,
218    ) -> Result<()> {
219        Err(Error::InvalidArgument(
220            "analytic family execution is not implemented on RocmBackend in phase 1".into(),
221        ))
222    }
223
224    fn has_analytic_support(_desc: AnalyticPrimsDescriptor) -> bool {
225        false
226    }
227}