1use tenferro_ops::broadcast::{broadcast_input_plan, broadcast_shape, broadcast_shapes};
7use tenferro_tensor::{
8 CompareDir, DotGeneralConfig, Error, Result, Tensor, TensorBackend, TensorRead, TensorScalar,
9};
10
11use crate::{TypedTensorMaskOpsExt, TypedTensorOpsExt};
12use tenferro_tensor::TypedTensor;
13
14impl<T: TensorScalar> TypedTensorOpsExt<T> for TypedTensor<T> {
15 fn add<B: TensorBackend>(
16 &self,
17 rhs: &TypedTensor<T>,
18 backend: &mut B,
19 ) -> Result<TypedTensor<T>> {
20 add(self, rhs, backend)
21 }
22
23 fn sub<B: TensorBackend>(
24 &self,
25 rhs: &TypedTensor<T>,
26 backend: &mut B,
27 ) -> Result<TypedTensor<T>> {
28 sub(self, rhs, backend)
29 }
30
31 fn mul<B: TensorBackend>(
32 &self,
33 rhs: &TypedTensor<T>,
34 backend: &mut B,
35 ) -> Result<TypedTensor<T>> {
36 mul(self, rhs, backend)
37 }
38
39 fn div<B: TensorBackend>(
40 &self,
41 rhs: &TypedTensor<T>,
42 backend: &mut B,
43 ) -> Result<TypedTensor<T>> {
44 div(self, rhs, backend)
45 }
46
47 fn pow<B: TensorBackend>(
48 &self,
49 rhs: &TypedTensor<T>,
50 backend: &mut B,
51 ) -> Result<TypedTensor<T>> {
52 pow(self, rhs, backend)
53 }
54
55 fn maximum<B: TensorBackend>(
56 &self,
57 rhs: &TypedTensor<T>,
58 backend: &mut B,
59 ) -> Result<TypedTensor<T>> {
60 maximum(self, rhs, backend)
61 }
62
63 fn minimum<B: TensorBackend>(
64 &self,
65 rhs: &TypedTensor<T>,
66 backend: &mut B,
67 ) -> Result<TypedTensor<T>> {
68 minimum(self, rhs, backend)
69 }
70
71 fn neg<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
72 neg(self, backend)
73 }
74
75 fn abs<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
76 abs(self, backend)
77 }
78
79 fn sign<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
80 sign(self, backend)
81 }
82
83 fn conj<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
84 conj(self, backend)
85 }
86
87 fn exp<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
88 exp(self, backend)
89 }
90
91 fn log<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
92 log(self, backend)
93 }
94
95 fn sin<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
96 sin(self, backend)
97 }
98
99 fn cos<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
100 cos(self, backend)
101 }
102
103 fn tanh<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
104 tanh(self, backend)
105 }
106
107 fn sqrt<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
108 sqrt(self, backend)
109 }
110
111 fn rsqrt<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
112 rsqrt(self, backend)
113 }
114
115 fn expm1<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
116 expm1(self, backend)
117 }
118
119 fn log1p<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
120 log1p(self, backend)
121 }
122
123 fn compare<B: TensorBackend>(
124 &self,
125 rhs: &TypedTensor<T>,
126 dir: CompareDir,
127 backend: &mut B,
128 ) -> Result<TypedTensor<bool>> {
129 compare(self, rhs, dir, backend)
130 }
131
132 fn clamp<B: TensorBackend>(
133 &self,
134 lower: &TypedTensor<T>,
135 upper: &TypedTensor<T>,
136 backend: &mut B,
137 ) -> Result<TypedTensor<T>> {
138 clamp(self, lower, upper, backend)
139 }
140
141 fn matmul<B: TensorBackend>(
142 &self,
143 rhs: &TypedTensor<T>,
144 backend: &mut B,
145 ) -> Result<TypedTensor<T>> {
146 matmul(self, rhs, backend)
147 }
148
149 fn reduce_sum<B: TensorBackend>(
150 &self,
151 axes: &[usize],
152 backend: &mut B,
153 ) -> Result<TypedTensor<T>> {
154 reduce_sum(self, axes, backend)
155 }
156
157 fn reshape<B: TensorBackend>(
158 &self,
159 shape: &[usize],
160 backend: &mut B,
161 ) -> Result<TypedTensor<T>> {
162 reshape(self, shape, backend)
163 }
164
165 fn transpose<B: TensorBackend>(
166 &self,
167 perm: &[usize],
168 backend: &mut B,
169 ) -> Result<TypedTensor<T>> {
170 transpose(self, perm, backend)
171 }
172
173 fn broadcast_in_dim<B: TensorBackend>(
174 &self,
175 shape: &[usize],
176 dims: &[usize],
177 backend: &mut B,
178 ) -> Result<TypedTensor<T>> {
179 broadcast_in_dim(self, shape, dims, backend)
180 }
181}
182
183impl TypedTensorMaskOpsExt for TypedTensor<bool> {
184 fn where_select<T: TensorScalar, B: TensorBackend>(
185 &self,
186 on_true: &TypedTensor<T>,
187 on_false: &TypedTensor<T>,
188 backend: &mut B,
189 ) -> Result<TypedTensor<T>> {
190 where_select(self, on_true, on_false, backend)
191 }
192}
193
194fn add<T: TensorScalar>(
207 lhs: &TypedTensor<T>,
208 rhs: &TypedTensor<T>,
209 backend: &mut impl TensorBackend,
210) -> Result<TypedTensor<T>> {
211 let (lhs, rhs) = broadcast_binary_read(lhs, rhs, backend)?;
212 let out =
213 backend.with_backend_session(|exec| exec.add_read(lhs.tensor_read(), rhs.tensor_read()))?;
214 into_typed_result("add", out)
215}
216
217macro_rules! unary_fn {
218 ($name:ident, $method:ident, $summary:literal) => {
219 #[doc = $summary]
220 #[doc = concat!("let y = x.", stringify!($name), "(&mut backend).unwrap();")]
229 fn $name<T: TensorScalar>(
231 input: &TypedTensor<T>,
232 backend: &mut impl TensorBackend,
233 ) -> Result<TypedTensor<T>> {
234 let out = backend.with_backend_session(|exec| exec.$method(T::tensor_read(input)))?;
235 into_typed_result(stringify!($name), out)
236 }
237 };
238}
239
240macro_rules! binary_fn {
241 ($name:ident, $method:ident, $summary:literal) => {
242 #[doc = $summary]
243 #[doc = concat!("let z = x.", stringify!($name), "(&y, &mut backend).unwrap();")]
253 fn $name<T: TensorScalar>(
255 lhs: &TypedTensor<T>,
256 rhs: &TypedTensor<T>,
257 backend: &mut impl TensorBackend,
258 ) -> Result<TypedTensor<T>> {
259 let (lhs, rhs) = broadcast_binary_read(lhs, rhs, backend)?;
260 let out = backend
261 .with_backend_session(|exec| exec.$method(lhs.tensor_read(), rhs.tensor_read()))?;
262 into_typed_result(stringify!($name), out)
263 }
264 };
265}
266
267binary_fn!(
268 mul,
269 mul_read,
270 "Elementwise multiplication with NumPy-style broadcasting."
271);
272binary_fn!(
273 div,
274 div_read,
275 "Elementwise division with NumPy-style broadcasting."
276);
277binary_fn!(
278 pow,
279 pow_read,
280 "Elementwise power with NumPy-style broadcasting."
281);
282binary_fn!(
283 maximum,
284 maximum_read,
285 "Elementwise maximum with NumPy-style broadcasting."
286);
287binary_fn!(
288 minimum,
289 minimum_read,
290 "Elementwise minimum with NumPy-style broadcasting."
291);
292
293unary_fn!(neg, neg_read, "Elementwise negation.");
294unary_fn!(abs, abs_read, "Elementwise absolute value.");
295unary_fn!(sign, sign_read, "Elementwise sign.");
296unary_fn!(conj, conj_read, "Elementwise complex conjugate.");
297unary_fn!(exp, exp_read, "Elementwise exponential.");
298unary_fn!(log, log_read, "Elementwise natural logarithm.");
299unary_fn!(sin, sin_read, "Elementwise sine.");
300unary_fn!(cos, cos_read, "Elementwise cosine.");
301unary_fn!(tanh, tanh_read, "Elementwise hyperbolic tangent.");
302unary_fn!(sqrt, sqrt_read, "Elementwise square root.");
303unary_fn!(rsqrt, rsqrt_read, "Elementwise reciprocal square root.");
304unary_fn!(expm1, expm1_read, "Elementwise `exp(x) - 1`.");
305unary_fn!(log1p, log1p_read, "Elementwise `log(1 + x)`.");
306
307fn sub<T: TensorScalar>(
320 lhs: &TypedTensor<T>,
321 rhs: &TypedTensor<T>,
322 backend: &mut impl TensorBackend,
323) -> Result<TypedTensor<T>> {
324 let (lhs, rhs) = broadcast_binary_read(lhs, rhs, backend)?;
325 let neg_rhs = backend.with_backend_session(|exec| exec.neg_read(rhs.tensor_read()))?;
326 let out = backend.with_backend_session(|exec| {
327 exec.add_read(lhs.tensor_read(), TensorRead::from_tensor(&neg_rhs))
328 })?;
329 into_typed_result("sub", out)
330}
331
332fn compare<T: TensorScalar>(
348 lhs: &TypedTensor<T>,
349 rhs: &TypedTensor<T>,
350 dir: CompareDir,
351 backend: &mut impl TensorBackend,
352) -> Result<TypedTensor<bool>> {
353 let (lhs, rhs) = broadcast_binary_read(lhs, rhs, backend)?;
354 let out = backend.with_backend_session(|exec| {
355 exec.compare_read(lhs.tensor_read(), rhs.tensor_read(), &dir)
356 })?;
357 into_typed_result("compare", out)
358}
359
360fn where_select<T: TensorScalar>(
376 condition: &TypedTensor<bool>,
377 on_true: &TypedTensor<T>,
378 on_false: &TypedTensor<T>,
379 backend: &mut impl TensorBackend,
380) -> Result<TypedTensor<T>> {
381 let (condition, on_true, on_false) =
382 broadcast_ternary_read(condition, on_true, on_false, backend)?;
383 let out = backend.with_backend_session(|exec| {
384 exec.select_read(
385 condition.tensor_read(),
386 on_true.tensor_read(),
387 on_false.tensor_read(),
388 )
389 })?;
390 into_typed_result("where_select", out)
391}
392
393fn clamp<T: TensorScalar>(
407 input: &TypedTensor<T>,
408 lower: &TypedTensor<T>,
409 upper: &TypedTensor<T>,
410 backend: &mut impl TensorBackend,
411) -> Result<TypedTensor<T>> {
412 let (input, lower, upper) = broadcast_ternary_read(input, lower, upper, backend)?;
413 let out = backend.with_backend_session(|exec| {
414 exec.clamp_read(
415 input.tensor_read(),
416 lower.tensor_read(),
417 upper.tensor_read(),
418 )
419 })?;
420 into_typed_result("clamp", out)
421}
422
423fn matmul<T: TensorScalar>(
438 a: &TypedTensor<T>,
439 b: &TypedTensor<T>,
440 backend: &mut impl TensorBackend,
441) -> Result<TypedTensor<T>> {
442 let config = DotGeneralConfig {
443 lhs_contracting_dims: vec![a.shape().len() - 1],
444 rhs_contracting_dims: vec![0],
445 lhs_batch_dims: vec![],
446 rhs_batch_dims: vec![],
447 };
448 let out = backend.with_backend_session(|exec| {
449 exec.dot_general_read(T::tensor_read(a), T::tensor_read(b), &config)
450 })?;
451 into_typed_result("matmul", out)
452}
453
454fn reduce_sum<T: TensorScalar>(
471 input: &TypedTensor<T>,
472 axes: &[usize],
473 backend: &mut impl TensorBackend,
474) -> Result<TypedTensor<T>> {
475 let out =
476 backend.with_backend_session(|exec| exec.reduce_sum_read(T::tensor_read(input), axes))?;
477 into_typed_result("reduce_sum", out)
478}
479
480fn reshape<T: TensorScalar>(
493 input: &TypedTensor<T>,
494 shape: &[usize],
495 backend: &mut impl TensorBackend,
496) -> Result<TypedTensor<T>> {
497 let out =
498 backend.with_backend_session(|exec| exec.reshape_read(T::tensor_read(input), shape))?;
499 into_typed_result("reshape", out)
500}
501
502fn transpose<T: TensorScalar>(
515 input: &TypedTensor<T>,
516 perm: &[usize],
517 backend: &mut impl TensorBackend,
518) -> Result<TypedTensor<T>> {
519 let out =
520 backend.with_backend_session(|exec| exec.transpose_read(T::tensor_read(input), perm))?;
521 into_typed_result("transpose", out)
522}
523
524fn broadcast_in_dim<T: TensorScalar>(
540 input: &TypedTensor<T>,
541 shape: &[usize],
542 dims: &[usize],
543 backend: &mut impl TensorBackend,
544) -> Result<TypedTensor<T>> {
545 let out = backend.with_backend_session(|exec| {
546 exec.broadcast_in_dim_read(T::tensor_read(input), shape, dims)
547 })?;
548 into_typed_result("broadcast_in_dim", out)
549}
550
551enum ReadInput<'a> {
552 Borrowed(TensorRead<'a>),
553 Owned(Tensor),
554}
555
556impl ReadInput<'_> {
557 fn tensor_read(&self) -> TensorRead<'_> {
558 match self {
559 Self::Borrowed(read) => read.clone(),
560 Self::Owned(tensor) => TensorRead::from_tensor(tensor),
561 }
562 }
563}
564
565fn broadcast_binary_read<'a, T: TensorScalar>(
566 lhs: &'a TypedTensor<T>,
567 rhs: &'a TypedTensor<T>,
568 backend: &mut impl TensorBackend,
569) -> Result<(ReadInput<'a>, ReadInput<'a>)> {
570 let shape = broadcast_shape(lhs.shape(), rhs.shape()).map_err(broadcast_error)?;
571 Ok((
572 broadcast_to_read(lhs, &shape, backend)?,
573 broadcast_to_read(rhs, &shape, backend)?,
574 ))
575}
576
577fn broadcast_ternary_read<'a, C: TensorScalar, T: TensorScalar>(
578 first: &'a TypedTensor<C>,
579 second: &'a TypedTensor<T>,
580 third: &'a TypedTensor<T>,
581 backend: &mut impl TensorBackend,
582) -> Result<(ReadInput<'a>, ReadInput<'a>, ReadInput<'a>)> {
583 let shape = broadcast_shapes([first.shape(), second.shape(), third.shape()])
584 .map_err(broadcast_error)?;
585 Ok((
586 broadcast_to_read(first, &shape, backend)?,
587 broadcast_to_read(second, &shape, backend)?,
588 broadcast_to_read(third, &shape, backend)?,
589 ))
590}
591
592fn broadcast_to_read<'a, T: TensorScalar>(
593 input: &'a TypedTensor<T>,
594 target_shape: &[usize],
595 backend: &mut impl TensorBackend,
596) -> Result<ReadInput<'a>> {
597 if input.shape() == target_shape {
598 return Ok(ReadInput::Borrowed(T::tensor_read(input)));
599 }
600
601 let plan = broadcast_input_plan(input.shape(), target_shape).map_err(broadcast_error)?;
602 let source = if plan.source_shape == input.shape() {
603 ReadInput::Borrowed(T::tensor_read(input))
604 } else {
605 let reshaped = backend.with_backend_session(|exec| {
606 exec.reshape_read(T::tensor_read(input), &plan.source_shape)
607 })?;
608 ReadInput::Owned(reshaped)
609 };
610 let out = backend.with_backend_session(|exec| {
611 exec.broadcast_in_dim_read(source.tensor_read(), target_shape, &plan.dims)
612 })?;
613 Ok(ReadInput::Owned(out))
614}
615
616fn broadcast_error(err: impl std::fmt::Display) -> Error {
617 Error::backend_failure("broadcast", err.to_string())
618}
619
620fn into_typed_result<T: TensorScalar>(op: &'static str, tensor: Tensor) -> Result<TypedTensor<T>> {
621 let actual = tensor.dtype();
622 T::into_typed(tensor).map_err(|_| Error::DTypeMismatch {
623 op,
624 lhs: T::dtype(),
625 rhs: actual,
626 })
627}