chainrules_core/lib.rs
1#![doc = include_str!("../README.md")]
2
3//! Core AD trait definitions (like Julia's ChainRulesCore.jl).
4//!
5//! This crate defines the interface for automatic differentiation without
6//! providing an AD engine. It contains:
7//!
8//! - [`Differentiable`] — tangent space definition for any value type
9//! - [`ReverseRule`] — per-operation reverse-mode rule (rrule/pullback)
10//! - [`ForwardRule`] — per-operation forward-mode rule (frule/pushforward)
11//! - Error types ([`AutodiffError`], [`AdResult`])
12//! - [`NodeId`], [`SavePolicy`] — graph node identifier and save strategy
13//!
14//! AD engines (`Tape`, `TrackedValue`, `DualValue`, `pullback`, `hvp`) live in
15//! separate crates, for example [`tidu`](https://docs.rs/tidu).
16//!
17//! Operation-specific AD rules (e.g., einsum rrule/frule) live in the crate
18//! that defines the operation.
19//!
20//! # Examples
21//!
22//! Implementing `Differentiable` for a custom type:
23//!
24//! ```
25//! use chainrules_core::Differentiable;
26//!
27//! #[derive(Clone)]
28//! struct MyVec(Vec<f64>);
29//!
30//! impl Differentiable for MyVec {
31//! type Tangent = MyVec;
32//! fn zero_tangent(&self) -> MyVec {
33//! MyVec(vec![0.0; self.0.len()])
34//! }
35//! fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec {
36//! MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect())
37//! }
38//! fn num_elements(&self) -> usize {
39//! self.0.len()
40//! }
41//! fn seed_cotangent(&self) -> MyVec {
42//! MyVec(vec![1.0; self.0.len()])
43//! }
44//! }
45//! ```
46
47/// Trait defining the tangent space for a differentiable type.
48///
49/// This is the core abstraction of the AD framework, analogous to Julia's
50/// ChainRulesCore.jl tangent type system. Any type that participates in
51/// automatic differentiation must implement this trait.
52///
53/// The tangent type represents infinitesimal perturbations of the value.
54/// For most tensor types, `Tangent = Self` (e.g., the tangent of a matrix
55/// is another matrix of the same shape).
56///
57/// Note: This trait intentionally does **not** require `Clone` on the primal
58/// type. `Clone` is only required on `Tangent` (for gradient accumulation).
59/// Large values (e.g., tensors) may be expensive to clone; the AD engine
60/// avoids cloning primals by taking ownership where needed.
61///
62/// # Examples
63///
64/// ```
65/// use chainrules_core::Differentiable;
66///
67/// fn example<V: Differentiable>(x: &V) {
68/// let zero = x.zero_tangent();
69/// let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent());
70/// }
71/// ```
72pub trait Differentiable {
73 /// The tangent type for this value.
74 ///
75 /// For most types, this is `Self` (e.g., tangent of a tensor is a tensor).
76 type Tangent: Clone;
77
78 /// Returns the zero tangent for this value (additive identity).
79 fn zero_tangent(&self) -> Self::Tangent;
80
81 /// Accumulates (adds) two tangents: `a + b`.
82 fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent;
83
84 /// Returns the number of scalar elements in this value.
85 ///
86 /// For scalar types (f64, f32), this is always 1.
87 /// For tensor types, this is the total number of elements.
88 fn num_elements(&self) -> usize;
89
90 /// Returns the seed cotangent for reverse-mode pullback.
91 ///
92 /// For a scalar loss, this returns the "one" tangent (1.0 for scalars,
93 /// ones-like for single-element tensors). Used internally by
94 /// [`Tape::pullback`](https://docs.rs/chainrules) to initialize the
95 /// backward pass.
96 fn seed_cotangent(&self) -> Self::Tangent;
97}
98
99/// AD-specific error type.
100///
101/// # Examples
102///
103/// ```
104/// use chainrules_core::AutodiffError;
105///
106/// let err = AutodiffError::NonScalarLoss { num_elements: 8 };
107/// assert!(format!("{err}").contains("scalar"));
108/// ```
109#[derive(Debug, thiserror::Error)]
110pub enum AutodiffError {
111 /// Loss tensor for pullback must contain exactly one element.
112 #[error("pullback() requires scalar loss, got {num_elements} elements")]
113 NonScalarLoss { num_elements: usize },
114 /// Attempted pullback on a tensor not connected to AD tape.
115 #[error("tensor is not connected to AD tape")]
116 MissingNode,
117 /// Tangent shape must match primal shape.
118 #[error("tangent shape mismatch: expected {expected}, got {got}")]
119 TangentShapeMismatch {
120 /// Expected shape description.
121 expected: String,
122 /// Actual shape description.
123 got: String,
124 },
125 /// A ReverseRule does not support HVP (pullback_with_tangents).
126 #[error("HVP not supported by this ReverseRule implementation")]
127 HvpNotSupported,
128 /// The requested AD mode is not supported for the given algebra or operation.
129 ///
130 /// For example, tropical einsum does not support frule (JVP) or hvp —
131 /// only rrule (VJP) via the argmax route is available.
132 ///
133 /// # Examples
134 ///
135 /// ```
136 /// use chainrules_core::AutodiffError;
137 ///
138 /// let err = AutodiffError::ModeNotSupported {
139 /// mode: "frule".into(),
140 /// reason: "tropical einsum supports rrule only (max is not smooth)".into(),
141 /// };
142 /// ```
143 #[error("AD mode not supported: {mode} — {reason}")]
144 ModeNotSupported {
145 /// The unsupported mode (e.g., "frule", "hvp").
146 mode: String,
147 /// Explanation of why this mode is not supported.
148 reason: String,
149 },
150 /// Generic AD argument error.
151 #[error("invalid autodiff argument: {0}")]
152 InvalidArgument(String),
153 /// Attempted to execute backward/grad on a graph that was already freed.
154 ///
155 /// # Examples
156 ///
157 /// ```
158 /// use chainrules_core::AutodiffError;
159 ///
160 /// let err = AutodiffError::GraphFreed;
161 /// assert!(err.to_string().contains("freed"));
162 /// ```
163 #[error("computation graph has been freed")]
164 GraphFreed,
165}
166
167/// Result alias for AD APIs.
168///
169/// # Examples
170///
171/// ```
172/// use chainrules_core::AdResult;
173///
174/// fn returns_ad_result() -> AdResult<()> { Ok(()) }
175/// ```
176pub type AdResult<T> = std::result::Result<T, AutodiffError>;
177
178/// Reverse-rule pullback output entry `(input_node, input_cotangent)`.
179pub type PullbackEntry<V> = (NodeId, <V as Differentiable>::Tangent);
180
181/// Reverse-rule pullback-with-tangents output entry.
182///
183/// Tuple layout: `(input_node, input_cotangent, input_cotangent_tangent)`.
184pub type PullbackWithTangentsEntry<V> = (
185 NodeId,
186 <V as Differentiable>::Tangent,
187 <V as Differentiable>::Tangent,
188);
189
190/// Stable identifier of an AD graph node.
191///
192/// # Examples
193///
194/// ```
195/// use chainrules_core::NodeId;
196///
197/// let id = NodeId::new(7);
198/// assert_eq!(id.index(), 7);
199/// ```
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
201pub struct NodeId(usize);
202
203impl NodeId {
204 /// Creates a node ID from an integer index.
205 ///
206 /// # Examples
207 ///
208 /// ```
209 /// use chainrules_core::NodeId;
210 ///
211 /// let id = NodeId::new(42);
212 /// assert_eq!(id.index(), 42);
213 /// ```
214 pub fn new(index: usize) -> Self {
215 Self(index)
216 }
217
218 /// Returns the numeric index.
219 ///
220 /// # Examples
221 ///
222 /// ```
223 /// use chainrules_core::NodeId;
224 ///
225 /// let id = NodeId::new(3);
226 /// assert_eq!(id.index(), 3);
227 /// ```
228 pub fn index(&self) -> usize {
229 self.0
230 }
231}
232
233/// Saved-tensor retention policy for reverse-mode rules.
234///
235/// # Examples
236///
237/// ```
238/// use chainrules_core::SavePolicy;
239///
240/// let p = SavePolicy::SaveForPullback;
241/// assert_eq!(p, SavePolicy::SaveForPullback);
242/// ```
243#[derive(Debug, Clone, Copy, PartialEq, Eq)]
244pub enum SavePolicy {
245 /// Keep forward tensors for exact pullback formulas.
246 SaveForPullback,
247 /// Discard forward tensors and require recomputation/checkpointing later.
248 RecomputeOnPullback,
249}
250
251/// Reverse-mode AD rule interface (rrule).
252///
253/// Implemented by operation-specific nodes (einsum, reduce, permute, ...).
254/// Named after Julia's ChainRules.jl convention: `rrule` returns a pullback.
255///
256/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
257///
258/// Implementors must be `Send + Sync` because rule objects may be stored on an
259/// AD tape that is shared across threads.
260///
261/// # Examples
262///
263/// Custom reverse rule for scalar multiplication `output = a * b`:
264///
265/// ```
266/// use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};
267///
268/// struct ScalarMulRule {
269/// a: f64,
270/// b: f64,
271/// a_node: NodeId,
272/// b_node: NodeId,
273/// }
274///
275/// impl ReverseRule<f64> for ScalarMulRule {
276/// fn pullback(&self, cotangent: &f64) -> AdResult<Vec<(NodeId, f64)>> {
277/// // d(a*b)/da = b, d(a*b)/db = a
278/// let da = cotangent * self.b;
279/// let db = cotangent * self.a;
280/// Ok(vec![(self.a_node, da), (self.b_node, db)])
281/// }
282///
283/// fn inputs(&self) -> Vec<NodeId> {
284/// vec![self.a_node, self.b_node]
285/// }
286/// }
287///
288/// // Verify: for a=3, b=5, cotangent=1 → da=5, db=3
289/// let rule = ScalarMulRule {
290/// a: 3.0, b: 5.0,
291/// a_node: NodeId::new(0), b_node: NodeId::new(1),
292/// };
293/// let grads = rule.pullback(&1.0).unwrap();
294/// assert_eq!(grads[0], (NodeId::new(0), 5.0)); // da = cotangent * b
295/// assert_eq!(grads[1], (NodeId::new(1), 3.0)); // db = cotangent * a
296/// ```
297pub trait ReverseRule<V: Differentiable>: Send + Sync {
298 /// Computes input cotangents from an output cotangent (pullback).
299 fn pullback(&self, cotangent: &V::Tangent) -> AdResult<Vec<PullbackEntry<V>>>;
300
301 /// Returns input node IDs this rule depends on.
302 fn inputs(&self) -> Vec<NodeId>;
303
304 /// Computes the forward tangent of this operation's output.
305 ///
306 /// Given a closure that returns the tangent for each input node
307 /// (or `None` if the input has no tangent), returns the output tangent.
308 ///
309 /// The default implementation returns [`AutodiffError::HvpNotSupported`].
310 /// Operations that support deferred HVP override this method.
311 fn forward_tangents<'t>(
312 &self,
313 input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>,
314 ) -> AdResult<Option<V::Tangent>>
315 where
316 V::Tangent: 't,
317 {
318 let _ = input_tangents;
319 Err(AutodiffError::HvpNotSupported)
320 }
321
322 /// Computes pullback with tangent propagation for HVP.
323 ///
324 /// Given an output cotangent, its tangent, and a closure providing input
325 /// tangents by node ID, returns
326 /// `(node_id, input_cotangent, input_cotangent_tangent)` triples.
327 ///
328 /// The `input_tangents` closure provides access to forward-propagated
329 /// tangents for each input node, enabling deferred tangent injection
330 /// without storing tangents in the rule struct.
331 ///
332 /// The default implementation returns [`AutodiffError::HvpNotSupported`].
333 /// Operations that support forward-over-reverse HVP override this method.
334 ///
335 /// # Examples
336 ///
337 /// ```ignore
338 /// // Called internally by hvp(); users rarely call this directly.
339 /// let results = rule.pullback_with_tangents(
340 /// &cotangent, &cotangent_tangent, &|node| tangents_vec[node.index()].as_ref(),
341 /// )?;
342 /// for (node_id, grad, grad_tangent) in results {
343 /// // grad: standard cotangent for this input
344 /// // grad_tangent: cotangent tangent for HVP
345 /// }
346 /// ```
347 fn pullback_with_tangents<'t>(
348 &self,
349 cotangent: &V::Tangent,
350 cotangent_tangent: &V::Tangent,
351 input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>,
352 ) -> AdResult<Vec<PullbackWithTangentsEntry<V>>>
353 where
354 V::Tangent: 't,
355 {
356 let _ = (cotangent, cotangent_tangent, input_tangents);
357 Err(AutodiffError::HvpNotSupported)
358 }
359}
360
361/// Forward-mode AD rule interface (frule).
362///
363/// Named after Julia's ChainRules.jl convention: `frule` computes pushforward.
364///
365/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
366///
367/// Implementors must be `Send + Sync` because rule objects may be stored on an
368/// AD tape that is shared across threads.
369///
370/// # Examples
371///
372/// Custom forward rule for scalar multiplication `output = a * b`:
373///
374/// ```
375/// use chainrules_core::{ForwardRule, Differentiable, AdResult};
376///
377/// struct ScalarMulFrule {
378/// a: f64,
379/// b: f64,
380/// }
381///
382/// impl ForwardRule<f64> for ScalarMulFrule {
383/// fn pushforward(&self, tangents: &[Option<&f64>]) -> AdResult<f64> {
384/// // d(a*b) = da*b + a*db
385/// let da = tangents.get(0).and_then(|t| *t).copied().unwrap_or(0.0);
386/// let db = tangents.get(1).and_then(|t| *t).copied().unwrap_or(0.0);
387/// Ok(da * self.b + self.a * db)
388/// }
389/// }
390///
391/// // Verify: for a=3, b=5, da=1, db=0 → d(a*b) = 1*5 + 3*0 = 5
392/// let rule = ScalarMulFrule { a: 3.0, b: 5.0 };
393/// let result = rule.pushforward(&[Some(&1.0), Some(&0.0)]).unwrap();
394/// assert_eq!(result, 5.0);
395///
396/// // Both tangents active: da=1, db=1 → d(a*b) = 1*5 + 3*1 = 8
397/// let result = rule.pushforward(&[Some(&1.0), Some(&1.0)]).unwrap();
398/// assert_eq!(result, 8.0);
399/// ```
400pub trait ForwardRule<V: Differentiable>: Send + Sync {
401 /// Computes output tangent from input tangents (pushforward).
402 fn pushforward(&self, tangents: &[Option<&V::Tangent>]) -> AdResult<V::Tangent>;
403}
404
405// ============================================================================
406// Differentiable impls for primitive types
407// ============================================================================
408
409impl Differentiable for f64 {
410 type Tangent = f64;
411
412 fn zero_tangent(&self) -> f64 {
413 0.0
414 }
415
416 fn accumulate_tangent(a: f64, b: &f64) -> f64 {
417 a + b
418 }
419
420 fn num_elements(&self) -> usize {
421 1
422 }
423
424 fn seed_cotangent(&self) -> f64 {
425 1.0
426 }
427}
428
429impl Differentiable for f32 {
430 type Tangent = f32;
431
432 fn zero_tangent(&self) -> f32 {
433 0.0
434 }
435
436 fn accumulate_tangent(a: f32, b: &f32) -> f32 {
437 a + b
438 }
439
440 fn num_elements(&self) -> usize {
441 1
442 }
443
444 fn seed_cotangent(&self) -> f32 {
445 1.0
446 }
447}
448
449#[cfg(test)]
450mod tests;