Skip to main content

tenferro_ops/
axis.rs

1use thiserror::Error;
2
3/// Error returned when normalizing user-facing axis arguments.
4#[derive(Debug, Clone, PartialEq, Eq, Error)]
5pub enum AxisError {
6    /// Axis is outside `[-rank, rank)`.
7    #[error("axis {axis} is out of bounds for rank {rank}")]
8    OutOfBounds { axis: isize, rank: usize },
9    /// Axis appears more than once after negative-axis normalization.
10    #[error("duplicate axis {axis}")]
11    Duplicate { axis: usize },
12}
13
14/// Normalize a possibly-negative axis against `rank`.
15///
16/// # Examples
17///
18/// ```
19/// use tenferro_ops::axis::normalize_axis;
20///
21/// assert_eq!(normalize_axis(-1, 3).unwrap(), 2);
22/// assert!(normalize_axis(3, 3).is_err());
23/// ```
24pub fn normalize_axis(axis: isize, rank: usize) -> Result<usize, AxisError> {
25    let rank_i = rank as isize;
26    let normalized = if axis < 0 { rank_i + axis } else { axis };
27    if normalized < 0 || normalized >= rank_i {
28        return Err(AxisError::OutOfBounds { axis, rank });
29    }
30    Ok(normalized as usize)
31}
32
33/// Normalize a list of possibly-negative axes and reject duplicates.
34///
35/// # Examples
36///
37/// ```
38/// use tenferro_ops::axis::normalize_axes;
39///
40/// assert_eq!(normalize_axes(&[0, -1], 3).unwrap(), vec![0, 2]);
41/// assert!(normalize_axes(&[1, -2], 3).is_err());
42/// ```
43pub fn normalize_axes(axes: &[isize], rank: usize) -> Result<Vec<usize>, AxisError> {
44    let mut out = Vec::with_capacity(axes.len());
45    let mut seen = vec![false; rank];
46    for &axis in axes {
47        let normalized = normalize_axis(axis, rank)?;
48        if seen[normalized] {
49            return Err(AxisError::Duplicate { axis: normalized });
50        }
51        seen[normalized] = true;
52        out.push(normalized);
53    }
54    Ok(out)
55}