1#[cfg(test)]
2use crate::extension::LinalgOp;
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
15pub enum LinalgAdRuleSupport {
16 Supported,
17 SupportedViaLinearize,
18 PartiallySupported,
19 NonDifferentiable,
20 Unsupported,
21 PendingOracle,
22}
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum LinalgAdOpKind {
35 Cholesky,
36 Lu,
37 LuFactor,
38 LuSolvePrepared,
39 FullPivLu,
40 FullPivLuSolve,
41 Svd,
42 SvdVals,
43 Qr,
44 Eigh,
45 EighVals,
46 Eig,
47 EigVals,
48 TriangularSolve,
49}
50
51impl LinalgAdOpKind {
52 pub const COUNT: usize = 14;
53
54 pub const fn as_index(self) -> usize {
64 match self {
65 Self::Cholesky => 0,
66 Self::Lu => 1,
67 Self::LuFactor => 2,
68 Self::LuSolvePrepared => 3,
69 Self::FullPivLu => 4,
70 Self::FullPivLuSolve => 5,
71 Self::Svd => 6,
72 Self::SvdVals => 7,
73 Self::Qr => 8,
74 Self::Eigh => 9,
75 Self::EighVals => 10,
76 Self::Eig => 11,
77 Self::EigVals => 12,
78 Self::TriangularSolve => 13,
79 }
80 }
81
82 #[cfg(test)]
83 pub(crate) const fn from_linalg_op(op: LinalgOp) -> Self {
84 match op {
85 LinalgOp::Cholesky => Self::Cholesky,
86 LinalgOp::Lu => Self::Lu,
87 LinalgOp::LuFactor => Self::LuFactor,
88 LinalgOp::LuSolvePrepared { .. } => Self::LuSolvePrepared,
89 LinalgOp::FullPivLu => Self::FullPivLu,
90 LinalgOp::FullPivLuSolve { .. } => Self::FullPivLuSolve,
91 LinalgOp::Svd { .. } => Self::Svd,
92 LinalgOp::SvdVals { .. } => Self::SvdVals,
93 LinalgOp::Qr => Self::Qr,
94 LinalgOp::Eigh { .. } => Self::Eigh,
95 LinalgOp::EighVals { .. } => Self::EighVals,
96 LinalgOp::Eig { .. } => Self::Eig,
97 LinalgOp::EigVals { .. } => Self::EigVals,
98 LinalgOp::TriangularSolve { .. } => Self::TriangularSolve,
99 }
100 }
101}
102
103#[derive(Clone, Copy, Debug, PartialEq, Eq)]
115pub struct LinalgAdOutputSupport {
116 pub index: usize,
118 pub name: &'static str,
120 pub status: LinalgAdRuleSupport,
122}
123
124#[derive(Clone, Copy, Debug, PartialEq, Eq)]
135pub struct LinalgAdSupport {
136 pub kind: LinalgAdOpKind,
138 pub linearize: LinalgAdRuleSupport,
140 pub transpose: LinalgAdRuleSupport,
142 pub outputs: &'static [LinalgAdOutputSupport],
144}
145
146const fn output(
147 index: usize,
148 name: &'static str,
149 status: LinalgAdRuleSupport,
150) -> LinalgAdOutputSupport {
151 LinalgAdOutputSupport {
152 index,
153 name,
154 status,
155 }
156}
157
158static CHOLESKY_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
159 0,
160 "factor",
161 LinalgAdRuleSupport::SupportedViaLinearize,
162)];
163static LU_OUTPUTS: [LinalgAdOutputSupport; 4] = [
164 output(0, "p", LinalgAdRuleSupport::NonDifferentiable),
165 output(1, "l", LinalgAdRuleSupport::SupportedViaLinearize),
166 output(2, "u", LinalgAdRuleSupport::SupportedViaLinearize),
167 output(3, "parity", LinalgAdRuleSupport::NonDifferentiable),
168];
169static LU_FACTOR_OUTPUTS: [LinalgAdOutputSupport; 3] = [
170 output(0, "packed_lu", LinalgAdRuleSupport::Unsupported),
171 output(1, "pivots", LinalgAdRuleSupport::NonDifferentiable),
172 output(2, "parity", LinalgAdRuleSupport::NonDifferentiable),
173];
174static SOLUTION_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
175 0,
176 "solution",
177 LinalgAdRuleSupport::SupportedViaLinearize,
178)];
179static FULL_PIV_LU_OUTPUTS: [LinalgAdOutputSupport; 5] = [
180 output(0, "p", LinalgAdRuleSupport::NonDifferentiable),
181 output(1, "l", LinalgAdRuleSupport::SupportedViaLinearize),
182 output(2, "u", LinalgAdRuleSupport::SupportedViaLinearize),
183 output(3, "q", LinalgAdRuleSupport::NonDifferentiable),
184 output(4, "parity", LinalgAdRuleSupport::NonDifferentiable),
185];
186static FULL_PIV_LU_SOLVE_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
187 0,
188 "solution",
189 LinalgAdRuleSupport::SupportedViaLinearize,
190)];
191static SVD_OUTPUTS: [LinalgAdOutputSupport; 3] = [
192 output(0, "u", LinalgAdRuleSupport::SupportedViaLinearize),
193 output(
194 1,
195 "singular_values",
196 LinalgAdRuleSupport::SupportedViaLinearize,
197 ),
198 output(2, "vt", LinalgAdRuleSupport::SupportedViaLinearize),
199];
200static SVD_VALS_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
201 0,
202 "singular_values",
203 LinalgAdRuleSupport::SupportedViaLinearize,
204)];
205static QR_OUTPUTS: [LinalgAdOutputSupport; 2] = [
206 output(0, "q", LinalgAdRuleSupport::SupportedViaLinearize),
207 output(1, "r", LinalgAdRuleSupport::SupportedViaLinearize),
208];
209static EIGH_OUTPUTS: [LinalgAdOutputSupport; 2] = [
210 output(0, "eigenvalues", LinalgAdRuleSupport::SupportedViaLinearize),
211 output(
212 1,
213 "eigenvectors",
214 LinalgAdRuleSupport::SupportedViaLinearize,
215 ),
216];
217static EIGH_VALS_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
218 0,
219 "eigenvalues",
220 LinalgAdRuleSupport::SupportedViaLinearize,
221)];
222static EIG_OUTPUTS: [LinalgAdOutputSupport; 2] = [
223 output(0, "eigenvalues", LinalgAdRuleSupport::SupportedViaLinearize),
224 output(1, "eigenvectors", LinalgAdRuleSupport::Unsupported),
225];
226static EIG_VALS_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
227 0,
228 "eigenvalues",
229 LinalgAdRuleSupport::SupportedViaLinearize,
230)];
231
232static LINALG_AD_SUPPORT: [LinalgAdSupport; LinalgAdOpKind::COUNT] = [
233 LinalgAdSupport {
234 kind: LinalgAdOpKind::Cholesky,
235 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
236 transpose: LinalgAdRuleSupport::Unsupported,
237 outputs: &CHOLESKY_OUTPUTS,
238 },
239 LinalgAdSupport {
240 kind: LinalgAdOpKind::Lu,
241 linearize: LinalgAdRuleSupport::PartiallySupported,
242 transpose: LinalgAdRuleSupport::Unsupported,
243 outputs: &LU_OUTPUTS,
244 },
245 LinalgAdSupport {
246 kind: LinalgAdOpKind::LuFactor,
247 linearize: LinalgAdRuleSupport::Unsupported,
248 transpose: LinalgAdRuleSupport::Unsupported,
249 outputs: &LU_FACTOR_OUTPUTS,
250 },
251 LinalgAdSupport {
252 kind: LinalgAdOpKind::LuSolvePrepared,
253 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
254 transpose: LinalgAdRuleSupport::PartiallySupported,
255 outputs: &SOLUTION_OUTPUTS,
256 },
257 LinalgAdSupport {
258 kind: LinalgAdOpKind::FullPivLu,
259 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
260 transpose: LinalgAdRuleSupport::Unsupported,
261 outputs: &FULL_PIV_LU_OUTPUTS,
262 },
263 LinalgAdSupport {
264 kind: LinalgAdOpKind::FullPivLuSolve,
265 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
266 transpose: LinalgAdRuleSupport::Supported,
267 outputs: &FULL_PIV_LU_SOLVE_OUTPUTS,
268 },
269 LinalgAdSupport {
270 kind: LinalgAdOpKind::Svd,
271 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
272 transpose: LinalgAdRuleSupport::Unsupported,
273 outputs: &SVD_OUTPUTS,
274 },
275 LinalgAdSupport {
276 kind: LinalgAdOpKind::SvdVals,
277 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
278 transpose: LinalgAdRuleSupport::Unsupported,
279 outputs: &SVD_VALS_OUTPUTS,
280 },
281 LinalgAdSupport {
282 kind: LinalgAdOpKind::Qr,
283 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
284 transpose: LinalgAdRuleSupport::Unsupported,
285 outputs: &QR_OUTPUTS,
286 },
287 LinalgAdSupport {
288 kind: LinalgAdOpKind::Eigh,
289 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
290 transpose: LinalgAdRuleSupport::Unsupported,
291 outputs: &EIGH_OUTPUTS,
292 },
293 LinalgAdSupport {
294 kind: LinalgAdOpKind::EighVals,
295 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
296 transpose: LinalgAdRuleSupport::Unsupported,
297 outputs: &EIGH_VALS_OUTPUTS,
298 },
299 LinalgAdSupport {
300 kind: LinalgAdOpKind::Eig,
301 linearize: LinalgAdRuleSupport::PartiallySupported,
302 transpose: LinalgAdRuleSupport::Unsupported,
303 outputs: &EIG_OUTPUTS,
304 },
305 LinalgAdSupport {
306 kind: LinalgAdOpKind::EigVals,
307 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
308 transpose: LinalgAdRuleSupport::Unsupported,
309 outputs: &EIG_VALS_OUTPUTS,
310 },
311 LinalgAdSupport {
312 kind: LinalgAdOpKind::TriangularSolve,
313 linearize: LinalgAdRuleSupport::SupportedViaLinearize,
314 transpose: LinalgAdRuleSupport::Supported,
315 outputs: &SOLUTION_OUTPUTS,
316 },
317];
318
319pub fn all_linalg_ad_support() -> &'static [LinalgAdSupport; LinalgAdOpKind::COUNT] {
328 &LINALG_AD_SUPPORT
329}
330
331pub fn linalg_ad_support(kind: LinalgAdOpKind) -> &'static LinalgAdSupport {
342 &LINALG_AD_SUPPORT[kind.as_index()]
343}
344
345#[cfg(test)]
346pub(crate) fn linalg_ad_support_for_op(op: LinalgOp) -> &'static LinalgAdSupport {
347 linalg_ad_support(LinalgAdOpKind::from_linalg_op(op))
348}
349
350#[cfg(test)]
351mod tests;