Eager Integration

tidu::eager is for downstream frontends that execute operations immediately and want a reverse-mode backward() workflow.

Recording

Use Recorder to record each eager graph invocation. A single primitive eager operation is represented as a one-operation RecordedGraph; composite eager operations can record a larger primitive graph as one tape node. A Recorder is created from a KeySource, which hands out fresh, collision-free input keys for the internal graphs it records (fn fresh_input_key(&mut self) -> Op::InputKey).

Each input is described with EagerInput:

  • key is the user-visible value key used for cotangent accumulation.
  • trace points to the graph invocation that produced the value, or None for a leaf input.
  • requires_grad controls whether cotangents should flow through the value.
  • data stores concrete primal data for later replay.

Recorder::record_graph(graph, inputs, outputs, retained_values) returns one EagerOutput per recorded graph output. retained_values is a HashMap<ValueKey, Arc<Operand>> of concrete primal values to keep for the backward pass (typically the invocation’s outputs), which reverse mode replays later.

Backward Execution

The downstream runtime implements BackwardExecutor. Its three methods are the hooks tidu calls during try_backward:

pub trait BackwardExecutor<Op: Primitive>
where
    Op::InputKey: ADKey,
{
    // Replay a primitive graph and return the concrete values transpose needs.
    fn execute_forward(
        &mut self,
        graph: PrimitiveGraph<'_, Op>,
        initial_data: &HashMap<ValueKey<Op>, Arc<Op::Operand>>,
    ) -> HashMap<ValueKey<Op>, Arc<Op::Operand>>;

    // Run a transposed linear graph with concrete cotangent seeds.
    fn run_transposed_linear(
        &mut self,
        linear: &LinearizedGraph<Op>,
        cotangent_out: &[Option<Arc<Op::Operand>>],
        external_data: &HashMap<ValueKey<Op>, Arc<Op::Operand>>,
        ctx: &mut Op::ADContext,
    ) -> ADRuleResult<Vec<Option<Arc<Op::Operand>>>>;

    // Sum two concrete cotangents where multiple paths meet.
    fn add_operands(&mut self, a: &Arc<Op::Operand>, b: &Arc<Op::Operand>) -> Arc<Op::Operand>;
}

execute_forward receives a PrimitiveGraph; call as_graph() to reach the underlying computegraph::Graph and iterate it (see Computegraph Integration). The downstream runtime still owns tensor allocation, gradient storage, device selection, shape metadata, and user-facing error reporting.

The Backward Call Sequence

The sequence below runs from recording an invocation to producing cotangents.

sequenceDiagram
  participant DS as Downstream frontend
  participant R as Recorder
  participant TB as tidu::eager::try_backward
  participant EX as BackwardExecutor
  DS->>R: record_graph(graph, inputs, outputs, retained_values)
  R-->>DS: Vec of EagerOutput (key, trace, requires_grad, output_slot)
  Note over DS: forward continues, each output trace links to its inputs
  DS->>TB: try_backward(output_key, output_trace, seed, executor, ctx)
  loop per trace node, reverse topological order
    TB->>EX: execute_forward(PrimitiveGraph, initial_data)
    EX-->>TB: concrete primal values
    TB->>EX: run_transposed_linear(LinearizedGraph, cotangent_out, external_data, ctx)
    EX-->>TB: input cotangents
    TB->>EX: add_operands(a, b)
    EX-->>TB: accumulated cotangent
  end
  TB-->>DS: cotangents (HashMap of ValueKey to Arc of Operand)

  1. The frontend records each invocation with Recorder::record_graph, which returns one EagerOutput per output. Each EagerOutput carries key, trace, requires_grad, and output_slot (the output’s position within the recorded graph invocation).
  2. To start reverse mode, the frontend calls eager::try_backward(output_key, output_trace, seed, executor, ctx).
  3. try_backward walks the trace in reverse topological order. For each node it linearizes and transposes the recorded graph, driving the downstream BackwardExecutor: execute_forward replays the primitive graph to recover concrete primal values, run_transposed_linear runs the transposed linear map with the incoming cotangents, and add_operands accumulates cotangents where multiple paths meet.
  4. The result is a HashMap<ValueKey, Arc<Operand>> of cotangents for the inputs that required gradients.