Skip to main content

tenferro_runtime/
typed_tensor.rs

1//! Typed tensor operation extension traits.
2//!
3//! Operation families that are no longer part of core, including einsum, live
4//! in their extension crates.
5
6use tenferro_ops::broadcast::{broadcast_input_plan, broadcast_shape, broadcast_shapes};
7use tenferro_tensor::{
8    CompareDir, DotGeneralConfig, Error, Result, Tensor, TensorBackend, TensorRead, TensorScalar,
9};
10
11use crate::{TypedTensorMaskOpsExt, TypedTensorOpsExt};
12use tenferro_tensor::TypedTensor;
13
14impl<T: TensorScalar> TypedTensorOpsExt<T> for TypedTensor<T> {
15    fn add<B: TensorBackend>(
16        &self,
17        rhs: &TypedTensor<T>,
18        backend: &mut B,
19    ) -> Result<TypedTensor<T>> {
20        add(self, rhs, backend)
21    }
22
23    fn sub<B: TensorBackend>(
24        &self,
25        rhs: &TypedTensor<T>,
26        backend: &mut B,
27    ) -> Result<TypedTensor<T>> {
28        sub(self, rhs, backend)
29    }
30
31    fn mul<B: TensorBackend>(
32        &self,
33        rhs: &TypedTensor<T>,
34        backend: &mut B,
35    ) -> Result<TypedTensor<T>> {
36        mul(self, rhs, backend)
37    }
38
39    fn div<B: TensorBackend>(
40        &self,
41        rhs: &TypedTensor<T>,
42        backend: &mut B,
43    ) -> Result<TypedTensor<T>> {
44        div(self, rhs, backend)
45    }
46
47    fn pow<B: TensorBackend>(
48        &self,
49        rhs: &TypedTensor<T>,
50        backend: &mut B,
51    ) -> Result<TypedTensor<T>> {
52        pow(self, rhs, backend)
53    }
54
55    fn maximum<B: TensorBackend>(
56        &self,
57        rhs: &TypedTensor<T>,
58        backend: &mut B,
59    ) -> Result<TypedTensor<T>> {
60        maximum(self, rhs, backend)
61    }
62
63    fn minimum<B: TensorBackend>(
64        &self,
65        rhs: &TypedTensor<T>,
66        backend: &mut B,
67    ) -> Result<TypedTensor<T>> {
68        minimum(self, rhs, backend)
69    }
70
71    fn neg<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
72        neg(self, backend)
73    }
74
75    fn abs<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
76        abs(self, backend)
77    }
78
79    fn sign<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
80        sign(self, backend)
81    }
82
83    fn conj<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
84        conj(self, backend)
85    }
86
87    fn exp<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
88        exp(self, backend)
89    }
90
91    fn log<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
92        log(self, backend)
93    }
94
95    fn sin<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
96        sin(self, backend)
97    }
98
99    fn cos<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
100        cos(self, backend)
101    }
102
103    fn tanh<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
104        tanh(self, backend)
105    }
106
107    fn sqrt<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
108        sqrt(self, backend)
109    }
110
111    fn rsqrt<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
112        rsqrt(self, backend)
113    }
114
115    fn expm1<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
116        expm1(self, backend)
117    }
118
119    fn log1p<B: TensorBackend>(&self, backend: &mut B) -> Result<TypedTensor<T>> {
120        log1p(self, backend)
121    }
122
123    fn compare<B: TensorBackend>(
124        &self,
125        rhs: &TypedTensor<T>,
126        dir: CompareDir,
127        backend: &mut B,
128    ) -> Result<TypedTensor<bool>> {
129        compare(self, rhs, dir, backend)
130    }
131
132    fn clamp<B: TensorBackend>(
133        &self,
134        lower: &TypedTensor<T>,
135        upper: &TypedTensor<T>,
136        backend: &mut B,
137    ) -> Result<TypedTensor<T>> {
138        clamp(self, lower, upper, backend)
139    }
140
141    fn matmul<B: TensorBackend>(
142        &self,
143        rhs: &TypedTensor<T>,
144        backend: &mut B,
145    ) -> Result<TypedTensor<T>> {
146        matmul(self, rhs, backend)
147    }
148
149    fn reduce_sum<B: TensorBackend>(
150        &self,
151        axes: &[usize],
152        backend: &mut B,
153    ) -> Result<TypedTensor<T>> {
154        reduce_sum(self, axes, backend)
155    }
156
157    fn reshape<B: TensorBackend>(
158        &self,
159        shape: &[usize],
160        backend: &mut B,
161    ) -> Result<TypedTensor<T>> {
162        reshape(self, shape, backend)
163    }
164
165    fn transpose<B: TensorBackend>(
166        &self,
167        perm: &[usize],
168        backend: &mut B,
169    ) -> Result<TypedTensor<T>> {
170        transpose(self, perm, backend)
171    }
172
173    fn broadcast_in_dim<B: TensorBackend>(
174        &self,
175        shape: &[usize],
176        dims: &[usize],
177        backend: &mut B,
178    ) -> Result<TypedTensor<T>> {
179        broadcast_in_dim(self, shape, dims, backend)
180    }
181}
182
183impl TypedTensorMaskOpsExt for TypedTensor<bool> {
184    fn where_select<T: TensorScalar, B: TensorBackend>(
185        &self,
186        on_true: &TypedTensor<T>,
187        on_false: &TypedTensor<T>,
188        backend: &mut B,
189    ) -> Result<TypedTensor<T>> {
190        where_select(self, on_true, on_false, backend)
191    }
192}
193
194/// Elementwise addition with NumPy-style broadcasting.
195///
196/// # Examples
197///
198/// ```rust
199/// # use tenferro_cpu::CpuBackend;
200/// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
201/// # let mut backend = CpuBackend::new();
202/// # let x = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![1.0, 2.0]).unwrap();
203/// # let y = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![3.0, 4.0]).unwrap();
204/// let z = x.add(&y, &mut backend).unwrap();
205/// ```
206fn add<T: TensorScalar>(
207    lhs: &TypedTensor<T>,
208    rhs: &TypedTensor<T>,
209    backend: &mut impl TensorBackend,
210) -> Result<TypedTensor<T>> {
211    let (lhs, rhs) = broadcast_binary_read(lhs, rhs, backend)?;
212    let out =
213        backend.with_backend_session(|exec| exec.add_read(lhs.tensor_read(), rhs.tensor_read()))?;
214    into_typed_result("add", out)
215}
216
217macro_rules! unary_fn {
218    ($name:ident, $method:ident, $summary:literal) => {
219        #[doc = $summary]
220        ///
221        /// # Examples
222        ///
223        /// ```rust
224        /// # use tenferro_cpu::CpuBackend;
225        /// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
226        /// # let mut backend = CpuBackend::new();
227        /// # let x = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![1.0, 4.0]).unwrap();
228        #[doc = concat!("let y = x.", stringify!($name), "(&mut backend).unwrap();")]
229        /// ```
230        fn $name<T: TensorScalar>(
231            input: &TypedTensor<T>,
232            backend: &mut impl TensorBackend,
233        ) -> Result<TypedTensor<T>> {
234            let out = backend.with_backend_session(|exec| exec.$method(T::tensor_read(input)))?;
235            into_typed_result(stringify!($name), out)
236        }
237    };
238}
239
240macro_rules! binary_fn {
241    ($name:ident, $method:ident, $summary:literal) => {
242        #[doc = $summary]
243        ///
244        /// # Examples
245        ///
246        /// ```rust
247        /// # use tenferro_cpu::CpuBackend;
248        /// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
249        /// # let mut backend = CpuBackend::new();
250        /// # let x = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![2.0, 4.0]).unwrap();
251        /// # let y = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![1.0, 8.0]).unwrap();
252        #[doc = concat!("let z = x.", stringify!($name), "(&y, &mut backend).unwrap();")]
253        /// ```
254        fn $name<T: TensorScalar>(
255            lhs: &TypedTensor<T>,
256            rhs: &TypedTensor<T>,
257            backend: &mut impl TensorBackend,
258        ) -> Result<TypedTensor<T>> {
259            let (lhs, rhs) = broadcast_binary_read(lhs, rhs, backend)?;
260            let out = backend
261                .with_backend_session(|exec| exec.$method(lhs.tensor_read(), rhs.tensor_read()))?;
262            into_typed_result(stringify!($name), out)
263        }
264    };
265}
266
267binary_fn!(
268    mul,
269    mul_read,
270    "Elementwise multiplication with NumPy-style broadcasting."
271);
272binary_fn!(
273    div,
274    div_read,
275    "Elementwise division with NumPy-style broadcasting."
276);
277binary_fn!(
278    pow,
279    pow_read,
280    "Elementwise power with NumPy-style broadcasting."
281);
282binary_fn!(
283    maximum,
284    maximum_read,
285    "Elementwise maximum with NumPy-style broadcasting."
286);
287binary_fn!(
288    minimum,
289    minimum_read,
290    "Elementwise minimum with NumPy-style broadcasting."
291);
292
293unary_fn!(neg, neg_read, "Elementwise negation.");
294unary_fn!(abs, abs_read, "Elementwise absolute value.");
295unary_fn!(sign, sign_read, "Elementwise sign.");
296unary_fn!(conj, conj_read, "Elementwise complex conjugate.");
297unary_fn!(exp, exp_read, "Elementwise exponential.");
298unary_fn!(log, log_read, "Elementwise natural logarithm.");
299unary_fn!(sin, sin_read, "Elementwise sine.");
300unary_fn!(cos, cos_read, "Elementwise cosine.");
301unary_fn!(tanh, tanh_read, "Elementwise hyperbolic tangent.");
302unary_fn!(sqrt, sqrt_read, "Elementwise square root.");
303unary_fn!(rsqrt, rsqrt_read, "Elementwise reciprocal square root.");
304unary_fn!(expm1, expm1_read, "Elementwise `exp(x) - 1`.");
305unary_fn!(log1p, log1p_read, "Elementwise `log(1 + x)`.");
306
307/// Elementwise subtraction with NumPy-style broadcasting.
308///
309/// # Examples
310///
311/// ```rust
312/// # use tenferro_cpu::CpuBackend;
313/// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
314/// # let mut backend = CpuBackend::new();
315/// # let x = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![2.0, 4.0]).unwrap();
316/// # let y = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![1.0, 8.0]).unwrap();
317/// let z = x.sub(&y, &mut backend).unwrap();
318/// ```
319fn sub<T: TensorScalar>(
320    lhs: &TypedTensor<T>,
321    rhs: &TypedTensor<T>,
322    backend: &mut impl TensorBackend,
323) -> Result<TypedTensor<T>> {
324    let (lhs, rhs) = broadcast_binary_read(lhs, rhs, backend)?;
325    let neg_rhs = backend.with_backend_session(|exec| exec.neg_read(rhs.tensor_read()))?;
326    let out = backend.with_backend_session(|exec| {
327        exec.add_read(lhs.tensor_read(), TensorRead::from_tensor(&neg_rhs))
328    })?;
329    into_typed_result("sub", out)
330}
331
332/// Elementwise comparison with NumPy-style broadcasting.
333///
334/// The result is a bool tensor.
335///
336/// # Examples
337///
338/// ```rust
339/// # use tenferro_cpu::CpuBackend;
340/// use tenferro_runtime::{CompareDir, TypedTensor, TypedTensorMaskOpsExt, TypedTensorOpsExt};
341/// # let mut backend = CpuBackend::new();
342/// # let x = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![2.0, 4.0]).unwrap();
343/// # let y = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![1.0, 8.0]).unwrap();
344/// let z = x.compare(&y, CompareDir::Gt, &mut backend).unwrap();
345/// assert_eq!(z.host_data().unwrap(), &[true, false]);
346/// ```
347fn compare<T: TensorScalar>(
348    lhs: &TypedTensor<T>,
349    rhs: &TypedTensor<T>,
350    dir: CompareDir,
351    backend: &mut impl TensorBackend,
352) -> Result<TypedTensor<bool>> {
353    let (lhs, rhs) = broadcast_binary_read(lhs, rhs, backend)?;
354    let out = backend.with_backend_session(|exec| {
355        exec.compare_read(lhs.tensor_read(), rhs.tensor_read(), &dir)
356    })?;
357    into_typed_result("compare", out)
358}
359
360/// Select values from `on_true` or `on_false` using a condition tensor.
361///
362/// This corresponds to NumPy `where(condition, x, y)`.
363///
364/// # Examples
365///
366/// ```rust
367/// # use tenferro_cpu::CpuBackend;
368/// use tenferro_runtime::{CompareDir, TypedTensor, TypedTensorMaskOpsExt, TypedTensorOpsExt};
369/// # let mut backend = CpuBackend::new();
370/// # let x = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![2.0, 4.0]).unwrap();
371/// # let y = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![1.0, 8.0]).unwrap();
372/// # let condition = x.compare(&y, CompareDir::Gt, &mut backend).unwrap();
373/// let z = condition.where_select(&x, &y, &mut backend).unwrap();
374/// ```
375fn where_select<T: TensorScalar>(
376    condition: &TypedTensor<bool>,
377    on_true: &TypedTensor<T>,
378    on_false: &TypedTensor<T>,
379    backend: &mut impl TensorBackend,
380) -> Result<TypedTensor<T>> {
381    let (condition, on_true, on_false) =
382        broadcast_ternary_read(condition, on_true, on_false, backend)?;
383    let out = backend.with_backend_session(|exec| {
384        exec.select_read(
385            condition.tensor_read(),
386            on_true.tensor_read(),
387            on_false.tensor_read(),
388        )
389    })?;
390    into_typed_result("where_select", out)
391}
392
393/// Clamp values elementwise between lower and upper bounds.
394///
395/// # Examples
396///
397/// ```rust
398/// # use tenferro_cpu::CpuBackend;
399/// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
400/// # let mut backend = CpuBackend::new();
401/// # let x = TypedTensor::<f64>::from_vec_col_major(vec![2], vec![-2.0, 4.0]).unwrap();
402/// # let lower = TypedTensor::<f64>::from_vec_col_major(vec![], vec![0.0]).unwrap();
403/// # let upper = TypedTensor::<f64>::from_vec_col_major(vec![], vec![3.0]).unwrap();
404/// let z = x.clamp(&lower, &upper, &mut backend).unwrap();
405/// ```
406fn clamp<T: TensorScalar>(
407    input: &TypedTensor<T>,
408    lower: &TypedTensor<T>,
409    upper: &TypedTensor<T>,
410    backend: &mut impl TensorBackend,
411) -> Result<TypedTensor<T>> {
412    let (input, lower, upper) = broadcast_ternary_read(input, lower, upper, backend)?;
413    let out = backend.with_backend_session(|exec| {
414        exec.clamp_read(
415            input.tensor_read(),
416            lower.tensor_read(),
417            upper.tensor_read(),
418        )
419    })?;
420    into_typed_result("clamp", out)
421}
422
423/// Matrix multiplication helper for rank-2 typed tensors.
424///
425/// This contracts the last dimension of `a` with the first dimension of `b`.
426///
427/// # Examples
428///
429/// ```rust
430/// # use tenferro_cpu::CpuBackend;
431/// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
432/// # let mut backend = CpuBackend::new();
433/// # let a = TypedTensor::<f64>::from_vec_col_major(vec![2, 3], vec![1.0; 6]).unwrap();
434/// # let b = TypedTensor::<f64>::from_vec_col_major(vec![3, 2], vec![1.0; 6]).unwrap();
435/// let c = a.matmul(&b, &mut backend).unwrap();
436/// ```
437fn matmul<T: TensorScalar>(
438    a: &TypedTensor<T>,
439    b: &TypedTensor<T>,
440    backend: &mut impl TensorBackend,
441) -> Result<TypedTensor<T>> {
442    let config = DotGeneralConfig {
443        lhs_contracting_dims: vec![a.shape().len() - 1],
444        rhs_contracting_dims: vec![0],
445        lhs_batch_dims: vec![],
446        rhs_batch_dims: vec![],
447    };
448    let out = backend.with_backend_session(|exec| {
449        exec.dot_general_read(T::tensor_read(a), T::tensor_read(b), &config)
450    })?;
451    into_typed_result("matmul", out)
452}
453
454/// Sum elements across one or more axes.
455///
456/// # Examples
457///
458/// ```rust
459/// # use tenferro_cpu::CpuBackend;
460/// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
461/// # let mut backend = CpuBackend::new();
462/// let x = TypedTensor::<f64>::from_vec_col_major(
463///     vec![2, 3],
464///     vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0],
465/// )?;
466/// let row_sums = x.reduce_sum(&[1], &mut backend).unwrap();
467/// assert_eq!(row_sums.host_data()?, &[6.0, 15.0]);
468/// # Ok::<(), tenferro_runtime::Error>(())
469/// ```
470fn reduce_sum<T: TensorScalar>(
471    input: &TypedTensor<T>,
472    axes: &[usize],
473    backend: &mut impl TensorBackend,
474) -> Result<TypedTensor<T>> {
475    let out =
476        backend.with_backend_session(|exec| exec.reduce_sum_read(T::tensor_read(input), axes))?;
477    into_typed_result("reduce_sum", out)
478}
479
480/// Reshape a typed tensor through the backend structural operation.
481///
482/// # Examples
483///
484/// ```rust
485/// # use tenferro_cpu::CpuBackend;
486/// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
487/// # let mut backend = CpuBackend::new();
488/// let x = TypedTensor::<f64>::from_vec_col_major(vec![2, 3], vec![1.0; 6]).unwrap();
489/// let y = x.reshape(&[3, 2], &mut backend).unwrap();
490/// assert_eq!(y.shape(), &[3, 2]);
491/// ```
492fn reshape<T: TensorScalar>(
493    input: &TypedTensor<T>,
494    shape: &[usize],
495    backend: &mut impl TensorBackend,
496) -> Result<TypedTensor<T>> {
497    let out =
498        backend.with_backend_session(|exec| exec.reshape_read(T::tensor_read(input), shape))?;
499    into_typed_result("reshape", out)
500}
501
502/// Permute typed tensor axes through the backend structural operation.
503///
504/// # Examples
505///
506/// ```rust
507/// # use tenferro_cpu::CpuBackend;
508/// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
509/// # let mut backend = CpuBackend::new();
510/// let x = TypedTensor::<f64>::from_vec_col_major(vec![2, 3], vec![1.0; 6]).unwrap();
511/// let y = x.transpose(&[1, 0], &mut backend).unwrap();
512/// assert_eq!(y.shape(), &[3, 2]);
513/// ```
514fn transpose<T: TensorScalar>(
515    input: &TypedTensor<T>,
516    perm: &[usize],
517    backend: &mut impl TensorBackend,
518) -> Result<TypedTensor<T>> {
519    let out =
520        backend.with_backend_session(|exec| exec.transpose_read(T::tensor_read(input), perm))?;
521    into_typed_result("transpose", out)
522}
523
524/// Broadcast a typed tensor into a larger shape.
525///
526/// `dims` maps each input axis to its output axis, following the concrete
527/// backend `broadcast_in_dim` contract.
528///
529/// # Examples
530///
531/// ```rust
532/// # use tenferro_cpu::CpuBackend;
533/// use tenferro_runtime::{TypedTensor, TypedTensorOpsExt};
534/// # let mut backend = CpuBackend::new();
535/// let row = TypedTensor::<f64>::from_vec_col_major(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
536/// let matrix = row.broadcast_in_dim(&[2, 3], &[1], &mut backend).unwrap();
537/// assert_eq!(matrix.shape(), &[2, 3]);
538/// ```
539fn broadcast_in_dim<T: TensorScalar>(
540    input: &TypedTensor<T>,
541    shape: &[usize],
542    dims: &[usize],
543    backend: &mut impl TensorBackend,
544) -> Result<TypedTensor<T>> {
545    let out = backend.with_backend_session(|exec| {
546        exec.broadcast_in_dim_read(T::tensor_read(input), shape, dims)
547    })?;
548    into_typed_result("broadcast_in_dim", out)
549}
550
551enum ReadInput<'a> {
552    Borrowed(TensorRead<'a>),
553    Owned(Tensor),
554}
555
556impl ReadInput<'_> {
557    fn tensor_read(&self) -> TensorRead<'_> {
558        match self {
559            Self::Borrowed(read) => read.clone(),
560            Self::Owned(tensor) => TensorRead::from_tensor(tensor),
561        }
562    }
563}
564
565fn broadcast_binary_read<'a, T: TensorScalar>(
566    lhs: &'a TypedTensor<T>,
567    rhs: &'a TypedTensor<T>,
568    backend: &mut impl TensorBackend,
569) -> Result<(ReadInput<'a>, ReadInput<'a>)> {
570    let shape = broadcast_shape(lhs.shape(), rhs.shape()).map_err(broadcast_error)?;
571    Ok((
572        broadcast_to_read(lhs, &shape, backend)?,
573        broadcast_to_read(rhs, &shape, backend)?,
574    ))
575}
576
577fn broadcast_ternary_read<'a, C: TensorScalar, T: TensorScalar>(
578    first: &'a TypedTensor<C>,
579    second: &'a TypedTensor<T>,
580    third: &'a TypedTensor<T>,
581    backend: &mut impl TensorBackend,
582) -> Result<(ReadInput<'a>, ReadInput<'a>, ReadInput<'a>)> {
583    let shape = broadcast_shapes([first.shape(), second.shape(), third.shape()])
584        .map_err(broadcast_error)?;
585    Ok((
586        broadcast_to_read(first, &shape, backend)?,
587        broadcast_to_read(second, &shape, backend)?,
588        broadcast_to_read(third, &shape, backend)?,
589    ))
590}
591
592fn broadcast_to_read<'a, T: TensorScalar>(
593    input: &'a TypedTensor<T>,
594    target_shape: &[usize],
595    backend: &mut impl TensorBackend,
596) -> Result<ReadInput<'a>> {
597    if input.shape() == target_shape {
598        return Ok(ReadInput::Borrowed(T::tensor_read(input)));
599    }
600
601    let plan = broadcast_input_plan(input.shape(), target_shape).map_err(broadcast_error)?;
602    let source = if plan.source_shape == input.shape() {
603        ReadInput::Borrowed(T::tensor_read(input))
604    } else {
605        let reshaped = backend.with_backend_session(|exec| {
606            exec.reshape_read(T::tensor_read(input), &plan.source_shape)
607        })?;
608        ReadInput::Owned(reshaped)
609    };
610    let out = backend.with_backend_session(|exec| {
611        exec.broadcast_in_dim_read(source.tensor_read(), target_shape, &plan.dims)
612    })?;
613    Ok(ReadInput::Owned(out))
614}
615
616fn broadcast_error(err: impl std::fmt::Display) -> Error {
617    Error::backend_failure("broadcast", err.to_string())
618}
619
620fn into_typed_result<T: TensorScalar>(op: &'static str, tensor: Tensor) -> Result<TypedTensor<T>> {
621    let actual = tensor.dtype();
622    T::into_typed(tensor).map_err(|_| Error::DTypeMismatch {
623        op,
624        lhs: T::dtype(),
625        rhs: actual,
626    })
627}