Skip to main content

tidu/
differentiate.rs

1use std::collections::{HashMap, HashSet};
2
3use chainrules::{ADKey, DiffPassId, PrimitiveOp};
4use computegraph::fragment::FragmentBuilder;
5use computegraph::resolve::{ResolvedView, ValDef};
6use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, LocalValId};
7
8use crate::LinearFragment;
9
10/// Differentiate a resolved computation graph, producing a linear fragment.
11///
12/// The transform walks the reachable DAG from `outputs` in dependency-first
13/// order and delegates primitive-specific JVP generation to
14/// [`chainrules::PrimitiveOp::linearize`].
15///
16/// # Examples
17///
18/// ```ignore
19/// use computegraph::resolve::resolve;
20/// use tidu::differentiate;
21///
22/// let view = resolve(vec![primal_fragment]);
23/// let mut ctx = ();
24/// let aliases = std::collections::HashMap::new();
25/// let linear = differentiate(&view, &[output_key], &[input_key], 1, &mut ctx, &aliases);
26/// assert_eq!(linear.tangent_outputs.len(), 1);
27/// ```
28pub fn differentiate<Op: PrimitiveOp>(
29    view: &ResolvedView<Op>,
30    outputs: &[GlobalValKey<Op>],
31    wrt: &[Op::InputKey],
32    pass: DiffPassId,
33    ctx: &mut Op::ADContext,
34    aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
35) -> LinearFragment<Op>
36where
37    Op::InputKey: ADKey,
38{
39    let mut builder = FragmentBuilder::<Op>::new();
40    let topo_keys = topological_order(view, outputs, aliases);
41    let mut tangent_env: HashMap<GlobalValKey<Op>, Option<LocalValId>> = HashMap::new();
42    let mut processed_ops = HashSet::new();
43
44    let mut tangent_inputs = Vec::with_capacity(wrt.len());
45    for wrt_key in wrt {
46        let tangent_key = wrt_key.tangent_of(pass);
47        let tangent_id = builder.add_input(tangent_key);
48        tangent_env.insert(GlobalValKey::Input(wrt_key.clone()), Some(tangent_id));
49        tangent_inputs.push((wrt_key.clone(), tangent_id));
50    }
51
52    for key in topo_keys {
53        if tangent_env.contains_key(&key) {
54            continue;
55        }
56
57        let Some(val_def) = view.resolve_val(&key) else {
58            continue;
59        };
60
61        match val_def {
62            ValDef::Input { key: ref input_key } => {
63                if let Some(aliased_key) = aliases.get(input_key) {
64                    let aliased_tangent = tangent_env.get(aliased_key).copied().flatten();
65                    tangent_env.insert(key, aliased_tangent);
66                } else {
67                    tangent_env.insert(key, None);
68                }
69            }
70            ValDef::Produced {
71                op,
72                input_keys,
73                mode,
74                ..
75            } => {
76                let global_op_key = GlobalOpKey {
77                    primitive: op.clone(),
78                    inputs: input_keys.clone(),
79                    mode: mode.clone(),
80                };
81                if !processed_ops.insert(global_op_key.clone()) {
82                    continue;
83                }
84
85                let tangent_in: Vec<Option<LocalValId>> = input_keys
86                    .iter()
87                    .map(|input_key| tangent_env.get(input_key).copied().flatten())
88                    .collect();
89                let output_keys = output_keys(&global_op_key, op.n_outputs());
90
91                if tangent_in.iter().all(Option::is_none) {
92                    for output_key in output_keys {
93                        tangent_env.insert(output_key, None);
94                    }
95                    continue;
96                }
97
98                let tangent_out =
99                    op.linearize(&mut builder, &input_keys, &output_keys, &tangent_in, ctx);
100                assert_eq!(
101                    tangent_out.len(),
102                    output_keys.len(),
103                    "linearize for {:?} returned {} tangents for {} outputs",
104                    op,
105                    tangent_out.len(),
106                    output_keys.len()
107                );
108
109                for (output_key, tangent_output) in
110                    output_keys.into_iter().zip(tangent_out.into_iter())
111                {
112                    tangent_env.insert(output_key, tangent_output);
113                }
114            }
115        }
116    }
117
118    let tangent_outputs: Vec<Option<LocalValId>> = outputs
119        .iter()
120        .map(|key| tangent_env.get(key).copied().flatten())
121        .collect();
122    let active_outputs: Vec<LocalValId> = tangent_outputs.iter().filter_map(|id| *id).collect();
123    if !active_outputs.is_empty() {
124        builder.set_outputs(active_outputs);
125    }
126
127    LinearFragment {
128        fragment: builder.build(),
129        tangent_inputs,
130        tangent_outputs,
131    }
132}
133
134fn output_keys<Op: GraphOp>(op_key: &GlobalOpKey<Op>, n_outputs: usize) -> Vec<GlobalValKey<Op>> {
135    (0..n_outputs)
136        .map(|output_slot| GlobalValKey::Derived {
137            op: op_key.clone(),
138            output_slot: output_slot as u8,
139        })
140        .collect()
141}
142
143fn topological_order<Op: GraphOp>(
144    view: &ResolvedView<Op>,
145    outputs: &[GlobalValKey<Op>],
146    aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
147) -> Vec<GlobalValKey<Op>> {
148    fn visit<Op: GraphOp>(
149        key: &GlobalValKey<Op>,
150        view: &ResolvedView<Op>,
151        aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
152        visited: &mut HashSet<GlobalValKey<Op>>,
153        order: &mut Vec<GlobalValKey<Op>>,
154    ) {
155        if !visited.insert(key.clone()) {
156            return;
157        }
158
159        match view.resolve_val(key) {
160            Some(ValDef::Produced { input_keys, .. }) => {
161                for input_key in input_keys {
162                    visit(&input_key, view, aliases, visited, order);
163                }
164            }
165            Some(ValDef::Input { key: input_key }) => {
166                if let Some(aliased_key) = aliases.get(&input_key) {
167                    visit(aliased_key, view, aliases, visited, order);
168                }
169            }
170            None => {}
171        }
172
173        order.push(key.clone());
174    }
175
176    let mut visited = HashSet::new();
177    let mut order = Vec::new();
178    for output_key in outputs {
179        visit(output_key, view, aliases, &mut visited, &mut order);
180    }
181    order
182}