Skip to main content

tenferro_linalg/ad/
support.rs

1#[cfg(test)]
2use crate::extension::LinalgOp;
3
4/// AD rule support status for a linalg operation or output.
5///
6/// # Examples
7///
8/// ```rust
9/// use tenferro_linalg::{linalg_ad_support, LinalgAdOpKind, LinalgAdRuleSupport};
10///
11/// let svd = linalg_ad_support(LinalgAdOpKind::Svd);
12/// assert_eq!(svd.linearize, LinalgAdRuleSupport::SupportedViaLinearize);
13/// ```
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
15pub enum LinalgAdRuleSupport {
16    Supported,
17    SupportedViaLinearize,
18    PartiallySupported,
19    NonDifferentiable,
20    Unsupported,
21    PendingOracle,
22}
23
24/// Operation keys covered by the linalg AD support manifest.
25///
26/// # Examples
27///
28/// ```rust
29/// use tenferro_linalg::LinalgAdOpKind;
30///
31/// assert!(LinalgAdOpKind::Svd.as_index() < LinalgAdOpKind::COUNT);
32/// ```
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum LinalgAdOpKind {
35    Cholesky,
36    Lu,
37    LuFactor,
38    LuSolvePrepared,
39    FullPivLu,
40    FullPivLuSolve,
41    Svd,
42    SvdVals,
43    Qr,
44    Eigh,
45    EighVals,
46    Eig,
47    EigVals,
48    TriangularSolve,
49}
50
51impl LinalgAdOpKind {
52    pub const COUNT: usize = 14;
53
54    /// Return the manifest index for this operation kind.
55    ///
56    /// # Examples
57    ///
58    /// ```rust
59    /// use tenferro_linalg::LinalgAdOpKind;
60    ///
61    /// assert_eq!(LinalgAdOpKind::Cholesky.as_index(), 0);
62    /// ```
63    pub const fn as_index(self) -> usize {
64        match self {
65            Self::Cholesky => 0,
66            Self::Lu => 1,
67            Self::LuFactor => 2,
68            Self::LuSolvePrepared => 3,
69            Self::FullPivLu => 4,
70            Self::FullPivLuSolve => 5,
71            Self::Svd => 6,
72            Self::SvdVals => 7,
73            Self::Qr => 8,
74            Self::Eigh => 9,
75            Self::EighVals => 10,
76            Self::Eig => 11,
77            Self::EigVals => 12,
78            Self::TriangularSolve => 13,
79        }
80    }
81
82    #[cfg(test)]
83    pub(crate) const fn from_linalg_op(op: LinalgOp) -> Self {
84        match op {
85            LinalgOp::Cholesky => Self::Cholesky,
86            LinalgOp::Lu => Self::Lu,
87            LinalgOp::LuFactor => Self::LuFactor,
88            LinalgOp::LuSolvePrepared { .. } => Self::LuSolvePrepared,
89            LinalgOp::FullPivLu => Self::FullPivLu,
90            LinalgOp::FullPivLuSolve { .. } => Self::FullPivLuSolve,
91            LinalgOp::Svd { .. } => Self::Svd,
92            LinalgOp::SvdVals { .. } => Self::SvdVals,
93            LinalgOp::Qr => Self::Qr,
94            LinalgOp::Eigh { .. } => Self::Eigh,
95            LinalgOp::EighVals { .. } => Self::EighVals,
96            LinalgOp::Eig { .. } => Self::Eig,
97            LinalgOp::EigVals { .. } => Self::EigVals,
98            LinalgOp::TriangularSolve { .. } => Self::TriangularSolve,
99        }
100    }
101}
102
103/// AD support status for one output of a linalg operation.
104///
105/// # Examples
106///
107/// ```rust
108/// use tenferro_linalg::{linalg_ad_support, LinalgAdOpKind, LinalgAdRuleSupport};
109///
110/// let full_piv_lu = linalg_ad_support(LinalgAdOpKind::FullPivLu);
111/// let l_output = full_piv_lu.outputs.iter().find(|output| output.name == "l").unwrap();
112/// assert_eq!(l_output.status, LinalgAdRuleSupport::SupportedViaLinearize);
113/// ```
114#[derive(Clone, Copy, Debug, PartialEq, Eq)]
115pub struct LinalgAdOutputSupport {
116    /// Output position in the linalg operation result tuple.
117    pub index: usize,
118    /// Stable output name used by tests and support dashboards.
119    pub name: &'static str,
120    /// AD support status for this specific output.
121    pub status: LinalgAdRuleSupport,
122}
123
124/// AD support manifest entry for one linalg operation.
125///
126/// # Examples
127///
128/// ```rust
129/// use tenferro_linalg::{linalg_ad_support, LinalgAdOpKind, LinalgAdRuleSupport};
130///
131/// let solve = linalg_ad_support(LinalgAdOpKind::TriangularSolve);
132/// assert_eq!(solve.transpose, LinalgAdRuleSupport::Supported);
133/// ```
134#[derive(Clone, Copy, Debug, PartialEq, Eq)]
135pub struct LinalgAdSupport {
136    /// Operation kind described by this manifest entry.
137    pub kind: LinalgAdOpKind,
138    /// Forward-mode graph emission support.
139    pub linearize: LinalgAdRuleSupport,
140    /// Transposed-linear graph emission support.
141    pub transpose: LinalgAdRuleSupport,
142    /// Per-output support status for multi-output operations.
143    pub outputs: &'static [LinalgAdOutputSupport],
144}
145
146const fn output(
147    index: usize,
148    name: &'static str,
149    status: LinalgAdRuleSupport,
150) -> LinalgAdOutputSupport {
151    LinalgAdOutputSupport {
152        index,
153        name,
154        status,
155    }
156}
157
158static CHOLESKY_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
159    0,
160    "factor",
161    LinalgAdRuleSupport::SupportedViaLinearize,
162)];
163static LU_OUTPUTS: [LinalgAdOutputSupport; 4] = [
164    output(0, "p", LinalgAdRuleSupport::NonDifferentiable),
165    output(1, "l", LinalgAdRuleSupport::SupportedViaLinearize),
166    output(2, "u", LinalgAdRuleSupport::SupportedViaLinearize),
167    output(3, "parity", LinalgAdRuleSupport::NonDifferentiable),
168];
169static LU_FACTOR_OUTPUTS: [LinalgAdOutputSupport; 3] = [
170    output(0, "packed_lu", LinalgAdRuleSupport::Unsupported),
171    output(1, "pivots", LinalgAdRuleSupport::NonDifferentiable),
172    output(2, "parity", LinalgAdRuleSupport::NonDifferentiable),
173];
174static SOLUTION_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
175    0,
176    "solution",
177    LinalgAdRuleSupport::SupportedViaLinearize,
178)];
179static FULL_PIV_LU_OUTPUTS: [LinalgAdOutputSupport; 5] = [
180    output(0, "p", LinalgAdRuleSupport::NonDifferentiable),
181    output(1, "l", LinalgAdRuleSupport::SupportedViaLinearize),
182    output(2, "u", LinalgAdRuleSupport::SupportedViaLinearize),
183    output(3, "q", LinalgAdRuleSupport::NonDifferentiable),
184    output(4, "parity", LinalgAdRuleSupport::NonDifferentiable),
185];
186static FULL_PIV_LU_SOLVE_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
187    0,
188    "solution",
189    LinalgAdRuleSupport::SupportedViaLinearize,
190)];
191static SVD_OUTPUTS: [LinalgAdOutputSupport; 3] = [
192    output(0, "u", LinalgAdRuleSupport::SupportedViaLinearize),
193    output(
194        1,
195        "singular_values",
196        LinalgAdRuleSupport::SupportedViaLinearize,
197    ),
198    output(2, "vt", LinalgAdRuleSupport::SupportedViaLinearize),
199];
200static SVD_VALS_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
201    0,
202    "singular_values",
203    LinalgAdRuleSupport::SupportedViaLinearize,
204)];
205static QR_OUTPUTS: [LinalgAdOutputSupport; 2] = [
206    output(0, "q", LinalgAdRuleSupport::SupportedViaLinearize),
207    output(1, "r", LinalgAdRuleSupport::SupportedViaLinearize),
208];
209static EIGH_OUTPUTS: [LinalgAdOutputSupport; 2] = [
210    output(0, "eigenvalues", LinalgAdRuleSupport::SupportedViaLinearize),
211    output(
212        1,
213        "eigenvectors",
214        LinalgAdRuleSupport::SupportedViaLinearize,
215    ),
216];
217static EIGH_VALS_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
218    0,
219    "eigenvalues",
220    LinalgAdRuleSupport::SupportedViaLinearize,
221)];
222static EIG_OUTPUTS: [LinalgAdOutputSupport; 2] = [
223    output(0, "eigenvalues", LinalgAdRuleSupport::SupportedViaLinearize),
224    output(1, "eigenvectors", LinalgAdRuleSupport::Unsupported),
225];
226static EIG_VALS_OUTPUTS: [LinalgAdOutputSupport; 1] = [output(
227    0,
228    "eigenvalues",
229    LinalgAdRuleSupport::SupportedViaLinearize,
230)];
231
232static LINALG_AD_SUPPORT: [LinalgAdSupport; LinalgAdOpKind::COUNT] = [
233    LinalgAdSupport {
234        kind: LinalgAdOpKind::Cholesky,
235        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
236        transpose: LinalgAdRuleSupport::Unsupported,
237        outputs: &CHOLESKY_OUTPUTS,
238    },
239    LinalgAdSupport {
240        kind: LinalgAdOpKind::Lu,
241        linearize: LinalgAdRuleSupport::PartiallySupported,
242        transpose: LinalgAdRuleSupport::Unsupported,
243        outputs: &LU_OUTPUTS,
244    },
245    LinalgAdSupport {
246        kind: LinalgAdOpKind::LuFactor,
247        linearize: LinalgAdRuleSupport::Unsupported,
248        transpose: LinalgAdRuleSupport::Unsupported,
249        outputs: &LU_FACTOR_OUTPUTS,
250    },
251    LinalgAdSupport {
252        kind: LinalgAdOpKind::LuSolvePrepared,
253        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
254        transpose: LinalgAdRuleSupport::PartiallySupported,
255        outputs: &SOLUTION_OUTPUTS,
256    },
257    LinalgAdSupport {
258        kind: LinalgAdOpKind::FullPivLu,
259        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
260        transpose: LinalgAdRuleSupport::Unsupported,
261        outputs: &FULL_PIV_LU_OUTPUTS,
262    },
263    LinalgAdSupport {
264        kind: LinalgAdOpKind::FullPivLuSolve,
265        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
266        transpose: LinalgAdRuleSupport::Supported,
267        outputs: &FULL_PIV_LU_SOLVE_OUTPUTS,
268    },
269    LinalgAdSupport {
270        kind: LinalgAdOpKind::Svd,
271        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
272        transpose: LinalgAdRuleSupport::Unsupported,
273        outputs: &SVD_OUTPUTS,
274    },
275    LinalgAdSupport {
276        kind: LinalgAdOpKind::SvdVals,
277        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
278        transpose: LinalgAdRuleSupport::Unsupported,
279        outputs: &SVD_VALS_OUTPUTS,
280    },
281    LinalgAdSupport {
282        kind: LinalgAdOpKind::Qr,
283        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
284        transpose: LinalgAdRuleSupport::Unsupported,
285        outputs: &QR_OUTPUTS,
286    },
287    LinalgAdSupport {
288        kind: LinalgAdOpKind::Eigh,
289        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
290        transpose: LinalgAdRuleSupport::Unsupported,
291        outputs: &EIGH_OUTPUTS,
292    },
293    LinalgAdSupport {
294        kind: LinalgAdOpKind::EighVals,
295        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
296        transpose: LinalgAdRuleSupport::Unsupported,
297        outputs: &EIGH_VALS_OUTPUTS,
298    },
299    LinalgAdSupport {
300        kind: LinalgAdOpKind::Eig,
301        linearize: LinalgAdRuleSupport::PartiallySupported,
302        transpose: LinalgAdRuleSupport::Unsupported,
303        outputs: &EIG_OUTPUTS,
304    },
305    LinalgAdSupport {
306        kind: LinalgAdOpKind::EigVals,
307        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
308        transpose: LinalgAdRuleSupport::Unsupported,
309        outputs: &EIG_VALS_OUTPUTS,
310    },
311    LinalgAdSupport {
312        kind: LinalgAdOpKind::TriangularSolve,
313        linearize: LinalgAdRuleSupport::SupportedViaLinearize,
314        transpose: LinalgAdRuleSupport::Supported,
315        outputs: &SOLUTION_OUTPUTS,
316    },
317];
318
319/// Return the complete linalg AD support manifest.
320///
321/// # Examples
322///
323/// ```rust
324/// let manifest = tenferro_linalg::all_linalg_ad_support();
325/// assert_eq!(manifest.len(), tenferro_linalg::LinalgAdOpKind::COUNT);
326/// ```
327pub fn all_linalg_ad_support() -> &'static [LinalgAdSupport; LinalgAdOpKind::COUNT] {
328    &LINALG_AD_SUPPORT
329}
330
331/// Return the support manifest entry for one linalg operation kind.
332///
333/// # Examples
334///
335/// ```rust
336/// use tenferro_linalg::{linalg_ad_support, LinalgAdOpKind};
337///
338/// let entry = linalg_ad_support(LinalgAdOpKind::Eigh);
339/// assert_eq!(entry.kind, LinalgAdOpKind::Eigh);
340/// ```
341pub fn linalg_ad_support(kind: LinalgAdOpKind) -> &'static LinalgAdSupport {
342    &LINALG_AD_SUPPORT[kind.as_index()]
343}
344
345#[cfg(test)]
346pub(crate) fn linalg_ad_support_for_op(op: LinalgOp) -> &'static LinalgAdSupport {
347    linalg_ad_support(LinalgAdOpKind::from_linalg_op(op))
348}
349
350#[cfg(test)]
351mod tests;