Skip to main content

tenferro_ops/
dim_expr.rs

1/// Arithmetic expression over tensor dimension sizes.
2///
3/// Evaluated at execution time from actual input tensor shapes.
4/// `InputDim { input_idx, axis }` references the axis size of
5/// the op's `input_idx`-th input tensor.
6///
7/// # Examples
8///
9/// ```ignore
10/// use tenferro_ops::dim_expr::DimExpr;
11///
12/// let expr = DimExpr::mul(
13///     DimExpr::InputDim {
14///         input_idx: 0,
15///         axis: 0,
16///     },
17///     DimExpr::InputDim {
18///         input_idx: 0,
19///         axis: 1,
20///     },
21/// );
22/// assert_eq!(expr.eval(&[&[3, 4]]), 12);
23/// ```
24#[derive(Clone, Debug, PartialEq, Eq, Hash)]
25pub enum DimExpr {
26    /// A concrete dimension size.
27    Const(usize),
28    /// Axis size of the op's `input_idx`-th input tensor.
29    InputDim { input_idx: usize, axis: usize },
30    /// Sum of two dimension expressions.
31    Add(Box<DimExpr>, Box<DimExpr>),
32    /// Difference of two dimension expressions.
33    Sub(Box<DimExpr>, Box<DimExpr>),
34    /// Product of two dimension expressions.
35    Mul(Box<DimExpr>, Box<DimExpr>),
36    /// Floor division of two dimension expressions.
37    FloorDiv(Box<DimExpr>, Box<DimExpr>),
38    /// Minimum of two dimension expressions.
39    Min(Box<DimExpr>, Box<DimExpr>),
40    /// Maximum of two dimension expressions.
41    Max(Box<DimExpr>, Box<DimExpr>),
42}
43
44impl DimExpr {
45    /// Evaluate the expression using actual input tensor shapes.
46    ///
47    /// # Panics
48    ///
49    /// Panics if an `InputDim` node references an `input_idx` that is
50    /// out of bounds for `input_shapes`, or an `axis` that is out of
51    /// bounds for the corresponding shape slice.
52    ///
53    /// # Examples
54    ///
55    /// ```ignore
56    /// use tenferro_ops::dim_expr::DimExpr;
57    ///
58    /// let expr = DimExpr::add(
59    ///     DimExpr::InputDim {
60    ///         input_idx: 0,
61    ///         axis: 0,
62    ///     },
63    ///     DimExpr::Const(2),
64    /// );
65    /// assert_eq!(expr.eval(&[&[5, 7]]), 7);
66    /// ```
67    pub fn eval(&self, input_shapes: &[&[usize]]) -> usize {
68        match self {
69            Self::Const(v) => *v,
70            Self::InputDim { input_idx, axis } => input_shapes[*input_idx][*axis],
71            Self::Add(a, b) => a.eval(input_shapes) + b.eval(input_shapes),
72            Self::Sub(a, b) => a.eval(input_shapes) - b.eval(input_shapes),
73            Self::Mul(a, b) => a.eval(input_shapes) * b.eval(input_shapes),
74            Self::FloorDiv(a, b) => a.eval(input_shapes) / b.eval(input_shapes),
75            Self::Min(a, b) => a.eval(input_shapes).min(b.eval(input_shapes)),
76            Self::Max(a, b) => a.eval(input_shapes).max(b.eval(input_shapes)),
77        }
78    }
79
80    /// Return the maximum referenced `input_idx`, or `None` if the expression
81    /// contains only constants.
82    ///
83    /// # Examples
84    ///
85    /// ```ignore
86    /// use tenferro_ops::dim_expr::DimExpr;
87    ///
88    /// let expr = DimExpr::add(
89    ///     DimExpr::InputDim {
90    ///         input_idx: 0,
91    ///         axis: 0,
92    ///     },
93    ///     DimExpr::InputDim {
94    ///         input_idx: 2,
95    ///         axis: 1,
96    ///     },
97    /// );
98    /// assert_eq!(expr.max_input_idx(), Some(2));
99    /// ```
100    pub fn max_input_idx(&self) -> Option<usize> {
101        match self {
102            Self::Const(_) => None,
103            Self::InputDim { input_idx, .. } => Some(*input_idx),
104            Self::Add(a, b)
105            | Self::Sub(a, b)
106            | Self::Mul(a, b)
107            | Self::FloorDiv(a, b)
108            | Self::Min(a, b)
109            | Self::Max(a, b) => match (a.max_input_idx(), b.max_input_idx()) {
110                (Some(x), Some(y)) => Some(x.max(y)),
111                (Some(x), None) | (None, Some(x)) => Some(x),
112                (None, None) => None,
113            },
114        }
115    }
116
117    /// Remap `InputDim { input_idx: from, .. }` to `InputDim { input_idx: to, .. }`.
118    ///
119    /// # Examples
120    ///
121    /// ```ignore
122    /// use tenferro_ops::dim_expr::DimExpr;
123    ///
124    /// let expr = DimExpr::InputDim {
125    ///     input_idx: 0,
126    ///     axis: 1,
127    /// };
128    /// assert_eq!(expr.remap(0, 2), DimExpr::InputDim { input_idx: 2, axis: 1 });
129    /// ```
130    pub fn remap(&self, from: usize, to: usize) -> Self {
131        match self {
132            Self::Const(v) => Self::Const(*v),
133            Self::InputDim { input_idx, axis } => Self::InputDim {
134                input_idx: if *input_idx == from { to } else { *input_idx },
135                axis: *axis,
136            },
137            Self::Add(a, b) => Self::add(a.remap(from, to), b.remap(from, to)),
138            Self::Sub(a, b) => Self::sub(a.remap(from, to), b.remap(from, to)),
139            Self::Mul(a, b) => Self::mul(a.remap(from, to), b.remap(from, to)),
140            Self::FloorDiv(a, b) => Self::floor_div(a.remap(from, to), b.remap(from, to)),
141            Self::Min(a, b) => Self::min(a.remap(from, to), b.remap(from, to)),
142            Self::Max(a, b) => Self::max(a.remap(from, to), b.remap(from, to)),
143        }
144    }
145
146    /// Construct a constant dimension expression.
147    ///
148    /// # Examples
149    ///
150    /// ```ignore
151    /// use tenferro_ops::dim_expr::DimExpr;
152    ///
153    /// assert_eq!(DimExpr::constant(4), DimExpr::Const(4));
154    /// ```
155    pub fn constant(v: usize) -> Self {
156        Self::Const(v)
157    }
158
159    /// Construct an addition node.
160    ///
161    /// # Examples
162    ///
163    /// ```ignore
164    /// use tenferro_ops::dim_expr::DimExpr;
165    ///
166    /// let expr = DimExpr::add(DimExpr::Const(2), DimExpr::Const(3));
167    /// assert_eq!(expr.eval(&[]), 5);
168    /// ```
169    pub fn add(a: Self, b: Self) -> Self {
170        Self::Add(Box::new(a), Box::new(b))
171    }
172
173    /// Construct a subtraction node.
174    ///
175    /// # Examples
176    ///
177    /// ```ignore
178    /// use tenferro_ops::dim_expr::DimExpr;
179    ///
180    /// let expr = DimExpr::sub(DimExpr::Const(7), DimExpr::Const(2));
181    /// assert_eq!(expr.eval(&[]), 5);
182    /// ```
183    pub fn sub(a: Self, b: Self) -> Self {
184        Self::Sub(Box::new(a), Box::new(b))
185    }
186
187    /// Construct a multiplication node.
188    ///
189    /// # Examples
190    ///
191    /// ```ignore
192    /// use tenferro_ops::dim_expr::DimExpr;
193    ///
194    /// let expr = DimExpr::mul(DimExpr::Const(3), DimExpr::Const(4));
195    /// assert_eq!(expr.eval(&[]), 12);
196    /// ```
197    pub fn mul(a: Self, b: Self) -> Self {
198        Self::Mul(Box::new(a), Box::new(b))
199    }
200
201    /// Construct a floor-division node.
202    ///
203    /// # Examples
204    ///
205    /// ```ignore
206    /// use tenferro_ops::dim_expr::DimExpr;
207    ///
208    /// let expr = DimExpr::floor_div(DimExpr::Const(9), DimExpr::Const(2));
209    /// assert_eq!(expr.eval(&[]), 4);
210    /// ```
211    pub fn floor_div(a: Self, b: Self) -> Self {
212        Self::FloorDiv(Box::new(a), Box::new(b))
213    }
214
215    /// Construct a minimum node.
216    ///
217    /// # Examples
218    ///
219    /// ```ignore
220    /// use tenferro_ops::dim_expr::DimExpr;
221    ///
222    /// let expr = DimExpr::min(DimExpr::Const(3), DimExpr::Const(5));
223    /// assert_eq!(expr.eval(&[]), 3);
224    /// ```
225    pub fn min(a: Self, b: Self) -> Self {
226        Self::Min(Box::new(a), Box::new(b))
227    }
228
229    /// Construct a maximum node.
230    ///
231    /// # Examples
232    ///
233    /// ```ignore
234    /// use tenferro_ops::dim_expr::DimExpr;
235    ///
236    /// let expr = DimExpr::max(DimExpr::Const(3), DimExpr::Const(5));
237    /// assert_eq!(expr.eval(&[]), 5);
238    /// ```
239    pub fn max(a: Self, b: Self) -> Self {
240        Self::Max(Box::new(a), Box::new(b))
241    }
242
243    /// Return `true` when this expression is a constant.
244    ///
245    /// # Examples
246    ///
247    /// ```ignore
248    /// use tenferro_ops::dim_expr::DimExpr;
249    ///
250    /// assert!(DimExpr::Const(3).is_const());
251    /// ```
252    pub fn is_const(&self) -> bool {
253        matches!(self, Self::Const(_))
254    }
255
256    /// Convert a concrete shape to constant expressions.
257    ///
258    /// # Examples
259    ///
260    /// ```ignore
261    /// use tenferro_ops::dim_expr::DimExpr;
262    ///
263    /// assert_eq!(DimExpr::from_concrete(&[2, 3]), vec![DimExpr::Const(2), DimExpr::Const(3)]);
264    /// ```
265    pub fn from_concrete(shape: &[usize]) -> Vec<Self> {
266        shape.iter().map(|&v| Self::Const(v)).collect()
267    }
268
269    /// Build `[InputDim(input_idx, 0), ..., InputDim(input_idx, rank - 1)]`.
270    ///
271    /// # Examples
272    ///
273    /// ```ignore
274    /// use tenferro_ops::dim_expr::DimExpr;
275    ///
276    /// let shape = DimExpr::input_shape(1, 2);
277    /// assert_eq!(
278    ///     shape,
279    ///     vec![
280    ///         DimExpr::InputDim { input_idx: 1, axis: 0 },
281    ///         DimExpr::InputDim { input_idx: 1, axis: 1 },
282    ///     ]
283    /// );
284    /// ```
285    pub fn input_shape(input_idx: usize, rank: usize) -> Vec<Self> {
286        (0..rank)
287            .map(|axis| Self::InputDim { input_idx, axis })
288            .collect()
289    }
290
291    /// Evaluate a slice of expressions against actual input shapes.
292    ///
293    /// # Examples
294    ///
295    /// ```ignore
296    /// use tenferro_ops::dim_expr::DimExpr;
297    ///
298    /// let exprs = vec![
299    ///     DimExpr::InputDim { input_idx: 0, axis: 0 },
300    ///     DimExpr::Const(4),
301    /// ];
302    /// assert_eq!(DimExpr::eval_all(&exprs, &[&[3, 5]]), vec![3, 4]);
303    /// ```
304    pub fn eval_all(exprs: &[Self], input_shapes: &[&[usize]]) -> Vec<usize> {
305        exprs.iter().map(|e| e.eval(input_shapes)).collect()
306    }
307
308    /// Remap all `InputDim` references in a slice of expressions.
309    ///
310    /// # Examples
311    ///
312    /// ```ignore
313    /// use tenferro_ops::dim_expr::DimExpr;
314    ///
315    /// let exprs = vec![DimExpr::InputDim { input_idx: 0, axis: 0 }];
316    /// assert_eq!(
317    ///     DimExpr::remap_all(&exprs, 0, 1),
318    ///     vec![DimExpr::InputDim { input_idx: 1, axis: 0 }]
319    /// );
320    /// ```
321    pub fn remap_all(exprs: &[Self], from: usize, to: usize) -> Vec<Self> {
322        exprs.iter().map(|e| e.remap(from, to)).collect()
323    }
324
325    /// Compute the maximum referenced `input_idx` across a slice.
326    ///
327    /// # Examples
328    ///
329    /// ```ignore
330    /// use tenferro_ops::dim_expr::DimExpr;
331    ///
332    /// let exprs = vec![
333    ///     DimExpr::InputDim { input_idx: 0, axis: 0 },
334    ///     DimExpr::InputDim { input_idx: 2, axis: 1 },
335    /// ];
336    /// assert_eq!(DimExpr::max_input_idx_all(&exprs), Some(2));
337    /// ```
338    pub fn max_input_idx_all(exprs: &[Self]) -> Option<usize> {
339        exprs.iter().filter_map(Self::max_input_idx).max()
340    }
341}
342
343impl From<usize> for DimExpr {
344    fn from(v: usize) -> Self {
345        Self::Const(v)
346    }
347}
348
349impl From<&DimExpr> for DimExpr {
350    fn from(value: &DimExpr) -> Self {
351        value.clone()
352    }
353}