tenferro_prims/families/analytic.rs
1use tenferro_algebra::{Algebra, Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext};
7use crate::{RocmBackend, RocmContext};
8
9/// Analytic unary operations.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_prims::AnalyticUnaryOp;
15///
16/// let op = AnalyticUnaryOp::Sqrt;
17/// assert_eq!(op, AnalyticUnaryOp::Sqrt);
18/// ```
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum AnalyticUnaryOp {
21 Sqrt,
22 Rsqrt,
23 Exp,
24 Expm1,
25 Ceil,
26 Log,
27 Log1p,
28 Sin,
29 Cos,
30 Tan,
31 Tanh,
32 Asin,
33 Acos,
34 Atan,
35 Sinh,
36 Cosh,
37 Asinh,
38 Acosh,
39 Atanh,
40}
41
42/// Analytic binary operations.
43///
44/// # Examples
45///
46/// ```
47/// use tenferro_prims::AnalyticBinaryOp;
48///
49/// let op = AnalyticBinaryOp::Pow;
50/// assert_eq!(op, AnalyticBinaryOp::Pow);
51/// ```
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum AnalyticBinaryOp {
54 Pow,
55 Atan2,
56 Hypot,
57 Xlogy,
58}
59
60/// Analytic reduction operations.
61///
62/// # Examples
63///
64/// ```
65/// use tenferro_prims::AnalyticReductionOp;
66///
67/// let op = AnalyticReductionOp::Var;
68/// assert_eq!(op, AnalyticReductionOp::Var);
69/// ```
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum AnalyticReductionOp {
72 Var,
73 Std,
74}
75
76/// Descriptor for analytic-pointwise and analytic-reduction planning.
77///
78/// # Examples
79///
80/// ```
81/// use tenferro_prims::{AnalyticPrimsDescriptor, AnalyticUnaryOp};
82///
83/// let desc = AnalyticPrimsDescriptor::PointwiseUnary {
84/// op: AnalyticUnaryOp::Sqrt,
85/// };
86/// assert!(matches!(desc, AnalyticPrimsDescriptor::PointwiseUnary { .. }));
87/// ```
88#[derive(Debug, Clone, PartialEq, Eq, Hash)]
89pub enum AnalyticPrimsDescriptor {
90 /// Apply an analytic unary pointwise operation to one input tensor.
91 PointwiseUnary {
92 /// The unary analytic operation to apply.
93 op: AnalyticUnaryOp,
94 },
95 /// Apply an analytic binary pointwise operation to two input tensors.
96 PointwiseBinary {
97 /// The binary analytic operation to apply.
98 op: AnalyticBinaryOp,
99 },
100 /// Reduce one tensor into an output tensor over the dropped modes.
101 Reduction {
102 /// Input modes associated with the source tensor.
103 modes_a: Vec<u32>,
104 /// Output modes that remain after reduction.
105 modes_c: Vec<u32>,
106 /// Reduction operator to use.
107 op: AnalyticReductionOp,
108 },
109}
110
111/// Analytic pointwise and reduction protocol family.
112///
113/// # Examples
114///
115/// ```ignore
116/// use tenferro_algebra::Standard;
117/// use tenferro_prims::{AnalyticPrimsDescriptor, AnalyticUnaryOp, CpuBackend, CpuContext, TensorAnalyticPrims};
118///
119/// let mut ctx = CpuContext::new(1);
120/// let desc = AnalyticPrimsDescriptor::PointwiseUnary {
121/// op: AnalyticUnaryOp::Sqrt,
122/// };
123/// let _plan = <CpuBackend as TensorAnalyticPrims<Standard<f64>>>::plan(
124/// &mut ctx,
125/// &desc,
126/// &[&[2, 2], &[2, 2]],
127/// )
128/// .unwrap();
129/// ```
130pub trait TensorAnalyticPrims<Alg: Algebra> {
131 type Plan;
132 type Context;
133
134 /// Plan an analytic-family operation for the given input/output shapes.
135 ///
136 /// Public vocabulary may be broader than the currently wired execution
137 /// surface so later backend work can land without descriptor churn.
138 fn plan(
139 ctx: &mut Self::Context,
140 desc: &AnalyticPrimsDescriptor,
141 shapes: &[&[usize]],
142 ) -> Result<Self::Plan>;
143
144 /// Execute a previously planned analytic-family operation.
145 ///
146 /// The execution contract matches the rest of tenferro prims:
147 /// `output <- alpha * op(inputs) + beta * output`.
148 fn execute(
149 ctx: &mut Self::Context,
150 plan: &Self::Plan,
151 alpha: Alg::Scalar,
152 inputs: &[&Tensor<Alg::Scalar>],
153 beta: Alg::Scalar,
154 output: &mut Tensor<Alg::Scalar>,
155 ) -> Result<()>;
156
157 /// Report whether the backend advertises support for the given descriptor.
158 ///
159 /// This is a family-level capability check and does not validate every
160 /// shape-specific precondition.
161 fn has_analytic_support(desc: AnalyticPrimsDescriptor) -> bool;
162}
163
164#[cfg(not(feature = "cuda"))]
165impl<S: Scalar> TensorAnalyticPrims<Standard<S>> for CudaBackend {
166 type Plan = ();
167 type Context = CudaContext;
168
169 fn plan(
170 _ctx: &mut Self::Context,
171 desc: &AnalyticPrimsDescriptor,
172 _shapes: &[&[usize]],
173 ) -> Result<Self::Plan> {
174 Err(Error::InvalidArgument(format!(
175 "analytic family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
176 )))
177 }
178
179 fn execute(
180 _ctx: &mut Self::Context,
181 _plan: &Self::Plan,
182 _alpha: S,
183 _inputs: &[&Tensor<S>],
184 _beta: S,
185 _output: &mut Tensor<S>,
186 ) -> Result<()> {
187 Err(Error::InvalidArgument(
188 "analytic family execution is not implemented on CudaBackend in phase 1".into(),
189 ))
190 }
191
192 fn has_analytic_support(_desc: AnalyticPrimsDescriptor) -> bool {
193 false
194 }
195}
196
197impl<S: Scalar> TensorAnalyticPrims<Standard<S>> for RocmBackend {
198 type Plan = ();
199 type Context = RocmContext;
200
201 fn plan(
202 _ctx: &mut Self::Context,
203 desc: &AnalyticPrimsDescriptor,
204 _shapes: &[&[usize]],
205 ) -> Result<Self::Plan> {
206 Err(Error::InvalidArgument(format!(
207 "analytic family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
208 )))
209 }
210
211 fn execute(
212 _ctx: &mut Self::Context,
213 _plan: &Self::Plan,
214 _alpha: S,
215 _inputs: &[&Tensor<S>],
216 _beta: S,
217 _output: &mut Tensor<S>,
218 ) -> Result<()> {
219 Err(Error::InvalidArgument(
220 "analytic family execution is not implemented on RocmBackend in phase 1".into(),
221 ))
222 }
223
224 fn has_analytic_support(_desc: AnalyticPrimsDescriptor) -> bool {
225 false
226 }
227}