pub trait ExtensionOp:
Debug
+ Send
+ Sync
+ 'static {
// Required methods
fn family_id(&self) -> &'static str;
fn payload_hash(&self, hasher: &mut dyn Hasher);
fn payload_eq(&self, other: &dyn ExtensionOp) -> bool;
fn clone_arc(&self) -> Arc<dyn ExtensionOp>;
fn as_any(&self) -> &dyn Any;
fn input_count(&self) -> usize;
fn output_count(&self) -> usize;
fn infer_output_meta(
&self,
input_dtypes: &[DType],
input_shapes: &[&[SymDim]],
) -> Vec<(DType, Vec<SymDim>)>;
fn eager_execute(&self, inputs: &[&Tensor]) -> Result<Vec<Tensor>>;
// Provided method
fn lower_to_standard_ops(
&self,
_builder: &mut GraphBuilder<StdTensorOp>,
_inputs: &[ValueRef<StdTensorOp>],
_input_dtypes: &[DType],
_input_shapes: &[&[SymDim]],
) -> ExtensionLoweringResult { ... }
}Expand description
The contract every out-of-tree extension primitive must satisfy.
Implementations appear in the core graph as
StdTensorOp::Extension(Arc<dyn ExtensionOp>). Every method is part of the
ExtensionOp spec (docs/spec/extension-op.md); the short form:
- identity via
family_id+payload_hash - fixed arity via
input_count/output_count; - shape / dtype inference via
infer_output_meta; - host/reference forward execution via
eager_execute; runtime-owned eager and compiled paths dispatch through registered extension runtimes instead of falling back to this method; - optional fixed-shape standard-op expansion via
lower_to_standard_opsfor peer lowerers such as XLA that cannot execute extension runtimes; - AD via a separately registered
ExtensionAdRule.
§Downcast convention
Implementations MUST also implement Any so that
ExtensionOp::payload_eq can downcast a trait-object reference to
the concrete type. The helper ExtensionOp::as_any returns
&dyn Any for this purpose. Implementations usually define it as
fn as_any(&self) -> &dyn Any { self }.
§Examples
use std::sync::Arc;
use tenferro_ops::ext_op::ExtensionOp;
use tenferro_ops::SymDim;
use tenferro_tensor::{DType, Tensor};
#[derive(Clone, Debug)]
struct IdentityExt;
impl ExtensionOp for IdentityExt {
fn family_id(&self) -> &'static str { "example.identity.v1" }
fn payload_hash(&self, _hasher: &mut dyn std::hash::Hasher) {}
fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
other.as_any().downcast_ref::<IdentityExt>().is_some()
}
fn clone_arc(&self) -> Arc<dyn ExtensionOp> { Arc::new(self.clone()) }
fn as_any(&self) -> &dyn Any { self }
fn input_count(&self) -> usize { 1 }
fn output_count(&self) -> usize { 1 }
fn infer_output_meta(
&self,
dtypes: &[DType],
shapes: &[&[SymDim]],
) -> Vec<(DType, Vec<SymDim>)> {
vec![(dtypes[0], shapes[0].to_vec())]
}
fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
Ok(vec![inputs[0].clone()])
}
}
let op: Arc<dyn ExtensionOp> = Arc::new(IdentityExt);
assert_eq!(op.input_count(), 1);Required Methods§
Sourcefn family_id(&self) -> &'static str
fn family_id(&self) -> &'static str
Stable, process-independent family identifier.
MUST be unique per extension family (payload schema), not per
instance, and MUST follow the reserved format
"<crate-name>.<op-name>.v<major>".
Sourcefn payload_hash(&self, hasher: &mut dyn Hasher)
fn payload_hash(&self, hasher: &mut dyn Hasher)
Hash the payload (everything except family_id).
Implementations MUST be pure and deterministic across calls on the same value. Hashes MUST NOT include transient state such as allocation addresses or atomically updated counters.
Sourcefn payload_eq(&self, other: &dyn ExtensionOp) -> bool
fn payload_eq(&self, other: &dyn ExtensionOp) -> bool
Structural equality against another extension value.
The carrier’s PartialEq impl first compares family_ids. When the
family IDs match, it calls payload_eq. Implementations MUST return
true iff the payloads are semantically equal AND
other.family_id() == self.family_id().
Sourcefn clone_arc(&self) -> Arc<dyn ExtensionOp>
fn clone_arc(&self) -> Arc<dyn ExtensionOp>
Deep-clone the payload behind an Arc.
The carrier’s Clone impl uses Arc::clone on the fast path; this
method exists for rare cases that need a second independent Arc.
Sourcefn as_any(&self) -> &dyn Any
fn as_any(&self) -> &dyn Any
Upcast this extension to &dyn Any for downcasting in payload_eq.
Implementations SHOULD return self verbatim. The method is
object-safe (no Self: Sized bound) so it can be called on an
&dyn ExtensionOp; that’s what makes
other.as_any().downcast_ref::<ConcreteType>() work from
Self::payload_eq implementations.
Sourcefn input_count(&self) -> usize
fn input_count(&self) -> usize
Number of primal inputs. MUST be constant for any given
Arc<dyn ExtensionOp> value.
Sourcefn output_count(&self) -> usize
fn output_count(&self) -> usize
Number of outputs. MUST match the length of the vector returned by
Self::infer_output_meta.
Sourcefn infer_output_meta(
&self,
input_dtypes: &[DType],
input_shapes: &[&[SymDim]],
) -> Vec<(DType, Vec<SymDim>)>
fn infer_output_meta( &self, input_dtypes: &[DType], input_shapes: &[&[SymDim]], ) -> Vec<(DType, Vec<SymDim>)>
Infer output dtypes and shapes for each output slot.
input_dtypes.len() and input_shapes.len() both equal
self.input_count(). The returned vector MUST have length
self.output_count(), one (dtype, shape) entry per output slot.
Shapes use SymDim so extension ops compose with graph-global
symbolic metadata.
Sourcefn eager_execute(&self, inputs: &[&Tensor]) -> Result<Vec<Tensor>>
fn eager_execute(&self, inputs: &[&Tensor]) -> Result<Vec<Tensor>>
Eager forward execution; called from the eager path and indirectly from the compiled path.
Input tensors are on the device the caller already arranged. Output
tensors MUST have shapes matching Self::infer_output_meta and MUST
be placed on a device the caller can consume.
Provided Methods§
Sourcefn lower_to_standard_ops(
&self,
_builder: &mut GraphBuilder<StdTensorOp>,
_inputs: &[ValueRef<StdTensorOp>],
_input_dtypes: &[DType],
_input_shapes: &[&[SymDim]],
) -> ExtensionLoweringResult
fn lower_to_standard_ops( &self, _builder: &mut GraphBuilder<StdTensorOp>, _inputs: &[ValueRef<StdTensorOp>], _input_dtypes: &[DType], _input_shapes: &[&[SymDim]], ) -> ExtensionLoweringResult
Optionally expand this extension into standard tensor graph operations.
Peer lowerers call this when all input metadata is known and extension
runtime dispatch is not available. Return Ok(Some(outputs)) after
adding only standard StdTensorOp operations to builder. Return
Ok(None) when this extension family has no standard-op lowering for
the supplied metadata; strict lowerers should surface that as an
explicit unsupported-extension error. Return ExtensionLoweringError
when the payload is malformed or the lowering detects invalid metadata.
The default implementation returns Ok(None) so existing extension
runtimes keep their native dispatch behavior until their owning crate
deliberately implements this hook.