Expand description
Tensor primitive operations for the tenferro workspace.
This crate defines the TensorPrims<A> trait, a backend-agnostic interface
parameterized by algebra A. The API follows the cuTENSOR plan-based execution
pattern:
- Create a
PrimDescriptorspecifying the operation and index modes - Build a plan via
TensorPrims::plan(pre-computes kernel selection) - Execute the plan via
TensorPrims::execute
§Operation categories
Core operations (every backend must implement):
BatchedGemm: Batched matrix multiplicationReduce: Sum/max/min reduction over modesTrace: Trace (contraction of paired diagonal modes)Permute: Mode reorderingAntiTrace: Scatter-add to diagonal (AD backward of trace)AntiDiag: Write to diagonal positions (AD backward of diag)ElementwiseUnary: Point-wise unary transform (negate, reciprocal, abs, sqrt)
Extended operations (dynamically queried via TensorPrims::has_extension_for):
Contract: Fused permute + GEMM contractionElementwiseMul: Element-wise multiplication
§Algebra parameterization
TensorPrims<A> is parameterized by algebra A (e.g.,
Standard, MaxPlus).
External crates implement TensorPrims<MyAlgebra> for CpuBackend (orphan rule
compatible). The HasAlgebra trait on scalar types
enables automatic inference: Tensor<f64> → Standard.
§Examples
§Plan-based GEMM
ⓘ
use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
use tenferro_algebra::Standard;
use strided_view::StridedArray;
let a = StridedArray::<f64>::col_major(&[3, 4]);
let b = StridedArray::<f64>::col_major(&[4, 5]);
let mut c = StridedArray::<f64>::col_major(&[3, 5]);
let desc = PrimDescriptor::BatchedGemm {
batch_dims: vec![], m: 3, n: 5, k: 4,
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();§Reduction (sum over an axis)
ⓘ
use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};
// Sum over columns: c_i = Σ_j A_{i,j}
let desc = PrimDescriptor::Reduce {
modes_a: vec![0, 1], modes_c: vec![0], op: ReduceOp::Sum,
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
CpuBackend::execute(&plan, 1.0, &[&a.view()], 0.0, &mut c.view_mut()).unwrap();§Dynamic extension check
ⓘ
use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, Extension};
if CpuBackend::has_extension_for::<f64>(Extension::Contract) {
let desc = PrimDescriptor::Contract {
modes_a: vec![0, 1], modes_b: vec![1, 2], modes_c: vec![0, 2],
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
CpuBackend::execute(
&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut(),
).unwrap();
}Structs§
- CpuBackend
- CPU backend using strided-kernel and GEMM.
Enums§
- CpuPlan
- CPU plan — concrete enum, no type erasure.
- Extension
- Extended operation identifiers for dynamic capability query.
- Prim
Descriptor - Describes a tensor primitive operation.
- Reduce
Op - Reduction operation kind.
- UnaryOp
- Element-wise unary operation kind.
Traits§
- Tensor
Prims - Backend trait for tensor primitive operations, parameterized by algebra
A.