tenferro_ad/eager_ops_elementwise.rs
1use tenferro_ops::std_tensor_op::StdTensorOp;
2
3use crate::eager::EagerTensor;
4use crate::eager_ops::{broadcast_binary, broadcast_ternary};
5use crate::error::Result;
6use crate::CompareDir;
7
8impl EagerTensor {
9 /// Elementwise absolute value.
10 ///
11 /// # Examples
12 ///
13 /// ```
14 /// use tenferro_cpu::CpuBackend;
15 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
16 ///
17 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
18 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![-1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
19 /// let y = x.abs().unwrap();
20 ///
21 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
22 /// ```
23 pub fn abs(&self) -> Result<Self> {
24 self.unary_op(StdTensorOp::Abs)
25 }
26
27 /// Elementwise complex conjugate.
28 ///
29 /// # Examples
30 ///
31 /// ```
32 /// use tenferro_cpu::CpuBackend;
33 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
34 ///
35 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
36 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, -2.0]).unwrap(), ctx.clone()).unwrap();
37 /// let y = x.conj().unwrap();
38 ///
39 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, -2.0]);
40 /// ```
41 pub fn conj(&self) -> Result<Self> {
42 self.unary_op(StdTensorOp::Conj)
43 }
44
45 /// Elementwise sign.
46 ///
47 /// # Examples
48 ///
49 /// ```
50 /// use tenferro_cpu::CpuBackend;
51 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
52 ///
53 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
54 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![-2.0_f64, 3.0]).unwrap(), ctx.clone()).unwrap();
55 /// let y = x.sign().unwrap();
56 ///
57 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[-1.0, 1.0]);
58 /// ```
59 pub fn sign(&self) -> Result<Self> {
60 self.unary_op(StdTensorOp::Sign)
61 }
62
63 /// Elementwise natural logarithm.
64 ///
65 /// # Examples
66 ///
67 /// ```
68 /// use tenferro_cpu::CpuBackend;
69 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
70 ///
71 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
72 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap(), ctx.clone()).unwrap();
73 /// let y = x.log().unwrap();
74 ///
75 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
76 /// ```
77 pub fn log(&self) -> Result<Self> {
78 self.unary_op(StdTensorOp::Log)
79 }
80
81 /// Elementwise square root.
82 ///
83 /// # Examples
84 ///
85 /// ```
86 /// use tenferro_cpu::CpuBackend;
87 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
88 ///
89 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
90 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![4.0_f64]).unwrap(), ctx.clone()).unwrap();
91 /// let y = x.sqrt().unwrap();
92 ///
93 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[2.0]);
94 /// ```
95 pub fn sqrt(&self) -> Result<Self> {
96 self.unary_op(StdTensorOp::Sqrt)
97 }
98
99 /// Elementwise reciprocal square root.
100 ///
101 /// # Examples
102 ///
103 /// ```
104 /// use tenferro_cpu::CpuBackend;
105 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
106 ///
107 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
108 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![4.0_f64]).unwrap(), ctx.clone()).unwrap();
109 /// let y = x.rsqrt().unwrap();
110 ///
111 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.5]);
112 /// ```
113 pub fn rsqrt(&self) -> Result<Self> {
114 self.unary_op(StdTensorOp::Rsqrt)
115 }
116
117 /// Elementwise sine.
118 ///
119 /// # Examples
120 ///
121 /// ```
122 /// use tenferro_cpu::CpuBackend;
123 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
124 ///
125 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
126 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
127 /// let y = x.sin().unwrap();
128 ///
129 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
130 /// ```
131 pub fn sin(&self) -> Result<Self> {
132 self.unary_op(StdTensorOp::Sin)
133 }
134
135 /// Elementwise cosine.
136 ///
137 /// # Examples
138 ///
139 /// ```
140 /// use tenferro_cpu::CpuBackend;
141 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
142 ///
143 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
144 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
145 /// let y = x.cos().unwrap();
146 ///
147 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0]);
148 /// ```
149 pub fn cos(&self) -> Result<Self> {
150 self.unary_op(StdTensorOp::Cos)
151 }
152
153 /// Elementwise hyperbolic tangent.
154 ///
155 /// # Examples
156 ///
157 /// ```
158 /// use tenferro_cpu::CpuBackend;
159 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
160 ///
161 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
162 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
163 /// let y = x.tanh().unwrap();
164 ///
165 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
166 /// ```
167 pub fn tanh(&self) -> Result<Self> {
168 self.unary_op(StdTensorOp::Tanh)
169 }
170
171 /// Elementwise `exp(x) - 1`.
172 ///
173 /// # Examples
174 ///
175 /// ```
176 /// use tenferro_cpu::CpuBackend;
177 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
178 ///
179 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
180 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
181 /// let y = x.expm1().unwrap();
182 ///
183 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
184 /// ```
185 pub fn expm1(&self) -> Result<Self> {
186 self.unary_op(StdTensorOp::Expm1)
187 }
188
189 /// Elementwise `log(1 + x)`.
190 ///
191 /// # Examples
192 ///
193 /// ```
194 /// use tenferro_cpu::CpuBackend;
195 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
196 ///
197 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
198 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
199 /// let y = x.log1p().unwrap();
200 ///
201 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0]);
202 /// ```
203 pub fn log1p(&self) -> Result<Self> {
204 self.unary_op(StdTensorOp::Log1p)
205 }
206
207 /// Elementwise division.
208 ///
209 /// # Examples
210 ///
211 /// ```
212 /// use tenferro_cpu::CpuBackend;
213 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
214 ///
215 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
216 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![8.0_f64, -6.0, 9.0]).unwrap(), ctx.clone()).unwrap();
217 /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![2.0_f64, 3.0, 3.0]).unwrap(), ctx.clone()).unwrap();
218 /// let z = x.div(&y).unwrap();
219 ///
220 /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[4.0, -2.0, 3.0]);
221 /// ```
222 pub fn div(&self, other: &Self) -> Result<Self> {
223 let (lhs, rhs) = broadcast_binary("div", self, other)?;
224 lhs.binary_op(&rhs, StdTensorOp::Div)
225 }
226
227 /// Elementwise power.
228 ///
229 /// # Examples
230 ///
231 /// ```
232 /// use tenferro_cpu::CpuBackend;
233 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
234 ///
235 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
236 /// let base = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![2.0_f64, 3.0]).unwrap(), ctx.clone()).unwrap();
237 /// let exp = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
238 /// let y = base.pow(&exp).unwrap();
239 ///
240 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[8.0, 9.0]);
241 /// ```
242 pub fn pow(&self, other: &Self) -> Result<Self> {
243 let (lhs, rhs) = broadcast_binary("pow", self, other)?;
244 lhs.binary_op(&rhs, StdTensorOp::Pow)
245 }
246
247 /// Elementwise maximum.
248 ///
249 /// # Examples
250 ///
251 /// ```
252 /// use tenferro_cpu::CpuBackend;
253 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
254 ///
255 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
256 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 5.0]).unwrap(), ctx.clone()).unwrap();
257 /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
258 /// let z = x.maximum(&y).unwrap();
259 ///
260 /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0, 5.0]);
261 /// ```
262 pub fn maximum(&self, other: &Self) -> Result<Self> {
263 let (lhs, rhs) = broadcast_binary("maximum", self, other)?;
264 lhs.binary_op(&rhs, StdTensorOp::Maximum)
265 }
266
267 /// Elementwise minimum.
268 ///
269 /// # Examples
270 ///
271 /// ```
272 /// use tenferro_cpu::CpuBackend;
273 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
274 ///
275 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
276 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 5.0]).unwrap(), ctx.clone()).unwrap();
277 /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
278 /// let z = x.minimum(&y).unwrap();
279 ///
280 /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 4.0]);
281 /// ```
282 pub fn minimum(&self, other: &Self) -> Result<Self> {
283 let (lhs, rhs) = broadcast_binary("minimum", self, other)?;
284 lhs.binary_op(&rhs, StdTensorOp::Minimum)
285 }
286
287 /// Elementwise comparison.
288 pub fn compare(&self, other: &Self, dir: CompareDir) -> Result<Self> {
289 let (lhs, rhs) = broadcast_binary("compare", self, other)?;
290 lhs.binary_op(&rhs, StdTensorOp::Compare(dir))
291 }
292
293 /// Select values from `on_true` or `on_false` using `condition`.
294 ///
295 /// # Examples
296 ///
297 /// ```
298 /// use tenferro_cpu::CpuBackend;
299 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
300 ///
301 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
302 /// let condition = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![false, true]).unwrap(), ctx.clone()).unwrap();
303 /// let on_true = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![10.0_f64, 20.0]).unwrap(), ctx.clone()).unwrap();
304 /// let on_false = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
305 /// let y = EagerTensor::select(&condition, &on_true, &on_false).unwrap();
306 ///
307 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 20.0]);
308 /// ```
309 pub fn select(condition: &Self, on_true: &Self, on_false: &Self) -> Result<Self> {
310 Self::where_select(condition, on_true, on_false)
311 }
312
313 /// Select values from `on_true` or `on_false` using `condition`.
314 pub fn where_select(condition: &Self, on_true: &Self, on_false: &Self) -> Result<Self> {
315 let (condition, on_true, on_false) =
316 broadcast_ternary("where_select", condition, on_true, on_false)?;
317 condition.ternary_op(&on_true, &on_false, StdTensorOp::Select)
318 }
319
320 /// Clamp values elementwise between lower and upper bounds.
321 ///
322 /// # Examples
323 ///
324 /// ```
325 /// use tenferro_cpu::CpuBackend;
326 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
327 ///
328 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
329 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![-2.0_f64, 0.5, 5.0]).unwrap(), ctx.clone()).unwrap();
330 /// let lower = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![-1.0_f64, 0.0, 1.0]).unwrap(), ctx.clone()).unwrap();
331 /// let upper = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 4.0]).unwrap(), ctx.clone()).unwrap();
332 /// let y = x.clamp(&lower, &upper).unwrap();
333 ///
334 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[-1.0, 0.5, 4.0]);
335 /// ```
336 pub fn clamp(&self, lower: &Self, upper: &Self) -> Result<Self> {
337 let (input, lower, upper) = broadcast_ternary("clamp", self, lower, upper)?;
338 input.ternary_op(&lower, &upper, StdTensorOp::Clamp)
339 }
340}