chainrules_core/lib.rs
1#![doc = include_str!("../README.md")]
2
3//! Core AD trait definitions (like Julia's ChainRulesCore.jl).
4//!
5//! This crate defines the interface for automatic differentiation without
6//! providing an AD engine. It contains:
7//!
8//! - [`Differentiable`] — tangent space definition for any value type
9//! - [`ReverseRule`] — per-operation reverse-mode rule (rrule/pullback)
10//! - [`ForwardRule`] — per-operation forward-mode rule (frule/pushforward)
11//! - Error types ([`AutodiffError`], [`AdResult`])
12//! - [`NodeId`], [`SavePolicy`] — graph node identifier and save strategy
13//!
14//! AD engines (`Tape`, `TrackedValue`, `DualValue`, `pullback`, `hvp`) live in
15//! separate crates, for example [`tidu`](https://docs.rs/tidu).
16//!
17//! Operation-specific AD rules (e.g., einsum rrule/frule) live in the crate
18//! that defines the operation.
19//!
20//! # Examples
21//!
22//! Implementing `Differentiable` for a custom type:
23//!
24//! ```ignore
25//! use chainrules_core::Differentiable;
26//!
27//! #[derive(Clone)]
28//! struct MyVec(Vec<f64>);
29//!
30//! impl Differentiable for MyVec {
31//! type Tangent = MyVec;
32//! fn zero_tangent(&self) -> MyVec {
33//! MyVec(vec![0.0; self.0.len()])
34//! }
35//! fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec {
36//! MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect())
37//! }
38//! fn num_elements(&self) -> usize {
39//! self.0.len()
40//! }
41//! fn seed_cotangent(&self) -> MyVec {
42//! MyVec(vec![1.0; self.0.len()])
43//! }
44//! }
45//! ```
46
47/// Trait defining the tangent space for a differentiable type.
48///
49/// This is the core abstraction of the AD framework, analogous to Julia's
50/// ChainRulesCore.jl tangent type system. Any type that participates in
51/// automatic differentiation must implement this trait.
52///
53/// The tangent type represents infinitesimal perturbations of the value.
54/// For most tensor types, `Tangent = Self` (e.g., the tangent of a matrix
55/// is another matrix of the same shape).
56///
57/// Note: This trait intentionally does **not** require `Clone` on the primal
58/// type. `Clone` is only required on `Tangent` (for gradient accumulation).
59/// Large values (e.g., tensors) may be expensive to clone; the AD engine
60/// avoids cloning primals by taking ownership where needed.
61///
62/// # Examples
63///
64/// ```ignore
65/// use chainrules_core::Differentiable;
66///
67/// // Tensor<f64> implements Differentiable with Tangent = Tensor<f64>
68/// // (defined in tenferro-tensor crate)
69/// fn example<V: Differentiable>(x: &V) {
70/// let zero = x.zero_tangent();
71/// let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent());
72/// }
73/// ```
74pub trait Differentiable {
75 /// The tangent type for this value.
76 ///
77 /// For most types, this is `Self` (e.g., tangent of a tensor is a tensor).
78 type Tangent: Clone;
79
80 /// Returns the zero tangent for this value (additive identity).
81 fn zero_tangent(&self) -> Self::Tangent;
82
83 /// Accumulates (adds) two tangents: `a + b`.
84 fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent;
85
86 /// Returns the number of scalar elements in this value.
87 ///
88 /// For scalar types (f64, f32), this is always 1.
89 /// For tensor types, this is the total number of elements.
90 fn num_elements(&self) -> usize;
91
92 /// Returns the seed cotangent for reverse-mode pullback.
93 ///
94 /// For a scalar loss, this returns the "one" tangent (1.0 for scalars,
95 /// ones-like for single-element tensors). Used internally by
96 /// [`Tape::pullback`](https://docs.rs/chainrules) to initialize the
97 /// backward pass.
98 fn seed_cotangent(&self) -> Self::Tangent;
99}
100
101/// AD-specific error type.
102///
103/// # Examples
104///
105/// ```
106/// use chainrules_core::AutodiffError;
107///
108/// let err = AutodiffError::NonScalarLoss { num_elements: 8 };
109/// assert!(format!("{err}").contains("scalar"));
110/// ```
111#[derive(Debug, thiserror::Error)]
112pub enum AutodiffError {
113 /// Loss tensor for pullback must contain exactly one element.
114 #[error("pullback() requires scalar loss, got {num_elements} elements")]
115 NonScalarLoss { num_elements: usize },
116 /// Attempted pullback on a tensor not connected to AD tape.
117 #[error("tensor is not connected to AD tape")]
118 MissingNode,
119 /// Tangent shape must match primal shape.
120 #[error("tangent shape mismatch: expected {expected}, got {got}")]
121 TangentShapeMismatch {
122 /// Expected shape description.
123 expected: String,
124 /// Actual shape description.
125 got: String,
126 },
127 /// A ReverseRule does not support HVP (pullback_with_tangents).
128 #[error("HVP not supported by this ReverseRule implementation")]
129 HvpNotSupported,
130 /// The requested AD mode is not supported for the given algebra or operation.
131 ///
132 /// For example, tropical einsum does not support frule (JVP) or hvp —
133 /// only rrule (VJP) via the argmax route is available.
134 ///
135 /// # Examples
136 ///
137 /// ```
138 /// use chainrules_core::AutodiffError;
139 ///
140 /// let err = AutodiffError::ModeNotSupported {
141 /// mode: "frule".into(),
142 /// reason: "tropical einsum supports rrule only (max is not smooth)".into(),
143 /// };
144 /// ```
145 #[error("AD mode not supported: {mode} — {reason}")]
146 ModeNotSupported {
147 /// The unsupported mode (e.g., "frule", "hvp").
148 mode: String,
149 /// Explanation of why this mode is not supported.
150 reason: String,
151 },
152 /// Generic AD argument error.
153 #[error("invalid autodiff argument: {0}")]
154 InvalidArgument(String),
155 /// Attempted to execute backward/grad on a graph that was already freed.
156 ///
157 /// # Examples
158 ///
159 /// ```
160 /// use chainrules_core::AutodiffError;
161 ///
162 /// let err = AutodiffError::GraphFreed;
163 /// assert!(err.to_string().contains("freed"));
164 /// ```
165 #[error("computation graph has been freed")]
166 GraphFreed,
167}
168
169/// Result alias for AD APIs.
170///
171/// # Examples
172///
173/// ```
174/// use chainrules_core::AdResult;
175///
176/// fn returns_ad_result() -> AdResult<()> { Ok(()) }
177/// ```
178pub type AdResult<T> = std::result::Result<T, AutodiffError>;
179
180/// Reverse-rule pullback output entry `(input_node, input_cotangent)`.
181pub type PullbackEntry<V> = (NodeId, <V as Differentiable>::Tangent);
182
183/// Reverse-rule pullback-with-tangents output entry.
184///
185/// Tuple layout: `(input_node, input_cotangent, input_cotangent_tangent)`.
186pub type PullbackWithTangentsEntry<V> = (
187 NodeId,
188 <V as Differentiable>::Tangent,
189 <V as Differentiable>::Tangent,
190);
191
192/// Stable identifier of an AD graph node.
193///
194/// # Examples
195///
196/// ```
197/// use chainrules_core::NodeId;
198///
199/// let id = NodeId::new(7);
200/// assert_eq!(id.index(), 7);
201/// ```
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
203pub struct NodeId(usize);
204
205impl NodeId {
206 /// Creates a node ID from an integer index.
207 ///
208 /// # Examples
209 ///
210 /// ```
211 /// use chainrules_core::NodeId;
212 ///
213 /// let id = NodeId::new(42);
214 /// assert_eq!(id.index(), 42);
215 /// ```
216 pub fn new(index: usize) -> Self {
217 Self(index)
218 }
219
220 /// Returns the numeric index.
221 ///
222 /// # Examples
223 ///
224 /// ```
225 /// use chainrules_core::NodeId;
226 ///
227 /// let id = NodeId::new(3);
228 /// assert_eq!(id.index(), 3);
229 /// ```
230 pub fn index(&self) -> usize {
231 self.0
232 }
233}
234
235/// Saved-tensor retention policy for reverse-mode rules.
236///
237/// # Examples
238///
239/// ```
240/// use chainrules_core::SavePolicy;
241///
242/// let p = SavePolicy::SaveForPullback;
243/// assert_eq!(p, SavePolicy::SaveForPullback);
244/// ```
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum SavePolicy {
247 /// Keep forward tensors for exact pullback formulas.
248 SaveForPullback,
249 /// Discard forward tensors and require recomputation/checkpointing later.
250 RecomputeOnPullback,
251}
252
253/// Reverse-mode AD rule interface (rrule).
254///
255/// Implemented by operation-specific nodes (einsum, reduce, permute, ...).
256/// Named after Julia's ChainRules.jl convention: `rrule` returns a pullback.
257///
258/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
259///
260/// # Examples
261///
262/// Custom reverse rule for scalar multiplication `output = a * b`:
263///
264/// ```
265/// use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};
266///
267/// struct ScalarMulRule {
268/// a: f64,
269/// b: f64,
270/// a_node: NodeId,
271/// b_node: NodeId,
272/// }
273///
274/// impl ReverseRule<f64> for ScalarMulRule {
275/// fn pullback(&self, cotangent: &f64) -> AdResult<Vec<(NodeId, f64)>> {
276/// // d(a*b)/da = b, d(a*b)/db = a
277/// let da = cotangent * self.b;
278/// let db = cotangent * self.a;
279/// Ok(vec![(self.a_node, da), (self.b_node, db)])
280/// }
281///
282/// fn inputs(&self) -> Vec<NodeId> {
283/// vec![self.a_node, self.b_node]
284/// }
285/// }
286///
287/// // Verify: for a=3, b=5, cotangent=1 → da=5, db=3
288/// let rule = ScalarMulRule {
289/// a: 3.0, b: 5.0,
290/// a_node: NodeId::new(0), b_node: NodeId::new(1),
291/// };
292/// let grads = rule.pullback(&1.0).unwrap();
293/// assert_eq!(grads[0], (NodeId::new(0), 5.0)); // da = cotangent * b
294/// assert_eq!(grads[1], (NodeId::new(1), 3.0)); // db = cotangent * a
295/// ```
296pub trait ReverseRule<V: Differentiable>: Send + Sync {
297 /// Computes input cotangents from an output cotangent (pullback).
298 fn pullback(&self, cotangent: &V::Tangent) -> AdResult<Vec<PullbackEntry<V>>>;
299
300 /// Returns input node IDs this rule depends on.
301 fn inputs(&self) -> Vec<NodeId>;
302
303 /// Computes pullback with tangent propagation for HVP.
304 ///
305 /// Given an output cotangent and its tangent, returns
306 /// `(node_id, input_cotangent, input_cotangent_tangent)` triples.
307 ///
308 /// The default implementation returns [`AutodiffError::HvpNotSupported`].
309 /// Operations that support forward-over-reverse HVP override this method.
310 ///
311 /// # Examples
312 ///
313 /// ```ignore
314 /// // Called internally by hvp(); users rarely call this directly.
315 /// let results = rule.pullback_with_tangents(&cotangent, &cotangent_tangent)?;
316 /// for (node_id, grad, grad_tangent) in results {
317 /// // grad: standard cotangent for this input
318 /// // grad_tangent: cotangent tangent for HVP
319 /// }
320 /// ```
321 fn pullback_with_tangents(
322 &self,
323 cotangent: &V::Tangent,
324 cotangent_tangent: &V::Tangent,
325 ) -> AdResult<Vec<PullbackWithTangentsEntry<V>>> {
326 let _ = (cotangent, cotangent_tangent);
327 Err(AutodiffError::HvpNotSupported)
328 }
329}
330
331/// Forward-mode AD rule interface (frule).
332///
333/// Named after Julia's ChainRules.jl convention: `frule` computes pushforward.
334///
335/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
336///
337/// # Examples
338///
339/// Custom forward rule for scalar multiplication `output = a * b`:
340///
341/// ```
342/// use chainrules_core::{ForwardRule, Differentiable, AdResult};
343///
344/// struct ScalarMulFrule {
345/// a: f64,
346/// b: f64,
347/// }
348///
349/// impl ForwardRule<f64> for ScalarMulFrule {
350/// fn pushforward(&self, tangents: &[Option<&f64>]) -> AdResult<f64> {
351/// // d(a*b) = da*b + a*db
352/// let da = tangents.get(0).and_then(|t| *t).copied().unwrap_or(0.0);
353/// let db = tangents.get(1).and_then(|t| *t).copied().unwrap_or(0.0);
354/// Ok(da * self.b + self.a * db)
355/// }
356/// }
357///
358/// // Verify: for a=3, b=5, da=1, db=0 → d(a*b) = 1*5 + 3*0 = 5
359/// let rule = ScalarMulFrule { a: 3.0, b: 5.0 };
360/// let result = rule.pushforward(&[Some(&1.0), Some(&0.0)]).unwrap();
361/// assert_eq!(result, 5.0);
362///
363/// // Both tangents active: da=1, db=1 → d(a*b) = 1*5 + 3*1 = 8
364/// let result = rule.pushforward(&[Some(&1.0), Some(&1.0)]).unwrap();
365/// assert_eq!(result, 8.0);
366/// ```
367pub trait ForwardRule<V: Differentiable>: Send + Sync {
368 /// Computes output tangent from input tangents (pushforward).
369 fn pushforward(&self, tangents: &[Option<&V::Tangent>]) -> AdResult<V::Tangent>;
370}
371
372// ============================================================================
373// Differentiable impls for primitive types
374// ============================================================================
375
376impl Differentiable for f64 {
377 type Tangent = f64;
378
379 fn zero_tangent(&self) -> f64 {
380 0.0
381 }
382
383 fn accumulate_tangent(a: f64, b: &f64) -> f64 {
384 a + b
385 }
386
387 fn num_elements(&self) -> usize {
388 1
389 }
390
391 fn seed_cotangent(&self) -> f64 {
392 1.0
393 }
394}
395
396impl Differentiable for f32 {
397 type Tangent = f32;
398
399 fn zero_tangent(&self) -> f32 {
400 0.0
401 }
402
403 fn accumulate_tangent(a: f32, b: &f32) -> f32 {
404 a + b
405 }
406
407 fn num_elements(&self) -> usize {
408 1
409 }
410
411 fn seed_cotangent(&self) -> f32 {
412 1.0
413 }
414}
415
416#[cfg(test)]
417mod tests;