Skip to main content

tidu/
eager_transpose.rs

1use std::collections::HashMap;
2
3use chainrules::{ADRuleResult, PrimitiveOp};
4use computegraph::{GlobalValKey, LocalValId, OpEmitter, OpMode, ValRef};
5
6use crate::LinearFragment;
7
8/// Execute the transpose of a linear fragment using an eager emitter.
9///
10/// This mirrors [`crate::transpose`] but leaves execution strategy to the
11/// caller-provided [`computegraph::OpEmitter`].
12pub fn eager_transpose_fragment<Op: PrimitiveOp>(
13    linear: &LinearFragment<Op>,
14    emitter: &mut impl OpEmitter<Op>,
15    cotangent_seeds: &[Option<LocalValId>],
16    ctx: &mut Op::ADContext,
17) -> Vec<Option<LocalValId>>
18where
19    Op::InputKey: chainrules::ADKey,
20{
21    match try_eager_transpose_fragment(linear, emitter, cotangent_seeds, ctx) {
22        Ok(cotangents) => cotangents,
23        Err(err) => panic!("{err}"),
24    }
25}
26
27/// Fallible form of [`eager_transpose_fragment`].
28pub fn try_eager_transpose_fragment<Op: PrimitiveOp>(
29    linear: &LinearFragment<Op>,
30    emitter: &mut impl OpEmitter<Op>,
31    cotangent_seeds: &[Option<LocalValId>],
32    ctx: &mut Op::ADContext,
33) -> ADRuleResult<Vec<Option<LocalValId>>>
34where
35    Op::InputKey: chainrules::ADKey,
36{
37    let mut cotangent_env: HashMap<GlobalValKey<Op>, LocalValId> = HashMap::new();
38
39    for (index, maybe_tangent_output) in linear.tangent_outputs.iter().enumerate() {
40        if let (Some(output_id), Some(Some(seed_id))) =
41            (maybe_tangent_output, cotangent_seeds.get(index))
42        {
43            let key = linear.fragment.vals()[*output_id].key.clone();
44            cotangent_env.insert(key, *seed_id);
45        }
46    }
47
48    for op_node in linear.fragment.ops().iter().rev() {
49        let cotangent_out: Vec<Option<LocalValId>> = op_node
50            .outputs
51            .iter()
52            .map(|output_id| {
53                cotangent_env
54                    .get(&linear.fragment.vals()[*output_id].key)
55                    .copied()
56            })
57            .collect();
58        if cotangent_out.iter().all(Option::is_none) {
59            continue;
60        }
61
62        let rule_inputs: Vec<ValRef<Op>> = op_node
63            .inputs
64            .iter()
65            .map(|input| match input {
66                ValRef::Local(local_id) => {
67                    ValRef::External(linear.fragment.vals()[*local_id].key.clone())
68                }
69                ValRef::External(key) => ValRef::External(key.clone()),
70            })
71            .collect();
72
73        let cotangent_in = op_node.op.try_transpose_rule(
74            emitter,
75            &cotangent_out,
76            &rule_inputs,
77            &op_node.mode,
78            ctx,
79        )?;
80        assert_eq!(
81            cotangent_in.len(),
82            rule_inputs.len(),
83            "transpose_rule for {:?} returned {} cotangents for {} inputs",
84            op_node.op,
85            cotangent_in.len(),
86            rule_inputs.len()
87        );
88
89        for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in.into_iter()) {
90            let Some(cotangent_id) = maybe_cotangent else {
91                continue;
92            };
93            let input_key = match input {
94                ValRef::Local(_) => unreachable!("rule inputs are normalized to external refs"),
95                ValRef::External(key) => key.clone(),
96            };
97
98            match cotangent_env.get(&input_key).copied() {
99                Some(existing_id) => {
100                    let sum = emitter.add_op(
101                        Op::add(),
102                        vec![ValRef::Local(existing_id), ValRef::Local(cotangent_id)],
103                        OpMode::Linear {
104                            active_mask: vec![true, true],
105                        },
106                    );
107                    cotangent_env.insert(input_key, sum[0]);
108                }
109                None => {
110                    cotangent_env.insert(input_key, cotangent_id);
111                }
112            }
113        }
114    }
115
116    Ok(linear
117        .tangent_inputs
118        .iter()
119        .map(|(_, tangent_input_id)| {
120            let tangent_input_key = &linear.fragment.vals()[*tangent_input_id].key;
121            cotangent_env.get(tangent_input_key).copied()
122        })
123        .collect())
124}