tenferro_ext_tropical/prims/
impls.rs

1use std::any::TypeId;
2
3use num_traits::Float;
4use strided_view::{StridedView, StridedViewMut};
5use tenferro_algebra::{Algebra, Scalar, Semiring};
6use tenferro_device::{Error, Result};
7use tenferro_prims::{
8    CpuBackend, CpuContext, SemiringCoreDescriptor, SemiringFastPathDescriptor, TensorSemiringCore,
9    TensorSemiringFastPath,
10};
11use tenferro_tensor::Tensor;
12
13use super::execute::{execute_batched_gemm_optimized, tropical_execute, TropicalGemmDispatch};
14use super::plan::{tropical_plan, TropicalPlan};
15use super::view::{tensor_to_view, tensor_to_view_mut};
16use crate::algebra::{MaxMulAlgebra, MaxPlusAlgebra, MinPlusAlgebra};
17use crate::scalar::{MaxMul, MaxPlus, MinPlus};
18
19/// Try to dispatch BatchedGemm to the SIMD-optimized path for a concrete type.
20///
21/// SAFETY: The TypeId check guarantees T == $concrete_ty before transmuting.
22/// All tropical scalar types are #[repr(transparent)] over their inner type,
23/// so the transmute is sound.
24macro_rules! try_simd_dispatch {
25    ($T:ty, [$($concrete_ty:ty),+ $(,)?], $inputs:expr, $alpha:expr, $beta:expr,
26     $output:expr, $batch_dims:expr, $m:expr, $n:expr, $k:expr) => {{
27        let tid = TypeId::of::<$T>();
28        $(
29            if tid == TypeId::of::<$concrete_ty>() {
30                let a = unsafe {
31                    &*($inputs[0] as *const StridedView<$T> as *const StridedView<$concrete_ty>)
32                };
33                let b = unsafe {
34                    &*($inputs[1] as *const StridedView<$T> as *const StridedView<$concrete_ty>)
35                };
36                let out = unsafe {
37                    &mut *(
38                        $output as *mut StridedViewMut<$T>
39                            as *mut StridedViewMut<$concrete_ty>
40                    )
41                };
42                let alpha = unsafe { *(&$alpha as *const $T as *const $concrete_ty) };
43                let beta = unsafe { *(&$beta as *const $T as *const $concrete_ty) };
44                return execute_batched_gemm_optimized(
45                    alpha,
46                    &[a, b],
47                    beta,
48                    out,
49                    $batch_dims,
50                    $m,
51                    $n,
52                    $k,
53                );
54            }
55        )+
56    }};
57}
58
59macro_rules! impl_tropical_prims {
60    ($marker:ident, $wrapper:ident) => {
61        impl<S: Scalar + Float> TensorSemiringCore<$marker<S>> for CpuBackend
62        where
63            $marker<S>: Algebra<Scalar = $wrapper<S>> + Semiring<Scalar = $wrapper<S>>,
64            $wrapper<S>: Scalar + TropicalGemmDispatch,
65        {
66            type Plan = TropicalPlan<$wrapper<S>>;
67            type Context = CpuContext;
68
69            fn plan(
70                _ctx: &mut CpuContext,
71                desc: &SemiringCoreDescriptor,
72                shapes: &[&[usize]],
73            ) -> Result<TropicalPlan<$wrapper<S>>> {
74                tropical_plan(desc, shapes)
75            }
76
77            fn execute(
78                _ctx: &mut CpuContext,
79                plan: &TropicalPlan<$wrapper<S>>,
80                alpha: $wrapper<S>,
81                inputs: &[&Tensor<$wrapper<S>>],
82                beta: $wrapper<S>,
83                output: &mut Tensor<$wrapper<S>>,
84            ) -> Result<()> {
85                let views: Vec<StridedView<$wrapper<S>>> = inputs
86                    .iter()
87                    .map(|t| tensor_to_view(t))
88                    .collect::<Result<_>>()?;
89                let view_refs: Vec<&StridedView<$wrapper<S>>> = views.iter().collect();
90                let mut out_view = tensor_to_view_mut(output)?;
91
92                if let TropicalPlan::BatchedGemm {
93                    batch_dims,
94                    m,
95                    n,
96                    k,
97                    ..
98                } = plan
99                {
100                    if views.len() != 2 {
101                        return Err(Error::InvalidArgument(
102                            "BatchedGemm execute requires 2 input tensors".into(),
103                        ));
104                    }
105                    try_simd_dispatch!(
106                        $wrapper<S>,
107                        [$wrapper<f64>, $wrapper<f32>],
108                        &view_refs,
109                        alpha,
110                        beta,
111                        &mut out_view,
112                        batch_dims,
113                        *m,
114                        *n,
115                        *k
116                    );
117                }
118                tropical_execute(plan, alpha, &view_refs, beta, &mut out_view)
119            }
120        }
121
122        impl<S: Scalar + Float> TensorSemiringFastPath<$marker<S>> for CpuBackend
123        where
124            $marker<S>: Algebra<Scalar = $wrapper<S>> + Semiring<Scalar = $wrapper<S>>,
125            $wrapper<S>: Scalar,
126        {
127            type Plan = TropicalPlan<$wrapper<S>>;
128            type Context = CpuContext;
129
130            fn plan(
131                _ctx: &mut CpuContext,
132                _desc: &SemiringFastPathDescriptor,
133                _shapes: &[&[usize]],
134            ) -> Result<TropicalPlan<$wrapper<S>>> {
135                Err(Error::InvalidArgument(
136                    "tropical algebras do not support semiring fast paths".into(),
137                ))
138            }
139
140            fn execute(
141                _ctx: &mut CpuContext,
142                _plan: &TropicalPlan<$wrapper<S>>,
143                _alpha: $wrapper<S>,
144                _inputs: &[&Tensor<$wrapper<S>>],
145                _beta: $wrapper<S>,
146                _output: &mut Tensor<$wrapper<S>>,
147            ) -> Result<()> {
148                Err(Error::InvalidArgument(
149                    "tropical algebras do not support semiring fast paths".into(),
150                ))
151            }
152
153            fn has_fast_path(_desc: SemiringFastPathDescriptor) -> bool {
154                false
155            }
156        }
157    };
158}
159
160impl_tropical_prims!(MaxPlusAlgebra, MaxPlus);
161impl_tropical_prims!(MinPlusAlgebra, MinPlus);
162impl_tropical_prims!(MaxMulAlgebra, MaxMul);