Skip to main content

tidu/
linearize.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use crate::rules::GraphPrimitiveBuilder;
5use crate::{ADKey, ADRuleResult, DiffPassId, Primitive};
6use computegraph::graph::GraphBuilder;
7use computegraph::resolve::{ResolvedView, ValueDef};
8use computegraph::{GraphOperation, LocalValueId, OperationKey, ValueKey};
9
10use crate::LinearizedGraph;
11
12/// Linearize a resolved computation graph, producing a linear graph.
13///
14/// The transform walks the reachable DAG from `outputs` in dependency-first
15/// order and delegates primitive-specific JVP generation to
16/// [`crate::Primitive::try_jvp_rule`].
17///
18/// # Examples
19///
20/// ```ignore
21/// use computegraph::resolve::resolve;
22/// use tidu::try_linearize;
23///
24/// let view = resolve(vec![primal_graph]);
25/// let mut ctx = ();
26/// let aliases = std::collections::HashMap::new();
27/// let linear = try_linearize(&view, &[output_key], &[input_key], 1, &mut ctx, &aliases)?;
28/// assert_eq!(linear.tangent_outputs().len(), 1);
29/// # Ok::<(), crate::ADRuleError>(())
30/// ```
31pub fn linearize<Op: Primitive>(
32    view: &ResolvedView<Op>,
33    outputs: &[ValueKey<Op>],
34    wrt: &[Op::InputKey],
35    pass: DiffPassId,
36    ctx: &mut Op::ADContext,
37    aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
38) -> LinearizedGraph<Op>
39where
40    Op::InputKey: ADKey,
41{
42    match try_linearize(view, outputs, wrt, pass, ctx, aliases) {
43        Ok(linear) => linear,
44        Err(err) => panic!("{}", err),
45    }
46}
47
48/// Fallible form of [`linearize`].
49///
50/// This returns [`crate::ADRuleError`] when a primitive cannot emit a JVP
51/// rule, allowing downstream frontends to surface missing extension rules as
52/// normal errors instead of panics.
53pub fn try_linearize<Op: Primitive>(
54    view: &ResolvedView<Op>,
55    outputs: &[ValueKey<Op>],
56    wrt: &[Op::InputKey],
57    pass: DiffPassId,
58    ctx: &mut Op::ADContext,
59    aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
60) -> ADRuleResult<LinearizedGraph<Op>>
61where
62    Op::InputKey: ADKey,
63{
64    let mut builder = GraphBuilder::<Op>::new();
65    let topo_keys = topological_order(view, outputs, aliases);
66    let mut tangent_env: HashMap<ValueKey<Op>, Option<LocalValueId>> = HashMap::new();
67    let mut processed_ops = HashSet::new();
68
69    let mut tangent_inputs = Vec::with_capacity(wrt.len());
70    for wrt_key in wrt {
71        let tangent_key = wrt_key.tangent_of(pass);
72        let tangent_id = builder.add_input(tangent_key);
73        tangent_env.insert(ValueKey::Input(wrt_key.clone()), Some(tangent_id));
74        tangent_inputs.push((wrt_key.clone(), tangent_id));
75    }
76
77    for key in topo_keys {
78        if tangent_env.contains_key(&key) {
79            continue;
80        }
81
82        let val_def = match view.resolve_value(&key) {
83            Some(val_def) => val_def,
84            None => continue,
85        };
86
87        match val_def {
88            ValueDef::Input { key: ref input_key } => {
89                if let Some(aliased_key) = aliases.get(input_key) {
90                    let aliased_tangent = tangent_env.get(aliased_key).copied().flatten();
91                    tangent_env.insert(key, aliased_tangent);
92                } else {
93                    tangent_env.insert(key, None);
94                }
95            }
96            ValueDef::Produced {
97                operation,
98                input_keys,
99                role,
100                ..
101            } => {
102                let global_op_key =
103                    OperationKey::new(operation.clone(), input_keys.clone(), role.clone());
104                if !processed_ops.insert(global_op_key.clone()) {
105                    continue;
106                }
107
108                let tangent_in: Vec<Option<LocalValueId>> = input_keys
109                    .iter()
110                    .map(|input_key| tangent_env.get(input_key).copied().flatten())
111                    .collect();
112                let output_keys = output_keys(&global_op_key, operation.output_count());
113
114                if tangent_in.iter().all(Option::is_none) {
115                    for output_key in output_keys {
116                        tangent_env.insert(output_key, None);
117                    }
118                    continue;
119                }
120
121                let mut primitive_builder = GraphPrimitiveBuilder::new(&mut builder);
122                let tangent_out = operation.try_jvp_rule(
123                    &mut primitive_builder,
124                    &input_keys,
125                    &output_keys,
126                    &tangent_in,
127                    ctx,
128                )?;
129                assert_eq!(
130                    tangent_out.len(),
131                    output_keys.len(),
132                    "jvp_rule for {:?} returned {} tangents for {} outputs",
133                    operation,
134                    tangent_out.len(),
135                    output_keys.len()
136                );
137
138                for (output_key, tangent_output) in output_keys.into_iter().zip(tangent_out) {
139                    tangent_env.insert(output_key, tangent_output);
140                }
141            }
142        }
143    }
144
145    let tangent_outputs: Vec<Option<LocalValueId>> = outputs
146        .iter()
147        .map(|key| tangent_env.get(key).copied().flatten())
148        .collect();
149    let active_outputs: Vec<LocalValueId> = tangent_outputs.iter().filter_map(|id| *id).collect();
150    if !active_outputs.is_empty() {
151        builder.set_outputs(active_outputs);
152    }
153
154    Ok(LinearizedGraph::from_parts(
155        builder.build(),
156        tangent_inputs,
157        tangent_outputs,
158    ))
159}
160
161fn output_keys<Op: GraphOperation>(
162    op_key: &OperationKey<Op>,
163    output_count: usize,
164) -> Vec<ValueKey<Op>> {
165    let op_key = Arc::new(op_key.clone());
166    (0..output_count)
167        .map(|output_slot| ValueKey::Derived {
168            operation: Arc::clone(&op_key),
169            output_slot: output_slot as u8,
170        })
171        .collect()
172}
173
174fn topological_order<Op: GraphOperation>(
175    view: &ResolvedView<Op>,
176    outputs: &[ValueKey<Op>],
177    aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
178) -> Vec<ValueKey<Op>> {
179    fn visit<Op: GraphOperation>(
180        key: &ValueKey<Op>,
181        view: &ResolvedView<Op>,
182        aliases: &HashMap<Op::InputKey, ValueKey<Op>>,
183        visited: &mut HashSet<ValueKey<Op>>,
184        order: &mut Vec<ValueKey<Op>>,
185    ) {
186        if !visited.insert(key.clone()) {
187            return;
188        }
189
190        match view.resolve_value(key) {
191            Some(ValueDef::Produced { input_keys, .. }) => {
192                for input_key in input_keys {
193                    visit(&input_key, view, aliases, visited, order);
194                }
195            }
196            Some(ValueDef::Input { key: input_key }) => {
197                if let Some(aliased_key) = aliases.get(&input_key) {
198                    visit(aliased_key, view, aliases, visited, order);
199                }
200            }
201            None => {}
202        }
203
204        order.push(key.clone());
205    }
206
207    let mut visited = HashSet::new();
208    let mut order = Vec::new();
209    for output_key in outputs {
210        visit(output_key, view, aliases, &mut visited, &mut order);
211    }
212    order
213}