Skip to main content

tenferro_cpu/
reduction.rs

1use std::ops::{Add, Mul};
2
3use num_traits::{Float, One, Zero};
4use strided_kernel::reduce_axis;
5
6use super::{typed_host_data, typed_view, typed_view_from_view};
7use tenferro_tensor::{Tensor, TensorRank, TensorRead, TensorView, TypedTensor, TypedTensorView};
8
9fn validate_axes(op: &'static str, axes: &[usize], rank: usize) -> crate::Result<()> {
10    let mut seen = vec![false; rank];
11    for &axis in axes {
12        if axis >= rank {
13            return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
14        }
15        if seen[axis] {
16            return Err(crate::Error::DuplicateAxis {
17                op,
18                axis,
19                role: "axes",
20            });
21        }
22        seen[axis] = true;
23    }
24    Ok(())
25}
26
27fn ensure_host_tensor(op: &'static str, input: &Tensor) -> crate::Result<()> {
28    macro_rules! ensure {
29        ($tensor:expr) => {{
30            typed_host_data(op, $tensor)?;
31            Ok(())
32        }};
33    }
34
35    match input {
36        Tensor::F32(t) => ensure!(t),
37        Tensor::F64(t) => ensure!(t),
38        Tensor::I32(t) => ensure!(t),
39        Tensor::I64(t) => ensure!(t),
40        Tensor::Bool(t) => ensure!(t),
41        Tensor::C32(t) => ensure!(t),
42        Tensor::C64(t) => ensure!(t),
43    }
44}
45
46fn validate_reduced_axes_nonempty(
47    op: &'static str,
48    shape: &[usize],
49    axes: &[usize],
50) -> crate::Result<()> {
51    validate_axes(op, axes, shape.len())?;
52    for &axis in axes {
53        if shape[axis] == 0 {
54            return Err(crate::Error::InvalidConfig {
55                op,
56                message: format!("cannot reduce over zero-length axis {axis}"),
57            });
58        }
59    }
60    Ok(())
61}
62
63fn nan_propagating_max<T: Float>(a: T, b: T) -> T {
64    if a.is_nan() || b.is_nan() {
65        T::nan()
66    } else {
67        a.max(b)
68    }
69}
70
71fn nan_propagating_min<T: Float>(a: T, b: T) -> T {
72    if a.is_nan() || b.is_nan() {
73        T::nan()
74    } else {
75        a.min(b)
76    }
77}
78
79pub fn reduce_sum(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
80    match input {
81        Tensor::F32(t) => Ok(Tensor::F32(typed_reduce_sum(t, axes)?)),
82        Tensor::F64(t) => Ok(Tensor::F64(typed_reduce_sum(t, axes)?)),
83        Tensor::I32(t) => Ok(Tensor::I32(typed_reduce_sum(t, axes)?)),
84        Tensor::I64(t) => Ok(Tensor::I64(typed_reduce_sum(t, axes)?)),
85        Tensor::Bool(_) => Err(crate::Error::backend_failure(
86            "reduce_sum",
87            "unsupported dtype Bool",
88        )),
89        Tensor::C32(t) => Ok(Tensor::C32(typed_reduce_sum(t, axes)?)),
90        Tensor::C64(t) => Ok(Tensor::C64(typed_reduce_sum(t, axes)?)),
91    }
92}
93
94pub(crate) fn reduce_sum_read(input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
95    match input {
96        TensorRead::Tensor(input) => {
97            ensure_host_tensor("reduce_sum", input)?;
98            reduce_sum(input, axes)
99        }
100        TensorRead::View(TensorView::F32(t)) => Ok(Tensor::F32(typed_reduce_view(
101            &t,
102            axes,
103            |x| x,
104            |a, b| a + b,
105            f32::zero(),
106            "reduce_sum",
107        )?)),
108        TensorRead::View(TensorView::F64(t)) => Ok(Tensor::F64(typed_reduce_view(
109            &t,
110            axes,
111            |x| x,
112            |a, b| a + b,
113            f64::zero(),
114            "reduce_sum",
115        )?)),
116        TensorRead::View(TensorView::I32(t)) => Ok(Tensor::I32(typed_reduce_view(
117            &t,
118            axes,
119            |x| x,
120            |a, b| a + b,
121            i32::zero(),
122            "reduce_sum",
123        )?)),
124        TensorRead::View(TensorView::I64(t)) => Ok(Tensor::I64(typed_reduce_view(
125            &t,
126            axes,
127            |x| x,
128            |a, b| a + b,
129            i64::zero(),
130            "reduce_sum",
131        )?)),
132        TensorRead::View(TensorView::Bool(_)) => Err(crate::Error::backend_failure(
133            "reduce_sum",
134            "unsupported dtype Bool",
135        )),
136        TensorRead::View(TensorView::C32(t)) => Ok(Tensor::C32(typed_reduce_view(
137            &t,
138            axes,
139            |x| x,
140            |a, b| a + b,
141            num_complex::Complex32::zero(),
142            "reduce_sum",
143        )?)),
144        TensorRead::View(TensorView::C64(t)) => Ok(Tensor::C64(typed_reduce_view(
145            &t,
146            axes,
147            |x| x,
148            |a, b| a + b,
149            num_complex::Complex64::zero(),
150            "reduce_sum",
151        )?)),
152    }
153}
154
155pub fn reduce_prod(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
156    match input {
157        Tensor::F32(t) => Ok(Tensor::F32(typed_reduce_prod(t, axes)?)),
158        Tensor::F64(t) => Ok(Tensor::F64(typed_reduce_prod(t, axes)?)),
159        Tensor::I32(t) => Ok(Tensor::I32(typed_reduce_prod(t, axes)?)),
160        Tensor::I64(t) => Ok(Tensor::I64(typed_reduce_prod(t, axes)?)),
161        Tensor::Bool(_) => Err(crate::Error::backend_failure(
162            "reduce_prod",
163            "unsupported dtype Bool",
164        )),
165        Tensor::C32(t) => Ok(Tensor::C32(typed_reduce_prod(t, axes)?)),
166        Tensor::C64(t) => Ok(Tensor::C64(typed_reduce_prod(t, axes)?)),
167    }
168}
169
170pub(crate) fn reduce_prod_read(input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
171    match input {
172        TensorRead::Tensor(input) => {
173            ensure_host_tensor("reduce_prod", input)?;
174            reduce_prod(input, axes)
175        }
176        TensorRead::View(TensorView::F32(t)) => Ok(Tensor::F32(typed_reduce_view(
177            &t,
178            axes,
179            |x| x,
180            |a, b| a * b,
181            f32::one(),
182            "reduce_prod",
183        )?)),
184        TensorRead::View(TensorView::F64(t)) => Ok(Tensor::F64(typed_reduce_view(
185            &t,
186            axes,
187            |x| x,
188            |a, b| a * b,
189            f64::one(),
190            "reduce_prod",
191        )?)),
192        TensorRead::View(TensorView::I32(t)) => Ok(Tensor::I32(typed_reduce_view(
193            &t,
194            axes,
195            |x| x,
196            |a, b| a * b,
197            i32::one(),
198            "reduce_prod",
199        )?)),
200        TensorRead::View(TensorView::I64(t)) => Ok(Tensor::I64(typed_reduce_view(
201            &t,
202            axes,
203            |x| x,
204            |a, b| a * b,
205            i64::one(),
206            "reduce_prod",
207        )?)),
208        TensorRead::View(TensorView::Bool(_)) => Err(crate::Error::backend_failure(
209            "reduce_prod",
210            "unsupported dtype Bool",
211        )),
212        TensorRead::View(TensorView::C32(t)) => Ok(Tensor::C32(typed_reduce_view(
213            &t,
214            axes,
215            |x| x,
216            |a, b| a * b,
217            num_complex::Complex32::one(),
218            "reduce_prod",
219        )?)),
220        TensorRead::View(TensorView::C64(t)) => Ok(Tensor::C64(typed_reduce_view(
221            &t,
222            axes,
223            |x| x,
224            |a, b| a * b,
225            num_complex::Complex64::one(),
226            "reduce_prod",
227        )?)),
228    }
229}
230
231pub fn reduce_max(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
232    validate_axes("reduce_max", axes, input.shape().len())?;
233    if axes.is_empty() {
234        return Ok(input.clone());
235    }
236
237    match input {
238        Tensor::F32(tensor) => Ok(Tensor::F32(typed_reduce_max(tensor, axes)?)),
239        Tensor::F64(tensor) => Ok(Tensor::F64(typed_reduce_max(tensor, axes)?)),
240        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) | Tensor::C32(_) | Tensor::C64(_) => {
241            Err(crate::Error::backend_failure(
242                "reduce_max",
243                format!("unsupported dtype {:?}", input.dtype()),
244            ))
245        }
246    }
247}
248
249pub(crate) fn reduce_max_read(input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
250    validate_axes("reduce_max", axes, input.shape().len())?;
251    if axes.is_empty() {
252        return match input {
253            TensorRead::Tensor(input) => {
254                ensure_host_tensor("reduce_max", input)?;
255                Ok(input.clone())
256            }
257            TensorRead::View(input) => view_to_contiguous_tensor(input),
258        };
259    }
260
261    match input {
262        TensorRead::Tensor(input) => {
263            ensure_host_tensor("reduce_max", input)?;
264            reduce_max(input, axes)
265        }
266        TensorRead::View(TensorView::F32(t)) => {
267            validate_reduced_axes_nonempty("reduce_max", t.shape(), axes)?;
268            Ok(Tensor::F32(typed_reduce_view(
269                &t,
270                axes,
271                |x| x,
272                nan_propagating_max,
273                f32::neg_infinity(),
274                "reduce_max",
275            )?))
276        }
277        TensorRead::View(TensorView::F64(t)) => {
278            validate_reduced_axes_nonempty("reduce_max", t.shape(), axes)?;
279            Ok(Tensor::F64(typed_reduce_view(
280                &t,
281                axes,
282                |x| x,
283                nan_propagating_max,
284                f64::neg_infinity(),
285                "reduce_max",
286            )?))
287        }
288        view => Err(crate::Error::backend_failure(
289            "reduce_max",
290            format!("unsupported dtype {:?}", view.dtype()),
291        )),
292    }
293}
294
295pub fn reduce_min(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
296    validate_axes("reduce_min", axes, input.shape().len())?;
297    if axes.is_empty() {
298        return Ok(input.clone());
299    }
300
301    match input {
302        Tensor::F32(tensor) => Ok(Tensor::F32(typed_reduce_min(tensor, axes)?)),
303        Tensor::F64(tensor) => Ok(Tensor::F64(typed_reduce_min(tensor, axes)?)),
304        Tensor::I32(_) | Tensor::I64(_) | Tensor::Bool(_) | Tensor::C32(_) | Tensor::C64(_) => {
305            Err(crate::Error::backend_failure(
306                "reduce_min",
307                format!("unsupported dtype {:?}", input.dtype()),
308            ))
309        }
310    }
311}
312
313pub(crate) fn reduce_min_read(input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
314    validate_axes("reduce_min", axes, input.shape().len())?;
315    if axes.is_empty() {
316        return match input {
317            TensorRead::Tensor(input) => {
318                ensure_host_tensor("reduce_min", input)?;
319                Ok(input.clone())
320            }
321            TensorRead::View(input) => view_to_contiguous_tensor(input),
322        };
323    }
324
325    match input {
326        TensorRead::Tensor(input) => {
327            ensure_host_tensor("reduce_min", input)?;
328            reduce_min(input, axes)
329        }
330        TensorRead::View(TensorView::F32(t)) => {
331            validate_reduced_axes_nonempty("reduce_min", t.shape(), axes)?;
332            Ok(Tensor::F32(typed_reduce_view(
333                &t,
334                axes,
335                |x| x,
336                nan_propagating_min,
337                f32::infinity(),
338                "reduce_min",
339            )?))
340        }
341        TensorRead::View(TensorView::F64(t)) => {
342            validate_reduced_axes_nonempty("reduce_min", t.shape(), axes)?;
343            Ok(Tensor::F64(typed_reduce_view(
344                &t,
345                axes,
346                |x| x,
347                nan_propagating_min,
348                f64::infinity(),
349                "reduce_min",
350            )?))
351        }
352        view => Err(crate::Error::backend_failure(
353            "reduce_min",
354            format!("unsupported dtype {:?}", view.dtype()),
355        )),
356    }
357}
358
359fn typed_reduce<T, M, R>(
360    input: &TypedTensor<T>,
361    axes: &[usize],
362    map_fn: M,
363    reduce_fn: R,
364    init: T,
365    label: &'static str,
366) -> crate::Result<TypedTensor<T>>
367where
368    T: Copy + Clone + Send + Sync,
369    M: Fn(T) -> T + Copy + Sync,
370    R: Fn(T, T) -> T + Copy + Sync,
371{
372    validate_axes(label, axes, input.shape().len())?;
373    if axes.is_empty() {
374        return Ok(input.clone());
375    }
376
377    let output_shape: Vec<usize> = input
378        .shape()
379        .iter()
380        .enumerate()
381        .filter(|(axis, _)| !axes.contains(axis))
382        .map(|(_, &dim)| dim)
383        .collect();
384
385    let mut sorted_axes = axes.to_vec();
386    sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
387    let Some((&first_axis, remaining_axes)) = sorted_axes.split_first() else {
388        return Ok(input.clone());
389    };
390
391    let input_view = typed_view(label, input)?;
392    let mut current = reduce_axis(&input_view, first_axis, map_fn, reduce_fn, init)
393        .map_err(|err| crate::Error::backend_failure(label, err))?;
394
395    for &axis in remaining_axes {
396        current = reduce_axis(&current.view(), axis, map_fn, reduce_fn, init)
397            .map_err(|err| crate::Error::backend_failure(label, err))?;
398    }
399
400    TypedTensor::from_vec_col_major(output_shape, current.into_data())
401}
402
403pub(crate) fn typed_reduce_view<T, M, R, TR>(
404    input: &TypedTensorView<'_, T, TR>,
405    axes: &[usize],
406    map_fn: M,
407    reduce_fn: R,
408    init: T,
409    label: &'static str,
410) -> crate::Result<TypedTensor<T>>
411where
412    T: Copy + Clone + Send + Sync + 'static,
413    M: Fn(T) -> T + Copy + Sync,
414    R: Fn(T, T) -> T + Copy + Sync,
415    TR: TensorRank,
416{
417    validate_axes(label, axes, input.shape().len())?;
418    if axes.is_empty() {
419        return view_to_dyn_contiguous(input);
420    }
421
422    let output_shape: Vec<usize> = input
423        .shape()
424        .iter()
425        .enumerate()
426        .filter(|(axis, _)| !axes.contains(axis))
427        .map(|(_, &dim)| dim)
428        .collect();
429
430    let mut sorted_axes = axes.to_vec();
431    sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
432    let Some((&first_axis, remaining_axes)) = sorted_axes.split_first() else {
433        return view_to_dyn_contiguous(input);
434    };
435
436    let input_view = typed_view_from_view(label, input)?;
437    let mut current = reduce_axis(&input_view, first_axis, map_fn, reduce_fn, init)
438        .map_err(|err| crate::Error::backend_failure(label, err))?;
439
440    for &axis in remaining_axes {
441        current = reduce_axis(&current.view(), axis, map_fn, reduce_fn, init)
442            .map_err(|err| crate::Error::backend_failure(label, err))?;
443    }
444
445    TypedTensor::from_vec_col_major(output_shape, current.into_data())
446}
447
448fn view_to_dyn_contiguous<T, R>(input: &TypedTensorView<'_, T, R>) -> crate::Result<TypedTensor<T>>
449where
450    T: Clone + 'static,
451    R: TensorRank,
452{
453    let compact = input.to_contiguous()?;
454    let (shape, data) = compact.into_vec_col_major()?;
455    TypedTensor::from_vec_col_major(shape, data)
456}
457
458fn view_to_contiguous_tensor(input: TensorView<'_>) -> crate::Result<Tensor> {
459    match input {
460        TensorView::F32(t) => Ok(Tensor::F32(view_to_dyn_contiguous(&t)?)),
461        TensorView::F64(t) => Ok(Tensor::F64(view_to_dyn_contiguous(&t)?)),
462        TensorView::I32(t) => Ok(Tensor::I32(view_to_dyn_contiguous(&t)?)),
463        TensorView::I64(t) => Ok(Tensor::I64(view_to_dyn_contiguous(&t)?)),
464        TensorView::Bool(t) => Ok(Tensor::Bool(view_to_dyn_contiguous(&t)?)),
465        TensorView::C32(t) => Ok(Tensor::C32(view_to_dyn_contiguous(&t)?)),
466        TensorView::C64(t) => Ok(Tensor::C64(view_to_dyn_contiguous(&t)?)),
467    }
468}
469
470pub fn typed_reduce_sum<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
471where
472    T: Copy + Clone + Send + Sync + Zero + Add<Output = T>,
473{
474    typed_reduce(input, axes, |x| x, |a, b| a + b, T::zero(), "reduce_sum")
475}
476
477pub fn typed_reduce_prod<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
478where
479    T: Copy + Clone + Send + Sync + One + Mul<Output = T>,
480{
481    typed_reduce(input, axes, |x| x, |a, b| a * b, T::one(), "reduce_prod")
482}
483
484pub fn typed_reduce_max<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
485where
486    T: Float + Send + Sync,
487{
488    validate_reduced_axes_nonempty("reduce_max", input.shape(), axes)?;
489    typed_reduce(
490        input,
491        axes,
492        |x| x,
493        nan_propagating_max,
494        T::neg_infinity(),
495        "reduce_max",
496    )
497}
498
499pub fn typed_reduce_min<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
500where
501    T: Float + Send + Sync,
502{
503    validate_reduced_axes_nonempty("reduce_min", input.shape(), axes)?;
504    typed_reduce(
505        input,
506        axes,
507        |x| x,
508        nan_propagating_min,
509        T::infinity(),
510        "reduce_min",
511    )
512}