1use tenferro_ops::broadcast::{broadcast_input_plan, broadcast_shape, broadcast_shapes};
7use tenferro_tensor::{CompareDir, DType, DotGeneralConfig, Error, Result, TensorBackend};
8
9use crate::TensorOpsExt;
10use tenferro_tensor::Tensor;
11
12impl TensorOpsExt for Tensor {
13 fn convert<B: TensorBackend>(&self, to: DType, backend: &mut B) -> Result<Tensor> {
14 convert(self, to, backend)
15 }
16
17 fn cast<B: TensorBackend>(&self, to: DType, backend: &mut B) -> Result<Tensor> {
18 cast(self, to, backend)
19 }
20
21 fn add<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
22 add(self, rhs, backend)
23 }
24
25 fn sub<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
26 sub(self, rhs, backend)
27 }
28
29 fn mul<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
30 mul(self, rhs, backend)
31 }
32
33 fn div<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
34 div(self, rhs, backend)
35 }
36
37 fn pow<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
38 pow(self, rhs, backend)
39 }
40
41 fn maximum<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
42 maximum(self, rhs, backend)
43 }
44
45 fn minimum<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
46 minimum(self, rhs, backend)
47 }
48
49 fn neg<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
50 neg(self, backend)
51 }
52
53 fn abs<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
54 abs(self, backend)
55 }
56
57 fn sign<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
58 sign(self, backend)
59 }
60
61 fn conj<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
62 conj(self, backend)
63 }
64
65 fn exp<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
66 exp(self, backend)
67 }
68
69 fn log<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
70 log(self, backend)
71 }
72
73 fn sin<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
74 sin(self, backend)
75 }
76
77 fn cos<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
78 cos(self, backend)
79 }
80
81 fn tanh<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
82 tanh(self, backend)
83 }
84
85 fn sqrt<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
86 sqrt(self, backend)
87 }
88
89 fn rsqrt<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
90 rsqrt(self, backend)
91 }
92
93 fn expm1<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
94 expm1(self, backend)
95 }
96
97 fn log1p<B: TensorBackend>(&self, backend: &mut B) -> Result<Tensor> {
98 log1p(self, backend)
99 }
100
101 fn compare<B: TensorBackend>(
102 &self,
103 rhs: &Tensor,
104 dir: CompareDir,
105 backend: &mut B,
106 ) -> Result<Tensor> {
107 compare(self, rhs, dir, backend)
108 }
109
110 fn where_select<B: TensorBackend>(
111 &self,
112 on_true: &Tensor,
113 on_false: &Tensor,
114 backend: &mut B,
115 ) -> Result<Tensor> {
116 where_select(self, on_true, on_false, backend)
117 }
118
119 fn clamp<B: TensorBackend>(
120 &self,
121 lower: &Tensor,
122 upper: &Tensor,
123 backend: &mut B,
124 ) -> Result<Tensor> {
125 clamp(self, lower, upper, backend)
126 }
127
128 fn matmul<B: TensorBackend>(&self, rhs: &Tensor, backend: &mut B) -> Result<Tensor> {
129 matmul(self, rhs, backend)
130 }
131
132 fn reshape<B: TensorBackend>(&self, shape: &[usize], backend: &mut B) -> Result<Tensor> {
133 reshape(self, shape, backend)
134 }
135
136 fn transpose<B: TensorBackend>(&self, perm: &[usize], backend: &mut B) -> Result<Tensor> {
137 transpose(self, perm, backend)
138 }
139
140 fn reduce_sum<B: TensorBackend>(&self, axes: &[usize], backend: &mut B) -> Result<Tensor> {
141 reduce_sum(self, axes, backend)
142 }
143}
144
145fn convert(input: &Tensor, to: DType, backend: &mut impl TensorBackend) -> Result<Tensor> {
166 backend.with_backend_session(|exec| exec.convert(input, to))
167}
168
169fn cast(input: &Tensor, to: DType, backend: &mut impl TensorBackend) -> Result<Tensor> {
191 backend.with_backend_session(|exec| exec.cast(input, to))
192}
193
194fn add(lhs: &Tensor, rhs: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
207 let (lhs, rhs) = broadcast_binary(lhs, rhs, backend)?;
208 backend.with_backend_session(|exec| exec.add(&lhs, &rhs))
209}
210
211macro_rules! unary_fn {
212 ($name:ident, $method:ident, $summary:literal) => {
213 #[doc = $summary]
214 #[doc = concat!("let y = x.", stringify!($name), "(&mut backend).unwrap();")]
223 fn $name(input: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
225 backend.with_backend_session(|exec| exec.$method(input))
226 }
227 };
228}
229
230macro_rules! binary_fn {
231 ($name:ident, $method:ident, $summary:literal) => {
232 #[doc = $summary]
233 #[doc = concat!("let z = x.", stringify!($name), "(&y, &mut backend).unwrap();")]
243 fn $name(lhs: &Tensor, rhs: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
245 let (lhs, rhs) = broadcast_binary(lhs, rhs, backend)?;
246 backend.with_backend_session(|exec| exec.$method(&lhs, &rhs))
247 }
248 };
249}
250
251binary_fn!(
252 mul,
253 mul,
254 "Elementwise multiplication with NumPy-style broadcasting."
255);
256binary_fn!(
257 div,
258 div,
259 "Elementwise division with NumPy-style broadcasting."
260);
261binary_fn!(pow, pow, "Elementwise power with NumPy-style broadcasting.");
262binary_fn!(
263 maximum,
264 maximum,
265 "Elementwise maximum with NumPy-style broadcasting."
266);
267binary_fn!(
268 minimum,
269 minimum,
270 "Elementwise minimum with NumPy-style broadcasting."
271);
272
273unary_fn!(neg, neg, "Elementwise negation.");
274unary_fn!(abs, abs, "Elementwise absolute value.");
275unary_fn!(sign, sign, "Elementwise sign.");
276unary_fn!(conj, conj, "Elementwise complex conjugate.");
277unary_fn!(exp, exp, "Elementwise exponential.");
278unary_fn!(log, log, "Elementwise natural logarithm.");
279unary_fn!(sin, sin, "Elementwise sine.");
280unary_fn!(cos, cos, "Elementwise cosine.");
281unary_fn!(tanh, tanh, "Elementwise hyperbolic tangent.");
282unary_fn!(sqrt, sqrt, "Elementwise square root.");
283unary_fn!(rsqrt, rsqrt, "Elementwise reciprocal square root.");
284unary_fn!(expm1, expm1, "Elementwise `exp(x) - 1`.");
285unary_fn!(log1p, log1p, "Elementwise `log(1 + x)`.");
286
287fn sub(lhs: &Tensor, rhs: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
300 let (lhs, rhs) = broadcast_binary(lhs, rhs, backend)?;
301 let neg_rhs = backend.with_backend_session(|exec| exec.neg(&rhs))?;
302 backend.with_backend_session(|exec| exec.add(&lhs, &neg_rhs))
303}
304
305fn compare(
321 lhs: &Tensor,
322 rhs: &Tensor,
323 dir: CompareDir,
324 backend: &mut impl TensorBackend,
325) -> Result<Tensor> {
326 let (lhs, rhs) = broadcast_binary(lhs, rhs, backend)?;
327 backend.with_backend_session(|exec| exec.compare(&lhs, &rhs, &dir))
328}
329
330fn where_select(
346 condition: &Tensor,
347 on_true: &Tensor,
348 on_false: &Tensor,
349 backend: &mut impl TensorBackend,
350) -> Result<Tensor> {
351 let (condition, on_true, on_false) = broadcast_ternary(condition, on_true, on_false, backend)?;
352 backend.with_backend_session(|exec| exec.select(&condition, &on_true, &on_false))
353}
354
355fn clamp(
369 input: &Tensor,
370 lower: &Tensor,
371 upper: &Tensor,
372 backend: &mut impl TensorBackend,
373) -> Result<Tensor> {
374 let (input, lower, upper) = broadcast_ternary(input, lower, upper, backend)?;
375 backend.with_backend_session(|exec| exec.clamp(&input, &lower, &upper))
376}
377
378fn matmul(a: &Tensor, b: &Tensor, backend: &mut impl TensorBackend) -> Result<Tensor> {
393 let config = DotGeneralConfig {
394 lhs_contracting_dims: vec![a.shape().len() - 1],
395 rhs_contracting_dims: vec![0],
396 lhs_batch_dims: vec![],
397 rhs_batch_dims: vec![],
398 };
399 backend.with_backend_session(|exec| exec.dot_general(a, b, &config))
400}
401
402fn reshape(input: &Tensor, shape: &[usize], backend: &mut impl TensorBackend) -> Result<Tensor> {
415 backend.with_backend_session(|exec| exec.reshape(input, shape))
416}
417
418fn transpose(input: &Tensor, perm: &[usize], backend: &mut impl TensorBackend) -> Result<Tensor> {
431 backend.with_backend_session(|exec| exec.transpose(input, perm))
432}
433
434fn reduce_sum(input: &Tensor, axes: &[usize], backend: &mut impl TensorBackend) -> Result<Tensor> {
447 backend.with_backend_session(|exec| exec.reduce_sum(input, axes))
448}
449
450fn broadcast_binary(
451 lhs: &Tensor,
452 rhs: &Tensor,
453 backend: &mut impl TensorBackend,
454) -> Result<(Tensor, Tensor)> {
455 let shape = broadcast_shape(lhs.shape(), rhs.shape()).map_err(broadcast_error)?;
456 Ok((
457 broadcast_to(lhs, &shape, backend)?,
458 broadcast_to(rhs, &shape, backend)?,
459 ))
460}
461
462fn broadcast_ternary(
463 first: &Tensor,
464 second: &Tensor,
465 third: &Tensor,
466 backend: &mut impl TensorBackend,
467) -> Result<(Tensor, Tensor, Tensor)> {
468 let shape = broadcast_shapes([first.shape(), second.shape(), third.shape()])
469 .map_err(broadcast_error)?;
470 Ok((
471 broadcast_to(first, &shape, backend)?,
472 broadcast_to(second, &shape, backend)?,
473 broadcast_to(third, &shape, backend)?,
474 ))
475}
476
477fn broadcast_to(
478 input: &Tensor,
479 target_shape: &[usize],
480 backend: &mut impl TensorBackend,
481) -> Result<Tensor> {
482 let input_shape = input.shape();
483 if input_shape == target_shape {
484 return Ok(input.clone());
485 }
486
487 let plan = broadcast_input_plan(input_shape, target_shape).map_err(broadcast_error)?;
488 let source = if plan.source_shape == input_shape {
489 input.clone()
490 } else {
491 backend.with_backend_session(|exec| exec.reshape(input, &plan.source_shape))?
492 };
493 backend.with_backend_session(|exec| exec.broadcast_in_dim(&source, target_shape, &plan.dims))
494}
495
496fn broadcast_error(err: impl std::fmt::Display) -> Error {
497 Error::backend_failure("broadcast", err.to_string())
498}