Skip to main content

tenferro_ad/
eager.rs

1use std::cell::RefCell;
2use std::cmp::Reverse;
3use std::collections::HashMap;
4use std::env;
5use std::fmt;
6use std::sync::{Arc, Mutex, MutexGuard, OnceLock, Weak};
7use std::time::{Duration, Instant};
8
9use crate::extension_cache::{ExtensionCacheLimits, ExtensionCacheStore};
10use crate::extension_runtime::{ExtensionExecutor, ExtensionRuntimeRegistryError};
11#[cfg(test)]
12use computegraph::graph::Graph;
13use computegraph::ValueKey;
14#[cfg(test)]
15use computegraph::ValueRef;
16use tenferro_cpu::CpuBackend;
17#[cfg(feature = "cuda")]
18use tenferro_gpu::CudaBackend;
19#[cfg(feature = "webgpu")]
20use tenferro_gpu::WebGpuBackend;
21use tenferro_ops::input_key::TensorInputKey;
22use tenferro_ops::std_tensor_op::StdTensorOp;
23use tenferro_ops::ExtensionRuleSet;
24use tenferro_ops::ShapeGuardContext;
25#[cfg(test)]
26use tenferro_tensor::BackendSessionHost;
27use tenferro_tensor::{
28    CacheStats, DType, Tensor, TensorBackend, TensorElementwise, TensorRead, TensorValue,
29    TypedTensor,
30};
31use tidu::eager::{self, EagerInput, EagerOutput, KeySource, RecordedGraph, Recorder, Trace};
32
33use self::backward::TenferroBackwardCallbacks;
34use crate::eager_backend::EagerBackend;
35#[cfg(test)]
36use crate::eager_exec::exec_standard_op_on_tensor_reads_in_session;
37use crate::eager_exec::{
38    exec_op_on_tensor_reads_with_extension_executor, exec_op_on_tensors_with_extension_executor,
39};
40use crate::error::{ContextId, Error, Result};
41#[cfg(test)]
42use crate::metadata::push_metadata_scope;
43use crate::metadata::{
44    metadata_scopes_for_scope, register_scoped_metadata_batch, register_scoped_value_metadata,
45    tensor_meta_from_tensor, GlobalMetadataScope,
46};
47use crate::traced::next_input_key;
48
49use crate::AdContext;
50
51mod backward;
52
53pub(crate) type GradSlot = Arc<Mutex<Option<Arc<Tensor>>>>;
54pub(crate) type WeakGradSlot = Weak<Mutex<Option<Arc<Tensor>>>>;
55
56#[derive(Debug, Default, Clone)]
57struct EagerOpProfileEntry {
58    calls: usize,
59    total_time: Duration,
60}
61
62thread_local! {
63    static EAGER_OP_PROFILE_STATE: RefCell<HashMap<&'static str, EagerOpProfileEntry>> =
64        RefCell::new(HashMap::new());
65    #[cfg(test)]
66    static EAGER_OP_PROFILE_ENABLED_OVERRIDE: RefCell<Option<bool>> = const { RefCell::new(None) };
67    #[cfg(test)]
68    static EAGER_OP_PROFILE_PRINT_EVERY_OVERRIDE: RefCell<Option<Option<usize>>> = const { RefCell::new(None) };
69}
70
71pub(crate) fn eager_op_profile_enabled() -> bool {
72    #[cfg(test)]
73    if let Some(value) = EAGER_OP_PROFILE_ENABLED_OVERRIDE.with(|state| *state.borrow()) {
74        return value;
75    }
76
77    static ENABLED: OnceLock<bool> = OnceLock::new();
78    *ENABLED.get_or_init(|| env::var("TENFERRO_PROFILE_EAGER_OP_AGG").is_ok())
79}
80
81pub(crate) fn record_eager_op_profile(section: &'static str, elapsed: Duration) {
82    if !eager_op_profile_enabled() {
83        return;
84    }
85    EAGER_OP_PROFILE_STATE.with(|state| {
86        let mut state = state.borrow_mut();
87        let entry = state.entry(section).or_default();
88        entry.calls += 1;
89        entry.total_time += elapsed;
90    });
91}
92
93pub(crate) fn profile_eager_op_section<T>(section: &'static str, f: impl FnOnce() -> T) -> T {
94    if !eager_op_profile_enabled() {
95        return f();
96    }
97    let started = Instant::now();
98    let result = f();
99    record_eager_op_profile(section, started.elapsed());
100    result
101}
102
103pub(crate) fn maybe_print_eager_op_profile() {
104    if !eager_op_profile_enabled() {
105        return;
106    }
107    let Some(print_every) = eager_op_profile_print_every() else {
108        return;
109    };
110    if print_every == 0 {
111        return;
112    }
113
114    let should_print = EAGER_OP_PROFILE_STATE.with(|state| {
115        state
116            .borrow()
117            .get("nary_op.total")
118            .is_some_and(|entry| entry.calls % print_every == 0)
119    });
120    if should_print {
121        print_and_reset_eager_op_profile();
122    }
123}
124
125fn eager_op_profile_print_every() -> Option<usize> {
126    #[cfg(test)]
127    if let Some(value) = EAGER_OP_PROFILE_PRINT_EVERY_OVERRIDE.with(|state| *state.borrow()) {
128        return value;
129    }
130
131    env::var("TENFERRO_PROFILE_EAGER_OP_PRINT_EVERY")
132        .ok()?
133        .parse()
134        .ok()
135}
136
137pub(crate) fn print_and_reset_eager_op_profile() {
138    EAGER_OP_PROFILE_STATE.with(|state| {
139        let mut entries: Vec<_> = state
140            .borrow()
141            .iter()
142            .map(|(section, entry)| (*section, entry.clone()))
143            .collect();
144        state.borrow_mut().clear();
145        entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
146
147        eprintln!("=== tenferro eager op profile ===");
148        for (section, entry) in entries {
149            eprintln!(
150                "{section}: calls={} total={:.6}ms per_call={:.3}us",
151                entry.calls,
152                entry.total_time.as_secs_f64() * 1.0e3,
153                entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64,
154            );
155        }
156    });
157}
158
159/// Stats for caches owned by an [`EagerRuntime`].
160///
161/// `retained_bytes` fields are logical payload estimates, not process RSS.
162#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
163pub struct EagerRuntimeCacheStats {
164    /// Generic extension runtime caches.
165    pub extensions: CacheStats,
166}
167
168#[cfg(test)]
169pub(crate) struct EagerGraphExecution {
170    pub(crate) outputs: Vec<Arc<Tensor>>,
171    pub(crate) retained_values: HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>,
172}
173
174/// Shared eager execution context for tensors on a backend.
175///
176/// Reusing one context lets eager tensors share backend state, extension
177/// runtime caches, and gradient storage across a computation.
178///
179/// # Examples
180///
181/// ```
182/// use tenferro_cpu::CpuBackend;
183/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
184///
185/// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
186/// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap(), ctx.clone()).unwrap();
187/// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap(), ctx).unwrap();
188/// let z = x.add(&y).unwrap();
189///
190/// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0]);
191/// ```
192pub struct EagerRuntime {
193    pub(crate) backend: Mutex<EagerBackend>,
194    pub(crate) extension_executor: Mutex<ExtensionExecutor<EagerBackend>>,
195    extension_rules: Option<ExtensionRuleSet>,
196    grad_slots: Mutex<HashMap<ValueKey<StdTensorOp>, WeakGradSlot>>,
197}
198
199impl fmt::Debug for EagerRuntime {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        let mut debug = f.debug_struct("EagerRuntime");
202        match self.backend.try_lock() {
203            Ok(backend) => {
204                debug.field("backend", &*backend);
205            }
206            Err(_) => {
207                debug.field("backend", &"<locked>");
208            }
209        }
210        match self.extension_executor.try_lock() {
211            Ok(executor) => {
212                debug.field("extension_executor", &*executor);
213            }
214            Err(_) => {
215                debug.field("extension_executor", &"<locked>");
216            }
217        }
218        debug.field("has_extension_rules", &self.extension_rules.is_some());
219        match self.grad_slots.try_lock() {
220            Ok(slots) => {
221                debug.field("grad_slots_len", &slots.len());
222            }
223            Err(_) => {
224                debug.field("grad_slots_len", &"<locked>");
225            }
226        }
227        debug.finish_non_exhaustive()
228    }
229}
230
231impl EagerRuntime {
232    fn lock_backend(&self) -> Result<MutexGuard<'_, EagerBackend>> {
233        self.backend
234            .lock()
235            .map_err(|_| Error::Internal("backend lock poisoned".to_string()))
236    }
237
238    fn lock_extension_executor(&self) -> Result<MutexGuard<'_, ExtensionExecutor<EagerBackend>>> {
239        self.extension_executor
240            .lock()
241            .map_err(|_| Error::Internal("extension executor lock poisoned".to_string()))
242    }
243
244    fn lock_grad_slots(
245        &self,
246    ) -> Result<MutexGuard<'_, HashMap<ValueKey<StdTensorOp>, WeakGradSlot>>> {
247        self.grad_slots
248            .lock()
249            .map_err(|_| Error::Internal("gradient slot registry lock poisoned".to_string()))
250    }
251
252    fn from_backend(backend: EagerBackend) -> Self {
253        Self::from_backend_with_extension_rules(backend, None)
254    }
255
256    fn from_backend_with_extension_rules(
257        backend: EagerBackend,
258        extension_rules: Option<ExtensionRuleSet>,
259    ) -> Self {
260        Self {
261            backend: Mutex::new(backend),
262            extension_executor: Mutex::new(ExtensionExecutor::new()),
263            extension_rules,
264            grad_slots: Mutex::new(HashMap::new()),
265        }
266    }
267
268    /// Create a shared CPU eager execution context.
269    ///
270    /// # Examples
271    ///
272    /// ```
273    /// use tenferro_ad::EagerRuntime;
274    ///
275    /// let ctx = EagerRuntime::new();
276    /// assert_eq!(std::sync::Arc::strong_count(&ctx), 1);
277    /// ```
278    pub fn new() -> Arc<Self> {
279        Self::with_cpu_backend(CpuBackend::new())
280    }
281
282    /// Create a shared eager execution context from a configured CPU backend.
283    ///
284    /// # Examples
285    ///
286    /// ```
287    /// use tenferro_cpu::CpuBackend;
288    /// use tenferro_ad::{EagerRuntime};
289    ///
290    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::with_threads(1).unwrap());
291    /// assert_eq!(std::sync::Arc::strong_count(&ctx), 1);
292    /// ```
293    pub fn with_cpu_backend(backend: CpuBackend) -> Arc<Self> {
294        Arc::new(Self::from_backend(EagerBackend::cpu(backend)))
295    }
296
297    /// Create a shared CPU eager context with explicit AD extension rules.
298    ///
299    /// # Examples
300    ///
301    /// ```rust
302    /// use tenferro_cpu::CpuBackend;
303    /// use tenferro_ad::{AdContext, EagerRuntime};
304    ///
305    /// let ad = AdContext::builder().build().unwrap();
306    /// let ctx = EagerRuntime::with_cpu_backend_and_ad_context(CpuBackend::new(), &ad);
307    /// assert_eq!(std::sync::Arc::strong_count(&ctx), 1);
308    /// ```
309    pub fn with_cpu_backend_and_ad_context(backend: CpuBackend, ad: &AdContext) -> Arc<Self> {
310        Arc::new(Self::from_backend_with_extension_rules(
311            EagerBackend::cpu(backend),
312            Some(ad.extension_rule_set()),
313        ))
314    }
315
316    /// Create a shared eager execution context from a configured CUDA backend.
317    ///
318    /// # Examples
319    ///
320    /// ```
321    /// use tenferro_gpu::CudaBackend;
322    /// use tenferro_ad::EagerRuntime;
323    ///
324    /// let _ctor: fn(CudaBackend) -> std::sync::Arc<EagerRuntime> =
325    ///     EagerRuntime::with_cuda_backend;
326    /// ```
327    #[cfg(feature = "cuda")]
328    pub fn with_cuda_backend(backend: CudaBackend) -> Arc<Self> {
329        Arc::new(Self::from_backend(EagerBackend::cuda(backend)))
330    }
331
332    /// Create a shared CUDA eager context with explicit AD extension rules.
333    ///
334    /// # Examples
335    ///
336    /// ```rust
337    /// use tenferro_ad::{AdContext, EagerRuntime};
338    /// use tenferro_gpu::CudaBackend;
339    ///
340    /// let _ctor: fn(CudaBackend, &AdContext) -> std::sync::Arc<EagerRuntime> =
341    ///     EagerRuntime::with_cuda_backend_and_ad_context;
342    /// ```
343    #[cfg(feature = "cuda")]
344    pub fn with_cuda_backend_and_ad_context(backend: CudaBackend, ad: &AdContext) -> Arc<Self> {
345        Arc::new(Self::from_backend_with_extension_rules(
346            EagerBackend::cuda(backend),
347            Some(ad.extension_rule_set()),
348        ))
349    }
350
351    /// Create a shared eager execution context from a configured WebGPU backend.
352    ///
353    /// # Examples
354    ///
355    /// ```
356    /// use tenferro_ad::EagerRuntime;
357    /// use tenferro_gpu::WebGpuBackend;
358    ///
359    /// let _ctor: fn(WebGpuBackend) -> std::sync::Arc<EagerRuntime> =
360    ///     EagerRuntime::with_webgpu_backend;
361    /// ```
362    #[cfg(feature = "webgpu")]
363    pub fn with_webgpu_backend(backend: WebGpuBackend) -> Arc<Self> {
364        Arc::new(Self::from_backend(EagerBackend::webgpu(backend)))
365    }
366
367    /// Create a shared WebGPU eager context with explicit AD extension rules.
368    ///
369    /// # Examples
370    ///
371    /// ```rust
372    /// use tenferro_ad::{AdContext, EagerRuntime};
373    /// use tenferro_gpu::WebGpuBackend;
374    ///
375    /// let _ctor: fn(WebGpuBackend, &AdContext) -> std::sync::Arc<EagerRuntime> =
376    ///     EagerRuntime::with_webgpu_backend_and_ad_context;
377    /// ```
378    #[cfg(feature = "webgpu")]
379    pub fn with_webgpu_backend_and_ad_context(backend: WebGpuBackend, ad: &AdContext) -> Arc<Self> {
380        Arc::new(Self::from_backend_with_extension_rules(
381            EagerBackend::webgpu(backend),
382            Some(ad.extension_rule_set()),
383        ))
384    }
385
386    /// Return an opaque identifier for this context.
387    ///
388    /// # Examples
389    ///
390    /// ```
391    /// use tenferro_cpu::CpuBackend;
392    /// use tenferro_ad::{EagerRuntime};
393    ///
394    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
395    /// assert_ne!(ctx.id(), EagerRuntime::with_cpu_backend(CpuBackend::new()).id());
396    /// ```
397    pub fn id(&self) -> ContextId {
398        ContextId::from_ptr(self)
399    }
400
401    /// Register one extension runtime on this eager context.
402    pub fn register_extension(
403        &self,
404        register: impl FnOnce(
405            &mut ExtensionExecutor<EagerBackend>,
406        ) -> std::result::Result<(), ExtensionRuntimeRegistryError>,
407    ) -> std::result::Result<(), ExtensionRuntimeRegistryError> {
408        let mut executor = self.extension_executor.lock().map_err(|_| {
409            ExtensionRuntimeRegistryError::PoisonedLock {
410                name: "extension executor lock",
411            }
412        })?;
413        register(&mut executor)
414    }
415
416    /// Clear generic extension runtime cache entries.
417    ///
418    /// # Examples
419    ///
420    /// ```
421    /// use tenferro_cpu::CpuBackend;
422    /// use tenferro_ad::{EagerRuntime};
423    ///
424    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
425    /// ctx.clear_extension_caches()?;
426    /// assert_eq!(ctx.cache_stats()?.extensions.entries, 0);
427    /// # Ok::<(), tenferro_ad::Error>(())
428    /// ```
429    pub fn clear_extension_caches(&self) -> Result<()> {
430        self.lock_extension_executor()?.clear_caches();
431        Ok(())
432    }
433
434    /// Clear every cache owned by this eager context.
435    ///
436    /// # Examples
437    ///
438    /// ```
439    /// use tenferro_cpu::CpuBackend;
440    /// use tenferro_ad::{EagerRuntime};
441    ///
442    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
443    /// ctx.clear_caches()?;
444    /// assert_eq!(ctx.cache_stats()?.extensions.entries, 0);
445    /// # Ok::<(), tenferro_ad::Error>(())
446    /// ```
447    pub fn clear_caches(&self) -> Result<()> {
448        self.clear_extension_caches()
449    }
450
451    /// Return eager runtime cache-entry and retained-byte stats.
452    ///
453    /// # Examples
454    ///
455    /// ```
456    /// use tenferro_cpu::CpuBackend;
457    /// use tenferro_ad::{EagerRuntime};
458    ///
459    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
460    /// let stats = ctx.cache_stats()?;
461    /// assert_eq!(stats.extensions.entries, 0);
462    /// # Ok::<(), tenferro_ad::Error>(())
463    /// ```
464    pub fn cache_stats(&self) -> Result<EagerRuntimeCacheStats> {
465        Ok(EagerRuntimeCacheStats {
466            extensions: self.lock_extension_executor()?.cache_stats(),
467        })
468    }
469
470    /// Return the extension cache retention limits.
471    pub fn extension_cache_limits(&self) -> Result<ExtensionCacheLimits> {
472        Ok(self.lock_extension_executor()?.cache_limits())
473    }
474
475    /// Replace extension cache retention limits.
476    pub fn set_extension_cache_limits(&self, limits: ExtensionCacheLimits) -> Result<()> {
477        self.lock_extension_executor()?.set_cache_limits(limits);
478        Ok(())
479    }
480
481    /// Mutably borrow generic extension runtime cache storage.
482    ///
483    /// This hook is for standard extension crates that need cache entries
484    /// owned by an eager runtime while preserving eager value semantics outside
485    /// a registered extension execution boundary.
486    ///
487    /// # Examples
488    ///
489    /// ```
490    /// use tenferro_ad::EagerRuntime;
491    /// use tenferro_cpu::CpuBackend;
492    /// use tenferro_runtime::ExtensionCacheKey;
493    ///
494    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
495    /// let key = ExtensionCacheKey::new("example.cache.v1", "plans", 1);
496    ///
497    /// ctx.with_extension_caches_mut(|caches| {
498    ///     caches.put(key, 7_usize, std::mem::size_of::<usize>());
499    /// });
500    ///
501    /// assert_eq!(ctx.cache_stats().unwrap().extensions.entries, 1);
502    /// ```
503    pub fn with_extension_caches_mut<R>(
504        &self,
505        f: impl FnOnce(&mut ExtensionCacheStore) -> R,
506    ) -> Result<R> {
507        let mut executor = self.lock_extension_executor()?;
508        Ok(f(executor.caches_mut()))
509    }
510
511    /// Mutably borrow this runtime's backend.
512    ///
513    /// This hook lets standard extension crates run a whole contraction program
514    /// in a single backend session (instead of one eager op per step) while
515    /// preserving eager value semantics for untracked tensors.
516    ///
517    /// # Examples
518    ///
519    /// ```
520    /// use tenferro_ad::EagerRuntime;
521    /// use tenferro_cpu::CpuBackend;
522    ///
523    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
524    /// // The closure receives `&mut EagerBackend`; standard extension crates
525    /// // use it to open one backend session for a whole contraction program.
526    /// let answer = ctx.with_backend_mut(|_backend| 42).unwrap();
527    /// assert_eq!(answer, 42);
528    /// ```
529    pub fn with_backend_mut<R>(&self, f: impl FnOnce(&mut EagerBackend) -> R) -> Result<R> {
530        let mut backend = self.lock_backend()?;
531        Ok(f(&mut backend))
532    }
533
534    /// Block the current thread until backend work submitted by this eager runtime completes.
535    ///
536    /// CPU runtimes return immediately. CUDA and WebGPU runtimes synchronize
537    /// their current backend work queue.
538    ///
539    /// # Examples
540    ///
541    /// ```
542    /// use tenferro_cpu::CpuBackend;
543    /// use tenferro_ad::EagerRuntime;
544    ///
545    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
546    /// ctx.synchronize().unwrap();
547    /// ```
548    pub fn synchronize(&self) -> Result<()> {
549        self.lock_backend()?.synchronize().map_err(Error::from)
550    }
551
552    pub(crate) fn exec_outputs(&self, op: &StdTensorOp, inputs: &[&Tensor]) -> Result<Vec<Tensor>> {
553        let mut backend =
554            profile_eager_op_section("exec_outputs.lock_backend", || self.lock_backend())?;
555        let mut extension_executor =
556            profile_eager_op_section("exec_outputs.lock_extensions", || {
557                self.lock_extension_executor()
558            })?;
559        profile_eager_op_section("exec_outputs.exec_op", || {
560            exec_op_on_tensors_with_extension_executor(
561                op,
562                inputs,
563                &mut *backend,
564                Some(&mut *extension_executor),
565            )
566        })
567    }
568
569    pub(crate) fn exec_outputs_read(
570        &self,
571        op: &StdTensorOp,
572        inputs: &[TensorRead<'_>],
573    ) -> Result<Vec<Tensor>> {
574        let mut backend =
575            profile_eager_op_section("exec_outputs_read.lock_backend", || self.lock_backend())?;
576        let mut extension_executor =
577            profile_eager_op_section("exec_outputs_read.lock_extensions", || {
578                self.lock_extension_executor()
579            })?;
580        profile_eager_op_section("exec_outputs_read.exec_op", || {
581            exec_op_on_tensor_reads_with_extension_executor(
582                op,
583                inputs,
584                &mut *backend,
585                Some(&mut *extension_executor),
586            )
587        })
588    }
589
590    #[cfg(test)]
591    pub(crate) fn exec_standard_graph_outputs(
592        &self,
593        graph: &Graph<StdTensorOp>,
594        initial_data: &HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>,
595    ) -> Result<EagerGraphExecution> {
596        let mut backend =
597            profile_eager_op_section("exec_graph.lock_backend", || self.lock_backend())?;
598        let mut all_values = initial_data.clone();
599
600        profile_eager_op_section("exec_graph.with_backend_session", || {
601            backend.with_backend_session(|exec| -> Result<()> {
602                for op_node in graph.operations() {
603                    let outputs = {
604                        let input_values = op_node
605                            .inputs
606                            .iter()
607                            .map(|input| {
608                                let key = match input {
609                                    ValueRef::Local(local_id) => &graph.values()[*local_id].key,
610                                    ValueRef::External(key) => key,
611                                };
612                                all_values.get(key).cloned().ok_or_else(|| {
613                                    Error::Internal(format!(
614                                        "standard graph eager execution missing value for {key:?}"
615                                    ))
616                                })
617                            })
618                            .collect::<Result<Vec<_>>>()?;
619                        let input_reads = input_values
620                            .iter()
621                            .map(|value| TensorRead::from_tensor(value.as_ref()))
622                            .collect::<Vec<_>>();
623                        exec_standard_op_on_tensor_reads_in_session(
624                            &op_node.operation,
625                            &input_reads,
626                            exec,
627                        )?
628                    };
629
630                    if outputs.len() != op_node.outputs.len() {
631                        return Err(Error::Internal(format!(
632                            "standard graph eager execution expected {} outputs for {:?}, got {}",
633                            op_node.outputs.len(),
634                            op_node.operation,
635                            outputs.len()
636                        )));
637                    }
638
639                    for (output_id, output) in op_node.outputs.iter().zip(outputs) {
640                        let key = graph.values()[*output_id].key.clone();
641                        all_values.insert(key, Arc::new(output));
642                    }
643                }
644                Ok(())
645            })
646        })?;
647
648        let outputs = graph
649            .outputs()
650            .iter()
651            .map(|&output_id| {
652                let key = &graph.values()[output_id].key;
653                all_values.get(key).cloned().ok_or_else(|| {
654                    Error::Internal(format!(
655                        "standard graph eager execution missing graph output {key:?}"
656                    ))
657                })
658            })
659            .collect::<Result<Vec<_>>>()?;
660
661        Ok(EagerGraphExecution {
662            outputs,
663            retained_values: all_values,
664        })
665    }
666
667    pub(crate) fn try_register_grad_slot(
668        &self,
669        key: &ValueKey<StdTensorOp>,
670        slot: &GradSlot,
671    ) -> Result<()> {
672        self.lock_grad_slots()?
673            .insert(key.clone(), Arc::downgrade(slot));
674        Ok(())
675    }
676
677    /// Clear all live gradient slots tracked by this context.
678    ///
679    /// This resets the stored gradients to `None` without unregistering the
680    /// tensors, so future `backward()` calls can accumulate again.
681    ///
682    /// # Examples
683    ///
684    /// ```
685    /// use tenferro_cpu::CpuBackend;
686    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
687    ///
688    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
689    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap(), ctx.clone()).unwrap();
690    /// let y = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![3], vec![4.0_f64, 5.0, 6.0]).unwrap(), ctx.clone()).unwrap();
691    /// let loss = x.mul(&y).unwrap().reduce_sum(&[0]).unwrap();
692    /// let _ = loss.backward().unwrap();
693    ///
694    /// ctx.clear_grads()?;
695    ///
696    /// assert!(x.grad()?.is_none());
697    /// assert!(y.grad()?.is_none());
698    /// # Ok::<(), tenferro_ad::Error>(())
699    /// ```
700    pub fn clear_grads(&self) -> Result<()> {
701        let mut poisoned_slot = false;
702        self.lock_grad_slots()?.retain(|_, slot| {
703            if let Some(slot) = slot.upgrade() {
704                match slot.lock() {
705                    Ok(mut current) => {
706                        *current = None;
707                    }
708                    Err(_) => {
709                        poisoned_slot = true;
710                    }
711                }
712                true
713            } else {
714                false
715            }
716        });
717        if poisoned_slot {
718            return Err(Error::Internal("gradient slot lock poisoned".to_string()));
719        }
720        Ok(())
721    }
722
723    /// Import a concrete tensor into this context as an untracked constant.
724    ///
725    /// The returned tensor does not participate in gradient tracking.
726    /// Use this for fixed masks, quadrature weights, physical constants,
727    /// and other data that should not receive gradients.
728    ///
729    /// # Examples
730    ///
731    /// ```
732    /// use tenferro_cpu::CpuBackend;
733    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
734    ///
735    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
736    /// let c = ctx.constant_from(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap())?;
737    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx)?;
738    /// let z = x.add(&c).unwrap();
739    ///
740    /// assert_eq!(z.materialized()?.as_slice::<f64>().unwrap(), &[4.0, 6.0]);
741    /// # Ok::<(), tenferro_ad::Error>(())
742    /// ```
743    pub fn constant_from(self: &Arc<Self>, tensor: Tensor) -> Result<EagerTensor> {
744        EagerTensor::new_leaf(Arc::clone(self), tensor, false)
745    }
746
747    /// Import a concrete tensor into this context as a trainable variable.
748    ///
749    /// The returned tensor participates in gradient tracking; its gradient
750    /// slot is registered in this context.
751    ///
752    /// # Examples
753    ///
754    /// ```
755    /// use tenferro_cpu::CpuBackend;
756    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
757    ///
758    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
759    /// let p = ctx.variable_from(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap())?;
760    /// let loss = p.exp().unwrap().reduce_sum(&[0]).unwrap();
761    /// let _ = loss.backward().unwrap();
762    ///
763    /// let grad = p.grad().unwrap().unwrap();
764    /// assert_eq!(grad.shape(), &[2]);
765    /// # Ok::<(), tenferro_ad::Error>(())
766    /// ```
767    pub fn variable_from(self: &Arc<Self>, tensor: Tensor) -> Result<EagerTensor> {
768        EagerTensor::new_leaf(Arc::clone(self), tensor, true)
769    }
770
771    fn store_grads(
772        &self,
773        cotangents: &HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>,
774        backend: &mut EagerBackend,
775    ) -> Result<()> {
776        let mut updates = Vec::new();
777
778        {
779            let mut slots = self.lock_grad_slots()?;
780            slots.retain(|key, slot| {
781                let Some(slot) = slot.upgrade() else {
782                    return false;
783                };
784
785                if let Some(incoming) = cotangents.get(key) {
786                    updates.push((slot, Arc::clone(incoming)));
787                }
788
789                true
790            });
791        }
792
793        for (slot, incoming) in updates {
794            let mut current = slot
795                .lock()
796                .map_err(|_| Error::Internal("gradient slot lock poisoned".to_string()))?;
797            let next = match current.as_ref() {
798                Some(existing) => Arc::new(backend.add(existing.as_ref(), incoming.as_ref())?),
799                None => incoming,
800            };
801            *current = Some(next);
802        }
803
804        Ok(())
805    }
806}
807
808/// Eager tensor with reverse-mode autodiff over concrete tensor values.
809///
810/// This executes each primitive immediately and records a lightweight reverse
811/// DAG for `backward()`. Gradients accumulate across repeated `backward()`
812/// calls until they are cleared explicitly.
813///
814/// # Examples
815///
816/// ```
817/// use tenferro_cpu::CpuBackend;
818/// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
819///
820/// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
821/// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap(), ctx)?;
822/// let loss = x.mul(&x).unwrap().reduce_sum(&[0]).unwrap();
823/// let _cotangents = loss.backward().unwrap();
824/// let loss = x.mul(&x).unwrap().reduce_sum(&[0]).unwrap();
825/// let _cotangents = loss.backward().unwrap();
826///
827/// assert_eq!(x.grad().unwrap().unwrap().as_slice::<f64>().unwrap(), &[4.0, 8.0, 12.0]);
828/// x.clear_grad();
829///
830/// assert!(x.grad().unwrap().is_none());
831/// # Ok::<(), tenferro_ad::Error>(())
832/// ```
833#[derive(Clone)]
834pub struct EagerTensor {
835    pub(crate) value: Arc<TensorValue>,
836    materialized_cache: Arc<OnceLock<Arc<Tensor>>>,
837    pub(crate) key: ValueKey<StdTensorOp>,
838    pub(crate) trace: Option<Trace<StdTensorOp>>,
839    pub(crate) requires_grad: bool,
840    grad_slot: GradSlot,
841    pub(crate) metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
842    pub(crate) ctx: Arc<EagerRuntime>,
843}
844
845impl fmt::Debug for EagerTensor {
846    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
847        f.debug_struct("EagerTensor")
848            .field("dtype", &self.dtype())
849            .field("shape", &self.shape())
850            .field("key", &self.key)
851            .field("requires_grad", &self.requires_grad)
852            .field("has_trace", &self.trace.is_some())
853            .field("ctx_id", &self.ctx_id())
854            .finish_non_exhaustive()
855    }
856}
857
858impl EagerTensor {
859    /// Create an untracked eager tensor inside an existing eager context.
860    ///
861    /// # Examples
862    ///
863    /// ```
864    /// use tenferro_cpu::CpuBackend;
865    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
866    ///
867    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
868    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx)?;
869    ///
870    /// assert_eq!(x.materialized()?.as_slice::<f64>().unwrap(), &[1.0, 2.0]);
871    /// # Ok::<(), tenferro_ad::Error>(())
872    /// ```
873    pub fn from_tensor_in(tensor: Tensor, ctx: Arc<EagerRuntime>) -> Result<Self> {
874        Self::new_leaf(ctx, tensor, false)
875    }
876
877    /// Create a tracked eager leaf inside an existing eager context.
878    ///
879    /// # Examples
880    ///
881    /// ```
882    /// use tenferro_cpu::CpuBackend;
883    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
884    ///
885    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
886    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx)?;
887    ///
888    /// assert!(x.grad().unwrap().is_none());
889    /// # Ok::<(), tenferro_ad::Error>(())
890    /// ```
891    pub fn requires_grad_in(tensor: Tensor, ctx: Arc<EagerRuntime>) -> Result<Self> {
892        Self::new_leaf(ctx, tensor, true)
893    }
894
895    pub(crate) fn new_leaf(
896        ctx: Arc<EagerRuntime>,
897        tensor: Tensor,
898        requires_grad: bool,
899    ) -> Result<Self> {
900        let key = eager_val_key();
901        let metadata_scope =
902            register_scoped_value_metadata(key.clone(), tensor_meta_from_tensor(&tensor)).map_err(
903                |err| Error::Internal(format!("eager leaf metadata registration failed: {err}")),
904            )?;
905        let tensor = Arc::new(tensor);
906        let grad_slot = Arc::new(Mutex::new(None));
907        if requires_grad {
908            ctx.try_register_grad_slot(&key, &grad_slot)?;
909        }
910
911        Ok(Self {
912            value: Arc::new(TensorValue::from_tensor_arc(tensor)),
913            materialized_cache: Arc::new(OnceLock::new()),
914            key,
915            trace: None,
916            requires_grad,
917            grad_slot,
918            metadata_scopes: metadata_scopes_for_scope(metadata_scope),
919            ctx,
920        })
921    }
922
923    pub(crate) fn new_result(
924        ctx: Arc<EagerRuntime>,
925        key: ValueKey<StdTensorOp>,
926        tensor: Tensor,
927        requires_grad: bool,
928        trace: Option<Trace<StdTensorOp>>,
929        metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
930    ) -> Result<Self> {
931        Self::new_result_arc(
932            ctx,
933            key,
934            Arc::new(tensor),
935            requires_grad,
936            trace,
937            metadata_scopes,
938        )
939    }
940
941    pub(crate) fn new_result_arc(
942        ctx: Arc<EagerRuntime>,
943        key: ValueKey<StdTensorOp>,
944        tensor: Arc<Tensor>,
945        requires_grad: bool,
946        trace: Option<Trace<StdTensorOp>>,
947        metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
948    ) -> Result<Self> {
949        let grad_slot = Arc::new(Mutex::new(None));
950        if requires_grad {
951            ctx.try_register_grad_slot(&key, &grad_slot)?;
952        }
953
954        Ok(Self {
955            value: Arc::new(TensorValue::from_tensor_arc(tensor)),
956            materialized_cache: Arc::new(OnceLock::new()),
957            key,
958            trace,
959            requires_grad,
960            grad_slot,
961            metadata_scopes,
962            ctx,
963        })
964    }
965
966    pub(crate) fn new_result_value(
967        ctx: Arc<EagerRuntime>,
968        key: ValueKey<StdTensorOp>,
969        value: TensorValue,
970        requires_grad: bool,
971        trace: Option<Trace<StdTensorOp>>,
972        metadata_scopes: Vec<Arc<GlobalMetadataScope>>,
973    ) -> Result<Self> {
974        let grad_slot = Arc::new(Mutex::new(None));
975        if requires_grad {
976            ctx.try_register_grad_slot(&key, &grad_slot)?;
977        }
978
979        Ok(Self {
980            value: Arc::new(value),
981            materialized_cache: Arc::new(OnceLock::new()),
982            key,
983            trace,
984            requires_grad,
985            grad_slot,
986            metadata_scopes,
987            ctx,
988        })
989    }
990
991    pub(crate) fn new_untracked_result(ctx: Arc<EagerRuntime>, tensor: Tensor) -> Result<Self> {
992        Self::new_result(ctx, eager_val_key(), tensor, false, None, Vec::new())
993    }
994
995    pub(crate) fn new_untracked_value_result(ctx: Arc<EagerRuntime>, value: TensorValue) -> Self {
996        Self {
997            value: Arc::new(value),
998            materialized_cache: Arc::new(OnceLock::new()),
999            key: eager_val_key(),
1000            trace: None,
1001            requires_grad: false,
1002            grad_slot: Arc::new(Mutex::new(None)),
1003            metadata_scopes: Vec::new(),
1004            ctx,
1005        }
1006    }
1007
1008    /// Detach this tensor from the reverse graph.
1009    ///
1010    /// The returned tensor keeps the concrete value but no longer contributes
1011    /// gradients to the original graph.
1012    ///
1013    /// # Examples
1014    ///
1015    /// ```
1016    /// use tenferro_cpu::CpuBackend;
1017    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1018    ///
1019    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
1020    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx)?;
1021    /// let y = x.detach();
1022    ///
1023    /// assert_eq!(y.materialized()?.as_slice::<f64>().unwrap(), &[1.0, 2.0]);
1024    /// assert!(y.grad().unwrap().is_none());
1025    /// # Ok::<(), tenferro_ad::Error>(())
1026    /// ```
1027    pub fn detach(&self) -> Self {
1028        Self::new_untracked_value_result(self.ctx.clone(), self.value.as_ref().clone())
1029    }
1030
1031    /// Detach this tensor from its graph and re-register it in a different
1032    /// context as an untracked leaf.
1033    ///
1034    /// # Examples
1035    ///
1036    /// ```
1037    /// use tenferro_cpu::CpuBackend;
1038    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1039    ///
1040    /// let ctx_a = EagerRuntime::with_cpu_backend(CpuBackend::new());
1041    /// let ctx_b = EagerRuntime::with_cpu_backend(CpuBackend::new());
1042    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx_a)?;
1043    /// let d = x.detach_into(&ctx_b)?;
1044    ///
1045    /// assert!(!d.tracks_grad());
1046    /// assert_eq!(d.ctx_id(), ctx_b.id());
1047    /// # Ok::<(), tenferro_ad::Error>(())
1048    /// ```
1049    pub fn detach_into(&self, ctx: &Arc<EagerRuntime>) -> Result<Self> {
1050        Self::from_tensor_in(self.to_tensor()?, Arc::clone(ctx))
1051    }
1052
1053    /// Materialize and share the concrete tensor value.
1054    ///
1055    /// # Examples
1056    ///
1057    /// ```
1058    /// use tenferro_cpu::CpuBackend;
1059    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1060    ///
1061    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
1062    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap(), ctx)?;
1063    /// assert_eq!(x.materialized()?.as_slice::<f64>().unwrap(), &[3.0]);
1064    /// # Ok::<(), tenferro_ad::Error>(())
1065    /// ```
1066    pub fn materialized(&self) -> Result<Arc<Tensor>> {
1067        self.materialized_arc()
1068    }
1069
1070    /// Return this tensor's scalar dtype without materializing through
1071    /// [`materialized`](Self::materialized).
1072    pub fn dtype(&self) -> DType {
1073        self.value.dtype()
1074    }
1075
1076    /// Return this tensor's logical shape without materializing through
1077    /// [`materialized`](Self::materialized).
1078    pub fn shape(&self) -> &[usize] {
1079        self.value.shape()
1080    }
1081
1082    /// Borrow this tensor value as a [`TensorRead`].
1083    ///
1084    /// This is the preferred borrowed input boundary for executor calls. It
1085    /// preserves the option to replace eager storage with non-contiguous views
1086    /// without forcing callers through [`materialized`](Self::materialized).
1087    pub fn tensor_read(&self) -> TensorRead<'_> {
1088        self.value.tensor_read()
1089    }
1090
1091    /// Materialize this eager tensor as an owned [`Tensor`].
1092    ///
1093    /// This is the owned materialization boundary for callers that need a
1094    /// standalone compact tensor. The operation is fallible because eager
1095    /// values may be backed by lazy or backend-resident storage.
1096    pub fn to_tensor(&self) -> Result<Tensor> {
1097        self.value.to_tensor().map_err(Error::from)
1098    }
1099
1100    pub(crate) fn materialized_arc(&self) -> Result<Arc<Tensor>> {
1101        if let Some(tensor) = self.value.as_tensor_arc() {
1102            return Ok(Arc::clone(tensor));
1103        }
1104        if let Some(tensor) = self.materialized_cache.get() {
1105            return Ok(Arc::clone(tensor));
1106        }
1107
1108        let materialized = Arc::new(self.value.to_tensor().map_err(Error::from)?);
1109        let _ = self.materialized_cache.set(Arc::clone(&materialized));
1110        Ok(self
1111            .materialized_cache
1112            .get()
1113            .map(Arc::clone)
1114            .unwrap_or(materialized))
1115    }
1116
1117    #[cfg(test)]
1118    pub(crate) fn materialized_cache_is_initialized(&self) -> bool {
1119        self.materialized_cache.get().is_some()
1120    }
1121
1122    /// Return the accumulated gradient currently stored for this tensor.
1123    ///
1124    /// The stored gradient accumulates across repeated `backward()` calls
1125    /// until it is cleared explicitly.
1126    ///
1127    /// For complex scalar losses, stored gradients use tenferro's
1128    /// Hermitian-adjoint cotangent convention. See
1129    /// <https://tensor4all.org/tenferro-rs/guides/complex-ad.html>.
1130    ///
1131    /// # Examples
1132    ///
1133    /// ```
1134    /// use tenferro_cpu::CpuBackend;
1135    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1136    ///
1137    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
1138    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx).unwrap();
1139    /// let loss = x.exp().unwrap().reduce_sum(&[0]).unwrap();
1140    /// let _cotangents = loss.backward().unwrap();
1141    ///
1142    /// let grad = x.grad()?.unwrap();
1143    /// assert_eq!(grad.shape(), &[2]);
1144    /// # Ok::<(), tenferro_ad::Error>(())
1145    /// ```
1146    pub fn grad(&self) -> Result<Option<Arc<Tensor>>> {
1147        self.grad_slot
1148            .lock()
1149            .map_err(|_| Error::Internal("gradient slot lock poisoned".to_string()))
1150            .map(|slot| slot.clone())
1151    }
1152
1153    /// Clear the accumulated gradient stored for this tensor.
1154    ///
1155    /// This only affects this tensor's gradient slot. Other tensors in the
1156    /// same context retain their gradients until they are cleared explicitly or
1157    /// overwritten by later accumulation.
1158    ///
1159    /// # Examples
1160    ///
1161    /// ```
1162    /// use tenferro_cpu::CpuBackend;
1163    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1164    ///
1165    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
1166    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap(), ctx.clone()).unwrap();
1167    /// let y = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![3], vec![4.0_f64, 5.0, 6.0]).unwrap(), ctx).unwrap();
1168    /// let loss = x.mul(&y).unwrap().reduce_sum(&[0]).unwrap();
1169    /// let _ = loss.backward().unwrap();
1170    ///
1171    /// x.clear_grad()?;
1172    ///
1173    /// assert!(x.grad()?.is_none());
1174    /// assert!(y.grad()?.is_some());
1175    /// # Ok::<(), tenferro_ad::Error>(())
1176    /// ```
1177    pub fn clear_grad(&self) -> Result<()> {
1178        *self
1179            .grad_slot
1180            .lock()
1181            .map_err(|_| Error::Internal("gradient slot lock poisoned".to_string()))? = None;
1182        Ok(())
1183    }
1184
1185    /// Report whether this tensor participates in gradient tracking.
1186    ///
1187    /// Tracked tensors keep a gradient slot in their eager context; untracked
1188    /// tensors and detached tensors do not.
1189    ///
1190    /// # Examples
1191    ///
1192    /// ```
1193    /// use tenferro_cpu::CpuBackend;
1194    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1195    ///
1196    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
1197    /// let plain = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
1198    /// let tracked = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
1199    /// let detached = tracked.detach();
1200    ///
1201    /// assert!(!plain.tracks_grad());
1202    /// assert!(tracked.tracks_grad());
1203    /// assert!(!detached.tracks_grad());
1204    /// ```
1205    pub fn tracks_grad(&self) -> bool {
1206        self.requires_grad
1207    }
1208
1209    #[cfg(test)]
1210    fn debug_trace_saved_value_count(&self) -> Option<usize> {
1211        self.trace.as_ref().map(|trace| trace.saved_values().len())
1212    }
1213
1214    /// Return the opaque identifier of the context this tensor belongs to.
1215    ///
1216    /// # Examples
1217    ///
1218    /// ```
1219    /// use tenferro_cpu::CpuBackend;
1220    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1221    ///
1222    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
1223    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap(), ctx.clone()).unwrap();
1224    ///
1225    /// assert_eq!(x.ctx_id(), ctx.id());
1226    /// ```
1227    pub fn ctx_id(&self) -> ContextId {
1228        self.ctx.id()
1229    }
1230
1231    /// Borrow the eager runtime context that owns this tensor.
1232    pub fn runtime(&self) -> &Arc<EagerRuntime> {
1233        &self.ctx
1234    }
1235
1236    /// Check whether two tensors belong to the same eager context.
1237    ///
1238    /// # Examples
1239    ///
1240    /// ```
1241    /// use tenferro_cpu::CpuBackend;
1242    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1243    ///
1244    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
1245    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![1.0_f64]).unwrap(), ctx.clone()).unwrap();
1246    /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap(), ctx).unwrap();
1247    ///
1248    /// assert!(x.same_context(&y));
1249    /// ```
1250    pub fn same_context(&self, other: &Self) -> bool {
1251        self.ctx_id() == other.ctx_id()
1252    }
1253
1254    #[cfg(test)]
1255    pub(crate) fn standard_graph_op(
1256        inputs: &[&Self],
1257        build_graph: impl FnOnce(&[TensorInputKey]) -> Result<Arc<Graph<StdTensorOp>>>,
1258    ) -> Result<Vec<Self>> {
1259        let Some(first) = inputs.first() else {
1260            return Err(Error::Internal(
1261                "standard eager graph op requires at least one input tensor".to_string(),
1262            ));
1263        };
1264        let ctx = Arc::clone(&first.ctx);
1265        for tensor in inputs.iter().skip(1) {
1266            if !first.same_context(tensor) {
1267                return Err(Error::ContextMismatch {
1268                    lhs: first.ctx_id(),
1269                    rhs: tensor.ctx_id(),
1270                });
1271            }
1272        }
1273
1274        let mut recorder = Recorder::new(EagerTensorKeySource);
1275        let graph_input_keys = recorder.fresh_input_keys::<StdTensorOp>(inputs.len());
1276        let graph = build_graph(&graph_input_keys)?;
1277        let initial_data = graph_input_keys
1278            .iter()
1279            .zip(inputs.iter())
1280            .map(|(key, tensor)| Ok((ValueKey::Input(key.clone()), tensor.materialized_arc()?)))
1281            .collect::<Result<HashMap<_, _>>>()?;
1282        let execution = ctx.exec_standard_graph_outputs(graph.as_ref(), &initial_data)?;
1283        if execution.outputs.len() != graph.outputs().len() {
1284            return Err(Error::Internal(format!(
1285                "standard eager graph op expected {} graph outputs, got {}",
1286                graph.outputs().len(),
1287                execution.outputs.len()
1288            )));
1289        }
1290
1291        if !inputs.iter().any(|input| input.requires_grad) {
1292            return execution
1293                .outputs
1294                .into_iter()
1295                .map(|output| {
1296                    Self::new_result_arc(
1297                        Arc::clone(&ctx),
1298                        eager_val_key(),
1299                        output,
1300                        false,
1301                        None,
1302                        Vec::new(),
1303                    )
1304                })
1305                .collect();
1306        }
1307
1308        let output_keys = graph
1309            .outputs()
1310            .iter()
1311            .map(|&output_id| graph.values()[output_id].key.clone())
1312            .collect();
1313        let recorded_graph = RecordedGraph::new(Arc::clone(&graph), graph_input_keys, output_keys)
1314            .map_err(eager_record_error)?;
1315        let recorded = record_eager_recorded_graph_outputs(
1316            &mut recorder,
1317            recorded_graph,
1318            &execution.outputs,
1319            execution.retained_values,
1320            inputs,
1321        )?;
1322        if recorded.traces.len() != execution.outputs.len() {
1323            return Err(Error::Internal(format!(
1324                "standard eager graph op expected {} eager traces, got {}",
1325                execution.outputs.len(),
1326                recorded.traces.len()
1327            )));
1328        }
1329
1330        let mut metadata_scopes = vec![Arc::clone(&recorded.metadata_scope)];
1331        for input in inputs {
1332            for scope in &input.metadata_scopes {
1333                push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
1334            }
1335        }
1336
1337        recorded
1338            .traces
1339            .into_iter()
1340            .zip(execution.outputs)
1341            .map(|(trace, output)| {
1342                Self::new_result_arc(
1343                    Arc::clone(&ctx),
1344                    trace.key,
1345                    output,
1346                    trace.requires_grad,
1347                    trace.trace,
1348                    metadata_scopes.clone(),
1349                )
1350            })
1351            .collect()
1352    }
1353
1354    /// Run reverse-mode AD from this scalar output.
1355    ///
1356    /// Returns the full cotangent map produced by the reverse pass and also
1357    /// accumulates into `grad()` for tracked eager tensors reachable from this
1358    /// output.
1359    ///
1360    /// For complex scalar outputs, cotangents use tenferro's Hermitian
1361    /// real-inner-product convention. See
1362    /// <https://tensor4all.org/tenferro-rs/guides/complex-ad.html>.
1363    ///
1364    /// # Examples
1365    ///
1366    /// ```
1367    /// use tenferro_cpu::CpuBackend;
1368    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
1369    ///
1370    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
1371    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap(), ctx).unwrap();
1372    /// let loss = x.add(&x).unwrap().reduce_sum(&[0]).unwrap();
1373    /// let _cotangents = loss.backward().unwrap();
1374    /// let loss = x.add(&x).unwrap().reduce_sum(&[0]).unwrap();
1375    /// let _cotangents = loss.backward().unwrap();
1376    ///
1377    /// assert_eq!(x.grad().unwrap().unwrap().as_slice::<f64>().unwrap(), &[4.0, 4.0, 4.0]);
1378    /// ```
1379    pub fn backward(&self) -> Result<HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>> {
1380        if !self.shape().is_empty() {
1381            return Err(Error::NonScalarGrad {
1382                shape: self.shape().to_vec(),
1383            });
1384        }
1385
1386        let value = self.materialized_arc()?;
1387        let mut backend = self.ctx.lock_backend()?;
1388        let mut extension_executor = self.ctx.lock_extension_executor()?;
1389        let seed = Arc::new(one_like_tensor(value.as_ref(), &mut *backend)?);
1390        let mut callbacks = TenferroBackwardCallbacks::new(
1391            &mut *backend,
1392            Some(&mut *extension_executor),
1393            self.metadata_scopes.clone(),
1394        );
1395        let mut ad_ctx = ShapeGuardContext::with_global_metadata();
1396        if let Some(extension_rules) = &self.ctx.extension_rules {
1397            ad_ctx = ad_ctx.with_extension_rules(extension_rules.clone());
1398        }
1399        let cotangents_result = eager::backward(
1400            &self.key,
1401            self.trace.as_ref(),
1402            seed,
1403            &mut callbacks,
1404            &mut ad_ctx,
1405        );
1406        let callback_error = callbacks.take_error();
1407        drop(callbacks);
1408        let cotangents = match (cotangents_result, callback_error) {
1409            (_, Some(err)) => return Err(Error::Internal(err.to_string())),
1410            (Err(err), None) => return Err(Error::Internal(err.to_string())),
1411            (Ok(cotangents), None) => cotangents,
1412        };
1413        self.ctx.store_grads(&cotangents, &mut backend)?;
1414        Ok(cotangents)
1415    }
1416}
1417
1418pub(crate) fn eager_val_key() -> ValueKey<StdTensorOp> {
1419    ValueKey::Input(next_input_key())
1420}
1421
1422pub(crate) struct EagerTensorKeySource;
1423
1424impl KeySource<StdTensorOp> for EagerTensorKeySource {
1425    fn fresh_input_key(&mut self) -> TensorInputKey {
1426        next_input_key()
1427    }
1428}
1429
1430pub(crate) fn eager_value(tensor: &EagerTensor) -> Result<EagerInput<StdTensorOp>> {
1431    Ok(EagerInput {
1432        key: tensor.key.clone(),
1433        trace: tensor.trace.clone(),
1434        requires_grad: tensor.requires_grad,
1435        data: tensor.materialized_arc()?,
1436    })
1437}
1438
1439pub(crate) struct RecordedEagerOutputs {
1440    pub(crate) traces: Vec<EagerOutput<StdTensorOp>>,
1441    pub(crate) metadata_scope: Arc<GlobalMetadataScope>,
1442}
1443
1444pub(crate) fn record_eager_outputs(
1445    op: &StdTensorOp,
1446    outputs: &[Arc<Tensor>],
1447    inputs: &[&EagerTensor],
1448) -> Result<RecordedEagerOutputs> {
1449    let mut recorder = Recorder::new(EagerTensorKeySource);
1450    let graph_input_keys = recorder.fresh_input_keys::<StdTensorOp>(inputs.len());
1451    let graph =
1452        RecordedGraph::from_primitive(op.clone(), graph_input_keys).map_err(eager_record_error)?;
1453    let retained_values = graph
1454        .output_keys()
1455        .iter()
1456        .cloned()
1457        .zip(outputs.iter().cloned())
1458        .collect();
1459    record_eager_recorded_graph_outputs(&mut recorder, graph, outputs, retained_values, inputs)
1460}
1461
1462pub(crate) fn record_eager_recorded_graph_outputs(
1463    recorder: &mut Recorder<EagerTensorKeySource>,
1464    graph: RecordedGraph<StdTensorOp>,
1465    outputs: &[Arc<Tensor>],
1466    retained_values: HashMap<ValueKey<StdTensorOp>, Arc<Tensor>>,
1467    inputs: &[&EagerTensor],
1468) -> Result<RecordedEagerOutputs> {
1469    let input_values: Vec<_> = inputs
1470        .iter()
1471        .map(|tensor| eager_value(tensor))
1472        .collect::<Result<_>>()?;
1473    let traces = recorder
1474        .record_graph(graph, &input_values, outputs, retained_values)
1475        .map_err(eager_record_error)?;
1476
1477    let mut registrations = Vec::new();
1478    for trace in &traces {
1479        if let Some(output) = outputs.get(trace.output_slot) {
1480            registrations.push((trace.key.clone(), tensor_meta_from_tensor(output.as_ref())));
1481        }
1482    }
1483
1484    if let Some(trace) = traces.iter().find_map(|output| output.trace.as_ref()) {
1485        for (key, value) in trace.saved_values() {
1486            registrations.push((key.clone(), tensor_meta_from_tensor(value.as_ref())));
1487        }
1488    }
1489
1490    Ok(RecordedEagerOutputs {
1491        traces,
1492        metadata_scope: Arc::new(register_scoped_metadata_batch(registrations)?),
1493    })
1494}
1495
1496fn eager_record_error(err: tidu::eager::EagerRecordError) -> Error {
1497    Error::Internal(format!("invalid eager recording metadata: {err}"))
1498}
1499
1500pub(crate) fn exec_single_output(
1501    op: &StdTensorOp,
1502    inputs: &[&Tensor],
1503    ctx: &EagerRuntime,
1504) -> Result<Tensor> {
1505    let mut outputs = ctx.exec_outputs(op, inputs)?;
1506    if outputs.len() != 1 {
1507        return Err(Error::Internal(format!(
1508            "expected one eager output for {:?}, got {}",
1509            op,
1510            outputs.len()
1511        )));
1512    }
1513    Ok(profile_eager_op_section(
1514        "exec_single_output.remove_output",
1515        || outputs.remove(0),
1516    ))
1517}
1518
1519pub(crate) fn exec_single_output_read(
1520    op: &StdTensorOp,
1521    inputs: &[TensorRead<'_>],
1522    ctx: &EagerRuntime,
1523) -> Result<Tensor> {
1524    let mut outputs = ctx.exec_outputs_read(op, inputs)?;
1525    if outputs.len() != 1 {
1526        return Err(Error::Internal(format!(
1527            "expected one eager output for {:?}, got {}",
1528            op,
1529            outputs.len()
1530        )));
1531    }
1532    Ok(profile_eager_op_section(
1533        "exec_single_output_read.remove_output",
1534        || outputs.remove(0),
1535    ))
1536}
1537
1538pub(crate) fn zero_like_tensor<B: TensorBackend>(
1539    input: &Tensor,
1540    backend: &mut B,
1541) -> Result<Tensor> {
1542    let host = match input {
1543        Tensor::F32(tensor) => Tensor::F32(TypedTensor::zeros(tensor.shape().to_vec())?),
1544        Tensor::F64(tensor) => Tensor::F64(TypedTensor::zeros(tensor.shape().to_vec())?),
1545        Tensor::I32(tensor) => Tensor::I32(TypedTensor::zeros(tensor.shape().to_vec())?),
1546        Tensor::I64(tensor) => Tensor::I64(TypedTensor::zeros(tensor.shape().to_vec())?),
1547        Tensor::Bool(tensor) => Tensor::Bool(TypedTensor::from_vec_col_major(
1548            tensor.shape().to_vec(),
1549            vec![false; tensor.n_elements()],
1550        )?),
1551        Tensor::C32(tensor) => Tensor::C32(TypedTensor::zeros(tensor.shape().to_vec())?),
1552        Tensor::C64(tensor) => Tensor::C64(TypedTensor::zeros(tensor.shape().to_vec())?),
1553    };
1554    backend.upload_host_tensor(&host).map_err(Error::from)
1555}
1556
1557pub(crate) fn one_like_tensor<B: TensorBackend>(input: &Tensor, backend: &mut B) -> Result<Tensor> {
1558    let zero = zero_like_tensor(input, backend)?;
1559    backend.exp(&zero).map_err(Error::from)
1560}
1561
1562#[cfg(test)]
1563mod tests;