1use std::collections::HashMap;
2
3use crate::rules::GraphPrimitiveBuilder;
4use crate::{ADKey, ADRuleResult, Primitive, PrimitiveBuilder, PrimitiveValue};
5use computegraph::graph::GraphBuilder;
6use computegraph::{LocalValueId, OperationRole, ValueKey, ValueRef};
7
8use crate::LinearizedGraph;
9
10pub fn linear_transpose<Op: Primitive>(
23 linear: &LinearizedGraph<Op>,
24 ctx: &mut Op::ADContext,
25) -> LinearizedGraph<Op>
26where
27 Op::InputKey: ADKey,
28{
29 match try_linear_transpose(linear, ctx) {
30 Ok(transposed) => transposed,
31 Err(err) => panic!("{}", err),
32 }
33}
34
35pub fn try_linear_transpose<Op: Primitive>(
40 linear: &LinearizedGraph<Op>,
41 ctx: &mut Op::ADContext,
42) -> ADRuleResult<LinearizedGraph<Op>>
43where
44 Op::InputKey: ADKey,
45{
46 let mut builder = GraphBuilder::<Op>::new();
47 let mut cotangent_env: HashMap<ValueKey<Op>, LocalValueId> = HashMap::new();
48 let mut cotangent_seed_inputs = Vec::new();
49 let graph = linear.as_graph();
50
51 for (index, maybe_tangent_output) in linear.tangent_outputs().iter().enumerate() {
52 let tangent_output_id = match maybe_tangent_output {
53 Some(tangent_output_id) => tangent_output_id,
54 None => continue,
55 };
56
57 let source_key = graph.values()[*tangent_output_id].key.clone();
58 let seed_key = cotangent_seed_key(linear, index);
59 let seed_id = builder.add_input(seed_key.clone());
60 cotangent_env.insert(source_key, seed_id);
61 cotangent_seed_inputs.push((seed_key, seed_id));
62 }
63
64 for op_node in graph.operations().iter().rev() {
65 let cotangent_out: Vec<Option<LocalValueId>> = op_node
66 .outputs
67 .iter()
68 .map(|output_id| cotangent_env.get(&graph.values()[*output_id].key).copied())
69 .collect();
70 if cotangent_out.iter().all(Option::is_none) {
71 continue;
72 }
73
74 let rule_inputs: Vec<PrimitiveValue<Op>> = op_node
75 .inputs
76 .iter()
77 .map(|input| match input {
78 ValueRef::Local(local_id) => {
79 PrimitiveValue::External(graph.values()[*local_id].key.clone())
80 }
81 ValueRef::External(key) => PrimitiveValue::External(key.clone()),
82 })
83 .collect();
84
85 let mut primitive_builder = GraphPrimitiveBuilder::new(&mut builder);
86 let cotangent_in = op_node.operation.try_linear_transpose_rule(
87 &mut primitive_builder,
88 &cotangent_out,
89 &rule_inputs,
90 &op_node.role,
91 ctx,
92 )?;
93 assert_eq!(
94 cotangent_in.len(),
95 rule_inputs.len(),
96 "transpose_rule for {:?} returned {} cotangents for {} inputs",
97 op_node.operation,
98 cotangent_in.len(),
99 rule_inputs.len()
100 );
101
102 for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in) {
103 let cotangent_id = match maybe_cotangent {
104 Some(cotangent_id) => cotangent_id,
105 None => continue,
106 };
107 let input_key = match input {
108 PrimitiveValue::Local(_) => {
109 unreachable!("rule inputs are normalized to external refs")
110 }
111 PrimitiveValue::External(key) => key.clone(),
112 };
113
114 match cotangent_env.get(&input_key).copied() {
115 Some(existing_id) => {
116 let mut primitive_builder = GraphPrimitiveBuilder::new(&mut builder);
117 let sum = primitive_builder.add_primitive(
118 Op::add(),
119 vec![
120 PrimitiveValue::Local(existing_id),
121 PrimitiveValue::Local(cotangent_id),
122 ],
123 OperationRole::Linearized {
124 active_mask: vec![true, true],
125 },
126 );
127 cotangent_env.insert(input_key, sum[0]);
128 }
129 None => {
130 cotangent_env.insert(input_key, cotangent_id);
131 }
132 }
133 }
134 }
135
136 let tangent_outputs: Vec<Option<LocalValueId>> = linear
137 .tangent_inputs()
138 .iter()
139 .map(|(_, tangent_input_id)| {
140 let tangent_input_key = &graph.values()[*tangent_input_id].key;
141 cotangent_env.get(tangent_input_key).copied()
142 })
143 .collect();
144 let active_outputs: Vec<LocalValueId> = tangent_outputs.iter().filter_map(|id| *id).collect();
145 if !active_outputs.is_empty() {
146 builder.set_outputs(active_outputs);
147 }
148
149 Ok(LinearizedGraph::from_parts(
150 builder.build(),
151 cotangent_seed_inputs,
152 tangent_outputs,
153 ))
154}
155
156pub fn try_linear_transpose_with_builder<Op: Primitive>(
158 linear: &LinearizedGraph<Op>,
159 builder: &mut impl PrimitiveBuilder<Op>,
160 cotangent_seeds: &[Option<LocalValueId>],
161 ctx: &mut Op::ADContext,
162) -> ADRuleResult<Vec<Option<LocalValueId>>>
163where
164 Op::InputKey: ADKey,
165{
166 let mut cotangent_env: HashMap<ValueKey<Op>, LocalValueId> = HashMap::new();
167 let graph = linear.as_graph();
168
169 for (index, maybe_tangent_output) in linear.tangent_outputs().iter().enumerate() {
170 if let (Some(output_id), Some(Some(seed_id))) =
171 (maybe_tangent_output, cotangent_seeds.get(index))
172 {
173 let key = graph.values()[*output_id].key.clone();
174 cotangent_env.insert(key, *seed_id);
175 }
176 }
177
178 for op_node in graph.operations().iter().rev() {
179 let cotangent_out: Vec<Option<LocalValueId>> = op_node
180 .outputs
181 .iter()
182 .map(|output_id| cotangent_env.get(&graph.values()[*output_id].key).copied())
183 .collect();
184 if cotangent_out.iter().all(Option::is_none) {
185 continue;
186 }
187
188 let rule_inputs: Vec<PrimitiveValue<Op>> = op_node
189 .inputs
190 .iter()
191 .map(|input| match input {
192 ValueRef::Local(local_id) => {
193 PrimitiveValue::External(graph.values()[*local_id].key.clone())
194 }
195 ValueRef::External(key) => PrimitiveValue::External(key.clone()),
196 })
197 .collect();
198
199 let cotangent_in = op_node.operation.try_linear_transpose_rule(
200 builder,
201 &cotangent_out,
202 &rule_inputs,
203 &op_node.role,
204 ctx,
205 )?;
206 assert_eq!(
207 cotangent_in.len(),
208 rule_inputs.len(),
209 "transpose_rule for {:?} returned {} cotangents for {} inputs",
210 op_node.operation,
211 cotangent_in.len(),
212 rule_inputs.len()
213 );
214
215 for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in) {
216 let cotangent_id = match maybe_cotangent {
217 Some(cotangent_id) => cotangent_id,
218 None => continue,
219 };
220 let input_key = match input {
221 PrimitiveValue::Local(_) => {
222 unreachable!("rule inputs are normalized to external refs")
223 }
224 PrimitiveValue::External(key) => key.clone(),
225 };
226
227 match cotangent_env.get(&input_key).copied() {
228 Some(existing_id) => {
229 let sum = builder.add_primitive(
230 Op::add(),
231 vec![
232 PrimitiveValue::Local(existing_id),
233 PrimitiveValue::Local(cotangent_id),
234 ],
235 OperationRole::Linearized {
236 active_mask: vec![true, true],
237 },
238 );
239 cotangent_env.insert(input_key, sum[0]);
240 }
241 None => {
242 cotangent_env.insert(input_key, cotangent_id);
243 }
244 }
245 }
246 }
247
248 Ok(linear
249 .tangent_inputs()
250 .iter()
251 .map(|(_, tangent_input_id)| {
252 let tangent_input_key = &graph.values()[*tangent_input_id].key;
253 cotangent_env.get(tangent_input_key).copied()
254 })
255 .collect())
256}
257
258fn cotangent_seed_key<Op: Primitive>(linear: &LinearizedGraph<Op>, index: usize) -> Op::InputKey
259where
260 Op::InputKey: ADKey,
261{
262 assert!(
263 !linear.tangent_inputs().is_empty(),
264 "active tangent outputs require at least one tangent input to derive seed keys"
265 );
266
267 let base_slot = index.min(linear.tangent_inputs().len() - 1);
268 let base_key = &linear.tangent_inputs()[base_slot].0;
269 base_key.tangent_of(u64::MAX - index as u64)
270}