1use std::cell::RefCell;
2use std::cmp::Reverse;
3use std::collections::HashMap;
4use std::env;
5use std::fmt;
6use std::sync::{Arc, Mutex, MutexGuard, OnceLock, Weak};
7use std::time::{Duration, Instant};
8
9use crate::extension_cache::{ExtensionCacheLimits, ExtensionCacheStore};
10use crate::extension_runtime::{ExtensionExecutor, ExtensionRuntimeRegistryError};
11#[cfg(test)]
12use computegraph::graph::Graph;
13use computegraph::ValueKey;
14#[cfg(test)]
15use computegraph::ValueRef;
16use tenferro_cpu::CpuBackend;
17#[cfg(feature = "cuda")]
18use tenferro_gpu::CudaBackend;
19#[cfg(feature = "webgpu")]
20use tenferro_gpu::WebGpuBackend;
21use tenferro_ops::input_key::TensorInputKey;
22use tenferro_ops::std_tensor_op::StdTensorOp;
23use tenferro_ops::ExtensionRuleSet;
24use tenferro_ops::ShapeGuardContext;
25#[cfg(test)]
26use tenferro_tensor::BackendSessionHost;
27use tenferro_tensor::{
28 CacheStats, DType, Tensor, TensorBackend, TensorElementwise, TensorRead, TensorValue,
29 TypedTensor,
30};
31use tidu::eager::{self, EagerInput, EagerOutput, KeySource, RecordedGraph, Recorder, Trace};
32
33use self::backward::TenferroBackwardCallbacks;
34use crate::eager_backend::EagerBackend;
35#[cfg(test)]
36use crate::eager_exec::exec_standard_op_on_tensor_reads_in_session;
37use crate::eager_exec::{
38 exec_op_on_tensor_reads_with_extension_executor, exec_op_on_tensors_with_extension_executor,
39};
40use crate::error::{ContextId, Error, Result};
41#[cfg(test)]
42use crate::metadata::push_metadata_scope;
43use crate::metadata::{
44 metadata_scopes_for_scope, register_scoped_metadata_batch, register_scoped_value_metadata,
45 tensor_meta_from_tensor, GlobalMetadataScope,
46};
47use crate::traced::next_input_key;
48
49use crate::AdContext;
50
51mod backward;
52
53pub(crate) type GradSlot = Arc<Mutex<Option<Arc<Tensor>>>>;
54pub(crate) type WeakGradSlot = Weak<Mutex<Option<Arc<Tensor>>>>;
55
56#[derive(Debug, Default, Clone)]
57struct EagerOpProfileEntry {
58 calls: usize,
59 total_time: Duration,
60}
61
62thread_local! {
63 static EAGER_OP_PROFILE_STATE: RefCell<HashMap<&'static str, EagerOpProfileEntry>> =
64 RefCell::new(HashMap::new());
65 #[cfg(test)]
66 static EAGER_OP_PROFILE_ENABLED_OVERRIDE: RefCell<Option<bool>> = const { RefCell::new(None) };
67 #[cfg(test)]
68 static EAGER_OP_PROFILE_PRINT_EVERY_OVERRIDE: RefCell<Option<Option<usize>>> = const { RefCell::new(None) };
69}
70
71pub(crate) fn eager_op_profile_enabled() -> bool {
72 #[cfg(test)]
73 if let Some(value) = EAGER_OP_PROFILE_ENABLED_OVERRIDE.with(|state| *state.borrow()) {
74 return value;
75 }
76
77 static ENABLED: OnceLock<bool> = OnceLock::new();
78 *ENABLED.get_or_init(|| env::var("TENFERRO_PROFILE_EAGER_OP_AGG").is_ok())
79}
80
81pub(crate) fn record_eager_op_profile(section: &'static str, elapsed: Duration) {
82 if !eager_op_profile_enabled() {
83 return;
84 }
85 EAGER_OP_PROFILE_STATE.with(|state| {
86 let mut state = state.borrow_mut();
87 let entry = state.entry(section).or_default();
88 entry.calls += 1;
89 entry.total_time += elapsed;
90 });
91}
92
93pub(crate) fn profile_eager_op_section<T>(section: &'static str, f: impl FnOnce() -> T) -> T {
94 if !eager_op_profile_enabled() {
95 return f();
96 }
97 let started = Instant::now();
98 let result = f();
99 record_eager_op_profile(section, started.elapsed());
100 result
101}
102
103pub(crate) fn maybe_print_eager_op_profile() {
104 if !eager_op_profile_enabled() {
105 return;
106 }
107 let Some(print_every) = eager_op_profile_print_every() else {
108 return;
109 };
110 if print_every == 0 {
111 return;
112 }
113
114 let should_print = EAGER_OP_PROFILE_STATE.with(|state| {
115 state
116 .borrow()
117 .get("nary_op.total")
118 .is_some_and(|entry| entry.calls % print_every == 0)
119 });
120 if should_print {
121 print_and_reset_eager_op_profile();
122 }
123}
124
125fn eager_op_profile_print_every() -> Option<usize> {
126 #[cfg(test)]
127 if let Some(value) = EAGER_OP_PROFILE_PRINT_EVERY_OVERRIDE.with(|state| *state.borrow()) {
128 return value;
129 }
130
131 env::var("TENFERRO_PROFILE_EAGER_OP_PRINT_EVERY")
132 .ok()?
133 .parse()
134 .ok()
135}
136
137pub(crate) fn print_and_reset_eager_op_profile() {
138 EAGER_OP_PROFILE_STATE.with(|state| {
139 let mut entries: Vec<_> = state
140 .borrow()
141 .iter()
142 .map(|(section, entry)| (*section, entry.clone()))
143 .collect();
144 state.borrow_mut().clear();
145 entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
146
147 eprintln!("=== tenferro eager op profile ===");
148 for (section, entry) in entries {
149 eprintln!(
150 "{section}: calls={} total={:.6}ms per_call={:.3}us",
151 entry.calls,
152 entry.total_time.as_secs_f64() * 1.0e3,
153 entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64,
154 );
155 }
156 });
157}
158
159#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
163pub struct EagerRuntimeCacheStats {
164 pub extensions: CacheStats,
166}
167
168#[cfg(test)]
169pub(crate) struct EagerGraphExecution {
170 pub(crate) outputs: Vec<Arc<Tensor>>,
171 pub(crate) retained_values: HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>,
172}
173
174pub struct EagerRuntime {
193 pub(crate) backend: Mutex<EagerBackend>,
194 pub(crate) extension_executor: Mutex<ExtensionExecutor<EagerBackend>>,
195 extension_rules: Option<ExtensionRuleSet>,
196 grad_slots: Mutex<HashMap<ValueKey<StdTensorOp>, WeakGradSlot>>,
197}
198
199impl fmt::Debug for EagerRuntime {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 let mut debug = f.debug_struct("EagerRuntime");
202 match self.backend.try_lock() {
203 Ok(backend) => {
204 debug.field("backend", &*backend);
205 }
206 Err(_) => {
207 debug.field("backend", &"<locked>");
208 }
209 }
210 match self.extension_executor.try_lock() {
211 Ok(executor) => {
212 debug.field("extension_executor", &*executor);
213 }
214 Err(_) => {
215 debug.field("extension_executor", &"<locked>");
216 }
217 }
218 debug.field("has_extension_rules", &self.extension_rules.is_some());
219 match self.grad_slots.try_lock() {
220 Ok(slots) => {
221 debug.field("grad_slots_len", &slots.len());
222 }
223 Err(_) => {
224 debug.field("grad_slots_len", &"<locked>");
225 }
226 }
227 debug.finish_non_exhaustive()
228 }
229}
230
231impl EagerRuntime {
232 fn lock_backend(&self) -> Result<MutexGuard<'_, EagerBackend>> {
233 self.backend
234 .lock()
235 .map_err(|_| Error::Internal("backend lock poisoned".to_string()))
236 }
237
238 fn lock_extension_executor(&self) -> Result<MutexGuard<'_, ExtensionExecutor<EagerBackend>>> {
239 self.extension_executor
240 .lock()
241 .map_err(|_| Error::Internal("extension executor lock poisoned".to_string()))
242 }
243
244 fn lock_grad_slots(
245 &self,
246 ) -> Result<MutexGuard<'_, HashMap<ValueKey<StdTensorOp>, WeakGradSlot>>> {
247 self.grad_slots
248 .lock()
249 .map_err(|_| Error::Internal("gradient slot registry lock poisoned".to_string()))
250 }
251
252 fn from_backend(backend: EagerBackend) -> Self {
253 Self::from_backend_with_extension_rules(backend, None)
254 }
255
256 fn from_backend_with_extension_rules(
257 backend: EagerBackend,
258 extension_rules: Option<ExtensionRuleSet>,
259 ) -> Self {
260 Self {
261 backend: Mutex::new(backend),
262 extension_executor: Mutex::new(ExtensionExecutor::new()),
263 extension_rules,
264 grad_slots: Mutex::new(HashMap::new()),
265 }
266 }
267
268 pub fn new() -> Arc<Self> {
279 Self::with_cpu_backend(CpuBackend::new())
280 }
281
282 pub fn with_cpu_backend(backend: CpuBackend) -> Arc<Self> {
294 Arc::new(Self::from_backend(EagerBackend::cpu(backend)))
295 }
296
297 pub fn with_cpu_backend_and_ad_context(backend: CpuBackend, ad: &AdContext) -> Arc<Self> {
310 Arc::new(Self::from_backend_with_extension_rules(
311 EagerBackend::cpu(backend),
312 Some(ad.extension_rule_set()),
313 ))
314 }
315
316 #[cfg(feature = "cuda")]
328 pub fn with_cuda_backend(backend: CudaBackend) -> Arc<Self> {
329 Arc::new(Self::from_backend(EagerBackend::cuda(backend)))
330 }
331
332 #[cfg(feature = "cuda")]
344 pub fn with_cuda_backend_and_ad_context(backend: CudaBackend, ad: &AdContext) -> Arc<Self> {
345 Arc::new(Self::from_backend_with_extension_rules(
346 EagerBackend::cuda(backend),
347 Some(ad.extension_rule_set()),
348 ))
349 }
350
351 #[cfg(feature = "webgpu")]
363 pub fn with_webgpu_backend(backend: WebGpuBackend) -> Arc<Self> {
364 Arc::new(Self::from_backend(EagerBackend::webgpu(backend)))
365 }
366
367 #[cfg(feature = "webgpu")]
379 pub fn with_webgpu_backend_and_ad_context(backend: WebGpuBackend, ad: &AdContext) -> Arc<Self> {
380 Arc::new(Self::from_backend_with_extension_rules(
381 EagerBackend::webgpu(backend),
382 Some(ad.extension_rule_set()),
383 ))
384 }
385
386 pub fn id(&self) -> ContextId {
398 ContextId::from_ptr(self)
399 }
400
401 pub fn register_extension(
403 &self,
404 register: impl FnOnce(
405 &mut ExtensionExecutor<EagerBackend>,
406 ) -> std::result::Result<(), ExtensionRuntimeRegistryError>,
407 ) -> std::result::Result<(), ExtensionRuntimeRegistryError> {
408 let mut executor = self.extension_executor.lock().map_err(|_| {
409 ExtensionRuntimeRegistryError::PoisonedLock {
410 name: "extension executor lock",
411 }
412 })?;
413 register(&mut executor)
414 }
415
416 pub fn clear_extension_caches(&self) -> Result<()> {
430 self.lock_extension_executor()?.clear_caches();
431 Ok(())
432 }
433
434 pub fn clear_caches(&self) -> Result<()> {
448 self.clear_extension_caches()
449 }
450
451 pub fn cache_stats(&self) -> Result<EagerRuntimeCacheStats> {
465 Ok(EagerRuntimeCacheStats {
466 extensions: self.lock_extension_executor()?.cache_stats(),
467 })
468 }
469
470 pub fn extension_cache_limits(&self) -> Result<ExtensionCacheLimits> {
472 Ok(self.lock_extension_executor()?.cache_limits())
473 }
474
475 pub fn set_extension_cache_limits(&self, limits: ExtensionCacheLimits) -> Result<()> {
477 self.lock_extension_executor()?.set_cache_limits(limits);
478 Ok(())
479 }
480
481 pub fn with_extension_caches_mut<R>(
504 &self,
505 f: impl FnOnce(&mut ExtensionCacheStore) -> R,
506 ) -> Result<R> {
507 let mut executor = self.lock_extension_executor()?;
508 Ok(f(executor.caches_mut()))
509 }
510
511 pub fn with_backend_mut<R>(&self, f: impl FnOnce(&mut EagerBackend) -> R) -> Result<R> {
530 let mut backend = self.lock_backend()?;
531 Ok(f(&mut backend))
532 }
533
534 pub fn synchronize(&self) -> Result<()> {
549 self.lock_backend()?.synchronize().map_err(Error::from)
550 }
551
552 pub(crate) fn exec_outputs(&self, op: &StdTensorOp, inputs: &[&Tensor]) -> Result<Vec<Tensor>> {
553 let mut backend =
554 profile_eager_op_section("exec_outputs.lock_backend", || self.lock_backend())?;
555 let mut extension_executor =
556 profile_eager_op_section("exec_outputs.lock_extensions", || {
557 self.lock_extension_executor()
558 })?;
559 profile_eager_op_section("exec_outputs.exec_op", || {
560 exec_op_on_tensors_with_extension_executor(
561 op,
562 inputs,
563 &mut *backend,
564 Some(&mut *extension_executor),
565 )
566 })
567 }
568
569 pub(crate) fn exec_outputs_read(
570 &self,
571 op: &StdTensorOp,
572 inputs: &[TensorRead<'_>],
573 ) -> Result<Vec<Tensor>> {
574 let mut backend =
575 profile_eager_op_section("exec_outputs_read.lock_backend", || self.lock_backend())?;
576 let mut extension_executor =
577 profile_eager_op_section("exec_outputs_read.lock_extensions", || {
578 self.lock_extension_executor()
579 })?;
580 profile_eager_op_section("exec_outputs_read.exec_op", || {
581 exec_op_on_tensor_reads_with_extension_executor(
582 op,
583 inputs,
584 &mut *backend,
585 Some(&mut *extension_executor),
586 )
587 })
588 }
589
590 #[cfg(test)]
591 pub(crate) fn exec_standard_graph_outputs(
592 &self,
593 graph: &Graph<StdTensorOp>,
594 initial_data: &HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>,
595 ) -> Result<EagerGraphExecution> {
596 let mut backend =
597 profile_eager_op_section("exec_graph.lock_backend", || self.lock_backend())?;
598 let mut all_values = initial_data.clone();
599
600 profile_eager_op_section("exec_graph.with_backend_session", || {
601 backend.with_backend_session(|exec| -> Result<()> {
602 for op_node in graph.operations() {
603 let outputs = {
604 let input_values = op_node
605 .inputs
606 .iter()
607 .map(|input| {
608 let key = match input {
609 ValueRef::Local(local_id) => &graph.values()[*local_id].key,
610 ValueRef::External(key) => key,
611 };
612 all_values.get(key).cloned().ok_or_else(|| {
613 Error::Internal(format!(
614 "standard graph eager execution missing value for {key:?}"
615 ))
616 })
617 })
618 .collect::<Result<Vec<_>>>()?;
619 let input_reads = input_values
620 .iter()
621 .map(|value| TensorRead::from_tensor(value.as_ref()))
622 .collect::<Vec<_>>();
623 exec_standard_op_on_tensor_reads_in_session(
624 &op_node.operation,
625 &input_reads,
626 exec,
627 )?
628 };
629
630 if outputs.len() != op_node.outputs.len() {
631 return Err(Error::Internal(format!(
632 "standard graph eager execution expected {} outputs for {:?}, got {}",
633 op_node.outputs.len(),
634 op_node.operation,
635 outputs.len()
636 )));
637 }
638
639 for (output_id, output) in op_node.outputs.iter().zip(outputs) {
640 let key = graph.values()[*output_id].key.clone();
641 all_values.insert(key, Arc::new(output));
642 }
643 }
644 Ok(())
645 })
646 })?;
647
648 let outputs = graph
649 .outputs()
650 .iter()
651 .map(|&output_id| {
652 let key = &graph.values()[output_id].key;
653 all_values.get(key).cloned().ok_or_else(|| {
654 Error::Internal(format!(
655 "standard graph eager execution missing graph output {key:?}"
656 ))
657 })
658 })
659 .collect::<Result<Vec<_>>>()?;
660
661 Ok(EagerGraphExecution {
662 outputs,
663 retained_values: all_values,
664 })
665 }
666
667 pub(crate) fn try_register_grad_slot(
668 &self,
669 key: &ValueKey<StdTensorOp>,
670 slot: &GradSlot,
671 ) -> Result<()> {
672 self.lock_grad_slots()?
673 .insert(key.clone(), Arc::downgrade(slot));
674 Ok(())
675 }
676
677 pub fn clear_grads(&self) -> Result<()> {
701 let mut poisoned_slot = false;
702 self.lock_grad_slots()?.retain(|_, slot| {
703 if let Some(slot) = slot.upgrade() {
704 match slot.lock() {
705 Ok(mut current) => {
706 *current = None;
707 }
708 Err(_) => {
709 poisoned_slot = true;
710 }
711 }
712 true
713 } else {
714 false
715 }
716 });
717 if poisoned_slot {
718 return Err(Error::Internal("gradient slot lock poisoned".to_string()));
719 }
720 Ok(())
721 }
722
723 pub fn constant_from(self: &Arc<Self>, tensor: Tensor) -> Result<EagerTensor> {
744 EagerTensor::new_leaf(Arc::clone(self), tensor, false)
745 }
746
747 pub fn variable_from(self: &Arc<Self>, tensor: Tensor) -> Result<EagerTensor> {
768 EagerTensor::new_leaf(Arc::clone(self), tensor, true)
769 }
770
771 fn store_grads(
772 &self,
773 cotangents: &HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>,
774 backend: &mut EagerBackend,
775 ) -> Result<()> {
776 let mut updates = Vec::new();
777
778 {
779 let mut slots = self.lock_grad_slots()?;
780 slots.retain(|key, slot| {
781 let Some(slot) = slot.upgrade() else {
782 return false;
783 };
784
785 if let Some(incoming) = cotangents.get(key) {
786 updates.push((slot, Arc::clone(incoming)));
787 }
788
789 true
790 });
791 }
792
793 for (slot, incoming) in updates {
794 let mut current = slot
795 .lock()
796 .map_err(|_| Error::Internal("gradient slot lock poisoned".to_string()))?;
797 let next = match current.as_ref() {
798 Some(existing) => Arc::new(backend.add(existing.as_ref(), incoming.as_ref())?),
799 None => incoming,
800 };
801 *current = Some(next);
802 }
803
804 Ok(())
805 }
806}
807
808#[derive(Clone)]
834pub struct EagerTensor {
835 pub(crate) value: Arc<TensorValue>,
836 materialized_cache: Arc<OnceLock<Arc<Tensor>>>,
837 pub(crate) key: ValueKey<StdTensorOp>,
838 pub(crate) trace: Option<Trace<StdTensorOp>>,
839 pub(crate) requires_grad: bool,
840 grad_slot: GradSlot,
841 pub(crate) metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
842 pub(crate) ctx: Arc<EagerRuntime>,
843}
844
845impl fmt::Debug for EagerTensor {
846 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
847 f.debug_struct("EagerTensor")
848 .field("dtype", &self.dtype())
849 .field("shape", &self.shape())
850 .field("key", &self.key)
851 .field("requires_grad", &self.requires_grad)
852 .field("has_trace", &self.trace.is_some())
853 .field("ctx_id", &self.ctx_id())
854 .finish_non_exhaustive()
855 }
856}
857
858impl EagerTensor {
859 pub fn from_tensor_in(tensor: Tensor, ctx: Arc<EagerRuntime>) -> Result<Self> {
874 Self::new_leaf(ctx, tensor, false)
875 }
876
877 pub fn requires_grad_in(tensor: Tensor, ctx: Arc<EagerRuntime>) -> Result<Self> {
892 Self::new_leaf(ctx, tensor, true)
893 }
894
895 pub(crate) fn new_leaf(
896 ctx: Arc<EagerRuntime>,
897 tensor: Tensor,
898 requires_grad: bool,
899 ) -> Result<Self> {
900 let key = eager_val_key();
901 let metadata_scope =
902 register_scoped_value_metadata(key.clone(), tensor_meta_from_tensor(&tensor)).map_err(
903 |err| Error::Internal(format!("eager leaf metadata registration failed: {err}")),
904 )?;
905 let tensor = Arc::new(tensor);
906 let grad_slot = Arc::new(Mutex::new(None));
907 if requires_grad {
908 ctx.try_register_grad_slot(&key, &grad_slot)?;
909 }
910
911 Ok(Self {
912 value: Arc::new(TensorValue::from_tensor_arc(tensor)),
913 materialized_cache: Arc::new(OnceLock::new()),
914 key,
915 trace: None,
916 requires_grad,
917 grad_slot,
918 metadata_scopes: metadata_scopes_for_scope(metadata_scope),
919 ctx,
920 })
921 }
922
923 pub(crate) fn new_result(
924 ctx: Arc<EagerRuntime>,
925 key: ValueKey<StdTensorOp>,
926 tensor: Tensor,
927 requires_grad: bool,
928 trace: Option<Trace<StdTensorOp>>,
929 metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
930 ) -> Result<Self> {
931 Self::new_result_arc(
932 ctx,
933 key,
934 Arc::new(tensor),
935 requires_grad,
936 trace,
937 metadata_scopes,
938 )
939 }
940
941 pub(crate) fn new_result_arc(
942 ctx: Arc<EagerRuntime>,
943 key: ValueKey<StdTensorOp>,
944 tensor: Arc<Tensor>,
945 requires_grad: bool,
946 trace: Option<Trace<StdTensorOp>>,
947 metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
948 ) -> Result<Self> {
949 let grad_slot = Arc::new(Mutex::new(None));
950 if requires_grad {
951 ctx.try_register_grad_slot(&key, &grad_slot)?;
952 }
953
954 Ok(Self {
955 value: Arc::new(TensorValue::from_tensor_arc(tensor)),
956 materialized_cache: Arc::new(OnceLock::new()),
957 key,
958 trace,
959 requires_grad,
960 grad_slot,
961 metadata_scopes,
962 ctx,
963 })
964 }
965
966 pub(crate) fn new_result_value(
967 ctx: Arc<EagerRuntime>,
968 key: ValueKey<StdTensorOp>,
969 value: TensorValue,
970 requires_grad: bool,
971 trace: Option<Trace<StdTensorOp>>,
972 metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
973 ) -> Result<Self> {
974 let grad_slot = Arc::new(Mutex::new(None));
975 if requires_grad {
976 ctx.try_register_grad_slot(&key, &grad_slot)?;
977 }
978
979 Ok(Self {
980 value: Arc::new(value),
981 materialized_cache: Arc::new(OnceLock::new()),
982 key,
983 trace,
984 requires_grad,
985 grad_slot,
986 metadata_scopes,
987 ctx,
988 })
989 }
990
991 pub(crate) fn new_untracked_result(ctx: Arc<EagerRuntime>, tensor: Tensor) -> Result<Self> {
992 Self::new_result(ctx, eager_val_key(), tensor, false, None, Vec::new())
993 }
994
995 pub(crate) fn new_untracked_value_result(ctx: Arc<EagerRuntime>, value: TensorValue) -> Self {
996 Self {
997 value: Arc::new(value),
998 materialized_cache: Arc::new(OnceLock::new()),
999 key: eager_val_key(),
1000 trace: None,
1001 requires_grad: false,
1002 grad_slot: Arc::new(Mutex::new(None)),
1003 metadata_scopes: Vec::new(),
1004 ctx,
1005 }
1006 }
1007
1008 pub fn detach(&self) -> Self {
1028 Self::new_untracked_value_result(self.ctx.clone(), self.value.as_ref().clone())
1029 }
1030
1031 pub fn detach_into(&self, ctx: &Arc<EagerRuntime>) -> Result<Self> {
1050 Self::from_tensor_in(self.to_tensor()?, Arc::clone(ctx))
1051 }
1052
1053 pub fn materialized(&self) -> Result<Arc<Tensor>> {
1067 self.materialized_arc()
1068 }
1069
1070 pub fn dtype(&self) -> DType {
1073 self.value.dtype()
1074 }
1075
1076 pub fn shape(&self) -> &[usize] {
1079 self.value.shape()
1080 }
1081
1082 pub fn tensor_read(&self) -> TensorRead<'_> {
1088 self.value.tensor_read()
1089 }
1090
1091 pub fn to_tensor(&self) -> Result<Tensor> {
1097 self.value.to_tensor().map_err(Error::from)
1098 }
1099
1100 pub(crate) fn materialized_arc(&self) -> Result<Arc<Tensor>> {
1101 if let Some(tensor) = self.value.as_tensor_arc() {
1102 return Ok(Arc::clone(tensor));
1103 }
1104 if let Some(tensor) = self.materialized_cache.get() {
1105 return Ok(Arc::clone(tensor));
1106 }
1107
1108 let materialized = Arc::new(self.value.to_tensor().map_err(Error::from)?);
1109 let _ = self.materialized_cache.set(Arc::clone(&materialized));
1110 Ok(self
1111 .materialized_cache
1112 .get()
1113 .map(Arc::clone)
1114 .unwrap_or(materialized))
1115 }
1116
1117 #[cfg(test)]
1118 pub(crate) fn materialized_cache_is_initialized(&self) -> bool {
1119 self.materialized_cache.get().is_some()
1120 }
1121
1122 pub fn grad(&self) -> Result<Option<Arc<Tensor>>> {
1147 self.grad_slot
1148 .lock()
1149 .map_err(|_| Error::Internal("gradient slot lock poisoned".to_string()))
1150 .map(|slot| slot.clone())
1151 }
1152
1153 pub fn clear_grad(&self) -> Result<()> {
1178 *self
1179 .grad_slot
1180 .lock()
1181 .map_err(|_| Error::Internal("gradient slot lock poisoned".to_string()))? = None;
1182 Ok(())
1183 }
1184
1185 pub fn tracks_grad(&self) -> bool {
1206 self.requires_grad
1207 }
1208
1209 #[cfg(test)]
1210 fn debug_trace_saved_value_count(&self) -> Option<usize> {
1211 self.trace.as_ref().map(|trace| trace.saved_values().len())
1212 }
1213
1214 pub fn ctx_id(&self) -> ContextId {
1228 self.ctx.id()
1229 }
1230
1231 pub fn runtime(&self) -> &Arc<EagerRuntime> {
1233 &self.ctx
1234 }
1235
1236 pub fn same_context(&self, other: &Self) -> bool {
1251 self.ctx_id() == other.ctx_id()
1252 }
1253
1254 #[cfg(test)]
1255 pub(crate) fn standard_graph_op(
1256 inputs: &[&Self],
1257 build_graph: impl FnOnce(&[TensorInputKey]) -> Result<Arc<Graph<StdTensorOp>>>,
1258 ) -> Result<Vec<Self>> {
1259 let Some(first) = inputs.first() else {
1260 return Err(Error::Internal(
1261 "standard eager graph op requires at least one input tensor".to_string(),
1262 ));
1263 };
1264 let ctx = Arc::clone(&first.ctx);
1265 for tensor in inputs.iter().skip(1) {
1266 if !first.same_context(tensor) {
1267 return Err(Error::ContextMismatch {
1268 lhs: first.ctx_id(),
1269 rhs: tensor.ctx_id(),
1270 });
1271 }
1272 }
1273
1274 let mut recorder = Recorder::new(EagerTensorKeySource);
1275 let graph_input_keys = recorder.fresh_input_keys::<StdTensorOp>(inputs.len());
1276 let graph = build_graph(&graph_input_keys)?;
1277 let initial_data = graph_input_keys
1278 .iter()
1279 .zip(inputs.iter())
1280 .map(|(key, tensor)| Ok((ValueKey::Input(key.clone()), tensor.materialized_arc()?)))
1281 .collect::<Result<HashMap<_, _>>>()?;
1282 let execution = ctx.exec_standard_graph_outputs(graph.as_ref(), &initial_data)?;
1283 if execution.outputs.len() != graph.outputs().len() {
1284 return Err(Error::Internal(format!(
1285 "standard eager graph op expected {} graph outputs, got {}",
1286 graph.outputs().len(),
1287 execution.outputs.len()
1288 )));
1289 }
1290
1291 if !inputs.iter().any(|input| input.requires_grad) {
1292 return execution
1293 .outputs
1294 .into_iter()
1295 .map(|output| {
1296 Self::new_result_arc(
1297 Arc::clone(&ctx),
1298 eager_val_key(),
1299 output,
1300 false,
1301 None,
1302 Vec::new(),
1303 )
1304 })
1305 .collect();
1306 }
1307
1308 let output_keys = graph
1309 .outputs()
1310 .iter()
1311 .map(|&output_id| graph.values()[output_id].key.clone())
1312 .collect();
1313 let recorded_graph = RecordedGraph::new(Arc::clone(&graph), graph_input_keys, output_keys)
1314 .map_err(eager_record_error)?;
1315 let recorded = record_eager_recorded_graph_outputs(
1316 &mut recorder,
1317 recorded_graph,
1318 &execution.outputs,
1319 execution.retained_values,
1320 inputs,
1321 )?;
1322 if recorded.traces.len() != execution.outputs.len() {
1323 return Err(Error::Internal(format!(
1324 "standard eager graph op expected {} eager traces, got {}",
1325 execution.outputs.len(),
1326 recorded.traces.len()
1327 )));
1328 }
1329
1330 let mut metadata_scopes = vec![Arc::clone(&recorded.metadata_scope)];
1331 for input in inputs {
1332 for scope in &input.metadata_scopes {
1333 push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
1334 }
1335 }
1336
1337 recorded
1338 .traces
1339 .into_iter()
1340 .zip(execution.outputs)
1341 .map(|(trace, output)| {
1342 Self::new_result_arc(
1343 Arc::clone(&ctx),
1344 trace.key,
1345 output,
1346 trace.requires_grad,
1347 trace.trace,
1348 metadata_scopes.clone(),
1349 )
1350 })
1351 .collect()
1352 }
1353
1354 pub fn backward(&self) -> Result<HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>> {
1380 if !self.shape().is_empty() {
1381 return Err(Error::NonScalarGrad {
1382 shape: self.shape().to_vec(),
1383 });
1384 }
1385
1386 let value = self.materialized_arc()?;
1387 let mut backend = self.ctx.lock_backend()?;
1388 let mut extension_executor = self.ctx.lock_extension_executor()?;
1389 let seed = Arc::new(one_like_tensor(value.as_ref(), &mut *backend)?);
1390 let mut callbacks = TenferroBackwardCallbacks::new(
1391 &mut *backend,
1392 Some(&mut *extension_executor),
1393 self.metadata_scopes.clone(),
1394 );
1395 let mut ad_ctx = ShapeGuardContext::with_global_metadata();
1396 if let Some(extension_rules) = &self.ctx.extension_rules {
1397 ad_ctx = ad_ctx.with_extension_rules(extension_rules.clone());
1398 }
1399 let cotangents_result = eager::backward(
1400 &self.key,
1401 self.trace.as_ref(),
1402 seed,
1403 &mut callbacks,
1404 &mut ad_ctx,
1405 );
1406 let callback_error = callbacks.take_error();
1407 drop(callbacks);
1408 let cotangents = match (cotangents_result, callback_error) {
1409 (_, Some(err)) => return Err(Error::Internal(err.to_string())),
1410 (Err(err), None) => return Err(Error::Internal(err.to_string())),
1411 (Ok(cotangents), None) => cotangents,
1412 };
1413 self.ctx.store_grads(&cotangents, &mut backend)?;
1414 Ok(cotangents)
1415 }
1416}
1417
1418pub(crate) fn eager_val_key() -> ValueKey<StdTensorOp> {
1419 ValueKey::Input(next_input_key())
1420}
1421
1422pub(crate) struct EagerTensorKeySource;
1423
1424impl KeySource<StdTensorOp> for EagerTensorKeySource {
1425 fn fresh_input_key(&mut self) -> TensorInputKey {
1426 next_input_key()
1427 }
1428}
1429
1430pub(crate) fn eager_value(tensor: &EagerTensor) -> Result<EagerInput<StdTensorOp>> {
1431 Ok(EagerInput {
1432 key: tensor.key.clone(),
1433 trace: tensor.trace.clone(),
1434 requires_grad: tensor.requires_grad,
1435 data: tensor.materialized_arc()?,
1436 })
1437}
1438
1439pub(crate) struct RecordedEagerOutputs {
1440 pub(crate) traces: Vec<EagerOutput<StdTensorOp>>,
1441 pub(crate) metadata_scope: Arc<GlobalMetadataScope>,
1442}
1443
1444pub(crate) fn record_eager_outputs(
1445 op: &StdTensorOp,
1446 outputs: &[Arc<Tensor>],
1447 inputs: &[&EagerTensor],
1448) -> Result<RecordedEagerOutputs> {
1449 let mut recorder = Recorder::new(EagerTensorKeySource);
1450 let graph_input_keys = recorder.fresh_input_keys::<StdTensorOp>(inputs.len());
1451 let graph =
1452 RecordedGraph::from_primitive(op.clone(), graph_input_keys).map_err(eager_record_error)?;
1453 let retained_values = graph
1454 .output_keys()
1455 .iter()
1456 .cloned()
1457 .zip(outputs.iter().cloned())
1458 .collect();
1459 record_eager_recorded_graph_outputs(&mut recorder, graph, outputs, retained_values, inputs)
1460}
1461
1462pub(crate) fn record_eager_recorded_graph_outputs(
1463 recorder: &mut Recorder<EagerTensorKeySource>,
1464 graph: RecordedGraph<StdTensorOp>,
1465 outputs: &[Arc<Tensor>],
1466 retained_values: HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>,
1467 inputs: &[&EagerTensor],
1468) -> Result<RecordedEagerOutputs> {
1469 let input_values: Vec<_> = inputs
1470 .iter()
1471 .map(|tensor| eager_value(tensor))
1472 .collect::<Result<_>>()?;
1473 let traces = recorder
1474 .record_graph(graph, &input_values, outputs, retained_values)
1475 .map_err(eager_record_error)?;
1476
1477 let mut registrations = Vec::new();
1478 for trace in &traces {
1479 if let Some(output) = outputs.get(trace.output_slot) {
1480 registrations.push((trace.key.clone(), tensor_meta_from_tensor(output.as_ref())));
1481 }
1482 }
1483
1484 if let Some(trace) = traces.iter().find_map(|output| output.trace.as_ref()) {
1485 for (key, value) in trace.saved_values() {
1486 registrations.push((key.clone(), tensor_meta_from_tensor(value.as_ref())));
1487 }
1488 }
1489
1490 Ok(RecordedEagerOutputs {
1491 traces,
1492 metadata_scope: Arc::new(register_scoped_metadata_batch(registrations)?),
1493 })
1494}
1495
1496fn eager_record_error(err: tidu::eager::EagerRecordError) -> Error {
1497 Error::Internal(format!("invalid eager recording metadata: {err}"))
1498}
1499
1500pub(crate) fn exec_single_output(
1501 op: &StdTensorOp,
1502 inputs: &[&Tensor],
1503 ctx: &EagerRuntime,
1504) -> Result<Tensor> {
1505 let mut outputs = ctx.exec_outputs(op, inputs)?;
1506 if outputs.len() != 1 {
1507 return Err(Error::Internal(format!(
1508 "expected one eager output for {:?}, got {}",
1509 op,
1510 outputs.len()
1511 )));
1512 }
1513 Ok(profile_eager_op_section(
1514 "exec_single_output.remove_output",
1515 || outputs.remove(0),
1516 ))
1517}
1518
1519pub(crate) fn exec_single_output_read(
1520 op: &StdTensorOp,
1521 inputs: &[TensorRead<'_>],
1522 ctx: &EagerRuntime,
1523) -> Result<Tensor> {
1524 let mut outputs = ctx.exec_outputs_read(op, inputs)?;
1525 if outputs.len() != 1 {
1526 return Err(Error::Internal(format!(
1527 "expected one eager output for {:?}, got {}",
1528 op,
1529 outputs.len()
1530 )));
1531 }
1532 Ok(profile_eager_op_section(
1533 "exec_single_output_read.remove_output",
1534 || outputs.remove(0),
1535 ))
1536}
1537
1538pub(crate) fn zero_like_tensor<B: TensorBackend>(
1539 input: &Tensor,
1540 backend: &mut B,
1541) -> Result<Tensor> {
1542 let host = match input {
1543 Tensor::F32(tensor) => Tensor::F32(TypedTensor::zeros(tensor.shape().to_vec())?),
1544 Tensor::F64(tensor) => Tensor::F64(TypedTensor::zeros(tensor.shape().to_vec())?),
1545 Tensor::I32(tensor) => Tensor::I32(TypedTensor::zeros(tensor.shape().to_vec())?),
1546 Tensor::I64(tensor) => Tensor::I64(TypedTensor::zeros(tensor.shape().to_vec())?),
1547 Tensor::Bool(tensor) => Tensor::Bool(TypedTensor::from_vec_col_major(
1548 tensor.shape().to_vec(),
1549 vec![false; tensor.n_elements()],
1550 )?),
1551 Tensor::C32(tensor) => Tensor::C32(TypedTensor::zeros(tensor.shape().to_vec())?),
1552 Tensor::C64(tensor) => Tensor::C64(TypedTensor::zeros(tensor.shape().to_vec())?),
1553 };
1554 backend.upload_host_tensor(&host).map_err(Error::from)
1555}
1556
1557pub(crate) fn one_like_tensor<B: TensorBackend>(input: &Tensor, backend: &mut B) -> Result<Tensor> {
1558 let zero = zero_like_tensor(input, backend)?;
1559 backend.exp(&zero).map_err(Error::from)
1560}
1561
1562#[cfg(test)]
1563mod tests;