chainrules_core/
lib.rs

1#![doc = include_str!("../README.md")]
2
3//! Core AD trait definitions (like Julia's ChainRulesCore.jl).
4//!
5//! This crate defines the interface for automatic differentiation without
6//! providing an AD engine. It contains:
7//!
8//! - [`Differentiable`] — tangent space definition for any value type
9//! - [`ReverseRule`] — per-operation reverse-mode rule (rrule/pullback)
10//! - [`ForwardRule`] — per-operation forward-mode rule (frule/pushforward)
11//! - Error types ([`AutodiffError`], [`AdResult`])
12//! - [`NodeId`], [`SavePolicy`] — graph node identifier and save strategy
13//!
14//! AD engines (`Tape`, `TrackedValue`, `DualValue`, `pullback`, `hvp`) live in
15//! separate crates, for example [`tidu`](https://docs.rs/tidu).
16//!
17//! Operation-specific AD rules (e.g., einsum rrule/frule) live in the crate
18//! that defines the operation.
19//!
20//! # Examples
21//!
22//! Implementing `Differentiable` for a custom type:
23//!
24//! ```
25//! use chainrules_core::Differentiable;
26//!
27//! #[derive(Clone)]
28//! struct MyVec(Vec<f64>);
29//!
30//! impl Differentiable for MyVec {
31//!     type Tangent = MyVec;
32//!     fn zero_tangent(&self) -> MyVec {
33//!         MyVec(vec![0.0; self.0.len()])
34//!     }
35//!     fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec {
36//!         MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect())
37//!     }
38//!     fn num_elements(&self) -> usize {
39//!         self.0.len()
40//!     }
41//!     fn seed_cotangent(&self) -> MyVec {
42//!         MyVec(vec![1.0; self.0.len()])
43//!     }
44//! }
45//! ```
46
47/// Trait defining the tangent space for a differentiable type.
48///
49/// This is the core abstraction of the AD framework, analogous to Julia's
50/// ChainRulesCore.jl tangent type system. Any type that participates in
51/// automatic differentiation must implement this trait.
52///
53/// The tangent type represents infinitesimal perturbations of the value.
54/// For most tensor types, `Tangent = Self` (e.g., the tangent of a matrix
55/// is another matrix of the same shape).
56///
57/// Note: This trait intentionally does **not** require `Clone` on the primal
58/// type. `Clone` is only required on `Tangent` (for gradient accumulation).
59/// Large values (e.g., tensors) may be expensive to clone; the AD engine
60/// avoids cloning primals by taking ownership where needed.
61///
62/// # Examples
63///
64/// ```
65/// use chainrules_core::Differentiable;
66///
67/// fn example<V: Differentiable>(x: &V) {
68///     let zero = x.zero_tangent();
69///     let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent());
70/// }
71/// ```
72pub trait Differentiable {
73    /// The tangent type for this value.
74    ///
75    /// For most types, this is `Self` (e.g., tangent of a tensor is a tensor).
76    type Tangent: Clone;
77
78    /// Returns the zero tangent for this value (additive identity).
79    fn zero_tangent(&self) -> Self::Tangent;
80
81    /// Accumulates (adds) two tangents: `a + b`.
82    fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent;
83
84    /// Returns the number of scalar elements in this value.
85    ///
86    /// For scalar types (f64, f32), this is always 1.
87    /// For tensor types, this is the total number of elements.
88    fn num_elements(&self) -> usize;
89
90    /// Returns the seed cotangent for reverse-mode pullback.
91    ///
92    /// For a scalar loss, this returns the "one" tangent (1.0 for scalars,
93    /// ones-like for single-element tensors). Used internally by
94    /// [`Tape::pullback`](https://docs.rs/chainrules) to initialize the
95    /// backward pass.
96    fn seed_cotangent(&self) -> Self::Tangent;
97}
98
99/// AD-specific error type.
100///
101/// # Examples
102///
103/// ```
104/// use chainrules_core::AutodiffError;
105///
106/// let err = AutodiffError::NonScalarLoss { num_elements: 8 };
107/// assert!(format!("{err}").contains("scalar"));
108/// ```
109#[derive(Debug, thiserror::Error)]
110pub enum AutodiffError {
111    /// Loss tensor for pullback must contain exactly one element.
112    #[error("pullback() requires scalar loss, got {num_elements} elements")]
113    NonScalarLoss { num_elements: usize },
114    /// Attempted pullback on a tensor not connected to AD tape.
115    #[error("tensor is not connected to AD tape")]
116    MissingNode,
117    /// Tangent shape must match primal shape.
118    #[error("tangent shape mismatch: expected {expected}, got {got}")]
119    TangentShapeMismatch {
120        /// Expected shape description.
121        expected: String,
122        /// Actual shape description.
123        got: String,
124    },
125    /// A ReverseRule does not support HVP (pullback_with_tangents).
126    #[error("HVP not supported by this ReverseRule implementation")]
127    HvpNotSupported,
128    /// The requested AD mode is not supported for the given algebra or operation.
129    ///
130    /// For example, tropical einsum does not support frule (JVP) or hvp —
131    /// only rrule (VJP) via the argmax route is available.
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// use chainrules_core::AutodiffError;
137    ///
138    /// let err = AutodiffError::ModeNotSupported {
139    ///     mode: "frule".into(),
140    ///     reason: "tropical einsum supports rrule only (max is not smooth)".into(),
141    /// };
142    /// ```
143    #[error("AD mode not supported: {mode} — {reason}")]
144    ModeNotSupported {
145        /// The unsupported mode (e.g., "frule", "hvp").
146        mode: String,
147        /// Explanation of why this mode is not supported.
148        reason: String,
149    },
150    /// Generic AD argument error.
151    #[error("invalid autodiff argument: {0}")]
152    InvalidArgument(String),
153    /// Attempted to execute backward/grad on a graph that was already freed.
154    ///
155    /// # Examples
156    ///
157    /// ```
158    /// use chainrules_core::AutodiffError;
159    ///
160    /// let err = AutodiffError::GraphFreed;
161    /// assert!(err.to_string().contains("freed"));
162    /// ```
163    #[error("computation graph has been freed")]
164    GraphFreed,
165}
166
167/// Result alias for AD APIs.
168///
169/// # Examples
170///
171/// ```
172/// use chainrules_core::AdResult;
173///
174/// fn returns_ad_result() -> AdResult<()> { Ok(()) }
175/// ```
176pub type AdResult<T> = std::result::Result<T, AutodiffError>;
177
178/// Reverse-rule pullback output entry `(input_node, input_cotangent)`.
179pub type PullbackEntry<V> = (NodeId, <V as Differentiable>::Tangent);
180
181/// Reverse-rule pullback-with-tangents output entry.
182///
183/// Tuple layout: `(input_node, input_cotangent, input_cotangent_tangent)`.
184pub type PullbackWithTangentsEntry<V> = (
185    NodeId,
186    <V as Differentiable>::Tangent,
187    <V as Differentiable>::Tangent,
188);
189
190/// Stable identifier of an AD graph node.
191///
192/// # Examples
193///
194/// ```
195/// use chainrules_core::NodeId;
196///
197/// let id = NodeId::new(7);
198/// assert_eq!(id.index(), 7);
199/// ```
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
201pub struct NodeId(usize);
202
203impl NodeId {
204    /// Creates a node ID from an integer index.
205    ///
206    /// # Examples
207    ///
208    /// ```
209    /// use chainrules_core::NodeId;
210    ///
211    /// let id = NodeId::new(42);
212    /// assert_eq!(id.index(), 42);
213    /// ```
214    pub fn new(index: usize) -> Self {
215        Self(index)
216    }
217
218    /// Returns the numeric index.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// use chainrules_core::NodeId;
224    ///
225    /// let id = NodeId::new(3);
226    /// assert_eq!(id.index(), 3);
227    /// ```
228    pub fn index(&self) -> usize {
229        self.0
230    }
231}
232
233/// Saved-tensor retention policy for reverse-mode rules.
234///
235/// # Examples
236///
237/// ```
238/// use chainrules_core::SavePolicy;
239///
240/// let p = SavePolicy::SaveForPullback;
241/// assert_eq!(p, SavePolicy::SaveForPullback);
242/// ```
243#[derive(Debug, Clone, Copy, PartialEq, Eq)]
244pub enum SavePolicy {
245    /// Keep forward tensors for exact pullback formulas.
246    SaveForPullback,
247    /// Discard forward tensors and require recomputation/checkpointing later.
248    RecomputeOnPullback,
249}
250
251/// Reverse-mode AD rule interface (rrule).
252///
253/// Implemented by operation-specific nodes (einsum, reduce, permute, ...).
254/// Named after Julia's ChainRules.jl convention: `rrule` returns a pullback.
255///
256/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
257///
258/// Implementors must be `Send + Sync` because rule objects may be stored on an
259/// AD tape that is shared across threads.
260///
261/// # Examples
262///
263/// Custom reverse rule for scalar multiplication `output = a * b`:
264///
265/// ```
266/// use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};
267///
268/// struct ScalarMulRule {
269///     a: f64,
270///     b: f64,
271///     a_node: NodeId,
272///     b_node: NodeId,
273/// }
274///
275/// impl ReverseRule<f64> for ScalarMulRule {
276///     fn pullback(&self, cotangent: &f64) -> AdResult<Vec<(NodeId, f64)>> {
277///         // d(a*b)/da = b, d(a*b)/db = a
278///         let da = cotangent * self.b;
279///         let db = cotangent * self.a;
280///         Ok(vec![(self.a_node, da), (self.b_node, db)])
281///     }
282///
283///     fn inputs(&self) -> Vec<NodeId> {
284///         vec![self.a_node, self.b_node]
285///     }
286/// }
287///
288/// // Verify: for a=3, b=5, cotangent=1 → da=5, db=3
289/// let rule = ScalarMulRule {
290///     a: 3.0, b: 5.0,
291///     a_node: NodeId::new(0), b_node: NodeId::new(1),
292/// };
293/// let grads = rule.pullback(&1.0).unwrap();
294/// assert_eq!(grads[0], (NodeId::new(0), 5.0)); // da = cotangent * b
295/// assert_eq!(grads[1], (NodeId::new(1), 3.0)); // db = cotangent * a
296/// ```
297pub trait ReverseRule<V: Differentiable>: Send + Sync {
298    /// Computes input cotangents from an output cotangent (pullback).
299    fn pullback(&self, cotangent: &V::Tangent) -> AdResult<Vec<PullbackEntry<V>>>;
300
301    /// Returns input node IDs this rule depends on.
302    fn inputs(&self) -> Vec<NodeId>;
303
304    /// Computes the forward tangent of this operation's output.
305    ///
306    /// Given a closure that returns the tangent for each input node
307    /// (or `None` if the input has no tangent), returns the output tangent.
308    ///
309    /// The default implementation returns [`AutodiffError::HvpNotSupported`].
310    /// Operations that support deferred HVP override this method.
311    fn forward_tangents<'t>(
312        &self,
313        input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>,
314    ) -> AdResult<Option<V::Tangent>>
315    where
316        V::Tangent: 't,
317    {
318        let _ = input_tangents;
319        Err(AutodiffError::HvpNotSupported)
320    }
321
322    /// Computes pullback with tangent propagation for HVP.
323    ///
324    /// Given an output cotangent, its tangent, and a closure providing input
325    /// tangents by node ID, returns
326    /// `(node_id, input_cotangent, input_cotangent_tangent)` triples.
327    ///
328    /// The `input_tangents` closure provides access to forward-propagated
329    /// tangents for each input node, enabling deferred tangent injection
330    /// without storing tangents in the rule struct.
331    ///
332    /// The default implementation returns [`AutodiffError::HvpNotSupported`].
333    /// Operations that support forward-over-reverse HVP override this method.
334    ///
335    /// # Examples
336    ///
337    /// ```ignore
338    /// // Called internally by hvp(); users rarely call this directly.
339    /// let results = rule.pullback_with_tangents(
340    ///     &cotangent, &cotangent_tangent, &|node| tangents_vec[node.index()].as_ref(),
341    /// )?;
342    /// for (node_id, grad, grad_tangent) in results {
343    ///     // grad: standard cotangent for this input
344    ///     // grad_tangent: cotangent tangent for HVP
345    /// }
346    /// ```
347    fn pullback_with_tangents<'t>(
348        &self,
349        cotangent: &V::Tangent,
350        cotangent_tangent: &V::Tangent,
351        input_tangents: &dyn Fn(NodeId) -> Option<&'t V::Tangent>,
352    ) -> AdResult<Vec<PullbackWithTangentsEntry<V>>>
353    where
354        V::Tangent: 't,
355    {
356        let _ = (cotangent, cotangent_tangent, input_tangents);
357        Err(AutodiffError::HvpNotSupported)
358    }
359}
360
361/// Forward-mode AD rule interface (frule).
362///
363/// Named after Julia's ChainRules.jl convention: `frule` computes pushforward.
364///
365/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
366///
367/// Implementors must be `Send + Sync` because rule objects may be stored on an
368/// AD tape that is shared across threads.
369///
370/// # Examples
371///
372/// Custom forward rule for scalar multiplication `output = a * b`:
373///
374/// ```
375/// use chainrules_core::{ForwardRule, Differentiable, AdResult};
376///
377/// struct ScalarMulFrule {
378///     a: f64,
379///     b: f64,
380/// }
381///
382/// impl ForwardRule<f64> for ScalarMulFrule {
383///     fn pushforward(&self, tangents: &[Option<&f64>]) -> AdResult<f64> {
384///         // d(a*b) = da*b + a*db
385///         let da = tangents.get(0).and_then(|t| *t).copied().unwrap_or(0.0);
386///         let db = tangents.get(1).and_then(|t| *t).copied().unwrap_or(0.0);
387///         Ok(da * self.b + self.a * db)
388///     }
389/// }
390///
391/// // Verify: for a=3, b=5, da=1, db=0 → d(a*b) = 1*5 + 3*0 = 5
392/// let rule = ScalarMulFrule { a: 3.0, b: 5.0 };
393/// let result = rule.pushforward(&[Some(&1.0), Some(&0.0)]).unwrap();
394/// assert_eq!(result, 5.0);
395///
396/// // Both tangents active: da=1, db=1 → d(a*b) = 1*5 + 3*1 = 8
397/// let result = rule.pushforward(&[Some(&1.0), Some(&1.0)]).unwrap();
398/// assert_eq!(result, 8.0);
399/// ```
400pub trait ForwardRule<V: Differentiable>: Send + Sync {
401    /// Computes output tangent from input tangents (pushforward).
402    fn pushforward(&self, tangents: &[Option<&V::Tangent>]) -> AdResult<V::Tangent>;
403}
404
405// ============================================================================
406// Differentiable impls for primitive types
407// ============================================================================
408
409impl Differentiable for f64 {
410    type Tangent = f64;
411
412    fn zero_tangent(&self) -> f64 {
413        0.0
414    }
415
416    fn accumulate_tangent(a: f64, b: &f64) -> f64 {
417        a + b
418    }
419
420    fn num_elements(&self) -> usize {
421        1
422    }
423
424    fn seed_cotangent(&self) -> f64 {
425        1.0
426    }
427}
428
429impl Differentiable for f32 {
430    type Tangent = f32;
431
432    fn zero_tangent(&self) -> f32 {
433        0.0
434    }
435
436    fn accumulate_tangent(a: f32, b: &f32) -> f32 {
437        a + b
438    }
439
440    fn num_elements(&self) -> usize {
441        1
442    }
443
444    fn seed_cotangent(&self) -> f32 {
445        1.0
446    }
447}
448
449#[cfg(test)]
450mod tests;