Skip to main content

tenferro_ops/
reduction.rs

1use thiserror::Error;
2
3/// Error returned when computing public reduction output shapes.
4#[derive(Debug, Clone, PartialEq, Eq, Error)]
5pub enum ReductionShapeError {
6    /// Axis is outside the input rank.
7    #[error("axis {axis} is out of bounds for rank {rank}")]
8    AxisOutOfBounds { axis: usize, rank: usize },
9    /// Axis appears more than once.
10    #[error("duplicate axis {axis}")]
11    DuplicateAxis { axis: usize },
12}
13
14/// Compute the output shape for a reduction.
15///
16/// # Examples
17///
18/// ```
19/// use tenferro_ops::reduction::reduced_shape;
20///
21/// assert_eq!(reduced_shape(&[2, 3, 4], &[1], false).unwrap(), vec![2, 4]);
22/// assert_eq!(reduced_shape(&[2, 3, 4], &[1], true).unwrap(), vec![2, 1, 4]);
23/// ```
24pub fn reduced_shape(
25    input_shape: &[usize],
26    axes: &[usize],
27    keepdims: bool,
28) -> Result<Vec<usize>, ReductionShapeError> {
29    let mut reduced = vec![false; input_shape.len()];
30    for &axis in axes {
31        if axis >= input_shape.len() {
32            return Err(ReductionShapeError::AxisOutOfBounds {
33                axis,
34                rank: input_shape.len(),
35            });
36        }
37        if reduced[axis] {
38            return Err(ReductionShapeError::DuplicateAxis { axis });
39        }
40        reduced[axis] = true;
41    }
42    let mut out = Vec::with_capacity(input_shape.len());
43    for (axis, &dim) in input_shape.iter().enumerate() {
44        if reduced[axis] {
45            if keepdims {
46                out.push(1);
47            }
48        } else {
49            out.push(dim);
50        }
51    }
52    Ok(out)
53}