1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use crate::{ADKey, ADRuleResult, Primitive};
5use computegraph::{GraphOperation, ValueKey};
6
7use crate::{LinearizedGraph, PrimitiveGraph};
8
9use super::trace::{Trace, TraceNode};
10
11pub trait BackwardExecutor<Op: Primitive>
13where
14 Op::InputKey: ADKey,
15{
16 fn execute_forward(
19 &mut self,
20 graph: PrimitiveGraph<'_, Op>,
21 initial_data: &HashMap<ValueKey<Op>, Arc<Op::Operand>>,
22 ) -> HashMap<ValueKey<Op>, Arc<Op::Operand>>;
23
24 fn run_transposed_linear(
26 &mut self,
27 linear: &LinearizedGraph<Op>,
28 cotangent_out: &[Option<Arc<Op::Operand>>],
29 external_data: &HashMap<ValueKey<Op>, Arc<Op::Operand>>,
30 ctx: &mut Op::ADContext,
31 ) -> ADRuleResult<Vec<Option<Arc<Op::Operand>>>>;
32
33 fn add_operands(&mut self, a: &Arc<Op::Operand>, b: &Arc<Op::Operand>) -> Arc<Op::Operand>;
35}
36
37pub fn try_backward<Op: Primitive>(
39 output_key: &ValueKey<Op>,
40 output_trace: Option<&Trace<Op>>,
41 seed: Arc<Op::Operand>,
42 executor: &mut impl BackwardExecutor<Op>,
43 ctx: &mut Op::ADContext,
44) -> ADRuleResult<HashMap<ValueKey<Op>, Arc<Op::Operand>>>
45where
46 Op::InputKey: ADKey,
47{
48 let sorted_nodes = topo_sort_trace(output_trace);
49 let mut cotangents: HashMap<ValueKey<Op>, Arc<Op::Operand>> = HashMap::new();
50 cotangents.insert(output_key.clone(), seed);
51
52 for node in sorted_nodes.iter().rev() {
53 let cotangent_out: Vec<Option<Arc<Op::Operand>>> = node
54 .primal_out_keys()
55 .iter()
56 .map(|key| cotangents.get(key).cloned())
57 .collect();
58 if cotangent_out.iter().all(Option::is_none) {
59 continue;
60 }
61
62 let mut active_output_slots = Vec::new();
63 let mut active_cotangent_out = Vec::new();
64 for (slot, maybe_cotangent) in cotangent_out.into_iter().enumerate() {
65 if let Some(cotangent) = maybe_cotangent {
66 active_output_slots.push(slot);
67 active_cotangent_out.push(Some(cotangent));
68 }
69 }
70
71 let linear = node
72 .computation()
73 .try_linearize(&active_output_slots, ctx)?;
74 let replay_graph = PrimitiveGraph::new(linear.as_graph());
75 let all_values = executor.execute_forward(replay_graph, node.saved_data());
76 let cotangent_in =
77 executor.run_transposed_linear(&linear, &active_cotangent_out, &all_values, ctx)?;
78
79 for (edge, maybe_cotangent) in node.input_edges().iter().zip(cotangent_in) {
80 let cotangent = match maybe_cotangent {
81 Some(cotangent) => cotangent,
82 None => continue,
83 };
84 if !edge.requires_grad {
85 continue;
86 }
87
88 let accumulated = match cotangents.remove(&edge.key) {
89 Some(existing) => executor.add_operands(&existing, &cotangent),
90 None => cotangent,
91 };
92 cotangents.insert(edge.key.clone(), accumulated);
93 }
94 }
95
96 Ok(cotangents)
97}
98
99fn topo_sort_trace<Op: GraphOperation>(
100 output_trace: Option<&Trace<Op>>,
101) -> Vec<Arc<TraceNode<Op>>> {
102 fn visit<Op: GraphOperation>(
103 node: &Arc<TraceNode<Op>>,
104 visited: &mut HashSet<*const TraceNode<Op>>,
105 order: &mut Vec<Arc<TraceNode<Op>>>,
106 ) {
107 let ptr = Arc::as_ptr(node);
108 if !visited.insert(ptr) {
109 return;
110 }
111
112 for edge in node.input_edges() {
113 if let Some(parent) = &edge.node {
114 visit(parent, visited, order);
115 }
116 }
117
118 order.push(node.clone());
119 }
120
121 let mut visited = HashSet::new();
122 let mut order = Vec::new();
123 if let Some(trace) = output_trace {
124 visit(trace.node(), &mut visited, &mut order);
125 }
126 order
127}