tenferro_runtime/extension_runtime.rs
1//! Backend-parametric runtime dispatch for extension ops.
2//!
3//! This module is intentionally generic: extension crates can register an
4//! executor for a family and keep runtime cache state outside both the
5//! semantic [`ExtensionOp`] payload and the
6//! tensor backend implementation.
7
8use std::collections::HashMap;
9use std::fmt::{self, Debug};
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13use tenferro_ops::ext_op::ExtensionOp;
14use tenferro_tensor::{CacheStats, Tensor, TensorBackend, TensorRead};
15
16use crate::extension_cache::{ExtensionCacheLimits, ExtensionCacheSelector, ExtensionCacheStore};
17
18/// Errors returned by backend-parametric extension runtime registries.
19#[derive(Debug, thiserror::Error)]
20pub enum ExtensionRuntimeRegistryError {
21 /// The `family_id` does not match the namespaced format
22 /// `"<crate-name>.<op-name>.v<major>"`.
23 #[error("family_id {family_id:?} does not match the namespaced format")]
24 MalformedFamilyId { family_id: &'static str },
25 /// A registry lock was poisoned by a panic in another thread.
26 #[error("{name} poisoned")]
27 PoisonedLock { name: &'static str },
28}
29
30/// Backend and cache state passed to one extension execution.
31pub struct ExtensionExecutionContext<'a, B: TensorBackend> {
32 backend: &'a mut B,
33 caches: &'a mut ExtensionCacheStore,
34}
35
36impl<B: TensorBackend> fmt::Debug for ExtensionExecutionContext<'_, B> {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 f.debug_struct("ExtensionExecutionContext")
39 .field("backend_type", &std::any::type_name::<B>())
40 .field("caches", &self.caches)
41 .finish_non_exhaustive()
42 }
43}
44
45impl<'a, B: TensorBackend> ExtensionExecutionContext<'a, B> {
46 /// Build a context from externally-owned backend and cache state.
47 pub fn new(backend: &'a mut B, caches: &'a mut ExtensionCacheStore) -> Self {
48 Self { backend, caches }
49 }
50
51 /// Borrow the backend for non-mutating inspection.
52 pub fn backend(&self) -> &B {
53 self.backend
54 }
55
56 /// Borrow the backend mutably for extension execution.
57 pub fn backend_mut(&mut self) -> &mut B {
58 self.backend
59 }
60
61 /// Borrow the extension runtime cache store.
62 pub fn caches(&self) -> &ExtensionCacheStore {
63 self.caches
64 }
65
66 /// Borrow the extension runtime cache store mutably.
67 pub fn caches_mut(&mut self) -> &mut ExtensionCacheStore {
68 self.caches
69 }
70
71 /// Execute a core-only execution program one instruction at a time.
72 ///
73 /// This is for extension runtimes that lower their own operation into a
74 /// temporary `ExecProgram` containing only core tensor ops. Nested
75 /// `ExecOp::Extension` instructions are rejected so extension dispatch
76 /// cannot bypass the owning runtime registry.
77 ///
78 /// # Examples
79 ///
80 /// ```
81 /// use tenferro_cpu::CpuBackend;
82 /// use tenferro_ops::dim_expr::DimExpr;
83 /// use tenferro_runtime::extension::{ExecInstruction, ExecOp, ExecProgram};
84 /// use tenferro_runtime::{DType, ExtensionCacheStore, ExtensionExecutionContext, Tensor};
85 ///
86 /// let program = ExecProgram {
87 /// instructions: vec![ExecInstruction {
88 /// op: ExecOp::Add,
89 /// input_slots: vec![0, 1],
90 /// output_slots: vec![2],
91 /// dtype: DType::F64,
92 /// output_shapes: vec![vec![]].into(),
93 /// output_extents: vec![vec![]].into(),
94 /// last_use: vec![true, true],
95 /// }],
96 /// input_slots: vec![0, 1],
97 /// output_slots: vec![2],
98 /// n_slots: 3,
99 /// };
100 /// let lhs = Tensor::from_vec_col_major(vec![], vec![1.0_f64]).unwrap();
101 /// let rhs = Tensor::from_vec_col_major(vec![], vec![2.0_f64]).unwrap();
102 ///
103 /// let mut backend = CpuBackend::new();
104 /// let mut caches = ExtensionCacheStore::new();
105 /// let mut ctx = ExtensionExecutionContext::new(&mut backend, &mut caches);
106 /// let outputs = ctx
107 /// .execute_core_exec_program_unsegmented(&program, vec![lhs, rhs])
108 /// .unwrap();
109 /// assert_eq!(outputs[0].as_slice::<f64>().unwrap(), &[3.0]);
110 /// ```
111 pub fn execute_core_exec_program_unsegmented(
112 &mut self,
113 program: &crate::extension::ExecProgram,
114 inputs: Vec<Tensor>,
115 ) -> crate::error::Result<Vec<Tensor>>
116 where
117 B: 'static,
118 {
119 crate::exec::ensure_core_exec_program(
120 program,
121 "ExtensionExecutionContext::execute_core_exec_program_unsegmented",
122 )?;
123 crate::exec::eval_exec_ir_unsegmented_with_cache(self.backend, program, inputs)
124 }
125
126 /// Borrow backend and extension cache store as disjoint mutable parts.
127 pub fn parts_mut(&mut self) -> (&mut B, &mut ExtensionCacheStore) {
128 (self.backend, self.caches)
129 }
130}
131
132/// A backend-specific runtime executor for one extension family.
133pub trait ExtensionRuntime<B: TensorBackend + 'static>: Debug + Send + Sync + 'static {
134 /// Extension family handled by this executor.
135 fn family_id(&self) -> &'static str;
136
137 /// Execute the extension op with backend and cache state supplied by core.
138 fn execute(
139 &self,
140 op: &dyn ExtensionOp,
141 inputs: &[&Tensor],
142 ctx: &mut ExtensionExecutionContext<'_, B>,
143 ) -> tenferro_tensor::Result<Vec<Tensor>>;
144
145 /// Execute the extension op on borrowed tensor reads.
146 ///
147 /// Implementations that need compact tensors must materialize inputs here
148 /// explicitly. Keeping this method required prevents implicit read-path
149 /// fallbacks from hiding backend or view handling bugs.
150 fn execute_reads(
151 &self,
152 op: &dyn ExtensionOp,
153 inputs: &[TensorRead<'_>],
154 ctx: &mut ExtensionExecutionContext<'_, B>,
155 ) -> tenferro_tensor::Result<Vec<Tensor>>;
156}
157
158fn validate_runtime_output_count(
159 op: &dyn ExtensionOp,
160 outputs: Vec<Tensor>,
161) -> tenferro_tensor::Result<Vec<Tensor>> {
162 let expected = op.output_count();
163 if outputs.len() != expected {
164 return Err(tenferro_tensor::Error::InvalidConfig {
165 op: "extension",
166 message: format!(
167 "family_id {:?}: runtime returned {} outputs but op declared {} outputs",
168 op.family_id(),
169 outputs.len(),
170 expected
171 ),
172 });
173 }
174 Ok(outputs)
175}
176
177fn validate_runtime_input_count(
178 op: &dyn ExtensionOp,
179 actual: usize,
180) -> tenferro_tensor::Result<()> {
181 let expected = op.input_count();
182 if actual != expected {
183 return Err(tenferro_tensor::Error::InvalidConfig {
184 op: "extension",
185 message: format!(
186 "family_id {:?}: op expects {} inputs, got {}",
187 op.family_id(),
188 expected,
189 actual
190 ),
191 });
192 }
193 Ok(())
194}
195
196/// Registry of backend-specific extension runtime executors.
197pub struct ExtensionRegistry<B: TensorBackend + 'static> {
198 executors: HashMap<&'static str, Arc<dyn ExtensionRuntime<B>>>,
199}
200
201impl<B: TensorBackend + 'static> fmt::Debug for ExtensionRegistry<B> {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 let mut families = self.executors.keys().copied().collect::<Vec<_>>();
204 families.sort_unstable();
205 f.debug_struct("ExtensionRegistry")
206 .field("backend_type", &std::any::type_name::<B>())
207 .field("len", &self.executors.len())
208 .field("families", &families)
209 .finish_non_exhaustive()
210 }
211}
212
213impl<B: TensorBackend + 'static> ExtensionRegistry<B> {
214 /// Create an empty extension runtime registry.
215 ///
216 /// # Examples
217 ///
218 /// ```
219 /// use tenferro_runtime::ExtensionRegistry;
220 /// use tenferro_cpu::CpuBackend;
221 ///
222 /// let registry = ExtensionRegistry::<CpuBackend>::new();
223 /// assert!(!registry.contains("example.identity.v1"));
224 /// ```
225 pub fn new() -> Self {
226 Self {
227 executors: HashMap::new(),
228 }
229 }
230
231 /// Register one runtime executor.
232 ///
233 /// Registration is idempotent by family id: registering the same extension
234 /// family more than once succeeds and keeps the first runtime. This lets
235 /// extension crates register their own dependency extensions defensively.
236 pub fn register(
237 &mut self,
238 executor: Arc<dyn ExtensionRuntime<B>>,
239 ) -> Result<(), ExtensionRuntimeRegistryError> {
240 let family_id = executor.family_id();
241 if !is_valid_family_id(family_id) {
242 return Err(ExtensionRuntimeRegistryError::MalformedFamilyId { family_id });
243 }
244 if self.executors.contains_key(family_id) {
245 return Ok(());
246 }
247 self.executors.insert(family_id, executor);
248 Ok(())
249 }
250
251 /// Look up an executor by extension family id.
252 pub fn get(&self, family_id: &str) -> Option<Arc<dyn ExtensionRuntime<B>>> {
253 self.executors.get(family_id).cloned()
254 }
255
256 /// Return whether an executor is registered for `family_id`.
257 pub fn contains(&self, family_id: &str) -> bool {
258 self.executors.contains_key(family_id)
259 }
260
261 /// Number of registered runtime executors.
262 pub fn len(&self) -> usize {
263 self.executors.len()
264 }
265
266 /// Return whether no runtime executors are registered.
267 pub fn is_empty(&self) -> bool {
268 self.executors.is_empty()
269 }
270}
271
272impl<B: TensorBackend + 'static> Default for ExtensionRegistry<B> {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278/// Runtime owner for backend-specific extension dispatch and caches.
279pub struct ExtensionExecutor<B: TensorBackend + 'static> {
280 registry: ExtensionRegistry<B>,
281 caches: ExtensionCacheStore,
282 _backend: PhantomData<fn() -> B>,
283}
284
285impl<B: TensorBackend + 'static> fmt::Debug for ExtensionExecutor<B> {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 f.debug_struct("ExtensionExecutor")
288 .field("backend_type", &std::any::type_name::<B>())
289 .field("registry", &self.registry)
290 .field("caches", &self.caches)
291 .finish_non_exhaustive()
292 }
293}
294
295impl<B: TensorBackend + 'static> ExtensionExecutor<B> {
296 /// Create an executor with an empty registry and default cache limits.
297 ///
298 /// # Examples
299 ///
300 /// ```
301 /// use tenferro_runtime::ExtensionExecutor;
302 /// use tenferro_cpu::CpuBackend;
303 ///
304 /// let executor = ExtensionExecutor::<CpuBackend>::new();
305 /// assert_eq!(executor.cache_stats().entries, 0);
306 /// ```
307 pub fn new() -> Self {
308 Self {
309 registry: ExtensionRegistry::new(),
310 caches: ExtensionCacheStore::new(),
311 _backend: PhantomData,
312 }
313 }
314
315 /// Create an executor from explicit registry and cache store.
316 pub fn with_parts(registry: ExtensionRegistry<B>, caches: ExtensionCacheStore) -> Self {
317 Self {
318 registry,
319 caches,
320 _backend: PhantomData,
321 }
322 }
323
324 /// Borrow the runtime executor registry.
325 pub fn registry(&self) -> &ExtensionRegistry<B> {
326 &self.registry
327 }
328
329 /// Borrow the runtime executor registry mutably.
330 pub fn registry_mut(&mut self) -> &mut ExtensionRegistry<B> {
331 &mut self.registry
332 }
333
334 /// Borrow the extension cache store.
335 pub fn caches(&self) -> &ExtensionCacheStore {
336 &self.caches
337 }
338
339 /// Borrow the extension cache store mutably.
340 pub fn caches_mut(&mut self) -> &mut ExtensionCacheStore {
341 &mut self.caches
342 }
343
344 /// Execute an extension using a registered runtime executor.
345 pub fn execute(
346 &mut self,
347 backend: &mut B,
348 op: &dyn ExtensionOp,
349 inputs: &[&Tensor],
350 ) -> tenferro_tensor::Result<Vec<Tensor>> {
351 validate_runtime_input_count(op, inputs.len())?;
352 let Some(executor) = self.registry.get(op.family_id()) else {
353 return Err(tenferro_tensor::Error::InvalidConfig {
354 op: "extension",
355 message: format!(
356 "missing runtime for family_id {:?}; register the extension on this runtime owner, for example `executor.register_extension(<extension_crate>::register_runtime)` or `eager_runtime.register_extension(<extension_crate>::register_runtime)`",
357 op.family_id()
358 ),
359 });
360 };
361 let mut ctx = ExtensionExecutionContext::new(backend, &mut self.caches);
362 validate_runtime_output_count(op, executor.execute(op, inputs, &mut ctx)?)
363 }
364
365 /// Execute an extension using borrowed tensor reads.
366 ///
367 /// # Examples
368 ///
369 /// ```
370 /// use std::any::Any;
371 /// use std::hash::Hasher;
372 /// use std::sync::Arc;
373 ///
374 /// use tenferro_cpu::CpuBackend;
375 /// use tenferro_ops::{ext_op::ExtensionOp, SymDim};
376 /// use tenferro_runtime::{
377 /// DType, ExtensionExecutionContext, ExtensionExecutor, ExtensionRuntime, Tensor,
378 /// };
379 /// use tenferro_tensor::TensorRead;
380 ///
381 /// #[derive(Clone, Debug)]
382 /// struct IdentityOp;
383 ///
384 /// impl ExtensionOp for IdentityOp {
385 /// fn family_id(&self) -> &'static str {
386 /// "example.identity.v1"
387 /// }
388 ///
389 /// fn payload_hash(&self, _hasher: &mut dyn Hasher) {}
390 ///
391 /// fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
392 /// other.as_any().is::<IdentityOp>()
393 /// }
394 ///
395 /// fn clone_arc(&self) -> Arc<dyn ExtensionOp> {
396 /// Arc::new(self.clone())
397 /// }
398 ///
399 /// fn as_any(&self) -> &dyn Any {
400 /// self
401 /// }
402 ///
403 /// fn input_count(&self) -> usize {
404 /// 1
405 /// }
406 ///
407 /// fn output_count(&self) -> usize {
408 /// 1
409 /// }
410 ///
411 /// fn infer_output_meta(
412 /// &self,
413 /// input_dtypes: &[DType],
414 /// input_shapes: &[&[SymDim]],
415 /// ) -> Vec<(DType, Vec<SymDim>)> {
416 /// vec![(input_dtypes[0], input_shapes[0].to_vec())]
417 /// }
418 ///
419 /// fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
420 /// Ok(vec![inputs[0].clone()])
421 /// }
422 /// }
423 ///
424 /// #[derive(Debug)]
425 /// struct IdentityRuntime;
426 ///
427 /// impl ExtensionRuntime<CpuBackend> for IdentityRuntime {
428 /// fn family_id(&self) -> &'static str {
429 /// "example.identity.v1"
430 /// }
431 ///
432 /// fn execute(
433 /// &self,
434 /// op: &dyn ExtensionOp,
435 /// inputs: &[&Tensor],
436 /// _ctx: &mut ExtensionExecutionContext<'_, CpuBackend>,
437 /// ) -> tenferro_tensor::Result<Vec<Tensor>> {
438 /// op.eager_execute(inputs)
439 /// }
440 ///
441 /// fn execute_reads(
442 /// &self,
443 /// op: &dyn ExtensionOp,
444 /// inputs: &[TensorRead<'_>],
445 /// ctx: &mut ExtensionExecutionContext<'_, CpuBackend>,
446 /// ) -> tenferro_tensor::Result<Vec<Tensor>> {
447 /// let materialized_inputs: Vec<Tensor> = inputs
448 /// .iter()
449 /// .map(TensorRead::to_tensor)
450 /// .collect::<tenferro_tensor::Result<_>>()?;
451 /// let input_refs: Vec<&Tensor> = materialized_inputs.iter().collect();
452 /// self.execute(op, &input_refs, ctx)
453 /// }
454 /// }
455 ///
456 /// let mut executor = ExtensionExecutor::<CpuBackend>::new();
457 /// executor.registry_mut().register(Arc::new(IdentityRuntime))?;
458 /// let input = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
459 /// let read = TensorRead::from_tensor(&input);
460 /// let mut backend = CpuBackend::new();
461 ///
462 /// let outputs = executor.execute_reads(&mut backend, &IdentityOp, &[read])?;
463 ///
464 /// assert_eq!(outputs[0].as_slice::<f64>().unwrap(), &[1.0, 2.0]);
465 /// # Ok::<(), Box<dyn std::error::Error>>(())
466 /// ```
467 pub fn execute_reads(
468 &mut self,
469 backend: &mut B,
470 op: &dyn ExtensionOp,
471 inputs: &[TensorRead<'_>],
472 ) -> tenferro_tensor::Result<Vec<Tensor>> {
473 validate_runtime_input_count(op, inputs.len())?;
474 let Some(executor) = self.registry.get(op.family_id()) else {
475 return Err(tenferro_tensor::Error::InvalidConfig {
476 op: "extension",
477 message: format!(
478 "missing runtime for family_id {:?}; register the extension on this runtime owner, for example `executor.register_extension(<extension_crate>::register_runtime)` or `eager_runtime.register_extension(<extension_crate>::register_runtime)`",
479 op.family_id()
480 ),
481 });
482 };
483 let mut ctx = ExtensionExecutionContext::new(backend, &mut self.caches);
484 validate_runtime_output_count(op, executor.execute_reads(op, inputs, &mut ctx)?)
485 }
486
487 /// Clear every runtime extension cache entry.
488 pub fn clear_caches(&mut self) {
489 self.caches.clear();
490 }
491
492 /// Return extension cache stats for all entries.
493 pub fn cache_stats(&self) -> CacheStats {
494 self.caches.stats(ExtensionCacheSelector::All)
495 }
496
497 /// Return the extension cache retention limits.
498 pub fn cache_limits(&self) -> ExtensionCacheLimits {
499 self.caches.limits()
500 }
501
502 /// Replace extension cache retention limits.
503 pub fn set_cache_limits(&mut self, limits: ExtensionCacheLimits) {
504 self.caches.set_limits(limits);
505 }
506}
507
508impl<B: TensorBackend + 'static> Default for ExtensionExecutor<B> {
509 fn default() -> Self {
510 Self::new()
511 }
512}
513
514#[cfg(test)]
515mod tests;
516
517fn is_valid_family_id(family_id: &str) -> bool {
518 let mut parts = family_id.rsplitn(2, '.');
519 let Some(version_part) = parts.next() else {
520 return false;
521 };
522 let Some(prefix) = parts.next() else {
523 return false;
524 };
525 if !version_part.starts_with('v') {
526 return false;
527 }
528 let digits = &version_part[1..];
529 if digits.is_empty() || !digits.chars().all(|c| c.is_ascii_digit()) {
530 return false;
531 }
532 let Some((crate_name, op_name)) = prefix.split_once('.') else {
533 return false;
534 };
535 if crate_name.is_empty() || op_name.is_empty() {
536 return false;
537 }
538 let any_invalid = |s: &str| s.chars().any(|c| c.is_whitespace() || !c.is_ascii());
539 !any_invalid(crate_name) && !any_invalid(op_name)
540}