Crate tenferro_prims

Crate tenferro_prims 

Source
Expand description

Tensor primitive operations for the tenferro workspace.

This crate defines the TensorPrims<A> trait, a backend-agnostic interface parameterized by algebra A. The API follows the cuTENSOR plan-based execution pattern:

  1. Create a PrimDescriptor specifying the operation and index modes
  2. Build a plan via TensorPrims::plan (pre-computes kernel selection)
  3. Execute the plan via TensorPrims::execute

§Operation categories

Core operations (every backend must implement):

  • BatchedGemm: Batched matrix multiplication
  • Reduce: Sum/max/min reduction over modes
  • Trace: Trace (contraction of paired diagonal modes)
  • Permute: Mode reordering
  • AntiTrace: Scatter-add to diagonal (AD backward of trace)
  • AntiDiag: Write to diagonal positions (AD backward of diag)
  • ElementwiseUnary: Point-wise unary transform (negate, reciprocal, abs, sqrt)

Extended operations (dynamically queried via TensorPrims::has_extension_for):

§Algebra parameterization

TensorPrims<A> is parameterized by algebra A (e.g., Standard, MaxPlus). External crates implement TensorPrims<MyAlgebra> for CpuBackend (orphan rule compatible). The HasAlgebra trait on scalar types enables automatic inference: Tensor<f64>Standard.

§Examples

§Plan-based GEMM

use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
use tenferro_algebra::Standard;
use strided_view::StridedArray;

let a = StridedArray::<f64>::col_major(&[3, 4]);
let b = StridedArray::<f64>::col_major(&[4, 5]);
let mut c = StridedArray::<f64>::col_major(&[3, 5]);

let desc = PrimDescriptor::BatchedGemm {
    batch_dims: vec![], m: 3, n: 5, k: 4,
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();

§Reduction (sum over an axis)

use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};

// Sum over columns: c_i = Σ_j A_{i,j}
let desc = PrimDescriptor::Reduce {
    modes_a: vec![0, 1], modes_c: vec![0], op: ReduceOp::Sum,
};
let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
CpuBackend::execute(&plan, 1.0, &[&a.view()], 0.0, &mut c.view_mut()).unwrap();

§Dynamic extension check

use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, Extension};

if CpuBackend::has_extension_for::<f64>(Extension::Contract) {
    let desc = PrimDescriptor::Contract {
        modes_a: vec![0, 1], modes_b: vec![1, 2], modes_c: vec![0, 2],
    };
    let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
    CpuBackend::execute(
        &plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut(),
    ).unwrap();
}

Structs§

CpuBackend
CPU backend using strided-kernel and GEMM.

Enums§

CpuPlan
CPU plan — concrete enum, no type erasure.
Extension
Extended operation identifiers for dynamic capability query.
PrimDescriptor
Describes a tensor primitive operation.
ReduceOp
Reduction operation kind.
UnaryOp
Element-wise unary operation kind.

Traits§

TensorPrims
Backend trait for tensor primitive operations, parameterized by algebra A.