1use std::collections::{HashMap, HashSet};
2
3use chainrules::{ADKey, DiffPassId, PrimitiveOp};
4use computegraph::fragment::FragmentBuilder;
5use computegraph::resolve::{ResolvedView, ValDef};
6use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, LocalValId};
7
8use crate::LinearFragment;
9
10pub fn differentiate<Op: PrimitiveOp>(
29 view: &ResolvedView<Op>,
30 outputs: &[GlobalValKey<Op>],
31 wrt: &[Op::InputKey],
32 pass: DiffPassId,
33 ctx: &mut Op::ADContext,
34 aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
35) -> LinearFragment<Op>
36where
37 Op::InputKey: ADKey,
38{
39 let mut builder = FragmentBuilder::<Op>::new();
40 let topo_keys = topological_order(view, outputs, aliases);
41 let mut tangent_env: HashMap<GlobalValKey<Op>, Option<LocalValId>> = HashMap::new();
42 let mut processed_ops = HashSet::new();
43
44 let mut tangent_inputs = Vec::with_capacity(wrt.len());
45 for wrt_key in wrt {
46 let tangent_key = wrt_key.tangent_of(pass);
47 let tangent_id = builder.add_input(tangent_key);
48 tangent_env.insert(GlobalValKey::Input(wrt_key.clone()), Some(tangent_id));
49 tangent_inputs.push((wrt_key.clone(), tangent_id));
50 }
51
52 for key in topo_keys {
53 if tangent_env.contains_key(&key) {
54 continue;
55 }
56
57 let Some(val_def) = view.resolve_val(&key) else {
58 continue;
59 };
60
61 match val_def {
62 ValDef::Input { key: ref input_key } => {
63 if let Some(aliased_key) = aliases.get(input_key) {
64 let aliased_tangent = tangent_env.get(aliased_key).copied().flatten();
65 tangent_env.insert(key, aliased_tangent);
66 } else {
67 tangent_env.insert(key, None);
68 }
69 }
70 ValDef::Produced {
71 op,
72 input_keys,
73 mode,
74 ..
75 } => {
76 let global_op_key = GlobalOpKey {
77 primitive: op.clone(),
78 inputs: input_keys.clone(),
79 mode: mode.clone(),
80 };
81 if !processed_ops.insert(global_op_key.clone()) {
82 continue;
83 }
84
85 let tangent_in: Vec<Option<LocalValId>> = input_keys
86 .iter()
87 .map(|input_key| tangent_env.get(input_key).copied().flatten())
88 .collect();
89 let output_keys = output_keys(&global_op_key, op.n_outputs());
90
91 if tangent_in.iter().all(Option::is_none) {
92 for output_key in output_keys {
93 tangent_env.insert(output_key, None);
94 }
95 continue;
96 }
97
98 let tangent_out =
99 op.linearize(&mut builder, &input_keys, &output_keys, &tangent_in, ctx);
100 assert_eq!(
101 tangent_out.len(),
102 output_keys.len(),
103 "linearize for {:?} returned {} tangents for {} outputs",
104 op,
105 tangent_out.len(),
106 output_keys.len()
107 );
108
109 for (output_key, tangent_output) in
110 output_keys.into_iter().zip(tangent_out.into_iter())
111 {
112 tangent_env.insert(output_key, tangent_output);
113 }
114 }
115 }
116 }
117
118 let tangent_outputs: Vec<Option<LocalValId>> = outputs
119 .iter()
120 .map(|key| tangent_env.get(key).copied().flatten())
121 .collect();
122 let active_outputs: Vec<LocalValId> = tangent_outputs.iter().filter_map(|id| *id).collect();
123 if !active_outputs.is_empty() {
124 builder.set_outputs(active_outputs);
125 }
126
127 LinearFragment {
128 fragment: builder.build(),
129 tangent_inputs,
130 tangent_outputs,
131 }
132}
133
134fn output_keys<Op: GraphOp>(op_key: &GlobalOpKey<Op>, n_outputs: usize) -> Vec<GlobalValKey<Op>> {
135 (0..n_outputs)
136 .map(|output_slot| GlobalValKey::Derived {
137 op: op_key.clone(),
138 output_slot: output_slot as u8,
139 })
140 .collect()
141}
142
143fn topological_order<Op: GraphOp>(
144 view: &ResolvedView<Op>,
145 outputs: &[GlobalValKey<Op>],
146 aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
147) -> Vec<GlobalValKey<Op>> {
148 fn visit<Op: GraphOp>(
149 key: &GlobalValKey<Op>,
150 view: &ResolvedView<Op>,
151 aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
152 visited: &mut HashSet<GlobalValKey<Op>>,
153 order: &mut Vec<GlobalValKey<Op>>,
154 ) {
155 if !visited.insert(key.clone()) {
156 return;
157 }
158
159 match view.resolve_val(key) {
160 Some(ValDef::Produced { input_keys, .. }) => {
161 for input_key in input_keys {
162 visit(&input_key, view, aliases, visited, order);
163 }
164 }
165 Some(ValDef::Input { key: input_key }) => {
166 if let Some(aliased_key) = aliases.get(&input_key) {
167 visit(aliased_key, view, aliases, visited, order);
168 }
169 }
170 None => {}
171 }
172
173 order.push(key.clone());
174 }
175
176 let mut visited = HashSet::new();
177 let mut order = Vec::new();
178 for output_key in outputs {
179 visit(output_key, view, aliases, &mut visited, &mut order);
180 }
181 order
182}