chainrules/lib.rs
1//! AD engine: tape-based reverse-mode and dual-number forward-mode.
2//!
3//! This crate provides the AD execution engine, built on top of
4//! [`chainrules_core`] traits. It is analogous to Zygote.jl in the Julia
5//! ecosystem: a concrete AD engine that uses ChainRulesCore.jl interfaces.
6//!
7//! - Reverse-mode AD via [`Tape`], [`TrackedTensor`], and [`Tape::pullback`]
8//! - Forward-mode AD via [`DualTensor`]
9//! - Forward-over-reverse HVP via [`Tape::hvp`]
10//!
11//! Operation-specific AD rules (e.g., einsum rrule/frule) live in the crate
12//! that defines the operation. See `tenferro-einsum` for einsum AD functions.
13//!
14//! Bodies are intentionally `todo!()` in the current POC phase.
15//!
16//! # Examples
17//!
18//! Reverse-mode usage (with operation-specific AD functions from other crates):
19//!
20//! ```ignore
21//! use chainrules::{Tape, TrackedTensor};
22//! use tenferro_einsum::tracked_einsum;
23//! use tenferro_tensor::{MemoryOrder, Tensor};
24//! use tenferro_device::LogicalMemorySpace;
25//!
26//! let tape = Tape::<Tensor<f64>>::new();
27//! let a = tape.leaf(Tensor::ones(
28//! &[2, 3],
29//! LogicalMemorySpace::MainMemory,
30//! MemoryOrder::ColumnMajor,
31//! ));
32//! let b = tape.leaf(Tensor::ones(
33//! &[3, 4],
34//! LogicalMemorySpace::MainMemory,
35//! MemoryOrder::ColumnMajor,
36//! ));
37//! let c = tracked_einsum("ij,jk->ik", &[&a, &b]).unwrap();
38//! let loss = tracked_einsum("ij,ij->", &[&c, &c]).unwrap();
39//! let grads = tape.pullback(&loss).unwrap();
40//! let _ga = grads.get(a.node_id().unwrap()).unwrap();
41//! ```
42//!
43//! Forward-mode usage:
44//!
45//! ```ignore
46//! use chainrules::DualTensor;
47//! use tenferro_einsum::dual_einsum;
48//! use tenferro_tensor::{MemoryOrder, Tensor};
49//!
50//! let a = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor).unwrap();
51//! let da = Tensor::<f64>::ones(&[2, 2], tenferro_device::LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
52//! let b = Tensor::<f64>::ones(&[2, 2], tenferro_device::LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
53//!
54//! let a_dual = DualTensor::with_tangent(a, da).unwrap();
55//! let b_dual = DualTensor::new(b);
56//! let c_dual = dual_einsum("ij,jk->ik", &[&a_dual, &b_dual]).unwrap();
57//! let _jvp = c_dual.tangent();
58//! ```
59//!
60//! Forward-over-reverse HVP (Hessian-vector product):
61//!
62//! ```ignore
63//! use chainrules::Tape;
64//! use tenferro_einsum::tracked_einsum;
65//! use tenferro_tensor::{MemoryOrder, Tensor};
66//! use tenferro_device::LogicalMemorySpace;
67//!
68//! let tape = Tape::<Tensor<f64>>::new();
69//! let x = tape.leaf_with_tangent(
70//! Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
71//! Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor), // direction v
72//! ).unwrap();
73//! let loss = tracked_einsum("i,i->", &[&x, &x]).unwrap(); // f(x) = x·x
74//! let result = tape.hvp(&loss).unwrap();
75//! let _grad = result.gradients; // ∇f(x) = 2x
76//! let _hv = result.hvp; // H·v = 2v
77//! ```
78
79// Re-export all core traits so downstream can depend on just `chainrules`.
80pub use chainrules_core::*;
81
82use std::marker::PhantomData;
83
84/// Reverse-mode AD tape.
85///
86/// The tape records operations performed on [`TrackedTensor`] values and
87/// enables gradient computation via [`Tape::pullback`] or HVP via
88/// [`Tape::hvp`].
89///
90/// Create leaf values with [`Tape::leaf`], perform operations using
91/// AD-aware functions (e.g., `tracked_einsum`), then call
92/// [`Tape::pullback`] on the scalar loss to compute gradients.
93///
94/// `Tape` is cheaply cloneable (internally reference-counted). Multiple
95/// clones refer to the same underlying tape.
96///
97/// # Examples
98///
99/// ```ignore
100/// use chainrules::Tape;
101/// use tenferro_einsum::tracked_einsum;
102/// use tenferro_tensor::{MemoryOrder, Tensor};
103/// use tenferro_device::LogicalMemorySpace;
104///
105/// let tape = Tape::<Tensor<f64>>::new();
106/// let a = tape.leaf(Tensor::ones(
107/// &[2, 3],
108/// LogicalMemorySpace::MainMemory,
109/// MemoryOrder::ColumnMajor,
110/// ));
111/// let b = tape.leaf(Tensor::ones(
112/// &[3, 4],
113/// LogicalMemorySpace::MainMemory,
114/// MemoryOrder::ColumnMajor,
115/// ));
116/// let c = tracked_einsum("ij,jk->ik", &[&a, &b]).unwrap();
117/// let loss = tracked_einsum("ij,ij->", &[&c, &c]).unwrap();
118/// let grads = tape.pullback(&loss).unwrap();
119/// let _ga = grads.get(a.node_id().unwrap()).unwrap();
120/// ```
121pub struct Tape<V: Differentiable> {
122 _marker: PhantomData<V>,
123}
124
125impl<V: Differentiable> Tape<V> {
126 /// Creates a new empty tape.
127 ///
128 /// # Examples
129 ///
130 /// ```ignore
131 /// use chainrules::Tape;
132 ///
133 /// let tape = Tape::<f64>::new();
134 /// ```
135 pub fn new() -> Self {
136 Self {
137 _marker: PhantomData,
138 }
139 }
140
141 /// Creates a leaf value requiring gradient on this tape.
142 ///
143 /// The returned [`TrackedTensor`] is connected to this tape and
144 /// will participate in gradient computation via [`Tape::pullback`].
145 ///
146 /// # Examples
147 ///
148 /// ```ignore
149 /// use chainrules::Tape;
150 ///
151 /// let tape = Tape::<f64>::new();
152 /// let x = tape.leaf(3.14);
153 /// assert!(x.requires_grad());
154 /// ```
155 pub fn leaf(&self, _value: V) -> TrackedTensor<V> {
156 todo!()
157 }
158
159 /// Creates a leaf value with a tangent for HVP computation.
160 ///
161 /// The tangent defines the perturbation direction *v* used in
162 /// forward-over-reverse Hessian-vector product computation.
163 ///
164 /// # Errors
165 ///
166 /// Returns [`AutodiffError::TangentShapeMismatch`] if shapes differ.
167 ///
168 /// # Examples
169 ///
170 /// ```ignore
171 /// use chainrules::Tape;
172 ///
173 /// let tape = Tape::<f64>::new();
174 /// let x = tape.leaf_with_tangent(3.14, 1.0).unwrap();
175 /// assert!(x.requires_grad());
176 /// assert!(x.has_tangent());
177 /// ```
178 pub fn leaf_with_tangent(&self, _value: V, _tangent: V::Tangent) -> AdResult<TrackedTensor<V>> {
179 todo!()
180 }
181
182 /// Runs reverse-mode pullback from a scalar loss value.
183 ///
184 /// # Errors
185 ///
186 /// Returns [`AutodiffError::NonScalarLoss`] for non-scalar losses.
187 /// Returns [`AutodiffError::MissingNode`] if the loss is not connected
188 /// to this tape.
189 ///
190 /// # Examples
191 ///
192 /// ```ignore
193 /// use chainrules::Tape;
194 ///
195 /// let tape = Tape::<f64>::new();
196 /// let x = tape.leaf(2.0);
197 /// // ... compute loss from x ...
198 /// let grads = tape.pullback(&x).unwrap();
199 /// ```
200 pub fn pullback(&self, _loss: &TrackedTensor<V>) -> AdResult<Gradients<V>> {
201 todo!()
202 }
203
204 /// Computes gradient and Hessian-vector product via forward-over-reverse.
205 ///
206 /// Leaf values with tangents (created via [`Tape::leaf_with_tangent`])
207 /// define the direction *v*. The function runs pullback through the tape,
208 /// propagating both cotangents and cotangent-tangents at each node.
209 ///
210 /// Returns both the gradient (in [`HvpResult::gradients`]) and H*v (in
211 /// [`HvpResult::hvp`]).
212 ///
213 /// # Errors
214 ///
215 /// Returns [`AutodiffError::NonScalarLoss`] for non-scalar losses.
216 /// Returns [`AutodiffError::HvpNotSupported`] if any ReverseRule on the tape
217 /// does not implement `pullback_with_tangents`.
218 ///
219 /// # Examples
220 ///
221 /// ```ignore
222 /// use chainrules::Tape;
223 /// use tenferro_einsum::tracked_einsum;
224 /// use tenferro_tensor::{MemoryOrder, Tensor};
225 /// use tenferro_device::LogicalMemorySpace;
226 ///
227 /// let tape = Tape::<Tensor<f64>>::new();
228 /// let x = tape.leaf_with_tangent(
229 /// Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
230 /// Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
231 /// ).unwrap();
232 /// let loss = tracked_einsum("i,i->", &[&x, &x]).unwrap();
233 /// let result = tape.hvp(&loss).unwrap();
234 /// let _grad = result.gradients;
235 /// let _hv = result.hvp;
236 /// ```
237 pub fn hvp(&self, _loss: &TrackedTensor<V>) -> AdResult<HvpResult<V>> {
238 todo!()
239 }
240}
241
242impl<V: Differentiable> Clone for Tape<V> {
243 fn clone(&self) -> Self {
244 Self {
245 _marker: PhantomData,
246 }
247 }
248}
249
250impl<V: Differentiable> Default for Tape<V> {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256/// Value wrapper for reverse-mode AD.
257///
258/// Wraps any [`Differentiable`] value and connects it to a [`Tape`]
259/// for gradient computation.
260///
261/// Created via [`Tape::leaf`] for gradient-tracked values, or
262/// [`TrackedTensor::new`] for values that do not require gradients.
263///
264/// # Examples
265///
266/// ```ignore
267/// use chainrules::{Tape, TrackedTensor};
268/// use tenferro_tensor::{MemoryOrder, Tensor};
269/// use tenferro_device::LogicalMemorySpace;
270///
271/// let tape = Tape::<Tensor<f64>>::new();
272/// let a = tape.leaf(Tensor::ones(
273/// &[2, 3],
274/// LogicalMemorySpace::MainMemory,
275/// MemoryOrder::ColumnMajor,
276/// ));
277/// assert!(a.requires_grad());
278/// ```
279pub struct TrackedTensor<V: Differentiable> {
280 value: V,
281 node_id: Option<NodeId>,
282 tape: Option<Tape<V>>,
283 requires_grad: bool,
284 tangent: Option<V::Tangent>,
285}
286
287impl<V: Differentiable> TrackedTensor<V> {
288 /// Creates a tracked value with `requires_grad = false` (no tape).
289 ///
290 /// # Examples
291 ///
292 /// ```ignore
293 /// use chainrules::TrackedTensor;
294 /// let x = TrackedTensor::new(value);
295 /// assert!(!x.requires_grad());
296 /// ```
297 pub fn new(value: V) -> Self {
298 Self {
299 value,
300 node_id: None,
301 tape: None,
302 requires_grad: false,
303 tangent: None,
304 }
305 }
306
307 /// Returns the underlying value.
308 ///
309 /// # Examples
310 ///
311 /// ```ignore
312 /// let v = tracked.value();
313 /// ```
314 pub fn value(&self) -> &V {
315 &self.value
316 }
317
318 /// Consumes and returns the underlying value.
319 ///
320 /// # Examples
321 ///
322 /// ```ignore
323 /// let raw = tracked.into_value();
324 /// ```
325 pub fn into_value(self) -> V {
326 self.value
327 }
328
329 /// Returns whether this value participates in gradient propagation.
330 ///
331 /// # Examples
332 ///
333 /// ```ignore
334 /// assert!(tracked.requires_grad());
335 /// ```
336 pub fn requires_grad(&self) -> bool {
337 self.requires_grad
338 }
339
340 /// Returns the graph node ID when this value is connected to a tape.
341 ///
342 /// # Examples
343 ///
344 /// ```ignore
345 /// if let Some(id) = tracked.node_id() {
346 /// println!("node = {}", id.index());
347 /// }
348 /// ```
349 pub fn node_id(&self) -> Option<NodeId> {
350 self.node_id
351 }
352
353 /// Returns the tangent for HVP, or `None` if not set.
354 ///
355 /// # Examples
356 ///
357 /// ```ignore
358 /// if let Some(t) = tracked.tangent() {
359 /// // use tangent
360 /// }
361 /// ```
362 pub fn tangent(&self) -> Option<&V::Tangent> {
363 self.tangent.as_ref()
364 }
365
366 /// Returns whether this tracked value has a tangent for HVP.
367 ///
368 /// # Examples
369 ///
370 /// ```ignore
371 /// assert!(tracked.has_tangent());
372 /// ```
373 pub fn has_tangent(&self) -> bool {
374 self.tangent.is_some()
375 }
376
377 /// Consumes and returns a detached value that does not require gradients.
378 ///
379 /// # Examples
380 ///
381 /// ```ignore
382 /// let detached = tracked.detach();
383 /// assert!(!detached.requires_grad());
384 /// ```
385 pub fn detach(self) -> Self {
386 todo!()
387 }
388}
389
390/// Value wrapper for forward-mode AD.
391///
392/// # Examples
393///
394/// ```ignore
395/// use chainrules::DualTensor;
396/// let dual = DualTensor::new(primal);
397/// assert!(!dual.has_tangent());
398/// ```
399pub struct DualTensor<V: Differentiable> {
400 primal: V,
401 tangent: Option<V::Tangent>,
402}
403
404impl<V: Differentiable> DualTensor<V> {
405 /// Creates a dual value with zero tangent.
406 ///
407 /// # Examples
408 ///
409 /// ```ignore
410 /// use chainrules::DualTensor;
411 /// let x = DualTensor::new(primal);
412 /// ```
413 pub fn new(primal: V) -> Self {
414 Self {
415 primal,
416 tangent: None,
417 }
418 }
419
420 /// Creates a dual value with explicit tangent.
421 ///
422 /// # Errors
423 ///
424 /// Returns [`AutodiffError::TangentShapeMismatch`] if shapes differ.
425 ///
426 /// # Examples
427 ///
428 /// ```ignore
429 /// use chainrules::DualTensor;
430 /// let x = DualTensor::with_tangent(primal, tangent).unwrap();
431 /// ```
432 pub fn with_tangent(_primal: V, _tangent: V::Tangent) -> AdResult<Self> {
433 todo!()
434 }
435
436 /// Returns the primal value.
437 ///
438 /// # Examples
439 ///
440 /// ```ignore
441 /// let p = dual.primal();
442 /// ```
443 pub fn primal(&self) -> &V {
444 &self.primal
445 }
446
447 /// Returns the tangent, or `None` for zero tangent.
448 ///
449 /// # Examples
450 ///
451 /// ```ignore
452 /// let maybe_t = dual.tangent();
453 /// ```
454 pub fn tangent(&self) -> Option<&V::Tangent> {
455 self.tangent.as_ref()
456 }
457
458 /// Returns whether this dual value has a non-zero tangent.
459 ///
460 /// # Examples
461 ///
462 /// ```ignore
463 /// assert!(dual.has_tangent());
464 /// ```
465 pub fn has_tangent(&self) -> bool {
466 self.tangent.is_some()
467 }
468
469 /// Consumes and returns `(primal, tangent)`.
470 ///
471 /// # Examples
472 ///
473 /// ```ignore
474 /// let (p, t) = dual.into_parts();
475 /// ```
476 pub fn into_parts(self) -> (V, Option<V::Tangent>) {
477 (self.primal, self.tangent)
478 }
479
480 /// Consumes and returns a dual value with tangent removed.
481 ///
482 /// # Examples
483 ///
484 /// ```ignore
485 /// let c = dual.detach_tangent();
486 /// assert!(!c.has_tangent());
487 /// ```
488 pub fn detach_tangent(self) -> Self {
489 todo!()
490 }
491}
492
493/// Accumulated gradients indexed by [`NodeId`].
494///
495/// # Examples
496///
497/// ```ignore
498/// use chainrules::{Gradients, Differentiable};
499/// // V::Tangent is the gradient type
500/// let mut grads = Gradients::<MyType>::new();
501/// ```
502pub struct Gradients<V: Differentiable> {
503 entries: Vec<(NodeId, V::Tangent)>,
504}
505
506impl<V: Differentiable> Gradients<V> {
507 /// Creates an empty gradient container.
508 ///
509 /// # Examples
510 ///
511 /// ```ignore
512 /// use chainrules::Gradients;
513 /// let grads = Gradients::<MyType>::new();
514 /// ```
515 pub fn new() -> Self {
516 Self { entries: vec![] }
517 }
518
519 /// Returns the gradient for `node`, if present.
520 ///
521 /// # Examples
522 ///
523 /// ```ignore
524 /// if let Some(g) = grads.get(node) {
525 /// // use gradient
526 /// }
527 /// ```
528 pub fn get(&self, _node: NodeId) -> Option<&V::Tangent> {
529 todo!()
530 }
531
532 /// Inserts or accumulates a gradient for `node`.
533 ///
534 /// # Examples
535 ///
536 /// ```ignore
537 /// grads.accumulate(node, grad);
538 /// ```
539 pub fn accumulate(&mut self, _node: NodeId, _grad: V::Tangent) -> AdResult<()> {
540 todo!()
541 }
542
543 /// Returns all `(node, grad)` entries.
544 ///
545 /// # Examples
546 ///
547 /// ```ignore
548 /// for (node, grad) in grads.entries() {
549 /// println!("{}", node.index());
550 /// }
551 /// ```
552 pub fn entries(&self) -> &[(NodeId, V::Tangent)] {
553 &self.entries
554 }
555}
556
557impl<V: Differentiable> Default for Gradients<V> {
558 fn default() -> Self {
559 Self::new()
560 }
561}
562
563/// Compiled pullback execution plan.
564///
565/// # Examples
566///
567/// ```ignore
568/// let plan = chainrules::PullbackPlan::<MyType>::build(&loss).unwrap();
569/// ```
570#[derive(Debug, Clone)]
571pub struct PullbackPlan<V: Differentiable> {
572 loss: NodeId,
573 _marker: PhantomData<V>,
574}
575
576impl<V: Differentiable> PullbackPlan<V> {
577 /// Builds a pullback plan from a loss value.
578 ///
579 /// # Examples
580 ///
581 /// ```ignore
582 /// let plan = chainrules::PullbackPlan::build(&loss).unwrap();
583 /// ```
584 pub fn build(_loss: &TrackedTensor<V>) -> AdResult<Self> {
585 todo!()
586 }
587
588 /// Executes the pre-built pullback plan.
589 ///
590 /// # Examples
591 ///
592 /// ```ignore
593 /// let grads = plan.execute(&loss).unwrap();
594 /// ```
595 pub fn execute(&self, _loss: &TrackedTensor<V>) -> AdResult<Gradients<V>> {
596 todo!()
597 }
598
599 /// Returns loss node ID for this plan.
600 ///
601 /// # Examples
602 ///
603 /// ```
604 /// use chainrules::{PullbackPlan, NodeId};
605 /// let _id_fn: fn(&PullbackPlan<f64>) -> NodeId = PullbackPlan::loss_node;
606 /// ```
607 pub fn loss_node(&self) -> NodeId {
608 self.loss
609 }
610}
611
612/// Result of a forward-over-reverse HVP computation.
613///
614/// Contains both the standard gradient and the Hessian-vector
615/// product H*v, where v is the tangent direction set on leaf values
616/// via [`Tape::leaf_with_tangent`].
617///
618/// # Examples
619///
620/// ```ignore
621/// use chainrules::{Tape, HvpResult};
622/// use tenferro_einsum::tracked_einsum;
623/// use tenferro_tensor::{MemoryOrder, Tensor};
624/// use tenferro_device::LogicalMemorySpace;
625///
626/// let tape = Tape::<Tensor<f64>>::new();
627/// let x = tape.leaf_with_tangent(
628/// Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
629/// Tensor::ones(&[3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor),
630/// ).unwrap();
631/// let loss = tracked_einsum("i,i->", &[&x, &x]).unwrap();
632/// let result: HvpResult<Tensor<f64>> = tape.hvp(&loss).unwrap();
633/// let _grad = result.gradients.get(x.node_id().unwrap());
634/// let _hv = result.hvp.get(x.node_id().unwrap());
635/// ```
636pub struct HvpResult<V: Differentiable> {
637 /// Gradients.
638 pub gradients: Gradients<V>,
639 /// Hessian-vector product: H*v.
640 pub hvp: Gradients<V>,
641}