Skip to main content

tidu/
grad_node.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use computegraph::{GlobalValKey, GraphOp};
5
6/// Backward computation node for eager reverse-mode AD.
7///
8/// A `GradNode` records one primal operation, the stable input aliases used to
9/// replay that operation during backward, the user-visible output keys that can
10/// receive cotangent seeds, and the edges to parent eager values.
11///
12/// # Examples
13///
14/// ```ignore
15/// use tidu::{GradEdge, GradNode};
16///
17/// let node = GradNode::new(
18///     op,
19///     input_aliases,
20///     output_keys,
21///     saved_forward_values,
22///     vec![GradEdge::new(parent_node, input_key, true)],
23/// );
24/// ```
25pub struct GradNode<Op: GraphOp> {
26    op: Op,
27    primal_in_keys: Vec<GlobalValKey<Op>>,
28    primal_out_keys: Vec<GlobalValKey<Op>>,
29    saved_data: HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
30    input_edges: Vec<GradEdge<Op>>,
31}
32
33impl<Op: GraphOp> GradNode<Op> {
34    /// Create a grad node and validate the shape of its eager AD metadata.
35    ///
36    /// `primal_in_keys` must contain `GlobalValKey::Input` aliases. The eager
37    /// backward path linearizes one operation at a time and rebuilds those
38    /// aliases as fragment inputs.
39    ///
40    /// # Examples
41    ///
42    /// ```ignore
43    /// let node = tidu::GradNode::new(
44    ///     op,
45    ///     input_aliases,
46    ///     output_keys,
47    ///     saved_data,
48    ///     input_edges,
49    /// );
50    /// ```
51    pub fn new(
52        op: Op,
53        primal_in_keys: Vec<GlobalValKey<Op>>,
54        primal_out_keys: Vec<GlobalValKey<Op>>,
55        saved_data: HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
56        input_edges: Vec<GradEdge<Op>>,
57    ) -> Self {
58        assert_eq!(
59            primal_in_keys.len(),
60            op.n_inputs(),
61            "grad node for {:?} expected {} primal input keys, got {}",
62            op,
63            op.n_inputs(),
64            primal_in_keys.len()
65        );
66        assert_eq!(
67            primal_out_keys.len(),
68            op.n_outputs(),
69            "grad node for {:?} expected {} primal output keys, got {}",
70            op,
71            op.n_outputs(),
72            primal_out_keys.len()
73        );
74        assert_eq!(
75            input_edges.len(),
76            op.n_inputs(),
77            "grad node for {:?} expected {} input edges, got {}",
78            op,
79            op.n_inputs(),
80            input_edges.len()
81        );
82        assert!(
83            primal_in_keys
84                .iter()
85                .all(|key| matches!(key, GlobalValKey::Input(_))),
86            "grad node for {:?} requires GlobalValKey::Input aliases in primal_in_keys",
87            op
88        );
89
90        Self {
91            op,
92            primal_in_keys,
93            primal_out_keys,
94            saved_data,
95            input_edges,
96        }
97    }
98
99    /// The primal operation recorded by this node.
100    pub fn op(&self) -> &Op {
101        &self.op
102    }
103
104    /// Stable input aliases used for single-op backward replay.
105    pub fn primal_in_keys(&self) -> &[GlobalValKey<Op>] {
106        &self.primal_in_keys
107    }
108
109    /// User-visible output keys, one per primal output slot.
110    pub fn primal_out_keys(&self) -> &[GlobalValKey<Op>] {
111        &self.primal_out_keys
112    }
113
114    /// Saved concrete primal input and derived output values.
115    pub fn saved_data(&self) -> &HashMap<GlobalValKey<Op>, Arc<Op::Operand>> {
116        &self.saved_data
117    }
118
119    /// Edges to the eager values that provided this node's inputs.
120    pub fn input_edges(&self) -> &[GradEdge<Op>] {
121        &self.input_edges
122    }
123}
124
125/// Edge from a grad node to one of its primal inputs.
126///
127/// `node` points to the parent operation that produced the input. `None`
128/// denotes a leaf eager value. `key` is the cotangent accumulation target for
129/// that input.
130///
131/// # Examples
132///
133/// ```ignore
134/// let edge = tidu::GradEdge::new(parent_node, input_key, requires_grad);
135/// ```
136pub struct GradEdge<Op: GraphOp> {
137    /// Parent grad node. `None` denotes a leaf input.
138    pub node: Option<Arc<GradNode<Op>>>,
139    /// Gradient accumulation target for this input.
140    pub key: GlobalValKey<Op>,
141    /// Whether the input participates in gradient propagation.
142    pub requires_grad: bool,
143}
144
145impl<Op: GraphOp> GradEdge<Op> {
146    /// Create an eager backward edge.
147    ///
148    /// # Examples
149    ///
150    /// ```ignore
151    /// let edge = tidu::GradEdge::new(parent_node, input_key, true);
152    /// ```
153    pub fn new(
154        node: Option<Arc<GradNode<Op>>>,
155        key: GlobalValKey<Op>,
156        requires_grad: bool,
157    ) -> Self {
158        Self {
159            node,
160            key,
161            requires_grad,
162        }
163    }
164}