PrimDescriptor

Enum PrimDescriptor 

Source
pub enum PrimDescriptor {
    BatchedGemm {
        batch_dims: Vec<usize>,
        m: usize,
        n: usize,
        k: usize,
    },
    Reduce {
        modes_a: Vec<u32>,
        modes_c: Vec<u32>,
        op: ReduceOp,
    },
    Trace {
        modes_a: Vec<u32>,
        modes_c: Vec<u32>,
        paired: Vec<(u32, u32)>,
    },
    Permute {
        modes_a: Vec<u32>,
        modes_b: Vec<u32>,
    },
    AntiTrace {
        modes_a: Vec<u32>,
        modes_c: Vec<u32>,
        paired: Vec<(u32, u32)>,
    },
    AntiDiag {
        modes_a: Vec<u32>,
        modes_c: Vec<u32>,
        paired: Vec<(u32, u32)>,
    },
    ElementwiseUnary {
        op: UnaryOp,
    },
    Contract {
        modes_a: Vec<u32>,
        modes_b: Vec<u32>,
        modes_c: Vec<u32>,
    },
    ElementwiseMul,
}
Expand description

Describes a tensor primitive operation.

All operations follow the cuTENSOR pattern: describe → plan → execute. Core operations must be supported by every backend. Extended operations are dynamically queried via TensorPrims::has_extension_for.

Modes are u32 integer labels matching cuTENSOR conventions. Modes shared between input and output tensors are batch/free dimensions; modes present only in inputs are contracted.

§Examples

use tenferro_prims::PrimDescriptor;

// Matrix multiplication: C_{m,n} = A_{m,k} * B_{k,n}
let desc = PrimDescriptor::Contract {
    modes_a: vec![0, 1],  // m=0, k=1
    modes_b: vec![1, 2],  // k=1, n=2
    modes_c: vec![0, 2],  // m=0, n=2
};

Variants§

§

BatchedGemm

Batched matrix multiplication.

C[batch, m, n] = alpha * A[batch, m, k] * B[batch, k, n] + beta * C

Fields

§batch_dims: Vec<usize>

Batch dimension sizes.

§m: usize

Number of rows in A / C.

§n: usize

Number of columns in B / C.

§k: usize

Contraction dimension (columns of A / rows of B).

§

Reduce

Reduction over modes not present in the output.

C[modes_c] = alpha * reduce_op(A[modes_a]) + beta * C[modes_c]

Fields

§modes_a: Vec<u32>

Mode labels for input tensor A.

§modes_c: Vec<u32>

Mode labels for output tensor C (subset of modes_a).

§op: ReduceOp

Reduction operation (Sum, Max, Min).

§

Trace

Trace: contraction of paired diagonal modes.

For each pair (i, j), sums over the diagonal where mode i == mode j.

Fields

§modes_a: Vec<u32>

Mode labels for input tensor A.

§modes_c: Vec<u32>

Mode labels for output tensor C.

§paired: Vec<(u32, u32)>

Pairs of modes to trace over.

§

Permute

Permute (reorder) tensor modes.

B[modes_b] = alpha * A[modes_a]

Fields

§modes_a: Vec<u32>

Mode labels for input tensor A.

§modes_b: Vec<u32>

Mode labels for output tensor B (same labels, different order).

§

AntiTrace

Anti-trace: scatter-add gradient to diagonal (AD backward of trace).

Fields

§modes_a: Vec<u32>

Mode labels for input tensor A.

§modes_c: Vec<u32>

Mode labels for output tensor C.

§paired: Vec<(u32, u32)>

Pairs of modes for diagonal scatter.

§

AntiDiag

Anti-diag: write gradient to diagonal positions (AD backward of diag).

Fields

§modes_a: Vec<u32>

Mode labels for input tensor A.

§modes_c: Vec<u32>

Mode labels for output tensor C.

§paired: Vec<(u32, u32)>

Pairs of modes for diagonal write.

§

ElementwiseUnary

Element-wise unary operation.

C[modes] = alpha * op(A[modes]) + beta * C[modes]

§Examples

use tenferro_prims::{PrimDescriptor, UnaryOp};

// Reciprocal: C = 1/A
let desc = PrimDescriptor::ElementwiseUnary {
    op: UnaryOp::Reciprocal,
};

Fields

§op: UnaryOp

Unary operation to apply.

§

Contract

Fused contraction: permute + GEMM in one operation.

C[modes_c] = alpha * contract(A[modes_a], B[modes_b]) + beta * C

Available when has_extension_for::<T>(Extension::Contract) returns true.

Fields

§modes_a: Vec<u32>

Mode labels for input tensor A.

§modes_b: Vec<u32>

Mode labels for input tensor B.

§modes_c: Vec<u32>

Mode labels for output tensor C.

§

ElementwiseMul

Element-wise multiplication of two tensors.

Available when has_extension_for::<T>(Extension::ElementwiseMul) returns true.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.