Skip to main content

tidu/eager/
record.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::{ADKey, ADRuleResult, Primitive};
5use computegraph::graph::{Graph, GraphBuilder};
6use computegraph::resolve::resolve;
7use computegraph::{GraphOperation, OperationRole, ValueKey, ValueRef};
8
9use crate::LinearizedGraph;
10
11use super::trace::{Trace, TraceEdge, TraceNode};
12
13/// Graph invocation recorded as one eager reverse-mode trace node.
14///
15/// # Examples
16///
17/// ```
18/// use tidu::eager::RecordedGraph;
19/// use computegraph::GraphOperation;
20///
21/// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
22/// enum Op { Add }
23///
24/// impl GraphOperation for Op {
25///     type Operand = f64;
26///     type Context = ();
27///     type InputKey = &'static str;
28///
29///     fn input_count(&self) -> usize { 2 }
30///     fn output_count(&self) -> usize { 1 }
31/// }
32///
33/// let recorded = RecordedGraph::from_primitive(Op::Add, vec!["x", "y"]);
34/// assert_eq!(recorded.input_keys(), &["x", "y"]);
35/// assert_eq!(recorded.output_keys().len(), 1);
36/// ```
37pub struct RecordedGraph<Op: GraphOperation> {
38    graph: Arc<Graph<Op>>,
39    input_keys: Vec<Op::InputKey>,
40    output_keys: Vec<ValueKey<Op>>,
41}
42
43impl<Op: GraphOperation> RecordedGraph<Op> {
44    /// Create a recorded graph from an already-built graph and aligned keys.
45    ///
46    /// # Examples
47    ///
48    /// ```
49    /// use std::sync::Arc;
50    /// use computegraph::graph::GraphBuilder;
51    /// use computegraph::{GraphOperation, OperationRole, ValueRef};
52    /// use tidu::eager::RecordedGraph;
53    ///
54    /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
55    /// enum Op { Id }
56    ///
57    /// impl GraphOperation for Op {
58    ///     type Operand = f64;
59    ///     type Context = ();
60    ///     type InputKey = &'static str;
61    ///
62    ///     fn input_count(&self) -> usize { 1 }
63    ///     fn output_count(&self) -> usize { 1 }
64    /// }
65    ///
66    /// let mut builder = GraphBuilder::new();
67    /// let x = builder.add_input("x");
68    /// let y = builder.add_operation(Op::Id, vec![ValueRef::Local(x)], OperationRole::Primary);
69    /// builder.set_outputs(y.clone());
70    /// let graph = Arc::new(builder.build());
71    /// let output_keys = y.iter().map(|id| graph.values()[*id].key.clone()).collect();
72    /// let recorded = RecordedGraph::new(graph, vec!["x"], output_keys);
73    ///
74    /// assert_eq!(recorded.input_keys(), &["x"]);
75    /// ```
76    pub fn new(
77        graph: Arc<Graph<Op>>,
78        input_keys: Vec<Op::InputKey>,
79        output_keys: Vec<ValueKey<Op>>,
80    ) -> Self {
81        assert_eq!(
82            graph.inputs().len(),
83            input_keys.len(),
84            "RecordedGraph expected {} input keys, got {}",
85            graph.inputs().len(),
86            input_keys.len()
87        );
88        assert_eq!(
89            graph.outputs().len(),
90            output_keys.len(),
91            "RecordedGraph expected {} output keys, got {}",
92            graph.outputs().len(),
93            output_keys.len()
94        );
95        for (&input_id, input_key) in graph.inputs().iter().zip(input_keys.iter()) {
96            assert_eq!(
97                &graph.values()[input_id].key,
98                &ValueKey::Input(input_key.clone()),
99                "RecordedGraph input key order must match graph inputs"
100            );
101        }
102        for (&output_id, output_key) in graph.outputs().iter().zip(output_keys.iter()) {
103            assert_eq!(
104                &graph.values()[output_id].key,
105                output_key,
106                "RecordedGraph output key order must match graph outputs"
107            );
108        }
109
110        Self {
111            graph,
112            input_keys,
113            output_keys,
114        }
115    }
116
117    /// Build a one-op recorded graph for an eager primitive invocation.
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// use tidu::eager::RecordedGraph;
123    /// use computegraph::GraphOperation;
124    ///
125    /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
126    /// enum Op { Add }
127    ///
128    /// impl GraphOperation for Op {
129    ///     type Operand = f64;
130    ///     type Context = ();
131    ///     type InputKey = &'static str;
132    ///
133    ///     fn input_count(&self) -> usize { 2 }
134    ///     fn output_count(&self) -> usize { 1 }
135    /// }
136    ///
137    /// let recorded = RecordedGraph::from_primitive(Op::Add, vec!["x", "y"]);
138    /// assert_eq!(recorded.as_graph().operations().len(), 1);
139    /// ```
140    pub fn from_primitive(op: Op, input_keys: Vec<Op::InputKey>) -> Self {
141        let mut builder = GraphBuilder::new();
142        let input_ids: Vec<_> = input_keys
143            .iter()
144            .cloned()
145            .map(|key| builder.add_input(key))
146            .collect();
147        let output_ids = builder.add_operation(
148            op,
149            input_ids
150                .iter()
151                .map(|local_id| ValueRef::Local(*local_id))
152                .collect(),
153            OperationRole::Primary,
154        );
155        builder.set_outputs(output_ids.clone());
156        let graph = Arc::new(builder.build());
157        let output_keys = output_ids
158            .iter()
159            .map(|output_id| graph.values()[*output_id].key.clone())
160            .collect();
161        Self::new(graph, input_keys, output_keys)
162    }
163
164    /// Borrow the recorded graph.
165    pub fn as_graph(&self) -> &Graph<Op> {
166        &self.graph
167    }
168
169    /// Graph input keys aligned with eager input edges.
170    pub fn input_keys(&self) -> &[Op::InputKey] {
171        &self.input_keys
172    }
173
174    /// Graph output keys aligned with eager output slots.
175    pub fn output_keys(&self) -> &[ValueKey<Op>] {
176        &self.output_keys
177    }
178}
179
180impl<Op: Primitive> RecordedGraph<Op>
181where
182    Op::InputKey: ADKey,
183{
184    pub(crate) fn try_linearize(
185        &self,
186        output_slots: &[usize],
187        ctx: &mut Op::ADContext,
188    ) -> ADRuleResult<LinearizedGraph<Op>> {
189        let selected_outputs: Vec<_> = output_slots
190            .iter()
191            .map(|&slot| {
192                self.output_keys.get(slot).cloned().unwrap_or_else(|| {
193                    panic!(
194                        "RecordedGraph got output slot {}, but graph has {} outputs",
195                        slot,
196                        self.output_keys.len()
197                    )
198                })
199            })
200            .collect();
201        let view = resolve(vec![Arc::clone(&self.graph)]);
202        let aliases = HashMap::new();
203        crate::try_linearize(&view, &selected_outputs, &self.input_keys, 0, ctx, &aliases)
204    }
205}
206
207/// Input descriptor for recording one eager graph invocation.
208pub struct EagerInput<Op: GraphOperation> {
209    /// User-visible eager value key used for cotangent accumulation.
210    pub key: ValueKey<Op>,
211    /// Trace node that produced this value, or `None` for leaves.
212    pub trace: Option<Trace<Op>>,
213    /// Whether this value participates in reverse-mode propagation.
214    pub requires_grad: bool,
215    /// Concrete primal data for saved forward replay.
216    pub data: Arc<Op::Operand>,
217}
218
219/// Per-output trace metadata returned by [`Recorder::record_graph`].
220pub struct EagerOutput<Op: GraphOperation> {
221    /// User-visible eager output key.
222    pub key: ValueKey<Op>,
223    /// Shared trace node for all outputs when any input requires gradients.
224    pub trace: Option<Trace<Op>>,
225    /// Whether this output should be tracked by the downstream frontend.
226    pub requires_grad: bool,
227    /// Output slot within the recorded graph invocation.
228    pub output_slot: usize,
229}
230
231/// Caller-provided source of stable eager value keys.
232pub trait KeySource<Op: GraphOperation> {
233    /// Return a fresh input key that has not been used for another eager value.
234    fn fresh_input_key(&mut self) -> Op::InputKey;
235}
236
237/// Stateful eager operation recorder.
238pub struct Recorder<K> {
239    key_source: K,
240}
241
242impl<K> Recorder<K> {
243    /// Create a recorder from a downstream key source.
244    pub fn new(key_source: K) -> Self {
245        Self { key_source }
246    }
247
248    /// Borrow the underlying key source.
249    pub fn key_source_mut(&mut self) -> &mut K {
250        &mut self.key_source
251    }
252
253    /// Return the underlying key source.
254    pub fn into_key_source(self) -> K {
255        self.key_source
256    }
257
258    /// Return fresh graph input keys for one eager graph invocation.
259    ///
260    /// # Examples
261    ///
262    /// ```
263    /// use tidu::eager::{KeySource, Recorder};
264    /// use computegraph::GraphOperation;
265    ///
266    /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
267    /// enum Op { Id }
268    ///
269    /// impl GraphOperation for Op {
270    ///     type Operand = f64;
271    ///     type Context = ();
272    ///     type InputKey = usize;
273    ///
274    ///     fn input_count(&self) -> usize { 1 }
275    ///     fn output_count(&self) -> usize { 1 }
276    /// }
277    ///
278    /// struct Keys(usize);
279    ///
280    /// impl KeySource<Op> for Keys {
281    ///     fn fresh_input_key(&mut self) -> usize {
282    ///         let key = self.0;
283    ///         self.0 += 1;
284    ///         key
285    ///     }
286    /// }
287    ///
288    /// let mut recorder = Recorder::new(Keys(0));
289    /// assert_eq!(recorder.fresh_input_keys::<Op>(2), vec![0, 1]);
290    /// ```
291    pub fn fresh_input_keys<Op>(&mut self, count: usize) -> Vec<Op::InputKey>
292    where
293        Op: GraphOperation,
294        K: KeySource<Op>,
295    {
296        (0..count)
297            .map(|_| self.key_source.fresh_input_key())
298            .collect()
299    }
300
301    /// Record a concrete eager graph invocation for reverse-mode AD.
302    ///
303    /// # Examples
304    ///
305    /// ```
306    /// use std::collections::HashMap;
307    /// use std::sync::Arc;
308    /// use computegraph::{GraphOperation, LocalValueId, OperationRole, ValueKey};
309    /// use tidu::{
310    ///     ADKey, DiffPassId, Primitive, PrimitiveBuilder, PrimitiveValue,
311    /// };
312    /// use tidu::eager::{EagerInput, KeySource, RecordedGraph, Recorder};
313    ///
314    /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
315    /// enum Key {
316    ///     User(&'static str),
317    ///     Generated(usize),
318    ///     Tangent(Box<Key>, DiffPassId),
319    /// }
320    ///
321    /// impl ADKey for Key {
322    ///     fn tangent_of(&self, pass: DiffPassId) -> Self {
323    ///         Self::Tangent(Box::new(self.clone()), pass)
324    ///     }
325    /// }
326    ///
327    /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
328    /// enum Op { Id }
329    ///
330    /// impl GraphOperation for Op {
331    ///     type Operand = f64;
332    ///     type Context = ();
333    ///     type InputKey = Key;
334    ///
335    ///     fn input_count(&self) -> usize { 1 }
336    ///     fn output_count(&self) -> usize { 1 }
337    /// }
338    ///
339    /// impl Primitive for Op {
340    ///     type ADContext = ();
341    ///     fn add() -> Self { Self::Id }
342    ///
343    ///     fn jvp_rule(
344    ///         &self,
345    ///         _builder: &mut impl PrimitiveBuilder<Self>,
346    ///         _primal_in: &[ValueKey<Self>],
347    ///         _primal_out: &[ValueKey<Self>],
348    ///         tangent_in: &[Option<LocalValueId>],
349    ///         _ctx: &mut (),
350    ///     ) -> Vec<Option<LocalValueId>> {
351    ///         vec![tangent_in[0]]
352    ///     }
353    ///
354    ///     fn transpose_rule(
355    ///         &self,
356    ///         _builder: &mut impl PrimitiveBuilder<Self>,
357    ///         cotangent_out: &[Option<LocalValueId>],
358    ///         _inputs: &[PrimitiveValue<Self>],
359    ///         _role: &OperationRole,
360    ///         _ctx: &mut (),
361    ///     ) -> Vec<Option<LocalValueId>> {
362    ///         vec![cotangent_out[0]]
363    ///     }
364    /// }
365    ///
366    /// struct Keys(usize);
367    /// impl KeySource<Op> for Keys {
368    ///     fn fresh_input_key(&mut self) -> Key {
369    ///         let key = Key::Generated(self.0);
370    ///         self.0 += 1;
371    ///         key
372    ///     }
373    /// }
374    ///
375    /// let mut recorder = Recorder::new(Keys(0));
376    /// let graph_inputs = recorder.fresh_input_keys::<Op>(1);
377    /// let graph = RecordedGraph::from_primitive(Op::Id, graph_inputs);
378    /// let input = EagerInput {
379    ///     key: ValueKey::Input(Key::User("x")),
380    ///     trace: None,
381    ///     requires_grad: true,
382    ///     data: Arc::new(2.0),
383    /// };
384    ///
385    /// let outputs = recorder.record_graph(
386    ///     graph,
387    ///     &[input],
388    ///     &[Arc::new(2.0)],
389    ///     HashMap::new(),
390    /// );
391    /// assert!(outputs[0].trace.is_some());
392    /// ```
393    pub fn record_graph<Op>(
394        &mut self,
395        graph: RecordedGraph<Op>,
396        inputs: &[EagerInput<Op>],
397        outputs: &[Arc<Op::Operand>],
398        retained_values: HashMap<ValueKey<Op>, Arc<Op::Operand>>,
399    ) -> Vec<EagerOutput<Op>>
400    where
401        Op: Primitive,
402        Op::InputKey: ADKey,
403        K: KeySource<Op>,
404    {
405        assert_eq!(
406            inputs.len(),
407            graph.input_keys().len(),
408            "Recorder::record_graph expected {} inputs, got {}",
409            graph.input_keys().len(),
410            inputs.len()
411        );
412        assert_eq!(
413            outputs.len(),
414            graph.output_keys().len(),
415            "Recorder::record_graph expected {} outputs, got {}",
416            graph.output_keys().len(),
417            outputs.len()
418        );
419        assert!(
420            outputs.len() <= u8::MAX as usize + 1,
421            "Recorder::record_graph has too many outputs for ValueKey: {}",
422            outputs.len()
423        );
424
425        let output_keys = fresh_value_keys(&mut self.key_source, outputs.len());
426        let requires_grad = inputs.iter().any(|input| input.requires_grad);
427
428        let trace = requires_grad.then(|| {
429            let saved_data = saved_graph_values(&graph, inputs, &retained_values);
430            Trace::new(Arc::new(TraceNode::new(
431                graph,
432                output_keys.clone(),
433                saved_data,
434                inputs
435                    .iter()
436                    .map(|input| {
437                        TraceEdge::new(
438                            input.trace.as_ref().map(|trace| trace.node().clone()),
439                            input.key.clone(),
440                            input.requires_grad,
441                        )
442                    })
443                    .collect(),
444            )))
445        });
446
447        output_keys
448            .into_iter()
449            .enumerate()
450            .map(|(output_slot, key)| EagerOutput {
451                key,
452                trace: trace.clone(),
453                requires_grad,
454                output_slot,
455            })
456            .collect()
457    }
458}
459
460fn saved_graph_values<Op: GraphOperation>(
461    graph: &RecordedGraph<Op>,
462    inputs: &[EagerInput<Op>],
463    retained_values: &HashMap<ValueKey<Op>, Arc<Op::Operand>>,
464) -> HashMap<ValueKey<Op>, Arc<Op::Operand>> {
465    let mut saved = HashMap::with_capacity(inputs.len() + retained_values.len());
466    for (input_key, input) in graph.input_keys().iter().zip(inputs.iter()) {
467        saved.insert(ValueKey::Input(input_key.clone()), input.data.clone());
468    }
469    for (key, value) in retained_values {
470        saved.insert(key.clone(), value.clone());
471    }
472    saved
473}
474
475fn fresh_value_keys<Op: GraphOperation>(
476    key_source: &mut impl KeySource<Op>,
477    count: usize,
478) -> Vec<ValueKey<Op>> {
479    (0..count)
480        .map(|_| ValueKey::Input(key_source.fresh_input_key()))
481        .collect()
482}