Skip to main content

tenferro_ad/
context.rs

1//! Explicit ownership for automatic-differentiation rule sets.
2
3use tenferro_ops::{ExtensionRegistryError, ExtensionRuleSet};
4use tenferro_runtime::{Result, TracedTensor};
5
6/// Explicit automatic-differentiation context.
7///
8/// `AdContext` owns the extension AD rules used by traced AD transforms.
9///
10/// # Examples
11///
12/// ```rust
13/// use tenferro_ad::AdContext;
14///
15/// let ad = AdContext::builder().build().unwrap();
16/// assert!(ad.extension_rules().lookup_rule("example.missing.v1").is_none());
17/// ```
18#[derive(Clone, Debug)]
19pub struct AdContext {
20    extension_rules: ExtensionRuleSet,
21}
22
23impl AdContext {
24    /// Start building an explicit AD context.
25    ///
26    /// # Examples
27    ///
28    /// ```rust
29    /// use tenferro_ad::AdContext;
30    ///
31    /// let _builder = AdContext::builder();
32    /// ```
33    pub fn builder() -> AdContextBuilder {
34        AdContextBuilder::default()
35    }
36
37    /// Return the extension rules owned by this context.
38    ///
39    /// # Examples
40    ///
41    /// ```rust
42    /// use tenferro_ad::AdContext;
43    ///
44    /// let ad = AdContext::builder().build().unwrap();
45    /// assert!(!ad.extension_rules().is_rule_registered("example.missing.v1"));
46    /// ```
47    pub fn extension_rules(&self) -> &ExtensionRuleSet {
48        &self.extension_rules
49    }
50
51    pub(crate) fn extension_rule_set(&self) -> ExtensionRuleSet {
52        self.extension_rules.clone()
53    }
54
55    /// Gradient of a scalar traced output with respect to a traced input.
56    ///
57    /// For complex scalar outputs, tenferro returns the Hermitian-adjoint
58    /// cotangent. To compare seed-`1` scalar gradients with JAX's public
59    /// `grad` values, use the complex conjugate of this result. See
60    /// <https://tensor4all.org/tenferro-rs/guides/complex-ad.html>.
61    ///
62    /// # Examples
63    ///
64    /// ```rust
65    /// use tenferro_ad::AdContext;
66    /// use tenferro_runtime::TracedTensor;
67    ///
68    /// let ad = AdContext::builder().build().unwrap();
69    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
70    /// let loss = (&x * &x).unwrap();
71    /// let grad = ad.grad(&loss, &x).unwrap();
72    /// assert_eq!(grad.rank, 0);
73    /// ```
74    pub fn grad(&self, output: &TracedTensor, wrt: &TracedTensor) -> Result<TracedTensor> {
75        crate::traced::grad_with_rules(output, wrt, &self.extension_rules)
76    }
77
78    /// Gradient that returns `None` when `wrt` is inactive.
79    ///
80    /// # Examples
81    ///
82    /// ```rust
83    /// use tenferro_ad::AdContext;
84    /// use tenferro_runtime::TracedTensor;
85    ///
86    /// let ad = AdContext::builder().build().unwrap();
87    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
88    /// let loss = (&x * &x).unwrap();
89    /// assert!(ad.grad_optional(&loss, &x).unwrap().is_some());
90    /// ```
91    pub fn grad_optional(
92        &self,
93        output: &TracedTensor,
94        wrt: &TracedTensor,
95    ) -> Result<Option<TracedTensor>> {
96        crate::traced::grad_optional_with_rules(output, wrt, &self.extension_rules)
97    }
98
99    /// Forward-mode Jacobian-vector product.
100    ///
101    /// # Examples
102    ///
103    /// ```rust
104    /// use tenferro_ad::AdContext;
105    /// use tenferro_runtime::TracedTensor;
106    ///
107    /// let ad = AdContext::builder().build().unwrap();
108    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
109    /// let dx = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
110    /// let y = (&x * &x).unwrap();
111    /// let dy = ad.jvp(&y, &x, &dx).unwrap();
112    /// assert_eq!(dy.rank, 0);
113    /// ```
114    pub fn jvp(
115        &self,
116        output: &TracedTensor,
117        wrt: &TracedTensor,
118        tangent: &TracedTensor,
119    ) -> Result<TracedTensor> {
120        crate::traced::jvp_with_rules(output, wrt, tangent, &self.extension_rules)
121    }
122
123    /// Forward-mode Jacobian-vector product that returns `None` for inactive output.
124    ///
125    /// # Examples
126    ///
127    /// ```rust
128    /// use tenferro_ad::AdContext;
129    /// use tenferro_runtime::TracedTensor;
130    ///
131    /// let ad = AdContext::builder().build().unwrap();
132    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
133    /// let dx = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
134    /// let y = (&x * &x).unwrap();
135    /// assert!(ad.jvp_optional(&y, &x, &dx).unwrap().is_some());
136    /// ```
137    pub fn jvp_optional(
138        &self,
139        output: &TracedTensor,
140        wrt: &TracedTensor,
141        tangent: &TracedTensor,
142    ) -> Result<Option<TracedTensor>> {
143        crate::traced::jvp_optional_with_rules(output, wrt, tangent, &self.extension_rules)
144    }
145
146    /// Reverse-mode vector-Jacobian product.
147    ///
148    /// Complex cotangents use tenferro's Hermitian real-inner-product
149    /// convention. Non-real complex cotangent seeds therefore need an explicit
150    /// seed-convention comparison when matching JAX. See
151    /// <https://tensor4all.org/tenferro-rs/guides/complex-ad.html>.
152    ///
153    /// # Examples
154    ///
155    /// ```rust
156    /// use tenferro_ad::AdContext;
157    /// use tenferro_runtime::TracedTensor;
158    ///
159    /// let ad = AdContext::builder().build().unwrap();
160    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
161    /// let dy = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
162    /// let y = (&x * &x).unwrap();
163    /// let dx = ad.vjp(&y, &x, &dy).unwrap();
164    /// assert_eq!(dx.rank, 0);
165    /// ```
166    pub fn vjp(
167        &self,
168        output: &TracedTensor,
169        wrt: &TracedTensor,
170        cotangent: &TracedTensor,
171    ) -> Result<TracedTensor> {
172        crate::traced::vjp_with_rules(output, wrt, cotangent, &self.extension_rules)
173    }
174
175    /// Reverse-mode vector-Jacobian product that returns `None` for inactive input.
176    ///
177    /// # Examples
178    ///
179    /// ```rust
180    /// use tenferro_ad::AdContext;
181    /// use tenferro_runtime::TracedTensor;
182    ///
183    /// let ad = AdContext::builder().build().unwrap();
184    /// let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
185    /// let dy = TracedTensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
186    /// let y = (&x * &x).unwrap();
187    /// assert!(ad.vjp_optional(&y, &x, &dy).unwrap().is_some());
188    /// ```
189    pub fn vjp_optional(
190        &self,
191        output: &TracedTensor,
192        wrt: &TracedTensor,
193        cotangent: &TracedTensor,
194    ) -> Result<Option<TracedTensor>> {
195        crate::traced::vjp_optional_with_rules(output, wrt, cotangent, &self.extension_rules)
196    }
197}
198
199/// Builder for [`AdContext`].
200///
201/// # Examples
202///
203/// ```rust
204/// use tenferro_ad::AdContextBuilder;
205///
206/// let ad = AdContextBuilder::new().build().unwrap();
207/// assert!(ad.extension_rules().lookup_rule("example.missing.v1").is_none());
208/// ```
209#[derive(Clone, Debug, Default)]
210pub struct AdContextBuilder {
211    extension_rule_sets: Vec<ExtensionRuleSet>,
212}
213
214impl AdContextBuilder {
215    /// Create an empty builder.
216    ///
217    /// # Examples
218    ///
219    /// ```rust
220    /// use tenferro_ad::AdContextBuilder;
221    ///
222    /// let _builder = AdContextBuilder::new();
223    /// ```
224    pub fn new() -> Self {
225        Self::default()
226    }
227
228    /// Include an owned extension rule set.
229    ///
230    /// # Examples
231    ///
232    /// ```rust
233    /// use tenferro_ad::{AdContext, extension::ExtensionRuleSet};
234    ///
235    /// let _ad = AdContext::builder()
236    ///     .with_extension_rules(ExtensionRuleSet::new())
237    ///     .build()
238    ///     .unwrap();
239    /// ```
240    pub fn with_extension_rules(mut self, rules: ExtensionRuleSet) -> Self {
241        self.extension_rule_sets.push(rules);
242        self
243    }
244
245    /// Build the context.
246    ///
247    /// Duplicate extension family registrations are rejected.
248    ///
249    /// # Examples
250    ///
251    /// ```rust
252    /// use tenferro_ad::AdContext;
253    ///
254    /// let ad = AdContext::builder().build().unwrap();
255    /// assert!(ad.extension_rules().lookup_rule("example.missing.v1").is_none());
256    /// ```
257    pub fn build(self) -> std::result::Result<AdContext, ExtensionRegistryError> {
258        let mut extension_rules = ExtensionRuleSet::new();
259        for rules in self.extension_rule_sets {
260            extension_rules.merge(rules)?;
261        }
262        Ok(AdContext { extension_rules })
263    }
264}