pub trait TensorPrims<A> {
type Plan<T: ScalarBase>;
// Required methods
fn plan<T: ScalarBase>(
desc: &PrimDescriptor,
shapes: &[&[usize]],
) -> Result<Self::Plan<T>>;
fn execute<T: ScalarBase>(
plan: &Self::Plan<T>,
alpha: T,
inputs: &[&StridedView<'_, T>],
beta: T,
output: &mut StridedViewMut<'_, T>,
) -> Result<()>;
fn has_extension_for<T: ScalarBase>(ext: Extension) -> bool;
}Expand description
Backend trait for tensor primitive operations, parameterized by algebra A.
Provides a cuTENSOR-compatible plan-based execution model for all
operations. Core ops (batched_gemm, reduce, trace, permute, anti_trace,
anti_diag) must be implemented. Extended ops (contract, elementwise_mul)
are dynamically queried via has_extension_for.
§Algebra parameterization
The algebra parameter A enables extensibility: external crates can
implement TensorPrims<MyAlgebra> for CpuBackend (orphan rule compatible).
§Associated functions (not methods)
All functions are associated functions (no &self receiver). Call as
CpuBackend::plan::<f64>(...) instead of backend.plan(...).
§Examples
use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
// Plan a batched GEMM
let desc = PrimDescriptor::BatchedGemm {
batch_dims: vec![], m: 3, n: 5, k: 4,
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
// Execute: C = 1.0 * A*B + 0.0 * C
CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();Required Associated Types§
Required Methods§
Sourcefn plan<T: ScalarBase>(
desc: &PrimDescriptor,
shapes: &[&[usize]],
) -> Result<Self::Plan<T>>
fn plan<T: ScalarBase>( desc: &PrimDescriptor, shapes: &[&[usize]], ) -> Result<Self::Plan<T>>
Create an execution plan from an operation descriptor.
The plan pre-computes kernel selection and workspace sizes.
shapes contains the shape of each tensor involved in the operation
(inputs first, then output).
§Examples
use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};
let desc = PrimDescriptor::Reduce {
modes_a: vec![0, 1], modes_c: vec![0], op: ReduceOp::Sum,
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();Sourcefn execute<T: ScalarBase>(
plan: &Self::Plan<T>,
alpha: T,
inputs: &[&StridedView<'_, T>],
beta: T,
output: &mut StridedViewMut<'_, T>,
) -> Result<()>
fn execute<T: ScalarBase>( plan: &Self::Plan<T>, alpha: T, inputs: &[&StridedView<'_, T>], beta: T, output: &mut StridedViewMut<'_, T>, ) -> Result<()>
Execute a plan with the given scaling factors and tensor views.
Follows the BLAS/cuTENSOR pattern:
output = alpha * op(inputs) + beta * output
§Examples
// Execute: output = 1.0 * gemm(a, b) + 0.0 * output (overwrite)
CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();
// Accumulate: output = 1.0 * gemm(a, b) + 1.0 * output (add)
CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 1.0, &mut c.view_mut()).unwrap();Sourcefn has_extension_for<T: ScalarBase>(ext: Extension) -> bool
fn has_extension_for<T: ScalarBase>(ext: Extension) -> bool
Query whether an extended operation is available for scalar type T.
Returns true if the backend supports the given extended operation
for the specified scalar type. This enables dynamic dispatch:
GPU may support Contract for f64 but not for tropical types.
§Examples
use tenferro_prims::{CpuBackend, TensorPrims, Extension};
if CpuBackend::has_extension_for::<f64>(Extension::Contract) {
// Use fused contraction for better performance
} else {
// Decompose into core ops: permute → batched_gemm
}Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.