1use std::ops::{Add, Div, Mul, Sub};
2
3use tenferro_ops::dim_expr::DimExpr;
4
5use crate::traced::TracedTensorId;
6
7#[derive(Clone, Debug, PartialEq, Eq)]
22pub struct SymDim(pub(crate) RawSymDim);
23
24#[derive(Clone, Debug, PartialEq, Eq)]
25pub(crate) enum RawSymDim {
26 Const(usize),
27 TensorAxis {
28 tensor_id: TracedTensorId,
29 axis: usize,
30 },
31 Add(Box<RawSymDim>, Box<RawSymDim>),
32 Sub(Box<RawSymDim>, Box<RawSymDim>),
33 Mul(Box<RawSymDim>, Box<RawSymDim>),
34 FloorDiv(Box<RawSymDim>, Box<RawSymDim>),
35 Min(Box<RawSymDim>, Box<RawSymDim>),
36 Max(Box<RawSymDim>, Box<RawSymDim>),
37}
38
39impl SymDim {
40 pub fn min(self, other: impl Into<SymDim>) -> Self {
48 let other = other.into();
49 binary(self, other, |lhs, rhs| {
50 RawSymDim::Min(Box::new(lhs), Box::new(rhs))
51 })
52 }
53
54 pub fn max(self, other: impl Into<SymDim>) -> Self {
62 let other = other.into();
63 binary(self, other, |lhs, rhs| {
64 RawSymDim::Max(Box::new(lhs), Box::new(rhs))
65 })
66 }
67
68 pub(crate) fn tensor_axis(tensor_id: TracedTensorId, axis: usize) -> Self {
69 Self(RawSymDim::TensorAxis { tensor_id, axis })
70 }
71
72 pub(crate) fn constant_value(&self) -> Option<usize> {
73 match &self.0 {
74 RawSymDim::Const(value) => Some(*value),
75 _ => None,
76 }
77 }
78
79 pub(crate) fn to_dim_expr(
80 &self,
81 tensor_map: &[(TracedTensorId, usize)],
82 ) -> std::result::Result<DimExpr, String> {
83 raw_to_dim_expr(&self.0, tensor_map)
84 }
85}
86
87impl From<usize> for SymDim {
88 fn from(value: usize) -> Self {
89 Self(RawSymDim::Const(value))
90 }
91}
92
93fn raw_to_dim_expr(
94 raw: &RawSymDim,
95 tensor_map: &[(TracedTensorId, usize)],
96) -> std::result::Result<DimExpr, String> {
97 Ok(match raw {
98 RawSymDim::Const(value) => DimExpr::Const(*value),
99 RawSymDim::TensorAxis { tensor_id, axis } => {
100 let input_idx = tensor_map
101 .iter()
102 .find_map(|(candidate_id, input_idx)| {
103 (candidate_id == tensor_id).then_some(*input_idx)
104 })
105 .ok_or_else(|| format!("unknown symbolic tensor id {tensor_id}"))?;
106 DimExpr::InputDim {
107 input_idx,
108 axis: *axis,
109 }
110 }
111 RawSymDim::Add(lhs, rhs) => DimExpr::add(
112 raw_to_dim_expr(lhs, tensor_map)?,
113 raw_to_dim_expr(rhs, tensor_map)?,
114 ),
115 RawSymDim::Sub(lhs, rhs) => DimExpr::sub(
116 raw_to_dim_expr(lhs, tensor_map)?,
117 raw_to_dim_expr(rhs, tensor_map)?,
118 ),
119 RawSymDim::Mul(lhs, rhs) => DimExpr::mul(
120 raw_to_dim_expr(lhs, tensor_map)?,
121 raw_to_dim_expr(rhs, tensor_map)?,
122 ),
123 RawSymDim::FloorDiv(lhs, rhs) => DimExpr::floor_div(
124 raw_to_dim_expr(lhs, tensor_map)?,
125 raw_to_dim_expr(rhs, tensor_map)?,
126 ),
127 RawSymDim::Min(lhs, rhs) => DimExpr::min(
128 raw_to_dim_expr(lhs, tensor_map)?,
129 raw_to_dim_expr(rhs, tensor_map)?,
130 ),
131 RawSymDim::Max(lhs, rhs) => DimExpr::max(
132 raw_to_dim_expr(lhs, tensor_map)?,
133 raw_to_dim_expr(rhs, tensor_map)?,
134 ),
135 })
136}
137
138fn binary(lhs: SymDim, rhs: SymDim, f: impl FnOnce(RawSymDim, RawSymDim) -> RawSymDim) -> SymDim {
139 SymDim(f(lhs.0, rhs.0))
140}
141
142impl Add for SymDim {
143 type Output = SymDim;
144
145 fn add(self, rhs: SymDim) -> Self::Output {
146 binary(self, rhs, |lhs, rhs| {
147 RawSymDim::Add(Box::new(lhs), Box::new(rhs))
148 })
149 }
150}
151
152impl Add<usize> for SymDim {
153 type Output = SymDim;
154
155 fn add(self, rhs: usize) -> Self::Output {
156 self + SymDim::from(rhs)
157 }
158}
159
160impl Add<SymDim> for usize {
161 type Output = SymDim;
162
163 fn add(self, rhs: SymDim) -> Self::Output {
164 SymDim::from(self) + rhs
165 }
166}
167
168impl Sub for SymDim {
169 type Output = SymDim;
170
171 fn sub(self, rhs: SymDim) -> Self::Output {
172 binary(self, rhs, |lhs, rhs| {
173 RawSymDim::Sub(Box::new(lhs), Box::new(rhs))
174 })
175 }
176}
177
178impl Sub<usize> for SymDim {
179 type Output = SymDim;
180
181 fn sub(self, rhs: usize) -> Self::Output {
182 self - SymDim::from(rhs)
183 }
184}
185
186impl Sub<SymDim> for usize {
187 type Output = SymDim;
188
189 fn sub(self, rhs: SymDim) -> Self::Output {
190 SymDim::from(self) - rhs
191 }
192}
193
194impl Mul for SymDim {
195 type Output = SymDim;
196
197 fn mul(self, rhs: SymDim) -> Self::Output {
198 binary(self, rhs, |lhs, rhs| {
199 RawSymDim::Mul(Box::new(lhs), Box::new(rhs))
200 })
201 }
202}
203
204impl Mul<usize> for SymDim {
205 type Output = SymDim;
206
207 fn mul(self, rhs: usize) -> Self::Output {
208 self * SymDim::from(rhs)
209 }
210}
211
212impl Mul<SymDim> for usize {
213 type Output = SymDim;
214
215 fn mul(self, rhs: SymDim) -> Self::Output {
216 SymDim::from(self) * rhs
217 }
218}
219
220impl Div for SymDim {
221 type Output = SymDim;
222
223 fn div(self, rhs: SymDim) -> Self::Output {
224 binary(self, rhs, |lhs, rhs| {
225 RawSymDim::FloorDiv(Box::new(lhs), Box::new(rhs))
226 })
227 }
228}
229
230impl Div<usize> for SymDim {
231 type Output = SymDim;
232
233 fn div(self, rhs: usize) -> Self::Output {
234 self / SymDim::from(rhs)
235 }
236}
237
238impl Div<SymDim> for usize {
239 type Output = SymDim;
240
241 fn div(self, rhs: SymDim) -> Self::Output {
242 SymDim::from(self) / rhs
243 }
244}