Skip to main content

tenferro_tensor/cpu/
analytic.rs

1use num_complex::{Complex32, Complex64};
2use num_traits::{One, Zero};
3use strided_kernel::{map_into, zip_map2_into};
4
5use crate::backend::{tensor_from_array, typed_array, typed_view};
6use crate::types::{Tensor, TypedTensor};
7
8trait UnaryAnalyticElem: Copy + Clone + One + Zero {
9    fn exp_elem(self) -> Self;
10    fn log_elem(self) -> Self;
11    fn sin_elem(self) -> Self;
12    fn cos_elem(self) -> Self;
13    fn tanh_elem(self) -> Self;
14    fn sqrt_elem(self) -> Self;
15    fn rsqrt_elem(self) -> Self;
16    fn expm1_elem(self) -> Self;
17    fn log1p_elem(self) -> Self;
18}
19
20trait PowElem: Copy + Clone + Zero {
21    fn pow_elem(self, exponent: Self) -> Self;
22}
23
24macro_rules! impl_real_analytic_elem {
25    ($ty:ty) => {
26        impl UnaryAnalyticElem for $ty {
27            fn exp_elem(self) -> Self {
28                self.exp()
29            }
30
31            fn log_elem(self) -> Self {
32                self.ln()
33            }
34
35            fn sin_elem(self) -> Self {
36                self.sin()
37            }
38
39            fn cos_elem(self) -> Self {
40                self.cos()
41            }
42
43            fn tanh_elem(self) -> Self {
44                self.tanh()
45            }
46
47            fn sqrt_elem(self) -> Self {
48                self.sqrt()
49            }
50
51            fn rsqrt_elem(self) -> Self {
52                Self::one() / self.sqrt()
53            }
54
55            fn expm1_elem(self) -> Self {
56                self.exp_m1()
57            }
58
59            fn log1p_elem(self) -> Self {
60                self.ln_1p()
61            }
62        }
63
64        impl PowElem for $ty {
65            fn pow_elem(self, exponent: Self) -> Self {
66                self.powf(exponent)
67            }
68        }
69    };
70}
71
72macro_rules! impl_complex_analytic_elem {
73    ($ty:ty) => {
74        impl UnaryAnalyticElem for $ty {
75            fn exp_elem(self) -> Self {
76                self.exp()
77            }
78
79            fn log_elem(self) -> Self {
80                self.ln()
81            }
82
83            fn sin_elem(self) -> Self {
84                self.sin()
85            }
86
87            fn cos_elem(self) -> Self {
88                self.cos()
89            }
90
91            fn tanh_elem(self) -> Self {
92                self.tanh()
93            }
94
95            fn sqrt_elem(self) -> Self {
96                self.sqrt()
97            }
98
99            fn rsqrt_elem(self) -> Self {
100                Self::one() / self.sqrt()
101            }
102
103            fn expm1_elem(self) -> Self {
104                self.exp() - Self::one()
105            }
106
107            fn log1p_elem(self) -> Self {
108                (self + Self::one()).ln()
109            }
110        }
111
112        impl PowElem for $ty {
113            fn pow_elem(self, exponent: Self) -> Self {
114                self.powc(exponent)
115            }
116        }
117    };
118}
119
120impl_real_analytic_elem!(f32);
121impl_real_analytic_elem!(f64);
122impl_complex_analytic_elem!(Complex32);
123impl_complex_analytic_elem!(Complex64);
124
125fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
126    crate::Error::BackendFailure {
127        op,
128        message: err.to_string(),
129    }
130}
131
132macro_rules! define_unary_analytic_op {
133    ($dispatch_fn:ident, $typed_fn:ident, $elem_fn:ident) => {
134        pub fn $dispatch_fn(input: &Tensor) -> crate::Result<Tensor> {
135            match input {
136                Tensor::F32(t) => Ok(Tensor::F32($typed_fn(t)?)),
137                Tensor::F64(t) => Ok(Tensor::F64($typed_fn(t)?)),
138                Tensor::C32(t) => Ok(Tensor::C32($typed_fn(t)?)),
139                Tensor::C64(t) => Ok(Tensor::C64($typed_fn(t)?)),
140            }
141        }
142
143        fn $typed_fn<T>(input: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
144        where
145            T: UnaryAnalyticElem,
146        {
147            let mut out = typed_array(&input.shape, T::zero());
148            map_into(&mut out.view_mut(), &typed_view(input), |x| x.$elem_fn())
149                .map_err(|err| backend_failure(stringify!($typed_fn), err))?;
150            Ok(tensor_from_array(out))
151        }
152    };
153}
154
155define_unary_analytic_op!(exp, typed_exp, exp_elem);
156define_unary_analytic_op!(log, typed_log, log_elem);
157define_unary_analytic_op!(sin, typed_sin, sin_elem);
158define_unary_analytic_op!(cos, typed_cos, cos_elem);
159define_unary_analytic_op!(tanh, typed_tanh, tanh_elem);
160define_unary_analytic_op!(sqrt, typed_sqrt, sqrt_elem);
161define_unary_analytic_op!(rsqrt, typed_rsqrt, rsqrt_elem);
162define_unary_analytic_op!(expm1, typed_expm1, expm1_elem);
163define_unary_analytic_op!(log1p, typed_log1p, log1p_elem);
164
165pub fn pow(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
166    match (lhs, rhs) {
167        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_pow(a, b)?)),
168        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_pow(a, b)?)),
169        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_pow(a, b)?)),
170        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_pow(a, b)?)),
171        _ => Err(crate::Error::DTypeMismatch {
172            op: "pow",
173            lhs: lhs.dtype(),
174            rhs: rhs.dtype(),
175        }),
176    }
177}
178
179fn typed_pow<T>(lhs: &TypedTensor<T>, rhs: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
180where
181    T: PowElem,
182{
183    if lhs.shape != rhs.shape {
184        return Err(crate::Error::ShapeMismatch {
185            op: "pow",
186            lhs: lhs.shape.clone(),
187            rhs: rhs.shape.clone(),
188        });
189    }
190    let mut out = typed_array(&lhs.shape, T::zero());
191    zip_map2_into(
192        &mut out.view_mut(),
193        &typed_view(lhs),
194        &typed_view(rhs),
195        |x, y| x.pow_elem(y),
196    )
197    .map_err(|err| backend_failure("pow", err))?;
198    Ok(tensor_from_array(out))
199}