tenferro/eager_ops_elementwise.rs
1use tenferro_ops::std_tensor_op::StdTensorOp;
2use tenferro_tensor::TensorBackend;
3
4use crate::eager::EagerTensor;
5use crate::error::Result;
6
7impl<B: TensorBackend> EagerTensor<B> {
8 /// Elementwise absolute value.
9 ///
10 /// # Examples
11 ///
12 /// ```
13 /// use tenferro::{EagerTensor, Tensor};
14 ///
15 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![-1.0_f64, 2.0]));
16 /// let y = x.abs().unwrap();
17 ///
18 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
19 /// ```
20 pub fn abs(&self) -> Result<Self> {
21 self.unary_op(StdTensorOp::Abs)
22 }
23
24 /// Elementwise complex conjugate.
25 ///
26 /// # Examples
27 ///
28 /// ```
29 /// use tenferro::{EagerTensor, Tensor};
30 ///
31 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, -2.0]));
32 /// let y = x.conj().unwrap();
33 ///
34 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, -2.0]);
35 /// ```
36 pub fn conj(&self) -> Result<Self> {
37 self.unary_op(StdTensorOp::Conj)
38 }
39
40 /// Elementwise sign.
41 ///
42 /// # Examples
43 ///
44 /// ```
45 /// use tenferro::{EagerTensor, Tensor};
46 ///
47 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![-2.0_f64, 3.0]));
48 /// let y = x.sign().unwrap();
49 ///
50 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[-1.0, 1.0]);
51 /// ```
52 pub fn sign(&self) -> Result<Self> {
53 self.unary_op(StdTensorOp::Sign)
54 }
55
56 /// Elementwise natural logarithm.
57 ///
58 /// # Examples
59 ///
60 /// ```
61 /// use tenferro::{EagerTensor, Tensor};
62 ///
63 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![1.0_f64]));
64 /// let y = x.log().unwrap();
65 ///
66 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
67 /// ```
68 pub fn log(&self) -> Result<Self> {
69 self.unary_op(StdTensorOp::Log)
70 }
71
72 /// Elementwise square root.
73 ///
74 /// # Examples
75 ///
76 /// ```
77 /// use tenferro::{EagerTensor, Tensor};
78 ///
79 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![4.0_f64]));
80 /// let y = x.sqrt().unwrap();
81 ///
82 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[2.0]);
83 /// ```
84 pub fn sqrt(&self) -> Result<Self> {
85 self.unary_op(StdTensorOp::Sqrt)
86 }
87
88 /// Elementwise reciprocal square root.
89 ///
90 /// # Examples
91 ///
92 /// ```
93 /// use tenferro::{EagerTensor, Tensor};
94 ///
95 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![4.0_f64]));
96 /// let y = x.rsqrt().unwrap();
97 ///
98 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.5]);
99 /// ```
100 pub fn rsqrt(&self) -> Result<Self> {
101 self.unary_op(StdTensorOp::Rsqrt)
102 }
103
104 /// Elementwise sine.
105 ///
106 /// # Examples
107 ///
108 /// ```
109 /// use tenferro::{EagerTensor, Tensor};
110 ///
111 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
112 /// let y = x.sin().unwrap();
113 ///
114 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
115 /// ```
116 pub fn sin(&self) -> Result<Self> {
117 self.unary_op(StdTensorOp::Sin)
118 }
119
120 /// Elementwise cosine.
121 ///
122 /// # Examples
123 ///
124 /// ```
125 /// use tenferro::{EagerTensor, Tensor};
126 ///
127 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
128 /// let y = x.cos().unwrap();
129 ///
130 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0]);
131 /// ```
132 pub fn cos(&self) -> Result<Self> {
133 self.unary_op(StdTensorOp::Cos)
134 }
135
136 /// Elementwise hyperbolic tangent.
137 ///
138 /// # Examples
139 ///
140 /// ```
141 /// use tenferro::{EagerTensor, Tensor};
142 ///
143 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
144 /// let y = x.tanh().unwrap();
145 ///
146 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
147 /// ```
148 pub fn tanh(&self) -> Result<Self> {
149 self.unary_op(StdTensorOp::Tanh)
150 }
151
152 /// Elementwise `exp(x) - 1`.
153 ///
154 /// # Examples
155 ///
156 /// ```
157 /// use tenferro::{EagerTensor, Tensor};
158 ///
159 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
160 /// let y = x.expm1().unwrap();
161 ///
162 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
163 /// ```
164 pub fn expm1(&self) -> Result<Self> {
165 self.unary_op(StdTensorOp::Expm1)
166 }
167
168 /// Elementwise `log(1 + x)`.
169 ///
170 /// # Examples
171 ///
172 /// ```
173 /// use tenferro::{EagerTensor, Tensor};
174 ///
175 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
176 /// let y = x.log1p().unwrap();
177 ///
178 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
179 /// ```
180 pub fn log1p(&self) -> Result<Self> {
181 self.unary_op(StdTensorOp::Log1p)
182 }
183
184 /// Elementwise division.
185 ///
186 /// # Examples
187 ///
188 /// ```
189 /// use tenferro::{EagerTensor, Tensor};
190 ///
191 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![8.0_f64, -6.0, 9.0]));
192 /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![2.0_f64, 3.0, 3.0]));
193 /// let z = x.div(&y).unwrap();
194 ///
195 /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[4.0, -2.0, 3.0]);
196 /// ```
197 pub fn div(&self, other: &Self) -> Result<Self> {
198 self.binary_op(other, StdTensorOp::Div)
199 }
200
201 /// Elementwise power.
202 ///
203 /// # Examples
204 ///
205 /// ```
206 /// use tenferro::{EagerTensor, Tensor};
207 ///
208 /// let base = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![2.0_f64, 3.0]));
209 /// let exp = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 2.0]));
210 /// let y = base.pow(&exp).unwrap();
211 ///
212 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[8.0, 9.0]);
213 /// ```
214 pub fn pow(&self, other: &Self) -> Result<Self> {
215 self.binary_op(other, StdTensorOp::Pow)
216 }
217
218 /// Elementwise maximum.
219 ///
220 /// # Examples
221 ///
222 /// ```
223 /// use tenferro::{EagerTensor, Tensor};
224 ///
225 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 5.0]));
226 /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
227 /// let z = x.maximum(&y).unwrap();
228 ///
229 /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[3.0, 5.0]);
230 /// ```
231 pub fn maximum(&self, other: &Self) -> Result<Self> {
232 self.binary_op(other, StdTensorOp::Maximum)
233 }
234
235 /// Elementwise minimum.
236 ///
237 /// # Examples
238 ///
239 /// ```
240 /// use tenferro::{EagerTensor, Tensor};
241 ///
242 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 5.0]));
243 /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
244 /// let z = x.minimum(&y).unwrap();
245 ///
246 /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[1.0, 4.0]);
247 /// ```
248 pub fn minimum(&self, other: &Self) -> Result<Self> {
249 self.binary_op(other, StdTensorOp::Minimum)
250 }
251
252 /// Select values from `on_true` or `on_false` using `condition`.
253 ///
254 /// # Examples
255 ///
256 /// ```
257 /// use tenferro::{EagerTensor, Tensor};
258 ///
259 /// let condition = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![0.0_f64, 1.0]));
260 /// let on_true = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![10.0_f64, 20.0]));
261 /// let on_false = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
262 /// let y = EagerTensor::select(&condition, &on_true, &on_false).unwrap();
263 ///
264 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 20.0]);
265 /// ```
266 pub fn select(condition: &Self, on_true: &Self, on_false: &Self) -> Result<Self> {
267 condition.ternary_op(on_true, on_false, StdTensorOp::Select)
268 }
269}