Skip to main content

tenferro/
sym_dim.rs

1use std::ops::{Add, Div, Mul, Sub};
2
3use tenferro_ops::dim_expr::DimExpr;
4
5use crate::traced::TracedTensorId;
6
7/// A symbolic tensor dimension expression used to build shape-agnostic graphs.
8///
9/// `SymDim` values are produced from traced tensor axes via
10/// [`TracedTensor::sym_size`](crate::traced::TracedTensor::sym_size) and can be
11/// combined with basic arithmetic before being passed to
12/// [`TracedTensor::reshape_sym`](crate::traced::TracedTensor::reshape_sym).
13///
14/// # Examples
15///
16/// ```rust,ignore
17/// let rows = x.sym_size(0);
18/// let cols = x.sym_size(1);
19/// let y = x.reshape_sym(&[rows * cols])?;
20/// ```
21#[derive(Clone, Debug, PartialEq, Eq)]
22pub struct SymDim(pub(crate) RawSymDim);
23
24#[derive(Clone, Debug, PartialEq, Eq)]
25pub(crate) enum RawSymDim {
26    Const(usize),
27    TensorAxis {
28        tensor_id: TracedTensorId,
29        axis: usize,
30    },
31    Add(Box<RawSymDim>, Box<RawSymDim>),
32    Sub(Box<RawSymDim>, Box<RawSymDim>),
33    Mul(Box<RawSymDim>, Box<RawSymDim>),
34    FloorDiv(Box<RawSymDim>, Box<RawSymDim>),
35    Min(Box<RawSymDim>, Box<RawSymDim>),
36    Max(Box<RawSymDim>, Box<RawSymDim>),
37}
38
39impl SymDim {
40    /// Return the smaller of two symbolic dimensions.
41    ///
42    /// # Examples
43    ///
44    /// ```rust,ignore
45    /// let clipped = x.sym_size(0).min(8usize);
46    /// ```
47    pub fn min(self, other: impl Into<SymDim>) -> Self {
48        let other = other.into();
49        binary(self, other, |lhs, rhs| {
50            RawSymDim::Min(Box::new(lhs), Box::new(rhs))
51        })
52    }
53
54    /// Return the larger of two symbolic dimensions.
55    ///
56    /// # Examples
57    ///
58    /// ```rust,ignore
59    /// let batch = x.sym_size(0).max(1usize);
60    /// ```
61    pub fn max(self, other: impl Into<SymDim>) -> Self {
62        let other = other.into();
63        binary(self, other, |lhs, rhs| {
64            RawSymDim::Max(Box::new(lhs), Box::new(rhs))
65        })
66    }
67
68    pub(crate) fn tensor_axis(tensor_id: TracedTensorId, axis: usize) -> Self {
69        Self(RawSymDim::TensorAxis { tensor_id, axis })
70    }
71
72    pub(crate) fn constant_value(&self) -> Option<usize> {
73        match &self.0 {
74            RawSymDim::Const(value) => Some(*value),
75            _ => None,
76        }
77    }
78
79    pub(crate) fn to_dim_expr(
80        &self,
81        tensor_map: &[(TracedTensorId, usize)],
82    ) -> std::result::Result<DimExpr, String> {
83        raw_to_dim_expr(&self.0, tensor_map)
84    }
85}
86
87impl From<usize> for SymDim {
88    fn from(value: usize) -> Self {
89        Self(RawSymDim::Const(value))
90    }
91}
92
93fn raw_to_dim_expr(
94    raw: &RawSymDim,
95    tensor_map: &[(TracedTensorId, usize)],
96) -> std::result::Result<DimExpr, String> {
97    Ok(match raw {
98        RawSymDim::Const(value) => DimExpr::Const(*value),
99        RawSymDim::TensorAxis { tensor_id, axis } => {
100            let input_idx = tensor_map
101                .iter()
102                .find_map(|(candidate_id, input_idx)| {
103                    (candidate_id == tensor_id).then_some(*input_idx)
104                })
105                .ok_or_else(|| format!("unknown symbolic tensor id {tensor_id}"))?;
106            DimExpr::InputDim {
107                input_idx,
108                axis: *axis,
109            }
110        }
111        RawSymDim::Add(lhs, rhs) => DimExpr::add(
112            raw_to_dim_expr(lhs, tensor_map)?,
113            raw_to_dim_expr(rhs, tensor_map)?,
114        ),
115        RawSymDim::Sub(lhs, rhs) => DimExpr::sub(
116            raw_to_dim_expr(lhs, tensor_map)?,
117            raw_to_dim_expr(rhs, tensor_map)?,
118        ),
119        RawSymDim::Mul(lhs, rhs) => DimExpr::mul(
120            raw_to_dim_expr(lhs, tensor_map)?,
121            raw_to_dim_expr(rhs, tensor_map)?,
122        ),
123        RawSymDim::FloorDiv(lhs, rhs) => DimExpr::floor_div(
124            raw_to_dim_expr(lhs, tensor_map)?,
125            raw_to_dim_expr(rhs, tensor_map)?,
126        ),
127        RawSymDim::Min(lhs, rhs) => DimExpr::min(
128            raw_to_dim_expr(lhs, tensor_map)?,
129            raw_to_dim_expr(rhs, tensor_map)?,
130        ),
131        RawSymDim::Max(lhs, rhs) => DimExpr::max(
132            raw_to_dim_expr(lhs, tensor_map)?,
133            raw_to_dim_expr(rhs, tensor_map)?,
134        ),
135    })
136}
137
138fn binary(lhs: SymDim, rhs: SymDim, f: impl FnOnce(RawSymDim, RawSymDim) -> RawSymDim) -> SymDim {
139    SymDim(f(lhs.0, rhs.0))
140}
141
142impl Add for SymDim {
143    type Output = SymDim;
144
145    fn add(self, rhs: SymDim) -> Self::Output {
146        binary(self, rhs, |lhs, rhs| {
147            RawSymDim::Add(Box::new(lhs), Box::new(rhs))
148        })
149    }
150}
151
152impl Add<usize> for SymDim {
153    type Output = SymDim;
154
155    fn add(self, rhs: usize) -> Self::Output {
156        self + SymDim::from(rhs)
157    }
158}
159
160impl Add<SymDim> for usize {
161    type Output = SymDim;
162
163    fn add(self, rhs: SymDim) -> Self::Output {
164        SymDim::from(self) + rhs
165    }
166}
167
168impl Sub for SymDim {
169    type Output = SymDim;
170
171    fn sub(self, rhs: SymDim) -> Self::Output {
172        binary(self, rhs, |lhs, rhs| {
173            RawSymDim::Sub(Box::new(lhs), Box::new(rhs))
174        })
175    }
176}
177
178impl Sub<usize> for SymDim {
179    type Output = SymDim;
180
181    fn sub(self, rhs: usize) -> Self::Output {
182        self - SymDim::from(rhs)
183    }
184}
185
186impl Sub<SymDim> for usize {
187    type Output = SymDim;
188
189    fn sub(self, rhs: SymDim) -> Self::Output {
190        SymDim::from(self) - rhs
191    }
192}
193
194impl Mul for SymDim {
195    type Output = SymDim;
196
197    fn mul(self, rhs: SymDim) -> Self::Output {
198        binary(self, rhs, |lhs, rhs| {
199            RawSymDim::Mul(Box::new(lhs), Box::new(rhs))
200        })
201    }
202}
203
204impl Mul<usize> for SymDim {
205    type Output = SymDim;
206
207    fn mul(self, rhs: usize) -> Self::Output {
208        self * SymDim::from(rhs)
209    }
210}
211
212impl Mul<SymDim> for usize {
213    type Output = SymDim;
214
215    fn mul(self, rhs: SymDim) -> Self::Output {
216        SymDim::from(self) * rhs
217    }
218}
219
220impl Div for SymDim {
221    type Output = SymDim;
222
223    fn div(self, rhs: SymDim) -> Self::Output {
224        binary(self, rhs, |lhs, rhs| {
225            RawSymDim::FloorDiv(Box::new(lhs), Box::new(rhs))
226        })
227    }
228}
229
230impl Div<usize> for SymDim {
231    type Output = SymDim;
232
233    fn div(self, rhs: usize) -> Self::Output {
234        self / SymDim::from(rhs)
235    }
236}
237
238impl Div<SymDim> for usize {
239    type Output = SymDim;
240
241    fn div(self, rhs: SymDim) -> Self::Output {
242        SymDim::from(self) / rhs
243    }
244}