tidu-rs Design
Repo: tidu-rs Parent: ../index.md Depends on: computegraph-rs, chainrules-rs
I. Purpose
tidu-rs provides AD-specific graph transforms (differentiate, transpose) that are fully generic over Op: PrimitiveOp. It owns no graph infrastructure (that belongs to computegraph-rs) and references no specific primitives.
Among the JAX concepts, differentiate is the closest analogue to jax.linearize: it traverses a primal computation and builds a new linear computation by calling each primitive’s local linearization rule. The output is not StableHLO and not a backend kernel plan; it is another fragment composed of the same downstream primitive vocabulary.
II. Transforms
differentiate
Consumes a resolved view and returns a new linear fragment (JVP).
use computegraph::{ResolvedView, GlobalValKey, LocalValId, Fragment};
use chainrules::PrimitiveOp;
struct LinearFragment<Op> {
fragment: Fragment<Op>,
tangent_inputs: Vec<(Op::InputKey, LocalValId)>,
tangent_outputs: Vec<Option<LocalValId>>,
}
fn differentiate<Op: PrimitiveOp>(
view: &ResolvedView<Op>,
outputs: &[GlobalValKey<Op>],
wrt: &[Op::InputKey],
) -> LinearFragment<Op>;Each call to differentiate receives a unique DiffPassId (monotonically increasing counter). Tangent input keys are generated via wrt_key.tangent_of(pass_id) (see ADKey trait in chainrules.md).
Algorithm:
- Traverse the reachable logical DAG in topological order.
- Seed tangent inputs for the requested primal
InputKeys (keys generated viaADKey::tangent_of). - For each reachable primitive, call
Op::linearize. - Emit new local linear nodes into the new fragment.
- Reference primal values through
External(GlobalValKey). - Skip unreachable tangent flow with zero propagation.
This is the fragment-level analogue of JAX building a jaxpr whose linearized body is itself a composition of primitives.
transpose
Consumes a linear fragment and produces another with active inputs and outputs reversed.
fn transpose<Op: PrimitiveOp>(
linear: &LinearFragment<Op>,
) -> LinearFragment<Op>;Traverses the linear fragment in reverse topological order and, for each op node, calls Op::transpose_rule to obtain the local transposed contribution.
Transpose accumulation must use global identity: when multiple reverse contributions flow back to the same original tangent node, bucket by the global key of that tangent value, not by a fragment-local id.
Fan-out accumulation is handled internally by transpose, not by an explicit Dup primitive. When multiple cotangents flow to the same GlobalValKey, transpose accumulates them by emitting Op::add() nodes. This follows the JAX approach where add_jaxvals is built into the transpose pass rather than expressed as a separate primitive in the graph. Downstream primitive implementors do not need to implement Dup.
Transpose algorithm
fn transpose<Op: PrimitiveOp>(linear: &LinearFragment<Op>) -> LinearFragment<Op> {
let mut builder = FragmentBuilder::new();
let mut ct_env: HashMap<GlobalValKey<Op>, LocalValId> = HashMap::new();
// 1. Seed cotangent outputs
for (out_key, ct_input_id) in cotangent_seeds {
ct_env.insert(out_key, ct_input_id);
}
// 2. Reverse topological traversal
for op_node in linear.fragment.ops().iter().rev() {
// Look up cotangent for this op's outputs
let ct_outs: Vec<Option<LocalValId>> = op_node.outputs.iter()
.map(|out_id| ct_env.get(&global_key(out_id)).copied())
.collect();
// Delegate to per-op transpose rule
let ct_ins = op_node.op.transpose_rule(
&mut builder, &ct_outs, &op_node.inputs, &op_node.mode,
);
// 3. Accumulate cotangents by GlobalValKey
for (input, ct_in) in op_node.inputs.iter().zip(ct_ins) {
if let Some(ct) = ct_in {
let key = global_key_of(input);
match ct_env.entry(key) {
Vacant(e) => { e.insert(ct); }
Occupied(e) => {
// Fan-out: emit Add node for accumulation
let existing = *e.get();
let sum = builder.add_op(
Op::add(),
vec![ValRef::Local(existing), ValRef::Local(ct)],
OpMode::Linear { active_mask: vec![true, true] },
);
*e.into_mut() = sum[0];
}
}
}
}
}
// Build transposed LinearFragment from builder + ct_env
}The accumulation Add nodes emitted during transpose are normal graph nodes in the transposed fragment. They carry OpMode::Linear { active_mask: [active, active] } and participate in subsequent AD transforms like any other node. This is why PrimitiveOp includes add(): tidu needs one generic way to construct those accumulation nodes.
Worked example: transpose of f(x) = (x+x)*x
Primal fragment F0:
p0 = Input(x)
p1 = Add(p0, p0) // 2x
p2 = Mul(p1, p0) // 2x²
Linearize wrt x → L1:
t0 = Input(dx)
t1 = Add(t0, t0) Linear{[active, active]} // 2·dx
t2 = Mul(External(p1), Local(t0)) Linear{[fixed, active]} // 2x·dx
t3 = Mul(Local(t1), External(p0)) Linear{[active, fixed]} // 2·dx·x
t4 = Add(Local(t2), Local(t3)) Linear{[active, active]} // 4x·dx
Transpose L1, seed ct_y. ct_env state after each step:
seed: ct_env = { t4.key → c0 } c0 = Input(ct_y)
Reverse t4 = Add(t2, t3):
Add transpose → ct_t2 = c0, ct_t3 = c0
ct_env = { t4.key → c0, t2.key → c0, t3.key → c0 }
Reverse t3 = Mul(t1, p0) [active, fixed]:
Mul transpose wrt active → ct_t1 = Mul(p0, c0)
c1 = Mul(External(p0), Local(c0))
ct_env = { ..., t1.key → c1 }
Reverse t2 = Mul(p1, t0) [fixed, active]:
Mul transpose wrt active → ct_t0 = Mul(p1, c0)
c2 = Mul(External(p1), Local(c0))
ct_env = { ..., t0.key → c2 } ← 1st entry for t0
Reverse t1 = Add(t0, t0) [active, active]:
Add transpose → both inputs get ct_t1 = c1
Left input t0: ct_env[t0.key] = c2 (existing) → emit Add
c3 = Add(c2, c1) ← accumulation #1
ct_env[t0.key] = c3
Right input t0: ct_env[t0.key] = c3 (existing) → emit Add
c4 = Add(c3, c1) ← accumulation #2
ct_env[t0.key] = c4
Transposed fragment T1:
c0 = Input(ct_y)
c1 = Mul(External(p0), Local(c0)) // x · ct_y
c2 = Mul(External(p1), Local(c0)) // 2x · ct_y
c3 = Add(Local(c2), Local(c1)) // accumulation Add #1
c4 = Add(Local(c3), Local(c1)) // accumulation Add #2
output: c4 = 2x·ct_y + x·ct_y + x·ct_y = 4x·ct_y ✓ (f'=4x)
Note: c1 is referenced by both c3 and c4 — fan-out in the transposed fragment itself. This is handled correctly by subsequent transforms (see next section).
III. Higher-Order AD and Accumulation Correctness
FoR: differentiate the transposed fragment
The transposed fragment T1 computes ct_x = 4x · ct_y as a function of (x, ct_y). To get the second derivative (FoR), differentiate T1 wrt x via resolve([F0, T1]).
Primal tangents:
dp0 = dx2
dp1 = d(Add(p0, p0)) = Add(dx2, dx2) = 2·dx2
Tangent of each T1 node (dc0 = None because ct_y does not depend on x):
dc1 = d(Mul(p0, c0)): dp0 = dx2, dc0 = None
→ Mul(dx2, c0) = dx2 · ct_y
dc2 = d(Mul(p1, c0)): dp1 = 2·dx2, dc0 = None
→ Mul(Add(dx2, dx2), c0) = 2·dx2 · ct_y
dc3 = d(Add(c2, c1)): ← accumulation Add, linearized normally
→ Add(dc2, dc1) = 2·dx2·ct_y + dx2·ct_y = 3·dx2·ct_y
dc4 = d(Add(c3, c1)): ← accumulation Add, linearized normally
→ Add(dc3, dc1) = 3·dx2·ct_y + dx2·ct_y = 4·dx2·ct_y
Result: dc4 = 4·dx2·ct_y → f’’ = 4 ✓ (f=2x², f’=4x, f’’=4)
Why this is self-consistent
Accumulation produces normal graph nodes. The
Addnodes emitted during transpose carrymode=Linear{[active, active]}. They have the samelinearizeandtranspose_ruleas any otherAddnode.Fan-out in transposed fragments is safe. c1 is used by both c3 and c4. In the forward direction (FoR),
dc1feeds into bothdc3anddc4’s linearize — this is just multiple references to the same tangent value, which is always correct in forward mode.Further transpose (RoR) also works. If we transpose the FoR fragment,
dc1being used twice would cause two cotangents to flow todc1’s key. The sameHashMapaccumulation mechanism handles this recursively.No special-casing at any level. The
differentiateandtransposealgorithms are uniform:differentiatecallsOp::linearizefor each node,transposecallsOp::transpose_ruleand accumulates. The accumulationAddis indistinguishable from any otherAddin the graph.
IV. Typical Pipelines
JVP:
build -> resolve -> differentiate -> materialize_merge -> compile -> eval
VJP (grad):
build -> resolve -> differentiate -> transpose -> materialize_merge -> compile -> eval
2nd directional derivative (FoF):
build -> resolve -> differentiate -> resolve -> differentiate
-> materialize_merge -> compile -> eval
HVP (FoR = jvp(vjp(f))):
build -> resolve -> differentiate -> transpose -> resolve -> differentiate
-> materialize_merge -> compile -> eval
n-th derivative:
build -> (resolve -> differentiate) x n -> [transpose] -> materialize_merge -> compile -> eval
resolve, materialize_merge, compile, eval are provided by computegraph-rs. tidu-rs only adds differentiate and transpose.
V. Linear Nodes
The linear graph uses the same primitive set as the primal graph. There is no dedicated Scale primitive.
Mul(a, dx) mode=Linear { active_mask=[fixed, active] }
Add(dx, dy) mode=Linear { active_mask=[active, active] }
Exp(x) mode=Primal
The linearization of Exp(x) emits:
Mul(External(exp(x)), dx) mode=Linear { active_mask=[fixed, active] }
Active mask is part of identity (OpMode::Linear vs Primal). Nodes that evaluate the same way but transpose differently must not alias.
VI. Design Boundaries
tidu-rs owns:
- differentiate (JVP transform)
- transpose (reverse linear flow)
- LinearFragment data structure
tidu-rs does NOT own:
- graph infrastructure → computegraph-rs
- PrimitiveOp trait → chainrules-rs
- concrete primitives → downstream (tenferro-rs)