Skip to main content

tenferro_ops/ad/
context.rs

1//! AD context for guard-based shape resolution.
2//!
3//! During AD graph construction, linalg rules such as SVD, QR, and LU need
4//! concrete matrix dimensions to choose between structurally different
5//! subgraphs. `ShapeGuardContext` records those dimension comparisons as guards
6//! so cached AD graphs can later be invalidated when the observed shape
7//! relationship changes.
8
9use std::cmp::Ordering;
10
11use crate::dim_expr::DimExpr;
12
13/// A recorded dimension comparison made during AD graph construction.
14///
15/// # Examples
16///
17/// ```
18/// use std::cmp::Ordering;
19/// use tenferro_ops::ShapeGuard;
20///
21/// let guard = ShapeGuard {
22///     dim_a: 5,
23///     dim_b: 3,
24///     ordering: Ordering::Greater,
25/// };
26///
27/// assert_eq!(guard.ordering, Ordering::Greater);
28/// ```
29#[derive(Clone, Debug, PartialEq, Eq)]
30pub struct ShapeGuard {
31    /// First dimension value, such as `m`.
32    pub dim_a: usize,
33    /// Second dimension value, such as `n`.
34    pub dim_b: usize,
35    /// The observed ordering `dim_a.cmp(&dim_b)`.
36    pub ordering: Ordering,
37}
38
39/// AD context providing dimension resolution and guard recording.
40///
41/// # Examples
42///
43/// ```
44/// use tenferro_ops::ShapeGuardContext;
45///
46/// let ctx = ShapeGuardContext::default();
47/// assert!(ctx.guards().is_empty());
48/// ```
49#[derive(Clone, Debug, Default, PartialEq, Eq)]
50pub struct ShapeGuardContext {
51    guards: Vec<ShapeGuard>,
52}
53
54impl ShapeGuardContext {
55    /// Returns the guards recorded so far.
56    ///
57    /// # Examples
58    ///
59    /// ```
60    /// use tenferro_ops::ShapeGuardContext;
61    ///
62    /// let ctx = ShapeGuardContext::default();
63    /// assert_eq!(ctx.guards(), &[]);
64    /// ```
65    pub fn guards(&self) -> &[ShapeGuard] {
66        &self.guards
67    }
68
69    /// Clears all recorded guards.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use tenferro_ops::ShapeGuardContext;
75    ///
76    /// let mut ctx = ShapeGuardContext::default();
77    /// ctx.clear_guards();
78    /// assert!(ctx.guards().is_empty());
79    /// ```
80    pub fn clear_guards(&mut self) {
81        self.guards.clear();
82    }
83}
84
85/// Resolve a [`DimExpr`] to a concrete `usize`.
86///
87/// Currently this evaluates the expression without any input shapes, which
88/// works for `DimExpr::Const` and expressions composed entirely from constants.
89/// `DimExpr::InputDim` references will panic, which is currently a programming
90/// invariant enforced by the linalg AD callers.
91pub(crate) fn resolve_dim(dim: &DimExpr) -> usize {
92    dim.eval(&[])
93}
94
95/// Resolve matrix dimensions and record their ordering as a guard.
96pub(crate) fn resolve_and_guard(
97    m: &DimExpr,
98    n: &DimExpr,
99    ctx: &mut ShapeGuardContext,
100) -> (usize, usize) {
101    let m_size = resolve_dim(m);
102    let n_size = resolve_dim(n);
103    ctx.guards.push(ShapeGuard {
104        dim_a: m_size,
105        dim_b: n_size,
106        ordering: m_size.cmp(&n_size),
107    });
108    (m_size, n_size)
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use std::cmp::Ordering;
115
116    #[test]
117    fn resolve_dim_const() {
118        assert_eq!(resolve_dim(&DimExpr::Const(7)), 7);
119    }
120
121    #[test]
122    fn resolve_dim_const_expr() {
123        let expr = DimExpr::min(DimExpr::Const(3), DimExpr::Const(5));
124        assert_eq!(resolve_dim(&expr), 3);
125    }
126
127    #[test]
128    fn resolve_and_guard_records_greater() {
129        let mut ctx = ShapeGuardContext::default();
130        let (m, n) = resolve_and_guard(&DimExpr::Const(5), &DimExpr::Const(3), &mut ctx);
131        assert_eq!((m, n), (5, 3));
132        assert_eq!(ctx.guards().len(), 1);
133        assert_eq!(
134            ctx.guards()[0],
135            ShapeGuard {
136                dim_a: 5,
137                dim_b: 3,
138                ordering: Ordering::Greater,
139            }
140        );
141    }
142
143    #[test]
144    fn resolve_and_guard_records_less() {
145        let mut ctx = ShapeGuardContext::default();
146        let (m, n) = resolve_and_guard(&DimExpr::Const(2), &DimExpr::Const(4), &mut ctx);
147        assert_eq!((m, n), (2, 4));
148        assert_eq!(ctx.guards()[0].ordering, Ordering::Less);
149    }
150
151    #[test]
152    fn resolve_and_guard_records_equal() {
153        let mut ctx = ShapeGuardContext::default();
154        let (m, n) = resolve_and_guard(&DimExpr::Const(3), &DimExpr::Const(3), &mut ctx);
155        assert_eq!((m, n), (3, 3));
156        assert_eq!(ctx.guards()[0].ordering, Ordering::Equal);
157    }
158
159    #[test]
160    fn guards_accumulate() {
161        let mut ctx = ShapeGuardContext::default();
162        resolve_and_guard(&DimExpr::Const(5), &DimExpr::Const(3), &mut ctx);
163        resolve_and_guard(&DimExpr::Const(2), &DimExpr::Const(4), &mut ctx);
164        assert_eq!(ctx.guards().len(), 2);
165        assert_eq!(ctx.guards()[0].ordering, Ordering::Greater);
166        assert_eq!(ctx.guards()[1].ordering, Ordering::Less);
167    }
168
169    #[test]
170    fn clear_guards_empties() {
171        let mut ctx = ShapeGuardContext::default();
172        resolve_and_guard(&DimExpr::Const(5), &DimExpr::Const(3), &mut ctx);
173        assert_eq!(ctx.guards().len(), 1);
174        ctx.clear_guards();
175        assert!(ctx.guards().is_empty());
176    }
177}