tenferro_prims/families/
semiring_fast_path.rs

1use tenferro_algebra::Semiring;
2use tenferro_device::Result;
3use tenferro_tensor::Tensor;
4
5/// Semiring-valid optional binary fast-path operations.
6///
7/// # Examples
8///
9/// ```
10/// use tenferro_prims::SemiringBinaryOp;
11///
12/// let op = SemiringBinaryOp::Mul;
13/// assert_eq!(op, SemiringBinaryOp::Mul);
14/// ```
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum SemiringBinaryOp {
17    /// Elementwise semiring addition.
18    Add,
19    /// Elementwise semiring multiplication.
20    Mul,
21}
22
23/// Descriptor for optional semiring fast paths.
24///
25/// # Examples
26///
27/// ```
28/// use tenferro_prims::{SemiringBinaryOp, SemiringFastPathDescriptor};
29///
30/// let desc = SemiringFastPathDescriptor::ElementwiseBinary {
31///     op: SemiringBinaryOp::Mul,
32/// };
33/// assert!(matches!(desc, SemiringFastPathDescriptor::ElementwiseBinary { .. }));
34/// ```
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub enum SemiringFastPathDescriptor {
37    /// Optional contraction fast path.
38    Contract {
39        /// Mode labels for input A.
40        modes_a: Vec<u32>,
41        /// Mode labels for input B.
42        modes_b: Vec<u32>,
43        /// Mode labels for output C.
44        modes_c: Vec<u32>,
45    },
46    /// Optional elementwise semiring binary fast path.
47    ElementwiseBinary {
48        /// The semiring binary operation.
49        op: SemiringBinaryOp,
50    },
51}
52
53/// Optional semiring performance paths.
54///
55/// # Examples
56///
57/// ```ignore
58/// use tenferro_algebra::Standard;
59/// use tenferro_prims::{
60///     CpuBackend, SemiringBinaryOp, SemiringFastPathDescriptor, TensorSemiringFastPath,
61/// };
62///
63/// let supported =
64///     <CpuBackend as TensorSemiringFastPath<Standard<f64>>>::has_fast_path(SemiringFastPathDescriptor::ElementwiseBinary {
65///         op: SemiringBinaryOp::Mul,
66///     });
67/// assert!(supported);
68/// ```
69pub trait TensorSemiringFastPath<Alg: Semiring> {
70    /// Backend-specific plan type.
71    type Plan;
72
73    /// Backend-specific execution context.
74    type Context;
75
76    /// Plan an optional semiring fast path.
77    fn plan(
78        ctx: &mut Self::Context,
79        desc: &SemiringFastPathDescriptor,
80        shapes: &[&[usize]],
81    ) -> Result<Self::Plan>;
82
83    /// Execute an optional semiring fast path.
84    fn execute(
85        ctx: &mut Self::Context,
86        plan: &Self::Plan,
87        alpha: Alg::Scalar,
88        inputs: &[&Tensor<Alg::Scalar>],
89        beta: Alg::Scalar,
90        output: &mut Tensor<Alg::Scalar>,
91    ) -> Result<()>;
92
93    /// Query whether the optional path is available.
94    fn has_fast_path(desc: SemiringFastPathDescriptor) -> bool;
95}