Skip to main content

tenferro_ops/
dim_expr.rs

1/// Arithmetic expression over tensor dimension sizes.
2///
3/// Evaluated at execution time from actual input tensor shapes.
4/// `InputDim { input_idx, axis }` references the axis size of
5/// the op's `input_idx`-th input tensor.
6///
7/// # Examples
8///
9/// ```rust
10/// use tenferro_ops::dim_expr::DimExpr;
11///
12/// let expr = DimExpr::mul(
13///     DimExpr::InputDim {
14///         input_idx: 0,
15///         axis: 0,
16///     },
17///     DimExpr::InputDim {
18///         input_idx: 0,
19///         axis: 1,
20///     },
21/// );
22/// assert_eq!(expr.eval(&[&[3, 4]]).unwrap(), 12);
23/// ```
24#[derive(Clone, Debug, PartialEq, Eq, Hash)]
25pub enum DimExpr {
26    /// A concrete dimension size.
27    Const(usize),
28    /// Axis size of the op's `input_idx`-th input tensor.
29    InputDim { input_idx: usize, axis: usize },
30    /// Sum of two dimension expressions.
31    Add(Box<DimExpr>, Box<DimExpr>),
32    /// Difference of two dimension expressions.
33    Sub(Box<DimExpr>, Box<DimExpr>),
34    /// Product of two dimension expressions.
35    Mul(Box<DimExpr>, Box<DimExpr>),
36    /// Floor division of two dimension expressions.
37    FloorDiv(Box<DimExpr>, Box<DimExpr>),
38    /// Minimum of two dimension expressions.
39    Min(Box<DimExpr>, Box<DimExpr>),
40    /// Maximum of two dimension expressions.
41    Max(Box<DimExpr>, Box<DimExpr>),
42}
43
44/// Error produced while evaluating a [`DimExpr`] against concrete shapes.
45#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
46pub enum DimExprEvalError {
47    /// `InputDim` referenced an input that was not provided.
48    #[error(
49        "DimExpr::InputDim input index {input_idx} out of bounds for {input_count} input shapes"
50    )]
51    InputOutOfBounds {
52        input_idx: usize,
53        input_count: usize,
54    },
55    /// `InputDim` referenced an axis that does not exist on the selected input.
56    #[error("DimExpr::InputDim axis {axis} out of bounds for input {input_idx} rank {rank}")]
57    AxisOutOfBounds {
58        input_idx: usize,
59        axis: usize,
60        rank: usize,
61    },
62    /// Addition overflowed `usize`.
63    #[error("DimExpr::Add overflow: {lhs} + {rhs}")]
64    AddOverflow { lhs: usize, rhs: usize },
65    /// Subtraction would underflow `usize`.
66    #[error("DimExpr::Sub underflow: left operand {lhs} is smaller than {rhs}")]
67    SubUnderflow { lhs: usize, rhs: usize },
68    /// Multiplication overflowed `usize`.
69    #[error("DimExpr::Mul overflow: {lhs} * {rhs}")]
70    MulOverflow { lhs: usize, rhs: usize },
71    /// Floor division divisor evaluated to zero.
72    #[error("DimExpr::FloorDiv divide by zero: left operand {lhs}, divisor {rhs}")]
73    FloorDivByZero { lhs: usize, rhs: usize },
74}
75
76impl DimExpr {
77    /// Evaluate the expression using actual input tensor shapes.
78    ///
79    /// # Examples
80    ///
81    /// ```rust
82    /// use tenferro_ops::dim_expr::DimExpr;
83    ///
84    /// let expr = DimExpr::add(
85    ///     DimExpr::InputDim {
86    ///         input_idx: 0,
87    ///         axis: 0,
88    ///     },
89    ///     DimExpr::Const(2),
90    /// );
91    /// assert_eq!(expr.eval(&[&[5, 7]]).unwrap(), 7);
92    /// ```
93    pub fn eval(&self, input_shapes: &[&[usize]]) -> Result<usize, DimExprEvalError> {
94        match self {
95            Self::Const(v) => Ok(*v),
96            Self::InputDim { input_idx, axis } => input_shapes
97                .get(*input_idx)
98                .ok_or(DimExprEvalError::InputOutOfBounds {
99                    input_idx: *input_idx,
100                    input_count: input_shapes.len(),
101                })
102                .and_then(|shape| {
103                    shape
104                        .get(*axis)
105                        .copied()
106                        .ok_or(DimExprEvalError::AxisOutOfBounds {
107                            input_idx: *input_idx,
108                            axis: *axis,
109                            rank: shape.len(),
110                        })
111                }),
112            Self::Add(a, b) => {
113                let lhs = a.eval(input_shapes)?;
114                let rhs = b.eval(input_shapes)?;
115                lhs.checked_add(rhs)
116                    .ok_or(DimExprEvalError::AddOverflow { lhs, rhs })
117            }
118            Self::Sub(a, b) => {
119                let lhs = a.eval(input_shapes)?;
120                let rhs = b.eval(input_shapes)?;
121                lhs.checked_sub(rhs)
122                    .ok_or(DimExprEvalError::SubUnderflow { lhs, rhs })
123            }
124            Self::Mul(a, b) => {
125                let lhs = a.eval(input_shapes)?;
126                let rhs = b.eval(input_shapes)?;
127                lhs.checked_mul(rhs)
128                    .ok_or(DimExprEvalError::MulOverflow { lhs, rhs })
129            }
130            Self::FloorDiv(a, b) => {
131                let lhs = a.eval(input_shapes)?;
132                let rhs = b.eval(input_shapes)?;
133                if rhs == 0 {
134                    return Err(DimExprEvalError::FloorDivByZero { lhs, rhs });
135                }
136                Ok(lhs / rhs)
137            }
138            Self::Min(a, b) => Ok(a.eval(input_shapes)?.min(b.eval(input_shapes)?)),
139            Self::Max(a, b) => Ok(a.eval(input_shapes)?.max(b.eval(input_shapes)?)),
140        }
141    }
142
143    /// Return the maximum referenced `input_idx`, or `None` if the expression
144    /// contains only constants.
145    ///
146    /// # Examples
147    ///
148    /// ```rust
149    /// use tenferro_ops::dim_expr::DimExpr;
150    ///
151    /// let expr = DimExpr::add(
152    ///     DimExpr::InputDim {
153    ///         input_idx: 0,
154    ///         axis: 0,
155    ///     },
156    ///     DimExpr::InputDim {
157    ///         input_idx: 2,
158    ///         axis: 1,
159    ///     },
160    /// );
161    /// assert_eq!(expr.max_input_idx(), Some(2));
162    /// ```
163    pub fn max_input_idx(&self) -> Option<usize> {
164        match self {
165            Self::Const(_) => None,
166            Self::InputDim { input_idx, .. } => Some(*input_idx),
167            Self::Add(a, b)
168            | Self::Sub(a, b)
169            | Self::Mul(a, b)
170            | Self::FloorDiv(a, b)
171            | Self::Min(a, b)
172            | Self::Max(a, b) => match (a.max_input_idx(), b.max_input_idx()) {
173                (Some(x), Some(y)) => Some(x.max(y)),
174                (Some(x), None) | (None, Some(x)) => Some(x),
175                (None, None) => None,
176            },
177        }
178    }
179
180    /// Remap `InputDim { input_idx: from, .. }` to `InputDim { input_idx: to, .. }`.
181    ///
182    /// # Examples
183    ///
184    /// ```rust
185    /// use tenferro_ops::dim_expr::DimExpr;
186    ///
187    /// let expr = DimExpr::InputDim {
188    ///     input_idx: 0,
189    ///     axis: 1,
190    /// };
191    /// assert_eq!(expr.remap(0, 2), DimExpr::InputDim { input_idx: 2, axis: 1 });
192    /// ```
193    pub fn remap(&self, from: usize, to: usize) -> Self {
194        match self {
195            Self::Const(v) => Self::Const(*v),
196            Self::InputDim { input_idx, axis } => Self::InputDim {
197                input_idx: if *input_idx == from { to } else { *input_idx },
198                axis: *axis,
199            },
200            Self::Add(a, b) => Self::add(a.remap(from, to), b.remap(from, to)),
201            Self::Sub(a, b) => Self::sub(a.remap(from, to), b.remap(from, to)),
202            Self::Mul(a, b) => Self::mul(a.remap(from, to), b.remap(from, to)),
203            Self::FloorDiv(a, b) => Self::floor_div(a.remap(from, to), b.remap(from, to)),
204            Self::Min(a, b) => Self::min(a.remap(from, to), b.remap(from, to)),
205            Self::Max(a, b) => Self::max(a.remap(from, to), b.remap(from, to)),
206        }
207    }
208
209    /// Construct a constant dimension expression.
210    ///
211    /// # Examples
212    ///
213    /// ```rust
214    /// use tenferro_ops::dim_expr::DimExpr;
215    ///
216    /// assert_eq!(DimExpr::constant(4), DimExpr::Const(4));
217    /// ```
218    pub fn constant(v: usize) -> Self {
219        Self::Const(v)
220    }
221
222    /// Construct an addition node.
223    ///
224    /// # Examples
225    ///
226    /// ```rust
227    /// use tenferro_ops::dim_expr::DimExpr;
228    ///
229    /// let expr = DimExpr::add(DimExpr::Const(2), DimExpr::Const(3));
230    /// assert_eq!(expr.eval(&[]).unwrap(), 5);
231    /// ```
232    // Public constructor names mirror the DimExpr variants; operator traits are a separate API choice.
233    #[allow(clippy::should_implement_trait)]
234    pub fn add(a: Self, b: Self) -> Self {
235        Self::Add(Box::new(a), Box::new(b))
236    }
237
238    /// Construct a subtraction node.
239    ///
240    /// # Examples
241    ///
242    /// ```rust
243    /// use tenferro_ops::dim_expr::DimExpr;
244    ///
245    /// let expr = DimExpr::sub(DimExpr::Const(7), DimExpr::Const(2));
246    /// assert_eq!(expr.eval(&[]).unwrap(), 5);
247    /// ```
248    // Public constructor names mirror the DimExpr variants; operator traits are a separate API choice.
249    #[allow(clippy::should_implement_trait)]
250    pub fn sub(a: Self, b: Self) -> Self {
251        Self::Sub(Box::new(a), Box::new(b))
252    }
253
254    /// Construct a multiplication node.
255    ///
256    /// # Examples
257    ///
258    /// ```rust
259    /// use tenferro_ops::dim_expr::DimExpr;
260    ///
261    /// let expr = DimExpr::mul(DimExpr::Const(3), DimExpr::Const(4));
262    /// assert_eq!(expr.eval(&[]).unwrap(), 12);
263    /// ```
264    // Public constructor names mirror the DimExpr variants; operator traits are a separate API choice.
265    #[allow(clippy::should_implement_trait)]
266    pub fn mul(a: Self, b: Self) -> Self {
267        Self::Mul(Box::new(a), Box::new(b))
268    }
269
270    /// Construct a floor-division node.
271    ///
272    /// # Examples
273    ///
274    /// ```rust
275    /// use tenferro_ops::dim_expr::DimExpr;
276    ///
277    /// let expr = DimExpr::floor_div(DimExpr::Const(9), DimExpr::Const(2));
278    /// assert_eq!(expr.eval(&[]).unwrap(), 4);
279    /// ```
280    pub fn floor_div(a: Self, b: Self) -> Self {
281        Self::FloorDiv(Box::new(a), Box::new(b))
282    }
283
284    /// Construct a minimum node.
285    ///
286    /// # Examples
287    ///
288    /// ```rust
289    /// use tenferro_ops::dim_expr::DimExpr;
290    ///
291    /// let expr = DimExpr::min(DimExpr::Const(3), DimExpr::Const(5));
292    /// assert_eq!(expr.eval(&[]).unwrap(), 3);
293    /// ```
294    pub fn min(a: Self, b: Self) -> Self {
295        Self::Min(Box::new(a), Box::new(b))
296    }
297
298    /// Construct a maximum node.
299    ///
300    /// # Examples
301    ///
302    /// ```rust
303    /// use tenferro_ops::dim_expr::DimExpr;
304    ///
305    /// let expr = DimExpr::max(DimExpr::Const(3), DimExpr::Const(5));
306    /// assert_eq!(expr.eval(&[]).unwrap(), 5);
307    /// ```
308    pub fn max(a: Self, b: Self) -> Self {
309        Self::Max(Box::new(a), Box::new(b))
310    }
311
312    /// Return `true` when this expression is a constant.
313    ///
314    /// # Examples
315    ///
316    /// ```rust
317    /// use tenferro_ops::dim_expr::DimExpr;
318    ///
319    /// assert!(DimExpr::Const(3).is_const());
320    /// ```
321    pub fn is_const(&self) -> bool {
322        matches!(self, Self::Const(_))
323    }
324
325    /// Convert a concrete shape to constant expressions.
326    ///
327    /// # Examples
328    ///
329    /// ```rust
330    /// use tenferro_ops::dim_expr::DimExpr;
331    ///
332    /// assert_eq!(DimExpr::from_concrete(&[2, 3]), vec![DimExpr::Const(2), DimExpr::Const(3)]);
333    /// ```
334    pub fn from_concrete(shape: &[usize]) -> Vec<Self> {
335        shape.iter().map(|&v| Self::Const(v)).collect()
336    }
337
338    /// Build `[InputDim(input_idx, 0), ..., InputDim(input_idx, rank - 1)]`.
339    ///
340    /// # Examples
341    ///
342    /// ```rust
343    /// use tenferro_ops::dim_expr::DimExpr;
344    ///
345    /// let shape = DimExpr::input_shape(1, 2);
346    /// assert_eq!(
347    ///     shape,
348    ///     vec![
349    ///         DimExpr::InputDim { input_idx: 1, axis: 0 },
350    ///         DimExpr::InputDim { input_idx: 1, axis: 1 },
351    ///     ]
352    /// );
353    /// ```
354    pub fn input_shape(input_idx: usize, rank: usize) -> Vec<Self> {
355        (0..rank)
356            .map(|axis| Self::InputDim { input_idx, axis })
357            .collect()
358    }
359
360    /// Evaluate a slice of expressions against actual input shapes.
361    ///
362    /// # Examples
363    ///
364    /// ```rust
365    /// use tenferro_ops::dim_expr::DimExpr;
366    ///
367    /// let exprs = vec![
368    ///     DimExpr::InputDim { input_idx: 0, axis: 0 },
369    ///     DimExpr::Const(4),
370    /// ];
371    /// assert_eq!(DimExpr::eval_all(&exprs, &[&[3, 5]]).unwrap(), vec![3, 4]);
372    /// ```
373    pub fn eval_all(
374        exprs: &[Self],
375        input_shapes: &[&[usize]],
376    ) -> Result<Vec<usize>, DimExprEvalError> {
377        exprs.iter().map(|e| e.eval(input_shapes)).collect()
378    }
379
380    /// Remap all `InputDim` references in a slice of expressions.
381    ///
382    /// # Examples
383    ///
384    /// ```rust
385    /// use tenferro_ops::dim_expr::DimExpr;
386    ///
387    /// let exprs = vec![DimExpr::InputDim { input_idx: 0, axis: 0 }];
388    /// assert_eq!(
389    ///     DimExpr::remap_all(&exprs, 0, 1),
390    ///     vec![DimExpr::InputDim { input_idx: 1, axis: 0 }]
391    /// );
392    /// ```
393    pub fn remap_all(exprs: &[Self], from: usize, to: usize) -> Vec<Self> {
394        exprs.iter().map(|e| e.remap(from, to)).collect()
395    }
396
397    /// Compute the maximum referenced `input_idx` across a slice.
398    ///
399    /// # Examples
400    ///
401    /// ```rust
402    /// use tenferro_ops::dim_expr::DimExpr;
403    ///
404    /// let exprs = vec![
405    ///     DimExpr::InputDim { input_idx: 0, axis: 0 },
406    ///     DimExpr::InputDim { input_idx: 2, axis: 1 },
407    /// ];
408    /// assert_eq!(DimExpr::max_input_idx_all(&exprs), Some(2));
409    /// ```
410    pub fn max_input_idx_all(exprs: &[Self]) -> Option<usize> {
411        exprs.iter().filter_map(Self::max_input_idx).max()
412    }
413}
414
415impl From<usize> for DimExpr {
416    fn from(v: usize) -> Self {
417        Self::Const(v)
418    }
419}
420
421impl From<&DimExpr> for DimExpr {
422    fn from(value: &DimExpr) -> Self {
423        value.clone()
424    }
425}