Skip to main content

tidu/eager/
trace.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use computegraph::{GraphOperation, ValueKey};
5
6use super::record::RecordedGraph;
7
8/// Opaque handle to an eager reverse-mode trace node.
9///
10/// Downstream eager values store this handle next to their concrete data and
11/// pass it back to [`crate::eager::try_backward`]. The node and edge layout is
12/// intentionally private.
13pub struct Trace<Op: GraphOperation> {
14    node: Arc<TraceNode<Op>>,
15}
16
17impl<Op: GraphOperation> Clone for Trace<Op> {
18    fn clone(&self) -> Self {
19        Self {
20            node: self.node.clone(),
21        }
22    }
23}
24
25impl<Op: GraphOperation> Trace<Op> {
26    pub(crate) fn new(node: Arc<TraceNode<Op>>) -> Self {
27        Self { node }
28    }
29
30    pub(crate) fn node(&self) -> &Arc<TraceNode<Op>> {
31        &self.node
32    }
33
34    /// Saved concrete primal values used as initial data during backward replay.
35    pub fn saved_values(&self) -> &HashMap<ValueKey<Op>, Arc<Op::Operand>> {
36        self.node.saved_data()
37    }
38}
39
40pub(crate) struct TraceNode<Op: GraphOperation> {
41    computation: RecordedGraph<Op>,
42    primal_out_keys: Vec<ValueKey<Op>>,
43    saved_data: HashMap<ValueKey<Op>, Arc<Op::Operand>>,
44    input_edges: Vec<TraceEdge<Op>>,
45}
46
47impl<Op: GraphOperation> TraceNode<Op> {
48    pub(crate) fn new(
49        computation: RecordedGraph<Op>,
50        primal_out_keys: Vec<ValueKey<Op>>,
51        saved_data: HashMap<ValueKey<Op>, Arc<Op::Operand>>,
52        input_edges: Vec<TraceEdge<Op>>,
53    ) -> Self {
54        assert_eq!(
55            primal_out_keys.len(),
56            computation.output_keys().len(),
57            "trace node expected {} primal output keys, got {}",
58            computation.output_keys().len(),
59            primal_out_keys.len()
60        );
61        assert_eq!(
62            input_edges.len(),
63            computation.input_keys().len(),
64            "trace node expected {} input edges, got {}",
65            computation.input_keys().len(),
66            input_edges.len()
67        );
68
69        Self {
70            computation,
71            primal_out_keys,
72            saved_data,
73            input_edges,
74        }
75    }
76
77    pub(crate) fn computation(&self) -> &RecordedGraph<Op> {
78        &self.computation
79    }
80
81    pub(crate) fn primal_out_keys(&self) -> &[ValueKey<Op>] {
82        &self.primal_out_keys
83    }
84
85    pub(crate) fn saved_data(&self) -> &HashMap<ValueKey<Op>, Arc<Op::Operand>> {
86        &self.saved_data
87    }
88
89    pub(crate) fn input_edges(&self) -> &[TraceEdge<Op>] {
90        &self.input_edges
91    }
92}
93
94pub(crate) struct TraceEdge<Op: GraphOperation> {
95    pub(crate) node: Option<Arc<TraceNode<Op>>>,
96    pub(crate) key: ValueKey<Op>,
97    pub(crate) requires_grad: bool,
98}
99
100impl<Op: GraphOperation> TraceEdge<Op> {
101    pub(crate) fn new(
102        node: Option<Arc<TraceNode<Op>>>,
103        key: ValueKey<Op>,
104        requires_grad: bool,
105    ) -> Self {
106        Self {
107            node,
108            key,
109            requires_grad,
110        }
111    }
112}