tenferro_prims/
lib.rs

1//! Tensor primitive operations for the tenferro workspace.
2//!
3//! This crate defines the [`TensorPrims<A>`] trait, a backend-agnostic interface
4//! parameterized by algebra `A`. The API follows the cuTENSOR plan-based execution
5//! pattern:
6//!
7//! 1. Create a [`PrimDescriptor`] specifying the operation and index modes
8//! 2. Build a plan via [`TensorPrims::plan`] (pre-computes kernel selection)
9//! 3. Execute the plan via [`TensorPrims::execute`]
10//!
11//! # Operation categories
12//!
13//! **Core operations** (every backend must implement):
14//! - [`BatchedGemm`](PrimDescriptor::BatchedGemm): Batched matrix multiplication
15//! - [`Reduce`](PrimDescriptor::Reduce): Sum/max/min reduction over modes
16//! - [`Trace`](PrimDescriptor::Trace): Trace (contraction of paired diagonal modes)
17//! - [`Permute`](PrimDescriptor::Permute): Mode reordering
18//! - [`AntiTrace`](PrimDescriptor::AntiTrace): Scatter-add to diagonal (AD backward of trace)
19//! - [`AntiDiag`](PrimDescriptor::AntiDiag): Write to diagonal positions (AD backward of diag)
20//! - [`ElementwiseUnary`](PrimDescriptor::ElementwiseUnary): Point-wise unary transform (negate, reciprocal, abs, sqrt)
21//!
22//! **Extended operations** (dynamically queried via [`TensorPrims::has_extension_for`]):
23//! - [`Contract`](PrimDescriptor::Contract): Fused permute + GEMM contraction
24//! - [`ElementwiseMul`](PrimDescriptor::ElementwiseMul): Element-wise multiplication
25//!
26//! # Algebra parameterization
27//!
28//! [`TensorPrims<A>`] is parameterized by algebra `A` (e.g.,
29//! [`Standard`], `MaxPlus`).
30//! External crates implement `TensorPrims<MyAlgebra> for CpuBackend` (orphan rule
31//! compatible). The [`HasAlgebra`](tenferro_algebra::HasAlgebra) trait on scalar types
32//! enables automatic inference: `Tensor<f64>` → `Standard`.
33//!
34//! # Examples
35//!
36//! ## Plan-based GEMM
37//!
38//! ```ignore
39//! use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
40//! use tenferro_algebra::Standard;
41//! use strided_view::StridedArray;
42//!
43//! let a = StridedArray::<f64>::col_major(&[3, 4]);
44//! let b = StridedArray::<f64>::col_major(&[4, 5]);
45//! let mut c = StridedArray::<f64>::col_major(&[3, 5]);
46//!
47//! let desc = PrimDescriptor::BatchedGemm {
48//!     batch_dims: vec![], m: 3, n: 5, k: 4,
49//! };
50//! let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
51//! CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();
52//! ```
53//!
54//! ## Reduction (sum over an axis)
55//!
56//! ```ignore
57//! use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};
58//!
59//! // Sum over columns: c_i = Σ_j A_{i,j}
60//! let desc = PrimDescriptor::Reduce {
61//!     modes_a: vec![0, 1], modes_c: vec![0], op: ReduceOp::Sum,
62//! };
63//! let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
64//! CpuBackend::execute(&plan, 1.0, &[&a.view()], 0.0, &mut c.view_mut()).unwrap();
65//! ```
66//!
67//! ## Dynamic extension check
68//!
69//! ```ignore
70//! use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, Extension};
71//!
72//! if CpuBackend::has_extension_for::<f64>(Extension::Contract) {
73//!     let desc = PrimDescriptor::Contract {
74//!         modes_a: vec![0, 1], modes_b: vec![1, 2], modes_c: vec![0, 2],
75//!     };
76//!     let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
77//!     CpuBackend::execute(
78//!         &plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut(),
79//!     ).unwrap();
80//! }
81//! ```
82
83use std::marker::PhantomData;
84
85use strided_traits::ScalarBase;
86use strided_view::{StridedView, StridedViewMut};
87use tenferro_algebra::Standard;
88use tenferro_device::Result;
89
90/// Reduction operation kind.
91///
92/// # Examples
93///
94/// ```
95/// use tenferro_prims::ReduceOp;
96///
97/// let op = ReduceOp::Sum;
98/// assert_eq!(op, ReduceOp::Sum);
99/// ```
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum ReduceOp {
102    /// Sum reduction.
103    Sum,
104    /// Maximum value reduction.
105    Max,
106    /// Minimum value reduction.
107    Min,
108}
109
110/// Element-wise unary operation kind.
111///
112/// Used with [`PrimDescriptor::ElementwiseUnary`] for point-wise
113/// transformations. Maps to `cutensorElementwiseTrinary` (unary case)
114/// on GPU.
115///
116/// Note: square (`x²`) is omitted — expressible as
117/// `ElementwiseMul(x, x)` without an extra copy.
118///
119/// # Examples
120///
121/// ```
122/// use tenferro_prims::UnaryOp;
123///
124/// let op = UnaryOp::Reciprocal;
125/// assert_eq!(op, UnaryOp::Reciprocal);
126/// ```
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum UnaryOp {
129    /// Negate: `-x`.
130    Negate,
131    /// Reciprocal: `1 / x`.
132    Reciprocal,
133    /// Absolute value: `|x|`.
134    Abs,
135    /// Square root: `√x`.
136    Sqrt,
137}
138
139/// Extended operation identifiers for dynamic capability query.
140///
141/// Used with [`TensorPrims::has_extension_for`] to check at runtime whether
142/// a backend supports an optimized extended operation for a given scalar type.
143///
144/// # Examples
145///
146/// ```ignore
147/// use tenferro_prims::{CpuBackend, TensorPrims, Extension};
148///
149/// // Check if fused contraction is available for f64
150/// let available = CpuBackend::has_extension_for::<f64>(Extension::Contract);
151/// ```
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
153pub enum Extension {
154    /// Fused contraction (permute + GEMM). Maps to `cutensorContract` on GPU.
155    Contract,
156    /// Element-wise multiplication. Maps to `cutensorElementwiseBinary` on GPU.
157    ElementwiseMul,
158}
159
160/// Describes a tensor primitive operation.
161///
162/// All operations follow the cuTENSOR pattern: describe → plan → execute.
163/// Core operations must be supported by every backend. Extended operations
164/// are dynamically queried via [`TensorPrims::has_extension_for`].
165///
166/// Modes are `u32` integer labels matching cuTENSOR conventions. Modes
167/// shared between input and output tensors are batch/free dimensions;
168/// modes present only in inputs are contracted.
169///
170/// # Examples
171///
172/// ```
173/// use tenferro_prims::PrimDescriptor;
174///
175/// // Matrix multiplication: C_{m,n} = A_{m,k} * B_{k,n}
176/// let desc = PrimDescriptor::Contract {
177///     modes_a: vec![0, 1],  // m=0, k=1
178///     modes_b: vec![1, 2],  // k=1, n=2
179///     modes_c: vec![0, 2],  // m=0, n=2
180/// };
181/// ```
182pub enum PrimDescriptor {
183    // ====================================================================
184    // Core operations (every backend must implement)
185    // ====================================================================
186    /// Batched matrix multiplication.
187    ///
188    /// `C[batch, m, n] = alpha * A[batch, m, k] * B[batch, k, n] + beta * C`
189    BatchedGemm {
190        /// Batch dimension sizes.
191        batch_dims: Vec<usize>,
192        /// Number of rows in A / C.
193        m: usize,
194        /// Number of columns in B / C.
195        n: usize,
196        /// Contraction dimension (columns of A / rows of B).
197        k: usize,
198    },
199
200    /// Reduction over modes not present in the output.
201    ///
202    /// `C[modes_c] = alpha * reduce_op(A[modes_a]) + beta * C[modes_c]`
203    Reduce {
204        /// Mode labels for input tensor A.
205        modes_a: Vec<u32>,
206        /// Mode labels for output tensor C (subset of modes_a).
207        modes_c: Vec<u32>,
208        /// Reduction operation (Sum, Max, Min).
209        op: ReduceOp,
210    },
211
212    /// Trace: contraction of paired diagonal modes.
213    ///
214    /// For each pair `(i, j)`, sums over the diagonal where mode i == mode j.
215    Trace {
216        /// Mode labels for input tensor A.
217        modes_a: Vec<u32>,
218        /// Mode labels for output tensor C.
219        modes_c: Vec<u32>,
220        /// Pairs of modes to trace over.
221        paired: Vec<(u32, u32)>,
222    },
223
224    /// Permute (reorder) tensor modes.
225    ///
226    /// `B[modes_b] = alpha * A[modes_a]`
227    Permute {
228        /// Mode labels for input tensor A.
229        modes_a: Vec<u32>,
230        /// Mode labels for output tensor B (same labels, different order).
231        modes_b: Vec<u32>,
232    },
233
234    /// Anti-trace: scatter-add gradient to diagonal (AD backward of trace).
235    AntiTrace {
236        /// Mode labels for input tensor A.
237        modes_a: Vec<u32>,
238        /// Mode labels for output tensor C.
239        modes_c: Vec<u32>,
240        /// Pairs of modes for diagonal scatter.
241        paired: Vec<(u32, u32)>,
242    },
243
244    /// Anti-diag: write gradient to diagonal positions (AD backward of diag).
245    AntiDiag {
246        /// Mode labels for input tensor A.
247        modes_a: Vec<u32>,
248        /// Mode labels for output tensor C.
249        modes_c: Vec<u32>,
250        /// Pairs of modes for diagonal write.
251        paired: Vec<(u32, u32)>,
252    },
253
254    /// Element-wise unary operation.
255    ///
256    /// `C[modes] = alpha * op(A[modes]) + beta * C[modes]`
257    ///
258    /// # Examples
259    ///
260    /// ```
261    /// use tenferro_prims::{PrimDescriptor, UnaryOp};
262    ///
263    /// // Reciprocal: C = 1/A
264    /// let desc = PrimDescriptor::ElementwiseUnary {
265    ///     op: UnaryOp::Reciprocal,
266    /// };
267    /// ```
268    ElementwiseUnary {
269        /// Unary operation to apply.
270        op: UnaryOp,
271    },
272
273    // ====================================================================
274    // Extended operations (dynamically queried)
275    // ====================================================================
276    /// Fused contraction: permute + GEMM in one operation.
277    ///
278    /// `C[modes_c] = alpha * contract(A[modes_a], B[modes_b]) + beta * C`
279    ///
280    /// Available when `has_extension_for::<T>(Extension::Contract)` returns true.
281    Contract {
282        /// Mode labels for input tensor A.
283        modes_a: Vec<u32>,
284        /// Mode labels for input tensor B.
285        modes_b: Vec<u32>,
286        /// Mode labels for output tensor C.
287        modes_c: Vec<u32>,
288    },
289
290    /// Element-wise multiplication of two tensors.
291    ///
292    /// Available when `has_extension_for::<T>(Extension::ElementwiseMul)` returns true.
293    ElementwiseMul,
294}
295
296/// Backend trait for tensor primitive operations, parameterized by algebra `A`.
297///
298/// Provides a cuTENSOR-compatible plan-based execution model for all
299/// operations. Core ops (batched_gemm, reduce, trace, permute, anti_trace,
300/// anti_diag) must be implemented. Extended ops (contract, elementwise_mul)
301/// are dynamically queried via [`has_extension_for`](TensorPrims::has_extension_for).
302///
303/// # Algebra parameterization
304///
305/// The algebra parameter `A` enables extensibility: external crates can
306/// implement `TensorPrims<MyAlgebra> for CpuBackend` (orphan rule compatible).
307///
308/// # Associated functions (not methods)
309///
310/// All functions are associated functions (no `&self` receiver). Call as
311/// `CpuBackend::plan::<f64>(...)` instead of `backend.plan(...)`.
312///
313/// # Examples
314///
315/// ```ignore
316/// use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
317///
318/// // Plan a batched GEMM
319/// let desc = PrimDescriptor::BatchedGemm {
320///     batch_dims: vec![], m: 3, n: 5, k: 4,
321/// };
322/// let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 5], &[3, 5]]).unwrap();
323///
324/// // Execute: C = 1.0 * A*B + 0.0 * C
325/// CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();
326/// ```
327pub trait TensorPrims<A> {
328    /// Backend-specific plan type (no type erasure).
329    type Plan<T: ScalarBase>;
330
331    /// Create an execution plan from an operation descriptor.
332    ///
333    /// The plan pre-computes kernel selection and workspace sizes.
334    /// `shapes` contains the shape of each tensor involved in the operation
335    /// (inputs first, then output).
336    ///
337    /// # Examples
338    ///
339    /// ```ignore
340    /// use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor, ReduceOp};
341    ///
342    /// let desc = PrimDescriptor::Reduce {
343    ///     modes_a: vec![0, 1], modes_c: vec![0], op: ReduceOp::Sum,
344    /// };
345    /// let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[3]]).unwrap();
346    /// ```
347    fn plan<T: ScalarBase>(desc: &PrimDescriptor, shapes: &[&[usize]]) -> Result<Self::Plan<T>>;
348
349    /// Execute a plan with the given scaling factors and tensor views.
350    ///
351    /// Follows the BLAS/cuTENSOR pattern:
352    /// `output = alpha * op(inputs) + beta * output`
353    ///
354    /// # Examples
355    ///
356    /// ```ignore
357    /// // Execute: output = 1.0 * gemm(a, b) + 0.0 * output  (overwrite)
358    /// CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 0.0, &mut c.view_mut()).unwrap();
359    ///
360    /// // Accumulate: output = 1.0 * gemm(a, b) + 1.0 * output  (add)
361    /// CpuBackend::execute(&plan, 1.0, &[&a.view(), &b.view()], 1.0, &mut c.view_mut()).unwrap();
362    /// ```
363    fn execute<T: ScalarBase>(
364        plan: &Self::Plan<T>,
365        alpha: T,
366        inputs: &[&StridedView<T>],
367        beta: T,
368        output: &mut StridedViewMut<T>,
369    ) -> Result<()>;
370
371    /// Query whether an extended operation is available for scalar type `T`.
372    ///
373    /// Returns `true` if the backend supports the given extended operation
374    /// for the specified scalar type. This enables dynamic dispatch:
375    /// GPU may support Contract for f64 but not for tropical types.
376    ///
377    /// # Examples
378    ///
379    /// ```ignore
380    /// use tenferro_prims::{CpuBackend, TensorPrims, Extension};
381    ///
382    /// if CpuBackend::has_extension_for::<f64>(Extension::Contract) {
383    ///     // Use fused contraction for better performance
384    /// } else {
385    ///     // Decompose into core ops: permute → batched_gemm
386    /// }
387    /// ```
388    fn has_extension_for<T: ScalarBase>(ext: Extension) -> bool;
389}
390
391/// CPU plan — concrete enum, no type erasure.
392///
393/// Created by [`CpuBackend::plan`](TensorPrims::plan) and consumed by
394/// [`CpuBackend::execute`](TensorPrims::execute).
395pub enum CpuPlan<T: ScalarBase> {
396    /// Plan for batched GEMM.
397    BatchedGemm {
398        /// Number of rows.
399        m: usize,
400        /// Number of columns.
401        n: usize,
402        /// Contraction dimension.
403        k: usize,
404        _marker: PhantomData<T>,
405    },
406    /// Plan for reduction.
407    Reduce {
408        /// Axis to reduce over.
409        axis: usize,
410        /// Reduction operation.
411        op: ReduceOp,
412        _marker: PhantomData<T>,
413    },
414    /// Plan for trace.
415    Trace {
416        /// Paired modes.
417        paired: Vec<(u32, u32)>,
418        _marker: PhantomData<T>,
419    },
420    /// Plan for permutation.
421    Permute {
422        /// Permutation mapping.
423        perm: Vec<usize>,
424        _marker: PhantomData<T>,
425    },
426    /// Plan for anti-trace (AD backward).
427    AntiTrace {
428        /// Paired modes.
429        paired: Vec<(u32, u32)>,
430        _marker: PhantomData<T>,
431    },
432    /// Plan for anti-diag (AD backward).
433    AntiDiag {
434        /// Paired modes.
435        paired: Vec<(u32, u32)>,
436        _marker: PhantomData<T>,
437    },
438    /// Plan for element-wise unary operation.
439    ElementwiseUnary {
440        /// Unary operation.
441        op: UnaryOp,
442        _marker: PhantomData<T>,
443    },
444    /// Plan for fused contraction (extended op).
445    Contract { _marker: PhantomData<T> },
446    /// Plan for element-wise multiplication (extended op).
447    ElementwiseMul { _marker: PhantomData<T> },
448}
449
450/// CPU backend using strided-kernel and GEMM.
451///
452/// Dispatched automatically when tensors reside on
453/// [`LogicalMemorySpace::MainMemory`](tenferro_device::LogicalMemorySpace::MainMemory).
454/// Implements [`TensorPrims<Standard>`] for standard arithmetic.
455///
456/// # Examples
457///
458/// ```ignore
459/// use tenferro_prims::{CpuBackend, TensorPrims, PrimDescriptor};
460/// use strided_view::StridedArray;
461///
462/// // Transpose a matrix
463/// let desc = PrimDescriptor::Permute {
464///     modes_a: vec![0, 1],
465///     modes_b: vec![1, 0],
466/// };
467/// let plan = CpuBackend::plan::<f64>(&desc, &[&[3, 4], &[4, 3]]).unwrap();
468/// let a = StridedArray::<f64>::col_major(&[3, 4]);
469/// let mut b = StridedArray::<f64>::col_major(&[4, 3]);
470/// CpuBackend::execute(&plan, 1.0, &[&a.view()], 0.0, &mut b.view_mut()).unwrap();
471/// ```
472pub struct CpuBackend;
473
474impl TensorPrims<Standard> for CpuBackend {
475    type Plan<T: ScalarBase> = CpuPlan<T>;
476
477    fn plan<T: ScalarBase>(_desc: &PrimDescriptor, _shapes: &[&[usize]]) -> Result<CpuPlan<T>> {
478        todo!()
479    }
480
481    fn execute<T: ScalarBase>(
482        _plan: &CpuPlan<T>,
483        _alpha: T,
484        _inputs: &[&StridedView<T>],
485        _beta: T,
486        _output: &mut StridedViewMut<T>,
487    ) -> Result<()> {
488        todo!()
489    }
490
491    fn has_extension_for<T: ScalarBase>(_ext: Extension) -> bool {
492        todo!()
493    }
494}