Skip to main content

tidu/
transpose.rs

1use std::collections::HashMap;
2
3use chainrules::{ADKey, ADRuleResult, PrimitiveOp};
4use computegraph::fragment::FragmentBuilder;
5use computegraph::{GlobalValKey, LocalValId, OpMode, ValRef};
6
7use crate::LinearFragment;
8
9/// Transpose a linear fragment, reversing linear flow.
10///
11/// Fan-out accumulation is emitted explicitly with [`chainrules::PrimitiveOp::add`];
12/// no duplication primitive is assumed by the graph transform.
13///
14/// # Examples
15///
16/// ```ignore
17/// let mut ctx = ();
18/// let transposed = tidu::transpose(&linear_fragment, &mut ctx);
19/// assert_eq!(transposed.tangent_outputs.len(), linear_fragment.tangent_inputs.len());
20/// ```
21pub fn transpose<Op: PrimitiveOp>(
22    linear: &LinearFragment<Op>,
23    ctx: &mut Op::ADContext,
24) -> LinearFragment<Op>
25where
26    Op::InputKey: ADKey,
27{
28    match try_transpose(linear, ctx) {
29        Ok(transposed) => transposed,
30        Err(err) => panic!("{err}"),
31    }
32}
33
34/// Fallible form of [`transpose`].
35///
36/// This returns [`chainrules::ADRuleError`] when a primitive cannot emit a
37/// transpose rule.
38pub fn try_transpose<Op: PrimitiveOp>(
39    linear: &LinearFragment<Op>,
40    ctx: &mut Op::ADContext,
41) -> ADRuleResult<LinearFragment<Op>>
42where
43    Op::InputKey: ADKey,
44{
45    let mut builder = FragmentBuilder::<Op>::new();
46    let mut cotangent_env: HashMap<GlobalValKey<Op>, LocalValId> = HashMap::new();
47    let mut cotangent_seed_inputs = Vec::new();
48
49    for (index, maybe_tangent_output) in linear.tangent_outputs.iter().enumerate() {
50        let Some(tangent_output_id) = maybe_tangent_output else {
51            continue;
52        };
53
54        let source_key = linear.fragment.vals()[*tangent_output_id].key.clone();
55        let seed_key = cotangent_seed_key(linear, index);
56        let seed_id = builder.add_input(seed_key.clone());
57        cotangent_env.insert(source_key, seed_id);
58        cotangent_seed_inputs.push((seed_key, seed_id));
59    }
60
61    for op_node in linear.fragment.ops().iter().rev() {
62        let cotangent_out: Vec<Option<LocalValId>> = op_node
63            .outputs
64            .iter()
65            .map(|output_id| {
66                cotangent_env
67                    .get(&linear.fragment.vals()[*output_id].key)
68                    .copied()
69            })
70            .collect();
71        if cotangent_out.iter().all(Option::is_none) {
72            continue;
73        }
74
75        let rule_inputs: Vec<ValRef<Op>> = op_node
76            .inputs
77            .iter()
78            .map(|input| match input {
79                ValRef::Local(local_id) => {
80                    ValRef::External(linear.fragment.vals()[*local_id].key.clone())
81                }
82                ValRef::External(key) => ValRef::External(key.clone()),
83            })
84            .collect();
85
86        let cotangent_in = op_node.op.try_transpose_rule(
87            &mut builder,
88            &cotangent_out,
89            &rule_inputs,
90            &op_node.mode,
91            ctx,
92        )?;
93        assert_eq!(
94            cotangent_in.len(),
95            rule_inputs.len(),
96            "transpose_rule for {:?} returned {} cotangents for {} inputs",
97            op_node.op,
98            cotangent_in.len(),
99            rule_inputs.len()
100        );
101
102        for (input, maybe_cotangent) in rule_inputs.iter().zip(cotangent_in.into_iter()) {
103            let Some(cotangent_id) = maybe_cotangent else {
104                continue;
105            };
106            let input_key = match input {
107                ValRef::Local(_) => unreachable!("rule inputs are normalized to external refs"),
108                ValRef::External(key) => key.clone(),
109            };
110
111            match cotangent_env.get(&input_key).copied() {
112                Some(existing_id) => {
113                    let sum = builder.add_op(
114                        Op::add(),
115                        vec![ValRef::Local(existing_id), ValRef::Local(cotangent_id)],
116                        OpMode::Linear {
117                            active_mask: vec![true, true],
118                        },
119                    );
120                    cotangent_env.insert(input_key, sum[0]);
121                }
122                None => {
123                    cotangent_env.insert(input_key, cotangent_id);
124                }
125            }
126        }
127    }
128
129    let tangent_outputs: Vec<Option<LocalValId>> = linear
130        .tangent_inputs
131        .iter()
132        .map(|(_, tangent_input_id)| {
133            let tangent_input_key = &linear.fragment.vals()[*tangent_input_id].key;
134            cotangent_env.get(tangent_input_key).copied()
135        })
136        .collect();
137    let active_outputs: Vec<LocalValId> = tangent_outputs.iter().filter_map(|id| *id).collect();
138    if !active_outputs.is_empty() {
139        builder.set_outputs(active_outputs);
140    }
141
142    Ok(LinearFragment {
143        fragment: builder.build(),
144        tangent_inputs: cotangent_seed_inputs,
145        tangent_outputs,
146    })
147}
148
149fn cotangent_seed_key<Op: PrimitiveOp>(linear: &LinearFragment<Op>, index: usize) -> Op::InputKey
150where
151    Op::InputKey: ADKey,
152{
153    assert!(
154        !linear.tangent_inputs.is_empty(),
155        "active tangent outputs require at least one tangent input to derive seed keys"
156    );
157
158    let base_slot = index.min(linear.tangent_inputs.len() - 1);
159    let base_key = &linear.tangent_inputs[base_slot].0;
160    base_key.tangent_of(u64::MAX - index as u64)
161}