Skip to main content

tidu/eager/
backward.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use crate::{ADKey, ADRuleResult, Primitive};
5use computegraph::{GraphOperation, ValueKey};
6
7use crate::{LinearizedGraph, PrimitiveGraph};
8
9use super::trace::{Trace, TraceNode};
10
11/// Downstream execution hooks for eager backward.
12pub trait BackwardExecutor<Op: Primitive>
13where
14    Op::InputKey: ADKey,
15{
16    /// Replay a primitive graph and return any concrete values needed by
17    /// transpose execution.
18    fn execute_forward(
19        &mut self,
20        graph: PrimitiveGraph<'_, Op>,
21        initial_data: &HashMap<ValueKey<Op>, Arc<Op::Operand>>,
22    ) -> HashMap<ValueKey<Op>, Arc<Op::Operand>>;
23
24    /// Run a transposed linear graph with concrete cotangent seeds.
25    fn run_transposed_linear(
26        &mut self,
27        linear: &LinearizedGraph<Op>,
28        cotangent_out: &[Option<Arc<Op::Operand>>],
29        external_data: &HashMap<ValueKey<Op>, Arc<Op::Operand>>,
30        ctx: &mut Op::ADContext,
31    ) -> ADRuleResult<Vec<Option<Arc<Op::Operand>>>>;
32
33    /// Add two concrete operands for cotangent accumulation.
34    fn add_operands(&mut self, a: &Arc<Op::Operand>, b: &Arc<Op::Operand>) -> Arc<Op::Operand>;
35}
36
37/// Execute reverse-mode AD over an eager trace.
38pub fn try_backward<Op: Primitive>(
39    output_key: &ValueKey<Op>,
40    output_trace: Option<&Trace<Op>>,
41    seed: Arc<Op::Operand>,
42    executor: &mut impl BackwardExecutor<Op>,
43    ctx: &mut Op::ADContext,
44) -> ADRuleResult<HashMap<ValueKey<Op>, Arc<Op::Operand>>>
45where
46    Op::InputKey: ADKey,
47{
48    let sorted_nodes = topo_sort_trace(output_trace);
49    let mut cotangents: HashMap<ValueKey<Op>, Arc<Op::Operand>> = HashMap::new();
50    cotangents.insert(output_key.clone(), seed);
51
52    for node in sorted_nodes.iter().rev() {
53        let cotangent_out: Vec<Option<Arc<Op::Operand>>> = node
54            .primal_out_keys()
55            .iter()
56            .map(|key| cotangents.get(key).cloned())
57            .collect();
58        if cotangent_out.iter().all(Option::is_none) {
59            continue;
60        }
61
62        let mut active_output_slots = Vec::new();
63        let mut active_cotangent_out = Vec::new();
64        for (slot, maybe_cotangent) in cotangent_out.into_iter().enumerate() {
65            if let Some(cotangent) = maybe_cotangent {
66                active_output_slots.push(slot);
67                active_cotangent_out.push(Some(cotangent));
68            }
69        }
70
71        let linear = node
72            .computation()
73            .try_linearize(&active_output_slots, ctx)?;
74        let replay_graph = PrimitiveGraph::new(linear.as_graph());
75        let all_values = executor.execute_forward(replay_graph, node.saved_data());
76        let cotangent_in =
77            executor.run_transposed_linear(&linear, &active_cotangent_out, &all_values, ctx)?;
78
79        for (edge, maybe_cotangent) in node.input_edges().iter().zip(cotangent_in) {
80            let cotangent = match maybe_cotangent {
81                Some(cotangent) => cotangent,
82                None => continue,
83            };
84            if !edge.requires_grad {
85                continue;
86            }
87
88            let accumulated = match cotangents.remove(&edge.key) {
89                Some(existing) => executor.add_operands(&existing, &cotangent),
90                None => cotangent,
91            };
92            cotangents.insert(edge.key.clone(), accumulated);
93        }
94    }
95
96    Ok(cotangents)
97}
98
99fn topo_sort_trace<Op: GraphOperation>(
100    output_trace: Option<&Trace<Op>>,
101) -> Vec<Arc<TraceNode<Op>>> {
102    fn visit<Op: GraphOperation>(
103        node: &Arc<TraceNode<Op>>,
104        visited: &mut HashSet<*const TraceNode<Op>>,
105        order: &mut Vec<Arc<TraceNode<Op>>>,
106    ) {
107        let ptr = Arc::as_ptr(node);
108        if !visited.insert(ptr) {
109            return;
110        }
111
112        for edge in node.input_edges() {
113            if let Some(parent) = &edge.node {
114                visit(parent, visited, order);
115            }
116        }
117
118        order.push(node.clone());
119    }
120
121    let mut visited = HashSet::new();
122    let mut order = Vec::new();
123    if let Some(trace) = output_trace {
124        visit(trace.node(), &mut visited, &mut order);
125    }
126    order
127}