1use std::ops::{Add, Mul};
2
3use num_traits::{Float, One, Zero};
4use strided_kernel::{col_major_strides, reduce_axis, StridedArray};
5
6use crate::types::{Tensor, TypedTensor};
7
8fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
9 crate::Error::BackendFailure {
10 op,
11 message: err.to_string(),
12 }
13}
14
15fn validate_axes(op: &'static str, axes: &[usize], rank: usize) -> crate::Result<()> {
16 let mut seen = vec![false; rank];
17 for &axis in axes {
18 if axis >= rank {
19 return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
20 }
21 if seen[axis] {
22 return Err(crate::Error::DuplicateAxis {
23 op,
24 axis,
25 role: "axes",
26 });
27 }
28 seen[axis] = true;
29 }
30 Ok(())
31}
32
33pub fn reduce_sum(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
34 match input {
35 Tensor::F32(t) => Ok(Tensor::F32(typed_reduce_sum(t, axes)?)),
36 Tensor::F64(t) => Ok(Tensor::F64(typed_reduce_sum(t, axes)?)),
37 Tensor::C32(t) => Ok(Tensor::C32(typed_reduce_sum(t, axes)?)),
38 Tensor::C64(t) => Ok(Tensor::C64(typed_reduce_sum(t, axes)?)),
39 }
40}
41
42pub fn reduce_prod(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
43 match input {
44 Tensor::F32(t) => Ok(Tensor::F32(typed_reduce_prod(t, axes)?)),
45 Tensor::F64(t) => Ok(Tensor::F64(typed_reduce_prod(t, axes)?)),
46 Tensor::C32(t) => Ok(Tensor::C32(typed_reduce_prod(t, axes)?)),
47 Tensor::C64(t) => Ok(Tensor::C64(typed_reduce_prod(t, axes)?)),
48 }
49}
50
51pub fn reduce_max(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
52 if axes.is_empty() {
53 return Ok(input.clone());
54 }
55
56 match input {
57 Tensor::F32(tensor) => Ok(Tensor::F32(typed_reduce_max(tensor, axes)?)),
58 Tensor::F64(tensor) => Ok(Tensor::F64(typed_reduce_max(tensor, axes)?)),
59 Tensor::C32(_) | Tensor::C64(_) => Err(crate::Error::BackendFailure {
60 op: "reduce_max",
61 message: format!("unsupported dtype {:?}", input.dtype()),
62 }),
63 }
64}
65
66pub fn reduce_min(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
67 if axes.is_empty() {
68 return Ok(input.clone());
69 }
70
71 match input {
72 Tensor::F32(tensor) => Ok(Tensor::F32(typed_reduce_min(tensor, axes)?)),
73 Tensor::F64(tensor) => Ok(Tensor::F64(typed_reduce_min(tensor, axes)?)),
74 Tensor::C32(_) | Tensor::C64(_) => Err(crate::Error::BackendFailure {
75 op: "reduce_min",
76 message: format!("unsupported dtype {:?}", input.dtype()),
77 }),
78 }
79}
80
81fn typed_reduce<T, M, R>(
82 input: &TypedTensor<T>,
83 axes: &[usize],
84 map_fn: M,
85 reduce_fn: R,
86 init: T,
87 label: &'static str,
88) -> crate::Result<TypedTensor<T>>
89where
90 T: Copy + Clone,
91 M: Fn(T) -> T + Copy,
92 R: Fn(T, T) -> T + Copy,
93{
94 validate_axes(label, axes, input.shape.len())?;
95 if axes.is_empty() {
96 return Ok(input.clone());
97 }
98
99 let output_shape: Vec<usize> = input
100 .shape
101 .iter()
102 .enumerate()
103 .filter(|(axis, _)| !axes.contains(axis))
104 .map(|(_, &dim)| dim)
105 .collect();
106
107 let strides = col_major_strides(&input.shape);
108 let mut current =
109 StridedArray::from_parts(input.host_data().to_vec(), &input.shape, &strides, 0)
110 .map_err(|err| backend_failure(label, err))?;
111
112 let mut sorted_axes = axes.to_vec();
113 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
114 for axis in sorted_axes {
115 current = reduce_axis(¤t.view(), axis, map_fn, reduce_fn, init)
116 .map_err(|err| backend_failure(label, err))?;
117 }
118
119 Ok(TypedTensor::from_vec(output_shape, current.into_data()))
120}
121
122pub fn typed_reduce_sum<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
123where
124 T: Copy + Clone + Zero + Add<Output = T>,
125{
126 typed_reduce(input, axes, |x| x, |a, b| a + b, T::zero(), "reduce_sum")
127}
128
129pub fn typed_reduce_prod<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
130where
131 T: Copy + Clone + One + Mul<Output = T>,
132{
133 typed_reduce(input, axes, |x| x, |a, b| a * b, T::one(), "reduce_prod")
134}
135
136pub fn typed_reduce_max<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
137where
138 T: Float,
139{
140 typed_reduce(
141 input,
142 axes,
143 |x| x,
144 |a, b| a.max(b),
145 T::neg_infinity(),
146 "reduce_max",
147 )
148}
149
150pub fn typed_reduce_min<T>(input: &TypedTensor<T>, axes: &[usize]) -> crate::Result<TypedTensor<T>>
151where
152 T: Float,
153{
154 typed_reduce(
155 input,
156 axes,
157 |x| x,
158 |a, b| a.min(b),
159 T::infinity(),
160 "reduce_min",
161 )
162}