tenferro_prims/families/semiring_core.rs
1use tenferro_algebra::Semiring;
2use tenferro_device::Result;
3use tenferro_tensor::Tensor;
4
5/// Descriptor for semiring-core execution operations.
6///
7/// This is the minimal protocol family that `tenferro-einsum` may depend on.
8///
9/// # Examples
10///
11/// ```
12/// use tenferro_prims::SemiringCoreDescriptor;
13///
14/// let desc = SemiringCoreDescriptor::MakeContiguous;
15/// assert!(matches!(desc, SemiringCoreDescriptor::MakeContiguous));
16/// ```
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub enum SemiringCoreDescriptor {
19 /// Batched semiring GEMM.
20 BatchedGemm {
21 /// Batch dimension sizes.
22 batch_dims: Vec<usize>,
23 /// Rows in A / C.
24 m: usize,
25 /// Columns in B / C.
26 n: usize,
27 /// Contracted dimension.
28 k: usize,
29 },
30 /// Reduction using semiring addition.
31 ReduceAdd {
32 /// Mode labels for input tensor A.
33 modes_a: Vec<u32>,
34 /// Mode labels for output tensor C.
35 modes_c: Vec<u32>,
36 },
37 /// Diagonal contraction.
38 Trace {
39 /// Mode labels for input tensor A.
40 modes_a: Vec<u32>,
41 /// Mode labels for output tensor C.
42 modes_c: Vec<u32>,
43 /// Paired diagonal modes.
44 paired: Vec<(u32, u32)>,
45 },
46 /// Diagonal scatter-add for trace adjoints.
47 AntiTrace {
48 /// Mode labels for input tensor A.
49 modes_a: Vec<u32>,
50 /// Mode labels for output tensor C.
51 modes_c: Vec<u32>,
52 /// Paired diagonal modes.
53 paired: Vec<(u32, u32)>,
54 },
55 /// Diagonal scatter/write for diag adjoints.
56 AntiDiag {
57 /// Mode labels for input tensor A.
58 modes_a: Vec<u32>,
59 /// Mode labels for output tensor C.
60 modes_c: Vec<u32>,
61 /// Paired diagonal modes.
62 paired: Vec<(u32, u32)>,
63 },
64 /// Materialize a contiguous output tensor.
65 MakeContiguous,
66}
67
68/// Minimal semiring execution protocol.
69///
70/// # Examples
71///
72/// ```ignore
73/// use tenferro_algebra::Standard;
74/// use tenferro_prims::{CpuBackend, CpuContext, SemiringCoreDescriptor, TensorSemiringCore};
75///
76/// let mut ctx = CpuContext::new(1);
77/// let desc = SemiringCoreDescriptor::MakeContiguous;
78/// let _plan = <CpuBackend as TensorSemiringCore<Standard<f64>>>::plan(
79/// &mut ctx,
80/// &desc,
81/// &[&[2, 2], &[2, 2]],
82/// )
83/// .unwrap();
84/// ```
85pub trait TensorSemiringCore<Alg: Semiring> {
86 /// Backend-specific plan type.
87 type Plan;
88
89 /// Backend-specific execution context.
90 type Context;
91
92 /// Plan a semiring-core operation.
93 fn plan(
94 ctx: &mut Self::Context,
95 desc: &SemiringCoreDescriptor,
96 shapes: &[&[usize]],
97 ) -> Result<Self::Plan>;
98
99 /// Execute a semiring-core operation.
100 fn execute(
101 ctx: &mut Self::Context,
102 plan: &Self::Plan,
103 alpha: Alg::Scalar,
104 inputs: &[&Tensor<Alg::Scalar>],
105 beta: Alg::Scalar,
106 output: &mut Tensor<Alg::Scalar>,
107 ) -> Result<()>;
108}