Skip to main content

tenferro_cpu/
elementwise.rs

1use std::ops::{Add, Div, Mul, Neg};
2use std::sync::Arc;
3
4use num_complex::Complex;
5use num_traits::{One, Zero};
6use strided_kernel::{
7    batched_outer_product_into, broadcast_mul_into, map_into, mul_into, zip_map2_into,
8    zip_map3_into,
9};
10
11use crate::buffer_pool::{BufferPool, PoolScalar};
12use crate::ConjElem;
13use tenferro_tensor::{
14    col_major_strides, CompareDir, DType, Tensor, TensorOwnedView, TensorRank, TensorRead,
15    TensorValue, TensorView, TypedTensor, TypedTensorView,
16};
17
18use super::{
19    tensor_from_array, typed_array_uninit_from_pool, typed_host_data, typed_view,
20    typed_view_from_view,
21};
22
23macro_rules! dispatch_ternary_result_with_pool {
24    ($op:literal, $a:expr, $b:expr, $c:expr, |$x:ident, $y:ident, $z:ident| $body:expr) => {
25        match ($a, $b, $c) {
26            (Tensor::F32($x), Tensor::F32($y), Tensor::F32($z)) => Ok(Tensor::F32($body?)),
27            (Tensor::F64($x), Tensor::F64($y), Tensor::F64($z)) => Ok(Tensor::F64($body?)),
28            (Tensor::C32($x), Tensor::C32($y), Tensor::C32($z)) => Ok(Tensor::C32($body?)),
29            (Tensor::C64($x), Tensor::C64($y), Tensor::C64($z)) => Ok(Tensor::C64($body?)),
30            _ => Err(crate::Error::backend_failure($op, "dtype mismatch")),
31        }
32    };
33}
34
35fn dtype_pair_error(op: &'static str, lhs: DType, rhs: DType) -> crate::Error {
36    if lhs == rhs {
37        crate::Error::backend_failure(op, format!("unsupported dtype {lhs:?}"))
38    } else {
39        crate::Error::DTypeMismatch { op, lhs, rhs }
40    }
41}
42
43fn tensor_pair_error(op: &'static str, lhs: &Tensor, rhs: &Tensor) -> crate::Error {
44    dtype_pair_error(op, lhs.dtype(), rhs.dtype())
45}
46
47fn read_pair_error(op: &'static str, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Error {
48    dtype_pair_error(op, lhs.dtype(), rhs.dtype())
49}
50
51pub(crate) trait Tier2Elem: Copy + Clone + One + Zero + Send + Sync {
52    fn abs_elem(self) -> Self;
53    fn sign_elem(self) -> Self;
54    fn max_elem(self, other: Self) -> Self;
55    fn min_elem(self, other: Self) -> Self;
56}
57
58pub(crate) trait CompareElem: Copy + Send + Sync {
59    fn compare_elem(self, other: Self, dir: &CompareDir) -> bool;
60}
61
62macro_rules! impl_tier2_elem_real {
63    ($ty:ty) => {
64        impl Tier2Elem for $ty {
65            fn abs_elem(self) -> Self {
66                self.abs()
67            }
68
69            fn sign_elem(self) -> Self {
70                if self == Self::zero() {
71                    Self::zero()
72                } else {
73                    self.signum()
74                }
75            }
76
77            fn max_elem(self, other: Self) -> Self {
78                if self.is_nan() || other.is_nan() {
79                    <$ty>::NAN
80                } else if self >= other {
81                    self
82                } else {
83                    other
84                }
85            }
86
87            fn min_elem(self, other: Self) -> Self {
88                if self.is_nan() || other.is_nan() {
89                    <$ty>::NAN
90                } else if self <= other {
91                    self
92                } else {
93                    other
94                }
95            }
96        }
97
98        impl CompareElem for $ty {
99            fn compare_elem(self, other: Self, dir: &CompareDir) -> bool {
100                match dir {
101                    CompareDir::Eq => self == other,
102                    CompareDir::Lt => self < other,
103                    CompareDir::Le => self <= other,
104                    CompareDir::Gt => self > other,
105                    CompareDir::Ge => self >= other,
106                }
107            }
108        }
109    };
110}
111
112macro_rules! impl_tier2_elem_complex {
113    ($real:ty) => {
114        impl Tier2Elem for Complex<$real> {
115            fn abs_elem(self) -> Self {
116                Self::new(self.norm(), <$real>::zero())
117            }
118
119            fn sign_elem(self) -> Self {
120                if self.is_zero() {
121                    Self::zero()
122                } else {
123                    self / self.abs_elem()
124                }
125            }
126
127            fn max_elem(self, other: Self) -> Self {
128                let lhs_norm = self.norm_sqr();
129                let rhs_norm = other.norm_sqr();
130                if lhs_norm.is_nan() || rhs_norm.is_nan() {
131                    Self::new(<$real>::NAN, <$real>::NAN)
132                } else if lhs_norm >= rhs_norm {
133                    self
134                } else {
135                    other
136                }
137            }
138
139            fn min_elem(self, other: Self) -> Self {
140                let lhs_norm = self.norm_sqr();
141                let rhs_norm = other.norm_sqr();
142                if lhs_norm.is_nan() || rhs_norm.is_nan() {
143                    Self::new(<$real>::NAN, <$real>::NAN)
144                } else if lhs_norm <= rhs_norm {
145                    self
146                } else {
147                    other
148                }
149            }
150        }
151
152        impl CompareElem for Complex<$real> {
153            fn compare_elem(self, other: Self, dir: &CompareDir) -> bool {
154                match dir {
155                    CompareDir::Eq => self == other,
156                    CompareDir::Lt => self.norm_sqr() < other.norm_sqr(),
157                    CompareDir::Le => self.norm_sqr() <= other.norm_sqr(),
158                    CompareDir::Gt => self.norm_sqr() > other.norm_sqr(),
159                    CompareDir::Ge => self.norm_sqr() >= other.norm_sqr(),
160                }
161            }
162        }
163    };
164}
165
166impl_tier2_elem_real!(f32);
167impl_tier2_elem_real!(f64);
168impl_tier2_elem_complex!(f32);
169impl_tier2_elem_complex!(f64);
170
171macro_rules! impl_compare_elem_ord {
172    ($ty:ty) => {
173        impl CompareElem for $ty {
174            fn compare_elem(self, other: Self, dir: &CompareDir) -> bool {
175                match dir {
176                    CompareDir::Eq => self == other,
177                    CompareDir::Lt => self < other,
178                    CompareDir::Le => self <= other,
179                    CompareDir::Gt => self > other,
180                    CompareDir::Ge => self >= other,
181                }
182            }
183        }
184    };
185}
186
187impl_compare_elem_ord!(i32);
188impl_compare_elem_ord!(i64);
189impl_compare_elem_ord!(bool);
190
191fn complex_scalar_tensor<T>(scalar: T) -> crate::Result<TypedTensor<Complex<T>>>
192where
193    T: Copy + Clone + Zero,
194{
195    TypedTensor::from_vec_col_major(vec![], vec![Complex::new(scalar, T::zero())])
196}
197
198fn complex_scalar_tensor_from_tensor<T>(
199    input: &TypedTensor<T>,
200) -> crate::Result<TypedTensor<Complex<T>>>
201where
202    T: Copy + Clone + Zero,
203{
204    complex_scalar_tensor(typed_host_data("add", input)?[0])
205}
206
207fn complex_scalar_tensor_from_view<T, R>(
208    input: &TypedTensorView<'_, T, R>,
209) -> crate::Result<TypedTensor<Complex<T>>>
210where
211    T: Copy + Clone + Zero + 'static,
212    R: TensorRank,
213{
214    complex_scalar_tensor(typed_view_from_view("add", input)?.get(&[]))
215}
216
217fn with_local_pool<T>(f: impl FnOnce(&mut BufferPool) -> T) -> T {
218    let mut buffers = BufferPool::new();
219    f(&mut buffers)
220}
221
222/// Add two CPU tensors elementwise.
223///
224/// # Examples
225///
226/// ```
227/// use tenferro_cpu::add;
228/// use tenferro_tensor::Tensor;
229///
230/// let a = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0])?;
231/// let b = Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0])?;
232/// let out = add(&a, &b)?;
233/// assert_eq!(out.as_slice::<f64>().unwrap(), &[4.0, 6.0]);
234/// # Ok::<(), tenferro_tensor::Error>(())
235/// ```
236pub fn add(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
237    with_local_pool(|buffers| add_with_pool(buffers, lhs, rhs))
238}
239
240pub(crate) fn add_with_pool(
241    buffers: &mut BufferPool,
242    lhs: &Tensor,
243    rhs: &Tensor,
244) -> crate::Result<Tensor> {
245    match (lhs, rhs) {
246        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_add_with_pool(buffers, a, b)?)),
247        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_add_with_pool(buffers, a, b)?)),
248        (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(typed_add_with_pool(buffers, a, b)?)),
249        (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(typed_add_with_pool(buffers, a, b)?)),
250        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_add_with_pool(buffers, a, b)?)),
251        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_add_with_pool(buffers, a, b)?)),
252        (Tensor::F32(a), Tensor::C32(b)) if a.shape().is_empty() => {
253            let scalar = complex_scalar_tensor(typed_host_data("add", a)?[0])?;
254            Ok(Tensor::C32(typed_add_with_pool(buffers, &scalar, b)?))
255        }
256        (Tensor::C32(a), Tensor::F32(b)) if b.shape().is_empty() => {
257            let scalar = complex_scalar_tensor(typed_host_data("add", b)?[0])?;
258            Ok(Tensor::C32(typed_add_with_pool(buffers, a, &scalar)?))
259        }
260        (Tensor::F64(a), Tensor::C64(b)) if a.shape().is_empty() => {
261            let scalar = complex_scalar_tensor(typed_host_data("add", a)?[0])?;
262            Ok(Tensor::C64(typed_add_with_pool(buffers, &scalar, b)?))
263        }
264        (Tensor::C64(a), Tensor::F64(b)) if b.shape().is_empty() => {
265            let scalar = complex_scalar_tensor(typed_host_data("add", b)?[0])?;
266            Ok(Tensor::C64(typed_add_with_pool(buffers, a, &scalar)?))
267        }
268        _ => Err(tensor_pair_error("add", lhs, rhs)),
269    }
270}
271
272pub(crate) fn add_read_with_pool(
273    buffers: &mut BufferPool,
274    lhs: TensorRead<'_>,
275    rhs: TensorRead<'_>,
276) -> crate::Result<Tensor> {
277    if let (TensorRead::Tensor(lhs), TensorRead::Tensor(rhs)) = (&lhs, &rhs) {
278        return add_with_pool(buffers, lhs, rhs);
279    }
280
281    macro_rules! dispatch {
282        ($variant:ident) => {
283            match (&lhs, &rhs) {
284                (
285                    TensorRead::Tensor(Tensor::$variant(a)),
286                    TensorRead::View(TensorView::$variant(b)),
287                ) => {
288                    let a = a.as_view();
289                    return Ok(Tensor::$variant(typed_add_view_with_pool(buffers, &a, b)?));
290                }
291                (
292                    TensorRead::View(TensorView::$variant(a)),
293                    TensorRead::Tensor(Tensor::$variant(b)),
294                ) => {
295                    let b = b.as_view();
296                    return Ok(Tensor::$variant(typed_add_view_with_pool(buffers, a, &b)?));
297                }
298                (
299                    TensorRead::View(TensorView::$variant(a)),
300                    TensorRead::View(TensorView::$variant(b)),
301                ) => {
302                    return Ok(Tensor::$variant(typed_add_view_with_pool(buffers, a, b)?));
303                }
304                _ => {}
305            }
306        };
307    }
308
309    macro_rules! dispatch_real_complex_scalar {
310        ($real_variant:ident, $complex_variant:ident) => {
311            match (&lhs, &rhs) {
312                (
313                    TensorRead::Tensor(Tensor::$real_variant(real)),
314                    TensorRead::View(TensorView::$complex_variant(complex)),
315                ) if real.shape().is_empty() => {
316                    let scalar = complex_scalar_tensor_from_tensor(real)?;
317                    let scalar = scalar.as_view();
318                    return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
319                        buffers, &scalar, complex,
320                    )?));
321                }
322                (
323                    TensorRead::View(TensorView::$real_variant(real)),
324                    TensorRead::Tensor(Tensor::$complex_variant(complex)),
325                ) if real.shape().is_empty() => {
326                    let scalar = complex_scalar_tensor_from_view(real)?;
327                    let scalar = scalar.as_view();
328                    let complex = complex.as_view();
329                    return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
330                        buffers, &scalar, &complex,
331                    )?));
332                }
333                (
334                    TensorRead::View(TensorView::$real_variant(real)),
335                    TensorRead::View(TensorView::$complex_variant(complex)),
336                ) if real.shape().is_empty() => {
337                    let scalar = complex_scalar_tensor_from_view(real)?;
338                    let scalar = scalar.as_view();
339                    return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
340                        buffers, &scalar, complex,
341                    )?));
342                }
343                (
344                    TensorRead::Tensor(Tensor::$complex_variant(complex)),
345                    TensorRead::View(TensorView::$real_variant(real)),
346                ) if real.shape().is_empty() => {
347                    let complex = complex.as_view();
348                    let scalar = complex_scalar_tensor_from_view(real)?;
349                    let scalar = scalar.as_view();
350                    return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
351                        buffers, &complex, &scalar,
352                    )?));
353                }
354                (
355                    TensorRead::View(TensorView::$complex_variant(complex)),
356                    TensorRead::Tensor(Tensor::$real_variant(real)),
357                ) if real.shape().is_empty() => {
358                    let scalar = complex_scalar_tensor_from_tensor(real)?;
359                    let scalar = scalar.as_view();
360                    return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
361                        buffers, complex, &scalar,
362                    )?));
363                }
364                (
365                    TensorRead::View(TensorView::$complex_variant(complex)),
366                    TensorRead::View(TensorView::$real_variant(real)),
367                ) if real.shape().is_empty() => {
368                    let scalar = complex_scalar_tensor_from_view(real)?;
369                    let scalar = scalar.as_view();
370                    return Ok(Tensor::$complex_variant(typed_add_view_with_pool(
371                        buffers, complex, &scalar,
372                    )?));
373                }
374                _ => {}
375            }
376        };
377    }
378
379    dispatch_real_complex_scalar!(F32, C32);
380    dispatch_real_complex_scalar!(F64, C64);
381
382    dispatch!(F32);
383    dispatch!(F64);
384    dispatch!(I32);
385    dispatch!(I64);
386    dispatch!(C32);
387    dispatch!(C64);
388
389    Err(read_pair_error("add", lhs, rhs))
390}
391
392/// Multiply two CPU tensors elementwise.
393///
394/// # Examples
395///
396/// ```
397/// use tenferro_cpu::mul;
398/// use tenferro_tensor::Tensor;
399///
400/// let a = Tensor::from_vec_col_major(vec![2], vec![2.0_f64, 3.0])?;
401/// let b = Tensor::from_vec_col_major(vec![2], vec![4.0_f64, 5.0])?;
402/// let out = mul(&a, &b)?;
403/// assert_eq!(out.as_slice::<f64>().unwrap(), &[8.0, 15.0]);
404/// # Ok::<(), tenferro_tensor::Error>(())
405/// ```
406pub fn mul(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
407    with_local_pool(|buffers| mul_with_pool(buffers, lhs, rhs))
408}
409
410fn binary_read_with_pool(
411    op: &'static str,
412    buffers: &mut BufferPool,
413    lhs: TensorRead<'_>,
414    rhs: TensorRead<'_>,
415    f: impl FnOnce(&mut BufferPool, &Tensor, &Tensor) -> crate::Result<Tensor>,
416) -> crate::Result<Tensor> {
417    if let (Some(lhs), Some(rhs)) = (lhs.as_tensor(), rhs.as_tensor()) {
418        return f(buffers, lhs, rhs);
419    }
420
421    Err(read_pair_error(op, lhs, rhs))
422}
423
424pub(crate) fn mul_with_pool(
425    buffers: &mut BufferPool,
426    lhs: &Tensor,
427    rhs: &Tensor,
428) -> crate::Result<Tensor> {
429    match (lhs, rhs) {
430        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_mul_with_pool(buffers, a, b)?)),
431        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_mul_with_pool(buffers, a, b)?)),
432        (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(typed_mul_with_pool(buffers, a, b)?)),
433        (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(typed_mul_with_pool(buffers, a, b)?)),
434        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_mul_with_pool(buffers, a, b)?)),
435        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_mul_with_pool(buffers, a, b)?)),
436        (Tensor::F32(a), Tensor::C32(b)) if a.shape().is_empty() => {
437            let scalar = complex_scalar_tensor(typed_host_data("mul", a)?[0])?;
438            Ok(Tensor::C32(typed_mul_with_pool(buffers, &scalar, b)?))
439        }
440        (Tensor::C32(a), Tensor::F32(b)) if b.shape().is_empty() => {
441            let scalar = complex_scalar_tensor(typed_host_data("mul", b)?[0])?;
442            Ok(Tensor::C32(typed_mul_with_pool(buffers, a, &scalar)?))
443        }
444        (Tensor::F64(a), Tensor::C64(b)) if a.shape().is_empty() => {
445            let scalar = complex_scalar_tensor(typed_host_data("mul", a)?[0])?;
446            Ok(Tensor::C64(typed_mul_with_pool(buffers, &scalar, b)?))
447        }
448        (Tensor::C64(a), Tensor::F64(b)) if b.shape().is_empty() => {
449            let scalar = complex_scalar_tensor(typed_host_data("mul", b)?[0])?;
450            Ok(Tensor::C64(typed_mul_with_pool(buffers, a, &scalar)?))
451        }
452        _ => Err(tensor_pair_error("mul", lhs, rhs)),
453    }
454}
455
456pub(crate) fn mul_read_with_pool(
457    buffers: &mut BufferPool,
458    lhs: TensorRead<'_>,
459    rhs: TensorRead<'_>,
460) -> crate::Result<Tensor> {
461    if let (TensorRead::Tensor(lhs), TensorRead::Tensor(rhs)) = (&lhs, &rhs) {
462        return mul_with_pool(buffers, lhs, rhs);
463    }
464
465    macro_rules! dispatch {
466        ($variant:ident) => {
467            match (&lhs, &rhs) {
468                (
469                    TensorRead::Tensor(Tensor::$variant(a)),
470                    TensorRead::View(TensorView::$variant(b)),
471                ) => {
472                    let a = a.as_view();
473                    return Ok(Tensor::$variant(typed_mul_view_with_pool(buffers, &a, b)?));
474                }
475                (
476                    TensorRead::View(TensorView::$variant(a)),
477                    TensorRead::Tensor(Tensor::$variant(b)),
478                ) => {
479                    let b = b.as_view();
480                    return Ok(Tensor::$variant(typed_mul_view_with_pool(buffers, a, &b)?));
481                }
482                (
483                    TensorRead::View(TensorView::$variant(a)),
484                    TensorRead::View(TensorView::$variant(b)),
485                ) => {
486                    return Ok(Tensor::$variant(typed_mul_view_with_pool(buffers, a, b)?));
487                }
488                _ => {}
489            }
490        };
491    }
492
493    macro_rules! dispatch_real_complex_scalar {
494        ($real_variant:ident, $complex_variant:ident) => {
495            match (&lhs, &rhs) {
496                (
497                    TensorRead::Tensor(Tensor::$real_variant(real)),
498                    TensorRead::View(TensorView::$complex_variant(complex)),
499                ) if real.shape().is_empty() => {
500                    let scalar = complex_scalar_tensor_from_tensor(real)?;
501                    let scalar = scalar.as_view();
502                    return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
503                        buffers, &scalar, complex,
504                    )?));
505                }
506                (
507                    TensorRead::View(TensorView::$real_variant(real)),
508                    TensorRead::Tensor(Tensor::$complex_variant(complex)),
509                ) if real.shape().is_empty() => {
510                    let scalar = complex_scalar_tensor_from_view(real)?;
511                    let scalar = scalar.as_view();
512                    let complex = complex.as_view();
513                    return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
514                        buffers, &scalar, &complex,
515                    )?));
516                }
517                (
518                    TensorRead::View(TensorView::$real_variant(real)),
519                    TensorRead::View(TensorView::$complex_variant(complex)),
520                ) if real.shape().is_empty() => {
521                    let scalar = complex_scalar_tensor_from_view(real)?;
522                    let scalar = scalar.as_view();
523                    return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
524                        buffers, &scalar, complex,
525                    )?));
526                }
527                (
528                    TensorRead::Tensor(Tensor::$complex_variant(complex)),
529                    TensorRead::View(TensorView::$real_variant(real)),
530                ) if real.shape().is_empty() => {
531                    let complex = complex.as_view();
532                    let scalar = complex_scalar_tensor_from_view(real)?;
533                    let scalar = scalar.as_view();
534                    return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
535                        buffers, &complex, &scalar,
536                    )?));
537                }
538                (
539                    TensorRead::View(TensorView::$complex_variant(complex)),
540                    TensorRead::Tensor(Tensor::$real_variant(real)),
541                ) if real.shape().is_empty() => {
542                    let scalar = complex_scalar_tensor_from_tensor(real)?;
543                    let scalar = scalar.as_view();
544                    return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
545                        buffers, complex, &scalar,
546                    )?));
547                }
548                (
549                    TensorRead::View(TensorView::$complex_variant(complex)),
550                    TensorRead::View(TensorView::$real_variant(real)),
551                ) if real.shape().is_empty() => {
552                    let scalar = complex_scalar_tensor_from_view(real)?;
553                    let scalar = scalar.as_view();
554                    return Ok(Tensor::$complex_variant(typed_mul_view_with_pool(
555                        buffers, complex, &scalar,
556                    )?));
557                }
558                _ => {}
559            }
560        };
561    }
562
563    dispatch_real_complex_scalar!(F32, C32);
564    dispatch_real_complex_scalar!(F64, C64);
565
566    dispatch!(F32);
567    dispatch!(F64);
568    dispatch!(I32);
569    dispatch!(I64);
570    dispatch!(C32);
571    dispatch!(C64);
572
573    binary_read_with_pool("mul", buffers, lhs, rhs, mul_with_pool)
574}
575
576enum CpuReadView<'a> {
577    F32(TypedTensorView<'a, f32>),
578    F64(TypedTensorView<'a, f64>),
579    I32(TypedTensorView<'a, i32>),
580    I64(TypedTensorView<'a, i64>),
581    Bool(TypedTensorView<'a, bool>),
582    C32(TypedTensorView<'a, Complex<f32>>),
583    C64(TypedTensorView<'a, Complex<f64>>),
584}
585
586fn read_as_cpu_view(input: TensorRead<'_>) -> CpuReadView<'_> {
587    match input {
588        TensorRead::Tensor(Tensor::F32(tensor)) => CpuReadView::F32(tensor.as_view()),
589        TensorRead::Tensor(Tensor::F64(tensor)) => CpuReadView::F64(tensor.as_view()),
590        TensorRead::Tensor(Tensor::I32(tensor)) => CpuReadView::I32(tensor.as_view()),
591        TensorRead::Tensor(Tensor::I64(tensor)) => CpuReadView::I64(tensor.as_view()),
592        TensorRead::Tensor(Tensor::Bool(tensor)) => CpuReadView::Bool(tensor.as_view()),
593        TensorRead::Tensor(Tensor::C32(tensor)) => CpuReadView::C32(tensor.as_view()),
594        TensorRead::Tensor(Tensor::C64(tensor)) => CpuReadView::C64(tensor.as_view()),
595        TensorRead::View(TensorView::F32(view)) => CpuReadView::F32(view),
596        TensorRead::View(TensorView::F64(view)) => CpuReadView::F64(view),
597        TensorRead::View(TensorView::I32(view)) => CpuReadView::I32(view),
598        TensorRead::View(TensorView::I64(view)) => CpuReadView::I64(view),
599        TensorRead::View(TensorView::Bool(view)) => CpuReadView::Bool(view),
600        TensorRead::View(TensorView::C32(view)) => CpuReadView::C32(view),
601        TensorRead::View(TensorView::C64(view)) => CpuReadView::C64(view),
602    }
603}
604
605fn typed_binary_view_with_pool<T, L, R>(
606    op: &'static str,
607    buffers: &mut BufferPool,
608    lhs: &TypedTensorView<'_, T, L>,
609    rhs: &TypedTensorView<'_, T, R>,
610    f: impl Fn(T, T) -> T + Copy + Sync,
611) -> crate::Result<TypedTensor<T>>
612where
613    T: Copy + PoolScalar + 'static,
614    L: TensorRank,
615    R: TensorRank,
616{
617    if lhs.shape() == rhs.shape() {
618        // SAFETY: the following kernel overwrites every output element before any read.
619        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
620        zip_map2_into(
621            &mut out.view_mut(),
622            &typed_view_from_view(op, lhs)?,
623            &typed_view_from_view(op, rhs)?,
624            f,
625        )
626        .map_err(|err| crate::Error::backend_failure(op, err))?;
627        Ok(tensor_from_array(out))
628    } else if lhs.shape().is_empty() {
629        let scalar = typed_view_from_view(op, lhs)?.get(&[]);
630        // SAFETY: the following kernel overwrites every output element before any read.
631        let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
632        map_into(&mut out.view_mut(), &typed_view_from_view(op, rhs)?, |x| {
633            f(scalar, x)
634        })
635        .map_err(|err| crate::Error::backend_failure(op, err))?;
636        Ok(tensor_from_array(out))
637    } else if rhs.shape().is_empty() {
638        let scalar = typed_view_from_view(op, rhs)?.get(&[]);
639        // SAFETY: the following kernel overwrites every output element before any read.
640        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
641        map_into(&mut out.view_mut(), &typed_view_from_view(op, lhs)?, |x| {
642            f(x, scalar)
643        })
644        .map_err(|err| crate::Error::backend_failure(op, err))?;
645        Ok(tensor_from_array(out))
646    } else {
647        Err(crate::Error::ShapeMismatch {
648            op,
649            lhs: lhs.shape().to_vec(),
650            rhs: rhs.shape().to_vec(),
651        })
652    }
653}
654
655fn typed_unary_view_with_pool<T, R>(
656    op: &'static str,
657    buffers: &mut BufferPool,
658    input: &TypedTensorView<'_, T, R>,
659    f: impl Fn(T) -> T + Copy + Sync,
660) -> crate::Result<TypedTensor<T>>
661where
662    T: Copy + PoolScalar + 'static,
663    R: TensorRank,
664{
665    // SAFETY: the following kernel overwrites every output element before any read.
666    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
667    map_into(&mut out.view_mut(), &typed_view_from_view(op, input)?, f)
668        .map_err(|err| crate::Error::backend_failure(op, err))?;
669    Ok(tensor_from_array(out))
670}
671
672fn typed_same_shape_binary_view_with_pool<T, O, L, R>(
673    op: &'static str,
674    buffers: &mut BufferPool,
675    lhs: &TypedTensorView<'_, T, L>,
676    rhs: &TypedTensorView<'_, T, R>,
677    f: impl Fn(T, T) -> O + Copy + Sync,
678) -> crate::Result<TypedTensor<O>>
679where
680    T: Copy + Send + Sync + 'static,
681    O: Copy + PoolScalar,
682    L: TensorRank,
683    R: TensorRank,
684{
685    if lhs.shape() != rhs.shape() {
686        return Err(crate::Error::ShapeMismatch {
687            op,
688            lhs: lhs.shape().to_vec(),
689            rhs: rhs.shape().to_vec(),
690        });
691    }
692    // SAFETY: the following kernel overwrites every output element before any read.
693    let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
694    zip_map2_into(
695        &mut out.view_mut(),
696        &typed_view_from_view(op, lhs)?,
697        &typed_view_from_view(op, rhs)?,
698        f,
699    )
700    .map_err(|err| crate::Error::backend_failure(op, err))?;
701    Ok(tensor_from_array(out))
702}
703
704fn typed_select_view_with_pool<T, P, A, B>(
705    buffers: &mut BufferPool,
706    pred: &TypedTensorView<'_, bool, P>,
707    on_true: &TypedTensorView<'_, T, A>,
708    on_false: &TypedTensorView<'_, T, B>,
709) -> crate::Result<TypedTensor<T>>
710where
711    T: Copy + PoolScalar + 'static,
712    P: TensorRank,
713    A: TensorRank,
714    B: TensorRank,
715{
716    if pred.shape() != on_true.shape() {
717        return Err(crate::Error::ShapeMismatch {
718            op: "select",
719            lhs: pred.shape().to_vec(),
720            rhs: on_true.shape().to_vec(),
721        });
722    }
723    if pred.shape() != on_false.shape() {
724        return Err(crate::Error::ShapeMismatch {
725            op: "select",
726            lhs: pred.shape().to_vec(),
727            rhs: on_false.shape().to_vec(),
728        });
729    }
730    // SAFETY: the following kernel overwrites every output element before any read.
731    let mut out = unsafe { typed_array_uninit_from_pool(buffers, pred.shape()) };
732    zip_map3_into(
733        &mut out.view_mut(),
734        &typed_view_from_view("select", pred)?,
735        &typed_view_from_view("select", on_true)?,
736        &typed_view_from_view("select", on_false)?,
737        |p, t, f| if p { t } else { f },
738    )
739    .map_err(|err| crate::Error::backend_failure("select", err))?;
740    Ok(tensor_from_array(out))
741}
742
743fn typed_clamp_view_with_pool<T, I, L, U>(
744    buffers: &mut BufferPool,
745    input: &TypedTensorView<'_, T, I>,
746    lower: &TypedTensorView<'_, T, L>,
747    upper: &TypedTensorView<'_, T, U>,
748) -> crate::Result<TypedTensor<T>>
749where
750    T: Tier2Elem + PoolScalar + 'static,
751    I: TensorRank,
752    L: TensorRank,
753    U: TensorRank,
754{
755    if input.shape() != lower.shape() {
756        return Err(crate::Error::ShapeMismatch {
757            op: "clamp",
758            lhs: input.shape().to_vec(),
759            rhs: lower.shape().to_vec(),
760        });
761    }
762    if input.shape() != upper.shape() {
763        return Err(crate::Error::ShapeMismatch {
764            op: "clamp",
765            lhs: input.shape().to_vec(),
766            rhs: upper.shape().to_vec(),
767        });
768    }
769    // SAFETY: the following kernel overwrites every output element before any read.
770    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
771    zip_map3_into(
772        &mut out.view_mut(),
773        &typed_view_from_view("clamp", input)?,
774        &typed_view_from_view("clamp", lower)?,
775        &typed_view_from_view("clamp", upper)?,
776        |x, lo, hi| lo.max_elem(hi.min_elem(x)),
777    )
778    .map_err(|err| crate::Error::backend_failure("clamp", err))?;
779    Ok(tensor_from_array(out))
780}
781
782#[derive(Clone, Copy)]
783enum SplitOuterProductLayout {
784    LhsPrefix,
785    RhsPrefix,
786}
787
788struct SplitOuterProductPlan {
789    #[allow(dead_code)]
790    rows: usize,
791    #[allow(dead_code)]
792    cols: usize,
793    #[allow(dead_code)]
794    batches: usize,
795    layout: SplitOuterProductLayout,
796    lhs_free_axes: Vec<usize>,
797    rhs_free_axes: Vec<usize>,
798    lhs_batch_axes: Vec<usize>,
799    rhs_batch_axes: Vec<usize>,
800}
801
802struct OuterProductAxisPartition {
803    lhs_free_output_axes: Vec<usize>,
804    rhs_free_output_axes: Vec<usize>,
805    batch_output_axes: Vec<usize>,
806    lhs_free_axes: Vec<usize>,
807    rhs_free_axes: Vec<usize>,
808    lhs_batch_axes: Vec<usize>,
809    rhs_batch_axes: Vec<usize>,
810}
811
812fn shape_matches_dims(source_shape: &[usize], output_shape: &[usize], dims: &[usize]) -> bool {
813    source_shape.len() == dims.len()
814        && source_shape
815            .iter()
816            .zip(dims.iter())
817            .all(|(&dim, &axis)| output_shape.get(axis).copied() == Some(dim))
818}
819
820fn axes_by_output(dims: &[usize], output_rank: usize) -> Option<Vec<Option<usize>>> {
821    let mut axes = vec![None; output_rank];
822    for (src_axis, &dst_axis) in dims.iter().enumerate() {
823        let slot = axes.get_mut(dst_axis)?;
824        if slot.replace(src_axis).is_some() {
825            return None;
826        }
827    }
828    Some(axes)
829}
830
831fn axes_shape_product<T>(
832    op: &'static str,
833    view: &TypedTensorView<'_, T>,
834    axes: &[usize],
835) -> crate::Result<usize>
836where
837    T: 'static,
838{
839    axes.iter().try_fold(1usize, |acc, &axis| {
840        acc.checked_mul(view.shape()[axis])
841            .ok_or_else(|| crate::Error::backend_failure(op, "shape size overflows usize"))
842    })
843}
844
845fn classify_outer_product_axes(
846    lhs_dims: &[usize],
847    rhs_dims: &[usize],
848    output_rank: usize,
849) -> Option<OuterProductAxisPartition> {
850    let lhs_axes_by_output = axes_by_output(lhs_dims, output_rank)?;
851    let rhs_axes_by_output = axes_by_output(rhs_dims, output_rank)?;
852
853    let mut lhs_free_output_axes = Vec::new();
854    let mut rhs_free_output_axes = Vec::new();
855    let mut batch_output_axes = Vec::new();
856    let mut lhs_free_axes = Vec::new();
857    let mut rhs_free_axes = Vec::new();
858    let mut lhs_batch_axes = Vec::new();
859    let mut rhs_batch_axes = Vec::new();
860
861    for output_axis in 0..output_rank {
862        match (
863            lhs_axes_by_output[output_axis],
864            rhs_axes_by_output[output_axis],
865        ) {
866            (Some(lhs_axis), Some(rhs_axis)) => {
867                batch_output_axes.push(output_axis);
868                lhs_batch_axes.push(lhs_axis);
869                rhs_batch_axes.push(rhs_axis);
870            }
871            (Some(lhs_axis), None) => {
872                lhs_free_output_axes.push(output_axis);
873                lhs_free_axes.push(lhs_axis);
874            }
875            (None, Some(rhs_axis)) => {
876                rhs_free_output_axes.push(output_axis);
877                rhs_free_axes.push(rhs_axis);
878            }
879            (None, None) => return None,
880        }
881    }
882
883    Some(OuterProductAxisPartition {
884        lhs_free_output_axes,
885        rhs_free_output_axes,
886        batch_output_axes,
887        lhs_free_axes,
888        rhs_free_axes,
889        lhs_batch_axes,
890        rhs_batch_axes,
891    })
892}
893
894fn output_axes_match_partition(output_rank: usize, groups: &[&[usize]]) -> bool {
895    groups
896        .iter()
897        .flat_map(|group| group.iter().copied())
898        .eq(0..output_rank)
899}
900
901fn split_outer_product_plan<T>(
902    lhs: &TypedTensorView<'_, T>,
903    lhs_shape: &[usize],
904    lhs_dims: &[usize],
905    rhs: &TypedTensorView<'_, T>,
906    rhs_shape: &[usize],
907    rhs_dims: &[usize],
908) -> crate::Result<Option<SplitOuterProductPlan>>
909where
910    T: 'static,
911{
912    let output_rank = lhs_shape.len();
913    if lhs_shape != rhs_shape
914        || !shape_matches_dims(lhs.shape(), lhs_shape, lhs_dims)
915        || !shape_matches_dims(rhs.shape(), rhs_shape, rhs_dims)
916        || lhs.backend_buffer().is_some()
917        || rhs.backend_buffer().is_some()
918        || lhs.offset() < 0
919        || rhs.offset() < 0
920        || lhs.strides().iter().any(|&stride| stride < 0)
921        || rhs.strides().iter().any(|&stride| stride < 0)
922    {
923        return Ok(None);
924    }
925
926    let Some(partition) = classify_outer_product_axes(lhs_dims, rhs_dims, output_rank) else {
927        return Ok(None);
928    };
929
930    let lhs_free_size = axes_shape_product("broadcast_multiply", lhs, &partition.lhs_free_axes)?;
931    let rhs_free_size = axes_shape_product("broadcast_multiply", rhs, &partition.rhs_free_axes)?;
932    if lhs_free_size <= 1 || rhs_free_size <= 1 {
933        return Ok(None);
934    }
935    let batches = axes_shape_product("broadcast_multiply", lhs, &partition.lhs_batch_axes)?;
936
937    let lhs_prefix = output_axes_match_partition(
938        output_rank,
939        &[
940            &partition.lhs_free_output_axes,
941            &partition.rhs_free_output_axes,
942            &partition.batch_output_axes,
943        ],
944    );
945    if lhs_prefix {
946        return Ok(Some(SplitOuterProductPlan {
947            rows: lhs_free_size,
948            cols: rhs_free_size,
949            batches,
950            layout: SplitOuterProductLayout::LhsPrefix,
951            lhs_free_axes: partition.lhs_free_axes,
952            rhs_free_axes: partition.rhs_free_axes,
953            lhs_batch_axes: partition.lhs_batch_axes,
954            rhs_batch_axes: partition.rhs_batch_axes,
955        }));
956    }
957
958    let rhs_prefix = output_axes_match_partition(
959        output_rank,
960        &[
961            &partition.rhs_free_output_axes,
962            &partition.lhs_free_output_axes,
963            &partition.batch_output_axes,
964        ],
965    );
966    if rhs_prefix {
967        return Ok(Some(SplitOuterProductPlan {
968            rows: rhs_free_size,
969            cols: lhs_free_size,
970            batches,
971            layout: SplitOuterProductLayout::RhsPrefix,
972            lhs_free_axes: partition.lhs_free_axes,
973            rhs_free_axes: partition.rhs_free_axes,
974            lhs_batch_axes: partition.lhs_batch_axes,
975            rhs_batch_axes: partition.rhs_batch_axes,
976        }));
977    }
978
979    Ok(None)
980}
981
982fn try_outer_product_with_pool<T>(
983    buffers: &mut BufferPool,
984    lhs: &TypedTensorView<'_, T>,
985    lhs_shape: &[usize],
986    lhs_dims: &[usize],
987    rhs: &TypedTensorView<'_, T>,
988    rhs_shape: &[usize],
989    rhs_dims: &[usize],
990) -> crate::Result<Option<TypedTensor<T>>>
991where
992    T: Copy + Clone + Mul<Output = T> + PoolScalar + 'static,
993{
994    let Some(plan) = split_outer_product_plan(lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims)?
995    else {
996        return Ok(None);
997    };
998
999    // SAFETY: every element in the column-major output is assigned below.
1000    let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs_shape) };
1001    let lhs_view = typed_view_from_view("broadcast_multiply", lhs)?;
1002    let rhs_view = typed_view_from_view("broadcast_multiply", rhs)?;
1003    match plan.layout {
1004        SplitOuterProductLayout::LhsPrefix => {
1005            let lhs_perm: Vec<_> = plan
1006                .lhs_free_axes
1007                .iter()
1008                .chain(plan.lhs_batch_axes.iter())
1009                .copied()
1010                .collect();
1011            let rhs_perm: Vec<_> = plan
1012                .rhs_free_axes
1013                .iter()
1014                .chain(plan.rhs_batch_axes.iter())
1015                .copied()
1016                .collect();
1017            let lhs_outer = lhs_view
1018                .permute(&lhs_perm)
1019                .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1020            let rhs_outer = rhs_view
1021                .permute(&rhs_perm)
1022                .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1023            batched_outer_product_into(
1024                &mut out.view_mut(),
1025                &lhs_outer,
1026                &rhs_outer,
1027                plan.lhs_free_axes.len(),
1028                plan.rhs_free_axes.len(),
1029            )
1030            .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1031        }
1032        SplitOuterProductLayout::RhsPrefix => {
1033            let lhs_perm: Vec<_> = plan
1034                .lhs_free_axes
1035                .iter()
1036                .chain(plan.lhs_batch_axes.iter())
1037                .copied()
1038                .collect();
1039            let rhs_perm: Vec<_> = plan
1040                .rhs_free_axes
1041                .iter()
1042                .chain(plan.rhs_batch_axes.iter())
1043                .copied()
1044                .collect();
1045            let lhs_outer = lhs_view
1046                .permute(&lhs_perm)
1047                .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1048            let rhs_outer = rhs_view
1049                .permute(&rhs_perm)
1050                .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1051            batched_outer_product_into(
1052                &mut out.view_mut(),
1053                &rhs_outer,
1054                &lhs_outer,
1055                plan.rhs_free_axes.len(),
1056                plan.lhs_free_axes.len(),
1057            )
1058            .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1059        }
1060    }
1061    Ok(Some(tensor_from_array(out)))
1062}
1063
1064struct LazyOuterProduct<T> {
1065    base: TypedTensor<T>,
1066    shape: Vec<usize>,
1067    strides: Vec<isize>,
1068}
1069
1070fn axes_by_physical_stride<T>(view: &TypedTensorView<'_, T>, axes: &[usize]) -> Vec<usize>
1071where
1072    T: 'static,
1073{
1074    let mut sorted = axes.to_vec();
1075    sorted.sort_by(|&lhs_axis, &rhs_axis| {
1076        view.strides()[lhs_axis]
1077            .cmp(&view.strides()[rhs_axis])
1078            .then_with(|| lhs_axis.cmp(&rhs_axis))
1079    });
1080    sorted
1081}
1082
1083fn append_axis_shapes<T>(shape: &mut Vec<usize>, view: &TypedTensorView<'_, T>, axes: &[usize])
1084where
1085    T: 'static,
1086{
1087    shape.extend(axes.iter().map(|&axis| view.shape()[axis]));
1088}
1089
1090fn set_lazy_stride(
1091    logical_strides: &mut [Option<isize>],
1092    output_axis: usize,
1093    stride: isize,
1094) -> crate::Result<()> {
1095    let rank = logical_strides.len();
1096    let slot = logical_strides
1097        .get_mut(output_axis)
1098        .ok_or(crate::Error::AxisOutOfBounds {
1099            op: "broadcast_multiply",
1100            axis: output_axis,
1101            rank,
1102        })?;
1103    if slot.replace(stride).is_some() {
1104        return Err(crate::Error::DuplicateAxis {
1105            op: "broadcast_multiply",
1106            axis: output_axis,
1107            role: "lazy output layout",
1108        });
1109    }
1110    Ok(())
1111}
1112
1113struct LazyOuterProductStrideSpec<'a> {
1114    output_shape: &'a [usize],
1115    base_shape: &'a [usize],
1116    leading_axes: &'a [usize],
1117    leading_dims: &'a [usize],
1118    trailing_axes: &'a [usize],
1119    trailing_dims: &'a [usize],
1120    lhs_batch_axes: &'a [usize],
1121    rhs_batch_axes: &'a [usize],
1122    lhs_dims: &'a [usize],
1123    rhs_dims: &'a [usize],
1124}
1125
1126fn lazy_outer_product_strides(spec: LazyOuterProductStrideSpec<'_>) -> crate::Result<Vec<isize>> {
1127    let base_strides = col_major_strides(spec.base_shape)?;
1128    let mut logical_strides = vec![None; spec.output_shape.len()];
1129    let mut base_axis = 0usize;
1130
1131    for &axis in spec.leading_axes {
1132        set_lazy_stride(
1133            &mut logical_strides,
1134            spec.leading_dims[axis],
1135            base_strides[base_axis],
1136        )?;
1137        base_axis += 1;
1138    }
1139    for &axis in spec.trailing_axes {
1140        set_lazy_stride(
1141            &mut logical_strides,
1142            spec.trailing_dims[axis],
1143            base_strides[base_axis],
1144        )?;
1145        base_axis += 1;
1146    }
1147    for (&lhs_axis, &rhs_axis) in spec.lhs_batch_axes.iter().zip(spec.rhs_batch_axes.iter()) {
1148        let output_axis = spec.lhs_dims[lhs_axis];
1149        if spec.rhs_dims[rhs_axis] != output_axis {
1150            return Err(crate::Error::backend_failure(
1151                "broadcast_multiply",
1152                "batch axes disagree while building lazy outer-product layout",
1153            ));
1154        }
1155        set_lazy_stride(&mut logical_strides, output_axis, base_strides[base_axis])?;
1156        base_axis += 1;
1157    }
1158
1159    logical_strides
1160        .into_iter()
1161        .collect::<Option<Vec<_>>>()
1162        .ok_or_else(|| {
1163            crate::Error::backend_failure(
1164                "broadcast_multiply",
1165                "lazy outer-product layout did not cover every output axis",
1166            )
1167        })
1168}
1169
1170fn lazy_outer_product_value(
1171    tensor: Tensor,
1172    shape: Vec<usize>,
1173    strides: Vec<isize>,
1174) -> crate::Result<TensorValue> {
1175    Ok(TensorValue::View(TensorOwnedView::from_parts(
1176        Arc::new(tensor),
1177        shape,
1178        strides,
1179        0,
1180    )?))
1181}
1182
1183fn try_lazy_outer_product_with_pool<T>(
1184    buffers: &mut BufferPool,
1185    lhs: &TypedTensorView<'_, T>,
1186    lhs_shape: &[usize],
1187    lhs_dims: &[usize],
1188    rhs: &TypedTensorView<'_, T>,
1189    rhs_shape: &[usize],
1190    rhs_dims: &[usize],
1191) -> crate::Result<Option<LazyOuterProduct<T>>>
1192where
1193    T: Copy + Clone + Mul<Output = T> + PoolScalar + 'static,
1194{
1195    let Some(plan) = split_outer_product_plan(lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims)?
1196    else {
1197        return Ok(None);
1198    };
1199
1200    let lhs_free_axes = axes_by_physical_stride(lhs, &plan.lhs_free_axes);
1201    let rhs_free_axes = axes_by_physical_stride(rhs, &plan.rhs_free_axes);
1202    if lhs_free_axes == plan.lhs_free_axes && rhs_free_axes == plan.rhs_free_axes {
1203        return Ok(None);
1204    }
1205
1206    let lhs_view = typed_view_from_view("broadcast_multiply", lhs)?;
1207    let rhs_view = typed_view_from_view("broadcast_multiply", rhs)?;
1208
1209    match plan.layout {
1210        SplitOuterProductLayout::LhsPrefix => {
1211            let lhs_perm: Vec<_> = lhs_free_axes
1212                .iter()
1213                .chain(plan.lhs_batch_axes.iter())
1214                .copied()
1215                .collect();
1216            let rhs_perm: Vec<_> = rhs_free_axes
1217                .iter()
1218                .chain(plan.rhs_batch_axes.iter())
1219                .copied()
1220                .collect();
1221            let lhs_outer = lhs_view
1222                .permute(&lhs_perm)
1223                .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1224            let rhs_outer = rhs_view
1225                .permute(&rhs_perm)
1226                .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1227
1228            let mut base_shape = Vec::with_capacity(lhs_shape.len());
1229            append_axis_shapes(&mut base_shape, lhs, &lhs_free_axes);
1230            append_axis_shapes(&mut base_shape, rhs, &rhs_free_axes);
1231            append_axis_shapes(&mut base_shape, lhs, &plan.lhs_batch_axes);
1232            let strides = lazy_outer_product_strides(LazyOuterProductStrideSpec {
1233                output_shape: lhs_shape,
1234                base_shape: &base_shape,
1235                leading_axes: &lhs_free_axes,
1236                leading_dims: lhs_dims,
1237                trailing_axes: &rhs_free_axes,
1238                trailing_dims: rhs_dims,
1239                lhs_batch_axes: &plan.lhs_batch_axes,
1240                rhs_batch_axes: &plan.rhs_batch_axes,
1241                lhs_dims,
1242                rhs_dims,
1243            })?;
1244
1245            // SAFETY: every element in the physical base output is assigned below.
1246            let mut base = unsafe { typed_array_uninit_from_pool(buffers, &base_shape) };
1247            batched_outer_product_into(
1248                &mut base.view_mut(),
1249                &lhs_outer,
1250                &rhs_outer,
1251                lhs_free_axes.len(),
1252                rhs_free_axes.len(),
1253            )
1254            .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1255            Ok(Some(LazyOuterProduct {
1256                base: tensor_from_array(base),
1257                shape: lhs_shape.to_vec(),
1258                strides,
1259            }))
1260        }
1261        SplitOuterProductLayout::RhsPrefix => {
1262            let lhs_perm: Vec<_> = lhs_free_axes
1263                .iter()
1264                .chain(plan.lhs_batch_axes.iter())
1265                .copied()
1266                .collect();
1267            let rhs_perm: Vec<_> = rhs_free_axes
1268                .iter()
1269                .chain(plan.rhs_batch_axes.iter())
1270                .copied()
1271                .collect();
1272            let lhs_outer = lhs_view
1273                .permute(&lhs_perm)
1274                .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1275            let rhs_outer = rhs_view
1276                .permute(&rhs_perm)
1277                .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1278
1279            let mut base_shape = Vec::with_capacity(lhs_shape.len());
1280            append_axis_shapes(&mut base_shape, rhs, &rhs_free_axes);
1281            append_axis_shapes(&mut base_shape, lhs, &lhs_free_axes);
1282            append_axis_shapes(&mut base_shape, lhs, &plan.lhs_batch_axes);
1283            let strides = lazy_outer_product_strides(LazyOuterProductStrideSpec {
1284                output_shape: lhs_shape,
1285                base_shape: &base_shape,
1286                leading_axes: &rhs_free_axes,
1287                leading_dims: rhs_dims,
1288                trailing_axes: &lhs_free_axes,
1289                trailing_dims: lhs_dims,
1290                lhs_batch_axes: &plan.lhs_batch_axes,
1291                rhs_batch_axes: &plan.rhs_batch_axes,
1292                lhs_dims,
1293                rhs_dims,
1294            })?;
1295
1296            // SAFETY: every element in the physical base output is assigned below.
1297            let mut base = unsafe { typed_array_uninit_from_pool(buffers, &base_shape) };
1298            batched_outer_product_into(
1299                &mut base.view_mut(),
1300                &rhs_outer,
1301                &lhs_outer,
1302                rhs_free_axes.len(),
1303                lhs_free_axes.len(),
1304            )
1305            .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1306            Ok(Some(LazyOuterProduct {
1307                base: tensor_from_array(base),
1308                shape: lhs_shape.to_vec(),
1309                strides,
1310            }))
1311        }
1312    }
1313}
1314
1315#[allow(clippy::too_many_arguments)]
1316fn typed_broadcast_mul_view_with_pool<T, L, R>(
1317    buffers: &mut BufferPool,
1318    lhs: &TypedTensorView<'_, T, L>,
1319    lhs_shape: &[usize],
1320    lhs_dims: &[usize],
1321    rhs: &TypedTensorView<'_, T, R>,
1322    rhs_shape: &[usize],
1323    rhs_dims: &[usize],
1324) -> crate::Result<TypedTensor<T>>
1325where
1326    T: Copy + Clone + Zero + Mul<Output = T> + PoolScalar + 'static,
1327    L: TensorRank,
1328    R: TensorRank,
1329{
1330    if lhs_shape != rhs_shape {
1331        return Err(crate::Error::ShapeMismatch {
1332            op: "broadcast_multiply",
1333            lhs: lhs_shape.to_vec(),
1334            rhs: rhs_shape.to_vec(),
1335        });
1336    }
1337    let output_rank = lhs_shape.len();
1338    let lhs_is_scalar = lhs.shape().is_empty() && lhs_dims.is_empty();
1339    let rhs_is_scalar = rhs.shape().is_empty() && rhs_dims.is_empty();
1340    let lhs_is_full_output =
1341        lhs.shape() == lhs_shape && lhs_dims.iter().copied().eq(0..output_rank);
1342    let rhs_is_full_output =
1343        rhs.shape() == rhs_shape && rhs_dims.iter().copied().eq(0..output_rank);
1344    if lhs_is_scalar && rhs_is_scalar {
1345        let lhs_scalar = typed_view_from_view("broadcast_multiply", lhs)?.get(&[]);
1346        let rhs_scalar = typed_view_from_view("broadcast_multiply", rhs)?.get(&[]);
1347        return filled_broadcast_multiply_tensor(buffers, lhs_shape, lhs_scalar * rhs_scalar);
1348    }
1349    if lhs_is_scalar && rhs_is_full_output {
1350        let scalar = typed_view_from_view("broadcast_multiply", lhs)?.get(&[]);
1351        // SAFETY: map_into overwrites every output element.
1352        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs_shape) };
1353        map_into(
1354            &mut out.view_mut(),
1355            &typed_view_from_view("broadcast_multiply", rhs)?,
1356            |x| scalar * x,
1357        )
1358        .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1359        return Ok(tensor_from_array(out));
1360    }
1361    if rhs_is_scalar && lhs_is_full_output {
1362        let scalar = typed_view_from_view("broadcast_multiply", rhs)?.get(&[]);
1363        // SAFETY: map_into overwrites every output element.
1364        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs_shape) };
1365        map_into(
1366            &mut out.view_mut(),
1367            &typed_view_from_view("broadcast_multiply", lhs)?,
1368            |x| x * scalar,
1369        )
1370        .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1371        return Ok(tensor_from_array(out));
1372    }
1373
1374    // SAFETY: broadcast_mul_into overwrites every output element.
1375    let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs_shape) };
1376    let lhs_view = typed_view_from_view("broadcast_multiply", lhs)?;
1377    let rhs_view = typed_view_from_view("broadcast_multiply", rhs)?;
1378    broadcast_mul_into(
1379        &mut out.view_mut(),
1380        &lhs_view,
1381        lhs_dims,
1382        &rhs_view,
1383        rhs_dims,
1384    )
1385    .map_err(|err| crate::Error::backend_failure("broadcast_multiply", err))?;
1386    Ok(tensor_from_array(out))
1387}
1388
1389fn filled_broadcast_multiply_tensor<T>(
1390    buffers: &mut BufferPool,
1391    shape: &[usize],
1392    fill: T,
1393) -> crate::Result<TypedTensor<T>>
1394where
1395    T: Copy + Clone + PoolScalar + 'static,
1396{
1397    let len = shape.iter().try_fold(1usize, |acc, &dim| {
1398        acc.checked_mul(dim).ok_or_else(|| {
1399            crate::Error::backend_failure("broadcast_multiply", "output shape size overflows usize")
1400        })
1401    })?;
1402    // SAFETY: every pooled element is initialized with `fill` before tensor construction.
1403    let mut data = unsafe { T::pool_acquire(buffers, len) };
1404    data.fill(fill);
1405    TypedTensor::from_vec_col_major(shape.to_vec(), data)
1406}
1407
1408#[allow(clippy::too_many_arguments)]
1409pub(crate) fn broadcast_multiply_read_with_pool(
1410    buffers: &mut BufferPool,
1411    lhs: TensorRead<'_>,
1412    lhs_shape: &[usize],
1413    lhs_dims: &[usize],
1414    rhs: TensorRead<'_>,
1415    rhs_shape: &[usize],
1416    rhs_dims: &[usize],
1417) -> crate::Result<Option<Tensor>> {
1418    let lhs = read_as_cpu_view(lhs);
1419    let rhs = read_as_cpu_view(rhs);
1420
1421    macro_rules! dispatch {
1422        ($variant:ident, $lhs:expr, $rhs:expr) => {{
1423            if let Some(out) = try_outer_product_with_pool(
1424                buffers, &$lhs, lhs_shape, lhs_dims, &$rhs, rhs_shape, rhs_dims,
1425            )? {
1426                return Ok(Some(Tensor::$variant(out)));
1427            }
1428            Ok(Some(Tensor::$variant(typed_broadcast_mul_view_with_pool(
1429                buffers, &$lhs, lhs_shape, lhs_dims, &$rhs, rhs_shape, rhs_dims,
1430            )?)))
1431        }};
1432    }
1433
1434    match (lhs, rhs) {
1435        (CpuReadView::F32(lhs), CpuReadView::F32(rhs)) => dispatch!(F32, lhs, rhs),
1436        (CpuReadView::F64(lhs), CpuReadView::F64(rhs)) => dispatch!(F64, lhs, rhs),
1437        (CpuReadView::I32(lhs), CpuReadView::I32(rhs)) => dispatch!(I32, lhs, rhs),
1438        (CpuReadView::I64(lhs), CpuReadView::I64(rhs)) => dispatch!(I64, lhs, rhs),
1439        (CpuReadView::C32(lhs), CpuReadView::C32(rhs)) => dispatch!(C32, lhs, rhs),
1440        (CpuReadView::C64(lhs), CpuReadView::C64(rhs)) => dispatch!(C64, lhs, rhs),
1441        _ => Ok(None),
1442    }
1443}
1444
1445#[allow(clippy::too_many_arguments)]
1446pub(crate) fn broadcast_multiply_value_with_pool(
1447    buffers: &mut BufferPool,
1448    lhs: TensorRead<'_>,
1449    lhs_shape: &[usize],
1450    lhs_dims: &[usize],
1451    rhs: TensorRead<'_>,
1452    rhs_shape: &[usize],
1453    rhs_dims: &[usize],
1454) -> crate::Result<Option<TensorValue>> {
1455    let lhs_view = read_as_cpu_view(lhs.clone());
1456    let rhs_view = read_as_cpu_view(rhs.clone());
1457
1458    macro_rules! dispatch_lazy {
1459        ($variant:ident, $lhs:expr, $rhs:expr) => {{
1460            if let Some(out) = try_lazy_outer_product_with_pool(
1461                buffers, &$lhs, lhs_shape, lhs_dims, &$rhs, rhs_shape, rhs_dims,
1462            )? {
1463                return Ok(Some(lazy_outer_product_value(
1464                    Tensor::$variant(out.base),
1465                    out.shape,
1466                    out.strides,
1467                )?));
1468            }
1469        }};
1470    }
1471
1472    match (lhs_view, rhs_view) {
1473        (CpuReadView::F32(lhs_view), CpuReadView::F32(rhs_view)) => {
1474            dispatch_lazy!(F32, lhs_view, rhs_view);
1475        }
1476        (CpuReadView::F64(lhs_view), CpuReadView::F64(rhs_view)) => {
1477            dispatch_lazy!(F64, lhs_view, rhs_view);
1478        }
1479        (CpuReadView::I32(lhs_view), CpuReadView::I32(rhs_view)) => {
1480            dispatch_lazy!(I32, lhs_view, rhs_view);
1481        }
1482        (CpuReadView::I64(lhs_view), CpuReadView::I64(rhs_view)) => {
1483            dispatch_lazy!(I64, lhs_view, rhs_view);
1484        }
1485        (CpuReadView::C32(lhs_view), CpuReadView::C32(rhs_view)) => {
1486            dispatch_lazy!(C32, lhs_view, rhs_view);
1487        }
1488        (CpuReadView::C64(lhs_view), CpuReadView::C64(rhs_view)) => {
1489            dispatch_lazy!(C64, lhs_view, rhs_view);
1490        }
1491        _ => {}
1492    }
1493
1494    broadcast_multiply_read_with_pool(buffers, lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims)
1495        .map(|tensor| tensor.map(TensorValue::from_tensor))
1496}
1497
1498/// Divide two CPU tensors elementwise.
1499///
1500/// # Examples
1501///
1502/// ```
1503/// use tenferro_cpu::div;
1504/// use tenferro_tensor::Tensor;
1505///
1506/// let a = Tensor::from_vec_col_major(vec![2], vec![8.0_f64, 15.0])?;
1507/// let b = Tensor::from_vec_col_major(vec![2], vec![2.0_f64, 5.0])?;
1508/// let out = div(&a, &b)?;
1509/// assert_eq!(out.as_slice::<f64>().unwrap(), &[4.0, 3.0]);
1510/// # Ok::<(), tenferro_tensor::Error>(())
1511/// ```
1512pub fn div(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
1513    with_local_pool(|buffers| div_with_pool(buffers, lhs, rhs))
1514}
1515
1516pub(crate) fn div_with_pool(
1517    buffers: &mut BufferPool,
1518    lhs: &Tensor,
1519    rhs: &Tensor,
1520) -> crate::Result<Tensor> {
1521    match (lhs, rhs) {
1522        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_div_with_pool(buffers, a, b)?)),
1523        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_div_with_pool(buffers, a, b)?)),
1524        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_div_with_pool(buffers, a, b)?)),
1525        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_div_with_pool(buffers, a, b)?)),
1526        (Tensor::F32(a), Tensor::C32(b)) if a.shape().is_empty() => {
1527            let scalar = complex_scalar_tensor(typed_host_data("div", a)?[0])?;
1528            Ok(Tensor::C32(typed_div_with_pool(buffers, &scalar, b)?))
1529        }
1530        (Tensor::C32(a), Tensor::F32(b)) if b.shape().is_empty() => {
1531            let scalar = complex_scalar_tensor(typed_host_data("div", b)?[0])?;
1532            Ok(Tensor::C32(typed_div_with_pool(buffers, a, &scalar)?))
1533        }
1534        (Tensor::F64(a), Tensor::C64(b)) if a.shape().is_empty() => {
1535            let scalar = complex_scalar_tensor(typed_host_data("div", a)?[0])?;
1536            Ok(Tensor::C64(typed_div_with_pool(buffers, &scalar, b)?))
1537        }
1538        (Tensor::C64(a), Tensor::F64(b)) if b.shape().is_empty() => {
1539            let scalar = complex_scalar_tensor(typed_host_data("div", b)?[0])?;
1540            Ok(Tensor::C64(typed_div_with_pool(buffers, a, &scalar)?))
1541        }
1542        _ => Err(crate::Error::DTypeMismatch {
1543            op: "div",
1544            lhs: lhs.dtype(),
1545            rhs: rhs.dtype(),
1546        }),
1547    }
1548}
1549
1550pub(crate) fn div_read_with_pool(
1551    buffers: &mut BufferPool,
1552    lhs: TensorRead<'_>,
1553    rhs: TensorRead<'_>,
1554) -> crate::Result<Tensor> {
1555    let lhs_dtype = lhs.dtype();
1556    let rhs_dtype = rhs.dtype();
1557    match (read_as_cpu_view(lhs), read_as_cpu_view(rhs)) {
1558        (CpuReadView::F32(a), CpuReadView::F32(b)) => Ok(Tensor::F32(typed_binary_view_with_pool(
1559            "div",
1560            buffers,
1561            &a,
1562            &b,
1563            |x, y| x / y,
1564        )?)),
1565        (CpuReadView::F64(a), CpuReadView::F64(b)) => Ok(Tensor::F64(typed_binary_view_with_pool(
1566            "div",
1567            buffers,
1568            &a,
1569            &b,
1570            |x, y| x / y,
1571        )?)),
1572        (CpuReadView::C32(a), CpuReadView::C32(b)) => Ok(Tensor::C32(typed_binary_view_with_pool(
1573            "div",
1574            buffers,
1575            &a,
1576            &b,
1577            |x, y| x / y,
1578        )?)),
1579        (CpuReadView::C64(a), CpuReadView::C64(b)) => Ok(Tensor::C64(typed_binary_view_with_pool(
1580            "div",
1581            buffers,
1582            &a,
1583            &b,
1584            |x, y| x / y,
1585        )?)),
1586        (CpuReadView::F32(real), CpuReadView::C32(complex)) if real.shape().is_empty() => {
1587            let scalar = complex_scalar_tensor_from_view(&real)?;
1588            let scalar = scalar.as_view();
1589            Ok(Tensor::C32(typed_binary_view_with_pool(
1590                "div",
1591                buffers,
1592                &scalar,
1593                &complex,
1594                |x, y| x / y,
1595            )?))
1596        }
1597        (CpuReadView::C32(complex), CpuReadView::F32(real)) if real.shape().is_empty() => {
1598            let scalar = complex_scalar_tensor_from_view(&real)?;
1599            let scalar = scalar.as_view();
1600            Ok(Tensor::C32(typed_binary_view_with_pool(
1601                "div",
1602                buffers,
1603                &complex,
1604                &scalar,
1605                |x, y| x / y,
1606            )?))
1607        }
1608        (CpuReadView::F64(real), CpuReadView::C64(complex)) if real.shape().is_empty() => {
1609            let scalar = complex_scalar_tensor_from_view(&real)?;
1610            let scalar = scalar.as_view();
1611            Ok(Tensor::C64(typed_binary_view_with_pool(
1612                "div",
1613                buffers,
1614                &scalar,
1615                &complex,
1616                |x, y| x / y,
1617            )?))
1618        }
1619        (CpuReadView::C64(complex), CpuReadView::F64(real)) if real.shape().is_empty() => {
1620            let scalar = complex_scalar_tensor_from_view(&real)?;
1621            let scalar = scalar.as_view();
1622            Ok(Tensor::C64(typed_binary_view_with_pool(
1623                "div",
1624                buffers,
1625                &complex,
1626                &scalar,
1627                |x, y| x / y,
1628            )?))
1629        }
1630        _ => Err(crate::Error::DTypeMismatch {
1631            op: "div",
1632            lhs: lhs_dtype,
1633            rhs: rhs_dtype,
1634        }),
1635    }
1636}
1637
1638/// Negate a CPU tensor elementwise.
1639///
1640/// # Examples
1641///
1642/// ```
1643/// use tenferro_cpu::neg;
1644/// use tenferro_tensor::Tensor;
1645///
1646/// let input = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, -2.0])?;
1647/// let out = neg(&input)?;
1648/// assert_eq!(out.as_slice::<f64>().unwrap(), &[-1.0, 2.0]);
1649/// # Ok::<(), tenferro_tensor::Error>(())
1650/// ```
1651pub fn neg(input: &Tensor) -> crate::Result<Tensor> {
1652    with_local_pool(|buffers| neg_with_pool(buffers, input))
1653}
1654
1655pub(crate) fn neg_with_pool(buffers: &mut BufferPool, input: &Tensor) -> crate::Result<Tensor> {
1656    match input {
1657        Tensor::F32(t) => Ok(Tensor::F32(typed_neg_with_pool(buffers, t)?)),
1658        Tensor::F64(t) => Ok(Tensor::F64(typed_neg_with_pool(buffers, t)?)),
1659        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(crate::Error::backend_failure(
1660            "neg",
1661            format!("unsupported dtype {:?}", input.dtype()),
1662        )),
1663        Tensor::C32(t) => Ok(Tensor::C32(typed_neg_with_pool(buffers, t)?)),
1664        Tensor::C64(t) => Ok(Tensor::C64(typed_neg_with_pool(buffers, t)?)),
1665    }
1666}
1667
1668pub(crate) fn neg_read_with_pool(
1669    buffers: &mut BufferPool,
1670    input: TensorRead<'_>,
1671) -> crate::Result<Tensor> {
1672    let dtype = input.dtype();
1673    match read_as_cpu_view(input) {
1674        CpuReadView::F32(t) => Ok(Tensor::F32(typed_unary_view_with_pool(
1675            "neg",
1676            buffers,
1677            &t,
1678            |x| -x,
1679        )?)),
1680        CpuReadView::F64(t) => Ok(Tensor::F64(typed_unary_view_with_pool(
1681            "neg",
1682            buffers,
1683            &t,
1684            |x| -x,
1685        )?)),
1686        CpuReadView::C32(t) => Ok(Tensor::C32(typed_unary_view_with_pool(
1687            "neg",
1688            buffers,
1689            &t,
1690            |x| -x,
1691        )?)),
1692        CpuReadView::C64(t) => Ok(Tensor::C64(typed_unary_view_with_pool(
1693            "neg",
1694            buffers,
1695            &t,
1696            |x| -x,
1697        )?)),
1698        _ => Err(crate::Error::backend_failure(
1699            "neg",
1700            format!("unsupported dtype {dtype:?}"),
1701        )),
1702    }
1703}
1704
1705/// Conjugate a real or complex CPU tensor elementwise.
1706///
1707/// # Examples
1708///
1709/// ```
1710/// use num_complex::Complex64;
1711/// use tenferro_cpu::conj;
1712/// use tenferro_tensor::Tensor;
1713///
1714/// let input = Tensor::from_vec_col_major(vec![1], vec![Complex64::new(1.0, 2.0)])?;
1715/// let out = conj(&input)?;
1716/// assert_eq!(out.as_slice::<Complex64>().unwrap(), &[Complex64::new(1.0, -2.0)]);
1717/// # Ok::<(), tenferro_tensor::Error>(())
1718/// ```
1719pub fn conj(input: &Tensor) -> crate::Result<Tensor> {
1720    with_local_pool(|buffers| conj_with_pool(buffers, input))
1721}
1722
1723pub(crate) fn conj_with_pool(buffers: &mut BufferPool, input: &Tensor) -> crate::Result<Tensor> {
1724    match input {
1725        Tensor::F32(t) => Ok(Tensor::F32(typed_conj_with_pool(buffers, t)?)),
1726        Tensor::F64(t) => Ok(Tensor::F64(typed_conj_with_pool(buffers, t)?)),
1727        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(crate::Error::backend_failure(
1728            "conj",
1729            format!("unsupported dtype {:?}", input.dtype()),
1730        )),
1731        Tensor::C32(t) => Ok(Tensor::C32(typed_conj_with_pool(buffers, t)?)),
1732        Tensor::C64(t) => Ok(Tensor::C64(typed_conj_with_pool(buffers, t)?)),
1733    }
1734}
1735
1736pub(crate) fn conj_read_with_pool(
1737    buffers: &mut BufferPool,
1738    input: TensorRead<'_>,
1739) -> crate::Result<Tensor> {
1740    let dtype = input.dtype();
1741    match read_as_cpu_view(input) {
1742        CpuReadView::F32(t) => Ok(Tensor::F32(typed_unary_view_with_pool(
1743            "conj",
1744            buffers,
1745            &t,
1746            |x| x.conj_elem(),
1747        )?)),
1748        CpuReadView::F64(t) => Ok(Tensor::F64(typed_unary_view_with_pool(
1749            "conj",
1750            buffers,
1751            &t,
1752            |x| x.conj_elem(),
1753        )?)),
1754        CpuReadView::C32(t) => Ok(Tensor::C32(typed_unary_view_with_pool(
1755            "conj",
1756            buffers,
1757            &t,
1758            |x| x.conj_elem(),
1759        )?)),
1760        CpuReadView::C64(t) => Ok(Tensor::C64(typed_unary_view_with_pool(
1761            "conj",
1762            buffers,
1763            &t,
1764            |x| x.conj_elem(),
1765        )?)),
1766        _ => Err(crate::Error::backend_failure(
1767            "conj",
1768            format!("unsupported dtype {dtype:?}"),
1769        )),
1770    }
1771}
1772
1773/// Compute elementwise absolute values.
1774///
1775/// Complex inputs return real magnitudes (`C32 -> F32`, `C64 -> F64`).
1776///
1777/// # Examples
1778///
1779/// ```
1780/// use tenferro_cpu::abs;
1781/// use tenferro_tensor::Tensor;
1782///
1783/// let input = Tensor::from_vec_col_major(vec![2], vec![-3.0_f64, 4.0])?;
1784/// let out = abs(&input)?;
1785/// assert_eq!(out.as_slice::<f64>().unwrap(), &[3.0, 4.0]);
1786/// # Ok::<(), tenferro_tensor::Error>(())
1787/// ```
1788pub fn abs(input: &Tensor) -> crate::Result<Tensor> {
1789    with_local_pool(|buffers| abs_with_pool(buffers, input))
1790}
1791
1792pub(crate) fn abs_with_pool(buffers: &mut BufferPool, input: &Tensor) -> crate::Result<Tensor> {
1793    match input {
1794        Tensor::F32(t) => Ok(Tensor::F32(typed_abs_with_pool(buffers, t)?)),
1795        Tensor::F64(t) => Ok(Tensor::F64(typed_abs_with_pool(buffers, t)?)),
1796        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(crate::Error::backend_failure(
1797            "abs",
1798            format!("unsupported dtype {:?}", input.dtype()),
1799        )),
1800        Tensor::C32(t) => Ok(Tensor::F32(typed_complex_abs_with_pool(buffers, t)?)),
1801        Tensor::C64(t) => Ok(Tensor::F64(typed_complex_abs_with_pool(buffers, t)?)),
1802    }
1803}
1804
1805pub(crate) fn abs_read_with_pool(
1806    buffers: &mut BufferPool,
1807    input: TensorRead<'_>,
1808) -> crate::Result<Tensor> {
1809    let dtype = input.dtype();
1810    match read_as_cpu_view(input) {
1811        CpuReadView::F32(t) => Ok(Tensor::F32(typed_unary_view_with_pool(
1812            "abs",
1813            buffers,
1814            &t,
1815            |x| x.abs_elem(),
1816        )?)),
1817        CpuReadView::F64(t) => Ok(Tensor::F64(typed_unary_view_with_pool(
1818            "abs",
1819            buffers,
1820            &t,
1821            |x| x.abs_elem(),
1822        )?)),
1823        CpuReadView::C32(t) => Ok(Tensor::F32(typed_complex_abs_view_with_pool(buffers, &t)?)),
1824        CpuReadView::C64(t) => Ok(Tensor::F64(typed_complex_abs_view_with_pool(buffers, &t)?)),
1825        _ => Err(crate::Error::backend_failure(
1826            "abs",
1827            format!("unsupported dtype {dtype:?}"),
1828        )),
1829    }
1830}
1831
1832/// Compute elementwise signs.
1833///
1834/// # Examples
1835///
1836/// ```
1837/// use tenferro_cpu::sign;
1838/// use tenferro_tensor::Tensor;
1839///
1840/// let input = Tensor::from_vec_col_major(vec![3], vec![-2.0_f64, 0.0, 3.0])?;
1841/// let out = sign(&input)?;
1842/// assert_eq!(out.as_slice::<f64>().unwrap(), &[-1.0, 0.0, 1.0]);
1843/// # Ok::<(), tenferro_tensor::Error>(())
1844/// ```
1845pub fn sign(input: &Tensor) -> crate::Result<Tensor> {
1846    with_local_pool(|buffers| sign_with_pool(buffers, input))
1847}
1848
1849pub(crate) fn sign_with_pool(buffers: &mut BufferPool, input: &Tensor) -> crate::Result<Tensor> {
1850    match input {
1851        Tensor::F32(t) => Ok(Tensor::F32(typed_sign_with_pool(buffers, t)?)),
1852        Tensor::F64(t) => Ok(Tensor::F64(typed_sign_with_pool(buffers, t)?)),
1853        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) => Err(crate::Error::backend_failure(
1854            "sign",
1855            format!("unsupported dtype {:?}", input.dtype()),
1856        )),
1857        Tensor::C32(t) => Ok(Tensor::C32(typed_sign_with_pool(buffers, t)?)),
1858        Tensor::C64(t) => Ok(Tensor::C64(typed_sign_with_pool(buffers, t)?)),
1859    }
1860}
1861
1862pub(crate) fn sign_read_with_pool(
1863    buffers: &mut BufferPool,
1864    input: TensorRead<'_>,
1865) -> crate::Result<Tensor> {
1866    let dtype = input.dtype();
1867    match read_as_cpu_view(input) {
1868        CpuReadView::F32(t) => Ok(Tensor::F32(typed_unary_view_with_pool(
1869            "sign",
1870            buffers,
1871            &t,
1872            |x| x.sign_elem(),
1873        )?)),
1874        CpuReadView::F64(t) => Ok(Tensor::F64(typed_unary_view_with_pool(
1875            "sign",
1876            buffers,
1877            &t,
1878            |x| x.sign_elem(),
1879        )?)),
1880        CpuReadView::C32(t) => Ok(Tensor::C32(typed_unary_view_with_pool(
1881            "sign",
1882            buffers,
1883            &t,
1884            |x| x.sign_elem(),
1885        )?)),
1886        CpuReadView::C64(t) => Ok(Tensor::C64(typed_unary_view_with_pool(
1887            "sign",
1888            buffers,
1889            &t,
1890            |x| x.sign_elem(),
1891        )?)),
1892        _ => Err(crate::Error::backend_failure(
1893            "sign",
1894            format!("unsupported dtype {dtype:?}"),
1895        )),
1896    }
1897}
1898
1899/// Compute elementwise maximum values.
1900///
1901/// # Examples
1902///
1903/// ```
1904/// use tenferro_cpu::maximum;
1905/// use tenferro_tensor::Tensor;
1906///
1907/// let a = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 5.0])?;
1908/// let b = Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0])?;
1909/// let out = maximum(&a, &b)?;
1910/// assert_eq!(out.as_slice::<f64>().unwrap(), &[3.0, 5.0]);
1911/// # Ok::<(), tenferro_tensor::Error>(())
1912/// ```
1913pub fn maximum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
1914    with_local_pool(|buffers| maximum_with_pool(buffers, lhs, rhs))
1915}
1916
1917pub(crate) fn maximum_with_pool(
1918    buffers: &mut BufferPool,
1919    lhs: &Tensor,
1920    rhs: &Tensor,
1921) -> crate::Result<Tensor> {
1922    match (lhs, rhs) {
1923        (Tensor::F32(a), Tensor::F32(b)) => {
1924            Ok(Tensor::F32(typed_maximum_with_pool(buffers, a, b)?))
1925        }
1926        (Tensor::F64(a), Tensor::F64(b)) => {
1927            Ok(Tensor::F64(typed_maximum_with_pool(buffers, a, b)?))
1928        }
1929        (Tensor::C32(a), Tensor::C32(b)) => {
1930            Ok(Tensor::C32(typed_maximum_with_pool(buffers, a, b)?))
1931        }
1932        (Tensor::C64(a), Tensor::C64(b)) => {
1933            Ok(Tensor::C64(typed_maximum_with_pool(buffers, a, b)?))
1934        }
1935        _ => Err(tensor_pair_error("maximum", lhs, rhs)),
1936    }
1937}
1938
1939pub(crate) fn maximum_read_with_pool(
1940    buffers: &mut BufferPool,
1941    lhs: TensorRead<'_>,
1942    rhs: TensorRead<'_>,
1943) -> crate::Result<Tensor> {
1944    let lhs_dtype = lhs.dtype();
1945    let rhs_dtype = rhs.dtype();
1946    match (read_as_cpu_view(lhs), read_as_cpu_view(rhs)) {
1947        (CpuReadView::F32(a), CpuReadView::F32(b)) => Ok(Tensor::F32(
1948            typed_same_shape_binary_view_with_pool("maximum", buffers, &a, &b, |x, y| {
1949                x.max_elem(y)
1950            })?,
1951        )),
1952        (CpuReadView::F64(a), CpuReadView::F64(b)) => Ok(Tensor::F64(
1953            typed_same_shape_binary_view_with_pool("maximum", buffers, &a, &b, |x, y| {
1954                x.max_elem(y)
1955            })?,
1956        )),
1957        (CpuReadView::C32(a), CpuReadView::C32(b)) => Ok(Tensor::C32(
1958            typed_same_shape_binary_view_with_pool("maximum", buffers, &a, &b, |x, y| {
1959                x.max_elem(y)
1960            })?,
1961        )),
1962        (CpuReadView::C64(a), CpuReadView::C64(b)) => Ok(Tensor::C64(
1963            typed_same_shape_binary_view_with_pool("maximum", buffers, &a, &b, |x, y| {
1964                x.max_elem(y)
1965            })?,
1966        )),
1967        _ => Err(dtype_pair_error("maximum", lhs_dtype, rhs_dtype)),
1968    }
1969}
1970
1971/// Compute elementwise minimum values.
1972///
1973/// # Examples
1974///
1975/// ```
1976/// use tenferro_cpu::minimum;
1977/// use tenferro_tensor::Tensor;
1978///
1979/// let a = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 5.0])?;
1980/// let b = Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0])?;
1981/// let out = minimum(&a, &b)?;
1982/// assert_eq!(out.as_slice::<f64>().unwrap(), &[1.0, 4.0]);
1983/// # Ok::<(), tenferro_tensor::Error>(())
1984/// ```
1985pub fn minimum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
1986    with_local_pool(|buffers| minimum_with_pool(buffers, lhs, rhs))
1987}
1988
1989pub(crate) fn minimum_with_pool(
1990    buffers: &mut BufferPool,
1991    lhs: &Tensor,
1992    rhs: &Tensor,
1993) -> crate::Result<Tensor> {
1994    match (lhs, rhs) {
1995        (Tensor::F32(a), Tensor::F32(b)) => {
1996            Ok(Tensor::F32(typed_minimum_with_pool(buffers, a, b)?))
1997        }
1998        (Tensor::F64(a), Tensor::F64(b)) => {
1999            Ok(Tensor::F64(typed_minimum_with_pool(buffers, a, b)?))
2000        }
2001        (Tensor::C32(a), Tensor::C32(b)) => {
2002            Ok(Tensor::C32(typed_minimum_with_pool(buffers, a, b)?))
2003        }
2004        (Tensor::C64(a), Tensor::C64(b)) => {
2005            Ok(Tensor::C64(typed_minimum_with_pool(buffers, a, b)?))
2006        }
2007        _ => Err(tensor_pair_error("minimum", lhs, rhs)),
2008    }
2009}
2010
2011pub(crate) fn minimum_read_with_pool(
2012    buffers: &mut BufferPool,
2013    lhs: TensorRead<'_>,
2014    rhs: TensorRead<'_>,
2015) -> crate::Result<Tensor> {
2016    let lhs_dtype = lhs.dtype();
2017    let rhs_dtype = rhs.dtype();
2018    match (read_as_cpu_view(lhs), read_as_cpu_view(rhs)) {
2019        (CpuReadView::F32(a), CpuReadView::F32(b)) => Ok(Tensor::F32(
2020            typed_same_shape_binary_view_with_pool("minimum", buffers, &a, &b, |x, y| {
2021                x.min_elem(y)
2022            })?,
2023        )),
2024        (CpuReadView::F64(a), CpuReadView::F64(b)) => Ok(Tensor::F64(
2025            typed_same_shape_binary_view_with_pool("minimum", buffers, &a, &b, |x, y| {
2026                x.min_elem(y)
2027            })?,
2028        )),
2029        (CpuReadView::C32(a), CpuReadView::C32(b)) => Ok(Tensor::C32(
2030            typed_same_shape_binary_view_with_pool("minimum", buffers, &a, &b, |x, y| {
2031                x.min_elem(y)
2032            })?,
2033        )),
2034        (CpuReadView::C64(a), CpuReadView::C64(b)) => Ok(Tensor::C64(
2035            typed_same_shape_binary_view_with_pool("minimum", buffers, &a, &b, |x, y| {
2036                x.min_elem(y)
2037            })?,
2038        )),
2039        _ => Err(dtype_pair_error("minimum", lhs_dtype, rhs_dtype)),
2040    }
2041}
2042
2043/// Compare two CPU tensors elementwise.
2044///
2045/// # Examples
2046///
2047/// ```
2048/// use tenferro_cpu::compare;
2049/// use tenferro_tensor::{CompareDir, Tensor};
2050///
2051/// let a = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 5.0])?;
2052/// let b = Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0])?;
2053/// let out = compare(&a, &b, &CompareDir::Gt)?;
2054/// assert_eq!(out.as_slice::<bool>().unwrap(), &[false, true]);
2055/// # Ok::<(), tenferro_tensor::Error>(())
2056/// ```
2057pub fn compare(lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor> {
2058    with_local_pool(|buffers| compare_with_pool(buffers, lhs, rhs, dir))
2059}
2060
2061pub(crate) fn compare_with_pool(
2062    buffers: &mut BufferPool,
2063    lhs: &Tensor,
2064    rhs: &Tensor,
2065    dir: &CompareDir,
2066) -> crate::Result<Tensor> {
2067    match (lhs, rhs) {
2068        (Tensor::F32(a), Tensor::F32(b)) => {
2069            Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2070        }
2071        (Tensor::F64(a), Tensor::F64(b)) => {
2072            Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2073        }
2074        (Tensor::I32(a), Tensor::I32(b)) => {
2075            Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2076        }
2077        (Tensor::I64(a), Tensor::I64(b)) => {
2078            Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2079        }
2080        (Tensor::Bool(a), Tensor::Bool(b)) => {
2081            Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2082        }
2083        (Tensor::C32(a), Tensor::C32(b)) => {
2084            Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2085        }
2086        (Tensor::C64(a), Tensor::C64(b)) => {
2087            Ok(Tensor::Bool(typed_compare_with_pool(buffers, a, b, dir)?))
2088        }
2089        _ => Err(crate::Error::DTypeMismatch {
2090            op: "compare",
2091            lhs: lhs.dtype(),
2092            rhs: rhs.dtype(),
2093        }),
2094    }
2095}
2096
2097pub(crate) fn compare_read_with_pool(
2098    buffers: &mut BufferPool,
2099    lhs: TensorRead<'_>,
2100    rhs: TensorRead<'_>,
2101    dir: &CompareDir,
2102) -> crate::Result<Tensor> {
2103    let lhs_dtype = lhs.dtype();
2104    let rhs_dtype = rhs.dtype();
2105    match (read_as_cpu_view(lhs), read_as_cpu_view(rhs)) {
2106        (CpuReadView::F32(a), CpuReadView::F32(b)) => Ok(Tensor::Bool(
2107            typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2108                x.compare_elem(y, dir)
2109            })?,
2110        )),
2111        (CpuReadView::F64(a), CpuReadView::F64(b)) => Ok(Tensor::Bool(
2112            typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2113                x.compare_elem(y, dir)
2114            })?,
2115        )),
2116        (CpuReadView::I32(a), CpuReadView::I32(b)) => Ok(Tensor::Bool(
2117            typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2118                x.compare_elem(y, dir)
2119            })?,
2120        )),
2121        (CpuReadView::I64(a), CpuReadView::I64(b)) => Ok(Tensor::Bool(
2122            typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2123                x.compare_elem(y, dir)
2124            })?,
2125        )),
2126        (CpuReadView::Bool(a), CpuReadView::Bool(b)) => Ok(Tensor::Bool(
2127            typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2128                x.compare_elem(y, dir)
2129            })?,
2130        )),
2131        (CpuReadView::C32(a), CpuReadView::C32(b)) => Ok(Tensor::Bool(
2132            typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2133                x.compare_elem(y, dir)
2134            })?,
2135        )),
2136        (CpuReadView::C64(a), CpuReadView::C64(b)) => Ok(Tensor::Bool(
2137            typed_same_shape_binary_view_with_pool("compare", buffers, &a, &b, |x, y| {
2138                x.compare_elem(y, dir)
2139            })?,
2140        )),
2141        _ => Err(crate::Error::DTypeMismatch {
2142            op: "compare",
2143            lhs: lhs_dtype,
2144            rhs: rhs_dtype,
2145        }),
2146    }
2147}
2148
2149/// Select values from two tensors using a boolean predicate tensor.
2150///
2151/// # Examples
2152///
2153/// ```
2154/// use tenferro_cpu::select;
2155/// use tenferro_tensor::Tensor;
2156///
2157/// let pred = Tensor::from_vec_col_major(vec![2], vec![true, false])?;
2158/// let on_true = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0])?;
2159/// let on_false = Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0])?;
2160/// let out = select(&pred, &on_true, &on_false)?;
2161/// assert_eq!(out.as_slice::<f64>().unwrap(), &[1.0, 4.0]);
2162/// # Ok::<(), tenferro_tensor::Error>(())
2163/// ```
2164pub fn select(pred: &Tensor, on_true: &Tensor, on_false: &Tensor) -> crate::Result<Tensor> {
2165    with_local_pool(|buffers| select_with_pool(buffers, pred, on_true, on_false))
2166}
2167
2168pub(crate) fn select_with_pool(
2169    buffers: &mut BufferPool,
2170    pred: &Tensor,
2171    on_true: &Tensor,
2172    on_false: &Tensor,
2173) -> crate::Result<Tensor> {
2174    match (pred, on_true, on_false) {
2175        (Tensor::Bool(p), Tensor::F32(t), Tensor::F32(f)) => {
2176            Ok(Tensor::F32(typed_select_with_pool(buffers, p, t, f)?))
2177        }
2178        (Tensor::Bool(p), Tensor::F64(t), Tensor::F64(f)) => {
2179            Ok(Tensor::F64(typed_select_with_pool(buffers, p, t, f)?))
2180        }
2181        (Tensor::Bool(p), Tensor::I32(t), Tensor::I32(f)) => {
2182            Ok(Tensor::I32(typed_select_with_pool(buffers, p, t, f)?))
2183        }
2184        (Tensor::Bool(p), Tensor::I64(t), Tensor::I64(f)) => {
2185            Ok(Tensor::I64(typed_select_with_pool(buffers, p, t, f)?))
2186        }
2187        (Tensor::Bool(p), Tensor::Bool(t), Tensor::Bool(f)) => {
2188            Ok(Tensor::Bool(typed_select_with_pool(buffers, p, t, f)?))
2189        }
2190        (Tensor::Bool(p), Tensor::C32(t), Tensor::C32(f)) => {
2191            Ok(Tensor::C32(typed_select_with_pool(buffers, p, t, f)?))
2192        }
2193        (Tensor::Bool(p), Tensor::C64(t), Tensor::C64(f)) => {
2194            Ok(Tensor::C64(typed_select_with_pool(buffers, p, t, f)?))
2195        }
2196        (Tensor::Bool(_), _, _) => Err(crate::Error::DTypeMismatch {
2197            op: "select",
2198            lhs: on_true.dtype(),
2199            rhs: on_false.dtype(),
2200        }),
2201        _ => Err(crate::Error::DTypeMismatch {
2202            op: "select",
2203            lhs: pred.dtype(),
2204            rhs: crate::DType::Bool,
2205        }),
2206    }
2207}
2208
2209pub(crate) fn select_read_with_pool(
2210    buffers: &mut BufferPool,
2211    pred: TensorRead<'_>,
2212    on_true: TensorRead<'_>,
2213    on_false: TensorRead<'_>,
2214) -> crate::Result<Tensor> {
2215    let pred_dtype = pred.dtype();
2216    let true_dtype = on_true.dtype();
2217    let false_dtype = on_false.dtype();
2218    match (
2219        read_as_cpu_view(pred),
2220        read_as_cpu_view(on_true),
2221        read_as_cpu_view(on_false),
2222    ) {
2223        (CpuReadView::Bool(p), CpuReadView::F32(t), CpuReadView::F32(f)) => Ok(Tensor::F32(
2224            typed_select_view_with_pool(buffers, &p, &t, &f)?,
2225        )),
2226        (CpuReadView::Bool(p), CpuReadView::F64(t), CpuReadView::F64(f)) => Ok(Tensor::F64(
2227            typed_select_view_with_pool(buffers, &p, &t, &f)?,
2228        )),
2229        (CpuReadView::Bool(p), CpuReadView::I32(t), CpuReadView::I32(f)) => Ok(Tensor::I32(
2230            typed_select_view_with_pool(buffers, &p, &t, &f)?,
2231        )),
2232        (CpuReadView::Bool(p), CpuReadView::I64(t), CpuReadView::I64(f)) => Ok(Tensor::I64(
2233            typed_select_view_with_pool(buffers, &p, &t, &f)?,
2234        )),
2235        (CpuReadView::Bool(p), CpuReadView::Bool(t), CpuReadView::Bool(f)) => Ok(Tensor::Bool(
2236            typed_select_view_with_pool(buffers, &p, &t, &f)?,
2237        )),
2238        (CpuReadView::Bool(p), CpuReadView::C32(t), CpuReadView::C32(f)) => Ok(Tensor::C32(
2239            typed_select_view_with_pool(buffers, &p, &t, &f)?,
2240        )),
2241        (CpuReadView::Bool(p), CpuReadView::C64(t), CpuReadView::C64(f)) => Ok(Tensor::C64(
2242            typed_select_view_with_pool(buffers, &p, &t, &f)?,
2243        )),
2244        (CpuReadView::Bool(_), _, _) => Err(crate::Error::DTypeMismatch {
2245            op: "select",
2246            lhs: true_dtype,
2247            rhs: false_dtype,
2248        }),
2249        _ => Err(crate::Error::DTypeMismatch {
2250            op: "select",
2251            lhs: pred_dtype,
2252            rhs: crate::DType::Bool,
2253        }),
2254    }
2255}
2256
2257/// Clamp CPU tensor values elementwise between lower and upper bounds.
2258///
2259/// # Examples
2260///
2261/// ```
2262/// use tenferro_cpu::clamp;
2263/// use tenferro_tensor::Tensor;
2264///
2265/// let input = Tensor::from_vec_col_major(vec![3], vec![-1.0_f64, 2.0, 8.0])?;
2266/// let lower = Tensor::from_vec_col_major(vec![3], vec![0.0_f64, 0.0, 0.0])?;
2267/// let upper = Tensor::from_vec_col_major(vec![3], vec![5.0_f64, 5.0, 5.0])?;
2268/// let out = clamp(&input, &lower, &upper)?;
2269/// assert_eq!(out.as_slice::<f64>().unwrap(), &[0.0, 2.0, 5.0]);
2270/// # Ok::<(), tenferro_tensor::Error>(())
2271/// ```
2272pub fn clamp(input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor> {
2273    with_local_pool(|buffers| clamp_with_pool(buffers, input, lower, upper))
2274}
2275
2276pub(crate) fn clamp_with_pool(
2277    buffers: &mut BufferPool,
2278    input: &Tensor,
2279    lower: &Tensor,
2280    upper: &Tensor,
2281) -> crate::Result<Tensor> {
2282    dispatch_ternary_result_with_pool!("clamp", input, lower, upper, |x, lo, hi| {
2283        typed_clamp_with_pool(buffers, x, lo, hi)
2284    })
2285}
2286
2287pub(crate) fn clamp_read_with_pool(
2288    buffers: &mut BufferPool,
2289    input: TensorRead<'_>,
2290    lower: TensorRead<'_>,
2291    upper: TensorRead<'_>,
2292) -> crate::Result<Tensor> {
2293    let input_dtype = input.dtype();
2294    let lower_dtype = lower.dtype();
2295    match (
2296        read_as_cpu_view(input),
2297        read_as_cpu_view(lower),
2298        read_as_cpu_view(upper),
2299    ) {
2300        (CpuReadView::F32(input), CpuReadView::F32(lower), CpuReadView::F32(upper)) => Ok(
2301            Tensor::F32(typed_clamp_view_with_pool(buffers, &input, &lower, &upper)?),
2302        ),
2303        (CpuReadView::F64(input), CpuReadView::F64(lower), CpuReadView::F64(upper)) => Ok(
2304            Tensor::F64(typed_clamp_view_with_pool(buffers, &input, &lower, &upper)?),
2305        ),
2306        (CpuReadView::C32(input), CpuReadView::C32(lower), CpuReadView::C32(upper)) => Ok(
2307            Tensor::C32(typed_clamp_view_with_pool(buffers, &input, &lower, &upper)?),
2308        ),
2309        (CpuReadView::C64(input), CpuReadView::C64(lower), CpuReadView::C64(upper)) => Ok(
2310            Tensor::C64(typed_clamp_view_with_pool(buffers, &input, &lower, &upper)?),
2311        ),
2312        _ => Err(crate::Error::DTypeMismatch {
2313            op: "clamp",
2314            lhs: input_dtype,
2315            rhs: lower_dtype,
2316        }),
2317    }
2318}
2319
2320pub(crate) fn typed_add_with_pool<T>(
2321    buffers: &mut BufferPool,
2322    lhs: &TypedTensor<T>,
2323    rhs: &TypedTensor<T>,
2324) -> crate::Result<TypedTensor<T>>
2325where
2326    T: Copy + Clone + Zero + Add<Output = T> + PoolScalar,
2327{
2328    if lhs.shape() == rhs.shape() {
2329        // SAFETY: zip_map2_into overwrites every output element.
2330        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2331        zip_map2_into(
2332            &mut out.view_mut(),
2333            &typed_view("add", lhs)?,
2334            &typed_view("add", rhs)?,
2335            |x, y| x + y,
2336        )
2337        .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2338        Ok(tensor_from_array(out))
2339    } else if lhs.shape().is_empty() {
2340        let scalar = typed_host_data("add", lhs)?[0];
2341        // SAFETY: map_into overwrites every output element.
2342        let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2343        map_into(&mut out.view_mut(), &typed_view("add", rhs)?, |x| {
2344            scalar + x
2345        })
2346        .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2347        Ok(tensor_from_array(out))
2348    } else if rhs.shape().is_empty() {
2349        let scalar = typed_host_data("add", rhs)?[0];
2350        // SAFETY: map_into overwrites every output element.
2351        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2352        map_into(&mut out.view_mut(), &typed_view("add", lhs)?, |x| {
2353            x + scalar
2354        })
2355        .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2356        Ok(tensor_from_array(out))
2357    } else {
2358        Err(crate::Error::ShapeMismatch {
2359            op: "add",
2360            lhs: lhs.shape().to_vec(),
2361            rhs: rhs.shape().to_vec(),
2362        })
2363    }
2364}
2365
2366pub(crate) fn typed_add_view_with_pool<T, L, R>(
2367    buffers: &mut BufferPool,
2368    lhs: &TypedTensorView<'_, T, L>,
2369    rhs: &TypedTensorView<'_, T, R>,
2370) -> crate::Result<TypedTensor<T>>
2371where
2372    T: Copy + Clone + Zero + Add<Output = T> + PoolScalar + 'static,
2373    L: TensorRank,
2374    R: TensorRank,
2375{
2376    if lhs.shape() == rhs.shape() {
2377        // SAFETY: zip_map2_into overwrites every output element.
2378        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2379        zip_map2_into(
2380            &mut out.view_mut(),
2381            &typed_view_from_view("add", lhs)?,
2382            &typed_view_from_view("add", rhs)?,
2383            |x, y| x + y,
2384        )
2385        .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2386        Ok(tensor_from_array(out))
2387    } else if lhs.shape().is_empty() {
2388        let scalar = typed_view_from_view("add", lhs)?.get(&[]);
2389        // SAFETY: map_into overwrites every output element.
2390        let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2391        map_into(
2392            &mut out.view_mut(),
2393            &typed_view_from_view("add", rhs)?,
2394            |x| scalar + x,
2395        )
2396        .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2397        Ok(tensor_from_array(out))
2398    } else if rhs.shape().is_empty() {
2399        let scalar = typed_view_from_view("add", rhs)?.get(&[]);
2400        // SAFETY: map_into overwrites every output element.
2401        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2402        map_into(
2403            &mut out.view_mut(),
2404            &typed_view_from_view("add", lhs)?,
2405            |x| x + scalar,
2406        )
2407        .map_err(|err| crate::Error::backend_failure("add", err.to_string()))?;
2408        Ok(tensor_from_array(out))
2409    } else {
2410        Err(crate::Error::ShapeMismatch {
2411            op: "add",
2412            lhs: lhs.shape().to_vec(),
2413            rhs: rhs.shape().to_vec(),
2414        })
2415    }
2416}
2417
2418pub(crate) fn typed_mul_with_pool<T>(
2419    buffers: &mut BufferPool,
2420    lhs: &TypedTensor<T>,
2421    rhs: &TypedTensor<T>,
2422) -> crate::Result<TypedTensor<T>>
2423where
2424    T: Copy + Clone + Zero + Mul<Output = T> + PoolScalar + 'static,
2425{
2426    if lhs.shape() == rhs.shape() {
2427        // SAFETY: mul_into overwrites every output element.
2428        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2429        mul_into(
2430            &mut out.view_mut(),
2431            &typed_view("mul", lhs)?,
2432            &typed_view("mul", rhs)?,
2433        )
2434        .map_err(|err| crate::Error::backend_failure("mul", err))?;
2435        Ok(tensor_from_array(out))
2436    } else if lhs.shape().is_empty() {
2437        let scalar = typed_host_data("mul", lhs)?[0];
2438        // SAFETY: map_into overwrites every output element.
2439        let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2440        map_into(&mut out.view_mut(), &typed_view("mul", rhs)?, |x| {
2441            scalar * x
2442        })
2443        .map_err(|err| crate::Error::backend_failure("mul", err))?;
2444        Ok(tensor_from_array(out))
2445    } else if rhs.shape().is_empty() {
2446        let scalar = typed_host_data("mul", rhs)?[0];
2447        // SAFETY: map_into overwrites every output element.
2448        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2449        map_into(&mut out.view_mut(), &typed_view("mul", lhs)?, |x| {
2450            x * scalar
2451        })
2452        .map_err(|err| crate::Error::backend_failure("mul", err))?;
2453        Ok(tensor_from_array(out))
2454    } else {
2455        Err(crate::Error::ShapeMismatch {
2456            op: "mul",
2457            lhs: lhs.shape().to_vec(),
2458            rhs: rhs.shape().to_vec(),
2459        })
2460    }
2461}
2462
2463pub(crate) fn typed_mul_view_with_pool<T, L, R>(
2464    buffers: &mut BufferPool,
2465    lhs: &TypedTensorView<'_, T, L>,
2466    rhs: &TypedTensorView<'_, T, R>,
2467) -> crate::Result<TypedTensor<T>>
2468where
2469    T: Copy + Clone + Zero + Mul<Output = T> + PoolScalar + 'static,
2470    L: TensorRank,
2471    R: TensorRank,
2472{
2473    if lhs.shape() == rhs.shape() {
2474        // SAFETY: mul_into overwrites every output element.
2475        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2476        mul_into(
2477            &mut out.view_mut(),
2478            &typed_view_from_view("mul", lhs)?,
2479            &typed_view_from_view("mul", rhs)?,
2480        )
2481        .map_err(|err| crate::Error::backend_failure("mul", err))?;
2482        Ok(tensor_from_array(out))
2483    } else if lhs.shape().is_empty() {
2484        let scalar = typed_view_from_view("mul", lhs)?.get(&[]);
2485        // SAFETY: map_into overwrites every output element.
2486        let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2487        map_into(
2488            &mut out.view_mut(),
2489            &typed_view_from_view("mul", rhs)?,
2490            |x| scalar * x,
2491        )
2492        .map_err(|err| crate::Error::backend_failure("mul", err))?;
2493        Ok(tensor_from_array(out))
2494    } else if rhs.shape().is_empty() {
2495        let scalar = typed_view_from_view("mul", rhs)?.get(&[]);
2496        // SAFETY: map_into overwrites every output element.
2497        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2498        map_into(
2499            &mut out.view_mut(),
2500            &typed_view_from_view("mul", lhs)?,
2501            |x| x * scalar,
2502        )
2503        .map_err(|err| crate::Error::backend_failure("mul", err))?;
2504        Ok(tensor_from_array(out))
2505    } else {
2506        Err(crate::Error::ShapeMismatch {
2507            op: "mul",
2508            lhs: lhs.shape().to_vec(),
2509            rhs: rhs.shape().to_vec(),
2510        })
2511    }
2512}
2513
2514pub(crate) fn typed_div_with_pool<T>(
2515    buffers: &mut BufferPool,
2516    lhs: &TypedTensor<T>,
2517    rhs: &TypedTensor<T>,
2518) -> crate::Result<TypedTensor<T>>
2519where
2520    T: Copy + Clone + Zero + Div<Output = T> + PoolScalar,
2521{
2522    if lhs.shape() == rhs.shape() {
2523        // SAFETY: zip_map2_into overwrites every output element.
2524        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2525        zip_map2_into(
2526            &mut out.view_mut(),
2527            &typed_view("div", lhs)?,
2528            &typed_view("div", rhs)?,
2529            |x, y| x / y,
2530        )
2531        .map_err(|err| crate::Error::backend_failure("div", err))?;
2532        Ok(tensor_from_array(out))
2533    } else if lhs.shape().is_empty() {
2534        let scalar = typed_host_data("div", lhs)?[0];
2535        // SAFETY: map_into overwrites every output element.
2536        let mut out = unsafe { typed_array_uninit_from_pool(buffers, rhs.shape()) };
2537        map_into(&mut out.view_mut(), &typed_view("div", rhs)?, |x| {
2538            scalar / x
2539        })
2540        .map_err(|err| crate::Error::backend_failure("div", err))?;
2541        Ok(tensor_from_array(out))
2542    } else if rhs.shape().is_empty() {
2543        let scalar = typed_host_data("div", rhs)?[0];
2544        // SAFETY: map_into overwrites every output element.
2545        let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2546        map_into(&mut out.view_mut(), &typed_view("div", lhs)?, |x| {
2547            x / scalar
2548        })
2549        .map_err(|err| crate::Error::backend_failure("div", err))?;
2550        Ok(tensor_from_array(out))
2551    } else {
2552        Err(crate::Error::ShapeMismatch {
2553            op: "div",
2554            lhs: lhs.shape().to_vec(),
2555            rhs: rhs.shape().to_vec(),
2556        })
2557    }
2558}
2559
2560pub(crate) fn typed_neg_with_pool<T>(
2561    buffers: &mut BufferPool,
2562    input: &TypedTensor<T>,
2563) -> crate::Result<TypedTensor<T>>
2564where
2565    T: Copy + Clone + Zero + Neg<Output = T> + PoolScalar,
2566{
2567    // SAFETY: map_into overwrites every output element.
2568    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2569    map_into(&mut out.view_mut(), &typed_view("neg", input)?, |x| -x)
2570        .map_err(|err| crate::Error::backend_failure("neg", err))?;
2571    Ok(tensor_from_array(out))
2572}
2573
2574pub(crate) fn typed_conj_with_pool<T>(
2575    buffers: &mut BufferPool,
2576    input: &TypedTensor<T>,
2577) -> crate::Result<TypedTensor<T>>
2578where
2579    T: Copy + Clone + Zero + ConjElem + PoolScalar,
2580{
2581    // SAFETY: map_into overwrites every output element.
2582    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2583    map_into(&mut out.view_mut(), &typed_view("conj", input)?, |x| {
2584        x.conj_elem()
2585    })
2586    .map_err(|err| crate::Error::backend_failure("conj", err))?;
2587    Ok(tensor_from_array(out))
2588}
2589
2590pub(crate) fn typed_abs_with_pool<T>(
2591    buffers: &mut BufferPool,
2592    input: &TypedTensor<T>,
2593) -> crate::Result<TypedTensor<T>>
2594where
2595    T: Tier2Elem + PoolScalar,
2596{
2597    // SAFETY: map_into overwrites every output element.
2598    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2599    map_into(&mut out.view_mut(), &typed_view("abs", input)?, |x| {
2600        x.abs_elem()
2601    })
2602    .map_err(|err| crate::Error::backend_failure("abs", err))?;
2603    Ok(tensor_from_array(out))
2604}
2605
2606fn typed_complex_abs_with_pool<T>(
2607    buffers: &mut BufferPool,
2608    input: &TypedTensor<Complex<T>>,
2609) -> crate::Result<TypedTensor<T>>
2610where
2611    T: num_traits::Float + PoolScalar,
2612{
2613    // SAFETY: the following kernel overwrites every output element before any read.
2614    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2615    map_into(&mut out.view_mut(), &typed_view("abs", input)?, |x| {
2616        x.norm()
2617    })
2618    .map_err(|err| crate::Error::backend_failure("abs", err))?;
2619    Ok(tensor_from_array(out))
2620}
2621
2622fn typed_complex_abs_view_with_pool<T, R>(
2623    buffers: &mut BufferPool,
2624    input: &TypedTensorView<'_, Complex<T>, R>,
2625) -> crate::Result<TypedTensor<T>>
2626where
2627    T: num_traits::Float + PoolScalar + 'static,
2628    R: TensorRank,
2629{
2630    // SAFETY: the following kernel overwrites every output element before any read.
2631    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2632    map_into(
2633        &mut out.view_mut(),
2634        &typed_view_from_view("abs", input)?,
2635        |x| x.norm(),
2636    )
2637    .map_err(|err| crate::Error::backend_failure("abs", err))?;
2638    Ok(tensor_from_array(out))
2639}
2640
2641pub(crate) fn typed_sign_with_pool<T>(
2642    buffers: &mut BufferPool,
2643    input: &TypedTensor<T>,
2644) -> crate::Result<TypedTensor<T>>
2645where
2646    T: Tier2Elem + PoolScalar,
2647{
2648    // SAFETY: map_into overwrites every output element.
2649    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2650    map_into(&mut out.view_mut(), &typed_view("sign", input)?, |x| {
2651        x.sign_elem()
2652    })
2653    .map_err(|err| crate::Error::backend_failure("sign", err))?;
2654    Ok(tensor_from_array(out))
2655}
2656
2657pub(crate) fn typed_maximum_with_pool<T>(
2658    buffers: &mut BufferPool,
2659    lhs: &TypedTensor<T>,
2660    rhs: &TypedTensor<T>,
2661) -> crate::Result<TypedTensor<T>>
2662where
2663    T: Tier2Elem + PoolScalar,
2664{
2665    if lhs.shape() != rhs.shape() {
2666        return Err(crate::Error::ShapeMismatch {
2667            op: "maximum",
2668            lhs: lhs.shape().to_vec(),
2669            rhs: rhs.shape().to_vec(),
2670        });
2671    }
2672    // SAFETY: zip_map2_into overwrites every output element.
2673    let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2674    zip_map2_into(
2675        &mut out.view_mut(),
2676        &typed_view("maximum", lhs)?,
2677        &typed_view("maximum", rhs)?,
2678        |x, y| x.max_elem(y),
2679    )
2680    .map_err(|err| crate::Error::backend_failure("maximum", err))?;
2681    Ok(tensor_from_array(out))
2682}
2683
2684pub(crate) fn typed_minimum_with_pool<T>(
2685    buffers: &mut BufferPool,
2686    lhs: &TypedTensor<T>,
2687    rhs: &TypedTensor<T>,
2688) -> crate::Result<TypedTensor<T>>
2689where
2690    T: Tier2Elem + PoolScalar,
2691{
2692    if lhs.shape() != rhs.shape() {
2693        return Err(crate::Error::ShapeMismatch {
2694            op: "minimum",
2695            lhs: lhs.shape().to_vec(),
2696            rhs: rhs.shape().to_vec(),
2697        });
2698    }
2699    // SAFETY: zip_map2_into overwrites every output element.
2700    let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2701    zip_map2_into(
2702        &mut out.view_mut(),
2703        &typed_view("minimum", lhs)?,
2704        &typed_view("minimum", rhs)?,
2705        |x, y| x.min_elem(y),
2706    )
2707    .map_err(|err| crate::Error::backend_failure("minimum", err))?;
2708    Ok(tensor_from_array(out))
2709}
2710
2711pub(crate) fn typed_compare_with_pool<T>(
2712    buffers: &mut BufferPool,
2713    lhs: &TypedTensor<T>,
2714    rhs: &TypedTensor<T>,
2715    dir: &CompareDir,
2716) -> crate::Result<TypedTensor<bool>>
2717where
2718    T: CompareElem,
2719{
2720    if lhs.shape() != rhs.shape() {
2721        return Err(crate::Error::ShapeMismatch {
2722            op: "compare",
2723            lhs: lhs.shape().to_vec(),
2724            rhs: rhs.shape().to_vec(),
2725        });
2726    }
2727    // SAFETY: zip_map2_into overwrites every output element.
2728    let mut out = unsafe { typed_array_uninit_from_pool(buffers, lhs.shape()) };
2729    zip_map2_into(
2730        &mut out.view_mut(),
2731        &typed_view("compare", lhs)?,
2732        &typed_view("compare", rhs)?,
2733        |x, y| x.compare_elem(y, dir),
2734    )
2735    .map_err(|err| crate::Error::backend_failure("compare", err))?;
2736    Ok(tensor_from_array(out))
2737}
2738
2739pub(crate) fn typed_select_with_pool<T>(
2740    buffers: &mut BufferPool,
2741    pred: &TypedTensor<bool>,
2742    on_true: &TypedTensor<T>,
2743    on_false: &TypedTensor<T>,
2744) -> crate::Result<TypedTensor<T>>
2745where
2746    T: Copy + PoolScalar,
2747{
2748    if pred.shape() != on_true.shape() {
2749        return Err(crate::Error::ShapeMismatch {
2750            op: "select",
2751            lhs: pred.shape().to_vec(),
2752            rhs: on_true.shape().to_vec(),
2753        });
2754    }
2755    if pred.shape() != on_false.shape() {
2756        return Err(crate::Error::ShapeMismatch {
2757            op: "select",
2758            lhs: pred.shape().to_vec(),
2759            rhs: on_false.shape().to_vec(),
2760        });
2761    }
2762    // SAFETY: zip_map3_into overwrites every output element.
2763    let mut out = unsafe { typed_array_uninit_from_pool(buffers, pred.shape()) };
2764    zip_map3_into(
2765        &mut out.view_mut(),
2766        &typed_view("select", pred)?,
2767        &typed_view("select", on_true)?,
2768        &typed_view("select", on_false)?,
2769        |p, t, f| if p { t } else { f },
2770    )
2771    .map_err(|err| crate::Error::backend_failure("select", err))?;
2772    Ok(tensor_from_array(out))
2773}
2774
2775pub(crate) fn typed_clamp_with_pool<T>(
2776    buffers: &mut BufferPool,
2777    input: &TypedTensor<T>,
2778    lower: &TypedTensor<T>,
2779    upper: &TypedTensor<T>,
2780) -> crate::Result<TypedTensor<T>>
2781where
2782    T: Tier2Elem + PoolScalar,
2783{
2784    if input.shape() != lower.shape() {
2785        return Err(crate::Error::ShapeMismatch {
2786            op: "clamp",
2787            lhs: input.shape().to_vec(),
2788            rhs: lower.shape().to_vec(),
2789        });
2790    }
2791    if input.shape() != upper.shape() {
2792        return Err(crate::Error::ShapeMismatch {
2793            op: "clamp",
2794            lhs: input.shape().to_vec(),
2795            rhs: upper.shape().to_vec(),
2796        });
2797    }
2798    // SAFETY: zip_map3_into overwrites every output element.
2799    let mut out = unsafe { typed_array_uninit_from_pool(buffers, input.shape()) };
2800    zip_map3_into(
2801        &mut out.view_mut(),
2802        &typed_view("clamp", input)?,
2803        &typed_view("clamp", lower)?,
2804        &typed_view("clamp", upper)?,
2805        |x, lo, hi| lo.max_elem(hi.min_elem(x)),
2806    )
2807    .map_err(|err| crate::Error::backend_failure("clamp", err))?;
2808    Ok(tensor_from_array(out))
2809}
2810
2811#[cfg(test)]
2812mod tests;