tenferro_ops/dim_expr.rs
1/// Arithmetic expression over tensor dimension sizes.
2///
3/// Evaluated at execution time from actual input tensor shapes.
4/// `InputDim { input_idx, axis }` references the axis size of
5/// the op's `input_idx`-th input tensor.
6///
7/// # Examples
8///
9/// ```rust
10/// use tenferro_ops::dim_expr::DimExpr;
11///
12/// let expr = DimExpr::mul(
13/// DimExpr::InputDim {
14/// input_idx: 0,
15/// axis: 0,
16/// },
17/// DimExpr::InputDim {
18/// input_idx: 0,
19/// axis: 1,
20/// },
21/// );
22/// assert_eq!(expr.eval(&[&[3, 4]]).unwrap(), 12);
23/// ```
24#[derive(Clone, Debug, PartialEq, Eq, Hash)]
25pub enum DimExpr {
26 /// A concrete dimension size.
27 Const(usize),
28 /// Axis size of the op's `input_idx`-th input tensor.
29 InputDim { input_idx: usize, axis: usize },
30 /// Sum of two dimension expressions.
31 Add(Box<DimExpr>, Box<DimExpr>),
32 /// Difference of two dimension expressions.
33 Sub(Box<DimExpr>, Box<DimExpr>),
34 /// Product of two dimension expressions.
35 Mul(Box<DimExpr>, Box<DimExpr>),
36 /// Floor division of two dimension expressions.
37 FloorDiv(Box<DimExpr>, Box<DimExpr>),
38 /// Minimum of two dimension expressions.
39 Min(Box<DimExpr>, Box<DimExpr>),
40 /// Maximum of two dimension expressions.
41 Max(Box<DimExpr>, Box<DimExpr>),
42}
43
44/// Error produced while evaluating a [`DimExpr`] against concrete shapes.
45#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
46pub enum DimExprEvalError {
47 /// `InputDim` referenced an input that was not provided.
48 #[error(
49 "DimExpr::InputDim input index {input_idx} out of bounds for {input_count} input shapes"
50 )]
51 InputOutOfBounds {
52 input_idx: usize,
53 input_count: usize,
54 },
55 /// `InputDim` referenced an axis that does not exist on the selected input.
56 #[error("DimExpr::InputDim axis {axis} out of bounds for input {input_idx} rank {rank}")]
57 AxisOutOfBounds {
58 input_idx: usize,
59 axis: usize,
60 rank: usize,
61 },
62 /// Addition overflowed `usize`.
63 #[error("DimExpr::Add overflow: {lhs} + {rhs}")]
64 AddOverflow { lhs: usize, rhs: usize },
65 /// Subtraction would underflow `usize`.
66 #[error("DimExpr::Sub underflow: left operand {lhs} is smaller than {rhs}")]
67 SubUnderflow { lhs: usize, rhs: usize },
68 /// Multiplication overflowed `usize`.
69 #[error("DimExpr::Mul overflow: {lhs} * {rhs}")]
70 MulOverflow { lhs: usize, rhs: usize },
71 /// Floor division divisor evaluated to zero.
72 #[error("DimExpr::FloorDiv divide by zero: left operand {lhs}, divisor {rhs}")]
73 FloorDivByZero { lhs: usize, rhs: usize },
74}
75
76impl DimExpr {
77 /// Evaluate the expression using actual input tensor shapes.
78 ///
79 /// # Examples
80 ///
81 /// ```rust
82 /// use tenferro_ops::dim_expr::DimExpr;
83 ///
84 /// let expr = DimExpr::add(
85 /// DimExpr::InputDim {
86 /// input_idx: 0,
87 /// axis: 0,
88 /// },
89 /// DimExpr::Const(2),
90 /// );
91 /// assert_eq!(expr.eval(&[&[5, 7]]).unwrap(), 7);
92 /// ```
93 pub fn eval(&self, input_shapes: &[&[usize]]) -> Result<usize, DimExprEvalError> {
94 match self {
95 Self::Const(v) => Ok(*v),
96 Self::InputDim { input_idx, axis } => input_shapes
97 .get(*input_idx)
98 .ok_or(DimExprEvalError::InputOutOfBounds {
99 input_idx: *input_idx,
100 input_count: input_shapes.len(),
101 })
102 .and_then(|shape| {
103 shape
104 .get(*axis)
105 .copied()
106 .ok_or(DimExprEvalError::AxisOutOfBounds {
107 input_idx: *input_idx,
108 axis: *axis,
109 rank: shape.len(),
110 })
111 }),
112 Self::Add(a, b) => {
113 let lhs = a.eval(input_shapes)?;
114 let rhs = b.eval(input_shapes)?;
115 lhs.checked_add(rhs)
116 .ok_or(DimExprEvalError::AddOverflow { lhs, rhs })
117 }
118 Self::Sub(a, b) => {
119 let lhs = a.eval(input_shapes)?;
120 let rhs = b.eval(input_shapes)?;
121 lhs.checked_sub(rhs)
122 .ok_or(DimExprEvalError::SubUnderflow { lhs, rhs })
123 }
124 Self::Mul(a, b) => {
125 let lhs = a.eval(input_shapes)?;
126 let rhs = b.eval(input_shapes)?;
127 lhs.checked_mul(rhs)
128 .ok_or(DimExprEvalError::MulOverflow { lhs, rhs })
129 }
130 Self::FloorDiv(a, b) => {
131 let lhs = a.eval(input_shapes)?;
132 let rhs = b.eval(input_shapes)?;
133 if rhs == 0 {
134 return Err(DimExprEvalError::FloorDivByZero { lhs, rhs });
135 }
136 Ok(lhs / rhs)
137 }
138 Self::Min(a, b) => Ok(a.eval(input_shapes)?.min(b.eval(input_shapes)?)),
139 Self::Max(a, b) => Ok(a.eval(input_shapes)?.max(b.eval(input_shapes)?)),
140 }
141 }
142
143 /// Return the maximum referenced `input_idx`, or `None` if the expression
144 /// contains only constants.
145 ///
146 /// # Examples
147 ///
148 /// ```rust
149 /// use tenferro_ops::dim_expr::DimExpr;
150 ///
151 /// let expr = DimExpr::add(
152 /// DimExpr::InputDim {
153 /// input_idx: 0,
154 /// axis: 0,
155 /// },
156 /// DimExpr::InputDim {
157 /// input_idx: 2,
158 /// axis: 1,
159 /// },
160 /// );
161 /// assert_eq!(expr.max_input_idx(), Some(2));
162 /// ```
163 pub fn max_input_idx(&self) -> Option<usize> {
164 match self {
165 Self::Const(_) => None,
166 Self::InputDim { input_idx, .. } => Some(*input_idx),
167 Self::Add(a, b)
168 | Self::Sub(a, b)
169 | Self::Mul(a, b)
170 | Self::FloorDiv(a, b)
171 | Self::Min(a, b)
172 | Self::Max(a, b) => match (a.max_input_idx(), b.max_input_idx()) {
173 (Some(x), Some(y)) => Some(x.max(y)),
174 (Some(x), None) | (None, Some(x)) => Some(x),
175 (None, None) => None,
176 },
177 }
178 }
179
180 /// Remap `InputDim { input_idx: from, .. }` to `InputDim { input_idx: to, .. }`.
181 ///
182 /// # Examples
183 ///
184 /// ```rust
185 /// use tenferro_ops::dim_expr::DimExpr;
186 ///
187 /// let expr = DimExpr::InputDim {
188 /// input_idx: 0,
189 /// axis: 1,
190 /// };
191 /// assert_eq!(expr.remap(0, 2), DimExpr::InputDim { input_idx: 2, axis: 1 });
192 /// ```
193 pub fn remap(&self, from: usize, to: usize) -> Self {
194 match self {
195 Self::Const(v) => Self::Const(*v),
196 Self::InputDim { input_idx, axis } => Self::InputDim {
197 input_idx: if *input_idx == from { to } else { *input_idx },
198 axis: *axis,
199 },
200 Self::Add(a, b) => Self::add(a.remap(from, to), b.remap(from, to)),
201 Self::Sub(a, b) => Self::sub(a.remap(from, to), b.remap(from, to)),
202 Self::Mul(a, b) => Self::mul(a.remap(from, to), b.remap(from, to)),
203 Self::FloorDiv(a, b) => Self::floor_div(a.remap(from, to), b.remap(from, to)),
204 Self::Min(a, b) => Self::min(a.remap(from, to), b.remap(from, to)),
205 Self::Max(a, b) => Self::max(a.remap(from, to), b.remap(from, to)),
206 }
207 }
208
209 /// Construct a constant dimension expression.
210 ///
211 /// # Examples
212 ///
213 /// ```rust
214 /// use tenferro_ops::dim_expr::DimExpr;
215 ///
216 /// assert_eq!(DimExpr::constant(4), DimExpr::Const(4));
217 /// ```
218 pub fn constant(v: usize) -> Self {
219 Self::Const(v)
220 }
221
222 /// Construct an addition node.
223 ///
224 /// # Examples
225 ///
226 /// ```rust
227 /// use tenferro_ops::dim_expr::DimExpr;
228 ///
229 /// let expr = DimExpr::add(DimExpr::Const(2), DimExpr::Const(3));
230 /// assert_eq!(expr.eval(&[]).unwrap(), 5);
231 /// ```
232 // Public constructor names mirror the DimExpr variants; operator traits are a separate API choice.
233 #[allow(clippy::should_implement_trait)]
234 pub fn add(a: Self, b: Self) -> Self {
235 Self::Add(Box::new(a), Box::new(b))
236 }
237
238 /// Construct a subtraction node.
239 ///
240 /// # Examples
241 ///
242 /// ```rust
243 /// use tenferro_ops::dim_expr::DimExpr;
244 ///
245 /// let expr = DimExpr::sub(DimExpr::Const(7), DimExpr::Const(2));
246 /// assert_eq!(expr.eval(&[]).unwrap(), 5);
247 /// ```
248 // Public constructor names mirror the DimExpr variants; operator traits are a separate API choice.
249 #[allow(clippy::should_implement_trait)]
250 pub fn sub(a: Self, b: Self) -> Self {
251 Self::Sub(Box::new(a), Box::new(b))
252 }
253
254 /// Construct a multiplication node.
255 ///
256 /// # Examples
257 ///
258 /// ```rust
259 /// use tenferro_ops::dim_expr::DimExpr;
260 ///
261 /// let expr = DimExpr::mul(DimExpr::Const(3), DimExpr::Const(4));
262 /// assert_eq!(expr.eval(&[]).unwrap(), 12);
263 /// ```
264 // Public constructor names mirror the DimExpr variants; operator traits are a separate API choice.
265 #[allow(clippy::should_implement_trait)]
266 pub fn mul(a: Self, b: Self) -> Self {
267 Self::Mul(Box::new(a), Box::new(b))
268 }
269
270 /// Construct a floor-division node.
271 ///
272 /// # Examples
273 ///
274 /// ```rust
275 /// use tenferro_ops::dim_expr::DimExpr;
276 ///
277 /// let expr = DimExpr::floor_div(DimExpr::Const(9), DimExpr::Const(2));
278 /// assert_eq!(expr.eval(&[]).unwrap(), 4);
279 /// ```
280 pub fn floor_div(a: Self, b: Self) -> Self {
281 Self::FloorDiv(Box::new(a), Box::new(b))
282 }
283
284 /// Construct a minimum node.
285 ///
286 /// # Examples
287 ///
288 /// ```rust
289 /// use tenferro_ops::dim_expr::DimExpr;
290 ///
291 /// let expr = DimExpr::min(DimExpr::Const(3), DimExpr::Const(5));
292 /// assert_eq!(expr.eval(&[]).unwrap(), 3);
293 /// ```
294 pub fn min(a: Self, b: Self) -> Self {
295 Self::Min(Box::new(a), Box::new(b))
296 }
297
298 /// Construct a maximum node.
299 ///
300 /// # Examples
301 ///
302 /// ```rust
303 /// use tenferro_ops::dim_expr::DimExpr;
304 ///
305 /// let expr = DimExpr::max(DimExpr::Const(3), DimExpr::Const(5));
306 /// assert_eq!(expr.eval(&[]).unwrap(), 5);
307 /// ```
308 pub fn max(a: Self, b: Self) -> Self {
309 Self::Max(Box::new(a), Box::new(b))
310 }
311
312 /// Return `true` when this expression is a constant.
313 ///
314 /// # Examples
315 ///
316 /// ```rust
317 /// use tenferro_ops::dim_expr::DimExpr;
318 ///
319 /// assert!(DimExpr::Const(3).is_const());
320 /// ```
321 pub fn is_const(&self) -> bool {
322 matches!(self, Self::Const(_))
323 }
324
325 /// Convert a concrete shape to constant expressions.
326 ///
327 /// # Examples
328 ///
329 /// ```rust
330 /// use tenferro_ops::dim_expr::DimExpr;
331 ///
332 /// assert_eq!(DimExpr::from_concrete(&[2, 3]), vec![DimExpr::Const(2), DimExpr::Const(3)]);
333 /// ```
334 pub fn from_concrete(shape: &[usize]) -> Vec<Self> {
335 shape.iter().map(|&v| Self::Const(v)).collect()
336 }
337
338 /// Build `[InputDim(input_idx, 0), ..., InputDim(input_idx, rank - 1)]`.
339 ///
340 /// # Examples
341 ///
342 /// ```rust
343 /// use tenferro_ops::dim_expr::DimExpr;
344 ///
345 /// let shape = DimExpr::input_shape(1, 2);
346 /// assert_eq!(
347 /// shape,
348 /// vec![
349 /// DimExpr::InputDim { input_idx: 1, axis: 0 },
350 /// DimExpr::InputDim { input_idx: 1, axis: 1 },
351 /// ]
352 /// );
353 /// ```
354 pub fn input_shape(input_idx: usize, rank: usize) -> Vec<Self> {
355 (0..rank)
356 .map(|axis| Self::InputDim { input_idx, axis })
357 .collect()
358 }
359
360 /// Evaluate a slice of expressions against actual input shapes.
361 ///
362 /// # Examples
363 ///
364 /// ```rust
365 /// use tenferro_ops::dim_expr::DimExpr;
366 ///
367 /// let exprs = vec![
368 /// DimExpr::InputDim { input_idx: 0, axis: 0 },
369 /// DimExpr::Const(4),
370 /// ];
371 /// assert_eq!(DimExpr::eval_all(&exprs, &[&[3, 5]]).unwrap(), vec![3, 4]);
372 /// ```
373 pub fn eval_all(
374 exprs: &[Self],
375 input_shapes: &[&[usize]],
376 ) -> Result<Vec<usize>, DimExprEvalError> {
377 exprs.iter().map(|e| e.eval(input_shapes)).collect()
378 }
379
380 /// Remap all `InputDim` references in a slice of expressions.
381 ///
382 /// # Examples
383 ///
384 /// ```rust
385 /// use tenferro_ops::dim_expr::DimExpr;
386 ///
387 /// let exprs = vec![DimExpr::InputDim { input_idx: 0, axis: 0 }];
388 /// assert_eq!(
389 /// DimExpr::remap_all(&exprs, 0, 1),
390 /// vec![DimExpr::InputDim { input_idx: 1, axis: 0 }]
391 /// );
392 /// ```
393 pub fn remap_all(exprs: &[Self], from: usize, to: usize) -> Vec<Self> {
394 exprs.iter().map(|e| e.remap(from, to)).collect()
395 }
396
397 /// Compute the maximum referenced `input_idx` across a slice.
398 ///
399 /// # Examples
400 ///
401 /// ```rust
402 /// use tenferro_ops::dim_expr::DimExpr;
403 ///
404 /// let exprs = vec![
405 /// DimExpr::InputDim { input_idx: 0, axis: 0 },
406 /// DimExpr::InputDim { input_idx: 2, axis: 1 },
407 /// ];
408 /// assert_eq!(DimExpr::max_input_idx_all(&exprs), Some(2));
409 /// ```
410 pub fn max_input_idx_all(exprs: &[Self]) -> Option<usize> {
411 exprs.iter().filter_map(Self::max_input_idx).max()
412 }
413}
414
415impl From<usize> for DimExpr {
416 fn from(v: usize) -> Self {
417 Self::Const(v)
418 }
419}
420
421impl From<&DimExpr> for DimExpr {
422 fn from(value: &DimExpr) -> Self {
423 value.clone()
424 }
425}