tidu/eager/record.rs
1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::{ADKey, ADRuleResult, Primitive};
5use computegraph::graph::{Graph, GraphBuilder};
6use computegraph::resolve::resolve;
7use computegraph::{GraphOperation, OperationRole, ValueKey, ValueRef};
8
9use crate::LinearizedGraph;
10
11use super::trace::{Trace, TraceEdge, TraceNode};
12
13/// Graph invocation recorded as one eager reverse-mode trace node.
14///
15/// # Examples
16///
17/// ```
18/// use tidu::eager::RecordedGraph;
19/// use computegraph::GraphOperation;
20///
21/// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
22/// enum Op { Add }
23///
24/// impl GraphOperation for Op {
25/// type Operand = f64;
26/// type Context = ();
27/// type InputKey = &'static str;
28///
29/// fn input_count(&self) -> usize { 2 }
30/// fn output_count(&self) -> usize { 1 }
31/// }
32///
33/// let recorded = RecordedGraph::from_primitive(Op::Add, vec!["x", "y"]);
34/// assert_eq!(recorded.input_keys(), &["x", "y"]);
35/// assert_eq!(recorded.output_keys().len(), 1);
36/// ```
37pub struct RecordedGraph<Op: GraphOperation> {
38 graph: Arc<Graph<Op>>,
39 input_keys: Vec<Op::InputKey>,
40 output_keys: Vec<ValueKey<Op>>,
41}
42
43impl<Op: GraphOperation> RecordedGraph<Op> {
44 /// Create a recorded graph from an already-built graph and aligned keys.
45 ///
46 /// # Examples
47 ///
48 /// ```
49 /// use std::sync::Arc;
50 /// use computegraph::graph::GraphBuilder;
51 /// use computegraph::{GraphOperation, OperationRole, ValueRef};
52 /// use tidu::eager::RecordedGraph;
53 ///
54 /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
55 /// enum Op { Id }
56 ///
57 /// impl GraphOperation for Op {
58 /// type Operand = f64;
59 /// type Context = ();
60 /// type InputKey = &'static str;
61 ///
62 /// fn input_count(&self) -> usize { 1 }
63 /// fn output_count(&self) -> usize { 1 }
64 /// }
65 ///
66 /// let mut builder = GraphBuilder::new();
67 /// let x = builder.add_input("x");
68 /// let y = builder.add_operation(Op::Id, vec![ValueRef::Local(x)], OperationRole::Primary);
69 /// builder.set_outputs(y.clone());
70 /// let graph = Arc::new(builder.build());
71 /// let output_keys = y.iter().map(|id| graph.values()[*id].key.clone()).collect();
72 /// let recorded = RecordedGraph::new(graph, vec!["x"], output_keys);
73 ///
74 /// assert_eq!(recorded.input_keys(), &["x"]);
75 /// ```
76 pub fn new(
77 graph: Arc<Graph<Op>>,
78 input_keys: Vec<Op::InputKey>,
79 output_keys: Vec<ValueKey<Op>>,
80 ) -> Self {
81 assert_eq!(
82 graph.inputs().len(),
83 input_keys.len(),
84 "RecordedGraph expected {} input keys, got {}",
85 graph.inputs().len(),
86 input_keys.len()
87 );
88 assert_eq!(
89 graph.outputs().len(),
90 output_keys.len(),
91 "RecordedGraph expected {} output keys, got {}",
92 graph.outputs().len(),
93 output_keys.len()
94 );
95 for (&input_id, input_key) in graph.inputs().iter().zip(input_keys.iter()) {
96 assert_eq!(
97 &graph.values()[input_id].key,
98 &ValueKey::Input(input_key.clone()),
99 "RecordedGraph input key order must match graph inputs"
100 );
101 }
102 for (&output_id, output_key) in graph.outputs().iter().zip(output_keys.iter()) {
103 assert_eq!(
104 &graph.values()[output_id].key,
105 output_key,
106 "RecordedGraph output key order must match graph outputs"
107 );
108 }
109
110 Self {
111 graph,
112 input_keys,
113 output_keys,
114 }
115 }
116
117 /// Build a one-op recorded graph for an eager primitive invocation.
118 ///
119 /// # Examples
120 ///
121 /// ```
122 /// use tidu::eager::RecordedGraph;
123 /// use computegraph::GraphOperation;
124 ///
125 /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
126 /// enum Op { Add }
127 ///
128 /// impl GraphOperation for Op {
129 /// type Operand = f64;
130 /// type Context = ();
131 /// type InputKey = &'static str;
132 ///
133 /// fn input_count(&self) -> usize { 2 }
134 /// fn output_count(&self) -> usize { 1 }
135 /// }
136 ///
137 /// let recorded = RecordedGraph::from_primitive(Op::Add, vec!["x", "y"]);
138 /// assert_eq!(recorded.as_graph().operations().len(), 1);
139 /// ```
140 pub fn from_primitive(op: Op, input_keys: Vec<Op::InputKey>) -> Self {
141 let mut builder = GraphBuilder::new();
142 let input_ids: Vec<_> = input_keys
143 .iter()
144 .cloned()
145 .map(|key| builder.add_input(key))
146 .collect();
147 let output_ids = builder.add_operation(
148 op,
149 input_ids
150 .iter()
151 .map(|local_id| ValueRef::Local(*local_id))
152 .collect(),
153 OperationRole::Primary,
154 );
155 builder.set_outputs(output_ids.clone());
156 let graph = Arc::new(builder.build());
157 let output_keys = output_ids
158 .iter()
159 .map(|output_id| graph.values()[*output_id].key.clone())
160 .collect();
161 Self::new(graph, input_keys, output_keys)
162 }
163
164 /// Borrow the recorded graph.
165 pub fn as_graph(&self) -> &Graph<Op> {
166 &self.graph
167 }
168
169 /// Graph input keys aligned with eager input edges.
170 pub fn input_keys(&self) -> &[Op::InputKey] {
171 &self.input_keys
172 }
173
174 /// Graph output keys aligned with eager output slots.
175 pub fn output_keys(&self) -> &[ValueKey<Op>] {
176 &self.output_keys
177 }
178}
179
180impl<Op: Primitive> RecordedGraph<Op>
181where
182 Op::InputKey: ADKey,
183{
184 pub(crate) fn try_linearize(
185 &self,
186 output_slots: &[usize],
187 ctx: &mut Op::ADContext,
188 ) -> ADRuleResult<LinearizedGraph<Op>> {
189 let selected_outputs: Vec<_> = output_slots
190 .iter()
191 .map(|&slot| {
192 self.output_keys.get(slot).cloned().unwrap_or_else(|| {
193 panic!(
194 "RecordedGraph got output slot {}, but graph has {} outputs",
195 slot,
196 self.output_keys.len()
197 )
198 })
199 })
200 .collect();
201 let view = resolve(vec![Arc::clone(&self.graph)]);
202 let aliases = HashMap::new();
203 crate::try_linearize(&view, &selected_outputs, &self.input_keys, 0, ctx, &aliases)
204 }
205}
206
207/// Input descriptor for recording one eager graph invocation.
208pub struct EagerInput<Op: GraphOperation> {
209 /// User-visible eager value key used for cotangent accumulation.
210 pub key: ValueKey<Op>,
211 /// Trace node that produced this value, or `None` for leaves.
212 pub trace: Option<Trace<Op>>,
213 /// Whether this value participates in reverse-mode propagation.
214 pub requires_grad: bool,
215 /// Concrete primal data for saved forward replay.
216 pub data: Arc<Op::Operand>,
217}
218
219/// Per-output trace metadata returned by [`Recorder::record_graph`].
220pub struct EagerOutput<Op: GraphOperation> {
221 /// User-visible eager output key.
222 pub key: ValueKey<Op>,
223 /// Shared trace node for all outputs when any input requires gradients.
224 pub trace: Option<Trace<Op>>,
225 /// Whether this output should be tracked by the downstream frontend.
226 pub requires_grad: bool,
227 /// Output slot within the recorded graph invocation.
228 pub output_slot: usize,
229}
230
231/// Caller-provided source of stable eager value keys.
232pub trait KeySource<Op: GraphOperation> {
233 /// Return a fresh input key that has not been used for another eager value.
234 fn fresh_input_key(&mut self) -> Op::InputKey;
235}
236
237/// Stateful eager operation recorder.
238pub struct Recorder<K> {
239 key_source: K,
240}
241
242impl<K> Recorder<K> {
243 /// Create a recorder from a downstream key source.
244 pub fn new(key_source: K) -> Self {
245 Self { key_source }
246 }
247
248 /// Borrow the underlying key source.
249 pub fn key_source_mut(&mut self) -> &mut K {
250 &mut self.key_source
251 }
252
253 /// Return the underlying key source.
254 pub fn into_key_source(self) -> K {
255 self.key_source
256 }
257
258 /// Return fresh graph input keys for one eager graph invocation.
259 ///
260 /// # Examples
261 ///
262 /// ```
263 /// use tidu::eager::{KeySource, Recorder};
264 /// use computegraph::GraphOperation;
265 ///
266 /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
267 /// enum Op { Id }
268 ///
269 /// impl GraphOperation for Op {
270 /// type Operand = f64;
271 /// type Context = ();
272 /// type InputKey = usize;
273 ///
274 /// fn input_count(&self) -> usize { 1 }
275 /// fn output_count(&self) -> usize { 1 }
276 /// }
277 ///
278 /// struct Keys(usize);
279 ///
280 /// impl KeySource<Op> for Keys {
281 /// fn fresh_input_key(&mut self) -> usize {
282 /// let key = self.0;
283 /// self.0 += 1;
284 /// key
285 /// }
286 /// }
287 ///
288 /// let mut recorder = Recorder::new(Keys(0));
289 /// assert_eq!(recorder.fresh_input_keys::<Op>(2), vec![0, 1]);
290 /// ```
291 pub fn fresh_input_keys<Op>(&mut self, count: usize) -> Vec<Op::InputKey>
292 where
293 Op: GraphOperation,
294 K: KeySource<Op>,
295 {
296 (0..count)
297 .map(|_| self.key_source.fresh_input_key())
298 .collect()
299 }
300
301 /// Record a concrete eager graph invocation for reverse-mode AD.
302 ///
303 /// # Examples
304 ///
305 /// ```
306 /// use std::collections::HashMap;
307 /// use std::sync::Arc;
308 /// use computegraph::{GraphOperation, LocalValueId, OperationRole, ValueKey};
309 /// use tidu::{
310 /// ADKey, DiffPassId, Primitive, PrimitiveBuilder, PrimitiveValue,
311 /// };
312 /// use tidu::eager::{EagerInput, KeySource, RecordedGraph, Recorder};
313 ///
314 /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
315 /// enum Key {
316 /// User(&'static str),
317 /// Generated(usize),
318 /// Tangent(Box<Key>, DiffPassId),
319 /// }
320 ///
321 /// impl ADKey for Key {
322 /// fn tangent_of(&self, pass: DiffPassId) -> Self {
323 /// Self::Tangent(Box::new(self.clone()), pass)
324 /// }
325 /// }
326 ///
327 /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
328 /// enum Op { Id }
329 ///
330 /// impl GraphOperation for Op {
331 /// type Operand = f64;
332 /// type Context = ();
333 /// type InputKey = Key;
334 ///
335 /// fn input_count(&self) -> usize { 1 }
336 /// fn output_count(&self) -> usize { 1 }
337 /// }
338 ///
339 /// impl Primitive for Op {
340 /// type ADContext = ();
341 /// fn add() -> Self { Self::Id }
342 ///
343 /// fn jvp_rule(
344 /// &self,
345 /// _builder: &mut impl PrimitiveBuilder<Self>,
346 /// _primal_in: &[ValueKey<Self>],
347 /// _primal_out: &[ValueKey<Self>],
348 /// tangent_in: &[Option<LocalValueId>],
349 /// _ctx: &mut (),
350 /// ) -> Vec<Option<LocalValueId>> {
351 /// vec![tangent_in[0]]
352 /// }
353 ///
354 /// fn transpose_rule(
355 /// &self,
356 /// _builder: &mut impl PrimitiveBuilder<Self>,
357 /// cotangent_out: &[Option<LocalValueId>],
358 /// _inputs: &[PrimitiveValue<Self>],
359 /// _role: &OperationRole,
360 /// _ctx: &mut (),
361 /// ) -> Vec<Option<LocalValueId>> {
362 /// vec![cotangent_out[0]]
363 /// }
364 /// }
365 ///
366 /// struct Keys(usize);
367 /// impl KeySource<Op> for Keys {
368 /// fn fresh_input_key(&mut self) -> Key {
369 /// let key = Key::Generated(self.0);
370 /// self.0 += 1;
371 /// key
372 /// }
373 /// }
374 ///
375 /// let mut recorder = Recorder::new(Keys(0));
376 /// let graph_inputs = recorder.fresh_input_keys::<Op>(1);
377 /// let graph = RecordedGraph::from_primitive(Op::Id, graph_inputs);
378 /// let input = EagerInput {
379 /// key: ValueKey::Input(Key::User("x")),
380 /// trace: None,
381 /// requires_grad: true,
382 /// data: Arc::new(2.0),
383 /// };
384 ///
385 /// let outputs = recorder.record_graph(
386 /// graph,
387 /// &[input],
388 /// &[Arc::new(2.0)],
389 /// HashMap::new(),
390 /// );
391 /// assert!(outputs[0].trace.is_some());
392 /// ```
393 pub fn record_graph<Op>(
394 &mut self,
395 graph: RecordedGraph<Op>,
396 inputs: &[EagerInput<Op>],
397 outputs: &[Arc<Op::Operand>],
398 retained_values: HashMap<ValueKey<Op>, Arc<Op::Operand>>,
399 ) -> Vec<EagerOutput<Op>>
400 where
401 Op: Primitive,
402 Op::InputKey: ADKey,
403 K: KeySource<Op>,
404 {
405 assert_eq!(
406 inputs.len(),
407 graph.input_keys().len(),
408 "Recorder::record_graph expected {} inputs, got {}",
409 graph.input_keys().len(),
410 inputs.len()
411 );
412 assert_eq!(
413 outputs.len(),
414 graph.output_keys().len(),
415 "Recorder::record_graph expected {} outputs, got {}",
416 graph.output_keys().len(),
417 outputs.len()
418 );
419 assert!(
420 outputs.len() <= u8::MAX as usize + 1,
421 "Recorder::record_graph has too many outputs for ValueKey: {}",
422 outputs.len()
423 );
424
425 let output_keys = fresh_value_keys(&mut self.key_source, outputs.len());
426 let requires_grad = inputs.iter().any(|input| input.requires_grad);
427
428 let trace = requires_grad.then(|| {
429 let saved_data = saved_graph_values(&graph, inputs, &retained_values);
430 Trace::new(Arc::new(TraceNode::new(
431 graph,
432 output_keys.clone(),
433 saved_data,
434 inputs
435 .iter()
436 .map(|input| {
437 TraceEdge::new(
438 input.trace.as_ref().map(|trace| trace.node().clone()),
439 input.key.clone(),
440 input.requires_grad,
441 )
442 })
443 .collect(),
444 )))
445 });
446
447 output_keys
448 .into_iter()
449 .enumerate()
450 .map(|(output_slot, key)| EagerOutput {
451 key,
452 trace: trace.clone(),
453 requires_grad,
454 output_slot,
455 })
456 .collect()
457 }
458}
459
460fn saved_graph_values<Op: GraphOperation>(
461 graph: &RecordedGraph<Op>,
462 inputs: &[EagerInput<Op>],
463 retained_values: &HashMap<ValueKey<Op>, Arc<Op::Operand>>,
464) -> HashMap<ValueKey<Op>, Arc<Op::Operand>> {
465 let mut saved = HashMap::with_capacity(inputs.len() + retained_values.len());
466 for (input_key, input) in graph.input_keys().iter().zip(inputs.iter()) {
467 saved.insert(ValueKey::Input(input_key.clone()), input.data.clone());
468 }
469 for (key, value) in retained_values {
470 saved.insert(key.clone(), value.clone());
471 }
472 saved
473}
474
475fn fresh_value_keys<Op: GraphOperation>(
476 key_source: &mut impl KeySource<Op>,
477 count: usize,
478) -> Vec<ValueKey<Op>> {
479 (0..count)
480 .map(|_| ValueKey::Input(key_source.fresh_input_key()))
481 .collect()
482}