1use std::collections::HashMap;
2
3use chainrules::{ADKey, ADRuleResult, PrimitiveOp};
4use computegraph::fragment::FragmentBuilder;
5use computegraph::{GlobalValKey, LocalValId, OpMode, ValRef};
6
7use crate::LinearFragment;
8
9pub fn transpose<Op: PrimitiveOp>(
22 linear: &LinearFragment<Op>,
23 ctx: &mut Op::ADContext,
24) -> LinearFragment<Op>
25where
26 Op::InputKey: ADKey,
27{
28 match try_transpose(linear, ctx) {
29 Ok(transposed) => transposed,
30 Err(err) => panic!("{err}"),
31 }
32}
33
34pub fn try_transpose<Op: PrimitiveOp>(
39 linear: &LinearFragment<Op>,
40 ctx: &mut Op::ADContext,
41) -> ADRuleResult<LinearFragment<Op>>
42where
43 Op::InputKey: ADKey,
44{
45 let mut builder = FragmentBuilder::<Op>::new();
46 let mut cotangent_env: HashMap<GlobalValKey<Op>, LocalValId> = HashMap::new();
47 let mut cotangent_seed_inputs = Vec::new();
48
49 for (index, maybe_tangent_output) in linear.tangent_outputs.iter().enumerate() {
50 let Some(tangent_output_id) = maybe_tangent_output else {
51 continue;
52 };
53
54 let source_key = linear.fragment.vals()[*tangent_output_id].key.clone();
55 let seed_key = cotangent_seed_key(linear, index);
56 let seed_id = builder.add_input(seed_key.clone());
57 cotangent_env.insert(source_key, seed_id);
58 cotangent_seed_inputs.push((seed_key, seed_id));
59 }
60
61 for op_node in linear.fragment.ops().iter().rev() {
62 let cotangent_out: Vec<Option<LocalValId>> = op_node
63 .outputs
64 .iter()
65 .map(|output_id| {
66 cotangent_env
67 .get(&linear.fragment.vals()[*output_id].key)
68 .copied()
69 })
70 .collect();
71 if cotangent_out.iter().all(Option::is_none) {
72 continue;
73 }
74
75 let rule_inputs: Vec<ValRef<Op>> = op_node
76 .inputs
77 .iter()
78 .map(|input| match input {
79 ValRef::Local(local_id) => {
80 ValRef::External(linear.fragment.vals()[*local_id].key.clone())
81 }
82 ValRef::External(key) => ValRef::External(key.clone()),
83 })
84 .collect();
85
86 let cotangent_in = op_node.op.try_transpose_rule(
87 &mut builder,
88 &cotangent_out,
89 &rule_inputs,
90 &op_node.mode,
91 ctx,
92 )?;
93 assert_eq!(
94 cotangent_in.len(),
95 rule_inputs.len(),
96 "transpose_rule for {:?} returned {} cotangents for {} inputs",
97 op_node.op,
98 cotangent_in.len(),
99 rule_inputs.len()
100 );
101
102 for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in.into_iter()) {
103 let Some(cotangent_id) = maybe_cotangent else {
104 continue;
105 };
106 let input_key = match input {
107 ValRef::Local(_) => unreachable!("rule inputs are normalized to external refs"),
108 ValRef::External(key) => key.clone(),
109 };
110
111 match cotangent_env.get(&input_key).copied() {
112 Some(existing_id) => {
113 let sum = builder.add_op(
114 Op::add(),
115 vec![ValRef::Local(existing_id), ValRef::Local(cotangent_id)],
116 OpMode::Linear {
117 active_mask: vec![true, true],
118 },
119 );
120 cotangent_env.insert(input_key, sum[0]);
121 }
122 None => {
123 cotangent_env.insert(input_key, cotangent_id);
124 }
125 }
126 }
127 }
128
129 let tangent_outputs: Vec<Option<LocalValId>> = linear
130 .tangent_inputs
131 .iter()
132 .map(|(_, tangent_input_id)| {
133 let tangent_input_key = &linear.fragment.vals()[*tangent_input_id].key;
134 cotangent_env.get(tangent_input_key).copied()
135 })
136 .collect();
137 let active_outputs: Vec<LocalValId> = tangent_outputs.iter().filter_map(|id| *id).collect();
138 if !active_outputs.is_empty() {
139 builder.set_outputs(active_outputs);
140 }
141
142 Ok(LinearFragment {
143 fragment: builder.build(),
144 tangent_inputs: cotangent_seed_inputs,
145 tangent_outputs,
146 })
147}
148
149fn cotangent_seed_key<Op: PrimitiveOp>(linear: &LinearFragment<Op>, index: usize) -> Op::InputKey
150where
151 Op::InputKey: ADKey,
152{
153 assert!(
154 !linear.tangent_inputs.is_empty(),
155 "active tangent outputs require at least one tangent input to derive seed keys"
156 );
157
158 let base_slot = index.min(linear.tangent_inputs.len() - 1);
159 let base_key = &linear.tangent_inputs[base_slot].0;
160 base_key.tangent_of(u64::MAX - index as u64)
161}