chainrules/
lib.rs

1//! AD engine: tape-based reverse-mode and dual-number forward-mode.
2//!
3//! This crate provides the AD execution engine, built on top of
4//! [`chainrules_core`] traits. It is analogous to Zygote.jl in the Julia
5//! ecosystem: a concrete AD engine that uses ChainRulesCore.jl interfaces.
6//!
7//! - Reverse-mode AD via [`Tape`], [`TrackedTensor`], and [`Tape::pullback`]
8//! - Forward-mode AD via [`DualTensor`]
9//! - Forward-over-reverse HVP via [`Tape::hvp`]
10//!
11//! Operation-specific AD rules (e.g., einsum rrule/frule) live in the crate
12//! that defines the operation. See `tenferro-einsum` for einsum AD functions.
13//!
14//! Bodies are intentionally `todo!()` in the current POC phase.
15//!
16//! # Examples
17//!
18//! Reverse-mode usage (with operation-specific AD functions from other crates):
19//!
20//! ```ignore
21//! use chainrules::{Tape, TrackedTensor};
22//! use tenferro_einsum::tracked_einsum;
23//! use tenferro_tensor::{MemoryOrder, Tensor};
24//! use tenferro_device::LogicalMemorySpace;
25//!
26//! let tape = Tape::<Tensor<f64>>::new();
27//! let a = tape.leaf(Tensor::ones(
28//!     &[2, 3],
29//!     LogicalMemorySpace::MainMemory,
30//!     MemoryOrder::ColumnMajor,
31//! ));
32//! let b = tape.leaf(Tensor::ones(
33//!     &[3, 4],
34//!     LogicalMemorySpace::MainMemory,
35//!     MemoryOrder::ColumnMajor,
36//! ));
37//! let c = tracked_einsum("ij,jk->ik", &[&a, &b]).unwrap();
38//! let loss = tracked_einsum("ij,ij->", &[&c, &c]).unwrap();
39//! let grads = tape.pullback(&loss).unwrap();
40//! let _ga = grads.get(a.node_id().unwrap()).unwrap();
41//! ```
42//!
43//! Forward-mode usage:
44//!
45//! ```ignore
46//! use chainrules::DualTensor;
47//! use tenferro_einsum::dual_einsum;
48//! use tenferro_tensor::{MemoryOrder, Tensor};
49//!
50//! let a = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor).unwrap();
51//! let da = Tensor::<f64>::ones(&[2, 2], tenferro_device::LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
52//! let b = Tensor::<f64>::ones(&[2, 2], tenferro_device::LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
53//!
54//! let a_dual = DualTensor::with_tangent(a, da).unwrap();
55//! let b_dual = DualTensor::new(b);
56//! let c_dual = dual_einsum("ij,jk->ik", &[&a_dual, &b_dual]).unwrap();
57//! let _jvp = c_dual.tangent();
58//! ```
59//!
60//! Forward-over-reverse HVP (Hessian-vector product):
61//!
62//! ```ignore
63//! use chainrules::Tape;
64//! use tenferro_einsum::tracked_einsum;
65//! use tenferro_tensor::{MemoryOrder, Tensor};
66//! use tenferro_device::LogicalMemorySpace;
67//!
68//! let tape = Tape::<Tensor<f64>>::new();
69//! let x = tape.leaf_with_tangent(
70//!     Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
71//!     Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),  // direction v
72//! ).unwrap();
73//! let loss = tracked_einsum("i,i->", &[&x, &x]).unwrap();  // f(x) = x·x
74//! let result = tape.hvp(&loss).unwrap();
75//! let _grad = result.gradients;  // ∇f(x) = 2x
76//! let _hv = result.hvp;          // H·v = 2v
77//! ```
78
79// Re-export all core traits so downstream can depend on just `chainrules`.
80pub use chainrules_core::*;
81
82use std::marker::PhantomData;
83
84/// Reverse-mode AD tape.
85///
86/// The tape records operations performed on [`TrackedTensor`] values and
87/// enables gradient computation via [`Tape::pullback`] or HVP via
88/// [`Tape::hvp`].
89///
90/// Create leaf values with [`Tape::leaf`], perform operations using
91/// AD-aware functions (e.g., `tracked_einsum`), then call
92/// [`Tape::pullback`] on the scalar loss to compute gradients.
93///
94/// `Tape` is cheaply cloneable (internally reference-counted). Multiple
95/// clones refer to the same underlying tape.
96///
97/// # Examples
98///
99/// ```ignore
100/// use chainrules::Tape;
101/// use tenferro_einsum::tracked_einsum;
102/// use tenferro_tensor::{MemoryOrder, Tensor};
103/// use tenferro_device::LogicalMemorySpace;
104///
105/// let tape = Tape::<Tensor<f64>>::new();
106/// let a = tape.leaf(Tensor::ones(
107///     &[2, 3],
108///     LogicalMemorySpace::MainMemory,
109///     MemoryOrder::ColumnMajor,
110/// ));
111/// let b = tape.leaf(Tensor::ones(
112///     &[3, 4],
113///     LogicalMemorySpace::MainMemory,
114///     MemoryOrder::ColumnMajor,
115/// ));
116/// let c = tracked_einsum("ij,jk->ik", &[&a, &b]).unwrap();
117/// let loss = tracked_einsum("ij,ij->", &[&c, &c]).unwrap();
118/// let grads = tape.pullback(&loss).unwrap();
119/// let _ga = grads.get(a.node_id().unwrap()).unwrap();
120/// ```
121pub struct Tape<V: Differentiable> {
122    _marker: PhantomData<V>,
123}
124
125impl<V: Differentiable> Tape<V> {
126    /// Creates a new empty tape.
127    ///
128    /// # Examples
129    ///
130    /// ```ignore
131    /// use chainrules::Tape;
132    ///
133    /// let tape = Tape::<f64>::new();
134    /// ```
135    pub fn new() -> Self {
136        Self {
137            _marker: PhantomData,
138        }
139    }
140
141    /// Creates a leaf value requiring gradient on this tape.
142    ///
143    /// The returned [`TrackedTensor`] is connected to this tape and
144    /// will participate in gradient computation via [`Tape::pullback`].
145    ///
146    /// # Examples
147    ///
148    /// ```ignore
149    /// use chainrules::Tape;
150    ///
151    /// let tape = Tape::<f64>::new();
152    /// let x = tape.leaf(3.14);
153    /// assert!(x.requires_grad());
154    /// ```
155    pub fn leaf(&self, _value: V) -> TrackedTensor<V> {
156        todo!()
157    }
158
159    /// Creates a leaf value with a tangent for HVP computation.
160    ///
161    /// The tangent defines the perturbation direction *v* used in
162    /// forward-over-reverse Hessian-vector product computation.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`AutodiffError::TangentShapeMismatch`] if shapes differ.
167    ///
168    /// # Examples
169    ///
170    /// ```ignore
171    /// use chainrules::Tape;
172    ///
173    /// let tape = Tape::<f64>::new();
174    /// let x = tape.leaf_with_tangent(3.14, 1.0).unwrap();
175    /// assert!(x.requires_grad());
176    /// assert!(x.has_tangent());
177    /// ```
178    pub fn leaf_with_tangent(&self, _value: V, _tangent: V::Tangent) -> AdResult<TrackedTensor<V>> {
179        todo!()
180    }
181
182    /// Runs reverse-mode pullback from a scalar loss value.
183    ///
184    /// # Errors
185    ///
186    /// Returns [`AutodiffError::NonScalarLoss`] for non-scalar losses.
187    /// Returns [`AutodiffError::MissingNode`] if the loss is not connected
188    /// to this tape.
189    ///
190    /// # Examples
191    ///
192    /// ```ignore
193    /// use chainrules::Tape;
194    ///
195    /// let tape = Tape::<f64>::new();
196    /// let x = tape.leaf(2.0);
197    /// // ... compute loss from x ...
198    /// let grads = tape.pullback(&x).unwrap();
199    /// ```
200    pub fn pullback(&self, _loss: &TrackedTensor<V>) -> AdResult<Gradients<V>> {
201        todo!()
202    }
203
204    /// Computes gradient and Hessian-vector product via forward-over-reverse.
205    ///
206    /// Leaf values with tangents (created via [`Tape::leaf_with_tangent`])
207    /// define the direction *v*. The function runs pullback through the tape,
208    /// propagating both cotangents and cotangent-tangents at each node.
209    ///
210    /// Returns both the gradient (in [`HvpResult::gradients`]) and H*v (in
211    /// [`HvpResult::hvp`]).
212    ///
213    /// # Errors
214    ///
215    /// Returns [`AutodiffError::NonScalarLoss`] for non-scalar losses.
216    /// Returns [`AutodiffError::HvpNotSupported`] if any ReverseRule on the tape
217    /// does not implement `pullback_with_tangents`.
218    ///
219    /// # Examples
220    ///
221    /// ```ignore
222    /// use chainrules::Tape;
223    /// use tenferro_einsum::tracked_einsum;
224    /// use tenferro_tensor::{MemoryOrder, Tensor};
225    /// use tenferro_device::LogicalMemorySpace;
226    ///
227    /// let tape = Tape::<Tensor<f64>>::new();
228    /// let x = tape.leaf_with_tangent(
229    ///     Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
230    ///     Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
231    /// ).unwrap();
232    /// let loss = tracked_einsum("i,i->", &[&x, &x]).unwrap();
233    /// let result = tape.hvp(&loss).unwrap();
234    /// let _grad = result.gradients;
235    /// let _hv = result.hvp;
236    /// ```
237    pub fn hvp(&self, _loss: &TrackedTensor<V>) -> AdResult<HvpResult<V>> {
238        todo!()
239    }
240}
241
242impl<V: Differentiable> Clone for Tape<V> {
243    fn clone(&self) -> Self {
244        Self {
245            _marker: PhantomData,
246        }
247    }
248}
249
250impl<V: Differentiable> Default for Tape<V> {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256/// Value wrapper for reverse-mode AD.
257///
258/// Wraps any [`Differentiable`] value and connects it to a [`Tape`]
259/// for gradient computation.
260///
261/// Created via [`Tape::leaf`] for gradient-tracked values, or
262/// [`TrackedTensor::new`] for values that do not require gradients.
263///
264/// # Examples
265///
266/// ```ignore
267/// use chainrules::{Tape, TrackedTensor};
268/// use tenferro_tensor::{MemoryOrder, Tensor};
269/// use tenferro_device::LogicalMemorySpace;
270///
271/// let tape = Tape::<Tensor<f64>>::new();
272/// let a = tape.leaf(Tensor::ones(
273///     &[2, 3],
274///     LogicalMemorySpace::MainMemory,
275///     MemoryOrder::ColumnMajor,
276/// ));
277/// assert!(a.requires_grad());
278/// ```
279pub struct TrackedTensor<V: Differentiable> {
280    value: V,
281    node_id: Option<NodeId>,
282    tape: Option<Tape<V>>,
283    requires_grad: bool,
284    tangent: Option<V::Tangent>,
285}
286
287impl<V: Differentiable> TrackedTensor<V> {
288    /// Creates a tracked value with `requires_grad = false` (no tape).
289    ///
290    /// # Examples
291    ///
292    /// ```ignore
293    /// use chainrules::TrackedTensor;
294    /// let x = TrackedTensor::new(value);
295    /// assert!(!x.requires_grad());
296    /// ```
297    pub fn new(value: V) -> Self {
298        Self {
299            value,
300            node_id: None,
301            tape: None,
302            requires_grad: false,
303            tangent: None,
304        }
305    }
306
307    /// Returns the underlying value.
308    ///
309    /// # Examples
310    ///
311    /// ```ignore
312    /// let v = tracked.value();
313    /// ```
314    pub fn value(&self) -> &V {
315        &self.value
316    }
317
318    /// Consumes and returns the underlying value.
319    ///
320    /// # Examples
321    ///
322    /// ```ignore
323    /// let raw = tracked.into_value();
324    /// ```
325    pub fn into_value(self) -> V {
326        self.value
327    }
328
329    /// Returns whether this value participates in gradient propagation.
330    ///
331    /// # Examples
332    ///
333    /// ```ignore
334    /// assert!(tracked.requires_grad());
335    /// ```
336    pub fn requires_grad(&self) -> bool {
337        self.requires_grad
338    }
339
340    /// Returns the graph node ID when this value is connected to a tape.
341    ///
342    /// # Examples
343    ///
344    /// ```ignore
345    /// if let Some(id) = tracked.node_id() {
346    ///     println!("node = {}", id.index());
347    /// }
348    /// ```
349    pub fn node_id(&self) -> Option<NodeId> {
350        self.node_id
351    }
352
353    /// Returns the tangent for HVP, or `None` if not set.
354    ///
355    /// # Examples
356    ///
357    /// ```ignore
358    /// if let Some(t) = tracked.tangent() {
359    ///     // use tangent
360    /// }
361    /// ```
362    pub fn tangent(&self) -> Option<&V::Tangent> {
363        self.tangent.as_ref()
364    }
365
366    /// Returns whether this tracked value has a tangent for HVP.
367    ///
368    /// # Examples
369    ///
370    /// ```ignore
371    /// assert!(tracked.has_tangent());
372    /// ```
373    pub fn has_tangent(&self) -> bool {
374        self.tangent.is_some()
375    }
376
377    /// Consumes and returns a detached value that does not require gradients.
378    ///
379    /// # Examples
380    ///
381    /// ```ignore
382    /// let detached = tracked.detach();
383    /// assert!(!detached.requires_grad());
384    /// ```
385    pub fn detach(self) -> Self {
386        todo!()
387    }
388}
389
390/// Value wrapper for forward-mode AD.
391///
392/// # Examples
393///
394/// ```ignore
395/// use chainrules::DualTensor;
396/// let dual = DualTensor::new(primal);
397/// assert!(!dual.has_tangent());
398/// ```
399pub struct DualTensor<V: Differentiable> {
400    primal: V,
401    tangent: Option<V::Tangent>,
402}
403
404impl<V: Differentiable> DualTensor<V> {
405    /// Creates a dual value with zero tangent.
406    ///
407    /// # Examples
408    ///
409    /// ```ignore
410    /// use chainrules::DualTensor;
411    /// let x = DualTensor::new(primal);
412    /// ```
413    pub fn new(primal: V) -> Self {
414        Self {
415            primal,
416            tangent: None,
417        }
418    }
419
420    /// Creates a dual value with explicit tangent.
421    ///
422    /// # Errors
423    ///
424    /// Returns [`AutodiffError::TangentShapeMismatch`] if shapes differ.
425    ///
426    /// # Examples
427    ///
428    /// ```ignore
429    /// use chainrules::DualTensor;
430    /// let x = DualTensor::with_tangent(primal, tangent).unwrap();
431    /// ```
432    pub fn with_tangent(_primal: V, _tangent: V::Tangent) -> AdResult<Self> {
433        todo!()
434    }
435
436    /// Returns the primal value.
437    ///
438    /// # Examples
439    ///
440    /// ```ignore
441    /// let p = dual.primal();
442    /// ```
443    pub fn primal(&self) -> &V {
444        &self.primal
445    }
446
447    /// Returns the tangent, or `None` for zero tangent.
448    ///
449    /// # Examples
450    ///
451    /// ```ignore
452    /// let maybe_t = dual.tangent();
453    /// ```
454    pub fn tangent(&self) -> Option<&V::Tangent> {
455        self.tangent.as_ref()
456    }
457
458    /// Returns whether this dual value has a non-zero tangent.
459    ///
460    /// # Examples
461    ///
462    /// ```ignore
463    /// assert!(dual.has_tangent());
464    /// ```
465    pub fn has_tangent(&self) -> bool {
466        self.tangent.is_some()
467    }
468
469    /// Consumes and returns `(primal, tangent)`.
470    ///
471    /// # Examples
472    ///
473    /// ```ignore
474    /// let (p, t) = dual.into_parts();
475    /// ```
476    pub fn into_parts(self) -> (V, Option<V::Tangent>) {
477        (self.primal, self.tangent)
478    }
479
480    /// Consumes and returns a dual value with tangent removed.
481    ///
482    /// # Examples
483    ///
484    /// ```ignore
485    /// let c = dual.detach_tangent();
486    /// assert!(!c.has_tangent());
487    /// ```
488    pub fn detach_tangent(self) -> Self {
489        todo!()
490    }
491}
492
493/// Accumulated gradients indexed by [`NodeId`].
494///
495/// # Examples
496///
497/// ```ignore
498/// use chainrules::{Gradients, Differentiable};
499/// // V::Tangent is the gradient type
500/// let mut grads = Gradients::<MyType>::new();
501/// ```
502pub struct Gradients<V: Differentiable> {
503    entries: Vec<(NodeId, V::Tangent)>,
504}
505
506impl<V: Differentiable> Gradients<V> {
507    /// Creates an empty gradient container.
508    ///
509    /// # Examples
510    ///
511    /// ```ignore
512    /// use chainrules::Gradients;
513    /// let grads = Gradients::<MyType>::new();
514    /// ```
515    pub fn new() -> Self {
516        Self { entries: vec![] }
517    }
518
519    /// Returns the gradient for `node`, if present.
520    ///
521    /// # Examples
522    ///
523    /// ```ignore
524    /// if let Some(g) = grads.get(node) {
525    ///     // use gradient
526    /// }
527    /// ```
528    pub fn get(&self, _node: NodeId) -> Option<&V::Tangent> {
529        todo!()
530    }
531
532    /// Inserts or accumulates a gradient for `node`.
533    ///
534    /// # Examples
535    ///
536    /// ```ignore
537    /// grads.accumulate(node, grad);
538    /// ```
539    pub fn accumulate(&mut self, _node: NodeId, _grad: V::Tangent) -> AdResult<()> {
540        todo!()
541    }
542
543    /// Returns all `(node, grad)` entries.
544    ///
545    /// # Examples
546    ///
547    /// ```ignore
548    /// for (node, grad) in grads.entries() {
549    ///     println!("{}", node.index());
550    /// }
551    /// ```
552    pub fn entries(&self) -> &[(NodeId, V::Tangent)] {
553        &self.entries
554    }
555}
556
557impl<V: Differentiable> Default for Gradients<V> {
558    fn default() -> Self {
559        Self::new()
560    }
561}
562
563/// Compiled pullback execution plan.
564///
565/// # Examples
566///
567/// ```ignore
568/// let plan = chainrules::PullbackPlan::<MyType>::build(&loss).unwrap();
569/// ```
570#[derive(Debug, Clone)]
571pub struct PullbackPlan<V: Differentiable> {
572    loss: NodeId,
573    _marker: PhantomData<V>,
574}
575
576impl<V: Differentiable> PullbackPlan<V> {
577    /// Builds a pullback plan from a loss value.
578    ///
579    /// # Examples
580    ///
581    /// ```ignore
582    /// let plan = chainrules::PullbackPlan::build(&loss).unwrap();
583    /// ```
584    pub fn build(_loss: &TrackedTensor<V>) -> AdResult<Self> {
585        todo!()
586    }
587
588    /// Executes the pre-built pullback plan.
589    ///
590    /// # Examples
591    ///
592    /// ```ignore
593    /// let grads = plan.execute(&loss).unwrap();
594    /// ```
595    pub fn execute(&self, _loss: &TrackedTensor<V>) -> AdResult<Gradients<V>> {
596        todo!()
597    }
598
599    /// Returns loss node ID for this plan.
600    ///
601    /// # Examples
602    ///
603    /// ```
604    /// use chainrules::{PullbackPlan, NodeId};
605    /// let _id_fn: fn(&PullbackPlan<f64>) -> NodeId = PullbackPlan::loss_node;
606    /// ```
607    pub fn loss_node(&self) -> NodeId {
608        self.loss
609    }
610}
611
612/// Result of a forward-over-reverse HVP computation.
613///
614/// Contains both the standard gradient and the Hessian-vector
615/// product H*v, where v is the tangent direction set on leaf values
616/// via [`Tape::leaf_with_tangent`].
617///
618/// # Examples
619///
620/// ```ignore
621/// use chainrules::{Tape, HvpResult};
622/// use tenferro_einsum::tracked_einsum;
623/// use tenferro_tensor::{MemoryOrder, Tensor};
624/// use tenferro_device::LogicalMemorySpace;
625///
626/// let tape = Tape::<Tensor<f64>>::new();
627/// let x = tape.leaf_with_tangent(
628///     Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
629///     Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
630/// ).unwrap();
631/// let loss = tracked_einsum("i,i->", &[&x, &x]).unwrap();
632/// let result: HvpResult<Tensor<f64>> = tape.hvp(&loss).unwrap();
633/// let _grad = result.gradients.get(x.node_id().unwrap());
634/// let _hv = result.hvp.get(x.node_id().unwrap());
635/// ```
636pub struct HvpResult<V: Differentiable> {
637    /// Gradients.
638    pub gradients: Gradients<V>,
639    /// Hessian-vector product: H*v.
640    pub hvp: Gradients<V>,
641}