tidu/grad_node.rs
1use std::collections::HashMap;
2use std::sync::Arc;
3
4use computegraph::{GlobalValKey, GraphOp};
5
6/// Backward computation node for eager reverse-mode AD.
7///
8/// A `GradNode` records one primal operation, the stable input aliases used to
9/// replay that operation during backward, the user-visible output keys that can
10/// receive cotangent seeds, and the edges to parent eager values.
11///
12/// # Examples
13///
14/// ```ignore
15/// use tidu::{GradEdge, GradNode};
16///
17/// let node = GradNode::new(
18/// op,
19/// input_aliases,
20/// output_keys,
21/// saved_forward_values,
22/// vec![GradEdge::new(parent_node, input_key, true)],
23/// );
24/// ```
25pub struct GradNode<Op: GraphOp> {
26 op: Op,
27 primal_in_keys: Vec<GlobalValKey<Op>>,
28 primal_out_keys: Vec<GlobalValKey<Op>>,
29 saved_data: HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
30 input_edges: Vec<GradEdge<Op>>,
31}
32
33impl<Op: GraphOp> GradNode<Op> {
34 /// Create a grad node and validate the shape of its eager AD metadata.
35 ///
36 /// `primal_in_keys` must contain `GlobalValKey::Input` aliases. The eager
37 /// backward path linearizes one operation at a time and rebuilds those
38 /// aliases as fragment inputs.
39 ///
40 /// # Examples
41 ///
42 /// ```ignore
43 /// let node = tidu::GradNode::new(
44 /// op,
45 /// input_aliases,
46 /// output_keys,
47 /// saved_data,
48 /// input_edges,
49 /// );
50 /// ```
51 pub fn new(
52 op: Op,
53 primal_in_keys: Vec<GlobalValKey<Op>>,
54 primal_out_keys: Vec<GlobalValKey<Op>>,
55 saved_data: HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
56 input_edges: Vec<GradEdge<Op>>,
57 ) -> Self {
58 assert_eq!(
59 primal_in_keys.len(),
60 op.n_inputs(),
61 "grad node for {:?} expected {} primal input keys, got {}",
62 op,
63 op.n_inputs(),
64 primal_in_keys.len()
65 );
66 assert_eq!(
67 primal_out_keys.len(),
68 op.n_outputs(),
69 "grad node for {:?} expected {} primal output keys, got {}",
70 op,
71 op.n_outputs(),
72 primal_out_keys.len()
73 );
74 assert_eq!(
75 input_edges.len(),
76 op.n_inputs(),
77 "grad node for {:?} expected {} input edges, got {}",
78 op,
79 op.n_inputs(),
80 input_edges.len()
81 );
82 assert!(
83 primal_in_keys
84 .iter()
85 .all(|key| matches!(key, GlobalValKey::Input(_))),
86 "grad node for {:?} requires GlobalValKey::Input aliases in primal_in_keys",
87 op
88 );
89
90 Self {
91 op,
92 primal_in_keys,
93 primal_out_keys,
94 saved_data,
95 input_edges,
96 }
97 }
98
99 /// The primal operation recorded by this node.
100 pub fn op(&self) -> &Op {
101 &self.op
102 }
103
104 /// Stable input aliases used for single-op backward replay.
105 pub fn primal_in_keys(&self) -> &[GlobalValKey<Op>] {
106 &self.primal_in_keys
107 }
108
109 /// User-visible output keys, one per primal output slot.
110 pub fn primal_out_keys(&self) -> &[GlobalValKey<Op>] {
111 &self.primal_out_keys
112 }
113
114 /// Saved concrete primal input and derived output values.
115 pub fn saved_data(&self) -> &HashMap<GlobalValKey<Op>, Arc<Op::Operand>> {
116 &self.saved_data
117 }
118
119 /// Edges to the eager values that provided this node's inputs.
120 pub fn input_edges(&self) -> &[GradEdge<Op>] {
121 &self.input_edges
122 }
123}
124
125/// Edge from a grad node to one of its primal inputs.
126///
127/// `node` points to the parent operation that produced the input. `None`
128/// denotes a leaf eager value. `key` is the cotangent accumulation target for
129/// that input.
130///
131/// # Examples
132///
133/// ```ignore
134/// let edge = tidu::GradEdge::new(parent_node, input_key, requires_grad);
135/// ```
136pub struct GradEdge<Op: GraphOp> {
137 /// Parent grad node. `None` denotes a leaf input.
138 pub node: Option<Arc<GradNode<Op>>>,
139 /// Gradient accumulation target for this input.
140 pub key: GlobalValKey<Op>,
141 /// Whether the input participates in gradient propagation.
142 pub requires_grad: bool,
143}
144
145impl<Op: GraphOp> GradEdge<Op> {
146 /// Create an eager backward edge.
147 ///
148 /// # Examples
149 ///
150 /// ```ignore
151 /// let edge = tidu::GradEdge::new(parent_node, input_key, true);
152 /// ```
153 pub fn new(
154 node: Option<Arc<GradNode<Op>>>,
155 key: GlobalValKey<Op>,
156 requires_grad: bool,
157 ) -> Self {
158 Self {
159 node,
160 key,
161 requires_grad,
162 }
163 }
164}