1use strided_kernel::{col_major_strides, reduce_axis, zip_map2_into, StridedArray, StridedView};
2use tenferro_algebra::Semiring;
3
4use crate::config::{
5 CompareDir, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
6};
7use crate::{Buffer, Tensor, TypedTensor};
8
9#[doc(hidden)]
11#[derive(Clone, Debug, Hash, PartialEq, Eq)]
12pub struct ElementwiseFusionPlan {
13 pub dtype: crate::DType,
14 pub n_inputs: usize,
15 pub outputs: Vec<usize>,
16 pub ops: Vec<ElementwiseFusionInst>,
17}
18
19#[doc(hidden)]
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct ElementwiseFusionInst {
23 pub op: ElementwiseFusionOp,
24 pub inputs: Vec<usize>,
25}
26
27#[doc(hidden)]
29#[derive(Clone, Debug, Hash, PartialEq, Eq)]
30pub enum ElementwiseFusionOp {
31 Add,
32 Multiply,
33 Negate,
34 Conj,
35 Divide,
36 Abs,
37 Maximum,
38 Minimum,
39 Compare(CompareDir),
40 Select,
41 Clamp,
42 Exp,
43 Log,
44 Sin,
45 Cos,
46 Tanh,
47 Sqrt,
48 Rsqrt,
49 Pow,
50 Expm1,
51 Log1p,
52}
53
54pub(crate) fn typed_view<T: Copy>(tensor: &TypedTensor<T>) -> StridedView<'_, T> {
55 match &tensor.buffer {
56 Buffer::Host(data) => {
57 let strides = col_major_strides(&tensor.shape);
58 StridedView::new(data, &tensor.shape, &strides, 0).expect("contiguous host tensor")
59 }
60 Buffer::Backend(_) => todo!("typed_view for backend buffers"),
61 #[cfg(feature = "cubecl")]
62 Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
63 }
64}
65
66pub(crate) fn typed_array<T: Clone>(shape: &[usize], fill: T) -> StridedArray<T> {
67 let total: usize = shape.iter().product();
68 let strides = col_major_strides(shape);
69 StridedArray::from_parts(vec![fill; total], shape, &strides, 0)
70 .expect("column-major output array")
71}
72
73pub(crate) fn tensor_from_array<T: Clone>(array: StridedArray<T>) -> TypedTensor<T> {
74 TypedTensor::from_vec(array.dims().to_vec(), array.into_data())
75}
76
77fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
78 crate::Error::BackendFailure {
79 op,
80 message: err.to_string(),
81 }
82}
83
84fn validate_axis_list(
85 op: &'static str,
86 role: &'static str,
87 axes: &[usize],
88 rank: usize,
89) -> crate::Result<()> {
90 let mut seen = vec![false; rank];
91 for &axis in axes {
92 if axis >= rank {
93 return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
94 }
95 if seen[axis] {
96 return Err(crate::Error::DuplicateAxis { op, axis, role });
97 }
98 seen[axis] = true;
99 }
100 Ok(())
101}
102
103fn validate_binary_shapes(op: &'static str, lhs: &[usize], rhs: &[usize]) -> crate::Result<()> {
104 if lhs != rhs {
105 return Err(crate::Error::ShapeMismatch {
106 op,
107 lhs: lhs.to_vec(),
108 rhs: rhs.to_vec(),
109 });
110 }
111 Ok(())
112}
113
114pub trait TensorExec {
134 fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
135 fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
136 fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor>;
137 fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor>;
138 fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
139 fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor>;
140 fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor>;
141 fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
142 fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
143 fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor>;
144 fn select(
145 &mut self,
146 pred: &Tensor,
147 on_true: &Tensor,
148 on_false: &Tensor,
149 ) -> crate::Result<Tensor>;
150 fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor>;
151
152 fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor>;
153 fn log(&mut self, input: &Tensor) -> crate::Result<Tensor>;
154 fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor>;
155 fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor>;
156 fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor>;
157 fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
158 fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
159 fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
160 fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor>;
161 fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor>;
162
163 fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor>;
164 fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor>;
165 fn broadcast_in_dim(
166 &mut self,
167 input: &Tensor,
168 shape: &[usize],
169 dims: &[usize],
170 ) -> crate::Result<Tensor>;
171 fn convert(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor>;
172 fn extract_diagonal(
173 &mut self,
174 input: &Tensor,
175 axis_a: usize,
176 axis_b: usize,
177 ) -> crate::Result<Tensor>;
178 fn embed_diagonal(
179 &mut self,
180 input: &Tensor,
181 axis_a: usize,
182 axis_b: usize,
183 ) -> crate::Result<Tensor>;
184 fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
185 fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
186
187 fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
188 fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
189 fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
190 fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
191
192 fn dot_general(
193 &mut self,
194 lhs: &Tensor,
195 rhs: &Tensor,
196 config: &DotGeneralConfig,
197 ) -> crate::Result<Tensor>;
198
199 fn gather(
200 &mut self,
201 operand: &Tensor,
202 start_indices: &Tensor,
203 config: &GatherConfig,
204 ) -> crate::Result<Tensor>;
205 fn scatter(
206 &mut self,
207 operand: &Tensor,
208 scatter_indices: &Tensor,
209 updates: &Tensor,
210 config: &ScatterConfig,
211 ) -> crate::Result<Tensor>;
212 fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor>;
213 fn dynamic_slice(
214 &mut self,
215 input: &Tensor,
216 starts: &Tensor,
217 slice_sizes: &[usize],
218 ) -> crate::Result<Tensor>;
219 fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor>;
220 fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor>;
221 fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
222
223 fn cholesky(&mut self, input: &Tensor) -> crate::Result<Tensor>;
224 fn triangular_solve(
225 &mut self,
226 a: &Tensor,
227 b: &Tensor,
228 left_side: bool,
229 lower: bool,
230 transpose_a: bool,
231 unit_diagonal: bool,
232 ) -> crate::Result<Tensor>;
233 fn lu(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
234 fn svd(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
235 fn qr(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
236 fn eigh(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
237 fn eig(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
238
239 fn reclaim_buffer(&mut self, tensor: Tensor);
240
241 #[doc(hidden)]
242 fn execute_elementwise_fusion(
243 &mut self,
244 _inputs: &[&Tensor],
245 _plan: &ElementwiseFusionPlan,
246 ) -> crate::Result<Option<Vec<Tensor>>> {
247 Ok(None)
248 }
249}
250
251struct BackendExecAdapter<'a, B: TensorBackend + ?Sized> {
252 backend: &'a mut B,
253}
254
255macro_rules! forward_exec_to_backend {
256 ($($name:ident($($arg:ident : $argty:ty),*) -> $ret:ty;)+) => {
257 $(
258 fn $name(&mut self, $($arg: $argty),*) -> $ret {
259 self.backend.$name($($arg),*)
260 }
261 )+
262 };
263}
264
265impl<B: TensorBackend + ?Sized> TensorExec for BackendExecAdapter<'_, B> {
266 forward_exec_to_backend! {
267 add(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
268 mul(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
269 neg(input: &Tensor) -> crate::Result<Tensor>;
270 conj(input: &Tensor) -> crate::Result<Tensor>;
271 div(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
272 abs(input: &Tensor) -> crate::Result<Tensor>;
273 sign(input: &Tensor) -> crate::Result<Tensor>;
274 maximum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
275 minimum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
276 compare(lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor>;
277 select(pred: &Tensor, on_true: &Tensor, on_false: &Tensor) -> crate::Result<Tensor>;
278 clamp(input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor>;
279 exp(input: &Tensor) -> crate::Result<Tensor>;
280 log(input: &Tensor) -> crate::Result<Tensor>;
281 sin(input: &Tensor) -> crate::Result<Tensor>;
282 cos(input: &Tensor) -> crate::Result<Tensor>;
283 tanh(input: &Tensor) -> crate::Result<Tensor>;
284 sqrt(input: &Tensor) -> crate::Result<Tensor>;
285 rsqrt(input: &Tensor) -> crate::Result<Tensor>;
286 pow(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
287 expm1(input: &Tensor) -> crate::Result<Tensor>;
288 log1p(input: &Tensor) -> crate::Result<Tensor>;
289 transpose(input: &Tensor, perm: &[usize]) -> crate::Result<Tensor>;
290 reshape(input: &Tensor, shape: &[usize]) -> crate::Result<Tensor>;
291 broadcast_in_dim(input: &Tensor, shape: &[usize], dims: &[usize]) -> crate::Result<Tensor>;
292 convert(input: &Tensor, to: crate::DType) -> crate::Result<Tensor>;
293 extract_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor>;
294 embed_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor>;
295 tril(input: &Tensor, k: i64) -> crate::Result<Tensor>;
296 triu(input: &Tensor, k: i64) -> crate::Result<Tensor>;
297 reduce_sum(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
298 reduce_prod(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
299 reduce_max(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
300 reduce_min(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
301 dot_general(lhs: &Tensor, rhs: &Tensor, config: &DotGeneralConfig) -> crate::Result<Tensor>;
302 gather(operand: &Tensor, start_indices: &Tensor, config: &GatherConfig) -> crate::Result<Tensor>;
303 scatter(
304 operand: &Tensor,
305 scatter_indices: &Tensor,
306 updates: &Tensor,
307 config: &ScatterConfig
308 ) -> crate::Result<Tensor>;
309 slice(input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor>;
310 dynamic_slice(input: &Tensor, starts: &Tensor, slice_sizes: &[usize]) -> crate::Result<Tensor>;
311 pad(input: &Tensor, config: &PadConfig) -> crate::Result<Tensor>;
312 concatenate(inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor>;
313 reverse(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
314 cholesky(input: &Tensor) -> crate::Result<Tensor>;
315 triangular_solve(
316 a: &Tensor,
317 b: &Tensor,
318 left_side: bool,
319 lower: bool,
320 transpose_a: bool,
321 unit_diagonal: bool
322 ) -> crate::Result<Tensor>;
323 lu(input: &Tensor) -> crate::Result<Vec<Tensor>>;
324 svd(input: &Tensor) -> crate::Result<Vec<Tensor>>;
325 qr(input: &Tensor) -> crate::Result<Vec<Tensor>>;
326 eigh(input: &Tensor) -> crate::Result<Vec<Tensor>>;
327 eig(input: &Tensor) -> crate::Result<Vec<Tensor>>;
328 reclaim_buffer(tensor: Tensor) -> ();
329 execute_elementwise_fusion(
330 inputs: &[&Tensor],
331 plan: &ElementwiseFusionPlan
332 ) -> crate::Result<Option<Vec<Tensor>>>;
333 }
334}
335
336pub fn default_exec_session<B: TensorBackend + ?Sized, R: Send>(
351 backend: &mut B,
352 f: impl FnOnce(&mut dyn TensorExec) -> R + Send,
353) -> R {
354 let mut adapter = BackendExecAdapter { backend };
355 f(&mut adapter)
356}
357
358pub trait TensorBackend {
368 fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
369 fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
370 fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor>;
371 fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor>;
372 fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
373 fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor>;
374 fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor>;
375 fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
376 fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
377 fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor>;
378 fn select(
379 &mut self,
380 pred: &Tensor,
381 on_true: &Tensor,
382 on_false: &Tensor,
383 ) -> crate::Result<Tensor>;
384 fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor>;
385
386 fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor>;
387 fn log(&mut self, input: &Tensor) -> crate::Result<Tensor>;
388 fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor>;
389 fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor>;
390 fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor>;
391 fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
392 fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
393 fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
394 fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor>;
395 fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor>;
396
397 fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor>;
398 fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor>;
399 fn broadcast_in_dim(
400 &mut self,
401 input: &Tensor,
402 shape: &[usize],
403 dims: &[usize],
404 ) -> crate::Result<Tensor>;
405 fn convert(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor>;
406 fn extract_diagonal(
407 &mut self,
408 input: &Tensor,
409 axis_a: usize,
410 axis_b: usize,
411 ) -> crate::Result<Tensor>;
412 fn embed_diagonal(
413 &mut self,
414 input: &Tensor,
415 axis_a: usize,
416 axis_b: usize,
417 ) -> crate::Result<Tensor>;
418 fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
419 fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
420
421 fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
422 fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
423 fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
424 fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
425
426 fn dot_general(
427 &mut self,
428 lhs: &Tensor,
429 rhs: &Tensor,
430 config: &DotGeneralConfig,
431 ) -> crate::Result<Tensor>;
432
433 fn gather(
434 &mut self,
435 operand: &Tensor,
436 start_indices: &Tensor,
437 config: &GatherConfig,
438 ) -> crate::Result<Tensor>;
439 fn scatter(
440 &mut self,
441 operand: &Tensor,
442 scatter_indices: &Tensor,
443 updates: &Tensor,
444 config: &ScatterConfig,
445 ) -> crate::Result<Tensor>;
446 fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor>;
447 fn dynamic_slice(
448 &mut self,
449 input: &Tensor,
450 starts: &Tensor,
451 slice_sizes: &[usize],
452 ) -> crate::Result<Tensor>;
453 fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor>;
454 fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor>;
455 fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
456
457 fn cholesky(&mut self, input: &Tensor) -> crate::Result<Tensor>;
458 fn triangular_solve(
459 &mut self,
460 a: &Tensor,
461 b: &Tensor,
462 left_side: bool,
463 lower: bool,
464 transpose_a: bool,
465 unit_diagonal: bool,
466 ) -> crate::Result<Tensor>;
467 fn lu(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
468 fn svd(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
469 fn qr(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
470 fn eigh(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
471 fn eig(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
472 fn solve(&mut self, a: &Tensor, b: &Tensor) -> crate::Result<Tensor>;
473
474 fn with_exec_session<R: Send>(&mut self, f: impl FnOnce(&mut dyn TensorExec) -> R + Send) -> R {
488 default_exec_session(self, f)
489 }
490
491 fn download_to_host(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
507 Ok(tensor.clone())
508 }
509
510 fn upload_host_tensor(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
526 Ok(tensor.clone())
527 }
528
529 fn reclaim_buffer(&mut self, _tensor: Tensor) {}
543
544 #[doc(hidden)]
545 fn execute_elementwise_fusion(
546 &mut self,
547 _inputs: &[&Tensor],
548 _plan: &ElementwiseFusionPlan,
549 ) -> crate::Result<Option<Vec<Tensor>>> {
550 Ok(None)
551 }
552}
553
554pub trait SemiringBackend<Alg: Semiring> {
567 fn batched_gemm(
568 &mut self,
569 lhs: &TypedTensor<Alg::Scalar>,
570 rhs: &TypedTensor<Alg::Scalar>,
571 config: &DotGeneralConfig,
572 ) -> crate::Result<TypedTensor<Alg::Scalar>>;
573
574 fn add(
575 &mut self,
576 lhs: &TypedTensor<Alg::Scalar>,
577 rhs: &TypedTensor<Alg::Scalar>,
578 ) -> crate::Result<TypedTensor<Alg::Scalar>> {
579 validate_binary_shapes("add", &lhs.shape, &rhs.shape)?;
580 let mut out = typed_array(&lhs.shape, Alg::zero());
581 zip_map2_into(
582 &mut out.view_mut(),
583 &typed_view(lhs),
584 &typed_view(rhs),
585 |x, y| Alg::add(x, y),
586 )
587 .map_err(|err| backend_failure("add", err))?;
588 Ok(tensor_from_array(out))
589 }
590
591 fn mul(
592 &mut self,
593 lhs: &TypedTensor<Alg::Scalar>,
594 rhs: &TypedTensor<Alg::Scalar>,
595 ) -> crate::Result<TypedTensor<Alg::Scalar>> {
596 validate_binary_shapes("mul", &lhs.shape, &rhs.shape)?;
597 let mut out = typed_array(&lhs.shape, Alg::zero());
598 zip_map2_into(
599 &mut out.view_mut(),
600 &typed_view(lhs),
601 &typed_view(rhs),
602 |x, y| Alg::mul(x, y),
603 )
604 .map_err(|err| backend_failure("mul", err))?;
605 Ok(tensor_from_array(out))
606 }
607
608 fn reduce_sum(
609 &mut self,
610 input: &TypedTensor<Alg::Scalar>,
611 axes: &[usize],
612 ) -> crate::Result<TypedTensor<Alg::Scalar>> {
613 validate_axis_list("reduce_sum", "axes", axes, input.shape.len())?;
614 if axes.is_empty() {
615 return Ok(input.clone());
616 }
617
618 let output_shape: Vec<usize> = input
619 .shape
620 .iter()
621 .enumerate()
622 .filter(|(axis, _)| !axes.contains(axis))
623 .map(|(_, &dim)| dim)
624 .collect();
625
626 let strides = col_major_strides(&input.shape);
627 let mut current =
628 StridedArray::from_parts(input.host_data().to_vec(), &input.shape, &strides, 0)
629 .map_err(|err| backend_failure("reduce_sum", err))?;
630
631 let mut sorted_axes = axes.to_vec();
632 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
633 for axis in sorted_axes {
634 current = reduce_axis(
635 ¤t.view(),
636 axis,
637 |x| x,
638 |a, b| Alg::add(a, b),
639 Alg::zero(),
640 )
641 .map_err(|err| backend_failure("reduce_sum", err))?;
642 }
643 Ok(TypedTensor::from_vec(output_shape, current.into_data()))
644 }
645}