Skip to main content

tenferro_runtime/
extension_runtime.rs

1//! Backend-parametric runtime dispatch for extension ops.
2//!
3//! This module is intentionally generic: extension crates can register an
4//! executor for a family and keep runtime cache state outside both the
5//! semantic [`ExtensionOp`] payload and the
6//! tensor backend implementation.
7
8use std::collections::HashMap;
9use std::fmt::{self, Debug};
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13use tenferro_ops::ext_op::ExtensionOp;
14use tenferro_tensor::{CacheStats, Tensor, TensorBackend, TensorRead};
15
16use crate::extension_cache::{ExtensionCacheLimits, ExtensionCacheSelector, ExtensionCacheStore};
17
18/// Errors returned by backend-parametric extension runtime registries.
19#[derive(Debug, thiserror::Error)]
20pub enum ExtensionRuntimeRegistryError {
21    /// The `family_id` does not match the namespaced format
22    /// `"<crate-name>.<op-name>.v<major>"`.
23    #[error("family_id {family_id:?} does not match the namespaced format")]
24    MalformedFamilyId { family_id: &'static str },
25    /// A registry lock was poisoned by a panic in another thread.
26    #[error("{name} poisoned")]
27    PoisonedLock { name: &'static str },
28}
29
30/// Backend and cache state passed to one extension execution.
31pub struct ExtensionExecutionContext<'a, B: TensorBackend> {
32    backend: &'a mut B,
33    caches: &'a mut ExtensionCacheStore,
34}
35
36impl<B: TensorBackend> fmt::Debug for ExtensionExecutionContext<'_, B> {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        f.debug_struct("ExtensionExecutionContext")
39            .field("backend_type", &std::any::type_name::<B>())
40            .field("caches", &self.caches)
41            .finish_non_exhaustive()
42    }
43}
44
45impl<'a, B: TensorBackend> ExtensionExecutionContext<'a, B> {
46    /// Build a context from externally-owned backend and cache state.
47    pub fn new(backend: &'a mut B, caches: &'a mut ExtensionCacheStore) -> Self {
48        Self { backend, caches }
49    }
50
51    /// Borrow the backend for non-mutating inspection.
52    pub fn backend(&self) -> &B {
53        self.backend
54    }
55
56    /// Borrow the backend mutably for extension execution.
57    pub fn backend_mut(&mut self) -> &mut B {
58        self.backend
59    }
60
61    /// Borrow the extension runtime cache store.
62    pub fn caches(&self) -> &ExtensionCacheStore {
63        self.caches
64    }
65
66    /// Borrow the extension runtime cache store mutably.
67    pub fn caches_mut(&mut self) -> &mut ExtensionCacheStore {
68        self.caches
69    }
70
71    /// Execute a core-only execution program one instruction at a time.
72    ///
73    /// This is for extension runtimes that lower their own operation into a
74    /// temporary `ExecProgram` containing only core tensor ops. Nested
75    /// `ExecOp::Extension` instructions are rejected so extension dispatch
76    /// cannot bypass the owning runtime registry.
77    ///
78    /// # Examples
79    ///
80    /// ```
81    /// use tenferro_cpu::CpuBackend;
82    /// use tenferro_ops::dim_expr::DimExpr;
83    /// use tenferro_runtime::extension::{ExecInstruction, ExecOp, ExecProgram};
84    /// use tenferro_runtime::{DType, ExtensionCacheStore, ExtensionExecutionContext, Tensor};
85    ///
86    /// let program = ExecProgram {
87    ///     instructions: vec![ExecInstruction {
88    ///         op: ExecOp::Add,
89    ///         input_slots: vec![0, 1],
90    ///         output_slots: vec![2],
91    ///         dtype: DType::F64,
92    ///         output_shapes: vec![vec![]].into(),
93    ///         output_extents: vec![vec![]].into(),
94    ///         last_use: vec![true, true],
95    ///     }],
96    ///     input_slots: vec![0, 1],
97    ///     output_slots: vec![2],
98    ///     n_slots: 3,
99    /// };
100    /// let lhs = Tensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
101    /// let rhs = Tensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap();
102    ///
103    /// let mut backend = CpuBackend::new();
104    /// let mut caches = ExtensionCacheStore::new();
105    /// let mut ctx = ExtensionExecutionContext::new(&mut backend, &mut caches);
106    /// let outputs = ctx
107    ///     .execute_core_exec_program_unsegmented(&program, vec![lhs, rhs])
108    ///     .unwrap();
109    /// assert_eq!(outputs[0].as_slice::<f64>().unwrap(), &[3.0]);
110    /// ```
111    pub fn execute_core_exec_program_unsegmented(
112        &mut self,
113        program: &crate::extension::ExecProgram,
114        inputs: Vec<Tensor>,
115    ) -> crate::error::Result<Vec<Tensor>>
116    where
117        B: 'static,
118    {
119        crate::exec::ensure_core_exec_program(
120            program,
121            "ExtensionExecutionContext::execute_core_exec_program_unsegmented",
122        )?;
123        crate::exec::eval_exec_ir_unsegmented_with_cache(self.backend, program, inputs)
124    }
125
126    /// Borrow backend and extension cache store as disjoint mutable parts.
127    pub fn parts_mut(&mut self) -> (&mut B, &mut ExtensionCacheStore) {
128        (self.backend, self.caches)
129    }
130}
131
132/// A backend-specific runtime executor for one extension family.
133pub trait ExtensionRuntime<B: TensorBackend + 'static>: Debug + Send + Sync + 'static {
134    /// Extension family handled by this executor.
135    fn family_id(&self) -> &'static str;
136
137    /// Execute the extension op with backend and cache state supplied by core.
138    fn execute(
139        &self,
140        op: &dyn ExtensionOp,
141        inputs: &[&Tensor],
142        ctx: &mut ExtensionExecutionContext<'_, B>,
143    ) -> tenferro_tensor::Result<Vec<Tensor>>;
144
145    /// Execute the extension op on borrowed tensor reads.
146    ///
147    /// Implementations that need compact tensors must materialize inputs here
148    /// explicitly. Keeping this method required prevents implicit read-path
149    /// fallbacks from hiding backend or view handling bugs.
150    fn execute_reads(
151        &self,
152        op: &dyn ExtensionOp,
153        inputs: &[TensorRead<'_>],
154        ctx: &mut ExtensionExecutionContext<'_, B>,
155    ) -> tenferro_tensor::Result<Vec<Tensor>>;
156}
157
158fn validate_runtime_output_count(
159    op: &dyn ExtensionOp,
160    outputs: Vec<Tensor>,
161) -> tenferro_tensor::Result<Vec<Tensor>> {
162    let expected = op.output_count();
163    if outputs.len() != expected {
164        return Err(tenferro_tensor::Error::InvalidConfig {
165            op: "extension",
166            message: format!(
167                "family_id {:?}: runtime returned {} outputs but op declared {} outputs",
168                op.family_id(),
169                outputs.len(),
170                expected
171            ),
172        });
173    }
174    Ok(outputs)
175}
176
177fn validate_runtime_input_count(
178    op: &dyn ExtensionOp,
179    actual: usize,
180) -> tenferro_tensor::Result<()> {
181    let expected = op.input_count();
182    if actual != expected {
183        return Err(tenferro_tensor::Error::InvalidConfig {
184            op: "extension",
185            message: format!(
186                "family_id {:?}: op expects {} inputs, got {}",
187                op.family_id(),
188                expected,
189                actual
190            ),
191        });
192    }
193    Ok(())
194}
195
196/// Registry of backend-specific extension runtime executors.
197pub struct ExtensionRegistry<B: TensorBackend + 'static> {
198    executors: HashMap<&'static str, Arc<dyn ExtensionRuntime<B>>>,
199}
200
201impl<B: TensorBackend + 'static> fmt::Debug for ExtensionRegistry<B> {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        let mut families = self.executors.keys().copied().collect::<Vec<_>>();
204        families.sort_unstable();
205        f.debug_struct("ExtensionRegistry")
206            .field("backend_type", &std::any::type_name::<B>())
207            .field("len", &self.executors.len())
208            .field("families", &families)
209            .finish_non_exhaustive()
210    }
211}
212
213impl<B: TensorBackend + 'static> ExtensionRegistry<B> {
214    /// Create an empty extension runtime registry.
215    ///
216    /// # Examples
217    ///
218    /// ```
219    /// use tenferro_runtime::ExtensionRegistry;
220    /// use tenferro_cpu::CpuBackend;
221    ///
222    /// let registry = ExtensionRegistry::<CpuBackend>::new();
223    /// assert!(!registry.contains("example.identity.v1"));
224    /// ```
225    pub fn new() -> Self {
226        Self {
227            executors: HashMap::new(),
228        }
229    }
230
231    /// Register one runtime executor.
232    ///
233    /// Registration is idempotent by family id: registering the same extension
234    /// family more than once succeeds and keeps the first runtime. This lets
235    /// extension crates register their own dependency extensions defensively.
236    pub fn register(
237        &mut self,
238        executor: Arc<dyn ExtensionRuntime<B>>,
239    ) -> Result<(), ExtensionRuntimeRegistryError> {
240        let family_id = executor.family_id();
241        if !is_valid_family_id(family_id) {
242            return Err(ExtensionRuntimeRegistryError::MalformedFamilyId { family_id });
243        }
244        if self.executors.contains_key(family_id) {
245            return Ok(());
246        }
247        self.executors.insert(family_id, executor);
248        Ok(())
249    }
250
251    /// Look up an executor by extension family id.
252    pub fn get(&self, family_id: &str) -> Option<Arc<dyn ExtensionRuntime<B>>> {
253        self.executors.get(family_id).cloned()
254    }
255
256    /// Return whether an executor is registered for `family_id`.
257    pub fn contains(&self, family_id: &str) -> bool {
258        self.executors.contains_key(family_id)
259    }
260
261    /// Number of registered runtime executors.
262    pub fn len(&self) -> usize {
263        self.executors.len()
264    }
265
266    /// Return whether no runtime executors are registered.
267    pub fn is_empty(&self) -> bool {
268        self.executors.is_empty()
269    }
270}
271
272impl<B: TensorBackend + 'static> Default for ExtensionRegistry<B> {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278/// Runtime owner for backend-specific extension dispatch and caches.
279pub struct ExtensionExecutor<B: TensorBackend + 'static> {
280    registry: ExtensionRegistry<B>,
281    caches: ExtensionCacheStore,
282    _backend: PhantomData<fn() -> B>,
283}
284
285impl<B: TensorBackend + 'static> fmt::Debug for ExtensionExecutor<B> {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        f.debug_struct("ExtensionExecutor")
288            .field("backend_type", &std::any::type_name::<B>())
289            .field("registry", &self.registry)
290            .field("caches", &self.caches)
291            .finish_non_exhaustive()
292    }
293}
294
295impl<B: TensorBackend + 'static> ExtensionExecutor<B> {
296    /// Create an executor with an empty registry and default cache limits.
297    ///
298    /// # Examples
299    ///
300    /// ```
301    /// use tenferro_runtime::ExtensionExecutor;
302    /// use tenferro_cpu::CpuBackend;
303    ///
304    /// let executor = ExtensionExecutor::<CpuBackend>::new();
305    /// assert_eq!(executor.cache_stats().entries, 0);
306    /// ```
307    pub fn new() -> Self {
308        Self {
309            registry: ExtensionRegistry::new(),
310            caches: ExtensionCacheStore::new(),
311            _backend: PhantomData,
312        }
313    }
314
315    /// Create an executor from explicit registry and cache store.
316    pub fn with_parts(registry: ExtensionRegistry<B>, caches: ExtensionCacheStore) -> Self {
317        Self {
318            registry,
319            caches,
320            _backend: PhantomData,
321        }
322    }
323
324    /// Borrow the runtime executor registry.
325    pub fn registry(&self) -> &ExtensionRegistry<B> {
326        &self.registry
327    }
328
329    /// Borrow the runtime executor registry mutably.
330    pub fn registry_mut(&mut self) -> &mut ExtensionRegistry<B> {
331        &mut self.registry
332    }
333
334    /// Borrow the extension cache store.
335    pub fn caches(&self) -> &ExtensionCacheStore {
336        &self.caches
337    }
338
339    /// Borrow the extension cache store mutably.
340    pub fn caches_mut(&mut self) -> &mut ExtensionCacheStore {
341        &mut self.caches
342    }
343
344    /// Execute an extension using a registered runtime executor.
345    pub fn execute(
346        &mut self,
347        backend: &mut B,
348        op: &dyn ExtensionOp,
349        inputs: &[&Tensor],
350    ) -> tenferro_tensor::Result<Vec<Tensor>> {
351        validate_runtime_input_count(op, inputs.len())?;
352        let Some(executor) = self.registry.get(op.family_id()) else {
353            return Err(tenferro_tensor::Error::InvalidConfig {
354                op: "extension",
355                message: format!(
356                    "missing runtime for family_id {:?}; register the extension on this runtime owner, for example `executor.register_extension(<extension_crate>::register_runtime)` or `eager_runtime.register_extension(<extension_crate>::register_runtime)`",
357                    op.family_id()
358                ),
359            });
360        };
361        let mut ctx = ExtensionExecutionContext::new(backend, &mut self.caches);
362        validate_runtime_output_count(op, executor.execute(op, inputs, &mut ctx)?)
363    }
364
365    /// Execute an extension using borrowed tensor reads.
366    ///
367    /// # Examples
368    ///
369    /// ```
370    /// use std::any::Any;
371    /// use std::hash::Hasher;
372    /// use std::sync::Arc;
373    ///
374    /// use tenferro_cpu::CpuBackend;
375    /// use tenferro_ops::{ext_op::ExtensionOp, SymDim};
376    /// use tenferro_runtime::{
377    ///     DType, ExtensionExecutionContext, ExtensionExecutor, ExtensionRuntime, Tensor,
378    /// };
379    /// use tenferro_tensor::TensorRead;
380    ///
381    /// #[derive(Clone, Debug)]
382    /// struct IdentityOp;
383    ///
384    /// impl ExtensionOp for IdentityOp {
385    ///     fn family_id(&self) -> &'static str {
386    ///         "example.identity.v1"
387    ///     }
388    ///
389    ///     fn payload_hash(&self, _hasher: &mut dyn Hasher) {}
390    ///
391    ///     fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
392    ///         other.as_any().is::<IdentityOp>()
393    ///     }
394    ///
395    ///     fn clone_arc(&self) -> Arc<dyn ExtensionOp> {
396    ///         Arc::new(self.clone())
397    ///     }
398    ///
399    ///     fn as_any(&self) -> &dyn Any {
400    ///         self
401    ///     }
402    ///
403    ///     fn input_count(&self) -> usize {
404    ///         1
405    ///     }
406    ///
407    ///     fn output_count(&self) -> usize {
408    ///         1
409    ///     }
410    ///
411    ///     fn infer_output_meta(
412    ///         &self,
413    ///         input_dtypes: &[DType],
414    ///         input_shapes: &[&[SymDim]],
415    ///     ) -> Vec<(DType, Vec<SymDim>)> {
416    ///         vec![(input_dtypes[0], input_shapes[0].to_vec())]
417    ///     }
418    ///
419    ///     fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
420    ///         Ok(vec![inputs[0].clone()])
421    ///     }
422    /// }
423    ///
424    /// #[derive(Debug)]
425    /// struct IdentityRuntime;
426    ///
427    /// impl ExtensionRuntime<CpuBackend> for IdentityRuntime {
428    ///     fn family_id(&self) -> &'static str {
429    ///         "example.identity.v1"
430    ///     }
431    ///
432    ///     fn execute(
433    ///         &self,
434    ///         op: &dyn ExtensionOp,
435    ///         inputs: &[&Tensor],
436    ///         _ctx: &mut ExtensionExecutionContext<'_, CpuBackend>,
437    ///     ) -> tenferro_tensor::Result<Vec<Tensor>> {
438    ///         op.eager_execute(inputs)
439    ///     }
440    ///
441    ///     fn execute_reads(
442    ///         &self,
443    ///         op: &dyn ExtensionOp,
444    ///         inputs: &[TensorRead<'_>],
445    ///         ctx: &mut ExtensionExecutionContext<'_, CpuBackend>,
446    ///     ) -> tenferro_tensor::Result<Vec<Tensor>> {
447    ///         let materialized_inputs: Vec<Tensor> = inputs
448    ///             .iter()
449    ///             .map(TensorRead::to_tensor)
450    ///             .collect::<tenferro_tensor::Result<_>>()?;
451    ///         let input_refs: Vec<&Tensor> = materialized_inputs.iter().collect();
452    ///         self.execute(op, &input_refs, ctx)
453    ///     }
454    /// }
455    ///
456    /// let mut executor = ExtensionExecutor::<CpuBackend>::new();
457    /// executor.registry_mut().register(Arc::new(IdentityRuntime))?;
458    /// let input = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
459    /// let read = TensorRead::from_tensor(&input);
460    /// let mut backend = CpuBackend::new();
461    ///
462    /// let outputs = executor.execute_reads(&mut backend, &IdentityOp, &[read])?;
463    ///
464    /// assert_eq!(outputs[0].as_slice::<f64>().unwrap(), &[1.0, 2.0]);
465    /// # Ok::<(), Box<dyn std::error::Error>>(())
466    /// ```
467    pub fn execute_reads(
468        &mut self,
469        backend: &mut B,
470        op: &dyn ExtensionOp,
471        inputs: &[TensorRead<'_>],
472    ) -> tenferro_tensor::Result<Vec<Tensor>> {
473        validate_runtime_input_count(op, inputs.len())?;
474        let Some(executor) = self.registry.get(op.family_id()) else {
475            return Err(tenferro_tensor::Error::InvalidConfig {
476                op: "extension",
477                message: format!(
478                    "missing runtime for family_id {:?}; register the extension on this runtime owner, for example `executor.register_extension(<extension_crate>::register_runtime)` or `eager_runtime.register_extension(<extension_crate>::register_runtime)`",
479                    op.family_id()
480                ),
481            });
482        };
483        let mut ctx = ExtensionExecutionContext::new(backend, &mut self.caches);
484        validate_runtime_output_count(op, executor.execute_reads(op, inputs, &mut ctx)?)
485    }
486
487    /// Clear every runtime extension cache entry.
488    pub fn clear_caches(&mut self) {
489        self.caches.clear();
490    }
491
492    /// Return extension cache stats for all entries.
493    pub fn cache_stats(&self) -> CacheStats {
494        self.caches.stats(ExtensionCacheSelector::All)
495    }
496
497    /// Return the extension cache retention limits.
498    pub fn cache_limits(&self) -> ExtensionCacheLimits {
499        self.caches.limits()
500    }
501
502    /// Replace extension cache retention limits.
503    pub fn set_cache_limits(&mut self, limits: ExtensionCacheLimits) {
504        self.caches.set_limits(limits);
505    }
506}
507
508impl<B: TensorBackend + 'static> Default for ExtensionExecutor<B> {
509    fn default() -> Self {
510        Self::new()
511    }
512}
513
514#[cfg(test)]
515mod tests;
516
517fn is_valid_family_id(family_id: &str) -> bool {
518    let mut parts = family_id.rsplitn(2, '.');
519    let Some(version_part) = parts.next() else {
520        return false;
521    };
522    let Some(prefix) = parts.next() else {
523        return false;
524    };
525    if !version_part.starts_with('v') {
526        return false;
527    }
528    let digits = &version_part[1..];
529    if digits.is_empty() || !digits.chars().all(|c| c.is_ascii_digit()) {
530        return false;
531    }
532    let Some((crate_name, op_name)) = prefix.split_once('.') else {
533        return false;
534    };
535    if crate_name.is_empty() || op_name.is_empty() {
536        return false;
537    }
538    let any_invalid = |s: &str| s.chars().any(|c| c.is_whitespace() || !c.is_ascii());
539    !any_invalid(crate_name) && !any_invalid(op_name)
540}