tenferro_prims/families/
semiring_core.rs

1use tenferro_algebra::Semiring;
2use tenferro_device::Result;
3use tenferro_tensor::Tensor;
4
5/// Descriptor for semiring-core execution operations.
6///
7/// This is the minimal protocol family that `tenferro-einsum` may depend on.
8///
9/// # Examples
10///
11/// ```
12/// use tenferro_prims::SemiringCoreDescriptor;
13///
14/// let desc = SemiringCoreDescriptor::MakeContiguous;
15/// assert!(matches!(desc, SemiringCoreDescriptor::MakeContiguous));
16/// ```
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub enum SemiringCoreDescriptor {
19    /// Batched semiring GEMM.
20    BatchedGemm {
21        /// Batch dimension sizes.
22        batch_dims: Vec<usize>,
23        /// Rows in A / C.
24        m: usize,
25        /// Columns in B / C.
26        n: usize,
27        /// Contracted dimension.
28        k: usize,
29    },
30    /// Reduction using semiring addition.
31    ReduceAdd {
32        /// Mode labels for input tensor A.
33        modes_a: Vec<u32>,
34        /// Mode labels for output tensor C.
35        modes_c: Vec<u32>,
36    },
37    /// Diagonal contraction.
38    Trace {
39        /// Mode labels for input tensor A.
40        modes_a: Vec<u32>,
41        /// Mode labels for output tensor C.
42        modes_c: Vec<u32>,
43        /// Paired diagonal modes.
44        paired: Vec<(u32, u32)>,
45    },
46    /// Diagonal scatter-add for trace adjoints.
47    AntiTrace {
48        /// Mode labels for input tensor A.
49        modes_a: Vec<u32>,
50        /// Mode labels for output tensor C.
51        modes_c: Vec<u32>,
52        /// Paired diagonal modes.
53        paired: Vec<(u32, u32)>,
54    },
55    /// Diagonal scatter/write for diag adjoints.
56    AntiDiag {
57        /// Mode labels for input tensor A.
58        modes_a: Vec<u32>,
59        /// Mode labels for output tensor C.
60        modes_c: Vec<u32>,
61        /// Paired diagonal modes.
62        paired: Vec<(u32, u32)>,
63    },
64    /// Materialize a contiguous output tensor.
65    MakeContiguous,
66}
67
68/// Minimal semiring execution protocol.
69///
70/// # Examples
71///
72/// ```ignore
73/// use tenferro_algebra::Standard;
74/// use tenferro_prims::{CpuBackend, CpuContext, SemiringCoreDescriptor, TensorSemiringCore};
75///
76/// let mut ctx = CpuContext::new(1);
77/// let desc = SemiringCoreDescriptor::MakeContiguous;
78/// let _plan = <CpuBackend as TensorSemiringCore<Standard<f64>>>::plan(
79///     &mut ctx,
80///     &desc,
81///     &[&[2, 2], &[2, 2]],
82/// )
83/// .unwrap();
84/// ```
85pub trait TensorSemiringCore<Alg: Semiring> {
86    /// Backend-specific plan type.
87    type Plan;
88
89    /// Backend-specific execution context.
90    type Context;
91
92    /// Plan a semiring-core operation.
93    fn plan(
94        ctx: &mut Self::Context,
95        desc: &SemiringCoreDescriptor,
96        shapes: &[&[usize]],
97    ) -> Result<Self::Plan>;
98
99    /// Execute a semiring-core operation.
100    fn execute(
101        ctx: &mut Self::Context,
102        plan: &Self::Plan,
103        alpha: Alg::Scalar,
104        inputs: &[&Tensor<Alg::Scalar>],
105        beta: Alg::Scalar,
106        output: &mut Tensor<Alg::Scalar>,
107    ) -> Result<()>;
108}