Skip to main content

tenferro_ops/
broadcast.rs

1use thiserror::Error;
2
3/// Lowering plan for broadcasting one input to an output shape.
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct BroadcastInputPlan {
6    /// Shape to use before `BroadcastInDim`.
7    pub source_shape: Vec<usize>,
8    /// Source axes retained in `source_shape` and their output-axis positions.
9    pub dims: Vec<usize>,
10}
11
12/// Error returned when NumPy-style broadcast planning fails.
13#[derive(Debug, Clone, PartialEq, Eq, Error)]
14pub enum BroadcastError {
15    /// Two shapes cannot be broadcast together.
16    #[error("cannot broadcast shapes {lhs:?} and {rhs:?}")]
17    IncompatibleBinary { lhs: Vec<usize>, rhs: Vec<usize> },
18    /// One input cannot be broadcast to the requested output shape.
19    #[error("cannot broadcast shape {input:?} to {output:?}")]
20    IncompatibleInput {
21        input: Vec<usize>,
22        output: Vec<usize>,
23    },
24    /// A higher-rank input cannot broadcast to a lower-rank output.
25    #[error("cannot broadcast higher-rank shape {input:?} to {output:?}")]
26    RankTooLarge {
27        input: Vec<usize>,
28        output: Vec<usize>,
29    },
30}
31
32/// Compute the NumPy-style broadcast shape for two concrete shapes.
33///
34/// # Examples
35///
36/// ```
37/// use tenferro_ops::broadcast::broadcast_shape;
38///
39/// assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap(), vec![3, 4]);
40/// assert_eq!(broadcast_shape(&[], &[2, 3]).unwrap(), vec![2, 3]);
41/// ```
42pub fn broadcast_shape(lhs: &[usize], rhs: &[usize]) -> Result<Vec<usize>, BroadcastError> {
43    let rank = lhs.len().max(rhs.len());
44    let mut out = Vec::with_capacity(rank);
45    for axis in 0..rank {
46        let lhs_dim = aligned_dim(lhs, rank, axis);
47        let rhs_dim = aligned_dim(rhs, rank, axis);
48        if lhs_dim == rhs_dim {
49            out.push(lhs_dim);
50        } else if lhs_dim == 1 {
51            out.push(rhs_dim);
52        } else if rhs_dim == 1 {
53            out.push(lhs_dim);
54        } else {
55            return Err(BroadcastError::IncompatibleBinary {
56                lhs: lhs.to_vec(),
57                rhs: rhs.to_vec(),
58            });
59        }
60    }
61    Ok(out)
62}
63
64/// Compute the common NumPy-style broadcast shape for zero or more shapes.
65///
66/// # Examples
67///
68/// ```
69/// use tenferro_ops::broadcast::broadcast_shapes;
70///
71/// let shape = broadcast_shapes([&[3, 1][..], &[1, 4][..], &[3, 4][..]]).unwrap();
72/// assert_eq!(shape, vec![3, 4]);
73/// ```
74pub fn broadcast_shapes<'a>(
75    shapes: impl IntoIterator<Item = &'a [usize]>,
76) -> Result<Vec<usize>, BroadcastError> {
77    let mut iter = shapes.into_iter();
78    let Some(first) = iter.next() else {
79        return Ok(Vec::new());
80    };
81    let mut out = first.to_vec();
82    for shape in iter {
83        out = broadcast_shape(&out, shape)?;
84    }
85    Ok(out)
86}
87
88/// Plan how one input should lower to `BroadcastInDim`.
89///
90/// Expanding singleton axes are omitted from `source_shape` so downstream VJP
91/// rules reduce those axes explicitly.
92///
93/// # Examples
94///
95/// ```
96/// use tenferro_ops::broadcast::broadcast_input_plan;
97///
98/// let plan = broadcast_input_plan(&[3, 1], &[3, 4]).unwrap();
99/// assert_eq!(plan.source_shape, vec![3]);
100/// assert_eq!(plan.dims, vec![0]);
101/// ```
102pub fn broadcast_input_plan(
103    input: &[usize],
104    output: &[usize],
105) -> Result<BroadcastInputPlan, BroadcastError> {
106    if input.len() > output.len() {
107        return Err(BroadcastError::RankTooLarge {
108            input: input.to_vec(),
109            output: output.to_vec(),
110        });
111    }
112    let rank_diff = output.len() - input.len();
113    let mut source_shape = Vec::with_capacity(input.len());
114    let mut dims = Vec::with_capacity(input.len());
115    for (src_axis, &src_dim) in input.iter().enumerate() {
116        let dst_axis = src_axis + rank_diff;
117        let dst_dim = output[dst_axis];
118        if src_dim != dst_dim && src_dim != 1 {
119            return Err(BroadcastError::IncompatibleInput {
120                input: input.to_vec(),
121                output: output.to_vec(),
122            });
123        }
124        if src_dim == 1 && dst_dim != 1 {
125            continue;
126        }
127        source_shape.push(src_dim);
128        dims.push(dst_axis);
129    }
130    Ok(BroadcastInputPlan { source_shape, dims })
131}
132
133fn aligned_dim(shape: &[usize], output_rank: usize, output_axis: usize) -> usize {
134    if output_axis < output_rank - shape.len() {
135        1
136    } else {
137        shape[output_axis - (output_rank - shape.len())]
138    }
139}