pub enum PrimDescriptor {
BatchedGemm {
batch_dims: Vec<usize>,
m: usize,
n: usize,
k: usize,
},
Reduce {
modes_a: Vec<u32>,
modes_c: Vec<u32>,
op: ReduceOp,
},
Trace {
modes_a: Vec<u32>,
modes_c: Vec<u32>,
paired: Vec<(u32, u32)>,
},
Permute {
modes_a: Vec<u32>,
modes_b: Vec<u32>,
},
AntiTrace {
modes_a: Vec<u32>,
modes_c: Vec<u32>,
paired: Vec<(u32, u32)>,
},
AntiDiag {
modes_a: Vec<u32>,
modes_c: Vec<u32>,
paired: Vec<(u32, u32)>,
},
ElementwiseUnary {
op: UnaryOp,
},
Contract {
modes_a: Vec<u32>,
modes_b: Vec<u32>,
modes_c: Vec<u32>,
},
ElementwiseMul,
}Expand description
Describes a tensor primitive operation.
All operations follow the cuTENSOR pattern: describe → plan → execute.
Core operations must be supported by every backend. Extended operations
are dynamically queried via TensorPrims::has_extension_for.
Modes are u32 integer labels matching cuTENSOR conventions. Modes
shared between input and output tensors are batch/free dimensions;
modes present only in inputs are contracted.
§Examples
use tenferro_prims::PrimDescriptor;
// Matrix multiplication: C_{m,n} = A_{m,k} * B_{k,n}
let desc = PrimDescriptor::Contract {
modes_a: vec![0, 1], // m=0, k=1
modes_b: vec![1, 2], // k=1, n=2
modes_c: vec![0, 2], // m=0, n=2
};Variants§
BatchedGemm
Batched matrix multiplication.
C[batch, m, n] = alpha * A[batch, m, k] * B[batch, k, n] + beta * C
Fields
Reduce
Reduction over modes not present in the output.
C[modes_c] = alpha * reduce_op(A[modes_a]) + beta * C[modes_c]
Fields
Trace
Trace: contraction of paired diagonal modes.
For each pair (i, j), sums over the diagonal where mode i == mode j.
Fields
Permute
Permute (reorder) tensor modes.
B[modes_b] = alpha * A[modes_a]
Fields
AntiTrace
Anti-trace: scatter-add gradient to diagonal (AD backward of trace).
Fields
AntiDiag
Anti-diag: write gradient to diagonal positions (AD backward of diag).
Fields
ElementwiseUnary
Element-wise unary operation.
C[modes] = alpha * op(A[modes]) + beta * C[modes]
§Examples
use tenferro_prims::{PrimDescriptor, UnaryOp};
// Reciprocal: C = 1/A
let desc = PrimDescriptor::ElementwiseUnary {
op: UnaryOp::Reciprocal,
};Contract
Fused contraction: permute + GEMM in one operation.
C[modes_c] = alpha * contract(A[modes_a], B[modes_b]) + beta * C
Available when has_extension_for::<T>(Extension::Contract) returns true.
Fields
ElementwiseMul
Element-wise multiplication of two tensors.
Available when has_extension_for::<T>(Extension::ElementwiseMul) returns true.