tenferro_ops/
broadcast.rs1use thiserror::Error;
2
3#[derive(Debug, Clone, PartialEq, Eq)]
5pub struct BroadcastInputPlan {
6 pub source_shape: Vec<usize>,
8 pub dims: Vec<usize>,
10}
11
12#[derive(Debug, Clone, PartialEq, Eq, Error)]
14pub enum BroadcastError {
15 #[error("cannot broadcast shapes {lhs:?} and {rhs:?}")]
17 IncompatibleBinary { lhs: Vec<usize>, rhs: Vec<usize> },
18 #[error("cannot broadcast shape {input:?} to {output:?}")]
20 IncompatibleInput {
21 input: Vec<usize>,
22 output: Vec<usize>,
23 },
24 #[error("cannot broadcast higher-rank shape {input:?} to {output:?}")]
26 RankTooLarge {
27 input: Vec<usize>,
28 output: Vec<usize>,
29 },
30}
31
32pub fn broadcast_shape(lhs: &[usize], rhs: &[usize]) -> Result<Vec<usize>, BroadcastError> {
43 let rank = lhs.len().max(rhs.len());
44 let mut out = Vec::with_capacity(rank);
45 for axis in 0..rank {
46 let lhs_dim = aligned_dim(lhs, rank, axis);
47 let rhs_dim = aligned_dim(rhs, rank, axis);
48 if lhs_dim == rhs_dim {
49 out.push(lhs_dim);
50 } else if lhs_dim == 1 {
51 out.push(rhs_dim);
52 } else if rhs_dim == 1 {
53 out.push(lhs_dim);
54 } else {
55 return Err(BroadcastError::IncompatibleBinary {
56 lhs: lhs.to_vec(),
57 rhs: rhs.to_vec(),
58 });
59 }
60 }
61 Ok(out)
62}
63
64pub fn broadcast_shapes<'a>(
75 shapes: impl IntoIterator<Item = &'a [usize]>,
76) -> Result<Vec<usize>, BroadcastError> {
77 let mut iter = shapes.into_iter();
78 let Some(first) = iter.next() else {
79 return Ok(Vec::new());
80 };
81 let mut out = first.to_vec();
82 for shape in iter {
83 out = broadcast_shape(&out, shape)?;
84 }
85 Ok(out)
86}
87
88pub fn broadcast_input_plan(
103 input: &[usize],
104 output: &[usize],
105) -> Result<BroadcastInputPlan, BroadcastError> {
106 if input.len() > output.len() {
107 return Err(BroadcastError::RankTooLarge {
108 input: input.to_vec(),
109 output: output.to_vec(),
110 });
111 }
112 let rank_diff = output.len() - input.len();
113 let mut source_shape = Vec::with_capacity(input.len());
114 let mut dims = Vec::with_capacity(input.len());
115 for (src_axis, &src_dim) in input.iter().enumerate() {
116 let dst_axis = src_axis + rank_diff;
117 let dst_dim = output[dst_axis];
118 if src_dim != dst_dim && src_dim != 1 {
119 return Err(BroadcastError::IncompatibleInput {
120 input: input.to_vec(),
121 output: output.to_vec(),
122 });
123 }
124 if src_dim == 1 && dst_dim != 1 {
125 continue;
126 }
127 source_shape.push(src_dim);
128 dims.push(dst_axis);
129 }
130 Ok(BroadcastInputPlan { source_shape, dims })
131}
132
133fn aligned_dim(shape: &[usize], output_rank: usize, output_axis: usize) -> usize {
134 if output_axis < output_rank - shape.len() {
135 1
136 } else {
137 shape[output_axis - (output_rank - shape.len())]
138 }
139}