tidu/rules/ad_rule_error.rs
1use std::error::Error;
2use std::fmt;
3
4/// Identifies which AD rule failed or is unavailable.
5///
6/// # Examples
7///
8/// ```
9/// use tidu::ADRuleKind;
10///
11/// assert_eq!(ADRuleKind::Jvp.as_str(), "jvp");
12/// assert_eq!(ADRuleKind::Transpose.as_str(), "transpose");
13/// ```
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub enum ADRuleKind {
16 /// JVP rule for forward linearization.
17 Jvp,
18 /// Transpose / VJP rule for a linear primitive.
19 Transpose,
20}
21
22impl ADRuleKind {
23 /// Returns a stable human-readable rule name.
24 ///
25 /// # Examples
26 ///
27 /// ```
28 /// use tidu::ADRuleKind;
29 ///
30 /// assert_eq!(ADRuleKind::Jvp.as_str(), "jvp");
31 /// ```
32 pub const fn as_str(self) -> &'static str {
33 match self {
34 Self::Jvp => "jvp",
35 Self::Transpose => "transpose",
36 }
37 }
38}
39
40/// Error returned when an AD rule cannot be emitted.
41///
42/// # Examples
43///
44/// ```
45/// use tidu::{ADRuleError, ADRuleKind};
46///
47/// let err = ADRuleError::unsupported("my_crate::fft", ADRuleKind::Jvp);
48/// assert_eq!(err.rule(), ADRuleKind::Jvp);
49/// assert!(err.to_string().contains("my_crate::fft"));
50/// ```
51#[derive(Clone, Debug, PartialEq, Eq)]
52pub enum ADRuleError {
53 /// The requested primitive does not provide the requested AD rule.
54 Unsupported {
55 /// Stable primitive name or extension family identifier.
56 op: String,
57 /// Missing rule kind.
58 rule: ADRuleKind,
59 },
60}
61
62impl ADRuleError {
63 /// Constructs an unsupported-rule error.
64 ///
65 /// # Examples
66 ///
67 /// ```
68 /// use tidu::{ADRuleError, ADRuleKind};
69 ///
70 /// let err = ADRuleError::unsupported("custom::op", ADRuleKind::Transpose);
71 /// assert_eq!(err.rule(), ADRuleKind::Transpose);
72 /// ```
73 pub fn unsupported(op: impl Into<String>, rule: ADRuleKind) -> Self {
74 Self::Unsupported {
75 op: op.into(),
76 rule,
77 }
78 }
79
80 /// Returns the AD rule kind associated with this error.
81 ///
82 /// # Examples
83 ///
84 /// ```
85 /// use tidu::{ADRuleError, ADRuleKind};
86 ///
87 /// let err = ADRuleError::unsupported("custom::op", ADRuleKind::Jvp);
88 /// assert_eq!(err.rule(), ADRuleKind::Jvp);
89 /// ```
90 #[cfg_attr(coverage, inline(never))]
91 pub const fn rule(&self) -> ADRuleKind {
92 match self {
93 Self::Unsupported { rule, .. } => *rule,
94 }
95 }
96}
97
98impl fmt::Display for ADRuleError {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 match self {
101 Self::Unsupported { op, rule } => {
102 write!(f, "unsupported {} AD rule for {}", rule.as_str(), op)
103 }
104 }
105 }
106}
107
108impl Error for ADRuleError {}
109
110/// Result type used by fallible AD rule emission.
111///
112/// # Examples
113///
114/// ```
115/// use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
116///
117/// fn missing_rule() -> ADRuleResult<()> {
118/// Err(ADRuleError::unsupported("custom::op", ADRuleKind::Transpose))
119/// }
120///
121/// assert!(missing_rule().is_err());
122/// ```
123pub type ADRuleResult<T> = Result<T, ADRuleError>;