1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use chainrules::{ADKey, PrimitiveOp};
5use computegraph::fragment::{Fragment, FragmentBuilder};
6use computegraph::resolve::resolve;
7use computegraph::{GlobalValKey, GraphOp, OpMode, ValRef};
8
9use crate::grad_node::GradNode;
10use crate::LinearFragment;
11
12pub trait BackwardCallbacks<Op: PrimitiveOp>
14where
15 Op::InputKey: ADKey,
16{
17 fn execute_forward(
20 &mut self,
21 fragment: &Fragment<Op>,
22 initial_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
23 ) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>>;
24
25 fn eager_transpose(
27 &mut self,
28 linear: &LinearFragment<Op>,
29 cotangent_out: &[Option<Arc<Op::Operand>>],
30 external_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
31 ctx: &mut Op::ADContext,
32 ) -> Vec<Option<Arc<Op::Operand>>>;
33
34 fn add_operands(&mut self, a: &Arc<Op::Operand>, b: &Arc<Op::Operand>) -> Arc<Op::Operand>;
36}
37
38pub fn topo_sort_grad_dag<Op: GraphOp>(
40 output_node: &Option<Arc<GradNode<Op>>>,
41) -> Vec<Arc<GradNode<Op>>> {
42 fn visit<Op: GraphOp>(
43 node: &Arc<GradNode<Op>>,
44 visited: &mut HashSet<*const GradNode<Op>>,
45 order: &mut Vec<Arc<GradNode<Op>>>,
46 ) {
47 let ptr = Arc::as_ptr(node);
48 if !visited.insert(ptr) {
49 return;
50 }
51
52 for edge in node.input_edges() {
53 if let Some(parent) = &edge.node {
54 visit(parent, visited, order);
55 }
56 }
57
58 order.push(node.clone());
59 }
60
61 let mut visited = HashSet::new();
62 let mut order = Vec::new();
63 if let Some(node) = output_node {
64 visit(node, &mut visited, &mut order);
65 }
66 order
67}
68
69pub fn backward_dag<Op: PrimitiveOp>(
71 sorted_nodes: &[Arc<GradNode<Op>>],
72 output_key: &GlobalValKey<Op>,
73 seed: Arc<Op::Operand>,
74 callbacks: &mut impl BackwardCallbacks<Op>,
75 ctx: &mut Op::ADContext,
76) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>>
77where
78 Op::InputKey: ADKey,
79{
80 let mut cotangents: HashMap<GlobalValKey<Op>, Arc<Op::Operand>> = HashMap::new();
81 cotangents.insert(output_key.clone(), seed);
82
83 for node in sorted_nodes.iter().rev() {
84 let cotangent_out: Vec<Option<Arc<Op::Operand>>> = node
85 .primal_out_keys()
86 .iter()
87 .map(|key| cotangents.get(key).cloned())
88 .collect();
89 if cotangent_out.iter().all(Option::is_none) {
90 continue;
91 }
92
93 let linear = build_single_op_linear(node, ctx);
94 let all_values = callbacks.execute_forward(&linear.fragment, node.saved_data());
95 let cotangent_in = callbacks.eager_transpose(&linear, &cotangent_out, &all_values, ctx);
96
97 for (edge, maybe_cotangent) in node.input_edges().iter().zip(cotangent_in.into_iter()) {
98 let Some(cotangent) = maybe_cotangent else {
99 continue;
100 };
101 if !edge.requires_grad {
102 continue;
103 }
104
105 let accumulated = match cotangents.remove(&edge.key) {
106 Some(existing) => callbacks.add_operands(&existing, &cotangent),
107 None => cotangent,
108 };
109 cotangents.insert(edge.key.clone(), accumulated);
110 }
111 }
112
113 cotangents
114}
115
116fn build_single_op_linear<Op: PrimitiveOp>(
117 node: &GradNode<Op>,
118 ctx: &mut Op::ADContext,
119) -> LinearFragment<Op>
120where
121 Op::InputKey: ADKey,
122{
123 let mut builder = FragmentBuilder::new();
124
125 let input_local_ids: Vec<_> = node
126 .primal_in_keys()
127 .iter()
128 .map(|key| match key {
129 GlobalValKey::Input(input_key) => builder.add_input(input_key.clone()),
130 GlobalValKey::Derived { .. } => {
131 panic!(
132 "build_single_op_linear requires GlobalValKey::Input aliases in node.primal_in_keys"
133 )
134 }
135 })
136 .collect();
137
138 let outputs = builder.add_op(
139 node.op().clone(),
140 input_local_ids
141 .iter()
142 .map(|local_id| ValRef::Local(*local_id))
143 .collect(),
144 OpMode::Primal,
145 );
146 builder.set_outputs(outputs.clone());
147
148 let fragment = Arc::new(builder.build());
149 let view = resolve(vec![fragment.clone()]);
150 let output_keys: Vec<_> = outputs
151 .iter()
152 .map(|output_id| fragment.vals()[*output_id].key.clone())
153 .collect();
154 let wrt_keys: Vec<_> = node
155 .primal_in_keys()
156 .iter()
157 .map(|key| match key {
158 GlobalValKey::Input(input_key) => input_key.clone(),
159 GlobalValKey::Derived { .. } => {
160 panic!(
161 "build_single_op_linear requires GlobalValKey::Input aliases in node.primal_in_keys"
162 )
163 }
164 })
165 .collect();
166 let aliases = HashMap::new();
167
168 crate::differentiate(&view, &output_keys, &wrt_keys, 0, ctx, &aliases)
169}