1use std::collections::HashMap;
2
3use chainrules::{ADKey, PrimitiveOp};
4use computegraph::fragment::FragmentBuilder;
5use computegraph::{GlobalValKey, LocalValId, OpMode, ValRef};
6
7use crate::LinearFragment;
8
9pub fn transpose<Op: PrimitiveOp>(
22 linear: &LinearFragment<Op>,
23 ctx: &mut Op::ADContext,
24) -> LinearFragment<Op>
25where
26 Op::InputKey: ADKey,
27{
28 let mut builder = FragmentBuilder::<Op>::new();
29 let mut cotangent_env: HashMap<GlobalValKey<Op>, LocalValId> = HashMap::new();
30 let mut cotangent_seed_inputs = Vec::new();
31
32 for (index, maybe_tangent_output) in linear.tangent_outputs.iter().enumerate() {
33 let Some(tangent_output_id) = maybe_tangent_output else {
34 continue;
35 };
36
37 let source_key = linear.fragment.vals()[*tangent_output_id].key.clone();
38 let seed_key = cotangent_seed_key(linear, index);
39 let seed_id = builder.add_input(seed_key.clone());
40 cotangent_env.insert(source_key, seed_id);
41 cotangent_seed_inputs.push((seed_key, seed_id));
42 }
43
44 for op_node in linear.fragment.ops().iter().rev() {
45 let cotangent_out: Vec<Option<LocalValId>> = op_node
46 .outputs
47 .iter()
48 .map(|output_id| {
49 cotangent_env
50 .get(&linear.fragment.vals()[*output_id].key)
51 .copied()
52 })
53 .collect();
54 if cotangent_out.iter().all(Option::is_none) {
55 continue;
56 }
57
58 let rule_inputs: Vec<ValRef<Op>> = op_node
59 .inputs
60 .iter()
61 .map(|input| match input {
62 ValRef::Local(local_id) => {
63 ValRef::External(linear.fragment.vals()[*local_id].key.clone())
64 }
65 ValRef::External(key) => ValRef::External(key.clone()),
66 })
67 .collect();
68
69 let cotangent_in = op_node.op.transpose_rule(
70 &mut builder,
71 &cotangent_out,
72 &rule_inputs,
73 &op_node.mode,
74 ctx,
75 );
76 assert_eq!(
77 cotangent_in.len(),
78 rule_inputs.len(),
79 "transpose_rule for {:?} returned {} cotangents for {} inputs",
80 op_node.op,
81 cotangent_in.len(),
82 rule_inputs.len()
83 );
84
85 for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in.into_iter()) {
86 let Some(cotangent_id) = maybe_cotangent else {
87 continue;
88 };
89 let input_key = match input {
90 ValRef::Local(_) => unreachable!("rule inputs are normalized to external refs"),
91 ValRef::External(key) => key.clone(),
92 };
93
94 match cotangent_env.get(&input_key).copied() {
95 Some(existing_id) => {
96 let sum = builder.add_op(
97 Op::add(),
98 vec![ValRef::Local(existing_id), ValRef::Local(cotangent_id)],
99 OpMode::Linear {
100 active_mask: vec![true, true],
101 },
102 );
103 cotangent_env.insert(input_key, sum[0]);
104 }
105 None => {
106 cotangent_env.insert(input_key, cotangent_id);
107 }
108 }
109 }
110 }
111
112 let tangent_outputs: Vec<Option<LocalValId>> = linear
113 .tangent_inputs
114 .iter()
115 .map(|(_, tangent_input_id)| {
116 let tangent_input_key = &linear.fragment.vals()[*tangent_input_id].key;
117 cotangent_env.get(tangent_input_key).copied()
118 })
119 .collect();
120 let active_outputs: Vec<LocalValId> = tangent_outputs.iter().filter_map(|id| *id).collect();
121 if !active_outputs.is_empty() {
122 builder.set_outputs(active_outputs);
123 }
124
125 LinearFragment {
126 fragment: builder.build(),
127 tangent_inputs: cotangent_seed_inputs,
128 tangent_outputs,
129 }
130}
131
132fn cotangent_seed_key<Op: PrimitiveOp>(linear: &LinearFragment<Op>, index: usize) -> Op::InputKey
133where
134 Op::InputKey: ADKey,
135{
136 assert!(
137 !linear.tangent_inputs.is_empty(),
138 "active tangent outputs require at least one tangent input to derive seed keys"
139 );
140
141 let base_slot = index.min(linear.tangent_inputs.len() - 1);
142 let base_key = &linear.tangent_inputs[base_slot].0;
143 base_key.tangent_of(u64::MAX - index as u64)
144}