1use std::ops::{Add, Div, Mul, Sub};
2
3use crate::dim_expr::DimExpr;
4
5#[derive(Clone, Debug, PartialEq, Eq)]
20pub struct SymDim(pub(crate) RawSymDim);
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub(crate) enum RawSymDim {
24 Const(usize),
25 TensorAxis { tensor_id: u64, axis: usize },
26 Add(Box<RawSymDim>, Box<RawSymDim>),
27 Sub(Box<RawSymDim>, Box<RawSymDim>),
28 Mul(Box<RawSymDim>, Box<RawSymDim>),
29 FloorDiv(Box<RawSymDim>, Box<RawSymDim>),
30 Min(Box<RawSymDim>, Box<RawSymDim>),
31 Max(Box<RawSymDim>, Box<RawSymDim>),
32}
33
34impl SymDim {
35 pub fn min(self, other: impl Into<SymDim>) -> Self {
45 let other = other.into();
46 binary(self, other, |lhs, rhs| {
47 RawSymDim::Min(Box::new(lhs), Box::new(rhs))
48 })
49 }
50
51 pub fn max(self, other: impl Into<SymDim>) -> Self {
61 let other = other.into();
62 binary(self, other, |lhs, rhs| {
63 RawSymDim::Max(Box::new(lhs), Box::new(rhs))
64 })
65 }
66
67 #[doc(hidden)]
68 pub fn tensor_axis(tensor_id: u64, axis: usize) -> Self {
69 Self(RawSymDim::TensorAxis { tensor_id, axis })
70 }
71
72 #[doc(hidden)]
73 pub fn constant_value(&self) -> Option<usize> {
74 raw_constant_value(&self.0)
75 }
76
77 #[doc(hidden)]
78 pub fn from_dim_expr(expr: &DimExpr, input_shapes: &[&[SymDim]]) -> Self {
79 match expr {
80 DimExpr::Const(value) => Self::from(*value),
81 DimExpr::InputDim { input_idx, axis } => input_shapes[*input_idx][*axis].clone(),
82 DimExpr::Add(lhs, rhs) => {
83 Self::from_dim_expr(lhs, input_shapes) + Self::from_dim_expr(rhs, input_shapes)
84 }
85 DimExpr::Sub(lhs, rhs) => {
86 Self::from_dim_expr(lhs, input_shapes) - Self::from_dim_expr(rhs, input_shapes)
87 }
88 DimExpr::Mul(lhs, rhs) => {
89 Self::from_dim_expr(lhs, input_shapes) * Self::from_dim_expr(rhs, input_shapes)
90 }
91 DimExpr::FloorDiv(lhs, rhs) => {
92 Self::from_dim_expr(lhs, input_shapes) / Self::from_dim_expr(rhs, input_shapes)
93 }
94 DimExpr::Min(lhs, rhs) => {
95 Self::from_dim_expr(lhs, input_shapes).min(Self::from_dim_expr(rhs, input_shapes))
96 }
97 DimExpr::Max(lhs, rhs) => {
98 Self::from_dim_expr(lhs, input_shapes).max(Self::from_dim_expr(rhs, input_shapes))
99 }
100 }
101 }
102
103 #[doc(hidden)]
104 pub fn to_dim_expr(&self, tensor_map: &[(u64, usize)]) -> std::result::Result<DimExpr, String> {
105 raw_to_dim_expr(&self.0, tensor_map)
106 }
107
108 pub fn referenced_tensor_ids(&self) -> Vec<u64> {
127 let mut ids = Vec::new();
128 collect_tensor_ids(&self.0, &mut ids);
129 ids
130 }
131}
132
133fn collect_tensor_ids(raw: &RawSymDim, ids: &mut Vec<u64>) {
134 match raw {
135 RawSymDim::Const(_) => {}
136 RawSymDim::TensorAxis { tensor_id, .. } => {
137 if !ids.contains(tensor_id) {
138 ids.push(*tensor_id);
139 }
140 }
141 RawSymDim::Add(lhs, rhs)
142 | RawSymDim::Sub(lhs, rhs)
143 | RawSymDim::Mul(lhs, rhs)
144 | RawSymDim::FloorDiv(lhs, rhs)
145 | RawSymDim::Min(lhs, rhs)
146 | RawSymDim::Max(lhs, rhs) => {
147 collect_tensor_ids(lhs, ids);
148 collect_tensor_ids(rhs, ids);
149 }
150 }
151}
152
153fn raw_constant_value(raw: &RawSymDim) -> Option<usize> {
154 match raw {
155 RawSymDim::Const(value) => Some(*value),
156 RawSymDim::TensorAxis { .. } => None,
157 RawSymDim::Add(lhs, rhs) => raw_constant_value(lhs)?.checked_add(raw_constant_value(rhs)?),
158 RawSymDim::Sub(lhs, rhs) => raw_constant_value(lhs)?.checked_sub(raw_constant_value(rhs)?),
159 RawSymDim::Mul(lhs, rhs) => raw_constant_value(lhs)?.checked_mul(raw_constant_value(rhs)?),
160 RawSymDim::FloorDiv(lhs, rhs) => {
161 let rhs = raw_constant_value(rhs)?;
162 raw_constant_value(lhs)?.checked_div(rhs)
163 }
164 RawSymDim::Min(lhs, rhs) => Some(raw_constant_value(lhs)?.min(raw_constant_value(rhs)?)),
165 RawSymDim::Max(lhs, rhs) => Some(raw_constant_value(lhs)?.max(raw_constant_value(rhs)?)),
166 }
167}
168
169impl From<usize> for SymDim {
170 fn from(value: usize) -> Self {
171 Self(RawSymDim::Const(value))
172 }
173}
174
175fn raw_to_dim_expr(
176 raw: &RawSymDim,
177 tensor_map: &[(u64, usize)],
178) -> std::result::Result<DimExpr, String> {
179 Ok(match raw {
180 RawSymDim::Const(value) => DimExpr::Const(*value),
181 RawSymDim::TensorAxis { tensor_id, axis } => {
182 let input_idx = tensor_map
183 .iter()
184 .find_map(|(candidate_id, input_idx)| {
185 (candidate_id == tensor_id).then_some(*input_idx)
186 })
187 .ok_or_else(|| format!("unknown symbolic tensor id {tensor_id}"))?;
188 DimExpr::InputDim {
189 input_idx,
190 axis: *axis,
191 }
192 }
193 RawSymDim::Add(lhs, rhs) => DimExpr::add(
194 raw_to_dim_expr(lhs, tensor_map)?,
195 raw_to_dim_expr(rhs, tensor_map)?,
196 ),
197 RawSymDim::Sub(lhs, rhs) => DimExpr::sub(
198 raw_to_dim_expr(lhs, tensor_map)?,
199 raw_to_dim_expr(rhs, tensor_map)?,
200 ),
201 RawSymDim::Mul(lhs, rhs) => DimExpr::mul(
202 raw_to_dim_expr(lhs, tensor_map)?,
203 raw_to_dim_expr(rhs, tensor_map)?,
204 ),
205 RawSymDim::FloorDiv(lhs, rhs) => DimExpr::floor_div(
206 raw_to_dim_expr(lhs, tensor_map)?,
207 raw_to_dim_expr(rhs, tensor_map)?,
208 ),
209 RawSymDim::Min(lhs, rhs) => DimExpr::min(
210 raw_to_dim_expr(lhs, tensor_map)?,
211 raw_to_dim_expr(rhs, tensor_map)?,
212 ),
213 RawSymDim::Max(lhs, rhs) => DimExpr::max(
214 raw_to_dim_expr(lhs, tensor_map)?,
215 raw_to_dim_expr(rhs, tensor_map)?,
216 ),
217 })
218}
219
220fn binary(lhs: SymDim, rhs: SymDim, f: impl FnOnce(RawSymDim, RawSymDim) -> RawSymDim) -> SymDim {
221 SymDim(f(lhs.0, rhs.0))
222}
223
224impl Add for SymDim {
225 type Output = SymDim;
226
227 fn add(self, rhs: SymDim) -> Self::Output {
228 binary(self, rhs, |lhs, rhs| {
229 RawSymDim::Add(Box::new(lhs), Box::new(rhs))
230 })
231 }
232}
233
234impl Add<usize> for SymDim {
235 type Output = SymDim;
236
237 fn add(self, rhs: usize) -> Self::Output {
238 self + SymDim::from(rhs)
239 }
240}
241
242impl Add<SymDim> for usize {
243 type Output = SymDim;
244
245 fn add(self, rhs: SymDim) -> Self::Output {
246 SymDim::from(self) + rhs
247 }
248}
249
250impl Sub for SymDim {
251 type Output = SymDim;
252
253 fn sub(self, rhs: SymDim) -> Self::Output {
254 binary(self, rhs, |lhs, rhs| {
255 RawSymDim::Sub(Box::new(lhs), Box::new(rhs))
256 })
257 }
258}
259
260impl Sub<usize> for SymDim {
261 type Output = SymDim;
262
263 fn sub(self, rhs: usize) -> Self::Output {
264 self - SymDim::from(rhs)
265 }
266}
267
268impl Sub<SymDim> for usize {
269 type Output = SymDim;
270
271 fn sub(self, rhs: SymDim) -> Self::Output {
272 SymDim::from(self) - rhs
273 }
274}
275
276impl Mul for SymDim {
277 type Output = SymDim;
278
279 fn mul(self, rhs: SymDim) -> Self::Output {
280 binary(self, rhs, |lhs, rhs| {
281 RawSymDim::Mul(Box::new(lhs), Box::new(rhs))
282 })
283 }
284}
285
286impl Mul<usize> for SymDim {
287 type Output = SymDim;
288
289 fn mul(self, rhs: usize) -> Self::Output {
290 self * SymDim::from(rhs)
291 }
292}
293
294impl Mul<SymDim> for usize {
295 type Output = SymDim;
296
297 fn mul(self, rhs: SymDim) -> Self::Output {
298 SymDim::from(self) * rhs
299 }
300}
301
302impl Div for SymDim {
303 type Output = SymDim;
304
305 fn div(self, rhs: SymDim) -> Self::Output {
306 binary(self, rhs, |lhs, rhs| {
307 RawSymDim::FloorDiv(Box::new(lhs), Box::new(rhs))
308 })
309 }
310}
311
312impl Div<usize> for SymDim {
313 type Output = SymDim;
314
315 fn div(self, rhs: usize) -> Self::Output {
316 self / SymDim::from(rhs)
317 }
318}
319
320impl Div<SymDim> for usize {
321 type Output = SymDim;
322
323 fn div(self, rhs: SymDim) -> Self::Output {
324 SymDim::from(self) / rhs
325 }
326}