1use std::ops::{Add, Mul};
2
3use num_traits::{Float, One, Zero};
4use strided_kernel::reduce_axis;
5
6use super::{typed_host_data, typed_view, typed_view_from_view};
7use tenferro_tensor::{Tensor, TensorRank, TensorRead, TensorView, TypedTensor, TypedTensorView};
8
9fn validate_axes(op: &'static str, axes: &[usize], rank: usize) -> crate::Result<()> {
10 let mut seen = vec![false; rank];
11 for &axis in axes {
12 if axis >= rank {
13 return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
14 }
15 if seen[axis] {
16 return Err(crate::Error::DuplicateAxis {
17 op,
18 axis,
19 role: "axes",
20 });
21 }
22 seen[axis] = true;
23 }
24 Ok(())
25}
26
27fn ensure_host_tensor(op: &'static str, input: &Tensor) -> crate::Result<()> {
28 macro_rules! ensure {
29 ($tensor:expr) => {{
30 typed_host_data(op, $tensor)?;
31 Ok(())
32 }};
33 }
34
35 match input {
36 Tensor::F32(t) => ensure!(t),
37 Tensor::F64(t) => ensure!(t),
38 Tensor::I32(t) => ensure!(t),
39 Tensor::I64(t) => ensure!(t),
40 Tensor::Bool(t) => ensure!(t),
41 Tensor::C32(t) => ensure!(t),
42 Tensor::C64(t) => ensure!(t),
43 }
44}
45
46fn validate_reduced_axes_nonempty(
47 op: &'static str,
48 shape: &[usize],
49 axes: &[usize],
50) -> crate::Result<()> {
51 validate_axes(op, axes, shape.len())?;
52 for &axis in axes {
53 if shape[axis] == 0 {
54 return Err(crate::Error::InvalidConfig {
55 op,
56 message: format!("cannot reduce over zero-length axis {axis}"),
57 });
58 }
59 }
60 Ok(())
61}
62
63fn nan_propagating_max<T: Float>(a: T, b: T) -> T {
64 if a.is_nan() || b.is_nan() {
65 T::nan()
66 } else {
67 a.max(b)
68 }
69}
70
71fn nan_propagating_min<T: Float>(a: T, b: T) -> T {
72 if a.is_nan() || b.is_nan() {
73 T::nan()
74 } else {
75 a.min(b)
76 }
77}
78
79pub fn reduce_sum(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
80 match input {
81 Tensor::F32(t) => Ok(Tensor::F32(typed_reduce_sum(t, axes)?)),
82 Tensor::F64(t) => Ok(Tensor::F64(typed_reduce_sum(t, axes)?)),
83 Tensor::I32(t) => Ok(Tensor::I32(typed_reduce_sum(t, axes)?)),
84 Tensor::I64(t) => Ok(Tensor::I64(typed_reduce_sum(t, axes)?)),
85 Tensor::Bool(_) => Err(crate::Error::backend_failure(
86 "reduce_sum",
87 "unsupported dtype Bool",
88 )),
89 Tensor::C32(t) => Ok(Tensor::C32(typed_reduce_sum(t, axes)?)),
90 Tensor::C64(t) => Ok(Tensor::C64(typed_reduce_sum(t, axes)?)),
91 }
92}
93
94pub(crate) fn reduce_sum_read(input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
95 match input {
96 TensorRead::Tensor(input) => {
97 ensure_host_tensor("reduce_sum", input)?;
98 reduce_sum(input, axes)
99 }
100 TensorRead::View(TensorView::F32(t)) => Ok(Tensor::F32(typed_reduce_view(
101 &t,
102 axes,
103 |x| x,
104 |a, b| a + b,
105 f32::zero(),
106 "reduce_sum",
107 )?)),
108 TensorRead::View(TensorView::F64(t)) => Ok(Tensor::F64(typed_reduce_view(
109 &t,
110 axes,
111 |x| x,
112 |a, b| a + b,
113 f64::zero(),
114 "reduce_sum",
115 )?)),
116 TensorRead::View(TensorView::I32(t)) => Ok(Tensor::I32(typed_reduce_view(
117 &t,
118 axes,
119 |x| x,
120 |a, b| a + b,
121 i32::zero(),
122 "reduce_sum",
123 )?)),
124 TensorRead::View(TensorView::I64(t)) => Ok(Tensor::I64(typed_reduce_view(
125 &t,
126 axes,
127 |x| x,
128 |a, b| a + b,
129 i64::zero(),
130 "reduce_sum",
131 )?)),
132 TensorRead::View(TensorView::Bool(_)) => Err(crate::Error::backend_failure(
133 "reduce_sum",
134 "unsupported dtype Bool",
135 )),
136 TensorRead::View(TensorView::C32(t)) => Ok(Tensor::C32(typed_reduce_view(
137 &t,
138 axes,
139 |x| x,
140 |a, b| a + b,
141 num_complex::Complex32::zero(),
142 "reduce_sum",
143 )?)),
144 TensorRead::View(TensorView::C64(t)) => Ok(Tensor::C64(typed_reduce_view(
145 &t,
146 axes,
147 |x| x,
148 |a, b| a + b,
149 num_complex::Complex64::zero(),
150 "reduce_sum",
151 )?)),
152 }
153}
154
155pub fn reduce_prod(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
156 match input {
157 Tensor::F32(t) => Ok(Tensor::F32(typed_reduce_prod(t, axes)?)),
158 Tensor::F64(t) => Ok(Tensor::F64(typed_reduce_prod(t, axes)?)),
159 Tensor::I32(t) => Ok(Tensor::I32(typed_reduce_prod(t, axes)?)),
160 Tensor::I64(t) => Ok(Tensor::I64(typed_reduce_prod(t, axes)?)),
161 Tensor::Bool(_) => Err(crate::Error::backend_failure(
162 "reduce_prod",
163 "unsupported dtype Bool",
164 )),
165 Tensor::C32(t) => Ok(Tensor::C32(typed_reduce_prod(t, axes)?)),
166 Tensor::C64(t) => Ok(Tensor::C64(typed_reduce_prod(t, axes)?)),
167 }
168}
169
170pub(crate) fn reduce_prod_read(input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
171 match input {
172 TensorRead::Tensor(input) => {
173 ensure_host_tensor("reduce_prod", input)?;
174 reduce_prod(input, axes)
175 }
176 TensorRead::View(TensorView::F32(t)) => Ok(Tensor::F32(typed_reduce_view(
177 &t,
178 axes,
179 |x| x,
180 |a, b| a * b,
181 f32::one(),
182 "reduce_prod",
183 )?)),
184 TensorRead::View(TensorView::F64(t)) => Ok(Tensor::F64(typed_reduce_view(
185 &t,
186 axes,
187 |x| x,
188 |a, b| a * b,
189 f64::one(),
190 "reduce_prod",
191 )?)),
192 TensorRead::View(TensorView::I32(t)) => Ok(Tensor::I32(typed_reduce_view(
193 &t,
194 axes,
195 |x| x,
196 |a, b| a * b,
197 i32::one(),
198 "reduce_prod",
199 )?)),
200 TensorRead::View(TensorView::I64(t)) => Ok(Tensor::I64(typed_reduce_view(
201 &t,
202 axes,
203 |x| x,
204 |a, b| a * b,
205 i64::one(),
206 "reduce_prod",
207 )?)),
208 TensorRead::View(TensorView::Bool(_)) => Err(crate::Error::backend_failure(
209 "reduce_prod",
210 "unsupported dtype Bool",
211 )),
212 TensorRead::View(TensorView::C32(t)) => Ok(Tensor::C32(typed_reduce_view(
213 &t,
214 axes,
215 |x| x,
216 |a, b| a * b,
217 num_complex::Complex32::one(),
218 "reduce_prod",
219 )?)),
220 TensorRead::View(TensorView::C64(t)) => Ok(Tensor::C64(typed_reduce_view(
221 &t,
222 axes,
223 |x| x,
224 |a, b| a * b,
225 num_complex::Complex64::one(),
226 "reduce_prod",
227 )?)),
228 }
229}
230
231pub fn reduce_max(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
232 validate_axes("reduce_max", axes, input.shape().len())?;
233 if axes.is_empty() {
234 return Ok(input.clone());
235 }
236
237 match input {
238 Tensor::F32(tensor) => Ok(Tensor::F32(typed_reduce_max(tensor, axes)?)),
239 Tensor::F64(tensor) => Ok(Tensor::F64(typed_reduce_max(tensor, axes)?)),
240 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) | Tensor::C32(_) | Tensor::C64(_) => {
241 Err(crate::Error::backend_failure(
242 "reduce_max",
243 format!("unsupported dtype {:?}", input.dtype()),
244 ))
245 }
246 }
247}
248
249pub(crate) fn reduce_max_read(input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
250 validate_axes("reduce_max", axes, input.shape().len())?;
251 if axes.is_empty() {
252 return match input {
253 TensorRead::Tensor(input) => {
254 ensure_host_tensor("reduce_max", input)?;
255 Ok(input.clone())
256 }
257 TensorRead::View(input) => view_to_contiguous_tensor(input),
258 };
259 }
260
261 match input {
262 TensorRead::Tensor(input) => {
263 ensure_host_tensor("reduce_max", input)?;
264 reduce_max(input, axes)
265 }
266 TensorRead::View(TensorView::F32(t)) => {
267 validate_reduced_axes_nonempty("reduce_max", t.shape(), axes)?;
268 Ok(Tensor::F32(typed_reduce_view(
269 &t,
270 axes,
271 |x| x,
272 nan_propagating_max,
273 f32::neg_infinity(),
274 "reduce_max",
275 )?))
276 }
277 TensorRead::View(TensorView::F64(t)) => {
278 validate_reduced_axes_nonempty("reduce_max", t.shape(), axes)?;
279 Ok(Tensor::F64(typed_reduce_view(
280 &t,
281 axes,
282 |x| x,
283 nan_propagating_max,
284 f64::neg_infinity(),
285 "reduce_max",
286 )?))
287 }
288 view => Err(crate::Error::backend_failure(
289 "reduce_max",
290 format!("unsupported dtype {:?}", view.dtype()),
291 )),
292 }
293}
294
295pub fn reduce_min(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
296 validate_axes("reduce_min", axes, input.shape().len())?;
297 if axes.is_empty() {
298 return Ok(input.clone());
299 }
300
301 match input {
302 Tensor::F32(tensor) => Ok(Tensor::F32(typed_reduce_min(tensor, axes)?)),
303 Tensor::F64(tensor) => Ok(Tensor::F64(typed_reduce_min(tensor, axes)?)),
304 Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) | Tensor::C32(_) | Tensor::C64(_) => {
305 Err(crate::Error::backend_failure(
306 "reduce_min",
307 format!("unsupported dtype {:?}", input.dtype()),
308 ))
309 }
310 }
311}
312
313pub(crate) fn reduce_min_read(input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
314 validate_axes("reduce_min", axes, input.shape().len())?;
315 if axes.is_empty() {
316 return match input {
317 TensorRead::Tensor(input) => {
318 ensure_host_tensor("reduce_min", input)?;
319 Ok(input.clone())
320 }
321 TensorRead::View(input) => view_to_contiguous_tensor(input),
322 };
323 }
324
325 match input {
326 TensorRead::Tensor(input) => {
327 ensure_host_tensor("reduce_min", input)?;
328 reduce_min(input, axes)
329 }
330 TensorRead::View(TensorView::F32(t)) => {
331 validate_reduced_axes_nonempty("reduce_min", t.shape(), axes)?;
332 Ok(Tensor::F32(typed_reduce_view(
333 &t,
334 axes,
335 |x| x,
336 nan_propagating_min,
337 f32::infinity(),
338 "reduce_min",
339 )?))
340 }
341 TensorRead::View(TensorView::F64(t)) => {
342 validate_reduced_axes_nonempty("reduce_min", t.shape(), axes)?;
343 Ok(Tensor::F64(typed_reduce_view(
344 &t,
345 axes,
346 |x| x,
347 nan_propagating_min,
348 f64::infinity(),
349 "reduce_min",
350 )?))
351 }
352 view => Err(crate::Error::backend_failure(
353 "reduce_min",
354 format!("unsupported dtype {:?}", view.dtype()),
355 )),
356 }
357}
358
359fn typed_reduce<T, M, R>(
360 input: &TypedTensor<T>,
361 axes: &[usize],
362 map_fn: M,
363 reduce_fn: R,
364 init: T,
365 label: &'static str,
366) -> crate::Result<TypedTensor<T>>
367where
368 T: Copy + Clone + Send + Sync,
369 M: Fn(T) -> T + Copy + Sync,
370 R: Fn(T, T) -> T + Copy + Sync,
371{
372 validate_axes(label, axes, input.shape().len())?;
373 if axes.is_empty() {
374 return Ok(input.clone());
375 }
376
377 let output_shape: Vec<usize> = input
378 .shape()
379 .iter()
380 .enumerate()
381 .filter(|(axis, _)| !axes.contains(axis))
382 .map(|(_, &dim)| dim)
383 .collect();
384
385 let mut sorted_axes = axes.to_vec();
386 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
387 let Some((&first_axis, remaining_axes)) = sorted_axes.split_first() else {
388 return Ok(input.clone());
389 };
390
391 let input_view = typed_view(label, input)?;
392 let mut current = reduce_axis(&input_view, first_axis, map_fn, reduce_fn, init)
393 .map_err(|err| crate::Error::backend_failure(label, err))?;
394
395 for &axis in remaining_axes {
396 current = reduce_axis(¤t.view(), axis, map_fn, reduce_fn, init)
397 .map_err(|err| crate::Error::backend_failure(label, err))?;
398 }
399
400 TypedTensor::from_vec_col_major(output_shape, current.into_data())
401}
402
403pub(crate) fn typed_reduce_view<T, M, R, TR>(
404 input: &TypedTensorView<'_, T, TR>,
405 axes: &[usize],
406 map_fn: M,
407 reduce_fn: R,
408 init: T,
409 label: &'static str,
410) -> crate::Result<TypedTensor<T>>
411where
412 T: Copy + Clone + Send + Sync + 'static,
413 M: Fn(T) -> T + Copy + Sync,
414 R: Fn(T, T) -> T + Copy + Sync,
415 TR: TensorRank,
416{
417 validate_axes(label, axes, input.shape().len())?;
418 if axes.is_empty() {
419 return view_to_dyn_contiguous(input);
420 }
421
422 let output_shape: Vec<usize> = input
423 .shape()
424 .iter()
425 .enumerate()
426 .filter(|(axis, _)| !axes.contains(axis))
427 .map(|(_, &dim)| dim)
428 .collect();
429
430 let mut sorted_axes = axes.to_vec();
431 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
432 let Some((&first_axis, remaining_axes)) = sorted_axes.split_first() else {
433 return view_to_dyn_contiguous(input);
434 };
435
436 let input_view = typed_view_from_view(label, input)?;
437 let mut current = reduce_axis(&input_view, first_axis, map_fn, reduce_fn, init)
438 .map_err(|err| crate::Error::backend_failure(label, err))?;
439
440 for &axis in remaining_axes {
441 current = reduce_axis(¤t.view(), axis, map_fn, reduce_fn, init)
442 .map_err(|err| crate::Error::backend_failure(label, err))?;
443 }
444
445 TypedTensor::from_vec_col_major(output_shape, current.into_data())
446}
447
448fn view_to_dyn_contiguous<T, R>(input: &TypedTensorView<'_, T, R>) -> crate::Result<TypedTensor<T>>
449where
450 T: Clone + 'static,
451 R: TensorRank,
452{
453 let compact = input.to_contiguous()?;
454 let (shape, data) = compact.into_vec_col_major()?;
455 TypedTensor::from_vec_col_major(shape, data)
456}
457
458fn view_to_contiguous_tensor(input: TensorView<'_>) -> crate::Result<Tensor> {
459 match input {
460 TensorView::F32(t) => Ok(Tensor::F32(view_to_dyn_contiguous(&t)?)),
461 TensorView::F64(t) => Ok(Tensor::F64(view_to_dyn_contiguous(&t)?)),
462 TensorView::I32(t) => Ok(Tensor::I32(view_to_dyn_contiguous(&t)?)),
463 TensorView::I64(t) => Ok(Tensor::I64(view_to_dyn_contiguous(&t)?)),
464 TensorView::Bool(t) => Ok(Tensor::Bool(view_to_dyn_contiguous(&t)?)),
465 TensorView::C32(t) => Ok(Tensor::C32(view_to_dyn_contiguous(&t)?)),
466 TensorView::C64(t) => Ok(Tensor::C64(view_to_dyn_contiguous(&t)?)),
467 }
468}
469
470pub fn typed_reduce_sum<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
471where
472 T: Copy + Clone + Send + Sync + Zero + Add<Output = T>,
473{
474 typed_reduce(input, axes, |x| x, |a, b| a + b, T::zero(), "reduce_sum")
475}
476
477pub fn typed_reduce_prod<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
478where
479 T: Copy + Clone + Send + Sync + One + Mul<Output = T>,
480{
481 typed_reduce(input, axes, |x| x, |a, b| a * b, T::one(), "reduce_prod")
482}
483
484pub fn typed_reduce_max<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
485where
486 T: Float + Send + Sync,
487{
488 validate_reduced_axes_nonempty("reduce_max", input.shape(), axes)?;
489 typed_reduce(
490 input,
491 axes,
492 |x| x,
493 nan_propagating_max,
494 T::neg_infinity(),
495 "reduce_max",
496 )
497}
498
499pub fn typed_reduce_min<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
500where
501 T: Float + Send + Sync,
502{
503 validate_reduced_axes_nonempty("reduce_min", input.shape(), axes)?;
504 typed_reduce(
505 input,
506 axes,
507 |x| x,
508 nan_propagating_min,
509 T::infinity(),
510 "reduce_min",
511 )
512}