1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use crate::rules::GraphPrimitiveBuilder;
5use crate::{ADKey, ADRuleResult, DiffPassId, Primitive};
6use computegraph::graph::GraphBuilder;
7use computegraph::resolve::{ResolvedView, ValueDef};
8use computegraph::{GraphOperation, LocalValueId, OperationKey, ValueKey};
9
10use crate::LinearizedGraph;
11
12pub fn linearize<Op: Primitive>(
32 view: &ResolvedView<Op>,
33 outputs: &[ValueKey<Op>],
34 wrt: &[Op::InputKey],
35 pass: DiffPassId,
36 ctx: &mut Op::ADContext,
37 aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
38) -> LinearizedGraph<Op>
39where
40 Op::InputKey: ADKey,
41{
42 match try_linearize(view, outputs, wrt, pass, ctx, aliases) {
43 Ok(linear) => linear,
44 Err(err) => panic!("{}", err),
45 }
46}
47
48pub fn try_linearize<Op: Primitive>(
54 view: &ResolvedView<Op>,
55 outputs: &[ValueKey<Op>],
56 wrt: &[Op::InputKey],
57 pass: DiffPassId,
58 ctx: &mut Op::ADContext,
59 aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
60) -> ADRuleResult<LinearizedGraph<Op>>
61where
62 Op::InputKey: ADKey,
63{
64 let mut builder = GraphBuilder::<Op>::new();
65 let topo_keys = topological_order(view, outputs, aliases);
66 let mut tangent_env: HashMap<ValueKey<Op>, Option<LocalValueId>> = HashMap::new();
67 let mut processed_ops = HashSet::new();
68
69 let mut tangent_inputs = Vec::with_capacity(wrt.len());
70 for wrt_key in wrt {
71 let tangent_key = wrt_key.tangent_of(pass);
72 let tangent_id = builder.add_input(tangent_key);
73 tangent_env.insert(ValueKey::Input(wrt_key.clone()), Some(tangent_id));
74 tangent_inputs.push((wrt_key.clone(), tangent_id));
75 }
76
77 for key in topo_keys {
78 if tangent_env.contains_key(&key) {
79 continue;
80 }
81
82 let val_def = match view.resolve_value(&key) {
83 Some(val_def) => val_def,
84 None => continue,
85 };
86
87 match val_def {
88 ValueDef::Input { key: ref input_key } => {
89 if let Some(aliased_key) = aliases.get(input_key) {
90 let aliased_tangent = tangent_env.get(aliased_key).copied().flatten();
91 tangent_env.insert(key, aliased_tangent);
92 } else {
93 tangent_env.insert(key, None);
94 }
95 }
96 ValueDef::Produced {
97 operation,
98 input_keys,
99 role,
100 ..
101 } => {
102 let global_op_key =
103 OperationKey::new(operation.clone(), input_keys.clone(), role.clone());
104 if !processed_ops.insert(global_op_key.clone()) {
105 continue;
106 }
107
108 let tangent_in: Vec<Option<LocalValueId>> = input_keys
109 .iter()
110 .map(|input_key| tangent_env.get(input_key).copied().flatten())
111 .collect();
112 let output_keys = output_keys(&global_op_key, operation.output_count());
113
114 if tangent_in.iter().all(Option::is_none) {
115 for output_key in output_keys {
116 tangent_env.insert(output_key, None);
117 }
118 continue;
119 }
120
121 let mut primitive_builder = GraphPrimitiveBuilder::new(&mut builder);
122 let tangent_out = operation.try_jvp_rule(
123 &mut primitive_builder,
124 &input_keys,
125 &output_keys,
126 &tangent_in,
127 ctx,
128 )?;
129 assert_eq!(
130 tangent_out.len(),
131 output_keys.len(),
132 "jvp_rule for {:?} returned {} tangents for {} outputs",
133 operation,
134 tangent_out.len(),
135 output_keys.len()
136 );
137
138 for (output_key, tangent_output) in output_keys.into_iter().zip(tangent_out) {
139 tangent_env.insert(output_key, tangent_output);
140 }
141 }
142 }
143 }
144
145 let tangent_outputs: Vec<Option<LocalValueId>> = outputs
146 .iter()
147 .map(|key| tangent_env.get(key).copied().flatten())
148 .collect();
149 let active_outputs: Vec<LocalValueId> = tangent_outputs.iter().filter_map(|id| *id).collect();
150 if !active_outputs.is_empty() {
151 builder.set_outputs(active_outputs);
152 }
153
154 Ok(LinearizedGraph::from_parts(
155 builder.build(),
156 tangent_inputs,
157 tangent_outputs,
158 ))
159}
160
161fn output_keys<Op: GraphOperation>(
162 op_key: &OperationKey<Op>,
163 output_count: usize,
164) -> Vec<ValueKey<Op>> {
165 let op_key = Arc::new(op_key.clone());
166 (0..output_count)
167 .map(|output_slot| ValueKey::Derived {
168 operation: Arc::clone(&op_key),
169 output_slot: output_slot as u8,
170 })
171 .collect()
172}
173
174fn topological_order<Op: GraphOperation>(
175 view: &ResolvedView<Op>,
176 outputs: &[ValueKey<Op>],
177 aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
178) -> Vec<ValueKey<Op>> {
179 fn visit<Op: GraphOperation>(
180 key: &ValueKey<Op>,
181 view: &ResolvedView<Op>,
182 aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
183 visited: &mut HashSet<ValueKey<Op>>,
184 order: &mut Vec<ValueKey<Op>>,
185 ) {
186 if !visited.insert(key.clone()) {
187 return;
188 }
189
190 match view.resolve_value(key) {
191 Some(ValueDef::Produced { input_keys, .. }) => {
192 for input_key in input_keys {
193 visit(&input_key, view, aliases, visited, order);
194 }
195 }
196 Some(ValueDef::Input { key: input_key }) => {
197 if let Some(aliased_key) = aliases.get(&input_key) {
198 visit(aliased_key, view, aliases, visited, order);
199 }
200 }
201 None => {}
202 }
203
204 order.push(key.clone());
205 }
206
207 let mut visited = HashSet::new();
208 let mut order = Vec::new();
209 for output_key in outputs {
210 visit(output_key, view, aliases, &mut visited, &mut order);
211 }
212 order
213}