1use std::ops::{Add, Div, Mul, Neg};
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_traits::{One, Zero};
6use strided_kernel::{
7 batched_outer_product_into, broadcast_mul_into, map_into, mul_into, zip_map2_into,
8 zip_map3_into,
9};
10
11use crate::buffer_pool::{BufferPool, PoolScalar};
12use crate::ConjElem;
13use tenferro_tensor::{
14 col_major_strides, CompareDir, DType, Tensor, TensorOwnedView, TensorRank, TensorRead,
15 TensorValue, TensorView, TypedTensor, TypedTensorView,
16};
17
18use super::{
19 tensor_from_array, typed_array_uninit_from_pool, typed_host_data, typed_view,
20 typed_view_from_view,
21};
22
23macro_rules! dispatch_ternary_result_with_pool {
24 ($op:literal, $a:expr, $b:expr, $c:expr, |$x:ident, $y:ident, $z:ident| $body:expr) => {
25 match ($a, $b, $c) {
26 (Tensor::F32($x), Tensor::F32($y), Tensor::F32($z)) => Ok(Tensor::F32($body?)),
27 (Tensor::F64($x), Tensor::F64($y), Tensor::F64($z)) => Ok(Tensor::F64($body?)),
28 (Tensor::C32($x), Tensor::C32($y), Tensor::C32($z)) => Ok(Tensor::C32($body?)),
29 (Tensor::C64($x), Tensor::C64($y), Tensor::C64($z)) => Ok(Tensor::C64($body?)),
30 _ => Err(crate::Error::backend_failure($op, "dtype mismatch")),
31 }
32 };
33}
34
35fn dtype_pair_error(op: &'static str, lhs: DType, rhs: DType) -> crate::Error {
36 if lhs == rhs {
37 crate::Error::backend_failure(op, format!("unsupported dtype {lhs:?}"))
38 } else {
39 crate::Error::DTypeMismatch { op, lhs, rhs }
40 }
41}
42
43fn tensor_pair_error(op: &'static str, lhs: &Tensor, rhs: &Tensor) -> crate::Error {
44 dtype_pair_error(op, lhs.dtype(), rhs.dtype())
45}
46
47fn read_pair_error(op: &'static str, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Error {
48 dtype_pair_error(op, lhs.dtype(), rhs.dtype())
49}
50
51pub(crate) trait Tier2Elem: Copy + Clone + One + Zero + Send + Sync {
52 fn abs_elem(self) -> Self;
53 fn sign_elem(self) -> Self;
54 fn max_elem(self, other: Self) -> Self;
55 fn min_elem(self, other: Self) -> Self;
56}
57
58pub(crate) trait CompareElem: Copy + Send + Sync {
59 fn compare_elem(self, other: Self, dir: &CompareDir) -> bool;
60}
61
62macro_rules! impl_tier2_elem_real {
63 ($ty:ty) => {
64 impl Tier2Elem for $ty {
65 fn abs_elem(self) -> Self {
66 self.abs()
67 }
68
69 fn sign_elem(self) -> Self {
70 if self == Self::zero() {
71 Self::zero()
72 } else {
73 self.signum()
74 }
75 }
76
77 fn max_elem(self, other: Self) -> Self {
78 if self.is_nan() || other.is_nan() {
79 <$ty>::NAN
80 } else if self >= other {
81 self
82 } else {
83 other
84 }
85 }
86
87 fn min_elem(self, other: Self) -> Self {
88 if self.is_nan() || other.is_nan() {
89 <$ty>::NAN
90 } else if self <= other {
91 self
92 } else {
93 other
94 }
95 }
96 }
97
98 impl CompareElem for $ty {
99 fn compare_elem(self, other: Self, dir: &CompareDir) -> bool {
100 match dir {
101 CompareDir::Eq => self == other,
102 CompareDir::Lt => self < other,
103 CompareDir::Le => self <= other,
104 CompareDir::Gt => self > other,
105 CompareDir::Ge => self >= other,
106 }
107 }
108 }
109 };
110}
111
112macro_rules! impl_tier2_elem_complex {
113 ($real:ty) => {
114 impl Tier2Elem for Complex<$real> {
115 fn abs_elem(self) -> Self {
116 Self::new(self.norm(), <$real>::zero())
117 }
118
119 fn sign_elem(self) -> Self {
120 if self.is_zero() {
121 Self::zero()
122 } else {
123 self / self.abs_elem()
124 }
125 }
126
127 fn max_elem(self, other: Self) -> Self {
128 let lhs_norm = self.norm_sqr();
129 let rhs_norm = other.norm_sqr();
130 if lhs_norm.is_nan() || rhs_norm.is_nan() {
131 Self::new(<$real>::NAN, <$real>::NAN)
132 } else if lhs_norm >= rhs_norm {
133 self
134 } else {
135 other
136 }
137 }
138
139 fn min_elem(self, other: Self) -> Self {
140 let lhs_norm = self.norm_sqr();
141 let rhs_norm = other.norm_sqr();
142 if lhs_norm.is_nan() || rhs_norm.is_nan() {
143 Self::new(<$real>::NAN, <$real>::NAN)
144 } else if lhs_norm <= rhs_norm {
145 self
146 } else {
147 other
148 }
149 }
150 }
151
152 impl CompareElem for Complex<$real> {
153 fn compare_elem(self, other: Self, dir: &CompareDir) -> bool {
154 match dir {
155 CompareDir::Eq => self == other,
156 CompareDir::Lt => self.norm_sqr() < other.norm_sqr(),
157 CompareDir::Le => self.norm_sqr() <= other.norm_sqr(),
158 CompareDir::Gt => self.norm_sqr() > other.norm_sqr(),
159 CompareDir::Ge => self.norm_sqr() >= other.norm_sqr(),
160 }
161 }
162 }
163 };
164}
165
166impl_tier2_elem_real!(f32);
167impl_tier2_elem_real!(f64);
168impl_tier2_elem_complex!(f32);
169impl_tier2_elem_complex!(f64);
170
171macro_rules! impl_compare_elem_ord {
172 ($ty:ty) => {
173 impl CompareElem for $ty {
174 fn compare_elem(self, other: Self, dir: &CompareDir) -> bool {
175 match dir {
176 CompareDir::Eq => self == other,
177 CompareDir::Lt => self < other,
178 CompareDir::Le => self <= other,
179 CompareDir::Gt => self > other,
180 CompareDir::Ge => self >= other,
181 }
182 }
183 }
184 };
185}
186
187impl_compare_elem_ord!(i32);
188impl_compare_elem_ord!(i64);
189impl_compare_elem_ord!(bool);
190
191fn complex_scalar_tensor<T>(scalar: T) -> crate::Result<TypedTensor<Complex<T>>>
192where
193 T: Copy + Clone + Zero,
194{
195 TypedTensor::from_vec_col_major(vec![], vec![Complex::new(scalar, T::zero())])
196}
197
198fn complex_scalar_tensor_from_tensor<T>(
199 input: &TypedTensor<T>,
200) -> crate::Result<TypedTensor<Complex<T>>>
201where
202 T: Copy + Clone + Zero,
203{
204 complex_scalar_tensor(typed_host_data("add", input)?[0])
205}
206
207fn complex_scalar_tensor_from_view<T, R>(
208 input: &TypedTensorView<'_, T, R>,
209) -> crate::Result<TypedTensor<Complex<T>>>
210where
211 T: Copy + Clone + Zero + 'static,
212 R: TensorRank,
213{
214 complex_scalar_tensor(typed_view_from_view("add", input)?.get(&[]))
215}
216
217fn with_local_pool<T>(f: impl FnOnce(&mut BufferPool) -> T) -> T {
218 let mut buffers = BufferPool::new();
219 f(&mut buffers)
220}
221
222pub fn add(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
237 with_local_pool(|buffers| add_with_pool(buffers, lhs, rhs))
238}
239
240pub(crate) fn add_with_pool(
241 buffers: &mut BufferPool,
242 lhs: &Tensor,
243 rhs: &Tensor,
244) -> crate::Result<Tensor> {
245 match (lhs, rhs) {
246 (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_add_with_pool(buffers, a, b)?)),
247 (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_add_with_pool(buffers, a, b)?)),
248 (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(typed_add_with_pool(buffers, a, b)?)),
249 (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(typed_add_with_pool(buffers, a, b)?)),
250 (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_add_with_pool(buffers, a, b)?)),
251 (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_add_with_pool(buffers, a, b)?)),
252 (Tensor::F32(a), Tensor::C32(b)) if a.shape().is_empty() => {
253 let scalar = complex_scalar_tensor(typed_host_data("add", a)?[0])?;
254 Ok(Tensor::C32(typed_add_with_pool(buffers, &scalar, b)?))
255 }
256 (Tensor::C32(a), Tensor::F32(b)) if b.shape().is_empty() => {
257 let scalar = complex_scalar_tensor(typed_host_data("add", b)?[0])?;
258 Ok(Tensor::C32(typed_add_with_pool(buffers, a, &scalar)?))
259 }
260 (Tensor::F64(a), Tensor::C64(b)) if a.shape().is_empty() => {
261 let scalar = complex_scalar_tensor(typed_host_data("add", a)?[0])?;
262 Ok(Tensor::C64(typed_add_with_pool(buffers, &scalar, b)?))
263 }
264 (Tensor::C64(a), Tensor::F64(b)) if b.shape().is_empty() => {
265 let scalar = complex_scalar_tensor(typed_host_data("add", b)?[0])?;
266 Ok(Tensor::C64(typed_add_with_pool(buffers, a, &scalar)?))
267 }
268 _ => Err(tensor_pair_error("add", lhs, rhs)),
269 }
270}
271
272pub(crate) fn add_read_with_pool(
273 buffers: &mut BufferPool,
274 lhs: TensorRead<'_>,
275 rhs: TensorRead<'_>,
276) -> crate::Result<Tensor> {
277 if let (TensorRead::Tensor(lhs), TensorRead::Tensor(rhs)) = (&lhs, &rhs) {
278 return add_with_pool(buffers, lhs, rhs);
279 }
280
281 macro_rules! dispatch {
282 ($variant:ident) => {
283 match (&lhs, &rhs) {
284 (
285 TensorRead::Tensor(Tensor::$variant(a)),
286 TensorRead::View(TensorView::$variant(b)),
287 ) => {
288 let a = a.as_view();
289 return Ok(Tensor::$variant(typed_add_view_with_pool(buffers, &a, b)?));
290 }
291 (
292 TensorRead::View(TensorView::$variant(a)),
293 TensorRead::Tensor(Tensor::$variant(b)),
294 ) => {
295 let b = b.as_view();
296 return Ok(Tensor::$variant(typed_add_view_with_pool(buffers, a, &b)?));
297 }
298 (
299 TensorRead::View(TensorView::$variant(a)),
300 TensorRead::View(TensorView::$variant(b)),
301 ) => {
302 return Ok(Tensor::$variant(typed_add_view_with_pool(buffers, a, b)?));
303 }
304 _ => {}
305 }
306 };
307 }
308
309 macro_rules! dispatch_real_complex_scalar {
310 ($real_variant:ident, $complex_variant:ident) => {
311 match (&lhs, &rhs) {
312 (
313 TensorRead::Tensor(Tensor::$real_variant(real)),
314 TensorRead::View(TensorView::$complex_variant(complex)),
315 ) if real.shape().is_empty() => {
316 let scalar = complex_scalar_tensor_from_tensor(real)?;
317 let scalar = scalar.as_view();
318 return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
319 buffers, &scalar, complex,
320 )?));
321 }
322 (
323 TensorRead::View(TensorView::$real_variant(real)),
324 TensorRead::Tensor(Tensor::$complex_variant(complex)),
325 ) if real.shape().is_empty() => {
326 let scalar = complex_scalar_tensor_from_view(real)?;
327 let scalar = scalar.as_view();
328 let complex = complex.as_view();
329 return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
330 buffers, &scalar, &complex,
331 )?));
332 }
333 (
334 TensorRead::View(TensorView::$real_variant(real)),
335 TensorRead::View(TensorView::$complex_variant(complex)),
336 ) if real.shape().is_empty() => {
337 let scalar = complex_scalar_tensor_from_view(real)?;
338 let scalar = scalar.as_view();
339 return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
340 buffers, &scalar, complex,
341 )?));
342 }
343 (
344 TensorRead::Tensor(Tensor::$complex_variant(complex)),
345 TensorRead::View(TensorView::$real_variant(real)),
346 ) if real.shape().is_empty() => {
347 let complex = complex.as_view();
348 let scalar = complex_scalar_tensor_from_view(real)?;
349 let scalar = scalar.as_view();
350 return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
351 buffers, &complex, &scalar,
352 )?));
353 }
354 (
355 TensorRead::View(TensorView::$complex_variant(complex)),
356 TensorRead::Tensor(Tensor::$real_variant(real)),
357 ) if real.shape().is_empty() => {
358 let scalar = complex_scalar_tensor_from_tensor(real)?;
359 let scalar = scalar.as_view();
360 return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
361 buffers, complex, &scalar,
362 )?));
363 }
364 (
365 TensorRead::View(TensorView::$complex_variant(complex)),
366 TensorRead::View(TensorView::$real_variant(real)),
367 ) if real.shape().is_empty() => {
368 let scalar = complex_scalar_tensor_from_view(real)?;
369 let scalar = scalar.as_view();
370 return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
371 buffers, complex, &scalar,
372 )?));
373 }
374 _ => {}
375 }
376 };
377 }
378
379 dispatch_real_complex_scalar!(F32, C32);
380 dispatch_real_complex_scalar!(F64, C64);
381
382 dispatch!(F32);
383 dispatch!(F64);
384 dispatch!(I32);
385 dispatch!(I64);
386 dispatch!(C32);
387 dispatch!(C64);
388
389 Err(read_pair_error("add", lhs, rhs))
390}
391
392pub fn mul(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
407 with_local_pool(|buffers| mul_with_pool(buffers, lhs, rhs))
408}
409
410fn binary_read_with_pool(
411 op: &'static str,
412 buffers: &mut BufferPool,
413 lhs: TensorRead<'_>,
414 rhs: TensorRead<'_>,
415 f: impl FnOnce(&mut BufferPool, &Tensor, &Tensor) -> crate::Result<Tensor>,
416) -> crate::Result<Tensor> {
417 if let (Some(lhs), Some(rhs)) = (lhs.as_tensor(), rhs.as_tensor()) {
418 return f(buffers, lhs, rhs);
419 }
420
421 Err(read_pair_error(op, lhs, rhs))
422}
423
424pub(crate) fn mul_with_pool(
425 buffers: &mut BufferPool,
426 lhs: &Tensor,
427 rhs: &Tensor,
428) -> crate::Result<Tensor> {
429 match (lhs, rhs) {
430 (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_mul_with_pool(buffers, a, b)?)),
431 (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_mul_with_pool(buffers, a, b)?)),
432 (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(typed_mul_with_pool(buffers, a, b)?)),
433 (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(typed_mul_with_pool(buffers, a, b)?)),
434 (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_mul_with_pool(buffers, a, b)?)),
435 (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_mul_with_pool(buffers, a, b)?)),
436 (Tensor::F32(a), Tensor::C32(b)) if a.shape().is_empty() => {
437 let scalar = complex_scalar_tensor(typed_host_data("mul", a)?[0])?;
438 Ok(Tensor::C32(typed_mul_with_pool(buffers, &scalar, b)?))
439 }
440 (Tensor::C32(a), Tensor::F32(b)) if b.shape().is_empty() => {
441 let scalar = complex_scalar_tensor(typed_host_data("mul", b)?[0])?;
442 Ok(Tensor::C32(typed_mul_with_pool(buffers, a, &scalar)?))
443 }
444 (Tensor::F64(a), Tensor::C64(b)) if a.shape().is_empty() => {
445 let scalar = complex_scalar_tensor(typed_host_data("mul", a)?[0])?;
446 Ok(Tensor::C64(typed_mul_with_pool(buffers, &scalar, b)?))
447 }
448 (Tensor::C64(a), Tensor::F64(b)) if b.shape().is_empty() => {
449 let scalar = complex_scalar_tensor(typed_host_data("mul", b)?[0])?;
450 Ok(Tensor::C64(typed_mul_with_pool(buffers, a, &scalar)?))
451 }
452 _ => Err(tensor_pair_error("mul", lhs, rhs)),
453 }
454}
455
456pub(crate) fn mul_read_with_pool(
457 buffers: &mut BufferPool,
458 lhs: TensorRead<'_>,
459 rhs: TensorRead<'_>,
460) -> crate::Result<Tensor> {
461 if let (TensorRead::Tensor(lhs), TensorRead::Tensor(rhs)) = (&lhs, &rhs) {
462 return mul_with_pool(buffers, lhs, rhs);
463 }
464
465 macro_rules! dispatch {
466 ($variant:ident) => {
467 match (&lhs, &rhs) {
468 (
469 TensorRead::Tensor(Tensor::$variant(a)),
470 TensorRead::View(TensorView::$variant(b)),
471 ) => {
472 let a = a.as_view();
473 return Ok(Tensor::$variant(typed_mul_view_with_pool(buffers, &a, b)?));
474 }
475 (
476 TensorRead::View(TensorView::$variant(a)),
477 TensorRead::Tensor(Tensor::$variant(b)),
478 ) => {
479 let b = b.as_view();
480 return Ok(Tensor::$variant(typed_mul_view_with_pool(buffers, a, &b)?));
481 }
482 (
483 TensorRead::View(TensorView::$variant(a)),
484 TensorRead::View(TensorView::$variant(b)),
485 ) => {
486 return Ok(Tensor::$variant(typed_mul_view_with_pool(buffers, a, b)?));
487 }
488 _ => {}
489 }
490 };
491 }
492
493 macro_rules! dispatch_real_complex_scalar {
494 ($real_variant:ident, $complex_variant:ident) => {
495 match (&lhs, &rhs) {
496 (
497 TensorRead::Tensor(Tensor::$real_variant(real)),
498 TensorRead::View(TensorView::$complex_variant(complex)),
499 ) if real.shape().is_empty() => {
500 let scalar = complex_scalar_tensor_from_tensor(real)?;
501 let scalar = scalar.as_view();
502 return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
503 buffers, &scalar, complex,
504 )?));
505 }
506 (
507 TensorRead::View(TensorView::$real_variant(real)),
508 TensorRead::Tensor(Tensor::$complex_variant(complex)),
509 ) if real.shape().is_empty() => {
510 let scalar = complex_scalar_tensor_from_view(real)?;
511 let scalar = scalar.as_view();
512 let complex = complex.as_view();
513 return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
514 buffers, &scalar, &complex,
515 )?));
516 }
517 (
518 TensorRead::View(TensorView::$real_variant(real)),
519 TensorRead::View(TensorView::$complex_variant(complex)),
520 ) if real.shape().is_empty() => {
521 let scalar = complex_scalar_tensor_from_view(real)?;
522 let scalar = scalar.as_view();
523 return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
524 buffers, &scalar, complex,
525 )?));
526 }
527 (
528 TensorRead::Tensor(Tensor::$complex_variant(complex)),
529 TensorRead::View(TensorView::$real_variant(real)),
530 ) if real.shape().is_empty() => {
531 let complex = complex.as_view();
532 let scalar = complex_scalar_tensor_from_view(real)?;
533 let scalar = scalar.as_view();
534 return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
535 buffers, &complex, &scalar,
536 )?));
537 }
538 (
539 TensorRead::View(TensorView::$complex_variant(complex)),
540 TensorRead::Tensor(Tensor::$real_variant(real)),
541 ) if real.shape().is_empty() => {
542 let scalar = complex_scalar_tensor_from_tensor(real)?;
543 let scalar = scalar.as_view();
544 return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
545 buffers, complex, &scalar,
546 )?));
547 }
548 (
549 TensorRead::View(TensorView::$complex_variant(complex)),
550 TensorRead::View(TensorView::$real_variant(real)),
551 ) if real.shape().is_empty() => {
552 let scalar = complex_scalar_tensor_from_view(real)?;
553 let scalar = scalar.as_view();
554 return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
555 buffers, complex, &scalar,
556 )?));
557 }
558 _ => {}
559 }
560 };
561 }
562
563 dispatch_real_complex_scalar!(F32, C32);
564 dispatch_real_complex_scalar!(F64, C64);
565
566 dispatch!(F32);
567 dispatch!(F64);
568 dispatch!(I32);
569 dispatch!(I64);
570 dispatch!(C32);
571 dispatch!(C64);
572
573 binary_read_with_pool("mul", buffers, lhs, rhs, mul_with_pool)
574}
575
576enum CpuReadView<'a> {
577 F32(TypedTensorView<'a, f32>),
578 F64(TypedTensorView<'a, f64>),
579 I32(TypedTensorView<'a, i32>),
580 I64(TypedTensorView<'a, i64>),
581 Bool(TypedTensorView<'a, bool>),
582 C32(TypedTensorView<'a, Complex<f32>>),
583 C64(TypedTensorView<'a, Complex<f64>>),
584}
585
586fn read_as_cpu_view(input: TensorRead<'_>) -> CpuReadView<'_> {
587 match input {
588 TensorRead::Tensor(Tensor::F32(tensor)) => CpuReadView::F32(tensor.as_view()),
589 TensorRead::Tensor(Tensor::F64(tensor)) => CpuReadView::F64(tensor.as_view()),
590 TensorRead::Tensor(Tensor::I32(tensor)) => CpuReadView::I32(tensor.as_view()),
591 TensorRead::Tensor(Tensor::I64(tensor)) => CpuReadView::I64(tensor.as_view()),
592 TensorRead::Tensor(Tensor::Bool(tensor)) => CpuReadView::Bool(tensor.as_view()),
593 TensorRead::Tensor(Tensor::C32(tensor)) => CpuReadView::C32(tensor.as_view()),
594 TensorRead::Tensor(Tensor::C64(tensor)) => CpuReadView::C64(tensor.as_view()),
595 TensorRead::View(TensorView::F32(view)) => CpuReadView::F32(view),
596 TensorRead::View(TensorView::F64(view)) => CpuReadView::F64(view),
597 TensorRead::View(TensorView::I32(view)) => CpuReadView::I32(view),
598 TensorRead::View(TensorView::I64(view)) => CpuReadView::I64(view),
599 TensorRead::View(TensorView::Bool(view)) => CpuReadView::Bool(view),
600 TensorRead::View(TensorView::C32(view)) => CpuReadView::C32(view),
601 TensorRead::View(TensorView::C64(view)) => CpuReadView::C64(view),
602 }
603}
604
605fn typed_binary_view_with_pool<T, L, R>(
606 op: &'static str,
607 buffers: &mut BufferPool,
608 lhs: &TypedTensorView<'_, T, L>,
609 rhs: &TypedTensorView<'_, T, R>,
610 f: impl Fn(T, T) -> T + Copy + Sync,
611) -> crate::Result<TypedTensor<T>>
612where
613 T: Copy + PoolScalar + 'static,
614 L: TensorRank,
615 R: TensorRank,
616{
617 if lhs.shape() == rhs.shape() {
618 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
620 zip_map2_into(
621 &mut out.view_mut(),
622 &typed_view_from_view(op, lhs)?,
623 &typed_view_from_view(op, rhs)?,
624 f,
625 )
626 .map_err(|err| crate::Error::backend_failure(op, err))?;
627 Ok(tensor_from_array(out))
628 } else if lhs.shape().is_empty() {
629 let scalar = typed_view_from_view(op, lhs)?.get(&[]);
630 let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
632 map_into(&mut out.view_mut(), &typed_view_from_view(op, rhs)?, |x| {
633 f(scalar, x)
634 })
635 .map_err(|err| crate::Error::backend_failure(op, err))?;
636 Ok(tensor_from_array(out))
637 } else if rhs.shape().is_empty() {
638 let scalar = typed_view_from_view(op, rhs)?.get(&[]);
639 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
641 map_into(&mut out.view_mut(), &typed_view_from_view(op, lhs)?, |x| {
642 f(x, scalar)
643 })
644 .map_err(|err| crate::Error::backend_failure(op, err))?;
645 Ok(tensor_from_array(out))
646 } else {
647 Err(crate::Error::ShapeMismatch {
648 op,
649 lhs: lhs.shape().to_vec(),
650 rhs: rhs.shape().to_vec(),
651 })
652 }
653}
654
655fn typed_unary_view_with_pool<T, R>(
656 op: &'static str,
657 buffers: &mut BufferPool,
658 input: &TypedTensorView<'_, T, R>,
659 f: impl Fn(T) -> T + Copy + Sync,
660) -> crate::Result<TypedTensor<T>>
661where
662 T: Copy + PoolScalar + 'static,
663 R: TensorRank,
664{
665 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
667 map_into(&mut out.view_mut(), &typed_view_from_view(op, input)?, f)
668 .map_err(|err| crate::Error::backend_failure(op, err))?;
669 Ok(tensor_from_array(out))
670}
671
672fn typed_same_shape_binary_view_with_pool<T, O, L, R>(
673 op: &'static str,
674 buffers: &mut BufferPool,
675 lhs: &TypedTensorView<'_, T, L>,
676 rhs: &TypedTensorView<'_, T, R>,
677 f: impl Fn(T, T) -> O + Copy + Sync,
678) -> crate::Result<TypedTensor<O>>
679where
680 T: Copy + Send + Sync + 'static,
681 O: Copy + PoolScalar,
682 L: TensorRank,
683 R: TensorRank,
684{
685 if lhs.shape() != rhs.shape() {
686 return Err(crate::Error::ShapeMismatch {
687 op,
688 lhs: lhs.shape().to_vec(),
689 rhs: rhs.shape().to_vec(),
690 });
691 }
692 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
694 zip_map2_into(
695 &mut out.view_mut(),
696 &typed_view_from_view(op, lhs)?,
697 &typed_view_from_view(op, rhs)?,
698 f,
699 )
700 .map_err(|err| crate::Error::backend_failure(op, err))?;
701 Ok(tensor_from_array(out))
702}
703
704fn typed_select_view_with_pool<T, P, A, B>(
705 buffers: &mut BufferPool,
706 pred: &TypedTensorView<'_, bool, P>,
707 on_true: &TypedTensorView<'_, T, A>,
708 on_false: &TypedTensorView<'_, T, B>,
709) -> crate::Result<TypedTensor<T>>
710where
711 T: Copy + PoolScalar + 'static,
712 P: TensorRank,
713 A: TensorRank,
714 B: TensorRank,
715{
716 if pred.shape() != on_true.shape() {
717 return Err(crate::Error::ShapeMismatch {
718 op: "select",
719 lhs: pred.shape().to_vec(),
720 rhs: on_true.shape().to_vec(),
721 });
722 }
723 if pred.shape() != on_false.shape() {
724 return Err(crate::Error::ShapeMismatch {
725 op: "select",
726 lhs: pred.shape().to_vec(),
727 rhs: on_false.shape().to_vec(),
728 });
729 }
730 let mut out = unsafe { typed_array_uninit_from_pool(buffers, pred.shape()) };
732 zip_map3_into(
733 &mut out.view_mut(),
734 &typed_view_from_view("select", pred)?,
735 &typed_view_from_view("select", on_true)?,
736 &typed_view_from_view("select", on_false)?,
737 |p, t, f| if p { t } else { f },
738 )
739 .map_err(|err| crate::Error::backend_failure("select", err))?;
740 Ok(tensor_from_array(out))
741}
742
743fn typed_clamp_view_with_pool<T, I, L, U>(
744 buffers: &mut BufferPool,
745 input: &TypedTensorView<'_, T, I>,
746 lower: &TypedTensorView<'_, T, L>,
747 upper: &TypedTensorView<'_, T, U>,
748) -> crate::Result<TypedTensor<T>>
749where
750 T: Tier2Elem + PoolScalar + 'static,
751 I: TensorRank,
752 L: TensorRank,
753 U: TensorRank,
754{
755 if input.shape() != lower.shape() {
756 return Err(crate::Error::ShapeMismatch {
757 op: "clamp",
758 lhs: input.shape().to_vec(),
759 rhs: lower.shape().to_vec(),
760 });
761 }
762 if input.shape() != upper.shape() {
763 return Err(crate::Error::ShapeMismatch {
764 op: "clamp",
765 lhs: input.shape().to_vec(),
766 rhs: upper.shape().to_vec(),
767 });
768 }
769 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
771 zip_map3_into(
772 &mut out.view_mut(),
773 &typed_view_from_view("clamp", input)?,
774 &typed_view_from_view("clamp", lower)?,
775 &typed_view_from_view("clamp", upper)?,
776 |x, lo, hi| lo.max_elem(hi.min_elem(x)),
777 )
778 .map_err(|err| crate::Error::backend_failure("clamp", err))?;
779 Ok(tensor_from_array(out))
780}
781
782#[derive(Clone, Copy)]
783enum SplitOuterProductLayout {
784 LhsPrefix,
785 RhsPrefix,
786}
787
788struct SplitOuterProductPlan {
789 #[allow(dead_code)]
790 rows: usize,
791 #[allow(dead_code)]
792 cols: usize,
793 #[allow(dead_code)]
794 batches: usize,
795 layout: SplitOuterProductLayout,
796 lhs_free_axes: Vec<usize>,
797 rhs_free_axes: Vec<usize>,
798 lhs_batch_axes: Vec<usize>,
799 rhs_batch_axes: Vec<usize>,
800}
801
802struct OuterProductAxisPartition {
803 lhs_free_output_axes: Vec<usize>,
804 rhs_free_output_axes: Vec<usize>,
805 batch_output_axes: Vec<usize>,
806 lhs_free_axes: Vec<usize>,
807 rhs_free_axes: Vec<usize>,
808 lhs_batch_axes: Vec<usize>,
809 rhs_batch_axes: Vec<usize>,
810}
811
812fn shape_matches_dims(source_shape: &[usize], output_shape: &[usize], dims: &[usize]) -> bool {
813 source_shape.len() == dims.len()
814 && source_shape
815 .iter()
816 .zip(dims.iter())
817 .all(|(&dim, &axis)| output_shape.get(axis).copied() == Some(dim))
818}
819
820fn axes_by_output(dims: &[usize], output_rank: usize) -> Option<Vec<Option<usize>>> {
821 let mut axes = vec![None; output_rank];
822 for (src_axis, &dst_axis) in dims.iter().enumerate() {
823 let slot = axes.get_mut(dst_axis)?;
824 if slot.replace(src_axis).is_some() {
825 return None;
826 }
827 }
828 Some(axes)
829}
830
831fn axes_shape_product<T>(
832 op: &'static str,
833 view: &TypedTensorView<'_, T>,
834 axes: &[usize],
835) -> crate::Result<usize>
836where
837 T: 'static,
838{
839 axes.iter().try_fold(1usize, |acc, &axis| {
840 acc.checked_mul(view.shape()[axis])
841 .ok_or_else(|| crate::Error::backend_failure(op, "shape size overflows usize"))
842 })
843}
844
845fn classify_outer_product_axes(
846 lhs_dims: &[usize],
847 rhs_dims: &[usize],
848 output_rank: usize,
849) -> Option<OuterProductAxisPartition> {
850 let lhs_axes_by_output = axes_by_output(lhs_dims, output_rank)?;
851 let rhs_axes_by_output = axes_by_output(rhs_dims, output_rank)?;
852
853 let mut lhs_free_output_axes = Vec::new();
854 let mut rhs_free_output_axes = Vec::new();
855 let mut batch_output_axes = Vec::new();
856 let mut lhs_free_axes = Vec::new();
857 let mut rhs_free_axes = Vec::new();
858 let mut lhs_batch_axes = Vec::new();
859 let mut rhs_batch_axes = Vec::new();
860
861 for output_axis in 0..output_rank {
862 match (
863 lhs_axes_by_output[output_axis],
864 rhs_axes_by_output[output_axis],
865 ) {
866 (Some(lhs_axis), Some(rhs_axis)) => {
867 batch_output_axes.push(output_axis);
868 lhs_batch_axes.push(lhs_axis);
869 rhs_batch_axes.push(rhs_axis);
870 }
871 (Some(lhs_axis), None) => {
872 lhs_free_output_axes.push(output_axis);
873 lhs_free_axes.push(lhs_axis);
874 }
875 (None, Some(rhs_axis)) => {
876 rhs_free_output_axes.push(output_axis);
877 rhs_free_axes.push(rhs_axis);
878 }
879 (None, None) => return None,
880 }
881 }
882
883 Some(OuterProductAxisPartition {
884 lhs_free_output_axes,
885 rhs_free_output_axes,
886 batch_output_axes,
887 lhs_free_axes,
888 rhs_free_axes,
889 lhs_batch_axes,
890 rhs_batch_axes,
891 })
892}
893
894fn output_axes_match_partition(output_rank: usize, groups: &[&[usize]]) -> bool {
895 groups
896 .iter()
897 .flat_map(|group| group.iter().copied())
898 .eq(0..output_rank)
899}
900
901fn split_outer_product_plan<T>(
902 lhs: &TypedTensorView<'_, T>,
903 lhs_shape: &[usize],
904 lhs_dims: &[usize],
905 rhs: &TypedTensorView<'_, T>,
906 rhs_shape: &[usize],
907 rhs_dims: &[usize],
908) -> crate::Result<Option<SplitOuterProductPlan>>
909where
910 T: 'static,
911{
912 let output_rank = lhs_shape.len();
913 if lhs_shape != rhs_shape
914 || !shape_matches_dims(lhs.shape(), lhs_shape, lhs_dims)
915 || !shape_matches_dims(rhs.shape(), rhs_shape, rhs_dims)
916 || lhs.backend_buffer().is_some()
917 || rhs.backend_buffer().is_some()
918 || lhs.offset() < 0
919 || rhs.offset() < 0
920 || lhs.strides().iter().any(|&stride| stride < 0)
921 || rhs.strides().iter().any(|&stride| stride < 0)
922 {
923 return Ok(None);
924 }
925
926 let Some(partition) = classify_outer_product_axes(lhs_dims, rhs_dims, output_rank) else {
927 return Ok(None);
928 };
929
930 let lhs_free_size = axes_shape_product("broadcast_multiply", lhs, &partition.lhs_free_axes)?;
931 let rhs_free_size = axes_shape_product("broadcast_multiply", rhs, &partition.rhs_free_axes)?;
932 if lhs_free_size <= 1 || rhs_free_size <= 1 {
933 return Ok(None);
934 }
935 let batches = axes_shape_product("broadcast_multiply", lhs, &partition.lhs_batch_axes)?;
936
937 let lhs_prefix = output_axes_match_partition(
938 output_rank,
939 &[
940 &partition.lhs_free_output_axes,
941 &partition.rhs_free_output_axes,
942 &partition.batch_output_axes,
943 ],
944 );
945 if lhs_prefix {
946 return Ok(Some(SplitOuterProductPlan {
947 rows: lhs_free_size,
948 cols: rhs_free_size,
949 batches,
950 layout: SplitOuterProductLayout::LhsPrefix,
951 lhs_free_axes: partition.lhs_free_axes,
952 rhs_free_axes: partition.rhs_free_axes,
953 lhs_batch_axes: partition.lhs_batch_axes,
954 rhs_batch_axes: partition.rhs_batch_axes,
955 }));
956 }
957
958 let rhs_prefix = output_axes_match_partition(
959 output_rank,
960 &[
961 &partition.rhs_free_output_axes,
962 &partition.lhs_free_output_axes,
963 &partition.batch_output_axes,
964 ],
965 );
966 if rhs_prefix {
967 return Ok(Some(SplitOuterProductPlan {
968 rows: rhs_free_size,
969 cols: lhs_free_size,
970 batches,
971 layout: SplitOuterProductLayout::RhsPrefix,
972 lhs_free_axes: partition.lhs_free_axes,
973 rhs_free_axes: partition.rhs_free_axes,
974 lhs_batch_axes: partition.lhs_batch_axes,
975 rhs_batch_axes: partition.rhs_batch_axes,
976 }));
977 }
978
979 Ok(None)
980}
981
982fn try_outer_product_with_pool<T>(
983 buffers: &mut BufferPool,
984 lhs: &TypedTensorView<'_, T>,
985 lhs_shape: &[usize],
986 lhs_dims: &[usize],
987 rhs: &TypedTensorView<'_, T>,
988 rhs_shape: &[usize],
989 rhs_dims: &[usize],
990) -> crate::Result<Option<TypedTensor<T>>>
991where
992 T: Copy + Clone + Mul<Output = T> + PoolScalar + 'static,
993{
994 let Some(plan) = split_outer_product_plan(lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims)?
995 else {
996 return Ok(None);
997 };
998
999 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs_shape) };
1001 let lhs_view = typed_view_from_view("broadcast_multiply", lhs)?;
1002 let rhs_view = typed_view_from_view("broadcast_multiply", rhs)?;
1003 match plan.layout {
1004 SplitOuterProductLayout::LhsPrefix => {
1005 let lhs_perm: Vec<_> = plan
1006 .lhs_free_axes
1007 .iter()
1008 .chain(plan.lhs_batch_axes.iter())
1009 .copied()
1010 .collect();
1011 let rhs_perm: Vec<_> = plan
1012 .rhs_free_axes
1013 .iter()
1014 .chain(plan.rhs_batch_axes.iter())
1015 .copied()
1016 .collect();
1017 let lhs_outer = lhs_view
1018 .permute(&lhs_perm)
1019 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1020 let rhs_outer = rhs_view
1021 .permute(&rhs_perm)
1022 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1023 batched_outer_product_into(
1024 &mut out.view_mut(),
1025 &lhs_outer,
1026 &rhs_outer,
1027 plan.lhs_free_axes.len(),
1028 plan.rhs_free_axes.len(),
1029 )
1030 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1031 }
1032 SplitOuterProductLayout::RhsPrefix => {
1033 let lhs_perm: Vec<_> = plan
1034 .lhs_free_axes
1035 .iter()
1036 .chain(plan.lhs_batch_axes.iter())
1037 .copied()
1038 .collect();
1039 let rhs_perm: Vec<_> = plan
1040 .rhs_free_axes
1041 .iter()
1042 .chain(plan.rhs_batch_axes.iter())
1043 .copied()
1044 .collect();
1045 let lhs_outer = lhs_view
1046 .permute(&lhs_perm)
1047 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1048 let rhs_outer = rhs_view
1049 .permute(&rhs_perm)
1050 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1051 batched_outer_product_into(
1052 &mut out.view_mut(),
1053 &rhs_outer,
1054 &lhs_outer,
1055 plan.rhs_free_axes.len(),
1056 plan.lhs_free_axes.len(),
1057 )
1058 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1059 }
1060 }
1061 Ok(Some(tensor_from_array(out)))
1062}
1063
1064struct LazyOuterProduct<T> {
1065 base: TypedTensor<T>,
1066 shape: Vec<usize>,
1067 strides: Vec<isize>,
1068}
1069
1070fn axes_by_physical_stride<T>(view: &TypedTensorView<'_, T>, axes: &[usize]) -> Vec<usize>
1071where
1072 T: 'static,
1073{
1074 let mut sorted = axes.to_vec();
1075 sorted.sort_by(|&lhs_axis, &rhs_axis| {
1076 view.strides()[lhs_axis]
1077 .cmp(&view.strides()[rhs_axis])
1078 .then_with(|| lhs_axis.cmp(&rhs_axis))
1079 });
1080 sorted
1081}
1082
1083fn append_axis_shapes<T>(shape: &mut Vec<usize>, view: &TypedTensorView<'_, T>, axes: &[usize])
1084where
1085 T: 'static,
1086{
1087 shape.extend(axes.iter().map(|&axis| view.shape()[axis]));
1088}
1089
1090fn set_lazy_stride(
1091 logical_strides: &mut [Option<isize>],
1092 output_axis: usize,
1093 stride: isize,
1094) -> crate::Result<()> {
1095 let rank = logical_strides.len();
1096 let slot = logical_strides
1097 .get_mut(output_axis)
1098 .ok_or(crate::Error::AxisOutOfBounds {
1099 op: "broadcast_multiply",
1100 axis: output_axis,
1101 rank,
1102 })?;
1103 if slot.replace(stride).is_some() {
1104 return Err(crate::Error::DuplicateAxis {
1105 op: "broadcast_multiply",
1106 axis: output_axis,
1107 role: "lazy output layout",
1108 });
1109 }
1110 Ok(())
1111}
1112
1113struct LazyOuterProductStrideSpec<'a> {
1114 output_shape: &'a [usize],
1115 base_shape: &'a [usize],
1116 leading_axes: &'a [usize],
1117 leading_dims: &'a [usize],
1118 trailing_axes: &'a [usize],
1119 trailing_dims: &'a [usize],
1120 lhs_batch_axes: &'a [usize],
1121 rhs_batch_axes: &'a [usize],
1122 lhs_dims: &'a [usize],
1123 rhs_dims: &'a [usize],
1124}
1125
1126fn lazy_outer_product_strides(spec: LazyOuterProductStrideSpec<'_>) -> crate::Result<Vec<isize>> {
1127 let base_strides = col_major_strides(spec.base_shape)?;
1128 let mut logical_strides = vec![None; spec.output_shape.len()];
1129 let mut base_axis = 0usize;
1130
1131 for &axis in spec.leading_axes {
1132 set_lazy_stride(
1133 &mut logical_strides,
1134 spec.leading_dims[axis],
1135 base_strides[base_axis],
1136 )?;
1137 base_axis += 1;
1138 }
1139 for &axis in spec.trailing_axes {
1140 set_lazy_stride(
1141 &mut logical_strides,
1142 spec.trailing_dims[axis],
1143 base_strides[base_axis],
1144 )?;
1145 base_axis += 1;
1146 }
1147 for (&lhs_axis, &rhs_axis) in spec.lhs_batch_axes.iter().zip(spec.rhs_batch_axes.iter()) {
1148 let output_axis = spec.lhs_dims[lhs_axis];
1149 if spec.rhs_dims[rhs_axis] != output_axis {
1150 return Err(crate::Error::backend_failure(
1151 "broadcast_multiply",
1152 "batch axes disagree while building lazy outer-product layout",
1153 ));
1154 }
1155 set_lazy_stride(&mut logical_strides, output_axis, base_strides[base_axis])?;
1156 base_axis += 1;
1157 }
1158
1159 logical_strides
1160 .into_iter()
1161 .collect::<Option<Vec<_>>>()
1162 .ok_or_else(|| {
1163 crate::Error::backend_failure(
1164 "broadcast_multiply",
1165 "lazy outer-product layout did not cover every output axis",
1166 )
1167 })
1168}
1169
1170fn lazy_outer_product_value(
1171 tensor: Tensor,
1172 shape: Vec<usize>,
1173 strides: Vec<isize>,
1174) -> crate::Result<TensorValue> {
1175 Ok(TensorValue::View(TensorOwnedView::from_parts(
1176 Arc::new(tensor),
1177 shape,
1178 strides,
1179 0,
1180 )?))
1181}
1182
1183fn try_lazy_outer_product_with_pool<T>(
1184 buffers: &mut BufferPool,
1185 lhs: &TypedTensorView<'_, T>,
1186 lhs_shape: &[usize],
1187 lhs_dims: &[usize],
1188 rhs: &TypedTensorView<'_, T>,
1189 rhs_shape: &[usize],
1190 rhs_dims: &[usize],
1191) -> crate::Result<Option<LazyOuterProduct<T>>>
1192where
1193 T: Copy + Clone + Mul<Output = T> + PoolScalar + 'static,
1194{
1195 let Some(plan) = split_outer_product_plan(lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims)?
1196 else {
1197 return Ok(None);
1198 };
1199
1200 let lhs_free_axes = axes_by_physical_stride(lhs, &plan.lhs_free_axes);
1201 let rhs_free_axes = axes_by_physical_stride(rhs, &plan.rhs_free_axes);
1202 if lhs_free_axes == plan.lhs_free_axes && rhs_free_axes == plan.rhs_free_axes {
1203 return Ok(None);
1204 }
1205
1206 let lhs_view = typed_view_from_view("broadcast_multiply", lhs)?;
1207 let rhs_view = typed_view_from_view("broadcast_multiply", rhs)?;
1208
1209 match plan.layout {
1210 SplitOuterProductLayout::LhsPrefix => {
1211 let lhs_perm: Vec<_> = lhs_free_axes
1212 .iter()
1213 .chain(plan.lhs_batch_axes.iter())
1214 .copied()
1215 .collect();
1216 let rhs_perm: Vec<_> = rhs_free_axes
1217 .iter()
1218 .chain(plan.rhs_batch_axes.iter())
1219 .copied()
1220 .collect();
1221 let lhs_outer = lhs_view
1222 .permute(&lhs_perm)
1223 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1224 let rhs_outer = rhs_view
1225 .permute(&rhs_perm)
1226 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1227
1228 let mut base_shape = Vec::with_capacity(lhs_shape.len());
1229 append_axis_shapes(&mut base_shape, lhs, &lhs_free_axes);
1230 append_axis_shapes(&mut base_shape, rhs, &rhs_free_axes);
1231 append_axis_shapes(&mut base_shape, lhs, &plan.lhs_batch_axes);
1232 let strides = lazy_outer_product_strides(LazyOuterProductStrideSpec {
1233 output_shape: lhs_shape,
1234 base_shape: &base_shape,
1235 leading_axes: &lhs_free_axes,
1236 leading_dims: lhs_dims,
1237 trailing_axes: &rhs_free_axes,
1238 trailing_dims: rhs_dims,
1239 lhs_batch_axes: &plan.lhs_batch_axes,
1240 rhs_batch_axes: &plan.rhs_batch_axes,
1241 lhs_dims,
1242 rhs_dims,
1243 })?;
1244
1245 let mut base = unsafe { typed_array_uninit_from_pool(buffers, &base_shape) };
1247 batched_outer_product_into(
1248 &mut base.view_mut(),
1249 &lhs_outer,
1250 &rhs_outer,
1251 lhs_free_axes.len(),
1252 rhs_free_axes.len(),
1253 )
1254 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1255 Ok(Some(LazyOuterProduct {
1256 base: tensor_from_array(base),
1257 shape: lhs_shape.to_vec(),
1258 strides,
1259 }))
1260 }
1261 SplitOuterProductLayout::RhsPrefix => {
1262 let lhs_perm: Vec<_> = lhs_free_axes
1263 .iter()
1264 .chain(plan.lhs_batch_axes.iter())
1265 .copied()
1266 .collect();
1267 let rhs_perm: Vec<_> = rhs_free_axes
1268 .iter()
1269 .chain(plan.rhs_batch_axes.iter())
1270 .copied()
1271 .collect();
1272 let lhs_outer = lhs_view
1273 .permute(&lhs_perm)
1274 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1275 let rhs_outer = rhs_view
1276 .permute(&rhs_perm)
1277 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1278
1279 let mut base_shape = Vec::with_capacity(lhs_shape.len());
1280 append_axis_shapes(&mut base_shape, rhs, &rhs_free_axes);
1281 append_axis_shapes(&mut base_shape, lhs, &lhs_free_axes);
1282 append_axis_shapes(&mut base_shape, lhs, &plan.lhs_batch_axes);
1283 let strides = lazy_outer_product_strides(LazyOuterProductStrideSpec {
1284 output_shape: lhs_shape,
1285 base_shape: &base_shape,
1286 leading_axes: &rhs_free_axes,
1287 leading_dims: rhs_dims,
1288 trailing_axes: &lhs_free_axes,
1289 trailing_dims: lhs_dims,
1290 lhs_batch_axes: &plan.lhs_batch_axes,
1291 rhs_batch_axes: &plan.rhs_batch_axes,
1292 lhs_dims,
1293 rhs_dims,
1294 })?;
1295
1296 let mut base = unsafe { typed_array_uninit_from_pool(buffers, &base_shape) };
1298 batched_outer_product_into(
1299 &mut base.view_mut(),
1300 &rhs_outer,
1301 &lhs_outer,
1302 rhs_free_axes.len(),
1303 lhs_free_axes.len(),
1304 )
1305 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1306 Ok(Some(LazyOuterProduct {
1307 base: tensor_from_array(base),
1308 shape: lhs_shape.to_vec(),
1309 strides,
1310 }))
1311 }
1312 }
1313}
1314
1315#[allow(clippy::too_many_arguments)]
1316fn typed_broadcast_mul_view_with_pool<T, L, R>(
1317 buffers: &mut BufferPool,
1318 lhs: &TypedTensorView<'_, T, L>,
1319 lhs_shape: &[usize],
1320 lhs_dims: &[usize],
1321 rhs: &TypedTensorView<'_, T, R>,
1322 rhs_shape: &[usize],
1323 rhs_dims: &[usize],
1324) -> crate::Result<TypedTensor<T>>
1325where
1326 T: Copy + Clone + Zero + Mul<Output = T> + PoolScalar + 'static,
1327 L: TensorRank,
1328 R: TensorRank,
1329{
1330 if lhs_shape != rhs_shape {
1331 return Err(crate::Error::ShapeMismatch {
1332 op: "broadcast_multiply",
1333 lhs: lhs_shape.to_vec(),
1334 rhs: rhs_shape.to_vec(),
1335 });
1336 }
1337 let output_rank = lhs_shape.len();
1338 let lhs_is_scalar = lhs.shape().is_empty() && lhs_dims.is_empty();
1339 let rhs_is_scalar = rhs.shape().is_empty() && rhs_dims.is_empty();
1340 let lhs_is_full_output =
1341 lhs.shape() == lhs_shape && lhs_dims.iter().copied().eq(0..output_rank);
1342 let rhs_is_full_output =
1343 rhs.shape() == rhs_shape && rhs_dims.iter().copied().eq(0..output_rank);
1344 if lhs_is_scalar && rhs_is_scalar {
1345 let lhs_scalar = typed_view_from_view("broadcast_multiply", lhs)?.get(&[]);
1346 let rhs_scalar = typed_view_from_view("broadcast_multiply", rhs)?.get(&[]);
1347 return filled_broadcast_multiply_tensor(buffers, lhs_shape, lhs_scalar * rhs_scalar);
1348 }
1349 if lhs_is_scalar && rhs_is_full_output {
1350 let scalar = typed_view_from_view("broadcast_multiply", lhs)?.get(&[]);
1351 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs_shape) };
1353 map_into(
1354 &mut out.view_mut(),
1355 &typed_view_from_view("broadcast_multiply", rhs)?,
1356 |x| scalar * x,
1357 )
1358 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1359 return Ok(tensor_from_array(out));
1360 }
1361 if rhs_is_scalar && lhs_is_full_output {
1362 let scalar = typed_view_from_view("broadcast_multiply", rhs)?.get(&[]);
1363 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs_shape) };
1365 map_into(
1366 &mut out.view_mut(),
1367 &typed_view_from_view("broadcast_multiply", lhs)?,
1368 |x| x * scalar,
1369 )
1370 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1371 return Ok(tensor_from_array(out));
1372 }
1373
1374 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs_shape) };
1376 let lhs_view = typed_view_from_view("broadcast_multiply", lhs)?;
1377 let rhs_view = typed_view_from_view("broadcast_multiply", rhs)?;
1378 broadcast_mul_into(
1379 &mut out.view_mut(),
1380 &lhs_view,
1381 lhs_dims,
1382 &rhs_view,
1383 rhs_dims,
1384 )
1385 .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1386 Ok(tensor_from_array(out))
1387}
1388
1389fn filled_broadcast_multiply_tensor<T>(
1390 buffers: &mut BufferPool,
1391 shape: &[usize],
1392 fill: T,
1393) -> crate::Result<TypedTensor<T>>
1394where
1395 T: Copy + Clone + PoolScalar + 'static,
1396{
1397 let len = shape.iter().try_fold(1usize, |acc, &dim| {
1398 acc.checked_mul(dim).ok_or_else(|| {
1399 crate::Error::backend_failure("broadcast_multiply", "output shape size overflows usize")
1400 })
1401 })?;
1402 let mut data = unsafe { T::pool_acquire(buffers, len) };
1404 data.fill(fill);
1405 TypedTensor::from_vec_col_major(shape.to_vec(), data)
1406}
1407
1408#[allow(clippy::too_many_arguments)]
1409pub(crate) fn broadcast_multiply_read_with_pool(
1410 buffers: &mut BufferPool,
1411 lhs: TensorRead<'_>,
1412 lhs_shape: &[usize],
1413 lhs_dims: &[usize],
1414 rhs: TensorRead<'_>,
1415 rhs_shape: &[usize],
1416 rhs_dims: &[usize],
1417) -> crate::Result<Option<Tensor>> {
1418 let lhs = read_as_cpu_view(lhs);
1419 let rhs = read_as_cpu_view(rhs);
1420
1421 macro_rules! dispatch {
1422 ($variant:ident, $lhs:expr, $rhs:expr) => {{
1423 if let Some(out) = try_outer_product_with_pool(
1424 buffers, &$lhs, lhs_shape, lhs_dims, &$rhs, rhs_shape, rhs_dims,
1425 )? {
1426 return Ok(Some(Tensor::$variant(out)));
1427 }
1428 Ok(Some(Tensor::$variant(typed_broadcast_mul_view_with_pool(
1429 buffers, &$lhs, lhs_shape, lhs_dims, &$rhs, rhs_shape, rhs_dims,
1430 )?)))
1431 }};
1432 }
1433
1434 match (lhs, rhs) {
1435 (CpuReadView::F32(lhs), CpuReadView::F32(rhs)) => dispatch!(F32, lhs, rhs),
1436 (CpuReadView::F64(lhs), CpuReadView::F64(rhs)) => dispatch!(F64, lhs, rhs),
1437 (CpuReadView::I32(lhs), CpuReadView::I32(rhs)) => dispatch!(I32, lhs, rhs),
1438 (CpuReadView::I64(lhs), CpuReadView::I64(rhs)) => dispatch!(I64, lhs, rhs),
1439 (CpuReadView::C32(lhs), CpuReadView::C32(rhs)) => dispatch!(C32, lhs, rhs),
1440 (CpuReadView::C64(lhs), CpuReadView::C64(rhs)) => dispatch!(C64, lhs, rhs),
1441 _ => Ok(None),
1442 }
1443}
1444
1445#[allow(clippy::too_many_arguments)]
1446pub(crate) fn broadcast_multiply_value_with_pool(
1447 buffers: &mut BufferPool,
1448 lhs: TensorRead<'_>,
1449 lhs_shape: &[usize],
1450 lhs_dims: &[usize],
1451 rhs: TensorRead<'_>,
1452 rhs_shape: &[usize],
1453 rhs_dims: &[usize],
1454) -> crate::Result<Option<TensorValue>> {
1455 let lhs_view = read_as_cpu_view(lhs.clone());
1456 let rhs_view = read_as_cpu_view(rhs.clone());
1457
1458 macro_rules! dispatch_lazy {
1459 ($variant:ident, $lhs:expr, $rhs:expr) => {{
1460 if let Some(out) = try_lazy_outer_product_with_pool(
1461 buffers, &$lhs, lhs_shape, lhs_dims, &$rhs, rhs_shape, rhs_dims,
1462 )? {
1463 return Ok(Some(lazy_outer_product_value(
1464 Tensor::$variant(out.base),
1465 out.shape,
1466 out.strides,
1467 )?));
1468 }
1469 }};
1470 }
1471
1472 match (lhs_view, rhs_view) {
1473 (CpuReadView::F32(lhs_view), CpuReadView::F32(rhs_view)) => {
1474 dispatch_lazy!(F32, lhs_view, rhs_view);
1475 }
1476 (CpuReadView::F64(lhs_view), CpuReadView::F64(rhs_view)) => {
1477 dispatch_lazy!(F64, lhs_view, rhs_view);
1478 }
1479 (CpuReadView::I32(lhs_view), CpuReadView::I32(rhs_view)) => {
1480 dispatch_lazy!(I32, lhs_view, rhs_view);
1481 }
1482 (CpuReadView::I64(lhs_view), CpuReadView::I64(rhs_view)) => {
1483 dispatch_lazy!(I64, lhs_view, rhs_view);
1484 }
1485 (CpuReadView::C32(lhs_view), CpuReadView::C32(rhs_view)) => {
1486 dispatch_lazy!(C32, lhs_view, rhs_view);
1487 }
1488 (CpuReadView::C64(lhs_view), CpuReadView::C64(rhs_view)) => {
1489 dispatch_lazy!(C64, lhs_view, rhs_view);
1490 }
1491 _ => {}
1492 }
1493
1494 broadcast_multiply_read_with_pool(buffers, lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims)
1495 .map(|tensor| tensor.map(TensorValue::from_tensor))
1496}
1497
1498pub fn div(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
1513 with_local_pool(|buffers| div_with_pool(buffers, lhs, rhs))
1514}
1515
1516pub(crate) fn div_with_pool(
1517 buffers: &mut BufferPool,
1518 lhs: &Tensor,
1519 rhs: &Tensor,
1520) -> crate::Result<Tensor> {
1521 match (lhs, rhs) {
1522 (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_div_with_pool(buffers, a, b)?)),
1523 (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_div_with_pool(buffers, a, b)?)),
1524 (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_div_with_pool(buffers, a, b)?)),
1525 (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_div_with_pool(buffers, a, b)?)),
1526 (Tensor::F32(a), Tensor::C32(b)) if a.shape().is_empty() => {
1527 let scalar = complex_scalar_tensor(typed_host_data("div", a)?[0])?;
1528 Ok(Tensor::C32(typed_div_with_pool(buffers, &scalar, b)?))
1529 }
1530 (Tensor::C32(a), Tensor::F32(b)) if b.shape().is_empty() => {
1531 let scalar = complex_scalar_tensor(typed_host_data("div", b)?[0])?;
1532 Ok(Tensor::C32(typed_div_with_pool(buffers, a, &scalar)?))
1533 }
1534 (Tensor::F64(a), Tensor::C64(b)) if a.shape().is_empty() => {
1535 let scalar = complex_scalar_tensor(typed_host_data("div", a)?[0])?;
1536 Ok(Tensor::C64(typed_div_with_pool(buffers, &scalar, b)?))
1537 }
1538 (Tensor::C64(a), Tensor::F64(b)) if b.shape().is_empty() => {
1539 let scalar = complex_scalar_tensor(typed_host_data("div", b)?[0])?;
1540 Ok(Tensor::C64(typed_div_with_pool(buffers, a, &scalar)?))
1541 }
1542 _ => Err(crate::Error::DTypeMismatch {
1543 op: "div",
1544 lhs: lhs.dtype(),
1545 rhs: rhs.dtype(),
1546 }),
1547 }
1548}
1549
1550pub(crate) fn div_read_with_pool(
1551 buffers: &mut BufferPool,
1552 lhs: TensorRead<'_>,
1553 rhs: TensorRead<'_>,
1554) -> crate::Result<Tensor> {
1555 let lhs_dtype = lhs.dtype();
1556 let rhs_dtype = rhs.dtype();
1557 match (read_as_cpu_view(lhs), read_as_cpu_view(rhs)) {
1558 (CpuReadView::F32(a), CpuReadView::F32(b)) => Ok(Tensor::F32(typed_binary_view_with_pool(
1559 "div",
1560 buffers,
1561 &a,
1562 &b,
1563 |x, y| x / y,
1564 )?)),
1565 (CpuReadView::F64(a), CpuReadView::F64(b)) => Ok(Tensor::F64(typed_binary_view_with_pool(
1566 "div",
1567 buffers,
1568 &a,
1569 &b,
1570 |x, y| x / y,
1571 )?)),
1572 (CpuReadView::C32(a), CpuReadView::C32(b)) => Ok(Tensor::C32(typed_binary_view_with_pool(
1573 "div",
1574 buffers,
1575 &a,
1576 &b,
1577 |x, y| x / y,
1578 )?)),
1579 (CpuReadView::C64(a), CpuReadView::C64(b)) => Ok(Tensor::C64(typed_binary_view_with_pool(
1580 "div",
1581 buffers,
1582 &a,
1583 &b,
1584 |x, y| x / y,
1585 )?)),
1586 (CpuReadView::F32(real), CpuReadView::C32(complex)) if real.shape().is_empty() => {
1587 let scalar = complex_scalar_tensor_from_view(&real)?;
1588 let scalar = scalar.as_view();
1589 Ok(Tensor::C32(typed_binary_view_with_pool(
1590 "div",
1591 buffers,
1592 &scalar,
1593 &complex,
1594 |x, y| x / y,
1595 )?))
1596 }
1597 (CpuReadView::C32(complex), CpuReadView::F32(real)) if real.shape().is_empty() => {
1598 let scalar = complex_scalar_tensor_from_view(&real)?;
1599 let scalar = scalar.as_view();
1600 Ok(Tensor::C32(typed_binary_view_with_pool(
1601 "div",
1602 buffers,
1603 &complex,
1604 &scalar,
1605 |x, y| x / y,
1606 )?))
1607 }
1608 (CpuReadView::F64(real), CpuReadView::C64(complex)) if real.shape().is_empty() => {
1609 let scalar = complex_scalar_tensor_from_view(&real)?;
1610 let scalar = scalar.as_view();
1611 Ok(Tensor::C64(typed_binary_view_with_pool(
1612 "div",
1613 buffers,
1614 &scalar,
1615 &complex,
1616 |x, y| x / y,
1617 )?))
1618 }
1619 (CpuReadView::C64(complex), CpuReadView::F64(real)) if real.shape().is_empty() => {
1620 let scalar = complex_scalar_tensor_from_view(&real)?;
1621 let scalar = scalar.as_view();
1622 Ok(Tensor::C64(typed_binary_view_with_pool(
1623 "div",
1624 buffers,
1625 &complex,
1626 &scalar,
1627 |x, y| x / y,
1628 )?))
1629 }
1630 _ => Err(crate::Error::DTypeMismatch {
1631 op: "div",
1632 lhs: lhs_dtype,
1633 rhs: rhs_dtype,
1634 }),
1635 }
1636}
1637
1638pub fn neg(input: &Tensor) -> crate::Result<Tensor> {
1652 with_local_pool(|buffers| neg_with_pool(buffers, input))
1653}
1654
1655pub(crate) fn neg_with_pool(buffers: &mut BufferPool, input: &Tensor) -> crate::Result<Tensor> {
1656 match input {
1657 Tensor::F32(t) => Ok(Tensor::F32(typed_neg_with_pool(buffers, t)?)),
1658 Tensor::F64(t) => Ok(Tensor::F64(typed_neg_with_pool(buffers, t)?)),
1659 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(crate::Error::backend_failure(
1660 "neg",
1661 format!("unsupported dtype {:?}", input.dtype()),
1662 )),
1663 Tensor::C32(t) => Ok(Tensor::C32(typed_neg_with_pool(buffers, t)?)),
1664 Tensor::C64(t) => Ok(Tensor::C64(typed_neg_with_pool(buffers, t)?)),
1665 }
1666}
1667
1668pub(crate) fn neg_read_with_pool(
1669 buffers: &mut BufferPool,
1670 input: TensorRead<'_>,
1671) -> crate::Result<Tensor> {
1672 let dtype = input.dtype();
1673 match read_as_cpu_view(input) {
1674 CpuReadView::F32(t) => Ok(Tensor::F32(typed_unary_view_with_pool(
1675 "neg",
1676 buffers,
1677 &t,
1678 |x| -x,
1679 )?)),
1680 CpuReadView::F64(t) => Ok(Tensor::F64(typed_unary_view_with_pool(
1681 "neg",
1682 buffers,
1683 &t,
1684 |x| -x,
1685 )?)),
1686 CpuReadView::C32(t) => Ok(Tensor::C32(typed_unary_view_with_pool(
1687 "neg",
1688 buffers,
1689 &t,
1690 |x| -x,
1691 )?)),
1692 CpuReadView::C64(t) => Ok(Tensor::C64(typed_unary_view_with_pool(
1693 "neg",
1694 buffers,
1695 &t,
1696 |x| -x,
1697 )?)),
1698 _ => Err(crate::Error::backend_failure(
1699 "neg",
1700 format!("unsupported dtype {dtype:?}"),
1701 )),
1702 }
1703}
1704
1705pub fn conj(input: &Tensor) -> crate::Result<Tensor> {
1720 with_local_pool(|buffers| conj_with_pool(buffers, input))
1721}
1722
1723pub(crate) fn conj_with_pool(buffers: &mut BufferPool, input: &Tensor) -> crate::Result<Tensor> {
1724 match input {
1725 Tensor::F32(t) => Ok(Tensor::F32(typed_conj_with_pool(buffers, t)?)),
1726 Tensor::F64(t) => Ok(Tensor::F64(typed_conj_with_pool(buffers, t)?)),
1727 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(crate::Error::backend_failure(
1728 "conj",
1729 format!("unsupported dtype {:?}", input.dtype()),
1730 )),
1731 Tensor::C32(t) => Ok(Tensor::C32(typed_conj_with_pool(buffers, t)?)),
1732 Tensor::C64(t) => Ok(Tensor::C64(typed_conj_with_pool(buffers, t)?)),
1733 }
1734}
1735
1736pub(crate) fn conj_read_with_pool(
1737 buffers: &mut BufferPool,
1738 input: TensorRead<'_>,
1739) -> crate::Result<Tensor> {
1740 let dtype = input.dtype();
1741 match read_as_cpu_view(input) {
1742 CpuReadView::F32(t) => Ok(Tensor::F32(typed_unary_view_with_pool(
1743 "conj",
1744 buffers,
1745 &t,
1746 |x| x.conj_elem(),
1747 )?)),
1748 CpuReadView::F64(t) => Ok(Tensor::F64(typed_unary_view_with_pool(
1749 "conj",
1750 buffers,
1751 &t,
1752 |x| x.conj_elem(),
1753 )?)),
1754 CpuReadView::C32(t) => Ok(Tensor::C32(typed_unary_view_with_pool(
1755 "conj",
1756 buffers,
1757 &t,
1758 |x| x.conj_elem(),
1759 )?)),
1760 CpuReadView::C64(t) => Ok(Tensor::C64(typed_unary_view_with_pool(
1761 "conj",
1762 buffers,
1763 &t,
1764 |x| x.conj_elem(),
1765 )?)),
1766 _ => Err(crate::Error::backend_failure(
1767 "conj",
1768 format!("unsupported dtype {dtype:?}"),
1769 )),
1770 }
1771}
1772
1773pub fn abs(input: &Tensor) -> crate::Result<Tensor> {
1789 with_local_pool(|buffers| abs_with_pool(buffers, input))
1790}
1791
1792pub(crate) fn abs_with_pool(buffers: &mut BufferPool, input: &Tensor) -> crate::Result<Tensor> {
1793 match input {
1794 Tensor::F32(t) => Ok(Tensor::F32(typed_abs_with_pool(buffers, t)?)),
1795 Tensor::F64(t) => Ok(Tensor::F64(typed_abs_with_pool(buffers, t)?)),
1796 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(crate::Error::backend_failure(
1797 "abs",
1798 format!("unsupported dtype {:?}", input.dtype()),
1799 )),
1800 Tensor::C32(t) => Ok(Tensor::F32(typed_complex_abs_with_pool(buffers, t)?)),
1801 Tensor::C64(t) => Ok(Tensor::F64(typed_complex_abs_with_pool(buffers, t)?)),
1802 }
1803}
1804
1805pub(crate) fn abs_read_with_pool(
1806 buffers: &mut BufferPool,
1807 input: TensorRead<'_>,
1808) -> crate::Result<Tensor> {
1809 let dtype = input.dtype();
1810 match read_as_cpu_view(input) {
1811 CpuReadView::F32(t) => Ok(Tensor::F32(typed_unary_view_with_pool(
1812 "abs",
1813 buffers,
1814 &t,
1815 |x| x.abs_elem(),
1816 )?)),
1817 CpuReadView::F64(t) => Ok(Tensor::F64(typed_unary_view_with_pool(
1818 "abs",
1819 buffers,
1820 &t,
1821 |x| x.abs_elem(),
1822 )?)),
1823 CpuReadView::C32(t) => Ok(Tensor::F32(typed_complex_abs_view_with_pool(buffers, &t)?)),
1824 CpuReadView::C64(t) => Ok(Tensor::F64(typed_complex_abs_view_with_pool(buffers, &t)?)),
1825 _ => Err(crate::Error::backend_failure(
1826 "abs",
1827 format!("unsupported dtype {dtype:?}"),
1828 )),
1829 }
1830}
1831
1832pub fn sign(input: &Tensor) -> crate::Result<Tensor> {
1846 with_local_pool(|buffers| sign_with_pool(buffers, input))
1847}
1848
1849pub(crate) fn sign_with_pool(buffers: &mut BufferPool, input: &Tensor) -> crate::Result<Tensor> {
1850 match input {
1851 Tensor::F32(t) => Ok(Tensor::F32(typed_sign_with_pool(buffers, t)?)),
1852 Tensor::F64(t) => Ok(Tensor::F64(typed_sign_with_pool(buffers, t)?)),
1853 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(crate::Error::backend_failure(
1854 "sign",
1855 format!("unsupported dtype {:?}", input.dtype()),
1856 )),
1857 Tensor::C32(t) => Ok(Tensor::C32(typed_sign_with_pool(buffers, t)?)),
1858 Tensor::C64(t) => Ok(Tensor::C64(typed_sign_with_pool(buffers, t)?)),
1859 }
1860}
1861
1862pub(crate) fn sign_read_with_pool(
1863 buffers: &mut BufferPool,
1864 input: TensorRead<'_>,
1865) -> crate::Result<Tensor> {
1866 let dtype = input.dtype();
1867 match read_as_cpu_view(input) {
1868 CpuReadView::F32(t) => Ok(Tensor::F32(typed_unary_view_with_pool(
1869 "sign",
1870 buffers,
1871 &t,
1872 |x| x.sign_elem(),
1873 )?)),
1874 CpuReadView::F64(t) => Ok(Tensor::F64(typed_unary_view_with_pool(
1875 "sign",
1876 buffers,
1877 &t,
1878 |x| x.sign_elem(),
1879 )?)),
1880 CpuReadView::C32(t) => Ok(Tensor::C32(typed_unary_view_with_pool(
1881 "sign",
1882 buffers,
1883 &t,
1884 |x| x.sign_elem(),
1885 )?)),
1886 CpuReadView::C64(t) => Ok(Tensor::C64(typed_unary_view_with_pool(
1887 "sign",
1888 buffers,
1889 &t,
1890 |x| x.sign_elem(),
1891 )?)),
1892 _ => Err(crate::Error::backend_failure(
1893 "sign",
1894 format!("unsupported dtype {dtype:?}"),
1895 )),
1896 }
1897}
1898
1899pub fn maximum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
1914 with_local_pool(|buffers| maximum_with_pool(buffers, lhs, rhs))
1915}
1916
1917pub(crate) fn maximum_with_pool(
1918 buffers: &mut BufferPool,
1919 lhs: &Tensor,
1920 rhs: &Tensor,
1921) -> crate::Result<Tensor> {
1922 match (lhs, rhs) {
1923 (Tensor::F32(a), Tensor::F32(b)) => {
1924 Ok(Tensor::F32(typed_maximum_with_pool(buffers, a, b)?))
1925 }
1926 (Tensor::F64(a), Tensor::F64(b)) => {
1927 Ok(Tensor::F64(typed_maximum_with_pool(buffers, a, b)?))
1928 }
1929 (Tensor::C32(a), Tensor::C32(b)) => {
1930 Ok(Tensor::C32(typed_maximum_with_pool(buffers, a, b)?))
1931 }
1932 (Tensor::C64(a), Tensor::C64(b)) => {
1933 Ok(Tensor::C64(typed_maximum_with_pool(buffers, a, b)?))
1934 }
1935 _ => Err(tensor_pair_error("maximum", lhs, rhs)),
1936 }
1937}
1938
1939pub(crate) fn maximum_read_with_pool(
1940 buffers: &mut BufferPool,
1941 lhs: TensorRead<'_>,
1942 rhs: TensorRead<'_>,
1943) -> crate::Result<Tensor> {
1944 let lhs_dtype = lhs.dtype();
1945 let rhs_dtype = rhs.dtype();
1946 match (read_as_cpu_view(lhs), read_as_cpu_view(rhs)) {
1947 (CpuReadView::F32(a), CpuReadView::F32(b)) => Ok(Tensor::F32(
1948 typed_same_shape_binary_view_with_pool("maximum", buffers, &a, &b, |x, y| {
1949 x.max_elem(y)
1950 })?,
1951 )),
1952 (CpuReadView::F64(a), CpuReadView::F64(b)) => Ok(Tensor::F64(
1953 typed_same_shape_binary_view_with_pool("maximum", buffers, &a, &b, |x, y| {
1954 x.max_elem(y)
1955 })?,
1956 )),
1957 (CpuReadView::C32(a), CpuReadView::C32(b)) => Ok(Tensor::C32(
1958 typed_same_shape_binary_view_with_pool("maximum", buffers, &a, &b, |x, y| {
1959 x.max_elem(y)
1960 })?,
1961 )),
1962 (CpuReadView::C64(a), CpuReadView::C64(b)) => Ok(Tensor::C64(
1963 typed_same_shape_binary_view_with_pool("maximum", buffers, &a, &b, |x, y| {
1964 x.max_elem(y)
1965 })?,
1966 )),
1967 _ => Err(dtype_pair_error("maximum", lhs_dtype, rhs_dtype)),
1968 }
1969}
1970
1971pub fn minimum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
1986 with_local_pool(|buffers| minimum_with_pool(buffers, lhs, rhs))
1987}
1988
1989pub(crate) fn minimum_with_pool(
1990 buffers: &mut BufferPool,
1991 lhs: &Tensor,
1992 rhs: &Tensor,
1993) -> crate::Result<Tensor> {
1994 match (lhs, rhs) {
1995 (Tensor::F32(a), Tensor::F32(b)) => {
1996 Ok(Tensor::F32(typed_minimum_with_pool(buffers, a, b)?))
1997 }
1998 (Tensor::F64(a), Tensor::F64(b)) => {
1999 Ok(Tensor::F64(typed_minimum_with_pool(buffers, a, b)?))
2000 }
2001 (Tensor::C32(a), Tensor::C32(b)) => {
2002 Ok(Tensor::C32(typed_minimum_with_pool(buffers, a, b)?))
2003 }
2004 (Tensor::C64(a), Tensor::C64(b)) => {
2005 Ok(Tensor::C64(typed_minimum_with_pool(buffers, a, b)?))
2006 }
2007 _ => Err(tensor_pair_error("minimum", lhs, rhs)),
2008 }
2009}
2010
2011pub(crate) fn minimum_read_with_pool(
2012 buffers: &mut BufferPool,
2013 lhs: TensorRead<'_>,
2014 rhs: TensorRead<'_>,
2015) -> crate::Result<Tensor> {
2016 let lhs_dtype = lhs.dtype();
2017 let rhs_dtype = rhs.dtype();
2018 match (read_as_cpu_view(lhs), read_as_cpu_view(rhs)) {
2019 (CpuReadView::F32(a), CpuReadView::F32(b)) => Ok(Tensor::F32(
2020 typed_same_shape_binary_view_with_pool("minimum", buffers, &a, &b, |x, y| {
2021 x.min_elem(y)
2022 })?,
2023 )),
2024 (CpuReadView::F64(a), CpuReadView::F64(b)) => Ok(Tensor::F64(
2025 typed_same_shape_binary_view_with_pool("minimum", buffers, &a, &b, |x, y| {
2026 x.min_elem(y)
2027 })?,
2028 )),
2029 (CpuReadView::C32(a), CpuReadView::C32(b)) => Ok(Tensor::C32(
2030 typed_same_shape_binary_view_with_pool("minimum", buffers, &a, &b, |x, y| {
2031 x.min_elem(y)
2032 })?,
2033 )),
2034 (CpuReadView::C64(a), CpuReadView::C64(b)) => Ok(Tensor::C64(
2035 typed_same_shape_binary_view_with_pool("minimum", buffers, &a, &b, |x, y| {
2036 x.min_elem(y)
2037 })?,
2038 )),
2039 _ => Err(dtype_pair_error("minimum", lhs_dtype, rhs_dtype)),
2040 }
2041}
2042
2043pub fn compare(lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor> {
2058 with_local_pool(|buffers| compare_with_pool(buffers, lhs, rhs, dir))
2059}
2060
2061pub(crate) fn compare_with_pool(
2062 buffers: &mut BufferPool,
2063 lhs: &Tensor,
2064 rhs: &Tensor,
2065 dir: &CompareDir,
2066) -> crate::Result<Tensor> {
2067 match (lhs, rhs) {
2068 (Tensor::F32(a), Tensor::F32(b)) => {
2069 Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2070 }
2071 (Tensor::F64(a), Tensor::F64(b)) => {
2072 Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2073 }
2074 (Tensor::I32(a), Tensor::I32(b)) => {
2075 Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2076 }
2077 (Tensor::I64(a), Tensor::I64(b)) => {
2078 Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2079 }
2080 (Tensor::Bool(a), Tensor::Bool(b)) => {
2081 Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2082 }
2083 (Tensor::C32(a), Tensor::C32(b)) => {
2084 Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2085 }
2086 (Tensor::C64(a), Tensor::C64(b)) => {
2087 Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2088 }
2089 _ => Err(crate::Error::DTypeMismatch {
2090 op: "compare",
2091 lhs: lhs.dtype(),
2092 rhs: rhs.dtype(),
2093 }),
2094 }
2095}
2096
2097pub(crate) fn compare_read_with_pool(
2098 buffers: &mut BufferPool,
2099 lhs: TensorRead<'_>,
2100 rhs: TensorRead<'_>,
2101 dir: &CompareDir,
2102) -> crate::Result<Tensor> {
2103 let lhs_dtype = lhs.dtype();
2104 let rhs_dtype = rhs.dtype();
2105 match (read_as_cpu_view(lhs), read_as_cpu_view(rhs)) {
2106 (CpuReadView::F32(a), CpuReadView::F32(b)) => Ok(Tensor::Bool(
2107 typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2108 x.compare_elem(y, dir)
2109 })?,
2110 )),
2111 (CpuReadView::F64(a), CpuReadView::F64(b)) => Ok(Tensor::Bool(
2112 typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2113 x.compare_elem(y, dir)
2114 })?,
2115 )),
2116 (CpuReadView::I32(a), CpuReadView::I32(b)) => Ok(Tensor::Bool(
2117 typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2118 x.compare_elem(y, dir)
2119 })?,
2120 )),
2121 (CpuReadView::I64(a), CpuReadView::I64(b)) => Ok(Tensor::Bool(
2122 typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2123 x.compare_elem(y, dir)
2124 })?,
2125 )),
2126 (CpuReadView::Bool(a), CpuReadView::Bool(b)) => Ok(Tensor::Bool(
2127 typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2128 x.compare_elem(y, dir)
2129 })?,
2130 )),
2131 (CpuReadView::C32(a), CpuReadView::C32(b)) => Ok(Tensor::Bool(
2132 typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2133 x.compare_elem(y, dir)
2134 })?,
2135 )),
2136 (CpuReadView::C64(a), CpuReadView::C64(b)) => Ok(Tensor::Bool(
2137 typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2138 x.compare_elem(y, dir)
2139 })?,
2140 )),
2141 _ => Err(crate::Error::DTypeMismatch {
2142 op: "compare",
2143 lhs: lhs_dtype,
2144 rhs: rhs_dtype,
2145 }),
2146 }
2147}
2148
2149pub fn select(pred: &Tensor, on_true: &Tensor, on_false: &Tensor) -> crate::Result<Tensor> {
2165 with_local_pool(|buffers| select_with_pool(buffers, pred, on_true, on_false))
2166}
2167
2168pub(crate) fn select_with_pool(
2169 buffers: &mut BufferPool,
2170 pred: &Tensor,
2171 on_true: &Tensor,
2172 on_false: &Tensor,
2173) -> crate::Result<Tensor> {
2174 match (pred, on_true, on_false) {
2175 (Tensor::Bool(p), Tensor::F32(t), Tensor::F32(f)) => {
2176 Ok(Tensor::F32(typed_select_with_pool(buffers, p, t, f)?))
2177 }
2178 (Tensor::Bool(p), Tensor::F64(t), Tensor::F64(f)) => {
2179 Ok(Tensor::F64(typed_select_with_pool(buffers, p, t, f)?))
2180 }
2181 (Tensor::Bool(p), Tensor::I32(t), Tensor::I32(f)) => {
2182 Ok(Tensor::I32(typed_select_with_pool(buffers, p, t, f)?))
2183 }
2184 (Tensor::Bool(p), Tensor::I64(t), Tensor::I64(f)) => {
2185 Ok(Tensor::I64(typed_select_with_pool(buffers, p, t, f)?))
2186 }
2187 (Tensor::Bool(p), Tensor::Bool(t), Tensor::Bool(f)) => {
2188 Ok(Tensor::Bool(typed_select_with_pool(buffers, p, t, f)?))
2189 }
2190 (Tensor::Bool(p), Tensor::C32(t), Tensor::C32(f)) => {
2191 Ok(Tensor::C32(typed_select_with_pool(buffers, p, t, f)?))
2192 }
2193 (Tensor::Bool(p), Tensor::C64(t), Tensor::C64(f)) => {
2194 Ok(Tensor::C64(typed_select_with_pool(buffers, p, t, f)?))
2195 }
2196 (Tensor::Bool(_), _, _) => Err(crate::Error::DTypeMismatch {
2197 op: "select",
2198 lhs: on_true.dtype(),
2199 rhs: on_false.dtype(),
2200 }),
2201 _ => Err(crate::Error::DTypeMismatch {
2202 op: "select",
2203 lhs: pred.dtype(),
2204 rhs: crate::DType::Bool,
2205 }),
2206 }
2207}
2208
2209pub(crate) fn select_read_with_pool(
2210 buffers: &mut BufferPool,
2211 pred: TensorRead<'_>,
2212 on_true: TensorRead<'_>,
2213 on_false: TensorRead<'_>,
2214) -> crate::Result<Tensor> {
2215 let pred_dtype = pred.dtype();
2216 let true_dtype = on_true.dtype();
2217 let false_dtype = on_false.dtype();
2218 match (
2219 read_as_cpu_view(pred),
2220 read_as_cpu_view(on_true),
2221 read_as_cpu_view(on_false),
2222 ) {
2223 (CpuReadView::Bool(p), CpuReadView::F32(t), CpuReadView::F32(f)) => Ok(Tensor::F32(
2224 typed_select_view_with_pool(buffers, &p, &t, &f)?,
2225 )),
2226 (CpuReadView::Bool(p), CpuReadView::F64(t), CpuReadView::F64(f)) => Ok(Tensor::F64(
2227 typed_select_view_with_pool(buffers, &p, &t, &f)?,
2228 )),
2229 (CpuReadView::Bool(p), CpuReadView::I32(t), CpuReadView::I32(f)) => Ok(Tensor::I32(
2230 typed_select_view_with_pool(buffers, &p, &t, &f)?,
2231 )),
2232 (CpuReadView::Bool(p), CpuReadView::I64(t), CpuReadView::I64(f)) => Ok(Tensor::I64(
2233 typed_select_view_with_pool(buffers, &p, &t, &f)?,
2234 )),
2235 (CpuReadView::Bool(p), CpuReadView::Bool(t), CpuReadView::Bool(f)) => Ok(Tensor::Bool(
2236 typed_select_view_with_pool(buffers, &p, &t, &f)?,
2237 )),
2238 (CpuReadView::Bool(p), CpuReadView::C32(t), CpuReadView::C32(f)) => Ok(Tensor::C32(
2239 typed_select_view_with_pool(buffers, &p, &t, &f)?,
2240 )),
2241 (CpuReadView::Bool(p), CpuReadView::C64(t), CpuReadView::C64(f)) => Ok(Tensor::C64(
2242 typed_select_view_with_pool(buffers, &p, &t, &f)?,
2243 )),
2244 (CpuReadView::Bool(_), _, _) => Err(crate::Error::DTypeMismatch {
2245 op: "select",
2246 lhs: true_dtype,
2247 rhs: false_dtype,
2248 }),
2249 _ => Err(crate::Error::DTypeMismatch {
2250 op: "select",
2251 lhs: pred_dtype,
2252 rhs: crate::DType::Bool,
2253 }),
2254 }
2255}
2256
2257pub fn clamp(input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor> {
2273 with_local_pool(|buffers| clamp_with_pool(buffers, input, lower, upper))
2274}
2275
2276pub(crate) fn clamp_with_pool(
2277 buffers: &mut BufferPool,
2278 input: &Tensor,
2279 lower: &Tensor,
2280 upper: &Tensor,
2281) -> crate::Result<Tensor> {
2282 dispatch_ternary_result_with_pool!("clamp", input, lower, upper, |x, lo, hi| {
2283 typed_clamp_with_pool(buffers, x, lo, hi)
2284 })
2285}
2286
2287pub(crate) fn clamp_read_with_pool(
2288 buffers: &mut BufferPool,
2289 input: TensorRead<'_>,
2290 lower: TensorRead<'_>,
2291 upper: TensorRead<'_>,
2292) -> crate::Result<Tensor> {
2293 let input_dtype = input.dtype();
2294 let lower_dtype = lower.dtype();
2295 match (
2296 read_as_cpu_view(input),
2297 read_as_cpu_view(lower),
2298 read_as_cpu_view(upper),
2299 ) {
2300 (CpuReadView::F32(input), CpuReadView::F32(lower), CpuReadView::F32(upper)) => Ok(
2301 Tensor::F32(typed_clamp_view_with_pool(buffers, &input, &lower, &upper)?),
2302 ),
2303 (CpuReadView::F64(input), CpuReadView::F64(lower), CpuReadView::F64(upper)) => Ok(
2304 Tensor::F64(typed_clamp_view_with_pool(buffers, &input, &lower, &upper)?),
2305 ),
2306 (CpuReadView::C32(input), CpuReadView::C32(lower), CpuReadView::C32(upper)) => Ok(
2307 Tensor::C32(typed_clamp_view_with_pool(buffers, &input, &lower, &upper)?),
2308 ),
2309 (CpuReadView::C64(input), CpuReadView::C64(lower), CpuReadView::C64(upper)) => Ok(
2310 Tensor::C64(typed_clamp_view_with_pool(buffers, &input, &lower, &upper)?),
2311 ),
2312 _ => Err(crate::Error::DTypeMismatch {
2313 op: "clamp",
2314 lhs: input_dtype,
2315 rhs: lower_dtype,
2316 }),
2317 }
2318}
2319
2320pub(crate) fn typed_add_with_pool<T>(
2321 buffers: &mut BufferPool,
2322 lhs: &TypedTensor<T>,
2323 rhs: &TypedTensor<T>,
2324) -> crate::Result<TypedTensor<T>>
2325where
2326 T: Copy + Clone + Zero + Add<Output = T> + PoolScalar,
2327{
2328 if lhs.shape() == rhs.shape() {
2329 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2331 zip_map2_into(
2332 &mut out.view_mut(),
2333 &typed_view("add", lhs)?,
2334 &typed_view("add", rhs)?,
2335 |x, y| x + y,
2336 )
2337 .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2338 Ok(tensor_from_array(out))
2339 } else if lhs.shape().is_empty() {
2340 let scalar = typed_host_data("add", lhs)?[0];
2341 let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2343 map_into(&mut out.view_mut(), &typed_view("add", rhs)?, |x| {
2344 scalar + x
2345 })
2346 .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2347 Ok(tensor_from_array(out))
2348 } else if rhs.shape().is_empty() {
2349 let scalar = typed_host_data("add", rhs)?[0];
2350 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2352 map_into(&mut out.view_mut(), &typed_view("add", lhs)?, |x| {
2353 x + scalar
2354 })
2355 .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2356 Ok(tensor_from_array(out))
2357 } else {
2358 Err(crate::Error::ShapeMismatch {
2359 op: "add",
2360 lhs: lhs.shape().to_vec(),
2361 rhs: rhs.shape().to_vec(),
2362 })
2363 }
2364}
2365
2366pub(crate) fn typed_add_view_with_pool<T, L, R>(
2367 buffers: &mut BufferPool,
2368 lhs: &TypedTensorView<'_, T, L>,
2369 rhs: &TypedTensorView<'_, T, R>,
2370) -> crate::Result<TypedTensor<T>>
2371where
2372 T: Copy + Clone + Zero + Add<Output = T> + PoolScalar + 'static,
2373 L: TensorRank,
2374 R: TensorRank,
2375{
2376 if lhs.shape() == rhs.shape() {
2377 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2379 zip_map2_into(
2380 &mut out.view_mut(),
2381 &typed_view_from_view("add", lhs)?,
2382 &typed_view_from_view("add", rhs)?,
2383 |x, y| x + y,
2384 )
2385 .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2386 Ok(tensor_from_array(out))
2387 } else if lhs.shape().is_empty() {
2388 let scalar = typed_view_from_view("add", lhs)?.get(&[]);
2389 let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2391 map_into(
2392 &mut out.view_mut(),
2393 &typed_view_from_view("add", rhs)?,
2394 |x| scalar + x,
2395 )
2396 .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2397 Ok(tensor_from_array(out))
2398 } else if rhs.shape().is_empty() {
2399 let scalar = typed_view_from_view("add", rhs)?.get(&[]);
2400 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2402 map_into(
2403 &mut out.view_mut(),
2404 &typed_view_from_view("add", lhs)?,
2405 |x| x + scalar,
2406 )
2407 .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2408 Ok(tensor_from_array(out))
2409 } else {
2410 Err(crate::Error::ShapeMismatch {
2411 op: "add",
2412 lhs: lhs.shape().to_vec(),
2413 rhs: rhs.shape().to_vec(),
2414 })
2415 }
2416}
2417
2418pub(crate) fn typed_mul_with_pool<T>(
2419 buffers: &mut BufferPool,
2420 lhs: &TypedTensor<T>,
2421 rhs: &TypedTensor<T>,
2422) -> crate::Result<TypedTensor<T>>
2423where
2424 T: Copy + Clone + Zero + Mul<Output = T> + PoolScalar + 'static,
2425{
2426 if lhs.shape() == rhs.shape() {
2427 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2429 mul_into(
2430 &mut out.view_mut(),
2431 &typed_view("mul", lhs)?,
2432 &typed_view("mul", rhs)?,
2433 )
2434 .map_err(|err| crate::Error::backend_failure("mul", err))?;
2435 Ok(tensor_from_array(out))
2436 } else if lhs.shape().is_empty() {
2437 let scalar = typed_host_data("mul", lhs)?[0];
2438 let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2440 map_into(&mut out.view_mut(), &typed_view("mul", rhs)?, |x| {
2441 scalar * x
2442 })
2443 .map_err(|err| crate::Error::backend_failure("mul", err))?;
2444 Ok(tensor_from_array(out))
2445 } else if rhs.shape().is_empty() {
2446 let scalar = typed_host_data("mul", rhs)?[0];
2447 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2449 map_into(&mut out.view_mut(), &typed_view("mul", lhs)?, |x| {
2450 x * scalar
2451 })
2452 .map_err(|err| crate::Error::backend_failure("mul", err))?;
2453 Ok(tensor_from_array(out))
2454 } else {
2455 Err(crate::Error::ShapeMismatch {
2456 op: "mul",
2457 lhs: lhs.shape().to_vec(),
2458 rhs: rhs.shape().to_vec(),
2459 })
2460 }
2461}
2462
2463pub(crate) fn typed_mul_view_with_pool<T, L, R>(
2464 buffers: &mut BufferPool,
2465 lhs: &TypedTensorView<'_, T, L>,
2466 rhs: &TypedTensorView<'_, T, R>,
2467) -> crate::Result<TypedTensor<T>>
2468where
2469 T: Copy + Clone + Zero + Mul<Output = T> + PoolScalar + 'static,
2470 L: TensorRank,
2471 R: TensorRank,
2472{
2473 if lhs.shape() == rhs.shape() {
2474 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2476 mul_into(
2477 &mut out.view_mut(),
2478 &typed_view_from_view("mul", lhs)?,
2479 &typed_view_from_view("mul", rhs)?,
2480 )
2481 .map_err(|err| crate::Error::backend_failure("mul", err))?;
2482 Ok(tensor_from_array(out))
2483 } else if lhs.shape().is_empty() {
2484 let scalar = typed_view_from_view("mul", lhs)?.get(&[]);
2485 let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2487 map_into(
2488 &mut out.view_mut(),
2489 &typed_view_from_view("mul", rhs)?,
2490 |x| scalar * x,
2491 )
2492 .map_err(|err| crate::Error::backend_failure("mul", err))?;
2493 Ok(tensor_from_array(out))
2494 } else if rhs.shape().is_empty() {
2495 let scalar = typed_view_from_view("mul", rhs)?.get(&[]);
2496 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2498 map_into(
2499 &mut out.view_mut(),
2500 &typed_view_from_view("mul", lhs)?,
2501 |x| x * scalar,
2502 )
2503 .map_err(|err| crate::Error::backend_failure("mul", err))?;
2504 Ok(tensor_from_array(out))
2505 } else {
2506 Err(crate::Error::ShapeMismatch {
2507 op: "mul",
2508 lhs: lhs.shape().to_vec(),
2509 rhs: rhs.shape().to_vec(),
2510 })
2511 }
2512}
2513
2514pub(crate) fn typed_div_with_pool<T>(
2515 buffers: &mut BufferPool,
2516 lhs: &TypedTensor<T>,
2517 rhs: &TypedTensor<T>,
2518) -> crate::Result<TypedTensor<T>>
2519where
2520 T: Copy + Clone + Zero + Div<Output = T> + PoolScalar,
2521{
2522 if lhs.shape() == rhs.shape() {
2523 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2525 zip_map2_into(
2526 &mut out.view_mut(),
2527 &typed_view("div", lhs)?,
2528 &typed_view("div", rhs)?,
2529 |x, y| x / y,
2530 )
2531 .map_err(|err| crate::Error::backend_failure("div", err))?;
2532 Ok(tensor_from_array(out))
2533 } else if lhs.shape().is_empty() {
2534 let scalar = typed_host_data("div", lhs)?[0];
2535 let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2537 map_into(&mut out.view_mut(), &typed_view("div", rhs)?, |x| {
2538 scalar / x
2539 })
2540 .map_err(|err| crate::Error::backend_failure("div", err))?;
2541 Ok(tensor_from_array(out))
2542 } else if rhs.shape().is_empty() {
2543 let scalar = typed_host_data("div", rhs)?[0];
2544 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2546 map_into(&mut out.view_mut(), &typed_view("div", lhs)?, |x| {
2547 x / scalar
2548 })
2549 .map_err(|err| crate::Error::backend_failure("div", err))?;
2550 Ok(tensor_from_array(out))
2551 } else {
2552 Err(crate::Error::ShapeMismatch {
2553 op: "div",
2554 lhs: lhs.shape().to_vec(),
2555 rhs: rhs.shape().to_vec(),
2556 })
2557 }
2558}
2559
2560pub(crate) fn typed_neg_with_pool<T>(
2561 buffers: &mut BufferPool,
2562 input: &TypedTensor<T>,
2563) -> crate::Result<TypedTensor<T>>
2564where
2565 T: Copy + Clone + Zero + Neg<Output = T> + PoolScalar,
2566{
2567 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2569 map_into(&mut out.view_mut(), &typed_view("neg", input)?, |x| -x)
2570 .map_err(|err| crate::Error::backend_failure("neg", err))?;
2571 Ok(tensor_from_array(out))
2572}
2573
2574pub(crate) fn typed_conj_with_pool<T>(
2575 buffers: &mut BufferPool,
2576 input: &TypedTensor<T>,
2577) -> crate::Result<TypedTensor<T>>
2578where
2579 T: Copy + Clone + Zero + ConjElem + PoolScalar,
2580{
2581 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2583 map_into(&mut out.view_mut(), &typed_view("conj", input)?, |x| {
2584 x.conj_elem()
2585 })
2586 .map_err(|err| crate::Error::backend_failure("conj", err))?;
2587 Ok(tensor_from_array(out))
2588}
2589
2590pub(crate) fn typed_abs_with_pool<T>(
2591 buffers: &mut BufferPool,
2592 input: &TypedTensor<T>,
2593) -> crate::Result<TypedTensor<T>>
2594where
2595 T: Tier2Elem + PoolScalar,
2596{
2597 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2599 map_into(&mut out.view_mut(), &typed_view("abs", input)?, |x| {
2600 x.abs_elem()
2601 })
2602 .map_err(|err| crate::Error::backend_failure("abs", err))?;
2603 Ok(tensor_from_array(out))
2604}
2605
2606fn typed_complex_abs_with_pool<T>(
2607 buffers: &mut BufferPool,
2608 input: &TypedTensor<Complex<T>>,
2609) -> crate::Result<TypedTensor<T>>
2610where
2611 T: num_traits::Float + PoolScalar,
2612{
2613 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2615 map_into(&mut out.view_mut(), &typed_view("abs", input)?, |x| {
2616 x.norm()
2617 })
2618 .map_err(|err| crate::Error::backend_failure("abs", err))?;
2619 Ok(tensor_from_array(out))
2620}
2621
2622fn typed_complex_abs_view_with_pool<T, R>(
2623 buffers: &mut BufferPool,
2624 input: &TypedTensorView<'_, Complex<T>, R>,
2625) -> crate::Result<TypedTensor<T>>
2626where
2627 T: num_traits::Float + PoolScalar + 'static,
2628 R: TensorRank,
2629{
2630 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2632 map_into(
2633 &mut out.view_mut(),
2634 &typed_view_from_view("abs", input)?,
2635 |x| x.norm(),
2636 )
2637 .map_err(|err| crate::Error::backend_failure("abs", err))?;
2638 Ok(tensor_from_array(out))
2639}
2640
2641pub(crate) fn typed_sign_with_pool<T>(
2642 buffers: &mut BufferPool,
2643 input: &TypedTensor<T>,
2644) -> crate::Result<TypedTensor<T>>
2645where
2646 T: Tier2Elem + PoolScalar,
2647{
2648 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2650 map_into(&mut out.view_mut(), &typed_view("sign", input)?, |x| {
2651 x.sign_elem()
2652 })
2653 .map_err(|err| crate::Error::backend_failure("sign", err))?;
2654 Ok(tensor_from_array(out))
2655}
2656
2657pub(crate) fn typed_maximum_with_pool<T>(
2658 buffers: &mut BufferPool,
2659 lhs: &TypedTensor<T>,
2660 rhs: &TypedTensor<T>,
2661) -> crate::Result<TypedTensor<T>>
2662where
2663 T: Tier2Elem + PoolScalar,
2664{
2665 if lhs.shape() != rhs.shape() {
2666 return Err(crate::Error::ShapeMismatch {
2667 op: "maximum",
2668 lhs: lhs.shape().to_vec(),
2669 rhs: rhs.shape().to_vec(),
2670 });
2671 }
2672 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2674 zip_map2_into(
2675 &mut out.view_mut(),
2676 &typed_view("maximum", lhs)?,
2677 &typed_view("maximum", rhs)?,
2678 |x, y| x.max_elem(y),
2679 )
2680 .map_err(|err| crate::Error::backend_failure("maximum", err))?;
2681 Ok(tensor_from_array(out))
2682}
2683
2684pub(crate) fn typed_minimum_with_pool<T>(
2685 buffers: &mut BufferPool,
2686 lhs: &TypedTensor<T>,
2687 rhs: &TypedTensor<T>,
2688) -> crate::Result<TypedTensor<T>>
2689where
2690 T: Tier2Elem + PoolScalar,
2691{
2692 if lhs.shape() != rhs.shape() {
2693 return Err(crate::Error::ShapeMismatch {
2694 op: "minimum",
2695 lhs: lhs.shape().to_vec(),
2696 rhs: rhs.shape().to_vec(),
2697 });
2698 }
2699 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2701 zip_map2_into(
2702 &mut out.view_mut(),
2703 &typed_view("minimum", lhs)?,
2704 &typed_view("minimum", rhs)?,
2705 |x, y| x.min_elem(y),
2706 )
2707 .map_err(|err| crate::Error::backend_failure("minimum", err))?;
2708 Ok(tensor_from_array(out))
2709}
2710
2711pub(crate) fn typed_compare_with_pool<T>(
2712 buffers: &mut BufferPool,
2713 lhs: &TypedTensor<T>,
2714 rhs: &TypedTensor<T>,
2715 dir: &CompareDir,
2716) -> crate::Result<TypedTensor<bool>>
2717where
2718 T: CompareElem,
2719{
2720 if lhs.shape() != rhs.shape() {
2721 return Err(crate::Error::ShapeMismatch {
2722 op: "compare",
2723 lhs: lhs.shape().to_vec(),
2724 rhs: rhs.shape().to_vec(),
2725 });
2726 }
2727 let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2729 zip_map2_into(
2730 &mut out.view_mut(),
2731 &typed_view("compare", lhs)?,
2732 &typed_view("compare", rhs)?,
2733 |x, y| x.compare_elem(y, dir),
2734 )
2735 .map_err(|err| crate::Error::backend_failure("compare", err))?;
2736 Ok(tensor_from_array(out))
2737}
2738
2739pub(crate) fn typed_select_with_pool<T>(
2740 buffers: &mut BufferPool,
2741 pred: &TypedTensor<bool>,
2742 on_true: &TypedTensor<T>,
2743 on_false: &TypedTensor<T>,
2744) -> crate::Result<TypedTensor<T>>
2745where
2746 T: Copy + PoolScalar,
2747{
2748 if pred.shape() != on_true.shape() {
2749 return Err(crate::Error::ShapeMismatch {
2750 op: "select",
2751 lhs: pred.shape().to_vec(),
2752 rhs: on_true.shape().to_vec(),
2753 });
2754 }
2755 if pred.shape() != on_false.shape() {
2756 return Err(crate::Error::ShapeMismatch {
2757 op: "select",
2758 lhs: pred.shape().to_vec(),
2759 rhs: on_false.shape().to_vec(),
2760 });
2761 }
2762 let mut out = unsafe { typed_array_uninit_from_pool(buffers, pred.shape()) };
2764 zip_map3_into(
2765 &mut out.view_mut(),
2766 &typed_view("select", pred)?,
2767 &typed_view("select", on_true)?,
2768 &typed_view("select", on_false)?,
2769 |p, t, f| if p { t } else { f },
2770 )
2771 .map_err(|err| crate::Error::backend_failure("select", err))?;
2772 Ok(tensor_from_array(out))
2773}
2774
2775pub(crate) fn typed_clamp_with_pool<T>(
2776 buffers: &mut BufferPool,
2777 input: &TypedTensor<T>,
2778 lower: &TypedTensor<T>,
2779 upper: &TypedTensor<T>,
2780) -> crate::Result<TypedTensor<T>>
2781where
2782 T: Tier2Elem + PoolScalar,
2783{
2784 if input.shape() != lower.shape() {
2785 return Err(crate::Error::ShapeMismatch {
2786 op: "clamp",
2787 lhs: input.shape().to_vec(),
2788 rhs: lower.shape().to_vec(),
2789 });
2790 }
2791 if input.shape() != upper.shape() {
2792 return Err(crate::Error::ShapeMismatch {
2793 op: "clamp",
2794 lhs: input.shape().to_vec(),
2795 rhs: upper.shape().to_vec(),
2796 });
2797 }
2798 let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2800 zip_map3_into(
2801 &mut out.view_mut(),
2802 &typed_view("clamp", input)?,
2803 &typed_view("clamp", lower)?,
2804 &typed_view("clamp", upper)?,
2805 |x, lo, hi| lo.max_elem(hi.min_elem(x)),
2806 )
2807 .map_err(|err| crate::Error::backend_failure("clamp", err))?;
2808 Ok(tensor_from_array(out))
2809}
2810
2811#[cfg(test)]
2812mod tests;