Skip to main content

tidu/rules/
ad_rule_error.rs

1use std::error::Error;
2use std::fmt;
3
4/// Identifies which AD rule failed or is unavailable.
5///
6/// # Examples
7///
8/// ```
9/// use tidu::ADRuleKind;
10///
11/// assert_eq!(ADRuleKind::Jvp.as_str(), "jvp");
12/// assert_eq!(ADRuleKind::Transpose.as_str(), "transpose");
13/// ```
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub enum ADRuleKind {
16    /// JVP rule for forward linearization.
17    Jvp,
18    /// Transpose / VJP rule for a linear primitive.
19    Transpose,
20}
21
22impl ADRuleKind {
23    /// Returns a stable human-readable rule name.
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// use tidu::ADRuleKind;
29    ///
30    /// assert_eq!(ADRuleKind::Jvp.as_str(), "jvp");
31    /// ```
32    pub const fn as_str(self) -> &'static str {
33        match self {
34            Self::Jvp => "jvp",
35            Self::Transpose => "transpose",
36        }
37    }
38}
39
40/// Error returned when an AD rule cannot be emitted.
41///
42/// # Examples
43///
44/// ```
45/// use tidu::{ADRuleError, ADRuleKind};
46///
47/// let err = ADRuleError::unsupported("my_crate::fft", ADRuleKind::Jvp);
48/// assert_eq!(err.rule(), ADRuleKind::Jvp);
49/// assert!(err.to_string().contains("my_crate::fft"));
50/// ```
51#[derive(Clone, Debug, PartialEq, Eq)]
52pub enum ADRuleError {
53    /// The requested primitive does not provide the requested AD rule.
54    Unsupported {
55        /// Stable primitive name or extension family identifier.
56        op: String,
57        /// Missing rule kind.
58        rule: ADRuleKind,
59    },
60}
61
62impl ADRuleError {
63    /// Constructs an unsupported-rule error.
64    ///
65    /// # Examples
66    ///
67    /// ```
68    /// use tidu::{ADRuleError, ADRuleKind};
69    ///
70    /// let err = ADRuleError::unsupported("custom::op", ADRuleKind::Transpose);
71    /// assert_eq!(err.rule(), ADRuleKind::Transpose);
72    /// ```
73    pub fn unsupported(op: impl Into<String>, rule: ADRuleKind) -> Self {
74        Self::Unsupported {
75            op: op.into(),
76            rule,
77        }
78    }
79
80    /// Returns the AD rule kind associated with this error.
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// use tidu::{ADRuleError, ADRuleKind};
86    ///
87    /// let err = ADRuleError::unsupported("custom::op", ADRuleKind::Jvp);
88    /// assert_eq!(err.rule(), ADRuleKind::Jvp);
89    /// ```
90    #[cfg_attr(coverage, inline(never))]
91    pub const fn rule(&self) -> ADRuleKind {
92        match self {
93            Self::Unsupported { rule, .. } => *rule,
94        }
95    }
96}
97
98impl fmt::Display for ADRuleError {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            Self::Unsupported { op, rule } => {
102                write!(f, "unsupported {} AD rule for {}", rule.as_str(), op)
103            }
104        }
105    }
106}
107
108impl Error for ADRuleError {}
109
110/// Result type used by fallible AD rule emission.
111///
112/// # Examples
113///
114/// ```
115/// use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
116///
117/// fn missing_rule() -> ADRuleResult<()> {
118///     Err(ADRuleError::unsupported("custom::op", ADRuleKind::Transpose))
119/// }
120///
121/// assert!(missing_rule().is_err());
122/// ```
123pub type ADRuleResult<T> = Result<T, ADRuleError>;