tenferro_ad/context.rs
1//! Explicit ownership for automatic-differentiation rule sets.
2
3use tenferro_ops::{ExtensionRegistryError, ExtensionRuleSet};
4use tenferro_runtime::{Result, TracedTensor};
5
6/// Explicit automatic-differentiation context.
7///
8/// `AdContext` owns the extension AD rules used by traced AD transforms.
9///
10/// # Examples
11///
12/// ```rust
13/// use tenferro_ad::AdContext;
14///
15/// let ad = AdContext::builder().build().unwrap();
16/// assert!(ad.extension_rules().lookup_rule("example.missing.v1").is_none());
17/// ```
18#[derive(Clone, Debug)]
19pub struct AdContext {
20 extension_rules: ExtensionRuleSet,
21}
22
23impl AdContext {
24 /// Start building an explicit AD context.
25 ///
26 /// # Examples
27 ///
28 /// ```rust
29 /// use tenferro_ad::AdContext;
30 ///
31 /// let _builder = AdContext::builder();
32 /// ```
33 pub fn builder() -> AdContextBuilder {
34 AdContextBuilder::default()
35 }
36
37 /// Return the extension rules owned by this context.
38 ///
39 /// # Examples
40 ///
41 /// ```rust
42 /// use tenferro_ad::AdContext;
43 ///
44 /// let ad = AdContext::builder().build().unwrap();
45 /// assert!(!ad.extension_rules().is_rule_registered("example.missing.v1"));
46 /// ```
47 pub fn extension_rules(&self) -> &ExtensionRuleSet {
48 &self.extension_rules
49 }
50
51 pub(crate) fn extension_rule_set(&self) -> ExtensionRuleSet {
52 self.extension_rules.clone()
53 }
54
55 /// Gradient of a scalar traced output with respect to a traced input.
56 ///
57 /// For complex scalar outputs, tenferro returns the Hermitian-adjoint
58 /// cotangent. To compare seed-`1` scalar gradients with JAX's public
59 /// `grad` values, use the complex conjugate of this result. See
60 /// <https://tensor4all.org/tenferro-rs/guides/complex-ad.html>.
61 ///
62 /// # Examples
63 ///
64 /// ```rust
65 /// use tenferro_ad::AdContext;
66 /// use tenferro_runtime::TracedTensor;
67 ///
68 /// let ad = AdContext::builder().build().unwrap();
69 /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
70 /// let loss = (&x * &x).unwrap();
71 /// let grad = ad.grad(&loss, &x).unwrap();
72 /// assert_eq!(grad.rank, 0);
73 /// ```
74 pub fn grad(&self, output: &TracedTensor, wrt: &TracedTensor) -> Result<TracedTensor> {
75 crate::traced::grad_with_rules(output, wrt, &self.extension_rules)
76 }
77
78 /// Gradient that returns `None` when `wrt` is inactive.
79 ///
80 /// # Examples
81 ///
82 /// ```rust
83 /// use tenferro_ad::AdContext;
84 /// use tenferro_runtime::TracedTensor;
85 ///
86 /// let ad = AdContext::builder().build().unwrap();
87 /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
88 /// let loss = (&x * &x).unwrap();
89 /// assert!(ad.grad_optional(&loss, &x).unwrap().is_some());
90 /// ```
91 pub fn grad_optional(
92 &self,
93 output: &TracedTensor,
94 wrt: &TracedTensor,
95 ) -> Result<Option<TracedTensor>> {
96 crate::traced::grad_optional_with_rules(output, wrt, &self.extension_rules)
97 }
98
99 /// Forward-mode Jacobian-vector product.
100 ///
101 /// # Examples
102 ///
103 /// ```rust
104 /// use tenferro_ad::AdContext;
105 /// use tenferro_runtime::TracedTensor;
106 ///
107 /// let ad = AdContext::builder().build().unwrap();
108 /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
109 /// let dx = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
110 /// let y = (&x * &x).unwrap();
111 /// let dy = ad.jvp(&y, &x, &dx).unwrap();
112 /// assert_eq!(dy.rank, 0);
113 /// ```
114 pub fn jvp(
115 &self,
116 output: &TracedTensor,
117 wrt: &TracedTensor,
118 tangent: &TracedTensor,
119 ) -> Result<TracedTensor> {
120 crate::traced::jvp_with_rules(output, wrt, tangent, &self.extension_rules)
121 }
122
123 /// Forward-mode Jacobian-vector product that returns `None` for inactive output.
124 ///
125 /// # Examples
126 ///
127 /// ```rust
128 /// use tenferro_ad::AdContext;
129 /// use tenferro_runtime::TracedTensor;
130 ///
131 /// let ad = AdContext::builder().build().unwrap();
132 /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
133 /// let dx = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
134 /// let y = (&x * &x).unwrap();
135 /// assert!(ad.jvp_optional(&y, &x, &dx).unwrap().is_some());
136 /// ```
137 pub fn jvp_optional(
138 &self,
139 output: &TracedTensor,
140 wrt: &TracedTensor,
141 tangent: &TracedTensor,
142 ) -> Result<Option<TracedTensor>> {
143 crate::traced::jvp_optional_with_rules(output, wrt, tangent, &self.extension_rules)
144 }
145
146 /// Reverse-mode vector-Jacobian product.
147 ///
148 /// Complex cotangents use tenferro's Hermitian real-inner-product
149 /// convention. Non-real complex cotangent seeds therefore need an explicit
150 /// seed-convention comparison when matching JAX. See
151 /// <https://tensor4all.org/tenferro-rs/guides/complex-ad.html>.
152 ///
153 /// # Examples
154 ///
155 /// ```rust
156 /// use tenferro_ad::AdContext;
157 /// use tenferro_runtime::TracedTensor;
158 ///
159 /// let ad = AdContext::builder().build().unwrap();
160 /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
161 /// let dy = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
162 /// let y = (&x * &x).unwrap();
163 /// let dx = ad.vjp(&y, &x, &dy).unwrap();
164 /// assert_eq!(dx.rank, 0);
165 /// ```
166 pub fn vjp(
167 &self,
168 output: &TracedTensor,
169 wrt: &TracedTensor,
170 cotangent: &TracedTensor,
171 ) -> Result<TracedTensor> {
172 crate::traced::vjp_with_rules(output, wrt, cotangent, &self.extension_rules)
173 }
174
175 /// Reverse-mode vector-Jacobian product that returns `None` for inactive input.
176 ///
177 /// # Examples
178 ///
179 /// ```rust
180 /// use tenferro_ad::AdContext;
181 /// use tenferro_runtime::TracedTensor;
182 ///
183 /// let ad = AdContext::builder().build().unwrap();
184 /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
185 /// let dy = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
186 /// let y = (&x * &x).unwrap();
187 /// assert!(ad.vjp_optional(&y, &x, &dy).unwrap().is_some());
188 /// ```
189 pub fn vjp_optional(
190 &self,
191 output: &TracedTensor,
192 wrt: &TracedTensor,
193 cotangent: &TracedTensor,
194 ) -> Result<Option<TracedTensor>> {
195 crate::traced::vjp_optional_with_rules(output, wrt, cotangent, &self.extension_rules)
196 }
197}
198
199/// Builder for [`AdContext`].
200///
201/// # Examples
202///
203/// ```rust
204/// use tenferro_ad::AdContextBuilder;
205///
206/// let ad = AdContextBuilder::new().build().unwrap();
207/// assert!(ad.extension_rules().lookup_rule("example.missing.v1").is_none());
208/// ```
209#[derive(Clone, Debug, Default)]
210pub struct AdContextBuilder {
211 extension_rule_sets: Vec<ExtensionRuleSet>,
212}
213
214impl AdContextBuilder {
215 /// Create an empty builder.
216 ///
217 /// # Examples
218 ///
219 /// ```rust
220 /// use tenferro_ad::AdContextBuilder;
221 ///
222 /// let _builder = AdContextBuilder::new();
223 /// ```
224 pub fn new() -> Self {
225 Self::default()
226 }
227
228 /// Include an owned extension rule set.
229 ///
230 /// # Examples
231 ///
232 /// ```rust
233 /// use tenferro_ad::{AdContext, extension::ExtensionRuleSet};
234 ///
235 /// let _ad = AdContext::builder()
236 /// .with_extension_rules(ExtensionRuleSet::new())
237 /// .build()
238 /// .unwrap();
239 /// ```
240 pub fn with_extension_rules(mut self, rules: ExtensionRuleSet) -> Self {
241 self.extension_rule_sets.push(rules);
242 self
243 }
244
245 /// Build the context.
246 ///
247 /// Duplicate extension family registrations are rejected.
248 ///
249 /// # Examples
250 ///
251 /// ```rust
252 /// use tenferro_ad::AdContext;
253 ///
254 /// let ad = AdContext::builder().build().unwrap();
255 /// assert!(ad.extension_rules().lookup_rule("example.missing.v1").is_none());
256 /// ```
257 pub fn build(self) -> std::result::Result<AdContext, ExtensionRegistryError> {
258 let mut extension_rules = ExtensionRuleSet::new();
259 for rules in self.extension_rule_sets {
260 extension_rules.merge(rules)?;
261 }
262 Ok(AdContext { extension_rules })
263 }
264}