1use std::collections::HashMap;
2
3use chainrules::PrimitiveOp;
4use computegraph::{GlobalValKey, LocalValId, OpEmitter, OpMode, ValRef};
5
6use crate::LinearFragment;
7
8pub fn eager_transpose_fragment<Op: PrimitiveOp>(
13 linear: &LinearFragment<Op>,
14 emitter: &mut impl OpEmitter<Op>,
15 cotangent_seeds: &[Option<LocalValId>],
16 ctx: &mut Op::ADContext,
17) -> Vec<Option<LocalValId>>
18where
19 Op::InputKey: chainrules::ADKey,
20{
21 let mut cotangent_env: HashMap<GlobalValKey<Op>, LocalValId> = HashMap::new();
22
23 for (index, maybe_tangent_output) in linear.tangent_outputs.iter().enumerate() {
24 if let (Some(output_id), Some(Some(seed_id))) =
25 (maybe_tangent_output, cotangent_seeds.get(index))
26 {
27 let key = linear.fragment.vals()[*output_id].key.clone();
28 cotangent_env.insert(key, *seed_id);
29 }
30 }
31
32 for op_node in linear.fragment.ops().iter().rev() {
33 let cotangent_out: Vec<Option<LocalValId>> = op_node
34 .outputs
35 .iter()
36 .map(|output_id| {
37 cotangent_env
38 .get(&linear.fragment.vals()[*output_id].key)
39 .copied()
40 })
41 .collect();
42 if cotangent_out.iter().all(Option::is_none) {
43 continue;
44 }
45
46 let rule_inputs: Vec<ValRef<Op>> = op_node
47 .inputs
48 .iter()
49 .map(|input| match input {
50 ValRef::Local(local_id) => {
51 ValRef::External(linear.fragment.vals()[*local_id].key.clone())
52 }
53 ValRef::External(key) => ValRef::External(key.clone()),
54 })
55 .collect();
56
57 let cotangent_in =
58 op_node
59 .op
60 .transpose_rule(emitter, &cotangent_out, &rule_inputs, &op_node.mode, ctx);
61 assert_eq!(
62 cotangent_in.len(),
63 rule_inputs.len(),
64 "transpose_rule for {:?} returned {} cotangents for {} inputs",
65 op_node.op,
66 cotangent_in.len(),
67 rule_inputs.len()
68 );
69
70 for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in.into_iter()) {
71 let Some(cotangent_id) = maybe_cotangent else {
72 continue;
73 };
74 let input_key = match input {
75 ValRef::Local(_) => unreachable!("rule inputs are normalized to external refs"),
76 ValRef::External(key) => key.clone(),
77 };
78
79 match cotangent_env.get(&input_key).copied() {
80 Some(existing_id) => {
81 let sum = emitter.add_op(
82 Op::add(),
83 vec![ValRef::Local(existing_id), ValRef::Local(cotangent_id)],
84 OpMode::Linear {
85 active_mask: vec![true, true],
86 },
87 );
88 cotangent_env.insert(input_key, sum[0]);
89 }
90 None => {
91 cotangent_env.insert(input_key, cotangent_id);
92 }
93 }
94 }
95 }
96
97 linear
98 .tangent_inputs
99 .iter()
100 .map(|(_, tangent_input_id)| {
101 let tangent_input_key = &linear.fragment.vals()[*tangent_input_id].key;
102 cotangent_env.get(tangent_input_key).copied()
103 })
104 .collect()
105}