1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use chainrules::{ADKey, ADRuleResult, PrimitiveOp};
5use computegraph::fragment::{Fragment, FragmentBuilder};
6use computegraph::resolve::resolve;
7use computegraph::{GlobalValKey, GraphOp, OpMode, ValRef};
8
9use crate::grad_node::GradNode;
10use crate::LinearFragment;
11
12pub trait BackwardCallbacks<Op: PrimitiveOp>
14where
15 Op::InputKey: ADKey,
16{
17 fn execute_forward(
20 &mut self,
21 fragment: &Fragment<Op>,
22 initial_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
23 ) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>>;
24
25 fn eager_transpose(
27 &mut self,
28 linear: &LinearFragment<Op>,
29 cotangent_out: &[Option<Arc<Op::Operand>>],
30 external_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
31 ctx: &mut Op::ADContext,
32 ) -> Vec<Option<Arc<Op::Operand>>>;
33
34 fn try_eager_transpose(
40 &mut self,
41 linear: &LinearFragment<Op>,
42 cotangent_out: &[Option<Arc<Op::Operand>>],
43 external_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
44 ctx: &mut Op::ADContext,
45 ) -> ADRuleResult<Vec<Option<Arc<Op::Operand>>>> {
46 Ok(self.eager_transpose(linear, cotangent_out, external_data, ctx))
47 }
48
49 fn add_operands(&mut self, a: &Arc<Op::Operand>, b: &Arc<Op::Operand>) -> Arc<Op::Operand>;
51}
52
53pub fn topo_sort_grad_dag<Op: GraphOp>(
55 output_node: &Option<Arc<GradNode<Op>>>,
56) -> Vec<Arc<GradNode<Op>>> {
57 fn visit<Op: GraphOp>(
58 node: &Arc<GradNode<Op>>,
59 visited: &mut HashSet<*const GradNode<Op>>,
60 order: &mut Vec<Arc<GradNode<Op>>>,
61 ) {
62 let ptr = Arc::as_ptr(node);
63 if !visited.insert(ptr) {
64 return;
65 }
66
67 for edge in node.input_edges() {
68 if let Some(parent) = &edge.node {
69 visit(parent, visited, order);
70 }
71 }
72
73 order.push(node.clone());
74 }
75
76 let mut visited = HashSet::new();
77 let mut order = Vec::new();
78 if let Some(node) = output_node {
79 visit(node, &mut visited, &mut order);
80 }
81 order
82}
83
84pub fn backward_dag<Op: PrimitiveOp>(
86 sorted_nodes: &[Arc<GradNode<Op>>],
87 output_key: &GlobalValKey<Op>,
88 seed: Arc<Op::Operand>,
89 callbacks: &mut impl BackwardCallbacks<Op>,
90 ctx: &mut Op::ADContext,
91) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>>
92where
93 Op::InputKey: ADKey,
94{
95 match try_backward_dag(sorted_nodes, output_key, seed, callbacks, ctx) {
96 Ok(cotangents) => cotangents,
97 Err(err) => panic!("{err}"),
98 }
99}
100
101pub fn try_backward_dag<Op: PrimitiveOp>(
103 sorted_nodes: &[Arc<GradNode<Op>>],
104 output_key: &GlobalValKey<Op>,
105 seed: Arc<Op::Operand>,
106 callbacks: &mut impl BackwardCallbacks<Op>,
107 ctx: &mut Op::ADContext,
108) -> ADRuleResult<HashMap<GlobalValKey<Op>, Arc<Op::Operand>>>
109where
110 Op::InputKey: ADKey,
111{
112 let mut cotangents: HashMap<GlobalValKey<Op>, Arc<Op::Operand>> = HashMap::new();
113 cotangents.insert(output_key.clone(), seed);
114
115 for node in sorted_nodes.iter().rev() {
116 let cotangent_out: Vec<Option<Arc<Op::Operand>>> = node
117 .primal_out_keys()
118 .iter()
119 .map(|key| cotangents.get(key).cloned())
120 .collect();
121 if cotangent_out.iter().all(Option::is_none) {
122 continue;
123 }
124
125 let mut active_output_slots = Vec::new();
126 let mut active_cotangent_out = Vec::new();
127 for (slot, maybe_cotangent) in cotangent_out.into_iter().enumerate() {
128 if let Some(cotangent) = maybe_cotangent {
129 active_output_slots.push(slot);
130 active_cotangent_out.push(Some(cotangent));
131 }
132 }
133
134 let linear = try_build_single_op_linear(node, &active_output_slots, ctx)?;
135 let all_values = callbacks.execute_forward(&linear.fragment, node.saved_data());
136 let cotangent_in =
137 callbacks.try_eager_transpose(&linear, &active_cotangent_out, &all_values, ctx)?;
138
139 for (edge, maybe_cotangent) in node.input_edges().iter().zip(cotangent_in.into_iter()) {
140 let Some(cotangent) = maybe_cotangent else {
141 continue;
142 };
143 if !edge.requires_grad {
144 continue;
145 }
146
147 let accumulated = match cotangents.remove(&edge.key) {
148 Some(existing) => callbacks.add_operands(&existing, &cotangent),
149 None => cotangent,
150 };
151 cotangents.insert(edge.key.clone(), accumulated);
152 }
153 }
154
155 Ok(cotangents)
156}
157
158fn try_build_single_op_linear<Op: PrimitiveOp>(
159 node: &GradNode<Op>,
160 output_slots: &[usize],
161 ctx: &mut Op::ADContext,
162) -> ADRuleResult<LinearFragment<Op>>
163where
164 Op::InputKey: ADKey,
165{
166 let mut builder = FragmentBuilder::new();
167
168 let input_local_ids: Vec<_> = node
169 .primal_in_keys()
170 .iter()
171 .map(|key| match key {
172 GlobalValKey::Input(input_key) => builder.add_input(input_key.clone()),
173 GlobalValKey::Derived { .. } => {
174 panic!(
175 "build_single_op_linear requires GlobalValKey::Input aliases in node.primal_in_keys"
176 )
177 }
178 })
179 .collect();
180
181 let outputs = builder.add_op(
182 node.op().clone(),
183 input_local_ids
184 .iter()
185 .map(|local_id| ValRef::Local(*local_id))
186 .collect(),
187 OpMode::Primal,
188 );
189 let selected_outputs: Vec<_> = output_slots
190 .iter()
191 .map(|&slot| {
192 *outputs.get(slot).unwrap_or_else(|| {
193 panic!(
194 "build_single_op_linear got output slot {slot} for {:?}, \
195 which has only {} outputs",
196 node.op(),
197 outputs.len()
198 )
199 })
200 })
201 .collect();
202 builder.set_outputs(selected_outputs.clone());
203
204 let fragment = Arc::new(builder.build());
205 let view = resolve(vec![fragment.clone()]);
206 let output_keys: Vec<_> = selected_outputs
207 .iter()
208 .map(|output_id| fragment.vals()[*output_id].key.clone())
209 .collect();
210 let wrt_keys: Vec<_> = node
211 .primal_in_keys()
212 .iter()
213 .map(|key| match key {
214 GlobalValKey::Input(input_key) => input_key.clone(),
215 GlobalValKey::Derived { .. } => {
216 panic!(
217 "build_single_op_linear requires GlobalValKey::Input aliases in node.primal_in_keys"
218 )
219 }
220 })
221 .collect();
222 let aliases = HashMap::new();
223
224 crate::try_differentiate(&view, &output_keys, &wrt_keys, 0, ctx, &aliases)
225}