1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, Weak};
3
4use computegraph::fragment::Fragment;
5use computegraph::{GlobalOpKey, GlobalValKey, OpMode, ValRef};
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_ops::ShapeGuardContext;
8use tenferro_tensor::cpu::CpuBackend;
9use tenferro_tensor::{Tensor, TensorBackend};
10use tidu::{backward_dag, topo_sort_grad_dag, BackwardCallbacks, GradNode, LinearFragment};
11
12use crate::eager_emitter::EagerEmitter;
13use crate::eager_exec::exec_op_on_tensors;
14use crate::error::{Error, Result};
15use crate::traced::next_input_key;
16
17pub(crate) type GradSlot = Arc<Mutex<Option<Arc<Tensor>>>>;
18pub(crate) type WeakGradSlot = Weak<Mutex<Option<Arc<Tensor>>>>;
19
20pub struct EagerContext<B: TensorBackend> {
38 pub(crate) backend: Mutex<B>,
39 grad_slots: Mutex<HashMap<GlobalValKey<StdTensorOp>, WeakGradSlot>>,
40}
41
42impl<B: TensorBackend> EagerContext<B> {
43 fn new(backend: B) -> Self {
44 Self {
45 backend: Mutex::new(backend),
46 grad_slots: Mutex::new(HashMap::new()),
47 }
48 }
49
50 pub fn with_backend(backend: B) -> Arc<Self> {
61 Arc::new(Self::new(backend))
62 }
63
64 pub(crate) fn register_grad_slot(&self, key: &GlobalValKey<StdTensorOp>, slot: &GradSlot) {
65 self.grad_slots
66 .lock()
67 .unwrap()
68 .insert(key.clone(), Arc::downgrade(slot));
69 }
70
71 pub(crate) fn absorb_from(&self, other: &Self) {
72 let other_slots = other.grad_slots.lock().unwrap();
73 let mut slots = self.grad_slots.lock().unwrap();
74 for (key, slot) in other_slots.iter() {
75 slots.entry(key.clone()).or_insert_with(|| slot.clone());
76 }
77 }
78
79 pub fn clear_grads(&self) {
101 self.grad_slots.lock().unwrap().retain(|_, slot| {
102 if let Some(slot) = slot.upgrade() {
103 *slot.lock().unwrap() = None;
104 true
105 } else {
106 false
107 }
108 });
109 }
110
111 fn store_grads(
112 &self,
113 cotangents: &HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>>,
114 backend: &mut B,
115 ) -> Result<()> {
116 let mut updates = Vec::new();
117 let mut staged = Vec::new();
118
119 {
120 let mut slots = self.grad_slots.lock().unwrap();
121 slots.retain(|key, slot| {
122 let Some(slot) = slot.upgrade() else {
123 return false;
124 };
125
126 if let Some(incoming) = cotangents.get(key) {
127 updates.push((slot, Arc::clone(incoming)));
128 }
129
130 true
131 });
132 }
133
134 for (slot, incoming) in updates {
135 let next = {
136 let current = slot.lock().unwrap();
137 match current.as_ref() {
138 Some(existing) => Arc::new(existing.as_ref().add(incoming.as_ref(), backend)?),
139 None => incoming,
140 }
141 };
142 staged.push((slot, next));
143 }
144
145 for (slot, next) in staged {
146 *slot.lock().unwrap() = Some(next);
147 }
148
149 Ok(())
150 }
151}
152
153#[derive(Clone)]
176pub struct EagerTensor<B: TensorBackend = CpuBackend> {
177 pub(crate) data: Arc<Tensor>,
178 pub(crate) key: GlobalValKey<StdTensorOp>,
179 pub(crate) grad_node: Option<Arc<GradNode<StdTensorOp>>>,
180 pub(crate) requires_grad: bool,
181 grad_slot: GradSlot,
182 pub(crate) ctx: Arc<EagerContext<B>>,
183}
184
185impl<B: TensorBackend> std::ops::Add for &EagerTensor<B> {
186 type Output = EagerTensor<B>;
187
188 fn add(self, rhs: &EagerTensor<B>) -> Self::Output {
189 EagerTensor::add(self, rhs).unwrap_or_else(|err| panic!("eager add failed: {}", err))
190 }
191}
192
193impl<B: TensorBackend> std::ops::Mul for &EagerTensor<B> {
194 type Output = EagerTensor<B>;
195
196 fn mul(self, rhs: &EagerTensor<B>) -> Self::Output {
197 EagerTensor::mul(self, rhs).unwrap_or_else(|err| panic!("eager mul failed: {}", err))
198 }
199}
200
201impl<B: TensorBackend> std::ops::Neg for &EagerTensor<B> {
202 type Output = EagerTensor<B>;
203
204 fn neg(self) -> Self::Output {
205 EagerTensor::neg(self).unwrap_or_else(|err| panic!("eager neg failed: {}", err))
206 }
207}
208
209impl EagerTensor<CpuBackend> {
210 pub fn from_tensor(tensor: Tensor) -> Self {
222 Self::from_tensor_in(tensor, EagerContext::with_backend(CpuBackend::new()))
223 }
224
225 pub fn requires_grad(tensor: Tensor) -> Self {
236 Self::requires_grad_in(tensor, EagerContext::with_backend(CpuBackend::new()))
237 }
238}
239
240impl<B: TensorBackend> EagerTensor<B> {
241 pub fn from_tensor_in(tensor: Tensor, ctx: Arc<EagerContext<B>>) -> Self {
254 Self::new_leaf(ctx, tensor, false)
255 }
256
257 pub fn requires_grad_in(tensor: Tensor, ctx: Arc<EagerContext<B>>) -> Self {
270 Self::new_leaf(ctx, tensor, true)
271 }
272
273 pub(crate) fn new_leaf(ctx: Arc<EagerContext<B>>, tensor: Tensor, requires_grad: bool) -> Self {
274 let key = eager_val_key();
275 let grad_slot = Arc::new(Mutex::new(None));
276 if requires_grad {
277 ctx.register_grad_slot(&key, &grad_slot);
278 }
279
280 Self {
281 data: Arc::new(tensor),
282 key,
283 grad_node: None,
284 requires_grad,
285 grad_slot,
286 ctx,
287 }
288 }
289
290 pub(crate) fn new_result(
291 ctx: Arc<EagerContext<B>>,
292 key: GlobalValKey<StdTensorOp>,
293 tensor: Tensor,
294 requires_grad: bool,
295 grad_node: Option<Arc<GradNode<StdTensorOp>>>,
296 ) -> Self {
297 let grad_slot = Arc::new(Mutex::new(None));
298 if requires_grad {
299 ctx.register_grad_slot(&key, &grad_slot);
300 }
301
302 Self {
303 data: Arc::new(tensor),
304 key,
305 grad_node,
306 requires_grad,
307 grad_slot,
308 ctx,
309 }
310 }
311
312 pub fn detach(&self) -> Self {
329 Self::new_leaf(self.ctx.clone(), self.data.as_ref().clone(), false)
330 }
331
332 pub fn data(&self) -> &Tensor {
343 self.data.as_ref()
344 }
345
346 pub fn grad(&self) -> Option<Arc<Tensor>> {
364 self.grad_slot.lock().unwrap().clone()
365 }
366
367 pub fn clear_grad(&self) {
390 *self.grad_slot.lock().unwrap() = None;
391 }
392
393 pub fn tracks_grad(&self) -> bool {
413 self.requires_grad
414 }
415
416 pub fn backward(&self) -> Result<HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>>> {
436 if !self.data.shape().is_empty() {
437 return Err(Error::NonScalarGrad {
438 shape: self.data.shape().to_vec(),
439 });
440 }
441
442 let sorted = topo_sort_grad_dag(&self.grad_node);
443 let mut backend = self.ctx.backend.lock().unwrap();
444 let seed = Arc::new(one_like_tensor(self.data.as_ref(), &mut *backend));
445 let mut callbacks = TenferroBackwardCallbacks {
446 backend: &mut *backend,
447 };
448 let mut ad_ctx = ShapeGuardContext::default();
449 let cotangents = backward_dag(&sorted, &self.key, seed, &mut callbacks, &mut ad_ctx);
450 self.ctx.store_grads(&cotangents, &mut *backend)?;
451 Ok(cotangents)
452 }
453}
454
455pub(crate) struct TenferroBackwardCallbacks<'a, B: TensorBackend> {
456 backend: &'a mut B,
457}
458
459impl<B: TensorBackend> BackwardCallbacks<StdTensorOp> for TenferroBackwardCallbacks<'_, B> {
460 fn execute_forward(
461 &mut self,
462 fragment: &Fragment<StdTensorOp>,
463 initial_data: &HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>>,
464 ) -> HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>> {
465 let mut all_values = initial_data.clone();
466
467 for &input_id in fragment.inputs() {
468 let key = fragment.vals()[input_id].key.clone();
469 all_values.entry(key.clone()).or_insert_with(|| {
470 let GlobalValKey::Input(tangent_key) = &key else {
471 panic!("expected input key for eager forward: {:?}", key);
472 };
473 let tenferro_ops::input_key::TensorInputKey::Tangent { of, .. } = tangent_key
474 else {
475 panic!("missing concrete eager value for {:?}", key);
476 };
477 let base_key = GlobalValKey::Input((**of).clone());
478 let base = initial_data
479 .get(&base_key)
480 .unwrap_or_else(|| panic!("missing base eager value for {:?}", base_key));
481 Arc::new(zero_like_tensor(base.as_ref(), self.backend))
482 });
483 }
484
485 for op_node in fragment.ops() {
486 let resolved_inputs: Vec<&Tensor> = op_node
487 .inputs
488 .iter()
489 .map(|input| match input {
490 ValRef::Local(local_id) => {
491 let key = &fragment.vals()[*local_id].key;
492 all_values
493 .get(key)
494 .unwrap_or_else(|| panic!("missing eager value for local {:?}", key))
495 .as_ref()
496 }
497 ValRef::External(key) => all_values
498 .get(key)
499 .unwrap_or_else(|| panic!("missing eager value for external {:?}", key))
500 .as_ref(),
501 })
502 .collect();
503 let outputs = exec_op_on_tensors(&op_node.op, &resolved_inputs, self.backend)
504 .unwrap_or_else(|err| {
505 panic!("eager forward exec failed for {:?}: {}", op_node.op, err)
506 });
507
508 for (output_id, output) in op_node.outputs.iter().zip(outputs.into_iter()) {
509 let key = fragment.vals()[*output_id].key.clone();
510 all_values.insert(key, Arc::new(output));
511 }
512 }
513
514 all_values
515 }
516
517 fn eager_transpose(
518 &mut self,
519 linear: &LinearFragment<StdTensorOp>,
520 cotangent_out: &[Option<Arc<Tensor>>],
521 external_data: &HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>>,
522 ctx: &mut ShapeGuardContext,
523 ) -> Vec<Option<Arc<Tensor>>> {
524 let mut emitter = EagerEmitter::new(self.backend);
525 emitter.external_data = external_data.clone();
526 let cotangent_seed_ids = cotangent_out
527 .iter()
528 .map(|maybe_seed| {
529 maybe_seed
530 .as_ref()
531 .map(|seed| emitter.push_tensor(Arc::clone(seed)))
532 })
533 .collect::<Vec<_>>();
534
535 tidu::eager_transpose_fragment(linear, &mut emitter, &cotangent_seed_ids, ctx)
536 .into_iter()
537 .map(|maybe_id| maybe_id.map(|id| emitter.tensor(id)))
538 .collect()
539 }
540
541 fn add_operands(&mut self, a: &Arc<Tensor>, b: &Arc<Tensor>) -> Arc<Tensor> {
542 Arc::new(
543 a.as_ref()
544 .add(b.as_ref(), self.backend)
545 .unwrap_or_else(|err| panic!("eager cotangent add failed: {}", err)),
546 )
547 }
548}
549
550pub(crate) fn eager_val_key() -> GlobalValKey<StdTensorOp> {
551 GlobalValKey::Input(next_input_key())
552}
553
554pub(crate) fn saved_forward_values(
555 op: &StdTensorOp,
556 input_keys: &[GlobalValKey<StdTensorOp>],
557 inputs: &[Arc<Tensor>],
558 output: Arc<Tensor>,
559) -> HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>> {
560 let mut saved = HashMap::with_capacity(input_keys.len() + 1);
561 for (key, value) in input_keys.iter().zip(inputs.iter()) {
562 saved.insert(key.clone(), Arc::clone(value));
563 }
564 saved.insert(derived_output_key(op, input_keys, 0), output);
565 saved
566}
567
568pub(crate) fn saved_forward_values_multi(
569 op: &StdTensorOp,
570 input_keys: &[GlobalValKey<StdTensorOp>],
571 inputs: &[Arc<Tensor>],
572 num_outputs: usize,
573 outputs: &[Arc<Tensor>],
574) -> HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>> {
575 let mut saved = HashMap::with_capacity(input_keys.len() + num_outputs);
576 for (key, value) in input_keys.iter().zip(inputs.iter()) {
577 saved.insert(key.clone(), Arc::clone(value));
578 }
579 for slot in 0..num_outputs {
580 saved.insert(
581 derived_output_key(op, input_keys, slot),
582 Arc::clone(&outputs[slot]),
583 );
584 }
585 saved
586}
587
588pub(crate) fn derived_output_key(
589 op: &StdTensorOp,
590 input_keys: &[GlobalValKey<StdTensorOp>],
591 output_slot: usize,
592) -> GlobalValKey<StdTensorOp> {
593 GlobalValKey::Derived {
594 op: GlobalOpKey {
595 primitive: op.clone(),
596 inputs: input_keys.to_vec(),
597 mode: OpMode::Primal,
598 },
599 output_slot: output_slot as u8,
600 }
601}
602
603pub(crate) fn exec_single_output<B: TensorBackend>(
604 op: &StdTensorOp,
605 inputs: &[&Tensor],
606 ctx: &EagerContext<B>,
607) -> Result<Tensor> {
608 let mut backend = ctx.backend.lock().unwrap();
609 let mut outputs = exec_op_on_tensors(op, inputs, &mut *backend)?;
610 if outputs.len() != 1 {
611 return Err(Error::Internal(format!(
612 "expected one eager output for {:?}, got {}",
613 op,
614 outputs.len()
615 )));
616 }
617 Ok(outputs.remove(0))
618}
619
620pub(crate) fn zero_like_tensor<B: TensorBackend>(input: &Tensor, backend: &mut B) -> Tensor {
621 let neg = input
622 .neg(backend)
623 .unwrap_or_else(|err| panic!("zero_like neg failed: {}", err));
624 input
625 .add(&neg, backend)
626 .unwrap_or_else(|err| panic!("zero_like add failed: {}", err))
627}
628
629pub(crate) fn one_like_tensor<B: TensorBackend>(input: &Tensor, backend: &mut B) -> Tensor {
630 let zero = zero_like_tensor(input, backend);
631 backend
632 .exp(&zero)
633 .unwrap_or_else(|err| panic!("one_like exp failed: {}", err))
634}