tidu/
lib.rs

1//! Torch-like public autograd API with a generic linearize-first core.
2//!
3//! `tidu` provides a value-centered public API for reverse-mode AD built around
4//! a first-order `linearize` step. The engine stays generic over the
5//! differentiable value type, so the same runtime can power scalar examples,
6//! tensor engines, or downstream custom value types.
7//!
8//! The normal public surface is:
9//! - [`Value`] for reverse-mode leaves and outputs,
10//! - [`LinearizableOp`] for custom high-level operations,
11//! - [`LinearizedOp`] for local `jvp`/`vjp` access,
12//! - [`CheckpointMode`], [`AdExecutionPolicy`], and [`with_ad_policy`] for
13//!   checkpoint policy scopes,
14//! - [`CheckpointHint`] for advanced retain-vs-replay hints on custom ops.
15//!
16//! **Companion crate:** The doc examples below import scalar rule helpers
17//! (e.g. `powf_rrule`) from the
18//! [`chainrules`](https://github.com/tensor4all/chainrules-rs) crate.
19//! Add it alongside `tidu` in your `Cargo.toml`:
20//!
21//! ```toml
22//! [dependencies]
23//! tidu       = { git = "https://github.com/tensor4all/tidu-rs" }
24//! chainrules = { git = "https://github.com/tensor4all/chainrules-rs" }
25//! ```
26//!
27//! ## Table of Contents
28//! - [Value-Centered Reverse Mode](#value-centered-reverse-mode)
29//! - [Local Directional Derivatives](#local-directional-derivatives)
30//! - [Checkpoint Policy](#checkpoint-policy)
31//! - [Custom Value Type](#custom-value-type)
32//!
33//! ## Value-Centered Reverse Mode
34//! ```rust
35//! use tidu::{LinearizableOp, LinearizedOp, Schema, SlotSchema, Value};
36//!
37//! #[derive(Clone, Copy)]
38//! struct Cube;
39//!
40//! struct CubeLinearized {
41//!     x: f64,
42//! }
43//!
44//! impl LinearizedOp<f64> for CubeLinearized {
45//!     fn jvp(&self, input_tangents: &[Option<f64>]) -> tidu::AdResult<Vec<Option<f64>>> {
46//!         Ok(vec![input_tangents[0].map(|dx| 3.0 * self.x * self.x * dx)])
47//!     }
48//!
49//!     fn vjp(
50//!         &self,
51//!         output_cotangents: &[Option<f64>],
52//!         input_grad_mask: &[bool],
53//!     ) -> tidu::AdResult<Vec<Option<f64>>> {
54//!         assert_eq!(input_grad_mask, &[true]);
55//!         let grad_out = output_cotangents[0].unwrap_or(0.0);
56//!         Ok(vec![Some(3.0 * self.x * self.x * grad_out)])
57//!     }
58//! }
59//!
60//! impl LinearizableOp<f64> for Cube {
61//!     type Linearized = CubeLinearized;
62//!
63//!     fn primal(&self, inputs: &[&f64]) -> tidu::AdResult<Vec<f64>> {
64//!         Ok(vec![*inputs[0] * *inputs[0] * *inputs[0]])
65//!     }
66//!
67//!     fn input_schema(&self, _inputs: &[&f64]) -> tidu::AdResult<Schema> {
68//!         Ok(Schema {
69//!             slots: vec![SlotSchema {
70//!                 differentiable: true,
71//!                 auxiliary: false,
72//!             }],
73//!         })
74//!     }
75//!
76//!     fn output_schema(&self, _inputs: &[&f64], _outputs: &[f64]) -> tidu::AdResult<Schema> {
77//!         Ok(Schema {
78//!             slots: vec![SlotSchema {
79//!                 differentiable: true,
80//!                 auxiliary: false,
81//!             }],
82//!         })
83//!     }
84//!
85//!     fn linearize(
86//!         &self,
87//!         inputs: &[&f64],
88//!         _outputs: &[f64],
89//!     ) -> tidu::AdResult<Self::Linearized> {
90//!         Ok(CubeLinearized { x: *inputs[0] })
91//!     }
92//! }
93//!
94//! let x = Value::new(2.0).with_requires_grad(true);
95//! let y = Cube.apply_one(&[&x]).unwrap();
96//! y.backward().unwrap();
97//! assert_eq!(x.grad().unwrap().unwrap(), 12.0);
98//! ```
99//!
100//! ## Local Directional Derivatives
101//!
102//! ```rust
103//! use tidu::{LinearizableOp, LinearizedOp, Schema, SlotSchema};
104//!
105//! #[derive(Clone, Copy)]
106//! struct Square;
107//!
108//! struct SquareLinearized {
109//!     x: f64,
110//! }
111//!
112//! impl LinearizedOp<f64> for SquareLinearized {
113//!     fn jvp(&self, input_tangents: &[Option<f64>]) -> tidu::AdResult<Vec<Option<f64>>> {
114//!         Ok(vec![input_tangents[0].map(|dx| 2.0 * self.x * dx)])
115//!     }
116//!
117//!     fn vjp(
118//!         &self,
119//!         output_cotangents: &[Option<f64>],
120//!         input_grad_mask: &[bool],
121//!     ) -> tidu::AdResult<Vec<Option<f64>>> {
122//!         assert_eq!(input_grad_mask, &[true]);
123//!         let grad_out = output_cotangents[0].unwrap_or(0.0);
124//!         Ok(vec![Some(2.0 * self.x * grad_out)])
125//!     }
126//! }
127//!
128//! impl LinearizableOp<f64> for Square {
129//!     type Linearized = SquareLinearized;
130//!
131//!     fn primal(&self, inputs: &[&f64]) -> tidu::AdResult<Vec<f64>> {
132//!         Ok(vec![*inputs[0] * *inputs[0]])
133//!     }
134//!
135//!     fn input_schema(&self, _inputs: &[&f64]) -> tidu::AdResult<Schema> {
136//!         Ok(Schema {
137//!             slots: vec![SlotSchema {
138//!                 differentiable: true,
139//!                 auxiliary: false,
140//!             }],
141//!         })
142//!     }
143//!
144//!     fn output_schema(&self, _inputs: &[&f64], _outputs: &[f64]) -> tidu::AdResult<Schema> {
145//!         Ok(Schema {
146//!             slots: vec![SlotSchema {
147//!                 differentiable: true,
148//!                 auxiliary: false,
149//!             }],
150//!         })
151//!     }
152//!
153//!     fn linearize(
154//!         &self,
155//!         inputs: &[&f64],
156//!         _outputs: &[f64],
157//!     ) -> tidu::AdResult<Self::Linearized> {
158//!         Ok(SquareLinearized { x: *inputs[0] })
159//!     }
160//! }
161//!
162//! let lin = Square.linearize(&[&3.0], &[9.0]).unwrap();
163//! assert_eq!(lin.jvp(&[Some(1.0)]).unwrap(), vec![Some(6.0)]);
164//! ```
165//!
166//! ## Checkpoint Policy
167//!
168//! ```rust
169//! use tidu::{AdExecutionPolicy, CheckpointMode, with_ad_policy};
170//!
171//! let policy = AdExecutionPolicy {
172//!     checkpoint_mode: CheckpointMode::Conservative,
173//! };
174//!
175//! with_ad_policy(policy, || -> tidu::AdResult<()> {
176//!     // Record and differentiate values inside this scope.
177//!     Ok(())
178//! })
179//! .unwrap();
180//! ```
181//!
182//! ## Custom Value Type
183//!
184//! `tidu` stays generic over any type implementing [`Differentiable`]. Custom
185//! values participate through the same [`Value`] and [`LinearizableOp`] surface,
186//! while reverse-mode seeding still happens through [`Value::backward`] or
187//! [`Value::backward_with_seed`] depending on the output shape.
188
189pub use chainrules_core::{AdResult, AutodiffError, Differentiable};
190
191mod checkpoint;
192mod graph_task;
193mod linearized;
194mod reverse_graph;
195mod value;
196
197pub use checkpoint::{with_ad_policy, AdExecutionPolicy, CheckpointHint, CheckpointMode};
198pub use linearized::{LinearizableOp, LinearizedOp, Schema, SlotSchema};
199pub use value::Value;