Skip to main content

tidu/
backward.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use chainrules::{ADKey, PrimitiveOp};
5use computegraph::fragment::{Fragment, FragmentBuilder};
6use computegraph::resolve::resolve;
7use computegraph::{GlobalValKey, GraphOp, OpMode, ValRef};
8
9use crate::grad_node::GradNode;
10use crate::LinearFragment;
11
12/// Caller-provided execution hooks for eager backward.
13pub trait BackwardCallbacks<Op: PrimitiveOp>
14where
15    Op::InputKey: ADKey,
16{
17    /// Execute a linear fragment forward and return any concrete values needed
18    /// by eager transpose.
19    fn execute_forward(
20        &mut self,
21        fragment: &Fragment<Op>,
22        initial_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
23    ) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>>;
24
25    /// Execute transpose eagerly for a linear fragment with concrete seeds.
26    fn eager_transpose(
27        &mut self,
28        linear: &LinearFragment<Op>,
29        cotangent_out: &[Option<Arc<Op::Operand>>],
30        external_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
31        ctx: &mut Op::ADContext,
32    ) -> Vec<Option<Arc<Op::Operand>>>;
33
34    /// Add two concrete operands for cotangent accumulation.
35    fn add_operands(&mut self, a: &Arc<Op::Operand>, b: &Arc<Op::Operand>) -> Arc<Op::Operand>;
36}
37
38/// Topologically sort the reachable grad DAG in dependency-first order.
39pub fn topo_sort_grad_dag<Op: GraphOp>(
40    output_node: &Option<Arc<GradNode<Op>>>,
41) -> Vec<Arc<GradNode<Op>>> {
42    fn visit<Op: GraphOp>(
43        node: &Arc<GradNode<Op>>,
44        visited: &mut HashSet<*const GradNode<Op>>,
45        order: &mut Vec<Arc<GradNode<Op>>>,
46    ) {
47        let ptr = Arc::as_ptr(node);
48        if !visited.insert(ptr) {
49            return;
50        }
51
52        for edge in node.input_edges() {
53            if let Some(parent) = &edge.node {
54                visit(parent, visited, order);
55            }
56        }
57
58        order.push(node.clone());
59    }
60
61    let mut visited = HashSet::new();
62    let mut order = Vec::new();
63    if let Some(node) = output_node {
64        visit(node, &mut visited, &mut order);
65    }
66    order
67}
68
69/// Execute reverse-mode AD over a grad DAG.
70pub fn backward_dag<Op: PrimitiveOp>(
71    sorted_nodes: &[Arc<GradNode<Op>>],
72    output_key: &GlobalValKey<Op>,
73    seed: Arc<Op::Operand>,
74    callbacks: &mut impl BackwardCallbacks<Op>,
75    ctx: &mut Op::ADContext,
76) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>>
77where
78    Op::InputKey: ADKey,
79{
80    let mut cotangents: HashMap<GlobalValKey<Op>, Arc<Op::Operand>> = HashMap::new();
81    cotangents.insert(output_key.clone(), seed);
82
83    for node in sorted_nodes.iter().rev() {
84        let cotangent_out: Vec<Option<Arc<Op::Operand>>> = node
85            .primal_out_keys()
86            .iter()
87            .map(|key| cotangents.get(key).cloned())
88            .collect();
89        if cotangent_out.iter().all(Option::is_none) {
90            continue;
91        }
92
93        let linear = build_single_op_linear(node, ctx);
94        let all_values = callbacks.execute_forward(&linear.fragment, node.saved_data());
95        let cotangent_in = callbacks.eager_transpose(&linear, &cotangent_out, &all_values, ctx);
96
97        for (edge, maybe_cotangent) in node.input_edges().iter().zip(cotangent_in.into_iter()) {
98            let Some(cotangent) = maybe_cotangent else {
99                continue;
100            };
101            if !edge.requires_grad {
102                continue;
103            }
104
105            let accumulated = match cotangents.remove(&edge.key) {
106                Some(existing) => callbacks.add_operands(&existing, &cotangent),
107                None => cotangent,
108            };
109            cotangents.insert(edge.key.clone(), accumulated);
110        }
111    }
112
113    cotangents
114}
115
116fn build_single_op_linear<Op: PrimitiveOp>(
117    node: &GradNode<Op>,
118    ctx: &mut Op::ADContext,
119) -> LinearFragment<Op>
120where
121    Op::InputKey: ADKey,
122{
123    let mut builder = FragmentBuilder::new();
124
125    let input_local_ids: Vec<_> = node
126        .primal_in_keys()
127        .iter()
128        .map(|key| match key {
129            GlobalValKey::Input(input_key) => builder.add_input(input_key.clone()),
130            GlobalValKey::Derived { .. } => {
131                panic!(
132                    "build_single_op_linear requires GlobalValKey::Input aliases in node.primal_in_keys"
133                )
134            }
135        })
136        .collect();
137
138    let outputs = builder.add_op(
139        node.op().clone(),
140        input_local_ids
141            .iter()
142            .map(|local_id| ValRef::Local(*local_id))
143            .collect(),
144        OpMode::Primal,
145    );
146    builder.set_outputs(outputs.clone());
147
148    let fragment = Arc::new(builder.build());
149    let view = resolve(vec![fragment.clone()]);
150    let output_keys: Vec<_> = outputs
151        .iter()
152        .map(|output_id| fragment.vals()[*output_id].key.clone())
153        .collect();
154    let wrt_keys: Vec<_> = node
155        .primal_in_keys()
156        .iter()
157        .map(|key| match key {
158            GlobalValKey::Input(input_key) => input_key.clone(),
159            GlobalValKey::Derived { .. } => {
160                panic!(
161                    "build_single_op_linear requires GlobalValKey::Input aliases in node.primal_in_keys"
162                )
163            }
164        })
165        .collect();
166    let aliases = HashMap::new();
167
168    crate::differentiate(&view, &output_keys, &wrt_keys, 0, ctx, &aliases)
169}