Skip to main content

apply

Function apply 

Source
pub fn apply(
    op: Arc<dyn ExtensionOp>,
    inputs: &[&TracedTensor],
) -> Result<Vec<TracedTensor>, Error>
Expand description

Apply an extension op in the traced graph.

The op value is cloned into a StdTensorOp::Extension(Arc<dyn ExtensionOp>) carrier. The returned vector contains one TracedTensor per declared output slot of the extension. Output shapes are inferred via ExtensionOp::infer_output_meta using the input shape hints.

inputs.len() must equal op.input_count(), and each input’s shape_hint must be present (i.e. the extension must be used on tensors whose rank is known at graph-build time). For symbolic-shape composition, bind the placeholder tensors via crate::GraphExecutor::run_with_inputs at evaluation time.

§Examples

use std::sync::Arc;
use tenferro_runtime::extension::{apply, ExtensionOp};
use tenferro_runtime::{DType, SymDim, Tensor, TracedTensor};

let op: Arc<dyn ExtensionOp> = Arc::new(IdentityExt);
let a = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
let outputs = apply(op, &[&a])?;
assert_eq!(outputs.len(), 1);

§Errors

Returns Error::InvalidGraphBuild when the extension receives the wrong number of inputs or when ExtensionOp::infer_output_meta returns metadata whose count does not match ExtensionOp::output_count.