Skip to main content

ExtensionOp

Trait ExtensionOp 

Source
pub trait ExtensionOp:
    Debug
    + Send
    + Sync
    + 'static {
    // Required methods
    fn family_id(&self) -> &'static str;
    fn payload_hash(&self, hasher: &mut dyn Hasher);
    fn payload_eq(&self, other: &dyn ExtensionOp) -> bool;
    fn clone_arc(&self) -> Arc<dyn ExtensionOp>;
    fn as_any(&self) -> &dyn Any;
    fn input_count(&self) -> usize;
    fn output_count(&self) -> usize;
    fn infer_output_meta(
        &self,
        input_dtypes: &[DType],
        input_shapes: &[&[SymDim]],
    ) -> Vec<(DType, Vec<SymDim>)>;
    fn eager_execute(&self, inputs: &[&Tensor]) -> Result<Vec<Tensor>>;

    // Provided method
    fn lower_to_standard_ops(
        &self,
        _builder: &mut GraphBuilder<StdTensorOp>,
        _inputs: &[ValueRef<StdTensorOp>],
        _input_dtypes: &[DType],
        _input_shapes: &[&[SymDim]],
    ) -> ExtensionLoweringResult { ... }
}
Expand description

The contract every out-of-tree extension primitive must satisfy.

Implementations appear in the core graph as StdTensorOp::Extension(Arc<dyn ExtensionOp>). Every method is part of the ExtensionOp spec (docs/spec/extension-op.md); the short form:

§Downcast convention

Implementations MUST also implement Any so that ExtensionOp::payload_eq can downcast a trait-object reference to the concrete type. The helper ExtensionOp::as_any returns &dyn Any for this purpose. Implementations usually define it as fn as_any(&self) -> &dyn Any { self }.

§Examples

use std::sync::Arc;
use tenferro_ops::ext_op::ExtensionOp;
use tenferro_ops::SymDim;
use tenferro_tensor::{DType, Tensor};

#[derive(Clone, Debug)]
struct IdentityExt;

impl ExtensionOp for IdentityExt {
    fn family_id(&self) -> &'static str { "example.identity.v1" }
    fn payload_hash(&self, _hasher: &mut dyn std::hash::Hasher) {}
    fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
        other.as_any().downcast_ref::<IdentityExt>().is_some()
    }
    fn clone_arc(&self) -> Arc<dyn ExtensionOp> { Arc::new(self.clone()) }
    fn as_any(&self) -> &dyn Any { self }
    fn input_count(&self) -> usize { 1 }
    fn output_count(&self) -> usize { 1 }
    fn infer_output_meta(
        &self,
        dtypes: &[DType],
        shapes: &[&[SymDim]],
    ) -> Vec<(DType, Vec<SymDim>)> {
        vec![(dtypes[0], shapes[0].to_vec())]
    }
    fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
        Ok(vec![inputs[0].clone()])
    }
}

let op: Arc<dyn ExtensionOp> = Arc::new(IdentityExt);
assert_eq!(op.input_count(), 1);

Required Methods§

Source

fn family_id(&self) -> &'static str

Stable, process-independent family identifier.

MUST be unique per extension family (payload schema), not per instance, and MUST follow the reserved format "<crate-name>.<op-name>.v<major>".

Source

fn payload_hash(&self, hasher: &mut dyn Hasher)

Hash the payload (everything except family_id).

Implementations MUST be pure and deterministic across calls on the same value. Hashes MUST NOT include transient state such as allocation addresses or atomically updated counters.

Source

fn payload_eq(&self, other: &dyn ExtensionOp) -> bool

Structural equality against another extension value.

The carrier’s PartialEq impl first compares family_ids. When the family IDs match, it calls payload_eq. Implementations MUST return true iff the payloads are semantically equal AND other.family_id() == self.family_id().

Source

fn clone_arc(&self) -> Arc<dyn ExtensionOp>

Deep-clone the payload behind an Arc.

The carrier’s Clone impl uses Arc::clone on the fast path; this method exists for rare cases that need a second independent Arc.

Source

fn as_any(&self) -> &dyn Any

Upcast this extension to &dyn Any for downcasting in payload_eq.

Implementations SHOULD return self verbatim. The method is object-safe (no Self: Sized bound) so it can be called on an &dyn ExtensionOp; that’s what makes other.as_any().downcast_ref::<ConcreteType>() work from Self::payload_eq implementations.

Source

fn input_count(&self) -> usize

Number of primal inputs. MUST be constant for any given Arc<dyn ExtensionOp> value.

Source

fn output_count(&self) -> usize

Number of outputs. MUST match the length of the vector returned by Self::infer_output_meta.

Source

fn infer_output_meta( &self, input_dtypes: &[DType], input_shapes: &[&[SymDim]], ) -> Vec<(DType, Vec<SymDim>)>

Infer output dtypes and shapes for each output slot.

input_dtypes.len() and input_shapes.len() both equal self.input_count(). The returned vector MUST have length self.output_count(), one (dtype, shape) entry per output slot. Shapes use SymDim so extension ops compose with graph-global symbolic metadata.

Source

fn eager_execute(&self, inputs: &[&Tensor]) -> Result<Vec<Tensor>>

Eager forward execution; called from the eager path and indirectly from the compiled path.

Input tensors are on the device the caller already arranged. Output tensors MUST have shapes matching Self::infer_output_meta and MUST be placed on a device the caller can consume.

Provided Methods§

Source

fn lower_to_standard_ops( &self, _builder: &mut GraphBuilder<StdTensorOp>, _inputs: &[ValueRef<StdTensorOp>], _input_dtypes: &[DType], _input_shapes: &[&[SymDim]], ) -> ExtensionLoweringResult

Optionally expand this extension into standard tensor graph operations.

Peer lowerers call this when all input metadata is known and extension runtime dispatch is not available. Return Ok(Some(outputs)) after adding only standard StdTensorOp operations to builder. Return Ok(None) when this extension family has no standard-op lowering for the supplied metadata; strict lowerers should surface that as an explicit unsupported-extension error. Return ExtensionLoweringError when the payload is malformed or the lowering detects invalid metadata.

The default implementation returns Ok(None) so existing extension runtimes keep their native dispatch behavior until their owning crate deliberately implements this hook.

Implementors§