Skip to main content

tidu/
eager_transpose.rs

1use std::collections::HashMap;
2
3use chainrules::PrimitiveOp;
4use computegraph::{GlobalValKey, LocalValId, OpEmitter, OpMode, ValRef};
5
6use crate::LinearFragment;
7
8/// Execute the transpose of a linear fragment using an eager emitter.
9///
10/// This mirrors [`crate::transpose`] but leaves execution strategy to the
11/// caller-provided [`computegraph::OpEmitter`].
12pub fn eager_transpose_fragment<Op: PrimitiveOp>(
13    linear: &LinearFragment<Op>,
14    emitter: &mut impl OpEmitter<Op>,
15    cotangent_seeds: &[Option<LocalValId>],
16    ctx: &mut Op::ADContext,
17) -> Vec<Option<LocalValId>>
18where
19    Op::InputKey: chainrules::ADKey,
20{
21    let mut cotangent_env: HashMap<GlobalValKey<Op>, LocalValId> = HashMap::new();
22
23    for (index, maybe_tangent_output) in linear.tangent_outputs.iter().enumerate() {
24        if let (Some(output_id), Some(Some(seed_id))) =
25            (maybe_tangent_output, cotangent_seeds.get(index))
26        {
27            let key = linear.fragment.vals()[*output_id].key.clone();
28            cotangent_env.insert(key, *seed_id);
29        }
30    }
31
32    for op_node in linear.fragment.ops().iter().rev() {
33        let cotangent_out: Vec<Option<LocalValId>> = op_node
34            .outputs
35            .iter()
36            .map(|output_id| {
37                cotangent_env
38                    .get(&linear.fragment.vals()[*output_id].key)
39                    .copied()
40            })
41            .collect();
42        if cotangent_out.iter().all(Option::is_none) {
43            continue;
44        }
45
46        let rule_inputs: Vec<ValRef<Op>> = op_node
47            .inputs
48            .iter()
49            .map(|input| match input {
50                ValRef::Local(local_id) => {
51                    ValRef::External(linear.fragment.vals()[*local_id].key.clone())
52                }
53                ValRef::External(key) => ValRef::External(key.clone()),
54            })
55            .collect();
56
57        let cotangent_in =
58            op_node
59                .op
60                .transpose_rule(emitter, &cotangent_out, &rule_inputs, &op_node.mode, ctx);
61        assert_eq!(
62            cotangent_in.len(),
63            rule_inputs.len(),
64            "transpose_rule for {:?} returned {} cotangents for {} inputs",
65            op_node.op,
66            cotangent_in.len(),
67            rule_inputs.len()
68        );
69
70        for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in.into_iter()) {
71            let Some(cotangent_id) = maybe_cotangent else {
72                continue;
73            };
74            let input_key = match input {
75                ValRef::Local(_) => unreachable!("rule inputs are normalized to external refs"),
76                ValRef::External(key) => key.clone(),
77            };
78
79            match cotangent_env.get(&input_key).copied() {
80                Some(existing_id) => {
81                    let sum = emitter.add_op(
82                        Op::add(),
83                        vec![ValRef::Local(existing_id), ValRef::Local(cotangent_id)],
84                        OpMode::Linear {
85                            active_mask: vec![true, true],
86                        },
87                    );
88                    cotangent_env.insert(input_key, sum[0]);
89                }
90                None => {
91                    cotangent_env.insert(input_key, cotangent_id);
92                }
93            }
94        }
95    }
96
97    linear
98        .tangent_inputs
99        .iter()
100        .map(|(_, tangent_input_id)| {
101            let tangent_input_key = &linear.fragment.vals()[*tangent_input_id].key;
102            cotangent_env.get(tangent_input_key).copied()
103        })
104        .collect()
105}