tenferro_ops/dim_expr.rs
1/// Arithmetic expression over tensor dimension sizes.
2///
3/// Evaluated at execution time from actual input tensor shapes.
4/// `InputDim { input_idx, axis }` references the axis size of
5/// the op's `input_idx`-th input tensor.
6///
7/// # Examples
8///
9/// ```ignore
10/// use tenferro_ops::dim_expr::DimExpr;
11///
12/// let expr = DimExpr::mul(
13/// DimExpr::InputDim {
14/// input_idx: 0,
15/// axis: 0,
16/// },
17/// DimExpr::InputDim {
18/// input_idx: 0,
19/// axis: 1,
20/// },
21/// );
22/// assert_eq!(expr.eval(&[&[3, 4]]), 12);
23/// ```
24#[derive(Clone, Debug, PartialEq, Eq, Hash)]
25pub enum DimExpr {
26 /// A concrete dimension size.
27 Const(usize),
28 /// Axis size of the op's `input_idx`-th input tensor.
29 InputDim { input_idx: usize, axis: usize },
30 /// Sum of two dimension expressions.
31 Add(Box<DimExpr>, Box<DimExpr>),
32 /// Difference of two dimension expressions.
33 Sub(Box<DimExpr>, Box<DimExpr>),
34 /// Product of two dimension expressions.
35 Mul(Box<DimExpr>, Box<DimExpr>),
36 /// Floor division of two dimension expressions.
37 FloorDiv(Box<DimExpr>, Box<DimExpr>),
38 /// Minimum of two dimension expressions.
39 Min(Box<DimExpr>, Box<DimExpr>),
40 /// Maximum of two dimension expressions.
41 Max(Box<DimExpr>, Box<DimExpr>),
42}
43
44impl DimExpr {
45 /// Evaluate the expression using actual input tensor shapes.
46 ///
47 /// # Panics
48 ///
49 /// Panics if an `InputDim` node references an `input_idx` that is
50 /// out of bounds for `input_shapes`, or an `axis` that is out of
51 /// bounds for the corresponding shape slice.
52 ///
53 /// # Examples
54 ///
55 /// ```ignore
56 /// use tenferro_ops::dim_expr::DimExpr;
57 ///
58 /// let expr = DimExpr::add(
59 /// DimExpr::InputDim {
60 /// input_idx: 0,
61 /// axis: 0,
62 /// },
63 /// DimExpr::Const(2),
64 /// );
65 /// assert_eq!(expr.eval(&[&[5, 7]]), 7);
66 /// ```
67 pub fn eval(&self, input_shapes: &[&[usize]]) -> usize {
68 match self {
69 Self::Const(v) => *v,
70 Self::InputDim { input_idx, axis } => input_shapes[*input_idx][*axis],
71 Self::Add(a, b) => a.eval(input_shapes) + b.eval(input_shapes),
72 Self::Sub(a, b) => a.eval(input_shapes) - b.eval(input_shapes),
73 Self::Mul(a, b) => a.eval(input_shapes) * b.eval(input_shapes),
74 Self::FloorDiv(a, b) => a.eval(input_shapes) / b.eval(input_shapes),
75 Self::Min(a, b) => a.eval(input_shapes).min(b.eval(input_shapes)),
76 Self::Max(a, b) => a.eval(input_shapes).max(b.eval(input_shapes)),
77 }
78 }
79
80 /// Return the maximum referenced `input_idx`, or `None` if the expression
81 /// contains only constants.
82 ///
83 /// # Examples
84 ///
85 /// ```ignore
86 /// use tenferro_ops::dim_expr::DimExpr;
87 ///
88 /// let expr = DimExpr::add(
89 /// DimExpr::InputDim {
90 /// input_idx: 0,
91 /// axis: 0,
92 /// },
93 /// DimExpr::InputDim {
94 /// input_idx: 2,
95 /// axis: 1,
96 /// },
97 /// );
98 /// assert_eq!(expr.max_input_idx(), Some(2));
99 /// ```
100 pub fn max_input_idx(&self) -> Option<usize> {
101 match self {
102 Self::Const(_) => None,
103 Self::InputDim { input_idx, .. } => Some(*input_idx),
104 Self::Add(a, b)
105 | Self::Sub(a, b)
106 | Self::Mul(a, b)
107 | Self::FloorDiv(a, b)
108 | Self::Min(a, b)
109 | Self::Max(a, b) => match (a.max_input_idx(), b.max_input_idx()) {
110 (Some(x), Some(y)) => Some(x.max(y)),
111 (Some(x), None) | (None, Some(x)) => Some(x),
112 (None, None) => None,
113 },
114 }
115 }
116
117 /// Remap `InputDim { input_idx: from, .. }` to `InputDim { input_idx: to, .. }`.
118 ///
119 /// # Examples
120 ///
121 /// ```ignore
122 /// use tenferro_ops::dim_expr::DimExpr;
123 ///
124 /// let expr = DimExpr::InputDim {
125 /// input_idx: 0,
126 /// axis: 1,
127 /// };
128 /// assert_eq!(expr.remap(0, 2), DimExpr::InputDim { input_idx: 2, axis: 1 });
129 /// ```
130 pub fn remap(&self, from: usize, to: usize) -> Self {
131 match self {
132 Self::Const(v) => Self::Const(*v),
133 Self::InputDim { input_idx, axis } => Self::InputDim {
134 input_idx: if *input_idx == from { to } else { *input_idx },
135 axis: *axis,
136 },
137 Self::Add(a, b) => Self::add(a.remap(from, to), b.remap(from, to)),
138 Self::Sub(a, b) => Self::sub(a.remap(from, to), b.remap(from, to)),
139 Self::Mul(a, b) => Self::mul(a.remap(from, to), b.remap(from, to)),
140 Self::FloorDiv(a, b) => Self::floor_div(a.remap(from, to), b.remap(from, to)),
141 Self::Min(a, b) => Self::min(a.remap(from, to), b.remap(from, to)),
142 Self::Max(a, b) => Self::max(a.remap(from, to), b.remap(from, to)),
143 }
144 }
145
146 /// Construct a constant dimension expression.
147 ///
148 /// # Examples
149 ///
150 /// ```ignore
151 /// use tenferro_ops::dim_expr::DimExpr;
152 ///
153 /// assert_eq!(DimExpr::constant(4), DimExpr::Const(4));
154 /// ```
155 pub fn constant(v: usize) -> Self {
156 Self::Const(v)
157 }
158
159 /// Construct an addition node.
160 ///
161 /// # Examples
162 ///
163 /// ```ignore
164 /// use tenferro_ops::dim_expr::DimExpr;
165 ///
166 /// let expr = DimExpr::add(DimExpr::Const(2), DimExpr::Const(3));
167 /// assert_eq!(expr.eval(&[]), 5);
168 /// ```
169 pub fn add(a: Self, b: Self) -> Self {
170 Self::Add(Box::new(a), Box::new(b))
171 }
172
173 /// Construct a subtraction node.
174 ///
175 /// # Examples
176 ///
177 /// ```ignore
178 /// use tenferro_ops::dim_expr::DimExpr;
179 ///
180 /// let expr = DimExpr::sub(DimExpr::Const(7), DimExpr::Const(2));
181 /// assert_eq!(expr.eval(&[]), 5);
182 /// ```
183 pub fn sub(a: Self, b: Self) -> Self {
184 Self::Sub(Box::new(a), Box::new(b))
185 }
186
187 /// Construct a multiplication node.
188 ///
189 /// # Examples
190 ///
191 /// ```ignore
192 /// use tenferro_ops::dim_expr::DimExpr;
193 ///
194 /// let expr = DimExpr::mul(DimExpr::Const(3), DimExpr::Const(4));
195 /// assert_eq!(expr.eval(&[]), 12);
196 /// ```
197 pub fn mul(a: Self, b: Self) -> Self {
198 Self::Mul(Box::new(a), Box::new(b))
199 }
200
201 /// Construct a floor-division node.
202 ///
203 /// # Examples
204 ///
205 /// ```ignore
206 /// use tenferro_ops::dim_expr::DimExpr;
207 ///
208 /// let expr = DimExpr::floor_div(DimExpr::Const(9), DimExpr::Const(2));
209 /// assert_eq!(expr.eval(&[]), 4);
210 /// ```
211 pub fn floor_div(a: Self, b: Self) -> Self {
212 Self::FloorDiv(Box::new(a), Box::new(b))
213 }
214
215 /// Construct a minimum node.
216 ///
217 /// # Examples
218 ///
219 /// ```ignore
220 /// use tenferro_ops::dim_expr::DimExpr;
221 ///
222 /// let expr = DimExpr::min(DimExpr::Const(3), DimExpr::Const(5));
223 /// assert_eq!(expr.eval(&[]), 3);
224 /// ```
225 pub fn min(a: Self, b: Self) -> Self {
226 Self::Min(Box::new(a), Box::new(b))
227 }
228
229 /// Construct a maximum node.
230 ///
231 /// # Examples
232 ///
233 /// ```ignore
234 /// use tenferro_ops::dim_expr::DimExpr;
235 ///
236 /// let expr = DimExpr::max(DimExpr::Const(3), DimExpr::Const(5));
237 /// assert_eq!(expr.eval(&[]), 5);
238 /// ```
239 pub fn max(a: Self, b: Self) -> Self {
240 Self::Max(Box::new(a), Box::new(b))
241 }
242
243 /// Return `true` when this expression is a constant.
244 ///
245 /// # Examples
246 ///
247 /// ```ignore
248 /// use tenferro_ops::dim_expr::DimExpr;
249 ///
250 /// assert!(DimExpr::Const(3).is_const());
251 /// ```
252 pub fn is_const(&self) -> bool {
253 matches!(self, Self::Const(_))
254 }
255
256 /// Convert a concrete shape to constant expressions.
257 ///
258 /// # Examples
259 ///
260 /// ```ignore
261 /// use tenferro_ops::dim_expr::DimExpr;
262 ///
263 /// assert_eq!(DimExpr::from_concrete(&[2, 3]), vec![DimExpr::Const(2), DimExpr::Const(3)]);
264 /// ```
265 pub fn from_concrete(shape: &[usize]) -> Vec<Self> {
266 shape.iter().map(|&v| Self::Const(v)).collect()
267 }
268
269 /// Build `[InputDim(input_idx, 0), ..., InputDim(input_idx, rank - 1)]`.
270 ///
271 /// # Examples
272 ///
273 /// ```ignore
274 /// use tenferro_ops::dim_expr::DimExpr;
275 ///
276 /// let shape = DimExpr::input_shape(1, 2);
277 /// assert_eq!(
278 /// shape,
279 /// vec![
280 /// DimExpr::InputDim { input_idx: 1, axis: 0 },
281 /// DimExpr::InputDim { input_idx: 1, axis: 1 },
282 /// ]
283 /// );
284 /// ```
285 pub fn input_shape(input_idx: usize, rank: usize) -> Vec<Self> {
286 (0..rank)
287 .map(|axis| Self::InputDim { input_idx, axis })
288 .collect()
289 }
290
291 /// Evaluate a slice of expressions against actual input shapes.
292 ///
293 /// # Examples
294 ///
295 /// ```ignore
296 /// use tenferro_ops::dim_expr::DimExpr;
297 ///
298 /// let exprs = vec![
299 /// DimExpr::InputDim { input_idx: 0, axis: 0 },
300 /// DimExpr::Const(4),
301 /// ];
302 /// assert_eq!(DimExpr::eval_all(&exprs, &[&[3, 5]]), vec![3, 4]);
303 /// ```
304 pub fn eval_all(exprs: &[Self], input_shapes: &[&[usize]]) -> Vec<usize> {
305 exprs.iter().map(|e| e.eval(input_shapes)).collect()
306 }
307
308 /// Remap all `InputDim` references in a slice of expressions.
309 ///
310 /// # Examples
311 ///
312 /// ```ignore
313 /// use tenferro_ops::dim_expr::DimExpr;
314 ///
315 /// let exprs = vec![DimExpr::InputDim { input_idx: 0, axis: 0 }];
316 /// assert_eq!(
317 /// DimExpr::remap_all(&exprs, 0, 1),
318 /// vec![DimExpr::InputDim { input_idx: 1, axis: 0 }]
319 /// );
320 /// ```
321 pub fn remap_all(exprs: &[Self], from: usize, to: usize) -> Vec<Self> {
322 exprs.iter().map(|e| e.remap(from, to)).collect()
323 }
324
325 /// Compute the maximum referenced `input_idx` across a slice.
326 ///
327 /// # Examples
328 ///
329 /// ```ignore
330 /// use tenferro_ops::dim_expr::DimExpr;
331 ///
332 /// let exprs = vec![
333 /// DimExpr::InputDim { input_idx: 0, axis: 0 },
334 /// DimExpr::InputDim { input_idx: 2, axis: 1 },
335 /// ];
336 /// assert_eq!(DimExpr::max_input_idx_all(&exprs), Some(2));
337 /// ```
338 pub fn max_input_idx_all(exprs: &[Self]) -> Option<usize> {
339 exprs.iter().filter_map(Self::max_input_idx).max()
340 }
341}
342
343impl From<usize> for DimExpr {
344 fn from(v: usize) -> Self {
345 Self::Const(v)
346 }
347}
348
349impl From<&DimExpr> for DimExpr {
350 fn from(value: &DimExpr) -> Self {
351 value.clone()
352 }
353}