Skip to main content

tenferro_ad/
eager_ops_elementwise.rs

1use tenferro_ops::std_tensor_op::StdTensorOp;
2
3use crate::eager::EagerTensor;
4use crate::eager_ops::{broadcast_binary, broadcast_ternary};
5use crate::error::Result;
6use crate::CompareDir;
7
8impl EagerTensor {
9    /// Elementwise absolute value.
10    ///
11    /// # Examples
12    ///
13    /// ```
14    /// use tenferro_cpu::CpuBackend;
15    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
16    ///
17    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
18    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![-1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
19    /// let y = x.abs().unwrap();
20    ///
21    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
22    /// ```
23    pub fn abs(&self) -> Result<Self> {
24        self.unary_op(StdTensorOp::Abs)
25    }
26
27    /// Elementwise complex conjugate.
28    ///
29    /// # Examples
30    ///
31    /// ```
32    /// use tenferro_cpu::CpuBackend;
33    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
34    ///
35    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
36    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, -2.0]).unwrap(), ctx.clone()).unwrap();
37    /// let y = x.conj().unwrap();
38    ///
39    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, -2.0]);
40    /// ```
41    pub fn conj(&self) -> Result<Self> {
42        self.unary_op(StdTensorOp::Conj)
43    }
44
45    /// Elementwise sign.
46    ///
47    /// # Examples
48    ///
49    /// ```
50    /// use tenferro_cpu::CpuBackend;
51    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
52    ///
53    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
54    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![-2.0_f64, 3.0]).unwrap(), ctx.clone()).unwrap();
55    /// let y = x.sign().unwrap();
56    ///
57    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[-1.0, 1.0]);
58    /// ```
59    pub fn sign(&self) -> Result<Self> {
60        self.unary_op(StdTensorOp::Sign)
61    }
62
63    /// Elementwise natural logarithm.
64    ///
65    /// # Examples
66    ///
67    /// ```
68    /// use tenferro_cpu::CpuBackend;
69    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
70    ///
71    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
72    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap(), ctx.clone()).unwrap();
73    /// let y = x.log().unwrap();
74    ///
75    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
76    /// ```
77    pub fn log(&self) -> Result<Self> {
78        self.unary_op(StdTensorOp::Log)
79    }
80
81    /// Elementwise square root.
82    ///
83    /// # Examples
84    ///
85    /// ```
86    /// use tenferro_cpu::CpuBackend;
87    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
88    ///
89    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
90    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![4.0_f64]).unwrap(), ctx.clone()).unwrap();
91    /// let y = x.sqrt().unwrap();
92    ///
93    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[2.0]);
94    /// ```
95    pub fn sqrt(&self) -> Result<Self> {
96        self.unary_op(StdTensorOp::Sqrt)
97    }
98
99    /// Elementwise reciprocal square root.
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use tenferro_cpu::CpuBackend;
105    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
106    ///
107    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
108    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![4.0_f64]).unwrap(), ctx.clone()).unwrap();
109    /// let y = x.rsqrt().unwrap();
110    ///
111    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.5]);
112    /// ```
113    pub fn rsqrt(&self) -> Result<Self> {
114        self.unary_op(StdTensorOp::Rsqrt)
115    }
116
117    /// Elementwise sine.
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// use tenferro_cpu::CpuBackend;
123    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
124    ///
125    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
126    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
127    /// let y = x.sin().unwrap();
128    ///
129    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
130    /// ```
131    pub fn sin(&self) -> Result<Self> {
132        self.unary_op(StdTensorOp::Sin)
133    }
134
135    /// Elementwise cosine.
136    ///
137    /// # Examples
138    ///
139    /// ```
140    /// use tenferro_cpu::CpuBackend;
141    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
142    ///
143    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
144    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
145    /// let y = x.cos().unwrap();
146    ///
147    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0]);
148    /// ```
149    pub fn cos(&self) -> Result<Self> {
150        self.unary_op(StdTensorOp::Cos)
151    }
152
153    /// Elementwise hyperbolic tangent.
154    ///
155    /// # Examples
156    ///
157    /// ```
158    /// use tenferro_cpu::CpuBackend;
159    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
160    ///
161    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
162    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
163    /// let y = x.tanh().unwrap();
164    ///
165    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
166    /// ```
167    pub fn tanh(&self) -> Result<Self> {
168        self.unary_op(StdTensorOp::Tanh)
169    }
170
171    /// Elementwise `exp(x) - 1`.
172    ///
173    /// # Examples
174    ///
175    /// ```
176    /// use tenferro_cpu::CpuBackend;
177    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
178    ///
179    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
180    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
181    /// let y = x.expm1().unwrap();
182    ///
183    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
184    /// ```
185    pub fn expm1(&self) -> Result<Self> {
186        self.unary_op(StdTensorOp::Expm1)
187    }
188
189    /// Elementwise `log(1 + x)`.
190    ///
191    /// # Examples
192    ///
193    /// ```
194    /// use tenferro_cpu::CpuBackend;
195    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
196    ///
197    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
198    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
199    /// let y = x.log1p().unwrap();
200    ///
201    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
202    /// ```
203    pub fn log1p(&self) -> Result<Self> {
204        self.unary_op(StdTensorOp::Log1p)
205    }
206
207    /// Elementwise division.
208    ///
209    /// # Examples
210    ///
211    /// ```
212    /// use tenferro_cpu::CpuBackend;
213    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
214    ///
215    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
216    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![8.0_f64, -6.0, 9.0]).unwrap(), ctx.clone()).unwrap();
217    /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![2.0_f64, 3.0, 3.0]).unwrap(), ctx.clone()).unwrap();
218    /// let z = x.div(&y).unwrap();
219    ///
220    /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[4.0, -2.0, 3.0]);
221    /// ```
222    pub fn div(&self, other: &Self) -> Result<Self> {
223        let (lhs, rhs) = broadcast_binary("div", self, other)?;
224        lhs.binary_op(&rhs, StdTensorOp::Div)
225    }
226
227    /// Elementwise power.
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use tenferro_cpu::CpuBackend;
233    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
234    ///
235    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
236    /// let base = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![2.0_f64, 3.0]).unwrap(), ctx.clone()).unwrap();
237    /// let exp = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
238    /// let y = base.pow(&exp).unwrap();
239    ///
240    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[8.0, 9.0]);
241    /// ```
242    pub fn pow(&self, other: &Self) -> Result<Self> {
243        let (lhs, rhs) = broadcast_binary("pow", self, other)?;
244        lhs.binary_op(&rhs, StdTensorOp::Pow)
245    }
246
247    /// Elementwise maximum.
248    ///
249    /// # Examples
250    ///
251    /// ```
252    /// use tenferro_cpu::CpuBackend;
253    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
254    ///
255    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
256    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 5.0]).unwrap(), ctx.clone()).unwrap();
257    /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
258    /// let z = x.maximum(&y).unwrap();
259    ///
260    /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0, 5.0]);
261    /// ```
262    pub fn maximum(&self, other: &Self) -> Result<Self> {
263        let (lhs, rhs) = broadcast_binary("maximum", self, other)?;
264        lhs.binary_op(&rhs, StdTensorOp::Maximum)
265    }
266
267    /// Elementwise minimum.
268    ///
269    /// # Examples
270    ///
271    /// ```
272    /// use tenferro_cpu::CpuBackend;
273    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
274    ///
275    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
276    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 5.0]).unwrap(), ctx.clone()).unwrap();
277    /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
278    /// let z = x.minimum(&y).unwrap();
279    ///
280    /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 4.0]);
281    /// ```
282    pub fn minimum(&self, other: &Self) -> Result<Self> {
283        let (lhs, rhs) = broadcast_binary("minimum", self, other)?;
284        lhs.binary_op(&rhs, StdTensorOp::Minimum)
285    }
286
287    /// Elementwise comparison.
288    pub fn compare(&self, other: &Self, dir: CompareDir) -> Result<Self> {
289        let (lhs, rhs) = broadcast_binary("compare", self, other)?;
290        lhs.binary_op(&rhs, StdTensorOp::Compare(dir))
291    }
292
293    /// Select values from `on_true` or `on_false` using `condition`.
294    ///
295    /// # Examples
296    ///
297    /// ```
298    /// use tenferro_cpu::CpuBackend;
299    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
300    ///
301    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
302    /// let condition = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![false, true]).unwrap(), ctx.clone()).unwrap();
303    /// let on_true = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![10.0_f64, 20.0]).unwrap(), ctx.clone()).unwrap();
304    /// let on_false = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
305    /// let y = EagerTensor::select(&condition, &on_true, &on_false).unwrap();
306    ///
307    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 20.0]);
308    /// ```
309    pub fn select(condition: &Self, on_true: &Self, on_false: &Self) -> Result<Self> {
310        Self::where_select(condition, on_true, on_false)
311    }
312
313    /// Select values from `on_true` or `on_false` using `condition`.
314    pub fn where_select(condition: &Self, on_true: &Self, on_false: &Self) -> Result<Self> {
315        let (condition, on_true, on_false) =
316            broadcast_ternary("where_select", condition, on_true, on_false)?;
317        condition.ternary_op(&on_true, &on_false, StdTensorOp::Select)
318    }
319
320    /// Clamp values elementwise between lower and upper bounds.
321    ///
322    /// # Examples
323    ///
324    /// ```
325    /// use tenferro_cpu::CpuBackend;
326    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
327    ///
328    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
329    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![-2.0_f64, 0.5, 5.0]).unwrap(), ctx.clone()).unwrap();
330    /// let lower = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![-1.0_f64, 0.0, 1.0]).unwrap(), ctx.clone()).unwrap();
331    /// let upper = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 4.0]).unwrap(), ctx.clone()).unwrap();
332    /// let y = x.clamp(&lower, &upper).unwrap();
333    ///
334    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[-1.0, 0.5, 4.0]);
335    /// ```
336    pub fn clamp(&self, lower: &Self, upper: &Self) -> Result<Self> {
337        let (input, lower, upper) = broadcast_ternary("clamp", self, lower, upper)?;
338        input.ternary_op(&lower, &upper, StdTensorOp::Clamp)
339    }
340}