tenferro_ops/
reduction.rs1use thiserror::Error;
2
3#[derive(Debug, Clone, PartialEq, Eq, Error)]
5pub enum ReductionShapeError {
6 #[error("axis {axis} is out of bounds for rank {rank}")]
8 AxisOutOfBounds { axis: usize, rank: usize },
9 #[error("duplicate axis {axis}")]
11 DuplicateAxis { axis: usize },
12}
13
14pub fn reduced_shape(
25 input_shape: &[usize],
26 axes: &[usize],
27 keepdims: bool,
28) -> Result<Vec<usize>, ReductionShapeError> {
29 let mut reduced = vec![false; input_shape.len()];
30 for &axis in axes {
31 if axis >= input_shape.len() {
32 return Err(ReductionShapeError::AxisOutOfBounds {
33 axis,
34 rank: input_shape.len(),
35 });
36 }
37 if reduced[axis] {
38 return Err(ReductionShapeError::DuplicateAxis { axis });
39 }
40 reduced[axis] = true;
41 }
42 let mut out = Vec::with_capacity(input_shape.len());
43 for (axis, &dim) in input_shape.iter().enumerate() {
44 if reduced[axis] {
45 if keepdims {
46 out.push(1);
47 }
48 } else {
49 out.push(dim);
50 }
51 }
52 Ok(out)
53}