1use std::collections::HashMap;
2use std::sync::Arc;
3
4use chainrules::{ADKey, PrimitiveOp};
5use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, OpMode};
6
7use crate::{GradEdge, GradNode};
8
9pub struct EagerValue<Op: GraphOp> {
25 pub key: GlobalValKey<Op>,
27 pub node: Option<Arc<GradNode<Op>>>,
29 pub requires_grad: bool,
31 pub data: Arc<Op::Operand>,
33}
34
35pub struct EagerOutput<Op: GraphOp> {
47 pub key: GlobalValKey<Op>,
49 pub node: Option<Arc<GradNode<Op>>>,
51 pub requires_grad: bool,
53 pub output_slot: usize,
55}
56
57pub trait EagerKeySource<Op: GraphOp> {
73 fn fresh_input_key(&mut self) -> Op::InputKey;
75}
76
77pub fn record_eager_op<Op: PrimitiveOp>(
97 key_source: &mut impl EagerKeySource<Op>,
98 op: Op,
99 inputs: &[EagerValue<Op>],
100 outputs: &[Arc<Op::Operand>],
101) -> Vec<EagerOutput<Op>>
102where
103 Op::InputKey: ADKey,
104{
105 assert_eq!(
106 inputs.len(),
107 op.n_inputs(),
108 "record_eager_op for {:?} expected {} inputs, got {}",
109 op,
110 op.n_inputs(),
111 inputs.len()
112 );
113 assert_eq!(
114 outputs.len(),
115 op.n_outputs(),
116 "record_eager_op for {:?} expected {} outputs, got {}",
117 op,
118 op.n_outputs(),
119 outputs.len()
120 );
121 assert!(
122 outputs.len() <= u8::MAX as usize + 1,
123 "record_eager_op for {:?} has too many outputs for GlobalValKey: {}",
124 op,
125 outputs.len()
126 );
127
128 let input_aliases = fresh_value_keys(key_source, inputs.len());
129 let output_keys = fresh_value_keys(key_source, outputs.len());
130 let requires_grad = inputs.iter().any(|input| input.requires_grad);
131
132 let node = requires_grad.then(|| {
133 Arc::new(GradNode::new(
134 op.clone(),
135 input_aliases.clone(),
136 output_keys.clone(),
137 saved_forward_values(&op, &input_aliases, inputs, outputs),
138 inputs
139 .iter()
140 .map(|input| {
141 GradEdge::new(input.node.clone(), input.key.clone(), input.requires_grad)
142 })
143 .collect(),
144 ))
145 });
146
147 output_keys
148 .into_iter()
149 .enumerate()
150 .map(|(output_slot, key)| EagerOutput {
151 key,
152 node: node.clone(),
153 requires_grad,
154 output_slot,
155 })
156 .collect()
157}
158
159pub fn derived_output_key<Op: GraphOp>(
167 op: &Op,
168 input_aliases: &[GlobalValKey<Op>],
169 output_slot: usize,
170) -> GlobalValKey<Op> {
171 assert!(
172 output_slot <= u8::MAX as usize,
173 "output slot {} is too large for GlobalValKey",
174 output_slot
175 );
176
177 GlobalValKey::Derived {
178 op: GlobalOpKey {
179 primitive: op.clone(),
180 inputs: input_aliases.to_vec(),
181 mode: OpMode::Primal,
182 },
183 output_slot: output_slot as u8,
184 }
185}
186
187pub fn saved_forward_values<Op: GraphOp>(
198 op: &Op,
199 input_aliases: &[GlobalValKey<Op>],
200 inputs: &[EagerValue<Op>],
201 outputs: &[Arc<Op::Operand>],
202) -> HashMap<GlobalValKey<Op>, Arc<Op::Operand>> {
203 assert_eq!(
204 input_aliases.len(),
205 op.n_inputs(),
206 "saved_forward_values for {:?} expected {} input aliases, got {}",
207 op,
208 op.n_inputs(),
209 input_aliases.len()
210 );
211 assert_eq!(
212 inputs.len(),
213 op.n_inputs(),
214 "saved_forward_values for {:?} expected {} inputs, got {}",
215 op,
216 op.n_inputs(),
217 inputs.len()
218 );
219 assert_eq!(
220 outputs.len(),
221 op.n_outputs(),
222 "saved_forward_values for {:?} expected {} outputs, got {}",
223 op,
224 op.n_outputs(),
225 outputs.len()
226 );
227 assert!(
228 input_aliases
229 .iter()
230 .all(|key| matches!(key, GlobalValKey::Input(_))),
231 "saved_forward_values for {:?} requires GlobalValKey::Input aliases",
232 op
233 );
234 assert!(
235 outputs.len() <= u8::MAX as usize + 1,
236 "saved_forward_values for {:?} has too many outputs for GlobalValKey: {}",
237 op,
238 outputs.len()
239 );
240 assert_eq!(
241 input_aliases.len(),
242 inputs.len(),
243 "saved_forward_values for {:?} expected one alias per input, got {} aliases for {} inputs",
244 op,
245 input_aliases.len(),
246 inputs.len()
247 );
248
249 let mut saved = HashMap::with_capacity(inputs.len() + outputs.len());
250 for (key, input) in input_aliases.iter().zip(inputs.iter()) {
251 saved.insert(key.clone(), input.data.clone());
252 }
253 for (slot, output) in outputs.iter().enumerate() {
254 saved.insert(derived_output_key(op, input_aliases, slot), output.clone());
255 }
256 saved
257}
258
259fn fresh_value_keys<Op: GraphOp>(
260 key_source: &mut impl EagerKeySource<Op>,
261 count: usize,
262) -> Vec<GlobalValKey<Op>> {
263 (0..count)
264 .map(|_| GlobalValKey::Input(key_source.fresh_input_key()))
265 .collect()
266}