Skip to main content

tidu/
eager_record.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use chainrules::{ADKey, PrimitiveOp};
5use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, OpMode};
6
7use crate::{GradEdge, GradNode};
8
9/// Eager frontend input descriptor for generic AD recording.
10///
11/// Downstream frontends execute the primal operation themselves, then pass one
12/// `EagerValue` per concrete input to [`record_eager_op`].
13///
14/// # Examples
15///
16/// ```ignore
17/// let input = tidu::EagerValue {
18///     key: tensor.key.clone(),
19///     node: tensor.grad_node.clone(),
20///     requires_grad: tensor.requires_grad,
21///     data: tensor.data.clone(),
22/// };
23/// ```
24pub struct EagerValue<Op: GraphOp> {
25    /// User-visible eager value key used for cotangent accumulation.
26    pub key: GlobalValKey<Op>,
27    /// Grad node that produced this value, or `None` for leaves.
28    pub node: Option<Arc<GradNode<Op>>>,
29    /// Whether this value participates in reverse-mode propagation.
30    pub requires_grad: bool,
31    /// Concrete primal data for saved forward replay.
32    pub data: Arc<Op::Operand>,
33}
34
35/// Per-output trace metadata returned by [`record_eager_op`].
36///
37/// The downstream eager value type should embed these fields next to its own
38/// concrete output data and gradient slot.
39///
40/// # Examples
41///
42/// ```ignore
43/// let traces = tidu::record_eager_op(&mut keys, op, &inputs, &outputs);
44/// let result = MyEagerValue::new(outputs[0].clone(), traces[0].key.clone(), traces[0].node.clone());
45/// ```
46pub struct EagerOutput<Op: GraphOp> {
47    /// User-visible eager output key.
48    pub key: GlobalValKey<Op>,
49    /// Shared grad node for all outputs when any input requires gradients.
50    pub node: Option<Arc<GradNode<Op>>>,
51    /// Whether this output should be tracked by the downstream frontend.
52    pub requires_grad: bool,
53    /// Output slot within the recorded primitive.
54    pub output_slot: usize,
55}
56
57/// Caller-provided source of stable eager value keys.
58///
59/// `tidu` wraps each fresh input key in `GlobalValKey::Input`. This guarantees
60/// the aliases recorded in [`GradNode::primal_in_keys`] satisfy the current
61/// single-op backward replay model.
62///
63/// # Examples
64///
65/// ```ignore
66/// impl tidu::EagerKeySource<MyOp> for MyKeySource {
67///     fn fresh_input_key(&mut self) -> MyInputKey {
68///         self.next_key()
69///     }
70/// }
71/// ```
72pub trait EagerKeySource<Op: GraphOp> {
73    /// Return a fresh input key that has not been used for another eager value.
74    fn fresh_input_key(&mut self) -> Op::InputKey;
75}
76
77/// Record a concrete eager primitive execution for reverse-mode AD.
78///
79/// The downstream frontend is responsible for executing `op` and passing its
80/// concrete `outputs`. `tidu` allocates stable input aliases and output keys,
81/// builds saved forward data, constructs input edges, and returns per-output
82/// metadata. Multi-output operations share one `GradNode`; each output receives
83/// its own key, and `backward_dag` seeds the matching output slot by key.
84///
85/// # Examples
86///
87/// ```ignore
88/// let output_data = Arc::new(execute_primal(&op, &input_data));
89/// let traces = tidu::record_eager_op(
90///     &mut key_source,
91///     op,
92///     &input_traces,
93///     &[output_data.clone()],
94/// );
95/// ```
96pub fn record_eager_op<Op: PrimitiveOp>(
97    key_source: &mut impl EagerKeySource<Op>,
98    op: Op,
99    inputs: &[EagerValue<Op>],
100    outputs: &[Arc<Op::Operand>],
101) -> Vec<EagerOutput<Op>>
102where
103    Op::InputKey: ADKey,
104{
105    assert_eq!(
106        inputs.len(),
107        op.n_inputs(),
108        "record_eager_op for {:?} expected {} inputs, got {}",
109        op,
110        op.n_inputs(),
111        inputs.len()
112    );
113    assert_eq!(
114        outputs.len(),
115        op.n_outputs(),
116        "record_eager_op for {:?} expected {} outputs, got {}",
117        op,
118        op.n_outputs(),
119        outputs.len()
120    );
121    assert!(
122        outputs.len() <= u8::MAX as usize + 1,
123        "record_eager_op for {:?} has too many outputs for GlobalValKey: {}",
124        op,
125        outputs.len()
126    );
127
128    let input_aliases = fresh_value_keys(key_source, inputs.len());
129    let output_keys = fresh_value_keys(key_source, outputs.len());
130    let requires_grad = inputs.iter().any(|input| input.requires_grad);
131
132    let node = requires_grad.then(|| {
133        Arc::new(GradNode::new(
134            op.clone(),
135            input_aliases.clone(),
136            output_keys.clone(),
137            saved_forward_values(&op, &input_aliases, inputs, outputs),
138            inputs
139                .iter()
140                .map(|input| {
141                    GradEdge::new(input.node.clone(), input.key.clone(), input.requires_grad)
142                })
143                .collect(),
144        ))
145    });
146
147    output_keys
148        .into_iter()
149        .enumerate()
150        .map(|(output_slot, key)| EagerOutput {
151            key,
152            node: node.clone(),
153            requires_grad,
154            output_slot,
155        })
156        .collect()
157}
158
159/// Construct the derived key used to save a replayed primal output value.
160///
161/// # Examples
162///
163/// ```ignore
164/// let key = tidu::derived_output_key(&op, &input_aliases, 0);
165/// ```
166pub fn derived_output_key<Op: GraphOp>(
167    op: &Op,
168    input_aliases: &[GlobalValKey<Op>],
169    output_slot: usize,
170) -> GlobalValKey<Op> {
171    assert!(
172        output_slot <= u8::MAX as usize,
173        "output slot {} is too large for GlobalValKey",
174        output_slot
175    );
176
177    GlobalValKey::Derived {
178        op: GlobalOpKey {
179            primitive: op.clone(),
180            inputs: input_aliases.to_vec(),
181            mode: OpMode::Primal,
182        },
183        output_slot: output_slot as u8,
184    }
185}
186
187/// Build saved forward data for one eager op.
188///
189/// Inputs are saved under stable input aliases. Outputs are saved under the
190/// derived keys produced by replaying `op` with those aliases.
191///
192/// # Examples
193///
194/// ```ignore
195/// let saved = tidu::saved_forward_values(&op, &input_aliases, &inputs, &outputs);
196/// ```
197pub fn saved_forward_values<Op: GraphOp>(
198    op: &Op,
199    input_aliases: &[GlobalValKey<Op>],
200    inputs: &[EagerValue<Op>],
201    outputs: &[Arc<Op::Operand>],
202) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>> {
203    assert_eq!(
204        input_aliases.len(),
205        op.n_inputs(),
206        "saved_forward_values for {:?} expected {} input aliases, got {}",
207        op,
208        op.n_inputs(),
209        input_aliases.len()
210    );
211    assert_eq!(
212        inputs.len(),
213        op.n_inputs(),
214        "saved_forward_values for {:?} expected {} inputs, got {}",
215        op,
216        op.n_inputs(),
217        inputs.len()
218    );
219    assert_eq!(
220        outputs.len(),
221        op.n_outputs(),
222        "saved_forward_values for {:?} expected {} outputs, got {}",
223        op,
224        op.n_outputs(),
225        outputs.len()
226    );
227    assert!(
228        input_aliases
229            .iter()
230            .all(|key| matches!(key, GlobalValKey::Input(_))),
231        "saved_forward_values for {:?} requires GlobalValKey::Input aliases",
232        op
233    );
234    assert!(
235        outputs.len() <= u8::MAX as usize + 1,
236        "saved_forward_values for {:?} has too many outputs for GlobalValKey: {}",
237        op,
238        outputs.len()
239    );
240    assert_eq!(
241        input_aliases.len(),
242        inputs.len(),
243        "saved_forward_values for {:?} expected one alias per input, got {} aliases for {} inputs",
244        op,
245        input_aliases.len(),
246        inputs.len()
247    );
248
249    let mut saved = HashMap::with_capacity(inputs.len() + outputs.len());
250    for (key, input) in input_aliases.iter().zip(inputs.iter()) {
251        saved.insert(key.clone(), input.data.clone());
252    }
253    for (slot, output) in outputs.iter().enumerate() {
254        saved.insert(derived_output_key(op, input_aliases, slot), output.clone());
255    }
256    saved
257}
258
259fn fresh_value_keys<Op: GraphOp>(
260    key_source: &mut impl EagerKeySource<Op>,
261    count: usize,
262) -> Vec<GlobalValKey<Op>> {
263    (0..count)
264        .map(|_| GlobalValKey::Input(key_source.fresh_input_key()))
265        .collect()
266}