tidu/lib.rs
1//! Torch-like public autograd API with a generic linearize-first core.
2//!
3//! `tidu` provides a value-centered public API for reverse-mode AD built around
4//! a first-order `linearize` step. The engine stays generic over the
5//! differentiable value type, so the same runtime can power scalar examples,
6//! tensor engines, or downstream custom value types.
7//!
8//! The normal public surface is:
9//! - [`Value`] for reverse-mode leaves and outputs,
10//! - [`LinearizableOp`] for custom high-level operations,
11//! - [`LinearizedOp`] for local `jvp`/`vjp` access,
12//! - [`CheckpointMode`], [`AdExecutionPolicy`], and [`with_ad_policy`] for
13//! checkpoint policy scopes,
14//! - [`CheckpointHint`] for advanced retain-vs-replay hints on custom ops.
15//!
16//! **Companion crate:** The doc examples below import scalar rule helpers
17//! (e.g. `powf_rrule`) from the
18//! [`chainrules`](https://github.com/tensor4all/chainrules-rs) crate.
19//! Add it alongside `tidu` in your `Cargo.toml`:
20//!
21//! ```toml
22//! [dependencies]
23//! tidu = { git = "https://github.com/tensor4all/tidu-rs" }
24//! chainrules = { git = "https://github.com/tensor4all/chainrules-rs" }
25//! ```
26//!
27//! ## Table of Contents
28//! - [Value-Centered Reverse Mode](#value-centered-reverse-mode)
29//! - [Local Directional Derivatives](#local-directional-derivatives)
30//! - [Checkpoint Policy](#checkpoint-policy)
31//! - [Custom Value Type](#custom-value-type)
32//!
33//! ## Value-Centered Reverse Mode
34//! ```rust
35//! use tidu::{LinearizableOp, LinearizedOp, Schema, SlotSchema, Value};
36//!
37//! #[derive(Clone, Copy)]
38//! struct Cube;
39//!
40//! struct CubeLinearized {
41//! x: f64,
42//! }
43//!
44//! impl LinearizedOp<f64> for CubeLinearized {
45//! fn jvp(&self, input_tangents: &[Option<f64>]) -> tidu::AdResult<Vec<Option<f64>>> {
46//! Ok(vec![input_tangents[0].map(|dx| 3.0 * self.x * self.x * dx)])
47//! }
48//!
49//! fn vjp(
50//! &self,
51//! output_cotangents: &[Option<f64>],
52//! input_grad_mask: &[bool],
53//! ) -> tidu::AdResult<Vec<Option<f64>>> {
54//! assert_eq!(input_grad_mask, &[true]);
55//! let grad_out = output_cotangents[0].unwrap_or(0.0);
56//! Ok(vec![Some(3.0 * self.x * self.x * grad_out)])
57//! }
58//! }
59//!
60//! impl LinearizableOp<f64> for Cube {
61//! type Linearized = CubeLinearized;
62//!
63//! fn primal(&self, inputs: &[&f64]) -> tidu::AdResult<Vec<f64>> {
64//! Ok(vec![*inputs[0] * *inputs[0] * *inputs[0]])
65//! }
66//!
67//! fn input_schema(&self, _inputs: &[&f64]) -> tidu::AdResult<Schema> {
68//! Ok(Schema {
69//! slots: vec![SlotSchema {
70//! differentiable: true,
71//! auxiliary: false,
72//! }],
73//! })
74//! }
75//!
76//! fn output_schema(&self, _inputs: &[&f64], _outputs: &[f64]) -> tidu::AdResult<Schema> {
77//! Ok(Schema {
78//! slots: vec![SlotSchema {
79//! differentiable: true,
80//! auxiliary: false,
81//! }],
82//! })
83//! }
84//!
85//! fn linearize(
86//! &self,
87//! inputs: &[&f64],
88//! _outputs: &[f64],
89//! ) -> tidu::AdResult<Self::Linearized> {
90//! Ok(CubeLinearized { x: *inputs[0] })
91//! }
92//! }
93//!
94//! let x = Value::new(2.0).with_requires_grad(true);
95//! let y = Cube.apply_one(&[&x]).unwrap();
96//! y.backward().unwrap();
97//! assert_eq!(x.grad().unwrap().unwrap(), 12.0);
98//! ```
99//!
100//! ## Local Directional Derivatives
101//!
102//! ```rust
103//! use tidu::{LinearizableOp, LinearizedOp, Schema, SlotSchema};
104//!
105//! #[derive(Clone, Copy)]
106//! struct Square;
107//!
108//! struct SquareLinearized {
109//! x: f64,
110//! }
111//!
112//! impl LinearizedOp<f64> for SquareLinearized {
113//! fn jvp(&self, input_tangents: &[Option<f64>]) -> tidu::AdResult<Vec<Option<f64>>> {
114//! Ok(vec![input_tangents[0].map(|dx| 2.0 * self.x * dx)])
115//! }
116//!
117//! fn vjp(
118//! &self,
119//! output_cotangents: &[Option<f64>],
120//! input_grad_mask: &[bool],
121//! ) -> tidu::AdResult<Vec<Option<f64>>> {
122//! assert_eq!(input_grad_mask, &[true]);
123//! let grad_out = output_cotangents[0].unwrap_or(0.0);
124//! Ok(vec![Some(2.0 * self.x * grad_out)])
125//! }
126//! }
127//!
128//! impl LinearizableOp<f64> for Square {
129//! type Linearized = SquareLinearized;
130//!
131//! fn primal(&self, inputs: &[&f64]) -> tidu::AdResult<Vec<f64>> {
132//! Ok(vec![*inputs[0] * *inputs[0]])
133//! }
134//!
135//! fn input_schema(&self, _inputs: &[&f64]) -> tidu::AdResult<Schema> {
136//! Ok(Schema {
137//! slots: vec![SlotSchema {
138//! differentiable: true,
139//! auxiliary: false,
140//! }],
141//! })
142//! }
143//!
144//! fn output_schema(&self, _inputs: &[&f64], _outputs: &[f64]) -> tidu::AdResult<Schema> {
145//! Ok(Schema {
146//! slots: vec![SlotSchema {
147//! differentiable: true,
148//! auxiliary: false,
149//! }],
150//! })
151//! }
152//!
153//! fn linearize(
154//! &self,
155//! inputs: &[&f64],
156//! _outputs: &[f64],
157//! ) -> tidu::AdResult<Self::Linearized> {
158//! Ok(SquareLinearized { x: *inputs[0] })
159//! }
160//! }
161//!
162//! let lin = Square.linearize(&[&3.0], &[9.0]).unwrap();
163//! assert_eq!(lin.jvp(&[Some(1.0)]).unwrap(), vec![Some(6.0)]);
164//! ```
165//!
166//! ## Checkpoint Policy
167//!
168//! ```rust
169//! use tidu::{AdExecutionPolicy, CheckpointMode, with_ad_policy};
170//!
171//! let policy = AdExecutionPolicy {
172//! checkpoint_mode: CheckpointMode::Conservative,
173//! };
174//!
175//! with_ad_policy(policy, || -> tidu::AdResult<()> {
176//! // Record and differentiate values inside this scope.
177//! Ok(())
178//! })
179//! .unwrap();
180//! ```
181//!
182//! ## Custom Value Type
183//!
184//! `tidu` stays generic over any type implementing [`Differentiable`]. Custom
185//! values participate through the same [`Value`] and [`LinearizableOp`] surface,
186//! while reverse-mode seeding still happens through [`Value::backward`] or
187//! [`Value::backward_with_seed`] depending on the output shape.
188
189pub use chainrules_core::{AdResult, AutodiffError, Differentiable};
190
191mod checkpoint;
192mod graph_task;
193mod linearized;
194mod reverse_graph;
195mod value;
196
197pub use checkpoint::{with_ad_policy, AdExecutionPolicy, CheckpointHint, CheckpointMode};
198pub use linearized::{LinearizableOp, LinearizedOp, Schema, SlotSchema};
199pub use value::Value;