tenferro_prims/families/semiring_fast_path.rs
1use tenferro_algebra::Semiring;
2use tenferro_device::Result;
3use tenferro_tensor::Tensor;
4
5/// Semiring-valid optional binary fast-path operations.
6///
7/// # Examples
8///
9/// ```
10/// use tenferro_prims::SemiringBinaryOp;
11///
12/// let op = SemiringBinaryOp::Mul;
13/// assert_eq!(op, SemiringBinaryOp::Mul);
14/// ```
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum SemiringBinaryOp {
17 /// Elementwise semiring addition.
18 Add,
19 /// Elementwise semiring multiplication.
20 Mul,
21}
22
23/// Descriptor for optional semiring fast paths.
24///
25/// # Examples
26///
27/// ```
28/// use tenferro_prims::{SemiringBinaryOp, SemiringFastPathDescriptor};
29///
30/// let desc = SemiringFastPathDescriptor::ElementwiseBinary {
31/// op: SemiringBinaryOp::Mul,
32/// };
33/// assert!(matches!(desc, SemiringFastPathDescriptor::ElementwiseBinary { .. }));
34/// ```
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub enum SemiringFastPathDescriptor {
37 /// Optional contraction fast path.
38 Contract {
39 /// Mode labels for input A.
40 modes_a: Vec<u32>,
41 /// Mode labels for input B.
42 modes_b: Vec<u32>,
43 /// Mode labels for output C.
44 modes_c: Vec<u32>,
45 },
46 /// Optional elementwise semiring binary fast path.
47 ElementwiseBinary {
48 /// The semiring binary operation.
49 op: SemiringBinaryOp,
50 },
51}
52
53/// Optional semiring performance paths.
54///
55/// # Examples
56///
57/// ```ignore
58/// use tenferro_algebra::Standard;
59/// use tenferro_prims::{
60/// CpuBackend, SemiringBinaryOp, SemiringFastPathDescriptor, TensorSemiringFastPath,
61/// };
62///
63/// let supported =
64/// <CpuBackend as TensorSemiringFastPath<Standard<f64>>>::has_fast_path(SemiringFastPathDescriptor::ElementwiseBinary {
65/// op: SemiringBinaryOp::Mul,
66/// });
67/// assert!(supported);
68/// ```
69pub trait TensorSemiringFastPath<Alg: Semiring> {
70 /// Backend-specific plan type.
71 type Plan;
72
73 /// Backend-specific execution context.
74 type Context;
75
76 /// Plan an optional semiring fast path.
77 fn plan(
78 ctx: &mut Self::Context,
79 desc: &SemiringFastPathDescriptor,
80 shapes: &[&[usize]],
81 ) -> Result<Self::Plan>;
82
83 /// Execute an optional semiring fast path.
84 fn execute(
85 ctx: &mut Self::Context,
86 plan: &Self::Plan,
87 alpha: Alg::Scalar,
88 inputs: &[&Tensor<Alg::Scalar>],
89 beta: Alg::Scalar,
90 output: &mut Tensor<Alg::Scalar>,
91 ) -> Result<()>;
92
93 /// Query whether the optional path is available.
94 fn has_fast_path(desc: SemiringFastPathDescriptor) -> bool;
95}