Skip to main content

tenferro_ops/
sym_dim.rs

1use std::ops::{Add, Div, Mul, Sub};
2
3use crate::dim_expr::DimExpr;
4
5/// A symbolic tensor dimension expression used to build shape-agnostic graphs.
6///
7/// `SymDim` values can be combined with basic arithmetic and later resolved
8/// against op-local [`DimExpr`] expressions when shape metadata is propagated.
9///
10/// # Examples
11///
12/// ```rust
13/// use tenferro_ops::sym_dim::SymDim;
14///
15/// let rows = SymDim::from(3usize);
16/// let cols = SymDim::from(4usize);
17/// let area = rows * cols;
18/// ```
19#[derive(Clone, Debug, PartialEq, Eq)]
20pub struct SymDim(pub(crate) RawSymDim);
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub(crate) enum RawSymDim {
24    Const(usize),
25    TensorAxis { tensor_id: u64, axis: usize },
26    Add(Box<RawSymDim>, Box<RawSymDim>),
27    Sub(Box<RawSymDim>, Box<RawSymDim>),
28    Mul(Box<RawSymDim>, Box<RawSymDim>),
29    FloorDiv(Box<RawSymDim>, Box<RawSymDim>),
30    Min(Box<RawSymDim>, Box<RawSymDim>),
31    Max(Box<RawSymDim>, Box<RawSymDim>),
32}
33
34impl SymDim {
35    /// Return the smaller of two symbolic dimensions.
36    ///
37    /// # Examples
38    ///
39    /// ```rust
40    /// use tenferro_ops::sym_dim::SymDim;
41    ///
42    /// let clipped = SymDim::from(5usize).min(8usize);
43    /// ```
44    pub fn min(self, other: impl Into<SymDim>) -> Self {
45        let other = other.into();
46        binary(self, other, |lhs, rhs| {
47            RawSymDim::Min(Box::new(lhs), Box::new(rhs))
48        })
49    }
50
51    /// Return the larger of two symbolic dimensions.
52    ///
53    /// # Examples
54    ///
55    /// ```rust
56    /// use tenferro_ops::sym_dim::SymDim;
57    ///
58    /// let batch = SymDim::from(2usize).max(1usize);
59    /// ```
60    pub fn max(self, other: impl Into<SymDim>) -> Self {
61        let other = other.into();
62        binary(self, other, |lhs, rhs| {
63            RawSymDim::Max(Box::new(lhs), Box::new(rhs))
64        })
65    }
66
67    #[doc(hidden)]
68    pub fn tensor_axis(tensor_id: u64, axis: usize) -> Self {
69        Self(RawSymDim::TensorAxis { tensor_id, axis })
70    }
71
72    #[doc(hidden)]
73    pub fn constant_value(&self) -> Option<usize> {
74        raw_constant_value(&self.0)
75    }
76
77    #[doc(hidden)]
78    pub fn from_dim_expr(expr: &DimExpr, input_shapes: &[&[SymDim]]) -> Self {
79        match expr {
80            DimExpr::Const(value) => Self::from(*value),
81            DimExpr::InputDim { input_idx, axis } => input_shapes[*input_idx][*axis].clone(),
82            DimExpr::Add(lhs, rhs) => {
83                Self::from_dim_expr(lhs, input_shapes) + Self::from_dim_expr(rhs, input_shapes)
84            }
85            DimExpr::Sub(lhs, rhs) => {
86                Self::from_dim_expr(lhs, input_shapes) - Self::from_dim_expr(rhs, input_shapes)
87            }
88            DimExpr::Mul(lhs, rhs) => {
89                Self::from_dim_expr(lhs, input_shapes) * Self::from_dim_expr(rhs, input_shapes)
90            }
91            DimExpr::FloorDiv(lhs, rhs) => {
92                Self::from_dim_expr(lhs, input_shapes) / Self::from_dim_expr(rhs, input_shapes)
93            }
94            DimExpr::Min(lhs, rhs) => {
95                Self::from_dim_expr(lhs, input_shapes).min(Self::from_dim_expr(rhs, input_shapes))
96            }
97            DimExpr::Max(lhs, rhs) => {
98                Self::from_dim_expr(lhs, input_shapes).max(Self::from_dim_expr(rhs, input_shapes))
99            }
100        }
101    }
102
103    #[doc(hidden)]
104    pub fn to_dim_expr(&self, tensor_map: &[(u64, usize)]) -> std::result::Result<DimExpr, String> {
105        raw_to_dim_expr(&self.0, tensor_map)
106    }
107
108    /// Collect the unique traced tensor IDs referenced by `TensorAxis`
109    /// variants inside this expression, in traversal order (first
110    /// occurrence wins).
111    ///
112    /// Used by traced composition wrappers that build multi-input ops
113    /// from `SymDim`-valued target shapes — e.g.
114    /// `TracedTensor::broadcast_in_dim_sym`.
115    ///
116    /// # Examples
117    ///
118    /// ```rust
119    /// use tenferro_ops::sym_dim::SymDim;
120    ///
121    /// let lhs = SymDim::tensor_axis(7, 0);
122    /// let rhs = SymDim::tensor_axis(9, 1);
123    /// let sum = lhs + rhs;
124    /// assert_eq!(sum.referenced_tensor_ids(), vec![7, 9]);
125    /// ```
126    pub fn referenced_tensor_ids(&self) -> Vec<u64> {
127        let mut ids = Vec::new();
128        collect_tensor_ids(&self.0, &mut ids);
129        ids
130    }
131}
132
133fn collect_tensor_ids(raw: &RawSymDim, ids: &mut Vec<u64>) {
134    match raw {
135        RawSymDim::Const(_) => {}
136        RawSymDim::TensorAxis { tensor_id, .. } => {
137            if !ids.contains(tensor_id) {
138                ids.push(*tensor_id);
139            }
140        }
141        RawSymDim::Add(lhs, rhs)
142        | RawSymDim::Sub(lhs, rhs)
143        | RawSymDim::Mul(lhs, rhs)
144        | RawSymDim::FloorDiv(lhs, rhs)
145        | RawSymDim::Min(lhs, rhs)
146        | RawSymDim::Max(lhs, rhs) => {
147            collect_tensor_ids(lhs, ids);
148            collect_tensor_ids(rhs, ids);
149        }
150    }
151}
152
153fn raw_constant_value(raw: &RawSymDim) -> Option<usize> {
154    match raw {
155        RawSymDim::Const(value) => Some(*value),
156        RawSymDim::TensorAxis { .. } => None,
157        RawSymDim::Add(lhs, rhs) => raw_constant_value(lhs)?.checked_add(raw_constant_value(rhs)?),
158        RawSymDim::Sub(lhs, rhs) => raw_constant_value(lhs)?.checked_sub(raw_constant_value(rhs)?),
159        RawSymDim::Mul(lhs, rhs) => raw_constant_value(lhs)?.checked_mul(raw_constant_value(rhs)?),
160        RawSymDim::FloorDiv(lhs, rhs) => {
161            let rhs = raw_constant_value(rhs)?;
162            raw_constant_value(lhs)?.checked_div(rhs)
163        }
164        RawSymDim::Min(lhs, rhs) => Some(raw_constant_value(lhs)?.min(raw_constant_value(rhs)?)),
165        RawSymDim::Max(lhs, rhs) => Some(raw_constant_value(lhs)?.max(raw_constant_value(rhs)?)),
166    }
167}
168
169impl From<usize> for SymDim {
170    fn from(value: usize) -> Self {
171        Self(RawSymDim::Const(value))
172    }
173}
174
175fn raw_to_dim_expr(
176    raw: &RawSymDim,
177    tensor_map: &[(u64, usize)],
178) -> std::result::Result<DimExpr, String> {
179    Ok(match raw {
180        RawSymDim::Const(value) => DimExpr::Const(*value),
181        RawSymDim::TensorAxis { tensor_id, axis } => {
182            let input_idx = tensor_map
183                .iter()
184                .find_map(|(candidate_id, input_idx)| {
185                    (candidate_id == tensor_id).then_some(*input_idx)
186                })
187                .ok_or_else(|| format!("unknown symbolic tensor id {tensor_id}"))?;
188            DimExpr::InputDim {
189                input_idx,
190                axis: *axis,
191            }
192        }
193        RawSymDim::Add(lhs, rhs) => DimExpr::add(
194            raw_to_dim_expr(lhs, tensor_map)?,
195            raw_to_dim_expr(rhs, tensor_map)?,
196        ),
197        RawSymDim::Sub(lhs, rhs) => DimExpr::sub(
198            raw_to_dim_expr(lhs, tensor_map)?,
199            raw_to_dim_expr(rhs, tensor_map)?,
200        ),
201        RawSymDim::Mul(lhs, rhs) => DimExpr::mul(
202            raw_to_dim_expr(lhs, tensor_map)?,
203            raw_to_dim_expr(rhs, tensor_map)?,
204        ),
205        RawSymDim::FloorDiv(lhs, rhs) => DimExpr::floor_div(
206            raw_to_dim_expr(lhs, tensor_map)?,
207            raw_to_dim_expr(rhs, tensor_map)?,
208        ),
209        RawSymDim::Min(lhs, rhs) => DimExpr::min(
210            raw_to_dim_expr(lhs, tensor_map)?,
211            raw_to_dim_expr(rhs, tensor_map)?,
212        ),
213        RawSymDim::Max(lhs, rhs) => DimExpr::max(
214            raw_to_dim_expr(lhs, tensor_map)?,
215            raw_to_dim_expr(rhs, tensor_map)?,
216        ),
217    })
218}
219
220fn binary(lhs: SymDim, rhs: SymDim, f: impl FnOnce(RawSymDim, RawSymDim) -> RawSymDim) -> SymDim {
221    SymDim(f(lhs.0, rhs.0))
222}
223
224impl Add for SymDim {
225    type Output = SymDim;
226
227    fn add(self, rhs: SymDim) -> Self::Output {
228        binary(self, rhs, |lhs, rhs| {
229            RawSymDim::Add(Box::new(lhs), Box::new(rhs))
230        })
231    }
232}
233
234impl Add<usize> for SymDim {
235    type Output = SymDim;
236
237    fn add(self, rhs: usize) -> Self::Output {
238        self + SymDim::from(rhs)
239    }
240}
241
242impl Add<SymDim> for usize {
243    type Output = SymDim;
244
245    fn add(self, rhs: SymDim) -> Self::Output {
246        SymDim::from(self) + rhs
247    }
248}
249
250impl Sub for SymDim {
251    type Output = SymDim;
252
253    fn sub(self, rhs: SymDim) -> Self::Output {
254        binary(self, rhs, |lhs, rhs| {
255            RawSymDim::Sub(Box::new(lhs), Box::new(rhs))
256        })
257    }
258}
259
260impl Sub<usize> for SymDim {
261    type Output = SymDim;
262
263    fn sub(self, rhs: usize) -> Self::Output {
264        self - SymDim::from(rhs)
265    }
266}
267
268impl Sub<SymDim> for usize {
269    type Output = SymDim;
270
271    fn sub(self, rhs: SymDim) -> Self::Output {
272        SymDim::from(self) - rhs
273    }
274}
275
276impl Mul for SymDim {
277    type Output = SymDim;
278
279    fn mul(self, rhs: SymDim) -> Self::Output {
280        binary(self, rhs, |lhs, rhs| {
281            RawSymDim::Mul(Box::new(lhs), Box::new(rhs))
282        })
283    }
284}
285
286impl Mul<usize> for SymDim {
287    type Output = SymDim;
288
289    fn mul(self, rhs: usize) -> Self::Output {
290        self * SymDim::from(rhs)
291    }
292}
293
294impl Mul<SymDim> for usize {
295    type Output = SymDim;
296
297    fn mul(self, rhs: SymDim) -> Self::Output {
298        SymDim::from(self) * rhs
299    }
300}
301
302impl Div for SymDim {
303    type Output = SymDim;
304
305    fn div(self, rhs: SymDim) -> Self::Output {
306        binary(self, rhs, |lhs, rhs| {
307            RawSymDim::FloorDiv(Box::new(lhs), Box::new(rhs))
308        })
309    }
310}
311
312impl Div<usize> for SymDim {
313    type Output = SymDim;
314
315    fn div(self, rhs: usize) -> Self::Output {
316        self / SymDim::from(rhs)
317    }
318}
319
320impl Div<SymDim> for usize {
321    type Output = SymDim;
322
323    fn div(self, rhs: SymDim) -> Self::Output {
324        SymDim::from(self) / rhs
325    }
326}