Skip to main content

tenferro_tensor/cpu/
reduction.rs

1use std::ops::{Add, Mul};
2
3use num_traits::{Float, One, Zero};
4use strided_kernel::{col_major_strides, reduce_axis, StridedArray};
5
6use crate::types::{Tensor, TypedTensor};
7
8fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
9    crate::Error::BackendFailure {
10        op,
11        message: err.to_string(),
12    }
13}
14
15fn validate_axes(op: &'static str, axes: &[usize], rank: usize) -> crate::Result<()> {
16    let mut seen = vec![false; rank];
17    for &axis in axes {
18        if axis >= rank {
19            return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
20        }
21        if seen[axis] {
22            return Err(crate::Error::DuplicateAxis {
23                op,
24                axis,
25                role: "axes",
26            });
27        }
28        seen[axis] = true;
29    }
30    Ok(())
31}
32
33pub fn reduce_sum(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
34    match input {
35        Tensor::F32(t) => Ok(Tensor::F32(typed_reduce_sum(t, axes)?)),
36        Tensor::F64(t) => Ok(Tensor::F64(typed_reduce_sum(t, axes)?)),
37        Tensor::C32(t) => Ok(Tensor::C32(typed_reduce_sum(t, axes)?)),
38        Tensor::C64(t) => Ok(Tensor::C64(typed_reduce_sum(t, axes)?)),
39    }
40}
41
42pub fn reduce_prod(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
43    match input {
44        Tensor::F32(t) => Ok(Tensor::F32(typed_reduce_prod(t, axes)?)),
45        Tensor::F64(t) => Ok(Tensor::F64(typed_reduce_prod(t, axes)?)),
46        Tensor::C32(t) => Ok(Tensor::C32(typed_reduce_prod(t, axes)?)),
47        Tensor::C64(t) => Ok(Tensor::C64(typed_reduce_prod(t, axes)?)),
48    }
49}
50
51pub fn reduce_max(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
52    if axes.is_empty() {
53        return Ok(input.clone());
54    }
55
56    match input {
57        Tensor::F32(tensor) => Ok(Tensor::F32(typed_reduce_max(tensor, axes)?)),
58        Tensor::F64(tensor) => Ok(Tensor::F64(typed_reduce_max(tensor, axes)?)),
59        Tensor::C32(_) | Tensor::C64(_) => Err(crate::Error::BackendFailure {
60            op: "reduce_max",
61            message: format!("unsupported dtype {:?}", input.dtype()),
62        }),
63    }
64}
65
66pub fn reduce_min(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
67    if axes.is_empty() {
68        return Ok(input.clone());
69    }
70
71    match input {
72        Tensor::F32(tensor) => Ok(Tensor::F32(typed_reduce_min(tensor, axes)?)),
73        Tensor::F64(tensor) => Ok(Tensor::F64(typed_reduce_min(tensor, axes)?)),
74        Tensor::C32(_) | Tensor::C64(_) => Err(crate::Error::BackendFailure {
75            op: "reduce_min",
76            message: format!("unsupported dtype {:?}", input.dtype()),
77        }),
78    }
79}
80
81fn typed_reduce<T, M, R>(
82    input: &TypedTensor<T>,
83    axes: &[usize],
84    map_fn: M,
85    reduce_fn: R,
86    init: T,
87    label: &'static str,
88) -> crate::Result<TypedTensor<T>>
89where
90    T: Copy + Clone,
91    M: Fn(T) -> T + Copy,
92    R: Fn(T, T) -> T + Copy,
93{
94    validate_axes(label, axes, input.shape.len())?;
95    if axes.is_empty() {
96        return Ok(input.clone());
97    }
98
99    let output_shape: Vec<usize> = input
100        .shape
101        .iter()
102        .enumerate()
103        .filter(|(axis, _)| !axes.contains(axis))
104        .map(|(_, &dim)| dim)
105        .collect();
106
107    let strides = col_major_strides(&input.shape);
108    let mut current =
109        StridedArray::from_parts(input.host_data().to_vec(), &input.shape, &strides, 0)
110            .map_err(|err| backend_failure(label, err))?;
111
112    let mut sorted_axes = axes.to_vec();
113    sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
114    for axis in sorted_axes {
115        current = reduce_axis(&current.view(), axis, map_fn, reduce_fn, init)
116            .map_err(|err| backend_failure(label, err))?;
117    }
118
119    Ok(TypedTensor::from_vec(output_shape, current.into_data()))
120}
121
122pub fn typed_reduce_sum<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
123where
124    T: Copy + Clone + Zero + Add<Output = T>,
125{
126    typed_reduce(input, axes, |x| x, |a, b| a + b, T::zero(), "reduce_sum")
127}
128
129pub fn typed_reduce_prod<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
130where
131    T: Copy + Clone + One + Mul<Output = T>,
132{
133    typed_reduce(input, axes, |x| x, |a, b| a * b, T::one(), "reduce_prod")
134}
135
136pub fn typed_reduce_max<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
137where
138    T: Float,
139{
140    typed_reduce(
141        input,
142        axes,
143        |x| x,
144        |a, b| a.max(b),
145        T::neg_infinity(),
146        "reduce_max",
147    )
148}
149
150pub fn typed_reduce_min<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
151where
152    T: Float,
153{
154    typed_reduce(
155        input,
156        axes,
157        |x| x,
158        |a, b| a.min(b),
159        T::infinity(),
160        "reduce_min",
161    )
162}