1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use chainrules::{ADKey, ADRuleResult, DiffPassId, PrimitiveOp};
5use computegraph::fragment::FragmentBuilder;
6use computegraph::resolve::{ResolvedView, ValDef};
7use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, LocalValId};
8
9use crate::LinearFragment;
10
11pub fn differentiate<Op: PrimitiveOp>(
31 view: &ResolvedView<Op>,
32 outputs: &[GlobalValKey<Op>],
33 wrt: &[Op::InputKey],
34 pass: DiffPassId,
35 ctx: &mut Op::ADContext,
36 aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
37) -> LinearFragment<Op>
38where
39 Op::InputKey: ADKey,
40{
41 match try_differentiate(view, outputs, wrt, pass, ctx, aliases) {
42 Ok(linear) => linear,
43 Err(err) => panic!("{err}"),
44 }
45}
46
47pub fn try_differentiate<Op: PrimitiveOp>(
53 view: &ResolvedView<Op>,
54 outputs: &[GlobalValKey<Op>],
55 wrt: &[Op::InputKey],
56 pass: DiffPassId,
57 ctx: &mut Op::ADContext,
58 aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
59) -> ADRuleResult<LinearFragment<Op>>
60where
61 Op::InputKey: ADKey,
62{
63 let mut builder = FragmentBuilder::<Op>::new();
64 let topo_keys = topological_order(view, outputs, aliases);
65 let mut tangent_env: HashMap<GlobalValKey<Op>, Option<LocalValId>> = HashMap::new();
66 let mut processed_ops = HashSet::new();
67
68 let mut tangent_inputs = Vec::with_capacity(wrt.len());
69 for wrt_key in wrt {
70 let tangent_key = wrt_key.tangent_of(pass);
71 let tangent_id = builder.add_input(tangent_key);
72 tangent_env.insert(GlobalValKey::Input(wrt_key.clone()), Some(tangent_id));
73 tangent_inputs.push((wrt_key.clone(), tangent_id));
74 }
75
76 for key in topo_keys {
77 if tangent_env.contains_key(&key) {
78 continue;
79 }
80
81 let Some(val_def) = view.resolve_val(&key) else {
82 continue;
83 };
84
85 match val_def {
86 ValDef::Input { key: ref input_key } => {
87 if let Some(aliased_key) = aliases.get(input_key) {
88 let aliased_tangent = tangent_env.get(aliased_key).copied().flatten();
89 tangent_env.insert(key, aliased_tangent);
90 } else {
91 tangent_env.insert(key, None);
92 }
93 }
94 ValDef::Produced {
95 op,
96 input_keys,
97 mode,
98 ..
99 } => {
100 let global_op_key = GlobalOpKey::new(op.clone(), input_keys.clone(), mode.clone());
101 if !processed_ops.insert(global_op_key.clone()) {
102 continue;
103 }
104
105 let tangent_in: Vec<Option<LocalValId>> = input_keys
106 .iter()
107 .map(|input_key| tangent_env.get(input_key).copied().flatten())
108 .collect();
109 let output_keys = output_keys(&global_op_key, op.n_outputs());
110
111 if tangent_in.iter().all(Option::is_none) {
112 for output_key in output_keys {
113 tangent_env.insert(output_key, None);
114 }
115 continue;
116 }
117
118 let tangent_out =
119 op.try_linearize(&mut builder, &input_keys, &output_keys, &tangent_in, ctx)?;
120 assert_eq!(
121 tangent_out.len(),
122 output_keys.len(),
123 "linearize for {:?} returned {} tangents for {} outputs",
124 op,
125 tangent_out.len(),
126 output_keys.len()
127 );
128
129 for (output_key, tangent_output) in
130 output_keys.into_iter().zip(tangent_out.into_iter())
131 {
132 tangent_env.insert(output_key, tangent_output);
133 }
134 }
135 }
136 }
137
138 let tangent_outputs: Vec<Option<LocalValId>> = outputs
139 .iter()
140 .map(|key| tangent_env.get(key).copied().flatten())
141 .collect();
142 let active_outputs: Vec<LocalValId> = tangent_outputs.iter().filter_map(|id| *id).collect();
143 if !active_outputs.is_empty() {
144 builder.set_outputs(active_outputs);
145 }
146
147 Ok(LinearFragment {
148 fragment: builder.build(),
149 tangent_inputs,
150 tangent_outputs,
151 })
152}
153
154fn output_keys<Op: GraphOp>(op_key: &GlobalOpKey<Op>, n_outputs: usize) -> Vec<GlobalValKey<Op>> {
155 let op_key = Arc::new(op_key.clone());
156 (0..n_outputs)
157 .map(|output_slot| GlobalValKey::Derived {
158 op: Arc::clone(&op_key),
159 output_slot: output_slot as u8,
160 })
161 .collect()
162}
163
164fn topological_order<Op: GraphOp>(
165 view: &ResolvedView<Op>,
166 outputs: &[GlobalValKey<Op>],
167 aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
168) -> Vec<GlobalValKey<Op>> {
169 fn visit<Op: GraphOp>(
170 key: &GlobalValKey<Op>,
171 view: &ResolvedView<Op>,
172 aliases: &HashMap<Op::InputKey, GlobalValKey<Op>>,
173 visited: &mut HashSet<GlobalValKey<Op>>,
174 order: &mut Vec<GlobalValKey<Op>>,
175 ) {
176 if !visited.insert(key.clone()) {
177 return;
178 }
179
180 match view.resolve_val(key) {
181 Some(ValDef::Produced { input_keys, .. }) => {
182 for input_key in input_keys {
183 visit(&input_key, view, aliases, visited, order);
184 }
185 }
186 Some(ValDef::Input { key: input_key }) => {
187 if let Some(aliased_key) = aliases.get(&input_key) {
188 visit(aliased_key, view, aliases, visited, order);
189 }
190 }
191 None => {}
192 }
193
194 order.push(key.clone());
195 }
196
197 let mut visited = HashSet::new();
198 let mut order = Vec::new();
199 for output_key in outputs {
200 visit(output_key, view, aliases, &mut visited, &mut order);
201 }
202 order
203}