chainrules_core/lib.rs
1//! Core AD trait definitions (like Julia's ChainRulesCore.jl).
2//!
3//! This crate defines the interface for automatic differentiation without
4//! providing an AD engine. It contains:
5//!
6//! - [`Differentiable`] — tangent space definition for any value type
7//! - [`ReverseRule`] — per-operation reverse-mode rule (rrule/pullback)
8//! - [`ForwardRule`] — per-operation forward-mode rule (frule/pushforward)
9//! - Error types ([`AutodiffError`], [`AdResult`])
10//! - [`NodeId`], [`SavePolicy`] — graph node identifier and save strategy
11//!
12//! The AD engine (`TrackedTensor`, `DualTensor`, `pullback`, `hvp`) lives in
13//! the [`chainrules`](https://docs.rs/chainrules) crate.
14//!
15//! Operation-specific AD rules (e.g., einsum rrule/frule) live in the crate
16//! that defines the operation.
17//!
18//! # Examples
19//!
20//! Implementing `Differentiable` for a custom type:
21//!
22//! ```ignore
23//! use chainrules_core::Differentiable;
24//!
25//! #[derive(Clone)]
26//! struct MyVec(Vec<f64>);
27//!
28//! impl Differentiable for MyVec {
29//! type Tangent = MyVec;
30//! fn zero_tangent(&self) -> MyVec {
31//! MyVec(vec![0.0; self.0.len()])
32//! }
33//! fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec {
34//! MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect())
35//! }
36//! }
37//! ```
38
39/// Trait defining the tangent space for a differentiable type.
40///
41/// This is the core abstraction of the AD framework, analogous to Julia's
42/// ChainRulesCore.jl tangent type system. Any type that participates in
43/// automatic differentiation must implement this trait.
44///
45/// The tangent type represents infinitesimal perturbations of the value.
46/// For most tensor types, `Tangent = Self` (e.g., the tangent of a matrix
47/// is another matrix of the same shape).
48///
49/// Note: This trait intentionally does **not** require `Clone` on the primal
50/// type. `Clone` is only required on `Tangent` (for gradient accumulation).
51/// Large values (e.g., tensors) may be expensive to clone; the AD engine
52/// avoids cloning primals by taking ownership where needed.
53///
54/// # Examples
55///
56/// ```ignore
57/// use chainrules_core::Differentiable;
58///
59/// // Tensor<f64> implements Differentiable with Tangent = Tensor<f64>
60/// // (defined in tenferro-tensor crate)
61/// fn example<V: Differentiable>(x: &V) {
62/// let zero = x.zero_tangent();
63/// let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent());
64/// }
65/// ```
66pub trait Differentiable {
67 /// The tangent type for this value.
68 ///
69 /// For most types, this is `Self` (e.g., tangent of a tensor is a tensor).
70 type Tangent: Clone;
71
72 /// Returns the zero tangent for this value (additive identity).
73 fn zero_tangent(&self) -> Self::Tangent;
74
75 /// Accumulates (adds) two tangents: `a + b`.
76 fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent;
77}
78
79/// AD-specific error type.
80///
81/// # Examples
82///
83/// ```ignore
84/// use chainrules_core::AutodiffError;
85///
86/// let err = AutodiffError::NonScalarLoss { num_elements: 8 };
87/// assert!(format!("{err}").contains("scalar"));
88/// ```
89#[derive(Debug, thiserror::Error)]
90pub enum AutodiffError {
91 /// Loss tensor for pullback must contain exactly one element.
92 #[error("pullback() requires scalar loss, got {num_elements} elements")]
93 NonScalarLoss { num_elements: usize },
94 /// Attempted pullback on a tensor not connected to AD tape.
95 #[error("tensor is not connected to AD tape")]
96 MissingNode,
97 /// Tangent shape must match primal shape.
98 #[error("tangent shape mismatch: expected {expected}, got {got}")]
99 TangentShapeMismatch {
100 /// Expected shape description.
101 expected: String,
102 /// Actual shape description.
103 got: String,
104 },
105 /// A ReverseRule does not support HVP (pullback_with_tangents).
106 #[error("HVP not supported by this ReverseRule implementation")]
107 HvpNotSupported,
108 /// Generic AD argument error.
109 #[error("invalid autodiff argument: {0}")]
110 InvalidArgument(String),
111}
112
113/// Result alias for AD APIs.
114///
115/// # Examples
116///
117/// ```ignore
118/// use chainrules_core::AdResult;
119///
120/// fn returns_ad_result() -> AdResult<()> { Ok(()) }
121/// ```
122pub type AdResult<T> = std::result::Result<T, AutodiffError>;
123
124/// Stable identifier of an AD graph node.
125///
126/// # Examples
127///
128/// ```
129/// use chainrules_core::NodeId;
130///
131/// let id = NodeId::new(7);
132/// assert_eq!(id.index(), 7);
133/// ```
134#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
135pub struct NodeId(usize);
136
137impl NodeId {
138 /// Creates a node ID from an integer index.
139 ///
140 /// # Examples
141 ///
142 /// ```
143 /// use chainrules_core::NodeId;
144 ///
145 /// let id = NodeId::new(42);
146 /// assert_eq!(id.index(), 42);
147 /// ```
148 pub fn new(index: usize) -> Self {
149 Self(index)
150 }
151
152 /// Returns the numeric index.
153 ///
154 /// # Examples
155 ///
156 /// ```
157 /// use chainrules_core::NodeId;
158 ///
159 /// let id = NodeId::new(3);
160 /// assert_eq!(id.index(), 3);
161 /// ```
162 pub fn index(&self) -> usize {
163 self.0
164 }
165}
166
167/// Saved-tensor retention policy for reverse-mode rules.
168///
169/// # Examples
170///
171/// ```
172/// use chainrules_core::SavePolicy;
173///
174/// let p = SavePolicy::SaveForPullback;
175/// assert_eq!(p, SavePolicy::SaveForPullback);
176/// ```
177#[derive(Debug, Clone, Copy, PartialEq, Eq)]
178pub enum SavePolicy {
179 /// Keep forward tensors for exact pullback formulas.
180 SaveForPullback,
181 /// Discard forward tensors and require recomputation/checkpointing later.
182 RecomputeOnPullback,
183}
184
185/// Reverse-mode AD rule interface (rrule).
186///
187/// Implemented by operation-specific nodes (einsum, reduce, permute, ...).
188/// Named after Julia's ChainRules.jl convention: `rrule` returns a pullback.
189///
190/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
191///
192/// # Examples
193///
194/// ```ignore
195/// use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};
196///
197/// struct MyRule;
198/// impl<V: Differentiable> ReverseRule<V> for MyRule {
199/// fn pullback(&self, cotangent: &V::Tangent)
200/// -> AdResult<Vec<(NodeId, V::Tangent)>> {
201/// todo!()
202/// }
203/// fn inputs(&self) -> Vec<NodeId> { vec![] }
204/// }
205/// ```
206pub trait ReverseRule<V: Differentiable> {
207 /// Computes input cotangents from an output cotangent (pullback).
208 fn pullback(&self, cotangent: &V::Tangent) -> AdResult<Vec<(NodeId, V::Tangent)>>;
209
210 /// Returns input node IDs this rule depends on.
211 fn inputs(&self) -> Vec<NodeId>;
212
213 /// Computes pullback with tangent propagation for HVP.
214 ///
215 /// Given an output cotangent and its tangent, returns
216 /// `(node_id, input_cotangent, input_cotangent_tangent)` triples.
217 ///
218 /// The default implementation returns [`AutodiffError::HvpNotSupported`].
219 /// Operations that support forward-over-reverse HVP override this method.
220 ///
221 /// # Examples
222 ///
223 /// ```ignore
224 /// // Called internally by hvp(); users rarely call this directly.
225 /// let results = rule.pullback_with_tangents(&cotangent, &cotangent_tangent)?;
226 /// for (node_id, grad, grad_tangent) in results {
227 /// // grad: standard cotangent for this input
228 /// // grad_tangent: cotangent tangent for HVP
229 /// }
230 /// ```
231 fn pullback_with_tangents(
232 &self,
233 cotangent: &V::Tangent,
234 cotangent_tangent: &V::Tangent,
235 ) -> AdResult<Vec<(NodeId, V::Tangent, V::Tangent)>> {
236 let _ = (cotangent, cotangent_tangent);
237 Err(AutodiffError::HvpNotSupported)
238 }
239}
240
241/// Forward-mode AD rule interface (frule).
242///
243/// Named after Julia's ChainRules.jl convention: `frule` computes pushforward.
244///
245/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
246///
247/// # Examples
248///
249/// ```ignore
250/// use chainrules_core::{ForwardRule, Differentiable, AdResult};
251///
252/// struct MyFrule;
253/// impl<V: Differentiable> ForwardRule<V> for MyFrule {
254/// fn pushforward(&self, tangents: &[Option<&V::Tangent>])
255/// -> AdResult<V::Tangent> {
256/// todo!()
257/// }
258/// }
259/// ```
260pub trait ForwardRule<V: Differentiable> {
261 /// Computes output tangent from input tangents (pushforward).
262 fn pushforward(&self, tangents: &[Option<&V::Tangent>]) -> AdResult<V::Tangent>;
263}
264
265// ============================================================================
266// Differentiable impls for primitive types
267// ============================================================================
268
269impl Differentiable for f64 {
270 type Tangent = f64;
271
272 fn zero_tangent(&self) -> f64 {
273 0.0
274 }
275
276 fn accumulate_tangent(a: f64, b: &f64) -> f64 {
277 a + b
278 }
279}
280
281impl Differentiable for f32 {
282 type Tangent = f32;
283
284 fn zero_tangent(&self) -> f32 {
285 0.0
286 }
287
288 fn accumulate_tangent(a: f32, b: &f32) -> f32 {
289 a + b
290 }
291}