Primitive AD Contract
Repo: tidu-rs Parent: ../index.md Depends on: computegraph-rs
I. Purpose
tidu-rs defines the AD trait (Primitive) that extends computegraph::GraphOperation with a cotangent-accumulation constructor plus JVP and transpose rules. It contains no concrete primitives.
This is the counterpart of the AD behavior that JAX stores in primitive_jvps and primitive_transposes. The information is the same kind of information, but the representation is different:
- JAX: global registries keyed by primitive object
- tenferro: methods on the concrete primitive type itself
II. Primitive Trait
Primitive extends GraphOperation with add() (cotangent accumulation constructor), linearize (JVP rule), and transpose_rule (VJP rule).
Canonical trait signature: ../spec/ad-contract.md.
add() returns the primitive used by tidu::linear_transpose when multiple cotangent contributions flow to the same ValueKey. This keeps fan-out accumulation inside the generic linear_transpose pass without requiring a separate built-in Dup or Add primitive in tidu.
III. Linearization Rules
A primitive’s linearize must be linear in tangent inputs. It may:
- reference primal inputs or outputs through
External(ValueKey) - add primitives in
OperationRole::Linearized - add
Conjwhen required by linear_transpose semantics
It must not introduce nonlinear dependence on tangent inputs.
The intended mental model is close to JAX linearize:
- JAX
linearizeapplies a primitive’s JVP rule and emits a new composition of JAX primitives in a jaxpr Primitive::jvp_ruleemits a new composition of downstream concrete primitives into a graph
IV. Transpose Rules
A primitive’s transpose_rule receives cotangent outputs and produces cotangent inputs. It must only add primitives that themselves implement Primitive.
When linear_transpose encounters fan-out, tidu accumulates multiple reverse contributions by emitting Primitive::add() nodes. So every downstream primitive set used with tidu must provide an addition primitive suitable for cotangent accumulation.
V. ADKey Trait
tidu-rs defines the ADKey trait that constrains InputKey for AD use. tidu-rs uses this trait to generate tangent input keys during linearize.
pub type DiffPassId = u64;
pub trait ADKey: Clone + Debug + Hash + Eq + Send + Sync + 'static {
/// Create a tangent input key derived from this key.
/// `pass` is a unique identifier for the `linearize` call.
fn tangent_of(&self, pass: DiffPassId) -> Self;
}Primitive requires Self::InputKey: ADKey (see ../spec/ad-contract.md for the canonical Primitive trait signature).
The concrete implementation of ADKey is the downstream implementor’s choice. A typical pattern is a recursive enum:
// tenferro-rs
enum TensorInputKey {
User(String),
Tangent { of: Box<TensorInputKey>, pass: DiffPassId },
}This gives debuggable keys like Tangent { of: Tangent { of: User("x"), pass: 1 }, pass: 3 } for higher-order AD.
VI. Closure Responsibility
tidu-rs defines the contract but does not enforce closure. The downstream implementor (e.g. tenferro-rs) is responsible for ensuring that the set of primitives reachable through linearize and transpose_rule is closed — i.e., every emitted op also implements Primitive.
VII. Design Boundaries
tidu-rs owns:
- Primitive trait (`add` + `jvp_rule` + `transpose_rule`)
- AD transforms (`linearize`, `linear_transpose`)
tidu-rs does NOT own:
- graph infrastructure → computegraph-rs
- concrete primitives → downstream (tenferro-rs)