tenferro_tensor/cpu/
analytic.rs1use num_complex::{Complex32, Complex64};
2use num_traits::{One, Zero};
3use strided_kernel::{map_into, zip_map2_into};
4
5use crate::backend::{tensor_from_array, typed_array, typed_view};
6use crate::types::{Tensor, TypedTensor};
7
8trait UnaryAnalyticElem: Copy + Clone + One + Zero {
9 fn exp_elem(self) -> Self;
10 fn log_elem(self) -> Self;
11 fn sin_elem(self) -> Self;
12 fn cos_elem(self) -> Self;
13 fn tanh_elem(self) -> Self;
14 fn sqrt_elem(self) -> Self;
15 fn rsqrt_elem(self) -> Self;
16 fn expm1_elem(self) -> Self;
17 fn log1p_elem(self) -> Self;
18}
19
20trait PowElem: Copy + Clone + Zero {
21 fn pow_elem(self, exponent: Self) -> Self;
22}
23
24macro_rules! impl_real_analytic_elem {
25 ($ty:ty) => {
26 impl UnaryAnalyticElem for $ty {
27 fn exp_elem(self) -> Self {
28 self.exp()
29 }
30
31 fn log_elem(self) -> Self {
32 self.ln()
33 }
34
35 fn sin_elem(self) -> Self {
36 self.sin()
37 }
38
39 fn cos_elem(self) -> Self {
40 self.cos()
41 }
42
43 fn tanh_elem(self) -> Self {
44 self.tanh()
45 }
46
47 fn sqrt_elem(self) -> Self {
48 self.sqrt()
49 }
50
51 fn rsqrt_elem(self) -> Self {
52 Self::one() / self.sqrt()
53 }
54
55 fn expm1_elem(self) -> Self {
56 self.exp_m1()
57 }
58
59 fn log1p_elem(self) -> Self {
60 self.ln_1p()
61 }
62 }
63
64 impl PowElem for $ty {
65 fn pow_elem(self, exponent: Self) -> Self {
66 self.powf(exponent)
67 }
68 }
69 };
70}
71
72macro_rules! impl_complex_analytic_elem {
73 ($ty:ty) => {
74 impl UnaryAnalyticElem for $ty {
75 fn exp_elem(self) -> Self {
76 self.exp()
77 }
78
79 fn log_elem(self) -> Self {
80 self.ln()
81 }
82
83 fn sin_elem(self) -> Self {
84 self.sin()
85 }
86
87 fn cos_elem(self) -> Self {
88 self.cos()
89 }
90
91 fn tanh_elem(self) -> Self {
92 self.tanh()
93 }
94
95 fn sqrt_elem(self) -> Self {
96 self.sqrt()
97 }
98
99 fn rsqrt_elem(self) -> Self {
100 Self::one() / self.sqrt()
101 }
102
103 fn expm1_elem(self) -> Self {
104 self.exp() - Self::one()
105 }
106
107 fn log1p_elem(self) -> Self {
108 (self + Self::one()).ln()
109 }
110 }
111
112 impl PowElem for $ty {
113 fn pow_elem(self, exponent: Self) -> Self {
114 self.powc(exponent)
115 }
116 }
117 };
118}
119
120impl_real_analytic_elem!(f32);
121impl_real_analytic_elem!(f64);
122impl_complex_analytic_elem!(Complex32);
123impl_complex_analytic_elem!(Complex64);
124
125fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
126 crate::Error::BackendFailure {
127 op,
128 message: err.to_string(),
129 }
130}
131
132macro_rules! define_unary_analytic_op {
133 ($dispatch_fn:ident, $typed_fn:ident, $elem_fn:ident) => {
134 pub fn $dispatch_fn(input: &Tensor) -> crate::Result<Tensor> {
135 match input {
136 Tensor::F32(t) => Ok(Tensor::F32($typed_fn(t)?)),
137 Tensor::F64(t) => Ok(Tensor::F64($typed_fn(t)?)),
138 Tensor::C32(t) => Ok(Tensor::C32($typed_fn(t)?)),
139 Tensor::C64(t) => Ok(Tensor::C64($typed_fn(t)?)),
140 }
141 }
142
143 fn $typed_fn<T>(input: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
144 where
145 T: UnaryAnalyticElem,
146 {
147 let mut out = typed_array(&input.shape, T::zero());
148 map_into(&mut out.view_mut(), &typed_view(input), |x| x.$elem_fn())
149 .map_err(|err| backend_failure(stringify!($typed_fn), err))?;
150 Ok(tensor_from_array(out))
151 }
152 };
153}
154
155define_unary_analytic_op!(exp, typed_exp, exp_elem);
156define_unary_analytic_op!(log, typed_log, log_elem);
157define_unary_analytic_op!(sin, typed_sin, sin_elem);
158define_unary_analytic_op!(cos, typed_cos, cos_elem);
159define_unary_analytic_op!(tanh, typed_tanh, tanh_elem);
160define_unary_analytic_op!(sqrt, typed_sqrt, sqrt_elem);
161define_unary_analytic_op!(rsqrt, typed_rsqrt, rsqrt_elem);
162define_unary_analytic_op!(expm1, typed_expm1, expm1_elem);
163define_unary_analytic_op!(log1p, typed_log1p, log1p_elem);
164
165pub fn pow(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
166 match (lhs, rhs) {
167 (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_pow(a, b)?)),
168 (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_pow(a, b)?)),
169 (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_pow(a, b)?)),
170 (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_pow(a, b)?)),
171 _ => Err(crate::Error::DTypeMismatch {
172 op: "pow",
173 lhs: lhs.dtype(),
174 rhs: rhs.dtype(),
175 }),
176 }
177}
178
179fn typed_pow<T>(lhs: &TypedTensor<T>, rhs: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
180where
181 T: PowElem,
182{
183 if lhs.shape != rhs.shape {
184 return Err(crate::Error::ShapeMismatch {
185 op: "pow",
186 lhs: lhs.shape.clone(),
187 rhs: rhs.shape.clone(),
188 });
189 }
190 let mut out = typed_array(&lhs.shape, T::zero());
191 zip_map2_into(
192 &mut out.view_mut(),
193 &typed_view(lhs),
194 &typed_view(rhs),
195 |x, y| x.pow_elem(y),
196 )
197 .map_err(|err| backend_failure("pow", err))?;
198 Ok(tensor_from_array(out))
199}