tenferro_ops/ad/
context.rs1use std::cmp::Ordering;
10
11use crate::dim_expr::DimExpr;
12
13#[derive(Clone, Debug, PartialEq, Eq)]
30pub struct ShapeGuard {
31 pub dim_a: usize,
33 pub dim_b: usize,
35 pub ordering: Ordering,
37}
38
39#[derive(Clone, Debug, Default, PartialEq, Eq)]
50pub struct ShapeGuardContext {
51 guards: Vec<ShapeGuard>,
52}
53
54impl ShapeGuardContext {
55 pub fn guards(&self) -> &[ShapeGuard] {
66 &self.guards
67 }
68
69 pub fn clear_guards(&mut self) {
81 self.guards.clear();
82 }
83}
84
85pub(crate) fn resolve_dim(dim: &DimExpr) -> usize {
92 dim.eval(&[])
93}
94
95pub(crate) fn resolve_and_guard(
97 m: &DimExpr,
98 n: &DimExpr,
99 ctx: &mut ShapeGuardContext,
100) -> (usize, usize) {
101 let m_size = resolve_dim(m);
102 let n_size = resolve_dim(n);
103 ctx.guards.push(ShapeGuard {
104 dim_a: m_size,
105 dim_b: n_size,
106 ordering: m_size.cmp(&n_size),
107 });
108 (m_size, n_size)
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use std::cmp::Ordering;
115
116 #[test]
117 fn resolve_dim_const() {
118 assert_eq!(resolve_dim(&DimExpr::Const(7)), 7);
119 }
120
121 #[test]
122 fn resolve_dim_const_expr() {
123 let expr = DimExpr::min(DimExpr::Const(3), DimExpr::Const(5));
124 assert_eq!(resolve_dim(&expr), 3);
125 }
126
127 #[test]
128 fn resolve_and_guard_records_greater() {
129 let mut ctx = ShapeGuardContext::default();
130 let (m, n) = resolve_and_guard(&DimExpr::Const(5), &DimExpr::Const(3), &mut ctx);
131 assert_eq!((m, n), (5, 3));
132 assert_eq!(ctx.guards().len(), 1);
133 assert_eq!(
134 ctx.guards()[0],
135 ShapeGuard {
136 dim_a: 5,
137 dim_b: 3,
138 ordering: Ordering::Greater,
139 }
140 );
141 }
142
143 #[test]
144 fn resolve_and_guard_records_less() {
145 let mut ctx = ShapeGuardContext::default();
146 let (m, n) = resolve_and_guard(&DimExpr::Const(2), &DimExpr::Const(4), &mut ctx);
147 assert_eq!((m, n), (2, 4));
148 assert_eq!(ctx.guards()[0].ordering, Ordering::Less);
149 }
150
151 #[test]
152 fn resolve_and_guard_records_equal() {
153 let mut ctx = ShapeGuardContext::default();
154 let (m, n) = resolve_and_guard(&DimExpr::Const(3), &DimExpr::Const(3), &mut ctx);
155 assert_eq!((m, n), (3, 3));
156 assert_eq!(ctx.guards()[0].ordering, Ordering::Equal);
157 }
158
159 #[test]
160 fn guards_accumulate() {
161 let mut ctx = ShapeGuardContext::default();
162 resolve_and_guard(&DimExpr::Const(5), &DimExpr::Const(3), &mut ctx);
163 resolve_and_guard(&DimExpr::Const(2), &DimExpr::Const(4), &mut ctx);
164 assert_eq!(ctx.guards().len(), 2);
165 assert_eq!(ctx.guards()[0].ordering, Ordering::Greater);
166 assert_eq!(ctx.guards()[1].ordering, Ordering::Less);
167 }
168
169 #[test]
170 fn clear_guards_empties() {
171 let mut ctx = ShapeGuardContext::default();
172 resolve_and_guard(&DimExpr::Const(5), &DimExpr::Const(3), &mut ctx);
173 assert_eq!(ctx.guards().len(), 1);
174 ctx.clear_guards();
175 assert!(ctx.guards().is_empty());
176 }
177}