Skip to main content

tidu/
backward.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use chainrules::{ADKey, ADRuleResult, PrimitiveOp};
5use computegraph::fragment::{Fragment, FragmentBuilder};
6use computegraph::resolve::resolve;
7use computegraph::{GlobalValKey, GraphOp, OpMode, ValRef};
8
9use crate::grad_node::GradNode;
10use crate::LinearFragment;
11
12/// Caller-provided execution hooks for eager backward.
13pub trait BackwardCallbacks<Op: PrimitiveOp>
14where
15    Op::InputKey: ADKey,
16{
17    /// Execute a linear fragment forward and return any concrete values needed
18    /// by eager transpose.
19    fn execute_forward(
20        &mut self,
21        fragment: &Fragment<Op>,
22        initial_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
23    ) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>>;
24
25    /// Execute transpose eagerly for a linear fragment with concrete seeds.
26    fn eager_transpose(
27        &mut self,
28        linear: &LinearFragment<Op>,
29        cotangent_out: &[Option<Arc<Op::Operand>>],
30        external_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
31        ctx: &mut Op::ADContext,
32    ) -> Vec<Option<Arc<Op::Operand>>>;
33
34    /// Fallible eager transpose hook.
35    ///
36    /// Frontends that execute extension rules should override this method so
37    /// missing AD rules can propagate out of [`try_backward_dag`]. The default
38    /// implementation preserves the existing infallible callback contract.
39    fn try_eager_transpose(
40        &mut self,
41        linear: &LinearFragment<Op>,
42        cotangent_out: &[Option<Arc<Op::Operand>>],
43        external_data: &HashMap<GlobalValKey<Op>, Arc<Op::Operand>>,
44        ctx: &mut Op::ADContext,
45    ) -> ADRuleResult<Vec<Option<Arc<Op::Operand>>>> {
46        Ok(self.eager_transpose(linear, cotangent_out, external_data, ctx))
47    }
48
49    /// Add two concrete operands for cotangent accumulation.
50    fn add_operands(&mut self, a: &Arc<Op::Operand>, b: &Arc<Op::Operand>) -> Arc<Op::Operand>;
51}
52
53/// Topologically sort the reachable grad DAG in dependency-first order.
54pub fn topo_sort_grad_dag<Op: GraphOp>(
55    output_node: &Option<Arc<GradNode<Op>>>,
56) -> Vec<Arc<GradNode<Op>>> {
57    fn visit<Op: GraphOp>(
58        node: &Arc<GradNode<Op>>,
59        visited: &mut HashSet<*const GradNode<Op>>,
60        order: &mut Vec<Arc<GradNode<Op>>>,
61    ) {
62        let ptr = Arc::as_ptr(node);
63        if !visited.insert(ptr) {
64            return;
65        }
66
67        for edge in node.input_edges() {
68            if let Some(parent) = &edge.node {
69                visit(parent, visited, order);
70            }
71        }
72
73        order.push(node.clone());
74    }
75
76    let mut visited = HashSet::new();
77    let mut order = Vec::new();
78    if let Some(node) = output_node {
79        visit(node, &mut visited, &mut order);
80    }
81    order
82}
83
84/// Execute reverse-mode AD over a grad DAG.
85pub fn backward_dag<Op: PrimitiveOp>(
86    sorted_nodes: &[Arc<GradNode<Op>>],
87    output_key: &GlobalValKey<Op>,
88    seed: Arc<Op::Operand>,
89    callbacks: &mut impl BackwardCallbacks<Op>,
90    ctx: &mut Op::ADContext,
91) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>>
92where
93    Op::InputKey: ADKey,
94{
95    match try_backward_dag(sorted_nodes, output_key, seed, callbacks, ctx) {
96        Ok(cotangents) => cotangents,
97        Err(err) => panic!("{err}"),
98    }
99}
100
101/// Fallible form of [`backward_dag`].
102pub fn try_backward_dag<Op: PrimitiveOp>(
103    sorted_nodes: &[Arc<GradNode<Op>>],
104    output_key: &GlobalValKey<Op>,
105    seed: Arc<Op::Operand>,
106    callbacks: &mut impl BackwardCallbacks<Op>,
107    ctx: &mut Op::ADContext,
108) -> ADRuleResult<HashMap<GlobalValKey<Op>, Arc<Op::Operand>>>
109where
110    Op::InputKey: ADKey,
111{
112    let mut cotangents: HashMap<GlobalValKey<Op>, Arc<Op::Operand>> = HashMap::new();
113    cotangents.insert(output_key.clone(), seed);
114
115    for node in sorted_nodes.iter().rev() {
116        let cotangent_out: Vec<Option<Arc<Op::Operand>>> = node
117            .primal_out_keys()
118            .iter()
119            .map(|key| cotangents.get(key).cloned())
120            .collect();
121        if cotangent_out.iter().all(Option::is_none) {
122            continue;
123        }
124
125        let mut active_output_slots = Vec::new();
126        let mut active_cotangent_out = Vec::new();
127        for (slot, maybe_cotangent) in cotangent_out.into_iter().enumerate() {
128            if let Some(cotangent) = maybe_cotangent {
129                active_output_slots.push(slot);
130                active_cotangent_out.push(Some(cotangent));
131            }
132        }
133
134        let linear = try_build_single_op_linear(node, &active_output_slots, ctx)?;
135        let all_values = callbacks.execute_forward(&linear.fragment, node.saved_data());
136        let cotangent_in =
137            callbacks.try_eager_transpose(&linear, &active_cotangent_out, &all_values, ctx)?;
138
139        for (edge, maybe_cotangent) in node.input_edges().iter().zip(cotangent_in.into_iter()) {
140            let Some(cotangent) = maybe_cotangent else {
141                continue;
142            };
143            if !edge.requires_grad {
144                continue;
145            }
146
147            let accumulated = match cotangents.remove(&edge.key) {
148                Some(existing) => callbacks.add_operands(&existing, &cotangent),
149                None => cotangent,
150            };
151            cotangents.insert(edge.key.clone(), accumulated);
152        }
153    }
154
155    Ok(cotangents)
156}
157
158fn try_build_single_op_linear<Op: PrimitiveOp>(
159    node: &GradNode<Op>,
160    output_slots: &[usize],
161    ctx: &mut Op::ADContext,
162) -> ADRuleResult<LinearFragment<Op>>
163where
164    Op::InputKey: ADKey,
165{
166    let mut builder = FragmentBuilder::new();
167
168    let input_local_ids: Vec<_> = node
169        .primal_in_keys()
170        .iter()
171        .map(|key| match key {
172            GlobalValKey::Input(input_key) => builder.add_input(input_key.clone()),
173            GlobalValKey::Derived { .. } => {
174                panic!(
175                    "build_single_op_linear requires GlobalValKey::Input aliases in node.primal_in_keys"
176                )
177            }
178        })
179        .collect();
180
181    let outputs = builder.add_op(
182        node.op().clone(),
183        input_local_ids
184            .iter()
185            .map(|local_id| ValRef::Local(*local_id))
186            .collect(),
187        OpMode::Primal,
188    );
189    let selected_outputs: Vec<_> = output_slots
190        .iter()
191        .map(|&slot| {
192            *outputs.get(slot).unwrap_or_else(|| {
193                panic!(
194                    "build_single_op_linear got output slot {slot} for {:?}, \
195                     which has only {} outputs",
196                    node.op(),
197                    outputs.len()
198                )
199            })
200        })
201        .collect();
202    builder.set_outputs(selected_outputs.clone());
203
204    let fragment = Arc::new(builder.build());
205    let view = resolve(vec![fragment.clone()]);
206    let output_keys: Vec<_> = selected_outputs
207        .iter()
208        .map(|output_id| fragment.vals()[*output_id].key.clone())
209        .collect();
210    let wrt_keys: Vec<_> = node
211        .primal_in_keys()
212        .iter()
213        .map(|key| match key {
214            GlobalValKey::Input(input_key) => input_key.clone(),
215            GlobalValKey::Derived { .. } => {
216                panic!(
217                    "build_single_op_linear requires GlobalValKey::Input aliases in node.primal_in_keys"
218                )
219            }
220        })
221        .collect();
222    let aliases = HashMap::new();
223
224    crate::try_differentiate(&view, &output_keys, &wrt_keys, 0, ctx, &aliases)
225}