1use std::collections::HashMap;
2
3use chainrules::{ADRuleResult, PrimitiveOp};
4use computegraph::{GlobalValKey, LocalValId, OpEmitter, OpMode, ValRef};
5
6use crate::LinearFragment;
7
8pub fn eager_transpose_fragment<Op: PrimitiveOp>(
13 linear: &LinearFragment<Op>,
14 emitter: &mut impl OpEmitter<Op>,
15 cotangent_seeds: &[Option<LocalValId>],
16 ctx: &mut Op::ADContext,
17) -> Vec<Option<LocalValId>>
18where
19 Op::InputKey: chainrules::ADKey,
20{
21 match try_eager_transpose_fragment(linear, emitter, cotangent_seeds, ctx) {
22 Ok(cotangents) => cotangents,
23 Err(err) => panic!("{err}"),
24 }
25}
26
27pub fn try_eager_transpose_fragment<Op: PrimitiveOp>(
29 linear: &LinearFragment<Op>,
30 emitter: &mut impl OpEmitter<Op>,
31 cotangent_seeds: &[Option<LocalValId>],
32 ctx: &mut Op::ADContext,
33) -> ADRuleResult<Vec<Option<LocalValId>>>
34where
35 Op::InputKey: chainrules::ADKey,
36{
37 let mut cotangent_env: HashMap<GlobalValKey<Op>, LocalValId> = HashMap::new();
38
39 for (index, maybe_tangent_output) in linear.tangent_outputs.iter().enumerate() {
40 if let (Some(output_id), Some(Some(seed_id))) =
41 (maybe_tangent_output, cotangent_seeds.get(index))
42 {
43 let key = linear.fragment.vals()[*output_id].key.clone();
44 cotangent_env.insert(key, *seed_id);
45 }
46 }
47
48 for op_node in linear.fragment.ops().iter().rev() {
49 let cotangent_out: Vec<Option<LocalValId>> = op_node
50 .outputs
51 .iter()
52 .map(|output_id| {
53 cotangent_env
54 .get(&linear.fragment.vals()[*output_id].key)
55 .copied()
56 })
57 .collect();
58 if cotangent_out.iter().all(Option::is_none) {
59 continue;
60 }
61
62 let rule_inputs: Vec<ValRef<Op>> = op_node
63 .inputs
64 .iter()
65 .map(|input| match input {
66 ValRef::Local(local_id) => {
67 ValRef::External(linear.fragment.vals()[*local_id].key.clone())
68 }
69 ValRef::External(key) => ValRef::External(key.clone()),
70 })
71 .collect();
72
73 let cotangent_in = op_node.op.try_transpose_rule(
74 emitter,
75 &cotangent_out,
76 &rule_inputs,
77 &op_node.mode,
78 ctx,
79 )?;
80 assert_eq!(
81 cotangent_in.len(),
82 rule_inputs.len(),
83 "transpose_rule for {:?} returned {} cotangents for {} inputs",
84 op_node.op,
85 cotangent_in.len(),
86 rule_inputs.len()
87 );
88
89 for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in.into_iter()) {
90 let Some(cotangent_id) = maybe_cotangent else {
91 continue;
92 };
93 let input_key = match input {
94 ValRef::Local(_) => unreachable!("rule inputs are normalized to external refs"),
95 ValRef::External(key) => key.clone(),
96 };
97
98 match cotangent_env.get(&input_key).copied() {
99 Some(existing_id) => {
100 let sum = emitter.add_op(
101 Op::add(),
102 vec![ValRef::Local(existing_id), ValRef::Local(cotangent_id)],
103 OpMode::Linear {
104 active_mask: vec![true, true],
105 },
106 );
107 cotangent_env.insert(input_key, sum[0]);
108 }
109 None => {
110 cotangent_env.insert(input_key, cotangent_id);
111 }
112 }
113 }
114 }
115
116 Ok(linear
117 .tangent_inputs
118 .iter()
119 .map(|(_, tangent_input_id)| {
120 let tangent_input_key = &linear.fragment.vals()[*tangent_input_id].key;
121 cotangent_env.get(tangent_input_key).copied()
122 })
123 .collect())
124}