tenferro_ext_tropical/prims/
impls.rs1use std::any::TypeId;
2
3use num_traits::Float;
4use strided_view::{StridedView, StridedViewMut};
5use tenferro_algebra::{Algebra, Scalar, Semiring};
6use tenferro_device::{Error, Result};
7use tenferro_prims::{
8 CpuBackend, CpuContext, SemiringCoreDescriptor, SemiringFastPathDescriptor, TensorSemiringCore,
9 TensorSemiringFastPath,
10};
11use tenferro_tensor::Tensor;
12
13use super::execute::{execute_batched_gemm_optimized, tropical_execute, TropicalGemmDispatch};
14use super::plan::{tropical_plan, TropicalPlan};
15use super::view::{tensor_to_view, tensor_to_view_mut};
16use crate::algebra::{MaxMulAlgebra, MaxPlusAlgebra, MinPlusAlgebra};
17use crate::scalar::{MaxMul, MaxPlus, MinPlus};
18
19macro_rules! try_simd_dispatch {
25 ($T:ty, [$($concrete_ty:ty),+ $(,)?], $inputs:expr, $alpha:expr, $beta:expr,
26 $output:expr, $batch_dims:expr, $m:expr, $n:expr, $k:expr) => {{
27 let tid = TypeId::of::<$T>();
28 $(
29 if tid == TypeId::of::<$concrete_ty>() {
30 let a = unsafe {
31 &*($inputs[0] as *const StridedView<$T> as *const StridedView<$concrete_ty>)
32 };
33 let b = unsafe {
34 &*($inputs[1] as *const StridedView<$T> as *const StridedView<$concrete_ty>)
35 };
36 let out = unsafe {
37 &mut *(
38 $output as *mut StridedViewMut<$T>
39 as *mut StridedViewMut<$concrete_ty>
40 )
41 };
42 let alpha = unsafe { *(&$alpha as *const $T as *const $concrete_ty) };
43 let beta = unsafe { *(&$beta as *const $T as *const $concrete_ty) };
44 return execute_batched_gemm_optimized(
45 alpha,
46 &[a, b],
47 beta,
48 out,
49 $batch_dims,
50 $m,
51 $n,
52 $k,
53 );
54 }
55 )+
56 }};
57}
58
59macro_rules! impl_tropical_prims {
60 ($marker:ident, $wrapper:ident) => {
61 impl<S: Scalar + Float> TensorSemiringCore<$marker<S>> for CpuBackend
62 where
63 $marker<S>: Algebra<Scalar = $wrapper<S>> + Semiring<Scalar = $wrapper<S>>,
64 $wrapper<S>: Scalar + TropicalGemmDispatch,
65 {
66 type Plan = TropicalPlan<$wrapper<S>>;
67 type Context = CpuContext;
68
69 fn plan(
70 _ctx: &mut CpuContext,
71 desc: &SemiringCoreDescriptor,
72 shapes: &[&[usize]],
73 ) -> Result<TropicalPlan<$wrapper<S>>> {
74 tropical_plan(desc, shapes)
75 }
76
77 fn execute(
78 _ctx: &mut CpuContext,
79 plan: &TropicalPlan<$wrapper<S>>,
80 alpha: $wrapper<S>,
81 inputs: &[&Tensor<$wrapper<S>>],
82 beta: $wrapper<S>,
83 output: &mut Tensor<$wrapper<S>>,
84 ) -> Result<()> {
85 let views: Vec<StridedView<$wrapper<S>>> = inputs
86 .iter()
87 .map(|t| tensor_to_view(t))
88 .collect::<Result<_>>()?;
89 let view_refs: Vec<&StridedView<$wrapper<S>>> = views.iter().collect();
90 let mut out_view = tensor_to_view_mut(output)?;
91
92 if let TropicalPlan::BatchedGemm {
93 batch_dims,
94 m,
95 n,
96 k,
97 ..
98 } = plan
99 {
100 if views.len() != 2 {
101 return Err(Error::InvalidArgument(
102 "BatchedGemm execute requires 2 input tensors".into(),
103 ));
104 }
105 try_simd_dispatch!(
106 $wrapper<S>,
107 [$wrapper<f64>, $wrapper<f32>],
108 &view_refs,
109 alpha,
110 beta,
111 &mut out_view,
112 batch_dims,
113 *m,
114 *n,
115 *k
116 );
117 }
118 tropical_execute(plan, alpha, &view_refs, beta, &mut out_view)
119 }
120 }
121
122 impl<S: Scalar + Float> TensorSemiringFastPath<$marker<S>> for CpuBackend
123 where
124 $marker<S>: Algebra<Scalar = $wrapper<S>> + Semiring<Scalar = $wrapper<S>>,
125 $wrapper<S>: Scalar,
126 {
127 type Plan = TropicalPlan<$wrapper<S>>;
128 type Context = CpuContext;
129
130 fn plan(
131 _ctx: &mut CpuContext,
132 _desc: &SemiringFastPathDescriptor,
133 _shapes: &[&[usize]],
134 ) -> Result<TropicalPlan<$wrapper<S>>> {
135 Err(Error::InvalidArgument(
136 "tropical algebras do not support semiring fast paths".into(),
137 ))
138 }
139
140 fn execute(
141 _ctx: &mut CpuContext,
142 _plan: &TropicalPlan<$wrapper<S>>,
143 _alpha: $wrapper<S>,
144 _inputs: &[&Tensor<$wrapper<S>>],
145 _beta: $wrapper<S>,
146 _output: &mut Tensor<$wrapper<S>>,
147 ) -> Result<()> {
148 Err(Error::InvalidArgument(
149 "tropical algebras do not support semiring fast paths".into(),
150 ))
151 }
152
153 fn has_fast_path(_desc: SemiringFastPathDescriptor) -> bool {
154 false
155 }
156 }
157 };
158}
159
160impl_tropical_prims!(MaxPlusAlgebra, MaxPlus);
161impl_tropical_prims!(MinPlusAlgebra, MinPlus);
162impl_tropical_prims!(MaxMulAlgebra, MaxMul);