chainrules_core/
lib.rs

1//! Core AD trait definitions (like Julia's ChainRulesCore.jl).
2//!
3//! This crate defines the interface for automatic differentiation without
4//! providing an AD engine. It contains:
5//!
6//! - [`Differentiable`] — tangent space definition for any value type
7//! - [`ReverseRule`] — per-operation reverse-mode rule (rrule/pullback)
8//! - [`ForwardRule`] — per-operation forward-mode rule (frule/pushforward)
9//! - Error types ([`AutodiffError`], [`AdResult`])
10//! - [`NodeId`], [`SavePolicy`] — graph node identifier and save strategy
11//!
12//! The AD engine (`TrackedTensor`, `DualTensor`, `pullback`, `hvp`) lives in
13//! the [`chainrules`](https://docs.rs/chainrules) crate.
14//!
15//! Operation-specific AD rules (e.g., einsum rrule/frule) live in the crate
16//! that defines the operation.
17//!
18//! # Examples
19//!
20//! Implementing `Differentiable` for a custom type:
21//!
22//! ```ignore
23//! use chainrules_core::Differentiable;
24//!
25//! #[derive(Clone)]
26//! struct MyVec(Vec<f64>);
27//!
28//! impl Differentiable for MyVec {
29//!     type Tangent = MyVec;
30//!     fn zero_tangent(&self) -> MyVec {
31//!         MyVec(vec![0.0; self.0.len()])
32//!     }
33//!     fn accumulate_tangent(a: MyVec, b: &MyVec) -> MyVec {
34//!         MyVec(a.0.iter().zip(&b.0).map(|(x, y)| x + y).collect())
35//!     }
36//! }
37//! ```
38
39/// Trait defining the tangent space for a differentiable type.
40///
41/// This is the core abstraction of the AD framework, analogous to Julia's
42/// ChainRulesCore.jl tangent type system. Any type that participates in
43/// automatic differentiation must implement this trait.
44///
45/// The tangent type represents infinitesimal perturbations of the value.
46/// For most tensor types, `Tangent = Self` (e.g., the tangent of a matrix
47/// is another matrix of the same shape).
48///
49/// Note: This trait intentionally does **not** require `Clone` on the primal
50/// type. `Clone` is only required on `Tangent` (for gradient accumulation).
51/// Large values (e.g., tensors) may be expensive to clone; the AD engine
52/// avoids cloning primals by taking ownership where needed.
53///
54/// # Examples
55///
56/// ```ignore
57/// use chainrules_core::Differentiable;
58///
59/// // Tensor<f64> implements Differentiable with Tangent = Tensor<f64>
60/// // (defined in tenferro-tensor crate)
61/// fn example<V: Differentiable>(x: &V) {
62///     let zero = x.zero_tangent();
63///     let _acc = V::accumulate_tangent(zero.clone(), &x.zero_tangent());
64/// }
65/// ```
66pub trait Differentiable {
67    /// The tangent type for this value.
68    ///
69    /// For most types, this is `Self` (e.g., tangent of a tensor is a tensor).
70    type Tangent: Clone;
71
72    /// Returns the zero tangent for this value (additive identity).
73    fn zero_tangent(&self) -> Self::Tangent;
74
75    /// Accumulates (adds) two tangents: `a + b`.
76    fn accumulate_tangent(a: Self::Tangent, b: &Self::Tangent) -> Self::Tangent;
77}
78
79/// AD-specific error type.
80///
81/// # Examples
82///
83/// ```ignore
84/// use chainrules_core::AutodiffError;
85///
86/// let err = AutodiffError::NonScalarLoss { num_elements: 8 };
87/// assert!(format!("{err}").contains("scalar"));
88/// ```
89#[derive(Debug, thiserror::Error)]
90pub enum AutodiffError {
91    /// Loss tensor for pullback must contain exactly one element.
92    #[error("pullback() requires scalar loss, got {num_elements} elements")]
93    NonScalarLoss { num_elements: usize },
94    /// Attempted pullback on a tensor not connected to AD tape.
95    #[error("tensor is not connected to AD tape")]
96    MissingNode,
97    /// Tangent shape must match primal shape.
98    #[error("tangent shape mismatch: expected {expected}, got {got}")]
99    TangentShapeMismatch {
100        /// Expected shape description.
101        expected: String,
102        /// Actual shape description.
103        got: String,
104    },
105    /// A ReverseRule does not support HVP (pullback_with_tangents).
106    #[error("HVP not supported by this ReverseRule implementation")]
107    HvpNotSupported,
108    /// Generic AD argument error.
109    #[error("invalid autodiff argument: {0}")]
110    InvalidArgument(String),
111}
112
113/// Result alias for AD APIs.
114///
115/// # Examples
116///
117/// ```ignore
118/// use chainrules_core::AdResult;
119///
120/// fn returns_ad_result() -> AdResult<()> { Ok(()) }
121/// ```
122pub type AdResult<T> = std::result::Result<T, AutodiffError>;
123
124/// Stable identifier of an AD graph node.
125///
126/// # Examples
127///
128/// ```
129/// use chainrules_core::NodeId;
130///
131/// let id = NodeId::new(7);
132/// assert_eq!(id.index(), 7);
133/// ```
134#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
135pub struct NodeId(usize);
136
137impl NodeId {
138    /// Creates a node ID from an integer index.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// use chainrules_core::NodeId;
144    ///
145    /// let id = NodeId::new(42);
146    /// assert_eq!(id.index(), 42);
147    /// ```
148    pub fn new(index: usize) -> Self {
149        Self(index)
150    }
151
152    /// Returns the numeric index.
153    ///
154    /// # Examples
155    ///
156    /// ```
157    /// use chainrules_core::NodeId;
158    ///
159    /// let id = NodeId::new(3);
160    /// assert_eq!(id.index(), 3);
161    /// ```
162    pub fn index(&self) -> usize {
163        self.0
164    }
165}
166
167/// Saved-tensor retention policy for reverse-mode rules.
168///
169/// # Examples
170///
171/// ```
172/// use chainrules_core::SavePolicy;
173///
174/// let p = SavePolicy::SaveForPullback;
175/// assert_eq!(p, SavePolicy::SaveForPullback);
176/// ```
177#[derive(Debug, Clone, Copy, PartialEq, Eq)]
178pub enum SavePolicy {
179    /// Keep forward tensors for exact pullback formulas.
180    SaveForPullback,
181    /// Discard forward tensors and require recomputation/checkpointing later.
182    RecomputeOnPullback,
183}
184
185/// Reverse-mode AD rule interface (rrule).
186///
187/// Implemented by operation-specific nodes (einsum, reduce, permute, ...).
188/// Named after Julia's ChainRules.jl convention: `rrule` returns a pullback.
189///
190/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
191///
192/// # Examples
193///
194/// ```ignore
195/// use chainrules_core::{ReverseRule, Differentiable, AdResult, NodeId};
196///
197/// struct MyRule;
198/// impl<V: Differentiable> ReverseRule<V> for MyRule {
199///     fn pullback(&self, cotangent: &V::Tangent)
200///         -> AdResult<Vec<(NodeId, V::Tangent)>> {
201///         todo!()
202///     }
203///     fn inputs(&self) -> Vec<NodeId> { vec![] }
204/// }
205/// ```
206pub trait ReverseRule<V: Differentiable> {
207    /// Computes input cotangents from an output cotangent (pullback).
208    fn pullback(&self, cotangent: &V::Tangent) -> AdResult<Vec<(NodeId, V::Tangent)>>;
209
210    /// Returns input node IDs this rule depends on.
211    fn inputs(&self) -> Vec<NodeId>;
212
213    /// Computes pullback with tangent propagation for HVP.
214    ///
215    /// Given an output cotangent and its tangent, returns
216    /// `(node_id, input_cotangent, input_cotangent_tangent)` triples.
217    ///
218    /// The default implementation returns [`AutodiffError::HvpNotSupported`].
219    /// Operations that support forward-over-reverse HVP override this method.
220    ///
221    /// # Examples
222    ///
223    /// ```ignore
224    /// // Called internally by hvp(); users rarely call this directly.
225    /// let results = rule.pullback_with_tangents(&cotangent, &cotangent_tangent)?;
226    /// for (node_id, grad, grad_tangent) in results {
227    ///     // grad: standard cotangent for this input
228    ///     // grad_tangent: cotangent tangent for HVP
229    /// }
230    /// ```
231    fn pullback_with_tangents(
232        &self,
233        cotangent: &V::Tangent,
234        cotangent_tangent: &V::Tangent,
235    ) -> AdResult<Vec<(NodeId, V::Tangent, V::Tangent)>> {
236        let _ = (cotangent, cotangent_tangent);
237        Err(AutodiffError::HvpNotSupported)
238    }
239}
240
241/// Forward-mode AD rule interface (frule).
242///
243/// Named after Julia's ChainRules.jl convention: `frule` computes pushforward.
244///
245/// The type parameter `V` is the differentiable value type (e.g., `Tensor<f64>`).
246///
247/// # Examples
248///
249/// ```ignore
250/// use chainrules_core::{ForwardRule, Differentiable, AdResult};
251///
252/// struct MyFrule;
253/// impl<V: Differentiable> ForwardRule<V> for MyFrule {
254///     fn pushforward(&self, tangents: &[Option<&V::Tangent>])
255///         -> AdResult<V::Tangent> {
256///         todo!()
257///     }
258/// }
259/// ```
260pub trait ForwardRule<V: Differentiable> {
261    /// Computes output tangent from input tangents (pushforward).
262    fn pushforward(&self, tangents: &[Option<&V::Tangent>]) -> AdResult<V::Tangent>;
263}
264
265// ============================================================================
266// Differentiable impls for primitive types
267// ============================================================================
268
269impl Differentiable for f64 {
270    type Tangent = f64;
271
272    fn zero_tangent(&self) -> f64 {
273        0.0
274    }
275
276    fn accumulate_tangent(a: f64, b: &f64) -> f64 {
277        a + b
278    }
279}
280
281impl Differentiable for f32 {
282    type Tangent = f32;
283
284    fn zero_tangent(&self) -> f32 {
285        0.0
286    }
287
288    fn accumulate_tangent(a: f32, b: &f32) -> f32 {
289        a + b
290    }
291}