Skip to main content

tidu/
lib.rs

1//! Automatic-differentiation transforms for primitive computation graphs.
2//!
3//! `tidu` is for downstream crates that define primitive operations, local AD
4//! rules, graph runtimes, or eager tensor frontends. It does not define tensor
5//! operations itself. Instead, downstream primitive sets implement [`Primitive`],
6//! then call the graph transforms here to build new primitive computation
7//! graphs.
8//!
9//! The main transforms are:
10//!
11//! - [`linearize`] / [`try_linearize`], which build a graph for a
12//!   Jacobian-vector product (JVP) of selected outputs with respect to selected
13//!   inputs.
14//! - [`linear_transpose`] / [`try_linear_transpose`], which transpose a
15//!   linearized graph so cotangents can flow backward through the corresponding
16//!   linear map.
17//! - [`eager::try_backward`], which supports downstream eager frontends that
18//!   record graph invocations and want a reverse-mode `backward()` workflow.
19//!
20//! Fallible variants (`try_linearize`, `try_linear_transpose`, and
21//! `eager::try_backward`) propagate [`ADRuleError`] for missing primitive or
22//! extension AD rules.
23//!
24//! See the repository `docs/` tree for the terminology guide, tutorials, and
25//! implementer guides.
26//!
27//! # Examples
28//!
29//! ```ignore
30//! use computegraph::resolve::resolve;
31//! use tidu::{try_linear_transpose, try_linearize};
32//!
33//! let view = resolve(vec![source_graph]);
34//! let mut ctx = ();
35//! let aliases = std::collections::HashMap::new();
36//! let linear = try_linearize(&view, &[output_key], &[input_key], 1, &mut ctx, &aliases)?;
37//! let _transposed = try_linear_transpose(&linear, &mut ctx)?;
38//! # Ok::<(), tidu::ADRuleError>(())
39//! ```
40
41pub mod eager;
42mod linear_transpose;
43mod linearize;
44mod linearized_graph;
45mod primitive_graph;
46pub mod rules;
47
48pub use linear_transpose::{
49    linear_transpose, try_linear_transpose, try_linear_transpose_with_builder,
50};
51pub use linearize::{linearize, try_linearize};
52pub use linearized_graph::LinearizedGraph;
53pub use primitive_graph::PrimitiveGraph;
54pub use rules::{
55    ADKey, ADRuleError, ADRuleKind, ADRuleResult, DiffPassId, Primitive, PrimitiveBuilder,
56    PrimitiveValue,
57};