Skip to main content

tidu/
linear_transpose.rs

1use std::collections::HashMap;
2
3use crate::rules::GraphPrimitiveBuilder;
4use crate::{ADKey, ADRuleResult, Primitive, PrimitiveBuilder, PrimitiveValue};
5use computegraph::graph::GraphBuilder;
6use computegraph::{LocalValueId, OperationRole, ValueKey, ValueRef};
7
8use crate::LinearizedGraph;
9
10/// Transpose a linearized graph, reversing linear flow.
11///
12/// Fan-out accumulation is emitted explicitly with [`crate::Primitive::add`];
13/// no duplication primitive is assumed by the graph transform.
14///
15/// # Examples
16///
17/// ```ignore
18/// let mut ctx = ();
19/// let transposed = tidu::linear_transpose(&linear, &mut ctx);
20/// assert_eq!(transposed.tangent_outputs().len(), linear.tangent_inputs().len());
21/// ```
22pub fn linear_transpose<Op: Primitive>(
23    linear: &LinearizedGraph<Op>,
24    ctx: &mut Op::ADContext,
25) -> LinearizedGraph<Op>
26where
27    Op::InputKey: ADKey,
28{
29    match try_linear_transpose(linear, ctx) {
30        Ok(transposed) => transposed,
31        Err(err) => panic!("{}", err),
32    }
33}
34
35/// Fallible form of [`linear_transpose`].
36///
37/// This returns [`crate::ADRuleError`] when a primitive cannot emit a
38/// transpose rule.
39pub fn try_linear_transpose<Op: Primitive>(
40    linear: &LinearizedGraph<Op>,
41    ctx: &mut Op::ADContext,
42) -> ADRuleResult<LinearizedGraph<Op>>
43where
44    Op::InputKey: ADKey,
45{
46    let mut builder = GraphBuilder::<Op>::new();
47    let mut cotangent_env: HashMap<ValueKey<Op>, LocalValueId> = HashMap::new();
48    let mut cotangent_seed_inputs = Vec::new();
49    let graph = linear.as_graph();
50
51    for (index, maybe_tangent_output) in linear.tangent_outputs().iter().enumerate() {
52        let tangent_output_id = match maybe_tangent_output {
53            Some(tangent_output_id) => tangent_output_id,
54            None => continue,
55        };
56
57        let source_key = graph.values()[*tangent_output_id].key.clone();
58        let seed_key = cotangent_seed_key(linear, index);
59        let seed_id = builder.add_input(seed_key.clone());
60        cotangent_env.insert(source_key, seed_id);
61        cotangent_seed_inputs.push((seed_key, seed_id));
62    }
63
64    for op_node in graph.operations().iter().rev() {
65        let cotangent_out: Vec<Option<LocalValueId>> = op_node
66            .outputs
67            .iter()
68            .map(|output_id| cotangent_env.get(&graph.values()[*output_id].key).copied())
69            .collect();
70        if cotangent_out.iter().all(Option::is_none) {
71            continue;
72        }
73
74        let rule_inputs: Vec<PrimitiveValue<Op>> = op_node
75            .inputs
76            .iter()
77            .map(|input| match input {
78                ValueRef::Local(local_id) => {
79                    PrimitiveValue::External(graph.values()[*local_id].key.clone())
80                }
81                ValueRef::External(key) => PrimitiveValue::External(key.clone()),
82            })
83            .collect();
84
85        let mut primitive_builder = GraphPrimitiveBuilder::new(&mut builder);
86        let cotangent_in = op_node.operation.try_linear_transpose_rule(
87            &mut primitive_builder,
88            &cotangent_out,
89            &rule_inputs,
90            &op_node.role,
91            ctx,
92        )?;
93        assert_eq!(
94            cotangent_in.len(),
95            rule_inputs.len(),
96            "transpose_rule for {:?} returned {} cotangents for {} inputs",
97            op_node.operation,
98            cotangent_in.len(),
99            rule_inputs.len()
100        );
101
102        for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in) {
103            let cotangent_id = match maybe_cotangent {
104                Some(cotangent_id) => cotangent_id,
105                None => continue,
106            };
107            let input_key = match input {
108                PrimitiveValue::Local(_) => {
109                    unreachable!("rule inputs are normalized to external refs")
110                }
111                PrimitiveValue::External(key) => key.clone(),
112            };
113
114            match cotangent_env.get(&input_key).copied() {
115                Some(existing_id) => {
116                    let mut primitive_builder = GraphPrimitiveBuilder::new(&mut builder);
117                    let sum = primitive_builder.add_primitive(
118                        Op::add(),
119                        vec![
120                            PrimitiveValue::Local(existing_id),
121                            PrimitiveValue::Local(cotangent_id),
122                        ],
123                        OperationRole::Linearized {
124                            active_mask: vec![true, true],
125                        },
126                    );
127                    cotangent_env.insert(input_key, sum[0]);
128                }
129                None => {
130                    cotangent_env.insert(input_key, cotangent_id);
131                }
132            }
133        }
134    }
135
136    let tangent_outputs: Vec<Option<LocalValueId>> = linear
137        .tangent_inputs()
138        .iter()
139        .map(|(_, tangent_input_id)| {
140            let tangent_input_key = &graph.values()[*tangent_input_id].key;
141            cotangent_env.get(tangent_input_key).copied()
142        })
143        .collect();
144    let active_outputs: Vec<LocalValueId> = tangent_outputs.iter().filter_map(|id| *id).collect();
145    if !active_outputs.is_empty() {
146        builder.set_outputs(active_outputs);
147    }
148
149    Ok(LinearizedGraph::from_parts(
150        builder.build(),
151        cotangent_seed_inputs,
152        tangent_outputs,
153    ))
154}
155
156/// Execute the transpose of a linearized graph using a caller-provided builder.
157pub fn try_linear_transpose_with_builder<Op: Primitive>(
158    linear: &LinearizedGraph<Op>,
159    builder: &mut impl PrimitiveBuilder<Op>,
160    cotangent_seeds: &[Option<LocalValueId>],
161    ctx: &mut Op::ADContext,
162) -> ADRuleResult<Vec<Option<LocalValueId>>>
163where
164    Op::InputKey: ADKey,
165{
166    let mut cotangent_env: HashMap<ValueKey<Op>, LocalValueId> = HashMap::new();
167    let graph = linear.as_graph();
168
169    for (index, maybe_tangent_output) in linear.tangent_outputs().iter().enumerate() {
170        if let (Some(output_id), Some(Some(seed_id))) =
171            (maybe_tangent_output, cotangent_seeds.get(index))
172        {
173            let key = graph.values()[*output_id].key.clone();
174            cotangent_env.insert(key, *seed_id);
175        }
176    }
177
178    for op_node in graph.operations().iter().rev() {
179        let cotangent_out: Vec<Option<LocalValueId>> = op_node
180            .outputs
181            .iter()
182            .map(|output_id| cotangent_env.get(&graph.values()[*output_id].key).copied())
183            .collect();
184        if cotangent_out.iter().all(Option::is_none) {
185            continue;
186        }
187
188        let rule_inputs: Vec<PrimitiveValue<Op>> = op_node
189            .inputs
190            .iter()
191            .map(|input| match input {
192                ValueRef::Local(local_id) => {
193                    PrimitiveValue::External(graph.values()[*local_id].key.clone())
194                }
195                ValueRef::External(key) => PrimitiveValue::External(key.clone()),
196            })
197            .collect();
198
199        let cotangent_in = op_node.operation.try_linear_transpose_rule(
200            builder,
201            &cotangent_out,
202            &rule_inputs,
203            &op_node.role,
204            ctx,
205        )?;
206        assert_eq!(
207            cotangent_in.len(),
208            rule_inputs.len(),
209            "transpose_rule for {:?} returned {} cotangents for {} inputs",
210            op_node.operation,
211            cotangent_in.len(),
212            rule_inputs.len()
213        );
214
215        for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in) {
216            let cotangent_id = match maybe_cotangent {
217                Some(cotangent_id) => cotangent_id,
218                None => continue,
219            };
220            let input_key = match input {
221                PrimitiveValue::Local(_) => {
222                    unreachable!("rule inputs are normalized to external refs")
223                }
224                PrimitiveValue::External(key) => key.clone(),
225            };
226
227            match cotangent_env.get(&input_key).copied() {
228                Some(existing_id) => {
229                    let sum = builder.add_primitive(
230                        Op::add(),
231                        vec![
232                            PrimitiveValue::Local(existing_id),
233                            PrimitiveValue::Local(cotangent_id),
234                        ],
235                        OperationRole::Linearized {
236                            active_mask: vec![true, true],
237                        },
238                    );
239                    cotangent_env.insert(input_key, sum[0]);
240                }
241                None => {
242                    cotangent_env.insert(input_key, cotangent_id);
243                }
244            }
245        }
246    }
247
248    Ok(linear
249        .tangent_inputs()
250        .iter()
251        .map(|(_, tangent_input_id)| {
252            let tangent_input_key = &graph.values()[*tangent_input_id].key;
253            cotangent_env.get(tangent_input_key).copied()
254        })
255        .collect())
256}
257
258fn cotangent_seed_key<Op: Primitive>(linear: &LinearizedGraph<Op>, index: usize) -> Op::InputKey
259where
260    Op::InputKey: ADKey,
261{
262    assert!(
263        !linear.tangent_inputs().is_empty(),
264        "active tangent outputs require at least one tangent input to derive seed keys"
265    );
266
267    let base_slot = index.min(linear.tangent_inputs().len() - 1);
268    let base_key = &linear.tangent_inputs()[base_slot].0;
269    base_key.tangent_of(u64::MAX - index as u64)
270}