TensorPrims

Trait TensorPrims 

Source
pub trait TensorPrims<A> {
    type Plan<T: ScalarBase>;

    // Required methods
    fn plan<T: ScalarBase>(
        desc: &PrimDescriptor,
        shapes: &[&[usize]],
    ) -> Result<Self::Plan<T>>;
    fn execute<T: ScalarBase>(
        plan: &Self::Plan<T>,
        alpha: T,
        inputs: &[&StridedView<'_, T>],
        beta: T,
        output: &mut StridedViewMut<'_, T>,
    ) -> Result<()>;
    fn has_extension_for<T: ScalarBase>(ext: Extension) -> bool;
}
Expand description

Backend trait for tensor primitive operations, parameterized by algebra A.

Provides a cuTENSOR-compatible plan-based execution model for all operations. Core ops (batched_gemm, reduce, trace, permute, anti_trace, anti_diag) must be implemented. Extended ops (contract, elementwise_mul) are dynamically queried via has_extension_for.

§Algebra parameterization

The algebra parameter A enables extensibility: external crates can implement TensorPrims<MyAlgebra> for CpuBackend (orphan rule compatible).

§Associated functions (not methods)

All functions are associated functions (no &self receiver). Call as CpuBackend::plan::<f64>(...) instead of backend.plan(...).

§Examples

use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};

// Plan a batched GEMM
let desc = PrimDescriptor::BatchedGemm {
    batch_dims: vec![], m: 3, n: 5, k: 4,
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();

// Execute: C = 1.0 * A*B + 0.0 * C
CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();

Required Associated Types§

Source

type Plan<T: ScalarBase>

Backend-specific plan type (no type erasure).

Required Methods§

Source

fn plan<T: ScalarBase>( desc: &PrimDescriptor, shapes: &[&[usize]], ) -> Result<Self::Plan<T>>

Create an execution plan from an operation descriptor.

The plan pre-computes kernel selection and workspace sizes. shapes contains the shape of each tensor involved in the operation (inputs first, then output).

§Examples
use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};

let desc = PrimDescriptor::Reduce {
    modes_a: vec![0, 1], modes_c: vec![0], op: ReduceOp::Sum,
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
Source

fn execute<T: ScalarBase>( plan: &Self::Plan<T>, alpha: T, inputs: &[&StridedView<'_, T>], beta: T, output: &mut StridedViewMut<'_, T>, ) -> Result<()>

Execute a plan with the given scaling factors and tensor views.

Follows the BLAS/cuTENSOR pattern: output = alpha * op(inputs) + beta * output

§Examples
// Execute: output = 1.0 * gemm(a, b) + 0.0 * output  (overwrite)
CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();

// Accumulate: output = 1.0 * gemm(a, b) + 1.0 * output  (add)
CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 1.0, &mut c.view_mut()).unwrap();
Source

fn has_extension_for<T: ScalarBase>(ext: Extension) -> bool

Query whether an extended operation is available for scalar type T.

Returns true if the backend supports the given extended operation for the specified scalar type. This enables dynamic dispatch: GPU may support Contract for f64 but not for tropical types.

§Examples
use tenferro_prims::{CpuBackend, TensorPrims, Extension};

if CpuBackend::has_extension_for::<f64>(Extension::Contract) {
    // Use fused contraction for better performance
} else {
    // Decompose into core ops: permute → batched_gemm
}

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl TensorPrims<Standard> for CpuBackend

Source§

type Plan<T: ScalarBase> = CpuPlan<T>