tenferro_prims/lib.rs
1//! Tensor primitive operations for the tenferro workspace.
2//!
3//! This crate defines the [`TensorPrims<A>`] trait, a backend-agnostic interface
4//! parameterized by algebra `A`. The API follows the cuTENSOR plan-based execution
5//! pattern:
6//!
7//! 1. Create a [`PrimDescriptor`] specifying the operation and index modes
8//! 2. Build a plan via [`TensorPrims::plan`] (pre-computes kernel selection)
9//! 3. Execute the plan via [`TensorPrims::execute`]
10//!
11//! # Operation categories
12//!
13//! **Core operations** (every backend must implement):
14//! - [`BatchedGemm`](PrimDescriptor::BatchedGemm): Batched matrix multiplication
15//! - [`Reduce`](PrimDescriptor::Reduce): Sum/max/min reduction over modes
16//! - [`Trace`](PrimDescriptor::Trace): Trace (contraction of paired diagonal modes)
17//! - [`Permute`](PrimDescriptor::Permute): Mode reordering
18//! - [`AntiTrace`](PrimDescriptor::AntiTrace): Scatter-add to diagonal (AD backward of trace)
19//! - [`AntiDiag`](PrimDescriptor::AntiDiag): Write to diagonal positions (AD backward of diag)
20//! - [`ElementwiseUnary`](PrimDescriptor::ElementwiseUnary): Point-wise unary transform (negate, reciprocal, abs, sqrt)
21//!
22//! **Extended operations** (dynamically queried via [`TensorPrims::has_extension_for`]):
23//! - [`Contract`](PrimDescriptor::Contract): Fused permute + GEMM contraction
24//! - [`ElementwiseMul`](PrimDescriptor::ElementwiseMul): Element-wise multiplication
25//!
26//! # Algebra parameterization
27//!
28//! [`TensorPrims<A>`] is parameterized by algebra `A` (e.g.,
29//! [`Standard`], `MaxPlus`).
30//! External crates implement `TensorPrims<MyAlgebra> for CpuBackend` (orphan rule
31//! compatible). The [`HasAlgebra`](tenferro_algebra::HasAlgebra) trait on scalar types
32//! enables automatic inference: `Tensor<f64>` → `Standard`.
33//!
34//! # Examples
35//!
36//! ## Plan-based GEMM
37//!
38//! ```ignore
39//! use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
40//! use tenferro_algebra::Standard;
41//! use strided_view::StridedArray;
42//!
43//! let a = StridedArray::<f64>::col_major(&[3, 4]);
44//! let b = StridedArray::<f64>::col_major(&[4, 5]);
45//! let mut c = StridedArray::<f64>::col_major(&[3, 5]);
46//!
47//! let desc = PrimDescriptor::BatchedGemm {
48//! batch_dims: vec![], m: 3, n: 5, k: 4,
49//! };
50//! let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
51//! CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();
52//! ```
53//!
54//! ## Reduction (sum over an axis)
55//!
56//! ```ignore
57//! use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};
58//!
59//! // Sum over columns: c_i = Σ_j A_{i,j}
60//! let desc = PrimDescriptor::Reduce {
61//! modes_a: vec![0, 1], modes_c: vec![0], op: ReduceOp::Sum,
62//! };
63//! let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
64//! CpuBackend::execute(&plan, 1.0, &[&a.view()], 0.0, &mut c.view_mut()).unwrap();
65//! ```
66//!
67//! ## Dynamic extension check
68//!
69//! ```ignore
70//! use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, Extension};
71//!
72//! if CpuBackend::has_extension_for::<f64>(Extension::Contract) {
73//! let desc = PrimDescriptor::Contract {
74//! modes_a: vec![0, 1], modes_b: vec![1, 2], modes_c: vec![0, 2],
75//! };
76//! let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
77//! CpuBackend::execute(
78//! &plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut(),
79//! ).unwrap();
80//! }
81//! ```
82
83use std::marker::PhantomData;
84
85use strided_traits::ScalarBase;
86use strided_view::{StridedView, StridedViewMut};
87use tenferro_algebra::Standard;
88use tenferro_device::Result;
89
90/// Reduction operation kind.
91///
92/// # Examples
93///
94/// ```
95/// use tenferro_prims::ReduceOp;
96///
97/// let op = ReduceOp::Sum;
98/// assert_eq!(op, ReduceOp::Sum);
99/// ```
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum ReduceOp {
102 /// Sum reduction.
103 Sum,
104 /// Maximum value reduction.
105 Max,
106 /// Minimum value reduction.
107 Min,
108}
109
110/// Element-wise unary operation kind.
111///
112/// Used with [`PrimDescriptor::ElementwiseUnary`] for point-wise
113/// transformations. Maps to `cutensorElementwiseTrinary` (unary case)
114/// on GPU.
115///
116/// Note: square (`x²`) is omitted — expressible as
117/// `ElementwiseMul(x, x)` without an extra copy.
118///
119/// # Examples
120///
121/// ```
122/// use tenferro_prims::UnaryOp;
123///
124/// let op = UnaryOp::Reciprocal;
125/// assert_eq!(op, UnaryOp::Reciprocal);
126/// ```
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum UnaryOp {
129 /// Negate: `-x`.
130 Negate,
131 /// Reciprocal: `1 / x`.
132 Reciprocal,
133 /// Absolute value: `|x|`.
134 Abs,
135 /// Square root: `√x`.
136 Sqrt,
137}
138
139/// Extended operation identifiers for dynamic capability query.
140///
141/// Used with [`TensorPrims::has_extension_for`] to check at runtime whether
142/// a backend supports an optimized extended operation for a given scalar type.
143///
144/// # Examples
145///
146/// ```ignore
147/// use tenferro_prims::{CpuBackend, TensorPrims, Extension};
148///
149/// // Check if fused contraction is available for f64
150/// let available = CpuBackend::has_extension_for::<f64>(Extension::Contract);
151/// ```
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
153pub enum Extension {
154 /// Fused contraction (permute + GEMM). Maps to `cutensorContract` on GPU.
155 Contract,
156 /// Element-wise multiplication. Maps to `cutensorElementwiseBinary` on GPU.
157 ElementwiseMul,
158}
159
160/// Describes a tensor primitive operation.
161///
162/// All operations follow the cuTENSOR pattern: describe → plan → execute.
163/// Core operations must be supported by every backend. Extended operations
164/// are dynamically queried via [`TensorPrims::has_extension_for`].
165///
166/// Modes are `u32` integer labels matching cuTENSOR conventions. Modes
167/// shared between input and output tensors are batch/free dimensions;
168/// modes present only in inputs are contracted.
169///
170/// # Examples
171///
172/// ```
173/// use tenferro_prims::PrimDescriptor;
174///
175/// // Matrix multiplication: C_{m,n} = A_{m,k} * B_{k,n}
176/// let desc = PrimDescriptor::Contract {
177/// modes_a: vec![0, 1], // m=0, k=1
178/// modes_b: vec![1, 2], // k=1, n=2
179/// modes_c: vec![0, 2], // m=0, n=2
180/// };
181/// ```
182pub enum PrimDescriptor {
183 // ====================================================================
184 // Core operations (every backend must implement)
185 // ====================================================================
186 /// Batched matrix multiplication.
187 ///
188 /// `C[batch, m, n] = alpha * A[batch, m, k] * B[batch, k, n] + beta * C`
189 BatchedGemm {
190 /// Batch dimension sizes.
191 batch_dims: Vec<usize>,
192 /// Number of rows in A / C.
193 m: usize,
194 /// Number of columns in B / C.
195 n: usize,
196 /// Contraction dimension (columns of A / rows of B).
197 k: usize,
198 },
199
200 /// Reduction over modes not present in the output.
201 ///
202 /// `C[modes_c] = alpha * reduce_op(A[modes_a]) + beta * C[modes_c]`
203 Reduce {
204 /// Mode labels for input tensor A.
205 modes_a: Vec<u32>,
206 /// Mode labels for output tensor C (subset of modes_a).
207 modes_c: Vec<u32>,
208 /// Reduction operation (Sum, Max, Min).
209 op: ReduceOp,
210 },
211
212 /// Trace: contraction of paired diagonal modes.
213 ///
214 /// For each pair `(i, j)`, sums over the diagonal where mode i == mode j.
215 Trace {
216 /// Mode labels for input tensor A.
217 modes_a: Vec<u32>,
218 /// Mode labels for output tensor C.
219 modes_c: Vec<u32>,
220 /// Pairs of modes to trace over.
221 paired: Vec<(u32, u32)>,
222 },
223
224 /// Permute (reorder) tensor modes.
225 ///
226 /// `B[modes_b] = alpha * A[modes_a]`
227 Permute {
228 /// Mode labels for input tensor A.
229 modes_a: Vec<u32>,
230 /// Mode labels for output tensor B (same labels, different order).
231 modes_b: Vec<u32>,
232 },
233
234 /// Anti-trace: scatter-add gradient to diagonal (AD backward of trace).
235 AntiTrace {
236 /// Mode labels for input tensor A.
237 modes_a: Vec<u32>,
238 /// Mode labels for output tensor C.
239 modes_c: Vec<u32>,
240 /// Pairs of modes for diagonal scatter.
241 paired: Vec<(u32, u32)>,
242 },
243
244 /// Anti-diag: write gradient to diagonal positions (AD backward of diag).
245 AntiDiag {
246 /// Mode labels for input tensor A.
247 modes_a: Vec<u32>,
248 /// Mode labels for output tensor C.
249 modes_c: Vec<u32>,
250 /// Pairs of modes for diagonal write.
251 paired: Vec<(u32, u32)>,
252 },
253
254 /// Element-wise unary operation.
255 ///
256 /// `C[modes] = alpha * op(A[modes]) + beta * C[modes]`
257 ///
258 /// # Examples
259 ///
260 /// ```
261 /// use tenferro_prims::{PrimDescriptor, UnaryOp};
262 ///
263 /// // Reciprocal: C = 1/A
264 /// let desc = PrimDescriptor::ElementwiseUnary {
265 /// op: UnaryOp::Reciprocal,
266 /// };
267 /// ```
268 ElementwiseUnary {
269 /// Unary operation to apply.
270 op: UnaryOp,
271 },
272
273 // ====================================================================
274 // Extended operations (dynamically queried)
275 // ====================================================================
276 /// Fused contraction: permute + GEMM in one operation.
277 ///
278 /// `C[modes_c] = alpha * contract(A[modes_a], B[modes_b]) + beta * C`
279 ///
280 /// Available when `has_extension_for::<T>(Extension::Contract)` returns true.
281 Contract {
282 /// Mode labels for input tensor A.
283 modes_a: Vec<u32>,
284 /// Mode labels for input tensor B.
285 modes_b: Vec<u32>,
286 /// Mode labels for output tensor C.
287 modes_c: Vec<u32>,
288 },
289
290 /// Element-wise multiplication of two tensors.
291 ///
292 /// Available when `has_extension_for::<T>(Extension::ElementwiseMul)` returns true.
293 ElementwiseMul,
294}
295
296/// Backend trait for tensor primitive operations, parameterized by algebra `A`.
297///
298/// Provides a cuTENSOR-compatible plan-based execution model for all
299/// operations. Core ops (batched_gemm, reduce, trace, permute, anti_trace,
300/// anti_diag) must be implemented. Extended ops (contract, elementwise_mul)
301/// are dynamically queried via [`has_extension_for`](TensorPrims::has_extension_for).
302///
303/// # Algebra parameterization
304///
305/// The algebra parameter `A` enables extensibility: external crates can
306/// implement `TensorPrims<MyAlgebra> for CpuBackend` (orphan rule compatible).
307///
308/// # Associated functions (not methods)
309///
310/// All functions are associated functions (no `&self` receiver). Call as
311/// `CpuBackend::plan::<f64>(...)` instead of `backend.plan(...)`.
312///
313/// # Examples
314///
315/// ```ignore
316/// use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
317///
318/// // Plan a batched GEMM
319/// let desc = PrimDescriptor::BatchedGemm {
320/// batch_dims: vec![], m: 3, n: 5, k: 4,
321/// };
322/// let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
323///
324/// // Execute: C = 1.0 * A*B + 0.0 * C
325/// CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();
326/// ```
327pub trait TensorPrims<A> {
328 /// Backend-specific plan type (no type erasure).
329 type Plan<T: ScalarBase>;
330
331 /// Create an execution plan from an operation descriptor.
332 ///
333 /// The plan pre-computes kernel selection and workspace sizes.
334 /// `shapes` contains the shape of each tensor involved in the operation
335 /// (inputs first, then output).
336 ///
337 /// # Examples
338 ///
339 /// ```ignore
340 /// use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};
341 ///
342 /// let desc = PrimDescriptor::Reduce {
343 /// modes_a: vec![0, 1], modes_c: vec![0], op: ReduceOp::Sum,
344 /// };
345 /// let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
346 /// ```
347 fn plan<T: ScalarBase>(desc: &PrimDescriptor, shapes: &[&[usize]]) -> Result<Self::Plan<T>>;
348
349 /// Execute a plan with the given scaling factors and tensor views.
350 ///
351 /// Follows the BLAS/cuTENSOR pattern:
352 /// `output = alpha * op(inputs) + beta * output`
353 ///
354 /// # Examples
355 ///
356 /// ```ignore
357 /// // Execute: output = 1.0 * gemm(a, b) + 0.0 * output (overwrite)
358 /// CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();
359 ///
360 /// // Accumulate: output = 1.0 * gemm(a, b) + 1.0 * output (add)
361 /// CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 1.0, &mut c.view_mut()).unwrap();
362 /// ```
363 fn execute<T: ScalarBase>(
364 plan: &Self::Plan<T>,
365 alpha: T,
366 inputs: &[&StridedView<T>],
367 beta: T,
368 output: &mut StridedViewMut<T>,
369 ) -> Result<()>;
370
371 /// Query whether an extended operation is available for scalar type `T`.
372 ///
373 /// Returns `true` if the backend supports the given extended operation
374 /// for the specified scalar type. This enables dynamic dispatch:
375 /// GPU may support Contract for f64 but not for tropical types.
376 ///
377 /// # Examples
378 ///
379 /// ```ignore
380 /// use tenferro_prims::{CpuBackend, TensorPrims, Extension};
381 ///
382 /// if CpuBackend::has_extension_for::<f64>(Extension::Contract) {
383 /// // Use fused contraction for better performance
384 /// } else {
385 /// // Decompose into core ops: permute → batched_gemm
386 /// }
387 /// ```
388 fn has_extension_for<T: ScalarBase>(ext: Extension) -> bool;
389}
390
391/// CPU plan — concrete enum, no type erasure.
392///
393/// Created by [`CpuBackend::plan`](TensorPrims::plan) and consumed by
394/// [`CpuBackend::execute`](TensorPrims::execute).
395pub enum CpuPlan<T: ScalarBase> {
396 /// Plan for batched GEMM.
397 BatchedGemm {
398 /// Number of rows.
399 m: usize,
400 /// Number of columns.
401 n: usize,
402 /// Contraction dimension.
403 k: usize,
404 _marker: PhantomData<T>,
405 },
406 /// Plan for reduction.
407 Reduce {
408 /// Axis to reduce over.
409 axis: usize,
410 /// Reduction operation.
411 op: ReduceOp,
412 _marker: PhantomData<T>,
413 },
414 /// Plan for trace.
415 Trace {
416 /// Paired modes.
417 paired: Vec<(u32, u32)>,
418 _marker: PhantomData<T>,
419 },
420 /// Plan for permutation.
421 Permute {
422 /// Permutation mapping.
423 perm: Vec<usize>,
424 _marker: PhantomData<T>,
425 },
426 /// Plan for anti-trace (AD backward).
427 AntiTrace {
428 /// Paired modes.
429 paired: Vec<(u32, u32)>,
430 _marker: PhantomData<T>,
431 },
432 /// Plan for anti-diag (AD backward).
433 AntiDiag {
434 /// Paired modes.
435 paired: Vec<(u32, u32)>,
436 _marker: PhantomData<T>,
437 },
438 /// Plan for element-wise unary operation.
439 ElementwiseUnary {
440 /// Unary operation.
441 op: UnaryOp,
442 _marker: PhantomData<T>,
443 },
444 /// Plan for fused contraction (extended op).
445 Contract { _marker: PhantomData<T> },
446 /// Plan for element-wise multiplication (extended op).
447 ElementwiseMul { _marker: PhantomData<T> },
448}
449
450/// CPU backend using strided-kernel and GEMM.
451///
452/// Dispatched automatically when tensors reside on
453/// [`LogicalMemorySpace::MainMemory`](tenferro_device::LogicalMemorySpace::MainMemory).
454/// Implements [`TensorPrims<Standard>`] for standard arithmetic.
455///
456/// # Examples
457///
458/// ```ignore
459/// use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
460/// use strided_view::StridedArray;
461///
462/// // Transpose a matrix
463/// let desc = PrimDescriptor::Permute {
464/// modes_a: vec![0, 1],
465/// modes_b: vec![1, 0],
466/// };
467/// let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 3]]).unwrap();
468/// let a = StridedArray::<f64>::col_major(&[3, 4]);
469/// let mut b = StridedArray::<f64>::col_major(&[4, 3]);
470/// CpuBackend::execute(&plan, 1.0, &[&a.view()], 0.0, &mut b.view_mut()).unwrap();
471/// ```
472pub struct CpuBackend;
473
474impl TensorPrims<Standard> for CpuBackend {
475 type Plan<T: ScalarBase> = CpuPlan<T>;
476
477 fn plan<T: ScalarBase>(_desc: &PrimDescriptor, _shapes: &[&[usize]]) -> Result<CpuPlan<T>> {
478 todo!()
479 }
480
481 fn execute<T: ScalarBase>(
482 _plan: &CpuPlan<T>,
483 _alpha: T,
484 _inputs: &[&StridedView<T>],
485 _beta: T,
486 _output: &mut StridedViewMut<T>,
487 ) -> Result<()> {
488 todo!()
489 }
490
491 fn has_extension_for<T: ScalarBase>(_ext: Extension) -> bool {
492 todo!()
493 }
494}