Skip to main content

tenferro_tensor/cpu/
elementwise.rs

1use std::ops::{Add, Div, Mul, Neg};
2
3use num_complex::Complex;
4use num_traits::{One, Zero};
5use strided_kernel::{map_into, zip_map2_into, zip_map3_into};
6
7use crate::{
8    config::CompareDir,
9    types::{ConjElem, Tensor, TypedTensor},
10};
11
12use super::{tensor_from_array, typed_array_uninit, typed_view};
13
14macro_rules! dispatch_ternary_result {
15    ($op:literal, $a:expr, $b:expr, $c:expr, |$x:ident, $y:ident, $z:ident| $body:expr) => {
16        match ($a, $b, $c) {
17            (Tensor::F32($x), Tensor::F32($y), Tensor::F32($z)) => Ok(Tensor::F32($body?)),
18            (Tensor::F64($x), Tensor::F64($y), Tensor::F64($z)) => Ok(Tensor::F64($body?)),
19            (Tensor::C32($x), Tensor::C32($y), Tensor::C32($z)) => Ok(Tensor::C32($body?)),
20            (Tensor::C64($x), Tensor::C64($y), Tensor::C64($z)) => Ok(Tensor::C64($body?)),
21            _ => Err(crate::Error::BackendFailure {
22                op: $op,
23                message: "dtype mismatch".into(),
24            }),
25        }
26    };
27}
28
29pub(crate) trait Tier2Elem: Copy + Clone + One + Zero {
30    fn abs_elem(self) -> Self;
31    fn sign_elem(self) -> Self;
32    fn max_elem(self, other: Self) -> Self;
33    fn min_elem(self, other: Self) -> Self;
34    fn compare_elem(self, other: Self, dir: &CompareDir) -> Self;
35    fn is_nonzero(self) -> bool;
36}
37
38macro_rules! impl_tier2_elem_real {
39    ($ty:ty) => {
40        impl Tier2Elem for $ty {
41            fn abs_elem(self) -> Self {
42                self.abs()
43            }
44
45            fn sign_elem(self) -> Self {
46                if self == Self::zero() {
47                    Self::zero()
48                } else {
49                    self.signum()
50                }
51            }
52
53            fn max_elem(self, other: Self) -> Self {
54                if self >= other {
55                    self
56                } else {
57                    other
58                }
59            }
60
61            fn min_elem(self, other: Self) -> Self {
62                if self <= other {
63                    self
64                } else {
65                    other
66                }
67            }
68
69            fn compare_elem(self, other: Self, dir: &CompareDir) -> Self {
70                let pred = match dir {
71                    CompareDir::Eq => self == other,
72                    CompareDir::Lt => self < other,
73                    CompareDir::Le => self <= other,
74                    CompareDir::Gt => self > other,
75                    CompareDir::Ge => self >= other,
76                };
77                if pred {
78                    Self::one()
79                } else {
80                    Self::zero()
81                }
82            }
83
84            fn is_nonzero(self) -> bool {
85                self != Self::zero()
86            }
87        }
88    };
89}
90
91macro_rules! impl_tier2_elem_complex {
92    ($real:ty) => {
93        impl Tier2Elem for Complex<$real> {
94            fn abs_elem(self) -> Self {
95                Self::new(self.norm(), <$real>::zero())
96            }
97
98            fn sign_elem(self) -> Self {
99                if self.is_zero() {
100                    Self::zero()
101                } else {
102                    self / self.abs_elem()
103                }
104            }
105
106            fn max_elem(self, other: Self) -> Self {
107                if self.norm_sqr() >= other.norm_sqr() {
108                    self
109                } else {
110                    other
111                }
112            }
113
114            fn min_elem(self, other: Self) -> Self {
115                if self.norm_sqr() <= other.norm_sqr() {
116                    self
117                } else {
118                    other
119                }
120            }
121
122            fn compare_elem(self, other: Self, dir: &CompareDir) -> Self {
123                let pred = match dir {
124                    CompareDir::Eq => self == other,
125                    CompareDir::Lt => self.norm_sqr() < other.norm_sqr(),
126                    CompareDir::Le => self.norm_sqr() <= other.norm_sqr(),
127                    CompareDir::Gt => self.norm_sqr() > other.norm_sqr(),
128                    CompareDir::Ge => self.norm_sqr() >= other.norm_sqr(),
129                };
130                if pred {
131                    Self::one()
132                } else {
133                    Self::zero()
134                }
135            }
136
137            fn is_nonzero(self) -> bool {
138                !self.is_zero()
139            }
140        }
141    };
142}
143
144impl_tier2_elem_real!(f32);
145impl_tier2_elem_real!(f64);
146impl_tier2_elem_complex!(f32);
147impl_tier2_elem_complex!(f64);
148
149fn complex_scalar_tensor<T>(scalar: T) -> TypedTensor<Complex<T>>
150where
151    T: Copy + Clone + Zero,
152{
153    TypedTensor::from_vec(vec![], vec![Complex::new(scalar, T::zero())])
154}
155
156fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
157    crate::Error::BackendFailure {
158        op,
159        message: err.to_string(),
160    }
161}
162
163pub fn add(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
164    match (lhs, rhs) {
165        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_add(a, b)?)),
166        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_add(a, b)?)),
167        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_add(a, b)?)),
168        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_add(a, b)?)),
169        (Tensor::F32(a), Tensor::C32(b)) if a.shape.is_empty() => {
170            let scalar = complex_scalar_tensor(a.host_data()[0]);
171            Ok(Tensor::C32(typed_add(&scalar, b)?))
172        }
173        (Tensor::C32(a), Tensor::F32(b)) if b.shape.is_empty() => {
174            let scalar = complex_scalar_tensor(b.host_data()[0]);
175            Ok(Tensor::C32(typed_add(a, &scalar)?))
176        }
177        (Tensor::F64(a), Tensor::C64(b)) if a.shape.is_empty() => {
178            let scalar = complex_scalar_tensor(a.host_data()[0]);
179            Ok(Tensor::C64(typed_add(&scalar, b)?))
180        }
181        (Tensor::C64(a), Tensor::F64(b)) if b.shape.is_empty() => {
182            let scalar = complex_scalar_tensor(b.host_data()[0]);
183            Ok(Tensor::C64(typed_add(a, &scalar)?))
184        }
185        _ => Err(crate::Error::DTypeMismatch {
186            op: "add",
187            lhs: lhs.dtype(),
188            rhs: rhs.dtype(),
189        }),
190    }
191}
192
193pub fn mul(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
194    match (lhs, rhs) {
195        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_mul(a, b)?)),
196        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_mul(a, b)?)),
197        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_mul(a, b)?)),
198        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_mul(a, b)?)),
199        (Tensor::F32(a), Tensor::C32(b)) if a.shape.is_empty() => {
200            let scalar = complex_scalar_tensor(a.host_data()[0]);
201            Ok(Tensor::C32(typed_mul(&scalar, b)?))
202        }
203        (Tensor::C32(a), Tensor::F32(b)) if b.shape.is_empty() => {
204            let scalar = complex_scalar_tensor(b.host_data()[0]);
205            Ok(Tensor::C32(typed_mul(a, &scalar)?))
206        }
207        (Tensor::F64(a), Tensor::C64(b)) if a.shape.is_empty() => {
208            let scalar = complex_scalar_tensor(a.host_data()[0]);
209            Ok(Tensor::C64(typed_mul(&scalar, b)?))
210        }
211        (Tensor::C64(a), Tensor::F64(b)) if b.shape.is_empty() => {
212            let scalar = complex_scalar_tensor(b.host_data()[0]);
213            Ok(Tensor::C64(typed_mul(a, &scalar)?))
214        }
215        _ => Err(crate::Error::DTypeMismatch {
216            op: "mul",
217            lhs: lhs.dtype(),
218            rhs: rhs.dtype(),
219        }),
220    }
221}
222
223pub fn div(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
224    match (lhs, rhs) {
225        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_div(a, b)?)),
226        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_div(a, b)?)),
227        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_div(a, b)?)),
228        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_div(a, b)?)),
229        _ => Err(crate::Error::DTypeMismatch {
230            op: "div",
231            lhs: lhs.dtype(),
232            rhs: rhs.dtype(),
233        }),
234    }
235}
236
237pub fn neg(input: &Tensor) -> crate::Result<Tensor> {
238    match input {
239        Tensor::F32(t) => Ok(Tensor::F32(typed_neg(t)?)),
240        Tensor::F64(t) => Ok(Tensor::F64(typed_neg(t)?)),
241        Tensor::C32(t) => Ok(Tensor::C32(typed_neg(t)?)),
242        Tensor::C64(t) => Ok(Tensor::C64(typed_neg(t)?)),
243    }
244}
245
246pub fn conj(input: &Tensor) -> crate::Result<Tensor> {
247    match input {
248        Tensor::F32(t) => Ok(Tensor::F32(typed_conj(t)?)),
249        Tensor::F64(t) => Ok(Tensor::F64(typed_conj(t)?)),
250        Tensor::C32(t) => Ok(Tensor::C32(typed_conj(t)?)),
251        Tensor::C64(t) => Ok(Tensor::C64(typed_conj(t)?)),
252    }
253}
254
255pub fn abs(input: &Tensor) -> crate::Result<Tensor> {
256    match input {
257        Tensor::F32(t) => Ok(Tensor::F32(typed_abs(t)?)),
258        Tensor::F64(t) => Ok(Tensor::F64(typed_abs(t)?)),
259        Tensor::C32(t) => Ok(Tensor::C32(typed_abs(t)?)),
260        Tensor::C64(t) => Ok(Tensor::C64(typed_abs(t)?)),
261    }
262}
263
264pub fn sign(input: &Tensor) -> crate::Result<Tensor> {
265    match input {
266        Tensor::F32(t) => Ok(Tensor::F32(typed_sign(t)?)),
267        Tensor::F64(t) => Ok(Tensor::F64(typed_sign(t)?)),
268        Tensor::C32(t) => Ok(Tensor::C32(typed_sign(t)?)),
269        Tensor::C64(t) => Ok(Tensor::C64(typed_sign(t)?)),
270    }
271}
272
273pub fn maximum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
274    match (lhs, rhs) {
275        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_maximum(a, b)?)),
276        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_maximum(a, b)?)),
277        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_maximum(a, b)?)),
278        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_maximum(a, b)?)),
279        _ => Err(crate::Error::DTypeMismatch {
280            op: "maximum",
281            lhs: lhs.dtype(),
282            rhs: rhs.dtype(),
283        }),
284    }
285}
286
287pub fn minimum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
288    match (lhs, rhs) {
289        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_minimum(a, b)?)),
290        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_minimum(a, b)?)),
291        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_minimum(a, b)?)),
292        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_minimum(a, b)?)),
293        _ => Err(crate::Error::DTypeMismatch {
294            op: "minimum",
295            lhs: lhs.dtype(),
296            rhs: rhs.dtype(),
297        }),
298    }
299}
300
301pub fn compare(lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor> {
302    match (lhs, rhs) {
303        (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(typed_compare(a, b, dir)?)),
304        (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(typed_compare(a, b, dir)?)),
305        (Tensor::C32(a), Tensor::C32(b)) => Ok(Tensor::C32(typed_compare(a, b, dir)?)),
306        (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(typed_compare(a, b, dir)?)),
307        _ => Err(crate::Error::DTypeMismatch {
308            op: "compare",
309            lhs: lhs.dtype(),
310            rhs: rhs.dtype(),
311        }),
312    }
313}
314
315pub fn select(pred: &Tensor, on_true: &Tensor, on_false: &Tensor) -> crate::Result<Tensor> {
316    dispatch_ternary_result!("select", pred, on_true, on_false, |p, t, f| typed_select(
317        p, t, f
318    ))
319}
320
321pub fn clamp(input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor> {
322    dispatch_ternary_result!("clamp", input, lower, upper, |x, lo, hi| typed_clamp(
323        x, lo, hi
324    ))
325}
326
327pub fn typed_add<T>(lhs: &TypedTensor<T>, rhs: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
328where
329    T: Copy + Clone + Zero + Add<Output = T>,
330{
331    if lhs.shape == rhs.shape {
332        // SAFETY: zip_map2_into overwrites every output element.
333        let mut out = unsafe { typed_array_uninit(&lhs.shape) };
334        zip_map2_into(
335            &mut out.view_mut(),
336            &typed_view(lhs),
337            &typed_view(rhs),
338            |x, y| x + y,
339        )
340        .map_err(|err| crate::Error::BackendFailure {
341            op: "add",
342            message: err.to_string(),
343        })?;
344        Ok(tensor_from_array(out))
345    } else if lhs.shape.is_empty() {
346        let scalar = lhs.host_data()[0];
347        // SAFETY: map_into overwrites every output element.
348        let mut out = unsafe { typed_array_uninit(&rhs.shape) };
349        map_into(&mut out.view_mut(), &typed_view(rhs), |x| scalar + x).map_err(|err| {
350            crate::Error::BackendFailure {
351                op: "add",
352                message: err.to_string(),
353            }
354        })?;
355        Ok(tensor_from_array(out))
356    } else if rhs.shape.is_empty() {
357        let scalar = rhs.host_data()[0];
358        // SAFETY: map_into overwrites every output element.
359        let mut out = unsafe { typed_array_uninit(&lhs.shape) };
360        map_into(&mut out.view_mut(), &typed_view(lhs), |x| x + scalar).map_err(|err| {
361            crate::Error::BackendFailure {
362                op: "add",
363                message: err.to_string(),
364            }
365        })?;
366        Ok(tensor_from_array(out))
367    } else {
368        Err(crate::Error::ShapeMismatch {
369            op: "add",
370            lhs: lhs.shape.clone(),
371            rhs: rhs.shape.clone(),
372        })
373    }
374}
375
376pub fn typed_mul<T>(lhs: &TypedTensor<T>, rhs: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
377where
378    T: Copy + Clone + Zero + Mul<Output = T>,
379{
380    if lhs.shape == rhs.shape {
381        // SAFETY: zip_map2_into overwrites every output element.
382        let mut out = unsafe { typed_array_uninit(&lhs.shape) };
383        zip_map2_into(
384            &mut out.view_mut(),
385            &typed_view(lhs),
386            &typed_view(rhs),
387            |x, y| x * y,
388        )
389        .map_err(|err| backend_failure("mul", err))?;
390        Ok(tensor_from_array(out))
391    } else if lhs.shape.is_empty() {
392        let scalar = lhs.host_data()[0];
393        // SAFETY: map_into overwrites every output element.
394        let mut out = unsafe { typed_array_uninit(&rhs.shape) };
395        map_into(&mut out.view_mut(), &typed_view(rhs), |x| scalar * x)
396            .map_err(|err| backend_failure("mul", err))?;
397        Ok(tensor_from_array(out))
398    } else if rhs.shape.is_empty() {
399        let scalar = rhs.host_data()[0];
400        // SAFETY: map_into overwrites every output element.
401        let mut out = unsafe { typed_array_uninit(&lhs.shape) };
402        map_into(&mut out.view_mut(), &typed_view(lhs), |x| x * scalar)
403            .map_err(|err| backend_failure("mul", err))?;
404        Ok(tensor_from_array(out))
405    } else {
406        Err(crate::Error::ShapeMismatch {
407            op: "mul",
408            lhs: lhs.shape.clone(),
409            rhs: rhs.shape.clone(),
410        })
411    }
412}
413
414pub fn typed_div<T>(lhs: &TypedTensor<T>, rhs: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
415where
416    T: Copy + Clone + Zero + Div<Output = T>,
417{
418    if lhs.shape != rhs.shape {
419        return Err(crate::Error::ShapeMismatch {
420            op: "div",
421            lhs: lhs.shape.clone(),
422            rhs: rhs.shape.clone(),
423        });
424    }
425    // SAFETY: zip_map2_into overwrites every output element.
426    let mut out = unsafe { typed_array_uninit(&lhs.shape) };
427    zip_map2_into(
428        &mut out.view_mut(),
429        &typed_view(lhs),
430        &typed_view(rhs),
431        |x, y| x / y,
432    )
433    .map_err(|err| backend_failure("div", err))?;
434    Ok(tensor_from_array(out))
435}
436
437pub fn typed_neg<T>(input: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
438where
439    T: Copy + Clone + Zero + Neg<Output = T>,
440{
441    // SAFETY: map_into overwrites every output element.
442    let mut out = unsafe { typed_array_uninit(&input.shape) };
443    map_into(&mut out.view_mut(), &typed_view(input), |x| -x)
444        .map_err(|err| backend_failure("neg", err))?;
445    Ok(tensor_from_array(out))
446}
447
448pub fn typed_conj<T>(input: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
449where
450    T: Copy + Clone + Zero + ConjElem,
451{
452    // SAFETY: map_into overwrites every output element.
453    let mut out = unsafe { typed_array_uninit(&input.shape) };
454    map_into(&mut out.view_mut(), &typed_view(input), |x| x.conj_elem())
455        .map_err(|err| backend_failure("conj", err))?;
456    Ok(tensor_from_array(out))
457}
458
459pub(crate) fn typed_abs<T>(input: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
460where
461    T: Tier2Elem,
462{
463    // SAFETY: map_into overwrites every output element.
464    let mut out = unsafe { typed_array_uninit(&input.shape) };
465    map_into(&mut out.view_mut(), &typed_view(input), |x| x.abs_elem())
466        .map_err(|err| backend_failure("abs", err))?;
467    Ok(tensor_from_array(out))
468}
469
470pub(crate) fn typed_sign<T>(input: &TypedTensor<T>) -> crate::Result<TypedTensor<T>>
471where
472    T: Tier2Elem,
473{
474    // SAFETY: map_into overwrites every output element.
475    let mut out = unsafe { typed_array_uninit(&input.shape) };
476    map_into(&mut out.view_mut(), &typed_view(input), |x| x.sign_elem())
477        .map_err(|err| backend_failure("sign", err))?;
478    Ok(tensor_from_array(out))
479}
480
481pub(crate) fn typed_maximum<T>(
482    lhs: &TypedTensor<T>,
483    rhs: &TypedTensor<T>,
484) -> crate::Result<TypedTensor<T>>
485where
486    T: Tier2Elem,
487{
488    if lhs.shape != rhs.shape {
489        return Err(crate::Error::ShapeMismatch {
490            op: "maximum",
491            lhs: lhs.shape.clone(),
492            rhs: rhs.shape.clone(),
493        });
494    }
495    // SAFETY: zip_map2_into overwrites every output element.
496    let mut out = unsafe { typed_array_uninit(&lhs.shape) };
497    zip_map2_into(
498        &mut out.view_mut(),
499        &typed_view(lhs),
500        &typed_view(rhs),
501        |x, y| x.max_elem(y),
502    )
503    .map_err(|err| backend_failure("maximum", err))?;
504    Ok(tensor_from_array(out))
505}
506
507pub(crate) fn typed_minimum<T>(
508    lhs: &TypedTensor<T>,
509    rhs: &TypedTensor<T>,
510) -> crate::Result<TypedTensor<T>>
511where
512    T: Tier2Elem,
513{
514    if lhs.shape != rhs.shape {
515        return Err(crate::Error::ShapeMismatch {
516            op: "minimum",
517            lhs: lhs.shape.clone(),
518            rhs: rhs.shape.clone(),
519        });
520    }
521    // SAFETY: zip_map2_into overwrites every output element.
522    let mut out = unsafe { typed_array_uninit(&lhs.shape) };
523    zip_map2_into(
524        &mut out.view_mut(),
525        &typed_view(lhs),
526        &typed_view(rhs),
527        |x, y| x.min_elem(y),
528    )
529    .map_err(|err| backend_failure("minimum", err))?;
530    Ok(tensor_from_array(out))
531}
532
533pub(crate) fn typed_compare<T>(
534    lhs: &TypedTensor<T>,
535    rhs: &TypedTensor<T>,
536    dir: &CompareDir,
537) -> crate::Result<TypedTensor<T>>
538where
539    T: Tier2Elem,
540{
541    if lhs.shape != rhs.shape {
542        return Err(crate::Error::ShapeMismatch {
543            op: "compare",
544            lhs: lhs.shape.clone(),
545            rhs: rhs.shape.clone(),
546        });
547    }
548    // SAFETY: zip_map2_into overwrites every output element.
549    let mut out = unsafe { typed_array_uninit(&lhs.shape) };
550    zip_map2_into(
551        &mut out.view_mut(),
552        &typed_view(lhs),
553        &typed_view(rhs),
554        |x, y| x.compare_elem(y, dir),
555    )
556    .map_err(|err| backend_failure("compare", err))?;
557    Ok(tensor_from_array(out))
558}
559
560pub(crate) fn typed_select<T>(
561    pred: &TypedTensor<T>,
562    on_true: &TypedTensor<T>,
563    on_false: &TypedTensor<T>,
564) -> crate::Result<TypedTensor<T>>
565where
566    T: Tier2Elem,
567{
568    if pred.shape != on_true.shape {
569        return Err(crate::Error::ShapeMismatch {
570            op: "select",
571            lhs: pred.shape.clone(),
572            rhs: on_true.shape.clone(),
573        });
574    }
575    if pred.shape != on_false.shape {
576        return Err(crate::Error::ShapeMismatch {
577            op: "select",
578            lhs: pred.shape.clone(),
579            rhs: on_false.shape.clone(),
580        });
581    }
582    // SAFETY: zip_map3_into overwrites every output element.
583    let mut out = unsafe { typed_array_uninit(&pred.shape) };
584    zip_map3_into(
585        &mut out.view_mut(),
586        &typed_view(pred),
587        &typed_view(on_true),
588        &typed_view(on_false),
589        |p, t, f| if p.is_nonzero() { t } else { f },
590    )
591    .map_err(|err| backend_failure("select", err))?;
592    Ok(tensor_from_array(out))
593}
594
595pub(crate) fn typed_clamp<T>(
596    input: &TypedTensor<T>,
597    lower: &TypedTensor<T>,
598    upper: &TypedTensor<T>,
599) -> crate::Result<TypedTensor<T>>
600where
601    T: Tier2Elem,
602{
603    if input.shape != lower.shape {
604        return Err(crate::Error::ShapeMismatch {
605            op: "clamp",
606            lhs: input.shape.clone(),
607            rhs: lower.shape.clone(),
608        });
609    }
610    if input.shape != upper.shape {
611        return Err(crate::Error::ShapeMismatch {
612            op: "clamp",
613            lhs: input.shape.clone(),
614            rhs: upper.shape.clone(),
615        });
616    }
617    // SAFETY: zip_map3_into overwrites every output element.
618    let mut out = unsafe { typed_array_uninit(&input.shape) };
619    zip_map3_into(
620        &mut out.view_mut(),
621        &typed_view(input),
622        &typed_view(lower),
623        &typed_view(upper),
624        |x, lo, hi| lo.max_elem(hi.min_elem(x)),
625    )
626    .map_err(|err| backend_failure("clamp", err))?;
627    Ok(tensor_from_array(out))
628}