Skip to main content

tidu/
differentiate.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use chainrules::{ADKey, ADRuleResult, DiffPassId, PrimitiveOp};
5use computegraph::fragment::FragmentBuilder;
6use computegraph::resolve::{ResolvedView, ValDef};
7use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, LocalValId};
8
9use crate::LinearFragment;
10
11/// Differentiate a resolved computation graph, producing a linear fragment.
12///
13/// The transform walks the reachable DAG from `outputs` in dependency-first
14/// order and delegates primitive-specific JVP generation to
15/// [`chainrules::PrimitiveOp::try_linearize`].
16///
17/// # Examples
18///
19/// ```ignore
20/// use computegraph::resolve::resolve;
21/// use tidu::try_differentiate;
22///
23/// let view = resolve(vec![primal_fragment]);
24/// let mut ctx = ();
25/// let aliases = std::collections::HashMap::new();
26/// let linear = try_differentiate(&view, &[output_key], &[input_key], 1, &mut ctx, &aliases)?;
27/// assert_eq!(linear.tangent_outputs.len(), 1);
28/// # Ok::<(), chainrules::ADRuleError>(())
29/// ```
30pub fn differentiate<Op: PrimitiveOp>(
31    view: &ResolvedView<Op>,
32    outputs: &[GlobalValKey<Op>],
33    wrt: &[Op::InputKey],
34    pass: DiffPassId,
35    ctx: &mut Op::ADContext,
36    aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
37) -> LinearFragment<Op>
38where
39    Op::InputKey: ADKey,
40{
41    match try_differentiate(view, outputs, wrt, pass, ctx, aliases) {
42        Ok(linear) => linear,
43        Err(err) => panic!("{err}"),
44    }
45}
46
47/// Fallible form of [`differentiate`].
48///
49/// This returns [`chainrules::ADRuleError`] when a primitive cannot emit a JVP
50/// rule, allowing downstream frontends to surface missing extension rules as
51/// normal errors instead of panics.
52pub fn try_differentiate<Op: PrimitiveOp>(
53    view: &ResolvedView<Op>,
54    outputs: &[GlobalValKey<Op>],
55    wrt: &[Op::InputKey],
56    pass: DiffPassId,
57    ctx: &mut Op::ADContext,
58    aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
59) -> ADRuleResult<LinearFragment<Op>>
60where
61    Op::InputKey: ADKey,
62{
63    let mut builder = FragmentBuilder::<Op>::new();
64    let topo_keys = topological_order(view, outputs, aliases);
65    let mut tangent_env: HashMap<GlobalValKey<Op>, Option<LocalValId>> = HashMap::new();
66    let mut processed_ops = HashSet::new();
67
68    let mut tangent_inputs = Vec::with_capacity(wrt.len());
69    for wrt_key in wrt {
70        let tangent_key = wrt_key.tangent_of(pass);
71        let tangent_id = builder.add_input(tangent_key);
72        tangent_env.insert(GlobalValKey::Input(wrt_key.clone()), Some(tangent_id));
73        tangent_inputs.push((wrt_key.clone(), tangent_id));
74    }
75
76    for key in topo_keys {
77        if tangent_env.contains_key(&key) {
78            continue;
79        }
80
81        let Some(val_def) = view.resolve_val(&key) else {
82            continue;
83        };
84
85        match val_def {
86            ValDef::Input { key: ref input_key } => {
87                if let Some(aliased_key) = aliases.get(input_key) {
88                    let aliased_tangent = tangent_env.get(aliased_key).copied().flatten();
89                    tangent_env.insert(key, aliased_tangent);
90                } else {
91                    tangent_env.insert(key, None);
92                }
93            }
94            ValDef::Produced {
95                op,
96                input_keys,
97                mode,
98                ..
99            } => {
100                let global_op_key = GlobalOpKey::new(op.clone(), input_keys.clone(), mode.clone());
101                if !processed_ops.insert(global_op_key.clone()) {
102                    continue;
103                }
104
105                let tangent_in: Vec<Option<LocalValId>> = input_keys
106                    .iter()
107                    .map(|input_key| tangent_env.get(input_key).copied().flatten())
108                    .collect();
109                let output_keys = output_keys(&global_op_key, op.n_outputs());
110
111                if tangent_in.iter().all(Option::is_none) {
112                    for output_key in output_keys {
113                        tangent_env.insert(output_key, None);
114                    }
115                    continue;
116                }
117
118                let tangent_out =
119                    op.try_linearize(&mut builder, &input_keys, &output_keys, &tangent_in, ctx)?;
120                assert_eq!(
121                    tangent_out.len(),
122                    output_keys.len(),
123                    "linearize for {:?} returned {} tangents for {} outputs",
124                    op,
125                    tangent_out.len(),
126                    output_keys.len()
127                );
128
129                for (output_key, tangent_output) in
130                    output_keys.into_iter().zip(tangent_out.into_iter())
131                {
132                    tangent_env.insert(output_key, tangent_output);
133                }
134            }
135        }
136    }
137
138    let tangent_outputs: Vec<Option<LocalValId>> = outputs
139        .iter()
140        .map(|key| tangent_env.get(key).copied().flatten())
141        .collect();
142    let active_outputs: Vec<LocalValId> = tangent_outputs.iter().filter_map(|id| *id).collect();
143    if !active_outputs.is_empty() {
144        builder.set_outputs(active_outputs);
145    }
146
147    Ok(LinearFragment {
148        fragment: builder.build(),
149        tangent_inputs,
150        tangent_outputs,
151    })
152}
153
154fn output_keys<Op: GraphOp>(op_key: &GlobalOpKey<Op>, n_outputs: usize) -> Vec<GlobalValKey<Op>> {
155    let op_key = Arc::new(op_key.clone());
156    (0..n_outputs)
157        .map(|output_slot| GlobalValKey::Derived {
158            op: Arc::clone(&op_key),
159            output_slot: output_slot as u8,
160        })
161        .collect()
162}
163
164fn topological_order<Op: GraphOp>(
165    view: &ResolvedView<Op>,
166    outputs: &[GlobalValKey<Op>],
167    aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
168) -> Vec<GlobalValKey<Op>> {
169    fn visit<Op: GraphOp>(
170        key: &GlobalValKey<Op>,
171        view: &ResolvedView<Op>,
172        aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
173        visited: &mut HashSet<GlobalValKey<Op>>,
174        order: &mut Vec<GlobalValKey<Op>>,
175    ) {
176        if !visited.insert(key.clone()) {
177            return;
178        }
179
180        match view.resolve_val(key) {
181            Some(ValDef::Produced { input_keys, .. }) => {
182                for input_key in input_keys {
183                    visit(&input_key, view, aliases, visited, order);
184                }
185            }
186            Some(ValDef::Input { key: input_key }) => {
187                if let Some(aliased_key) = aliases.get(&input_key) {
188                    visit(aliased_key, view, aliases, visited, order);
189                }
190            }
191            None => {}
192        }
193
194        order.push(key.clone());
195    }
196
197    let mut visited = HashSet::new();
198    let mut order = Vec::new();
199    for output_key in outputs {
200        visit(output_key, view, aliases, &mut visited, &mut order);
201    }
202    order
203}