1use std::collections::HashMap;
2use std::sync::Arc;
3
4use computegraph::{GraphOperation, ValueKey};
5
6use super::record::RecordedGraph;
7
8pub struct Trace<Op: GraphOperation> {
14 node: Arc<TraceNode<Op>>,
15}
16
17impl<Op: GraphOperation> Clone for Trace<Op> {
18 fn clone(&self) -> Self {
19 Self {
20 node: self.node.clone(),
21 }
22 }
23}
24
25impl<Op: GraphOperation> Trace<Op> {
26 pub(crate) fn new(node: Arc<TraceNode<Op>>) -> Self {
27 Self { node }
28 }
29
30 pub(crate) fn node(&self) -> &Arc<TraceNode<Op>> {
31 &self.node
32 }
33
34 pub fn saved_values(&self) -> &HashMap<ValueKey<Op>, Arc<Op::Operand>> {
36 self.node.saved_data()
37 }
38}
39
40pub(crate) struct TraceNode<Op: GraphOperation> {
41 computation: RecordedGraph<Op>,
42 primal_out_keys: Vec<ValueKey<Op>>,
43 saved_data: HashMap<ValueKey<Op>, Arc<Op::Operand>>,
44 input_edges: Vec<TraceEdge<Op>>,
45}
46
47impl<Op: GraphOperation> TraceNode<Op> {
48 pub(crate) fn new(
49 computation: RecordedGraph<Op>,
50 primal_out_keys: Vec<ValueKey<Op>>,
51 saved_data: HashMap<ValueKey<Op>, Arc<Op::Operand>>,
52 input_edges: Vec<TraceEdge<Op>>,
53 ) -> Self {
54 assert_eq!(
55 primal_out_keys.len(),
56 computation.output_keys().len(),
57 "trace node expected {} primal output keys, got {}",
58 computation.output_keys().len(),
59 primal_out_keys.len()
60 );
61 assert_eq!(
62 input_edges.len(),
63 computation.input_keys().len(),
64 "trace node expected {} input edges, got {}",
65 computation.input_keys().len(),
66 input_edges.len()
67 );
68
69 Self {
70 computation,
71 primal_out_keys,
72 saved_data,
73 input_edges,
74 }
75 }
76
77 pub(crate) fn computation(&self) -> &RecordedGraph<Op> {
78 &self.computation
79 }
80
81 pub(crate) fn primal_out_keys(&self) -> &[ValueKey<Op>] {
82 &self.primal_out_keys
83 }
84
85 pub(crate) fn saved_data(&self) -> &HashMap<ValueKey<Op>, Arc<Op::Operand>> {
86 &self.saved_data
87 }
88
89 pub(crate) fn input_edges(&self) -> &[TraceEdge<Op>] {
90 &self.input_edges
91 }
92}
93
94pub(crate) struct TraceEdge<Op: GraphOperation> {
95 pub(crate) node: Option<Arc<TraceNode<Op>>>,
96 pub(crate) key: ValueKey<Op>,
97 pub(crate) requires_grad: bool,
98}
99
100impl<Op: GraphOperation> TraceEdge<Op> {
101 pub(crate) fn new(
102 node: Option<Arc<TraceNode<Op>>>,
103 key: ValueKey<Op>,
104 requires_grad: bool,
105 ) -> Self {
106 Self {
107 node,
108 key,
109 requires_grad,
110 }
111 }
112}