tidu-rs Design
Repo: tidu-rs Parent: ../index.md Depends on: computegraph-rs
I. Purpose
tidu-rs provides AD-specific graph transforms (linearize, linear_transpose) that are fully generic over Op: Primitive. It owns no graph infrastructure (that belongs to computegraph-rs) and references no specific primitives.
Among the JAX concepts, linearize is the closest analogue to jax.linearize: it traverses a primal computation and builds a new linear computation by calling each primitive’s local linearization rule. The output is not StableHLO and not a backend kernel plan; it is another graph composed of the same downstream primitive vocabulary.
II. Transforms
linearize
Consumes a resolved view and returns a new linear graph (JVP).
use computegraph::{ResolvedView, ValueKey, LocalValueId, Graph};
use tidu::Primitive;
struct LinearizedGraph<Op> {
graph: Graph<Op>,
tangent_inputs: Vec<(Op::InputKey, LocalValueId)>,
tangent_outputs: Vec<Option<LocalValueId>>,
}
fn linearize<Op: Primitive>(
view: &ResolvedView<Op>,
outputs: &[ValueKey<Op>],
wrt: &[Op::InputKey],
) -> LinearizedGraph<Op>;Each call to linearize receives a unique DiffPassId (monotonically increasing counter). Tangent input keys are generated via wrt_key.tangent_of(pass_id) (see ADKey in the primitive AD contract).
Algorithm:
- Traverse the reachable logical DAG in topological order.
- Seed tangent inputs for the requested primal
InputKeys (keys generated viaADKey::tangent_of). - For each reachable primitive, call
Op::linearize. - Emit new local linear nodes into the new graph.
- Reference primal values through
External(ValueKey). - Skip unreachable tangent flow with zero propagation.
This is the graph-level analogue of JAX building a jaxpr whose linearized body is itself a composition of primitives.
linear_transpose
Consumes a linear graph and produces another with active inputs and outputs reversed.
fn linear_transpose<Op: Primitive>(
linear: &LinearizedGraph<Op>,
) -> LinearizedGraph<Op>;Traverses the linear graph in reverse topological order and, for each op node, calls Op::transpose_rule to obtain the local transposed contribution.
Transpose accumulation must use global identity: when multiple reverse contributions flow back to the same original tangent node, bucket by the global key of that tangent value, not by a graph-local id.
Fan-out accumulation is handled internally by linear_transpose, not by an explicit Dup primitive. When multiple cotangents flow to the same ValueKey, linear_transpose accumulates them by emitting Op::add() nodes. This follows the JAX approach where add_jaxvals is built into the linear_transpose pass rather than expressed as a separate primitive in the graph. Downstream primitive implementors do not need to implement Dup.
Transpose algorithm
fn linear_transpose<Op: Primitive>(linear: &LinearizedGraph<Op>) -> LinearizedGraph<Op> {
let mut builder = GraphBuilder::new();
let mut ct_env: HashMap<ValueKey<Op>, LocalValueId> = HashMap::new();
// 1. Seed cotangent outputs
for (out_key, ct_input_id) in cotangent_seeds {
ct_env.insert(out_key, ct_input_id);
}
// 2. Reverse topological traversal
for op_node in linear.as_graph().operations().iter().rev() {
// Look up cotangent for this op's outputs
let ct_outs: Vec<Option<LocalValueId>> = op_node.outputs.iter()
.map(|out_id| ct_env.get(&global_key(out_id)).copied())
.collect();
// Delegate to per-op linear_transpose rule
let ct_ins = op_node.operation.transpose_rule(
&mut builder, &ct_outs, &op_node.inputs, &op_node.role,
);
// 3. Accumulate cotangents by ValueKey
for (input, ct_in) in op_node.inputs.iter().zip(ct_ins) {
if let Some(ct) = ct_in {
let key = global_key_of(input);
match ct_env.entry(key) {
Vacant(e) => { e.insert(ct); }
Occupied(e) => {
// Fan-out: add Add node for accumulation
let existing = *e.get();
let sum = builder.add_operation(
Op::add(),
vec![ValueRef::Local(existing), ValueRef::Local(ct)],
OperationRole::Linearized { active_mask: vec![true, true] },
);
*e.into_mut() = sum[0];
}
}
}
}
}
// Build transposed LinearizedGraph from builder + ct_env
}The accumulation Add nodes emitted during linear_transpose are normal graph nodes in the transposed graph. They carry OperationRole::Linearized { active_mask: [active, active] } and participate in subsequent AD transforms like any other node. This is why Primitive includes add(): tidu needs one generic way to construct those accumulation nodes.
Worked example: linear_transpose of f(x) = (x+x)*x
Primal graph F0:
p0 = Input(x)
p1 = Add(p0, p0) // 2x
p2 = Mul(p1, p0) // 2x²
Linearize wrt x → L1:
t0 = Input(dx)
t1 = Add(t0, t0) Linear{[active, active]} // 2·dx
t2 = Mul(External(p1), Local(t0)) Linear{[fixed, active]} // 2x·dx
t3 = Mul(Local(t1), External(p0)) Linear{[active, fixed]} // 2·dx·x
t4 = Add(Local(t2), Local(t3)) Linear{[active, active]} // 4x·dx
Transpose L1, seed ct_y. ct_env state after each step:
seed: ct_env = { t4.key → c0 } c0 = Input(ct_y)
Reverse t4 = Add(t2, t3):
Add linear_transpose → ct_t2 = c0, ct_t3 = c0
ct_env = { t4.key → c0, t2.key → c0, t3.key → c0 }
Reverse t3 = Mul(t1, p0) [active, fixed]:
Mul linear_transpose wrt active → ct_t1 = Mul(p0, c0)
c1 = Mul(External(p0), Local(c0))
ct_env = { ..., t1.key → c1 }
Reverse t2 = Mul(p1, t0) [fixed, active]:
Mul linear_transpose wrt active → ct_t0 = Mul(p1, c0)
c2 = Mul(External(p1), Local(c0))
ct_env = { ..., t0.key → c2 } ← 1st entry for t0
Reverse t1 = Add(t0, t0) [active, active]:
Add linear_transpose → both inputs get ct_t1 = c1
Left input t0: ct_env[t0.key] = c2 (existing) → add Add
c3 = Add(c2, c1) ← accumulation #1
ct_env[t0.key] = c3
Right input t0: ct_env[t0.key] = c3 (existing) → add Add
c4 = Add(c3, c1) ← accumulation #2
ct_env[t0.key] = c4
Transposed graph T1:
c0 = Input(ct_y)
c1 = Mul(External(p0), Local(c0)) // x · ct_y
c2 = Mul(External(p1), Local(c0)) // 2x · ct_y
c3 = Add(Local(c2), Local(c1)) // accumulation Add #1
c4 = Add(Local(c3), Local(c1)) // accumulation Add #2
output: c4 = 2x·ct_y + x·ct_y + x·ct_y = 4x·ct_y ✓ (f'=4x)
Note: c1 is referenced by both c3 and c4 — fan-out in the transposed graph itself. This is handled correctly by subsequent transforms (see next section).
III. Higher-Order AD and Accumulation Correctness
FoR: linearize the transposed graph
The transposed graph T1 computes ct_x = 4x · ct_y as a function of (x, ct_y). To get the second derivative (FoR), linearize T1 wrt x via resolve([F0, T1]).
Primal tangents:
dp0 = dx2
dp1 = d(Add(p0, p0)) = Add(dx2, dx2) = 2·dx2
Tangent of each T1 node (dc0 = None because ct_y does not depend on x):
dc1 = d(Mul(p0, c0)): dp0 = dx2, dc0 = None
→ Mul(dx2, c0) = dx2 · ct_y
dc2 = d(Mul(p1, c0)): dp1 = 2·dx2, dc0 = None
→ Mul(Add(dx2, dx2), c0) = 2·dx2 · ct_y
dc3 = d(Add(c2, c1)): ← accumulation Add, linearized normally
→ Add(dc2, dc1) = 2·dx2·ct_y + dx2·ct_y = 3·dx2·ct_y
dc4 = d(Add(c3, c1)): ← accumulation Add, linearized normally
→ Add(dc3, dc1) = 3·dx2·ct_y + dx2·ct_y = 4·dx2·ct_y
Result: dc4 = 4·dx2·ct_y → f’’ = 4 ✓ (f=2x², f’=4x, f’’=4)
Why this is self-consistent
Accumulation produces normal graph nodes. The
Addnodes emitted during linear_transpose carryrole=Linearized{[active, active]}. They have the samelinearizeandtranspose_ruleas any otherAddnode.Fan-out in transposed graphs is safe. c1 is used by both c3 and c4. In the forward direction (FoR),
dc1feeds into bothdc3anddc4’s linearize — this is just multiple references to the same tangent value, which is always correct in forward mode.Further linear_transpose (RoR) also works. If we linear_transpose the FoR graph,
dc1being used twice would cause two cotangents to flow todc1’s key. The sameHashMapaccumulation mechanism handles this recursively.No special-casing at any level. The
linearizeandlinear_transposealgorithms are uniform:linearizecallsOp::linearizefor each node,linear_transposecallsOp::transpose_ruleand accumulates. The accumulationAddis indistinguishable from any otherAddin the graph.
IV. Typical Pipelines
JVP:
build -> resolve -> linearize -> materialize_merge -> compile -> eval
VJP (grad):
build -> resolve -> linearize -> linear_transpose -> materialize_merge -> compile -> eval
2nd directional derivative (FoF):
build -> resolve -> linearize -> resolve -> linearize
-> materialize_merge -> compile -> eval
HVP (FoR = jvp(vjp(f))):
build -> resolve -> linearize -> linear_transpose -> resolve -> linearize
-> materialize_merge -> compile -> eval
n-th derivative:
build -> (resolve -> linearize) x n -> [linear_transpose] -> materialize_merge -> compile -> eval
resolve, materialize_merge, compile, eval are provided by computegraph-rs. tidu-rs only adds linearize and linear_transpose.
V. Linear Nodes
The linear graph uses the same primitive set as the primal graph. There is no dedicated Scale primitive.
Mul(a, dx) role=Linearized { active_mask=[fixed, active] }
Add(dx, dy) role=Linearized { active_mask=[active, active] }
Exp(x) role=Primary
The linearization of Exp(x) emits:
Mul(External(exp(x)), dx) role=Linearized { active_mask=[fixed, active] }
Active mask is part of identity (OperationRole::Linearized vs Primary). Nodes that evaluate the same way but linear_transpose differently must not alias.
VI. Design Boundaries
tidu-rs owns:
- Primitive trait
- linearize (JVP transform)
- linear_transpose (reverse linear flow)
- LinearizedGraph data structure
tidu-rs does NOT own:
- graph infrastructure → computegraph-rs
- concrete primitives → downstream (tenferro-rs)