chainrules_core/
lib.rs

1#![doc = include_str!("../README.md")]
2
3//! Core AD trait definitions (like Julia's ChainRulesCore.jl).
4//!
5//! This crate defines the interface for automatic differentiation without
6//! providing an AD engine. It contains:
7//!
8//! - [`Differentiable`] — tangent space definition for any value type
9//! - [`ReverseRule`] — per-operation reverse-mode rule (rrule/pullback)
10//! - [`ForwardRule`] — per-operation forward-mode rule (frule/pushforward)
11//! - Error types ([`AutodiffError`], [`AdResult`])
12//! - [`NodeId`], [`SavePolicy`] — graph node identifier and save strategy
13//!
14//! AD engines (`Tape`, `TrackedValue`, `DualValue`, `pullback`, `hvp`) live in
15//! separate crates, for example [`tidu`](https://docs.rs/tidu).
16//!
17//! Operation-specific AD rules (e.g., einsum rrule/frule) live in the crate
18//! that defines the operation.
19//!
20//! # Examples
21//!
22//! Implementing `Differentiable` for a custom type:
23//!
24//! ```ignore
25//! use chainrules_core::Differentiable;
26//!
27//! #[derive(Clone)]
28//! struct MyVec(Vec<f64>);
29//!
30//! impl Differentiable for MyVec {
31//!     type Tangent = MyVec;
32//!     fn zero_tangent(&self) -> MyVec {
33//!         MyVec(vec![0.0; self.0.len()])
34//!     }
35//!     fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec {
36//!         MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect())
37//!     }
38//!     fn num_elements(&self) -> usize {
39//!         self.0.len()
40//!     }
41//!     fn seed_cotangent(&self) -> MyVec {
42//!         MyVec(vec![1.0; self.0.len()])
43//!     }
44//! }
45//! ```
46
47/// Trait defining the tangent space for a differentiable type.
48///
49/// This is the core abstraction of the AD framework, analogous to Julia's
50/// ChainRulesCore.jl tangent type system. Any type that participates in
51/// automatic differentiation must implement this trait.
52///
53/// The tangent type represents infinitesimal perturbations of the value.
54/// For most tensor types, `Tangent = Self` (e.g., the tangent of a matrix
55/// is another matrix of the same shape).
56///
57/// Note: This trait intentionally does **not** require `Clone` on the primal
58/// type. `Clone` is only required on `Tangent` (for gradient accumulation).
59/// Large values (e.g., tensors) may be expensive to clone; the AD engine
60/// avoids cloning primals by taking ownership where needed.
61///
62/// # Examples
63///
64/// ```ignore
65/// use chainrules_core::Differentiable;
66///
67/// // Tensor<f64> implements Differentiable with Tangent = Tensor<f64>
68/// // (defined in tenferro-tensor crate)
69/// fn example<V: Differentiable>(x: &V) {
70///     let zero = x.zero_tangent();
71///     let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent());
72/// }
73/// ```
74pub trait Differentiable {
75    /// The tangent type for this value.
76    ///
77    /// For most types, this is `Self` (e.g., tangent of a tensor is a tensor).
78    type Tangent: Clone;
79
80    /// Returns the zero tangent for this value (additive identity).
81    fn zero_tangent(&self) -> Self::Tangent;
82
83    /// Accumulates (adds) two tangents: `a + b`.
84    fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent;
85
86    /// Returns the number of scalar elements in this value.
87    ///
88    /// For scalar types (f64, f32), this is always 1.
89    /// For tensor types, this is the total number of elements.
90    fn num_elements(&self) -> usize;
91
92    /// Returns the seed cotangent for reverse-mode pullback.
93    ///
94    /// For a scalar loss, this returns the "one" tangent (1.0 for scalars,
95    /// ones-like for single-element tensors). Used internally by
96    /// [`Tape::pullback`](https://docs.rs/chainrules) to initialize the
97    /// backward pass.
98    fn seed_cotangent(&self) -> Self::Tangent;
99}
100
101/// AD-specific error type.
102///
103/// # Examples
104///
105/// ```
106/// use chainrules_core::AutodiffError;
107///
108/// let err = AutodiffError::NonScalarLoss { num_elements: 8 };
109/// assert!(format!("{err}").contains("scalar"));
110/// ```
111#[derive(Debug, thiserror::Error)]
112pub enum AutodiffError {
113    /// Loss tensor for pullback must contain exactly one element.
114    #[error("pullback() requires scalar loss, got {num_elements} elements")]
115    NonScalarLoss { num_elements: usize },
116    /// Attempted pullback on a tensor not connected to AD tape.
117    #[error("tensor is not connected to AD tape")]
118    MissingNode,
119    /// Tangent shape must match primal shape.
120    #[error("tangent shape mismatch: expected {expected}, got {got}")]
121    TangentShapeMismatch {
122        /// Expected shape description.
123        expected: String,
124        /// Actual shape description.
125        got: String,
126    },
127    /// A ReverseRule does not support HVP (pullback_with_tangents).
128    #[error("HVP not supported by this ReverseRule implementation")]
129    HvpNotSupported,
130    /// The requested AD mode is not supported for the given algebra or operation.
131    ///
132    /// For example, tropical einsum does not support frule (JVP) or hvp —
133    /// only rrule (VJP) via the argmax route is available.
134    ///
135    /// # Examples
136    ///
137    /// ```
138    /// use chainrules_core::AutodiffError;
139    ///
140    /// let err = AutodiffError::ModeNotSupported {
141    ///     mode: "frule".into(),
142    ///     reason: "tropical einsum supports rrule only (max is not smooth)".into(),
143    /// };
144    /// ```
145    #[error("AD mode not supported: {mode} — {reason}")]
146    ModeNotSupported {
147        /// The unsupported mode (e.g., "frule", "hvp").
148        mode: String,
149        /// Explanation of why this mode is not supported.
150        reason: String,
151    },
152    /// Generic AD argument error.
153    #[error("invalid autodiff argument: {0}")]
154    InvalidArgument(String),
155    /// Attempted to execute backward/grad on a graph that was already freed.
156    ///
157    /// # Examples
158    ///
159    /// ```
160    /// use chainrules_core::AutodiffError;
161    ///
162    /// let err = AutodiffError::GraphFreed;
163    /// assert!(err.to_string().contains("freed"));
164    /// ```
165    #[error("computation graph has been freed")]
166    GraphFreed,
167}
168
169/// Result alias for AD APIs.
170///
171/// # Examples
172///
173/// ```
174/// use chainrules_core::AdResult;
175///
176/// fn returns_ad_result() -> AdResult<()> { Ok(()) }
177/// ```
178pub type AdResult<T> = std::result::Result<T, AutodiffError>;
179
180/// Reverse-rule pullback output entry `(input_node, input_cotangent)`.
181pub type PullbackEntry<V> = (NodeId, <V as Differentiable>::Tangent);
182
183/// Reverse-rule pullback-with-tangents output entry.
184///
185/// Tuple layout: `(input_node, input_cotangent, input_cotangent_tangent)`.
186pub type PullbackWithTangentsEntry<V> = (
187    NodeId,
188    <V as Differentiable>::Tangent,
189    <V as Differentiable>::Tangent,
190);
191
192/// Stable identifier of an AD graph node.
193///
194/// # Examples
195///
196/// ```
197/// use chainrules_core::NodeId;
198///
199/// let id = NodeId::new(7);
200/// assert_eq!(id.index(), 7);
201/// ```
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
203pub struct NodeId(usize);
204
205impl NodeId {
206    /// Creates a node ID from an integer index.
207    ///
208    /// # Examples
209    ///
210    /// ```
211    /// use chainrules_core::NodeId;
212    ///
213    /// let id = NodeId::new(42);
214    /// assert_eq!(id.index(), 42);
215    /// ```
216    pub fn new(index: usize) -> Self {
217        Self(index)
218    }
219
220    /// Returns the numeric index.
221    ///
222    /// # Examples
223    ///
224    /// ```
225    /// use chainrules_core::NodeId;
226    ///
227    /// let id = NodeId::new(3);
228    /// assert_eq!(id.index(), 3);
229    /// ```
230    pub fn index(&self) -> usize {
231        self.0
232    }
233}
234
235/// Saved-tensor retention policy for reverse-mode rules.
236///
237/// # Examples
238///
239/// ```
240/// use chainrules_core::SavePolicy;
241///
242/// let p = SavePolicy::SaveForPullback;
243/// assert_eq!(p, SavePolicy::SaveForPullback);
244/// ```
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum SavePolicy {
247    /// Keep forward tensors for exact pullback formulas.
248    SaveForPullback,
249    /// Discard forward tensors and require recomputation/checkpointing later.
250    RecomputeOnPullback,
251}
252
253/// Reverse-mode AD rule interface (rrule).
254///
255/// Implemented by operation-specific nodes (einsum, reduce, permute, ...).
256/// Named after Julia's ChainRules.jl convention: `rrule` returns a pullback.
257///
258/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
259///
260/// # Examples
261///
262/// Custom reverse rule for scalar multiplication `output = a * b`:
263///
264/// ```
265/// use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};
266///
267/// struct ScalarMulRule {
268///     a: f64,
269///     b: f64,
270///     a_node: NodeId,
271///     b_node: NodeId,
272/// }
273///
274/// impl ReverseRule<f64> for ScalarMulRule {
275///     fn pullback(&self, cotangent: &f64) -> AdResult<Vec<(NodeId, f64)>> {
276///         // d(a*b)/da = b, d(a*b)/db = a
277///         let da = cotangent * self.b;
278///         let db = cotangent * self.a;
279///         Ok(vec![(self.a_node, da), (self.b_node, db)])
280///     }
281///
282///     fn inputs(&self) -> Vec<NodeId> {
283///         vec![self.a_node, self.b_node]
284///     }
285/// }
286///
287/// // Verify: for a=3, b=5, cotangent=1 → da=5, db=3
288/// let rule = ScalarMulRule {
289///     a: 3.0, b: 5.0,
290///     a_node: NodeId::new(0), b_node: NodeId::new(1),
291/// };
292/// let grads = rule.pullback(&1.0).unwrap();
293/// assert_eq!(grads[0], (NodeId::new(0), 5.0)); // da = cotangent * b
294/// assert_eq!(grads[1], (NodeId::new(1), 3.0)); // db = cotangent * a
295/// ```
296pub trait ReverseRule<V: Differentiable>: Send + Sync {
297    /// Computes input cotangents from an output cotangent (pullback).
298    fn pullback(&self, cotangent: &V::Tangent) -> AdResult<Vec<PullbackEntry<V>>>;
299
300    /// Returns input node IDs this rule depends on.
301    fn inputs(&self) -> Vec<NodeId>;
302
303    /// Computes pullback with tangent propagation for HVP.
304    ///
305    /// Given an output cotangent and its tangent, returns
306    /// `(node_id, input_cotangent, input_cotangent_tangent)` triples.
307    ///
308    /// The default implementation returns [`AutodiffError::HvpNotSupported`].
309    /// Operations that support forward-over-reverse HVP override this method.
310    ///
311    /// # Examples
312    ///
313    /// ```ignore
314    /// // Called internally by hvp(); users rarely call this directly.
315    /// let results = rule.pullback_with_tangents(&cotangent, &cotangent_tangent)?;
316    /// for (node_id, grad, grad_tangent) in results {
317    ///     // grad: standard cotangent for this input
318    ///     // grad_tangent: cotangent tangent for HVP
319    /// }
320    /// ```
321    fn pullback_with_tangents(
322        &self,
323        cotangent: &V::Tangent,
324        cotangent_tangent: &V::Tangent,
325    ) -> AdResult<Vec<PullbackWithTangentsEntry<V>>> {
326        let _ = (cotangent, cotangent_tangent);
327        Err(AutodiffError::HvpNotSupported)
328    }
329}
330
331/// Forward-mode AD rule interface (frule).
332///
333/// Named after Julia's ChainRules.jl convention: `frule` computes pushforward.
334///
335/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
336///
337/// # Examples
338///
339/// Custom forward rule for scalar multiplication `output = a * b`:
340///
341/// ```
342/// use chainrules_core::{ForwardRule, Differentiable, AdResult};
343///
344/// struct ScalarMulFrule {
345///     a: f64,
346///     b: f64,
347/// }
348///
349/// impl ForwardRule<f64> for ScalarMulFrule {
350///     fn pushforward(&self, tangents: &[Option<&f64>]) -> AdResult<f64> {
351///         // d(a*b) = da*b + a*db
352///         let da = tangents.get(0).and_then(|t| *t).copied().unwrap_or(0.0);
353///         let db = tangents.get(1).and_then(|t| *t).copied().unwrap_or(0.0);
354///         Ok(da * self.b + self.a * db)
355///     }
356/// }
357///
358/// // Verify: for a=3, b=5, da=1, db=0 → d(a*b) = 1*5 + 3*0 = 5
359/// let rule = ScalarMulFrule { a: 3.0, b: 5.0 };
360/// let result = rule.pushforward(&[Some(&1.0), Some(&0.0)]).unwrap();
361/// assert_eq!(result, 5.0);
362///
363/// // Both tangents active: da=1, db=1 → d(a*b) = 1*5 + 3*1 = 8
364/// let result = rule.pushforward(&[Some(&1.0), Some(&1.0)]).unwrap();
365/// assert_eq!(result, 8.0);
366/// ```
367pub trait ForwardRule<V: Differentiable>: Send + Sync {
368    /// Computes output tangent from input tangents (pushforward).
369    fn pushforward(&self, tangents: &[Option<&V::Tangent>]) -> AdResult<V::Tangent>;
370}
371
372// ============================================================================
373// Differentiable impls for primitive types
374// ============================================================================
375
376impl Differentiable for f64 {
377    type Tangent = f64;
378
379    fn zero_tangent(&self) -> f64 {
380        0.0
381    }
382
383    fn accumulate_tangent(a: f64, b: &f64) -> f64 {
384        a + b
385    }
386
387    fn num_elements(&self) -> usize {
388        1
389    }
390
391    fn seed_cotangent(&self) -> f64 {
392        1.0
393    }
394}
395
396impl Differentiable for f32 {
397    type Tangent = f32;
398
399    fn zero_tangent(&self) -> f32 {
400        0.0
401    }
402
403    fn accumulate_tangent(a: f32, b: &f32) -> f32 {
404        a + b
405    }
406
407    fn num_elements(&self) -> usize {
408        1
409    }
410
411    fn seed_cotangent(&self) -> f32 {
412        1.0
413    }
414}
415
416#[cfg(test)]
417mod tests;