Skip to main content

tenferro/
eager_ops_elementwise.rs

1use tenferro_ops::std_tensor_op::StdTensorOp;
2use tenferro_tensor::TensorBackend;
3
4use crate::eager::EagerTensor;
5use crate::error::Result;
6
7impl<B: TensorBackend> EagerTensor<B> {
8    /// Elementwise absolute value.
9    ///
10    /// # Examples
11    ///
12    /// ```
13    /// use tenferro::{EagerTensor, Tensor};
14    ///
15    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![-1.0_f64, 2.0]));
16    /// let y = x.abs().unwrap();
17    ///
18    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
19    /// ```
20    pub fn abs(&self) -> Result<Self> {
21        self.unary_op(StdTensorOp::Abs)
22    }
23
24    /// Elementwise complex conjugate.
25    ///
26    /// # Examples
27    ///
28    /// ```
29    /// use tenferro::{EagerTensor, Tensor};
30    ///
31    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, -2.0]));
32    /// let y = x.conj().unwrap();
33    ///
34    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, -2.0]);
35    /// ```
36    pub fn conj(&self) -> Result<Self> {
37        self.unary_op(StdTensorOp::Conj)
38    }
39
40    /// Elementwise sign.
41    ///
42    /// # Examples
43    ///
44    /// ```
45    /// use tenferro::{EagerTensor, Tensor};
46    ///
47    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![-2.0_f64, 3.0]));
48    /// let y = x.sign().unwrap();
49    ///
50    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[-1.0, 1.0]);
51    /// ```
52    pub fn sign(&self) -> Result<Self> {
53        self.unary_op(StdTensorOp::Sign)
54    }
55
56    /// Elementwise natural logarithm.
57    ///
58    /// # Examples
59    ///
60    /// ```
61    /// use tenferro::{EagerTensor, Tensor};
62    ///
63    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![1.0_f64]));
64    /// let y = x.log().unwrap();
65    ///
66    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
67    /// ```
68    pub fn log(&self) -> Result<Self> {
69        self.unary_op(StdTensorOp::Log)
70    }
71
72    /// Elementwise square root.
73    ///
74    /// # Examples
75    ///
76    /// ```
77    /// use tenferro::{EagerTensor, Tensor};
78    ///
79    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![4.0_f64]));
80    /// let y = x.sqrt().unwrap();
81    ///
82    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[2.0]);
83    /// ```
84    pub fn sqrt(&self) -> Result<Self> {
85        self.unary_op(StdTensorOp::Sqrt)
86    }
87
88    /// Elementwise reciprocal square root.
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// use tenferro::{EagerTensor, Tensor};
94    ///
95    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![4.0_f64]));
96    /// let y = x.rsqrt().unwrap();
97    ///
98    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.5]);
99    /// ```
100    pub fn rsqrt(&self) -> Result<Self> {
101        self.unary_op(StdTensorOp::Rsqrt)
102    }
103
104    /// Elementwise sine.
105    ///
106    /// # Examples
107    ///
108    /// ```
109    /// use tenferro::{EagerTensor, Tensor};
110    ///
111    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
112    /// let y = x.sin().unwrap();
113    ///
114    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
115    /// ```
116    pub fn sin(&self) -> Result<Self> {
117        self.unary_op(StdTensorOp::Sin)
118    }
119
120    /// Elementwise cosine.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// use tenferro::{EagerTensor, Tensor};
126    ///
127    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
128    /// let y = x.cos().unwrap();
129    ///
130    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0]);
131    /// ```
132    pub fn cos(&self) -> Result<Self> {
133        self.unary_op(StdTensorOp::Cos)
134    }
135
136    /// Elementwise hyperbolic tangent.
137    ///
138    /// # Examples
139    ///
140    /// ```
141    /// use tenferro::{EagerTensor, Tensor};
142    ///
143    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
144    /// let y = x.tanh().unwrap();
145    ///
146    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
147    /// ```
148    pub fn tanh(&self) -> Result<Self> {
149        self.unary_op(StdTensorOp::Tanh)
150    }
151
152    /// Elementwise `exp(x) - 1`.
153    ///
154    /// # Examples
155    ///
156    /// ```
157    /// use tenferro::{EagerTensor, Tensor};
158    ///
159    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
160    /// let y = x.expm1().unwrap();
161    ///
162    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
163    /// ```
164    pub fn expm1(&self) -> Result<Self> {
165        self.unary_op(StdTensorOp::Expm1)
166    }
167
168    /// Elementwise `log(1 + x)`.
169    ///
170    /// # Examples
171    ///
172    /// ```
173    /// use tenferro::{EagerTensor, Tensor};
174    ///
175    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
176    /// let y = x.log1p().unwrap();
177    ///
178    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0]);
179    /// ```
180    pub fn log1p(&self) -> Result<Self> {
181        self.unary_op(StdTensorOp::Log1p)
182    }
183
184    /// Elementwise division.
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// use tenferro::{EagerTensor, Tensor};
190    ///
191    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![8.0_f64, -6.0, 9.0]));
192    /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![2.0_f64, 3.0, 3.0]));
193    /// let z = x.div(&y).unwrap();
194    ///
195    /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[4.0, -2.0, 3.0]);
196    /// ```
197    pub fn div(&self, other: &Self) -> Result<Self> {
198        self.binary_op(other, StdTensorOp::Div)
199    }
200
201    /// Elementwise power.
202    ///
203    /// # Examples
204    ///
205    /// ```
206    /// use tenferro::{EagerTensor, Tensor};
207    ///
208    /// let base = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![2.0_f64, 3.0]));
209    /// let exp = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 2.0]));
210    /// let y = base.pow(&exp).unwrap();
211    ///
212    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[8.0, 9.0]);
213    /// ```
214    pub fn pow(&self, other: &Self) -> Result<Self> {
215        self.binary_op(other, StdTensorOp::Pow)
216    }
217
218    /// Elementwise maximum.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// use tenferro::{EagerTensor, Tensor};
224    ///
225    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 5.0]));
226    /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
227    /// let z = x.maximum(&y).unwrap();
228    ///
229    /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[3.0, 5.0]);
230    /// ```
231    pub fn maximum(&self, other: &Self) -> Result<Self> {
232        self.binary_op(other, StdTensorOp::Maximum)
233    }
234
235    /// Elementwise minimum.
236    ///
237    /// # Examples
238    ///
239    /// ```
240    /// use tenferro::{EagerTensor, Tensor};
241    ///
242    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 5.0]));
243    /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
244    /// let z = x.minimum(&y).unwrap();
245    ///
246    /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[1.0, 4.0]);
247    /// ```
248    pub fn minimum(&self, other: &Self) -> Result<Self> {
249        self.binary_op(other, StdTensorOp::Minimum)
250    }
251
252    /// Select values from `on_true` or `on_false` using `condition`.
253    ///
254    /// # Examples
255    ///
256    /// ```
257    /// use tenferro::{EagerTensor, Tensor};
258    ///
259    /// let condition = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![0.0_f64, 1.0]));
260    /// let on_true = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![10.0_f64, 20.0]));
261    /// let on_false = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
262    /// let y = EagerTensor::select(&condition, &on_true, &on_false).unwrap();
263    ///
264    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 20.0]);
265    /// ```
266    pub fn select(condition: &Self, on_true: &Self, on_false: &Self) -> Result<Self> {
267        condition.ternary_op(on_true, on_false, StdTensorOp::Select)
268    }
269}