AD Architecture
Repos: computegraph-rs, tidu-rs, tenferro-rs Parent: ../index.md Related: computegraph.md, primitive-ad.md, tidu.md, ../spec/backend-contract.md, ../spec/primitive-catalog.md
I. Vision
Build a differentiable programming stack in Rust where:
linearizeis the only derivative-producing transform. It consumes a resolved logical view of computation and produces a new linear graph.linear_transposeis not differentiation. It reverses active linear flow in a linear graph and reuses the same derivative information.- AD transforms operate on graphs, not on a single eagerly merged graph.
- What higher-order AD needs is resolve, not physical merge. External references must be traceable; they do not need to be copied into one graph at every stage.
- Physical flattening happens only in
materialize_merge, typically just beforecompile. - Evaluation is always forward on a materialized graph after flattening, CSE, and slot assignment.
This is the intended operation set:
build user constructs a primal graph
resolve create a logical DAG view over one or more graphs
linearize resolved view -> new linear graph (JVP)
linear_transpose linear graph -> new linear graph (reverse linear flow)
materialize_merge resolved view -> MaterializedGraph (flatten + CSE)
compile MaterializedGraph -> CompiledProgram
eval CompiledProgram + input values -> output values
The key pipeline distinction is:
linearize -> resolve -> linearize -> resolve -> ...
not
linearize -> physical merge -> linearize -> physical merge -> ...
materialize_merge is still required, but only when a backend, serializer, or debugger needs one concrete graph.
Typical pipelines:
JVP:
build -> resolve -> linearize -> materialize_merge -> compile -> eval
VJP:
build -> resolve -> linearize -> linear_transpose -> materialize_merge -> compile -> eval
2nd directional derivative:
build -> resolve -> linearize -> resolve -> linearize -> materialize_merge -> compile -> eval
n-th derivative:
build -> (resolve -> linearize) x n -> [linear_transpose] -> materialize_merge -> compile -> eval
Four crates, strictly layered:
computegraph GraphOperation + Operand traits, Graph, resolve,
materialize_merge, compile (SSA), eval,
compilation cache
↓
tidu Primitive: GraphOperation (adds add + JVP + transpose_rule),
linearize, linear_transpose; no graph infrastructure of its own
↓
tenferro Concrete tensor primitives + execution lowering
computegraph provides the general-purpose computation graph engine. It is usable without AD (e.g. multi-tensor einsum as a graph of binary contractions). tidu is a thin layer that adds AD-specific graph transforms (linearize, linear_transpose), fully generic over Op: Primitive. Neither computegraph nor tidu references specific primitives. The responsibility for ensuring that linearize and transpose_rule produce valid, closed graphs belongs entirely to the downstream primitive implementor (tenferro).
II. Core Model
Graph vs MaterializedGraph
A Graph is the unit produced by build, linearize, and linear_transpose.
A graph:
- owns only its local nodes and ops
- may reference values defined elsewhere through external references
- is valid as long as those external references are resolvable
A MaterializedGraph is different. It is the fully flattened graph produced by materialize_merge:
- all reachable definitions are collected
- same-key nodes are unified
- a concrete DAG exists for compile, serialization, and debug printing
So the intended mental model is:
Graph = transform-time object
ResolvedView = logical traversal object
MaterializedGraph = compile-time object
Local ids are local only
Local ids are graph-scoped. They must not be used as cross-graph identity.
type LocalValueId = usize;
type LocalOperationId = usize;
enum ValueRef<Op: GraphOperation> {
Local(LocalValueId),
External(ValueKey<Op>),
}
enum ValueKey<Op: GraphOperation> {
Input(Op::InputKey),
Derived {
operation: Arc<OperationKey<Op>>,
output_slot: u8,
},
}
struct OperationKey<Op> {
operation: Op,
inputs: Vec<ValueKey<Op>>,
role: OperationRole,
}
enum OperationRole {
Primary,
Linearized { active_mask: Vec<bool> },
}ValueKey is the identity that matters across graphs. It is structural:
- inputs are keyed by
InputKey - derived values are keyed by primitive, global input keys, output slot, and linear metadata
This is what makes the following possible:
- external reference resolution
- cross-graph CSE
- higher-order tracing through earlier graphs
- linear_transpose accumulation bucketed by global identity
Active mask is part of identity
Linear nodes use the same primitive set as primal nodes, but the linear mode is not optional metadata. It changes the meaning of linear_transpose and therefore must participate in identity.
Examples:
Mul(a, b) role=Primary
Mul(a, dx) role=Linearized { active_mask=[fixed, active] }
Add(dx, dy) role=Linearized { active_mask=[active, active] }
The first and second node both evaluate as multiplication, but they are not the same graph object. They linear_transpose differently, so they must not alias.
Graph data structure
Conceptually:
struct ValueNode<Op> {
key: ValueKey<Op>,
producer: Option<(LocalOperationId, usize)>, // None for graph inputs
}
struct OperationNode<Op> {
operation: Op,
inputs: Vec<ValueRef<Op>>,
outputs: Vec<LocalValueId>,
role: OperationRole,
}
struct Graph<Op: GraphOperation> {
values: Vec<ValueNode<Op>>,
operations: Vec<OperationNode<Op>>,
inputs: Vec<LocalValueId>,
outputs: Vec<LocalValueId>,
parents: Vec<Arc<Graph<Op>>>,
}parents are not eager ownership. They are the lookup base used by resolve.
Resolver and ResolvedView
resolve does not copy nodes into one graph. It builds a logical lookup view over graphs.
enum ValueDef<Op> {
Input {
key: InputKey,
},
Produced {
operation: Op,
inputs: Vec<ValueRef<Op>>,
role: OperationRole,
output_slot: usize,
},
}
trait Resolver<Op> {
fn resolve_value(&self, key: &ValueKey<Op>) -> Option<ValueDef<Op>>;
}
struct ResolvedView<Op> {
roots: Vec<Arc<Graph<Op>>>,
resolver: Arc<dyn Resolver<Op>>,
}The intended implementation is a resolver assembled from parent graphs, not a mandatory central registry.
resolve therefore means:
- an external value is not dangling
- its definition can be found
- its dependencies can be followed recursively
It does not mean:
- nodes are copied into one physical graph
- CSE has already run
- slot assignment has already happened
Logical traversal
All transform-time walkers operate on the same logical rule:
Local(LocalValueId) -> follow the local producer
External(ValueKey) -> ask the resolver for the defining op
This must work recursively through any number of graph boundaries.
Topological traversal at transform time is therefore logical, not physical:
- visitation is keyed by
ValueKey - the ordering is computed on the resolved logical DAG
- local ids only matter inside the graph currently being built
Materialized graph
Compile does not consume a graph. It consumes the result of materialize_merge.
struct MaterializedGraph<Op> {
values: Vec<MaterializedVal<Op>>,
operations: Vec<MaterializedOp<Op>>,
inputs: Vec<ValueKey<Op>>,
outputs: Vec<ValueKey<Op>>,
}materialize_merge is the stage that:
- collects the reachable subgraph
- deduplicates by global identity
- computes one concrete topological order
- produces compile-ready graph state
III. Transformations
build
build creates a primal graph.
- all nodes are
OperationRole::Primary - graph inputs use
ValueKey::Input - no eager merge is implied
resolve
Conceptually:
fn resolve<Op: GraphOperation>(roots: Vec<Arc<Graph<Op>>>) -> ResolvedView<Op>;resolve is cheap and logical. It prepares a traversal view over graph parents and external references.
This is the correct replacement for the old statement:
"merge must precede next linearize"
The precise rule is:
resolve must precede any transform that needs to trace through external refs
In practice:
- higher-order
linearizerequiresresolve - dependency analysis requires
resolve linear_transposeusually needs only the linear graph itself plus active masks
linearize
linearize consumes a resolved view and returns a new linear graph.
struct LinearizedGraph<Op> {
graph: Graph<Op>,
tangent_inputs: Vec<(InputKey, LocalValueId)>,
tangent_outputs: Vec<Option<LocalValueId>>,
}
fn linearize<Op: Primitive>(
view: &ResolvedView<Op>,
outputs: &[ValueKey<Op>],
wrt: &[InputKey],
) -> LinearizedGraph<Op>;Important consequences:
- callers specify which primal inputs they linearize with respect to
- tangent inputs are created inside the returned graph
- those tangent inputs receive fresh
InputKeys and are returned to the caller - primal values are referenced by
External(ValueKey), not copied
Algorithm sketch:
1. Traverse the reachable logical DAG in topological order.
2. Seed tangent inputs for the requested primal InputKeys.
3. For each reachable primitive, call its linearization rule.
4. Emit new local linear nodes into the new graph.
5. Reference primal values through External(ValueKey).
6. Skip unreachable tangent flow with zero propagation.
There is no physical merge in this step.
Linear nodes use the primal primitive set
The linear graph uses the same primitive set as the primal graph. There is no dedicated Scale primitive.
Examples:
Mul(a, dx) role=Linearized { active_mask=[fixed, active] }
Add(dx, dy) role=Linearized { active_mask=[active, active] }
Exp(x) role=Primary
That last line matters: Exp itself is not a linear node. The linearization of Exp(x) emits linear nodes such as:
Mul(External(exp(x)), dx) role=Linearized { active_mask=[fixed, active] }
The design rule is:
- linearization may reference primal inputs or outputs as fixed operands
- linearization must stay linear in tangent inputs
- active-vs-fixed information is recorded in
OperationRole::Linearized
linear_transpose
linear_transpose consumes a linear graph and produces another linear graph with active inputs and outputs reversed.
fn linear_transpose<Op: Primitive>(
linear: &LinearizedGraph<Op>,
) -> LinearizedGraph<Op>;It does not linearize again. It reuses the same local linear rules with direction reversed.
tidu::linear_transpose is generic. It traverses the linear graph in reverse topological order and, for each op node, calls Op::transpose_rule to obtain the local transposed contribution. tidu does not know which primitives exist; it only requires that every op in the linear graph implements Primitive::transpose_rule.
Transpose accumulation must use global identity. When multiple reverse contributions flow back to the same original tangent node, bucket by the global key of that tangent value, not by a graph-local id.
materialize_merge
materialize_merge is the physical graph-building step.
fn materialize_merge<Op: GraphOperation>(
view: &ResolvedView<Op>,
outputs: &[ValueKey<Op>],
) -> MaterializedGraph<Op>;This step:
- walks the resolved logical DAG from the requested outputs
- collects reachable definitions
- deduplicates by
ValueKey/OperationKey - computes one physical DAG
- prepares the input to
compile
This is where “merge” actually belongs.
The terminology should therefore be:
resolve = make external references traceable
materialize_merge = flatten graphs into one concrete graph
compile and eval
compile consumes a MaterializedGraph, not a graph.
let view = resolve(vec![graph_a, graph_b, graph_c]);
let graph = materialize_merge(&view, &requested_outputs);
let prog = compile(&graph);
let values = prog.run(&runtime_inputs);This separation is deliberate:
- transforms stay graph-based
- compile stays backend-oriented
- flattening and CSE happen once, late
IV. Scalar Example: f(x) = exp(a * x)
Step 1: build the primal graph F0
p0 = Input(x)
p1 = Input(a)
p2 = Mul(p0, p1)
p3 = Exp(p2)
Key identities:
key(p0) = Input(x)
key(p1) = Input(a)
key(p2) = Derived { op=Mul(Input(x), Input(a)), output_slot=0 }
key(p3) = Derived { op=Exp(key(p2)), output_slot=0 }
Step 2: linearize the resolved primal view
L1 = linearize(resolve([F0]), outputs=[key(p3)], wrt=[x])
One possible linear graph:
t0 = Input(t_x) // new tangent input key
t1 = Mul(External(key(p1)), Local(t0)) role=Linearized { active_mask=[fixed, active] }
t2 = Mul(External(key(p3)), Local(t1)) role=Linearized { active_mask=[fixed, active] }
Important facts:
t0,t1, andt2are local toL1key(p1)andkey(p3)are external references intoF0- no physical merge has happened
Step 3: resolve for higher-order tracing
If we want a second derivative, we resolve the combined logical view:
R1 = resolve([F0, L1])
Now linearize can trace the output of L1 through key(p3) and then further through the primal chain back to x.
This is the critical distinction:
R1 is enough for higher-order AD.
No physical merge is required yet.
Step 4: linear_transpose the linear graph
T1 = linear_transpose(L1)
One possible transposed graph:
c0 = Input(ct_y)
c1 = Mul(External(key(p3)), Local(c0)) role=Linearized { active_mask=[fixed, active] }
c2 = Mul(External(key(p1)), Local(c1)) role=Linearized { active_mask=[fixed, active] }
This computes the cotangent with respect to x.
Step 5: materialize only when compiling
view = resolve([F0, T1])
graph = materialize_merge(view, [key(p3), key(c2)])
prog = compile(graph)
Conceptually, the materialized graph contains:
Input(x)
Input(a)
Mul(x, a)
Exp(a*x)
Input(ct_y)
Mul(exp(a*x), ct_y)
Mul(a, exp(a*x) * ct_y)
Resulting formulas
y = exp(a*x)
dy = exp(a*x) * a * dx
ct_x = a * exp(a*x) * ct_y
Higher order
Second directional derivative uses resolve, then linearize again:
L2 = linearize(resolve([F0, L1]), outputs=[key(t2)], wrt=[x])
Again, no physical merge is required before this step.
V. Vector Examples
The vector examples remain mathematically identical to the earlier version. What changes is only the graph interpretation: graphs stay separate until materialize_merge.
For readability, Sum below is shorthand for ReduceSum over all axes.
Vector example 1: elementwise y = exp(a * x) with x, a in R^2
Primal graph:
u0 = Input(x:[2])
u1 = Input(a:[2])
u2 = Mul(u0, u1)
u3 = Exp(u2)
Linear graph from linearize(resolve([F0]), outputs=[key(u3)], wrt=[x]):
u4 = Input(t_x:[2])
u5 = Mul(External(key(u1)), Local(u4)) role=Linearized { active_mask=[fixed, active] }
u6 = Mul(External(key(u3)), Local(u5)) role=Linearized { active_mask=[fixed, active] }
Transposed graph:
u7 = Input(ct_y:[2])
u8 = Mul(External(key(u3)), Local(u7)) role=Linearized { active_mask=[fixed, active] }
u9 = Mul(External(key(u1)), Local(u8)) role=Linearized { active_mask=[fixed, active] }
Resulting formulas:
y = [exp(a0*x0), exp(a1*x1)]
dy = [exp(a0*x0) * a0 * dx0,
exp(a1*x1) * a1 * dx1]
ct_x = [a0 * exp(a0*x0) * ct_y0,
a1 * exp(a1*x1) * ct_y1]
This stays purely elementwise. The JVP matches finite differences and the linear_transpose satisfies <ct_y, dy> = <ct_x, t_x>.
Vector example 2: reduction y = Sum(exp(a * x)) with x, a in R^2
Primal graph:
r0 = Input(x:[2])
r1 = Input(a:[2])
r2 = Mul(r0, r1)
r3 = Exp(r2)
r4 = Sum(r3)
Linear graph:
r5 = Input(t_x:[2])
r6 = Mul(External(key(r1)), Local(r5)) role=Linearized { active_mask=[fixed, active] }
r7 = Mul(External(key(r3)), Local(r6)) role=Linearized { active_mask=[fixed, active] }
r8 = Sum(Local(r7)) role=Linearized { active_mask=[active] }
Transposed graph:
r9 = Input(ct_y:[])
r10 = BroadcastInDim(Local(r9), shape=[2], dims=[]) role=Linearized { active_mask=[active] }
r11 = Mul(External(key(r3)), Local(r10)) role=Linearized { active_mask=[fixed, active] }
r12 = Mul(External(key(r1)), Local(r11)) role=Linearized { active_mask=[fixed, active] }
Resulting formulas:
y = exp(a0*x0) + exp(a1*x1)
dy = exp(a0*x0) * a0 * dx0 + exp(a1*x1) * a1 * dx1
ct_x = [a0 * exp(a0*x0) * ct_y,
a1 * exp(a1*x1) * ct_y]
This is the smallest vector example that makes reduction linear_transpose explicit without requiring eager merge.
A reproducible checker for these two examples is in scripts/vector_ad_examples_check.py.
VI. Primitive Set and Traits
Operand
Operand is the runtime value type (tensor-like; scalars are rank-0 tensors). Canonical signature in spec/primitive-catalog.md.
Primitive
Primitive extends GraphOperation with add() (cotangent accumulation), linearize, and transpose_rule. tidu is fully generic over this trait. Canonical signature in spec/ad-contract.md.
Linearization and linear_transpose rules
A primitive’s linearize must be linear in tangent inputs. It may:
- reference primal inputs or outputs through
External(ValueKey) - add primal primitives in
OperationRole::Linearized - add
Conjwhen required by linear_transpose semantics
It must not introduce nonlinear dependence on tangent inputs.
Fan-out accumulation (when multiple cotangents flow to the same tangent node during linear_transpose) is handled internally by tidu::linear_transpose, not by an explicit Dup primitive. tidu buckets reverse contributions by ValueKey and accumulates them by emitting Op::add() nodes.
A primitive’s transpose_rule receives cotangent outputs and must produce cotangent inputs. It must only add primitives that themselves implement Primitive. The downstream implementor is responsible for ensuring that the set of primitives reachable through linearize and transpose_rule is closed.
Closure responsibility
tidu does not define or constrain the primitive set. It is fully generic over Op: Primitive. The only rule is:
linearizeandtranspose_rulemust add only ops that themselves implementPrimitive.
This ensures that tidu can apply linearize and linear_transpose to any graph without knowledge of the specific primitives involved. The concrete primitive set and its closure guarantees are entirely the downstream implementor’s responsibility (e.g. tenferro).
There is no dedicated Scale primitive in this design.
VII. Compilation and Execution
Pipeline
Graphs (primal / linear / transposed)
|
| resolve ← computegraph
v
Resolved logical DAG
|
| materialize_merge ← computegraph
v
MaterializedGraph
|
| compile (SSA) ← computegraph
v
CompiledProgram
|
| eval ← computegraph
| or lower to StableHLO backends ← tenferro
v
Runtime values
CompiledProgram
CompiledProgram is an SSA-form instruction sequence produced by computegraph::compile. Each slot is written exactly once.
struct CompiledProgram<Op: GraphOperation> {
instructions: Vec<Instruction<Op>>,
input_slots: Vec<usize>,
output_slots: Vec<usize>,
n_slots: usize,
}
struct Instruction<Op> {
operation: Op,
inputs: Vec<usize>,
outputs: Vec<usize>,
}compile runs only after materialize_merge:
- one physical DAG exists
- same-key nodes are already unified
- slot assignment can be computed once
Compile once, eval many times. computegraph caches compiled programs keyed by graph structure to avoid recompilation.
Backend boundary
This document stops at MaterializedGraph -> CompiledProgram. Backend-specific details (execution lowering, GPU dispatch) remain in ../spec/backend-contract.md.
The important contract is:
- AD transforms (tidu) are graph-based and resolver-backed
- graph infrastructure (computegraph) is AD-agnostic
- backends only see the materialized or compiled result
VIII. Advantages of the Graph + Resolver Model
No eager merge between transforms
Higher-order AD requires traceability, not a physically merged graph. Delaying physical merge avoids repeated flattening and repeated global CSE.
Better fit for partial transforms
Cross-country evaluation and partial linear_transpose are more natural when transforms operate on graphs rather than on one giant graph.
Global identity is explicit
ValueKey gives one identity mechanism for:
- external refs
- accumulation buckets
- CSE
- logical reachability
Compile-time work is isolated
Only materialize_merge and compile need one concrete graph. Transform-time code can stay light and local.
Higher-order AD stays clean
The rule for higher order is simple:
resolve before the next linearize
materialize_merge before compile
IX. Golden Tests
Minimal tests that validate the graph-based transform procedure:
| # | Function | What it checks |
|---|---|---|
| 1 | x + x |
linear_transpose accumulation buckets by global identity |
| 2 | x * y |
binary linearization with distinct reverse sinks |
| 3 | c * z (complex) |
Conj appears only in linear_transpose |
| 4 | x^2 |
higher-order AD without eager physical merge |
| 5 | exp(a*x) |
external refs, resolve-before-linearize, linear_transpose correctness |
| 6 | Sum(exp(a*x)) |
reduction linear_transpose via BroadcastInDim |
| 7 | exp(a*x) 3rd order |
repeated higher-order closure over graphs |
Expected second-order result for x^2 with unit seeds:
| Mode | Output |
|---|---|
| FoF | 2 |
| FoR | 2 |
| RoF | 2 |
| RoR | 2 |
Expected second-order result for exp(a*x) with unit seeds:
| Mode | Output |
|---|---|
| FoF | a^2 exp(ax) |
| FoR | a^2 exp(ax) |
| RoF | a^2 exp(ax) |
| RoR | a^2 exp(ax) |
X. Implementation Status
Phases 1–3 (scalar graph AD, tensor primitives, backend compilation) are implemented and tested. Current work focuses on:
- Logical-DAG-aware checkpoint scheduling
- Partial linear_transpose / cross-country mode
- Late materialization heuristics
- Operator fusion in compiled IR
XI. Superseded Issues
This document unifies and supersedes:
- tidu-rs#12: tape AD design
- tidu-rs#13: graph-based AD design
- tidu-rs#7: trait unification
- tidu-rs#8: DifferentiableOp trait
- tenferro-rs#616: Traced Tensor + StableHLO IR
- tenferro-rs#618: tenferro roadmap (AD portions)