tenferro_internal_runtime/
dispatch.rs1use tenferro_algebra::Standard;
2use tenferro_einsum::EinsumBackend;
3use tenferro_internal_error::{Error, Result};
4use tenferro_linalg::backend::{LinalgCapabilityOp, TensorLinalgBackend, TensorLinalgContextFor};
5use tenferro_linalg::LiftPermutationMatrixTensor;
6use tenferro_prims::{
7 CpuBackend, CpuContext, CudaBackend, CudaContext, RocmBackend, RocmContext,
8 SemiringFastPathDescriptor, TensorSemiringCore, TensorSemiringFastPath,
9};
10
11use crate::contracts::{EinsumRuntimeValue, LinalgRuntimeValue, StandardRuntimeValue};
12use crate::{with_default_runtime, RuntimeContext};
13
14pub trait DenseEinsumBackend<T, C>:
15 EinsumBackend<Standard<T>>
16 + TensorSemiringCore<Standard<T>, Context = C>
17 + TensorSemiringFastPath<
18 Standard<T>,
19 Context = C,
20 Plan = <Self as TensorSemiringCore<Standard<T>>>::Plan,
21 >
22where
23 T: StandardRuntimeValue,
24{
25}
26
27impl<T, C, B> DenseEinsumBackend<T, C> for B
28where
29 T: StandardRuntimeValue,
30 B: EinsumBackend<Standard<T>>
31 + TensorSemiringCore<Standard<T>, Context = C>
32 + TensorSemiringFastPath<
33 Standard<T>,
34 Context = C,
35 Plan = <B as TensorSemiringCore<Standard<T>>>::Plan,
36 >,
37{
38}
39
40pub trait RuntimeSlot {
41 type Context;
42 type SemiringBackend;
43
44 const NAME: &'static str;
45}
46
47pub struct CpuRuntimeSlot;
48pub struct CudaRuntimeSlot;
49pub struct RocmRuntimeSlot;
50
51impl RuntimeSlot for CpuRuntimeSlot {
52 type Context = CpuContext;
53 type SemiringBackend = CpuBackend;
54
55 const NAME: &'static str = "cpu";
56}
57
58impl RuntimeSlot for CudaRuntimeSlot {
59 type Context = CudaContext;
60 type SemiringBackend = CudaBackend;
61
62 const NAME: &'static str = "cuda";
63}
64
65impl RuntimeSlot for RocmRuntimeSlot {
66 type Context = RocmContext;
67 type SemiringBackend = RocmBackend;
68
69 const NAME: &'static str = "rocm";
70}
71
72pub trait ScaledRealLinalgDispatchValue:
73 crate::contracts::RealLinalgRuntimeValue
74 + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
75 + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
76 + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
77{
78}
79
80impl<T> ScaledRealLinalgDispatchValue for T where
81 T: crate::contracts::RealLinalgRuntimeValue
82 + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
83 + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
84 + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
85{
86}
87
88pub trait ScaledLinalgDispatchValue:
89 crate::contracts::LinalgRuntimeValue
90 + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
91 + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
92 + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
93{
94}
95
96impl<T> ScaledLinalgDispatchValue for T where
97 T: crate::contracts::LinalgRuntimeValue
98 + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
99 + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
100 + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
101{
102}
103
104pub trait NormLinalgDispatchValue:
105 crate::contracts::LinalgRuntimeValue
106 + tenferro_linalg::NormPrimal<CpuContext>
107 + tenferro_linalg::NormPrimal<CudaContext>
108 + tenferro_linalg::NormPrimal<RocmContext>
109 + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
110 + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
111 + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
112{
113}
114
115impl<T> NormLinalgDispatchValue for T where
116 T: crate::contracts::LinalgRuntimeValue
117 + tenferro_linalg::NormPrimal<CpuContext>
118 + tenferro_linalg::NormPrimal<CudaContext>
119 + tenferro_linalg::NormPrimal<RocmContext>
120 + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
121 + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
122 + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
123{
124}
125
126pub trait SlogdetLinalgDispatchValue:
127 crate::contracts::LinalgRuntimeValue
128 + tenferro_linalg::SlogdetDispatch<CpuContext>
129 + tenferro_linalg::SlogdetFruleDispatch<CpuContext>
130 + tenferro_linalg::SlogdetRruleDispatch<CpuContext>
131 + tenferro_linalg::SlogdetDispatch<CudaContext>
132 + tenferro_linalg::SlogdetFruleDispatch<CudaContext>
133 + tenferro_linalg::SlogdetRruleDispatch<CudaContext>
134 + tenferro_linalg::SlogdetDispatch<RocmContext>
135 + tenferro_linalg::SlogdetFruleDispatch<RocmContext>
136 + tenferro_linalg::SlogdetRruleDispatch<RocmContext>
137{
138}
139
140impl<T> SlogdetLinalgDispatchValue for T where
141 T: crate::contracts::LinalgRuntimeValue
142 + tenferro_linalg::SlogdetDispatch<CpuContext>
143 + tenferro_linalg::SlogdetFruleDispatch<CpuContext>
144 + tenferro_linalg::SlogdetRruleDispatch<CpuContext>
145 + tenferro_linalg::SlogdetDispatch<CudaContext>
146 + tenferro_linalg::SlogdetFruleDispatch<CudaContext>
147 + tenferro_linalg::SlogdetRruleDispatch<CudaContext>
148 + tenferro_linalg::SlogdetDispatch<RocmContext>
149 + tenferro_linalg::SlogdetFruleDispatch<RocmContext>
150 + tenferro_linalg::SlogdetRruleDispatch<RocmContext>
151{
152}
153
154pub trait MatrixExpLinalgDispatchValue:
155 crate::contracts::LinalgRuntimeValue
156 + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
157 + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
158 + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
159 + tenferro_linalg::MatrixExpAbsTensor<CpuContext>
160 + tenferro_linalg::MatrixExpAbsTensor<CudaContext>
161 + tenferro_linalg::MatrixExpAbsTensor<RocmContext>
162{
163}
164
165impl<T> MatrixExpLinalgDispatchValue for T where
166 T: crate::contracts::LinalgRuntimeValue
167 + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
168 + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
169 + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
170 + tenferro_linalg::MatrixExpAbsTensor<CpuContext>
171 + tenferro_linalg::MatrixExpAbsTensor<CudaContext>
172 + tenferro_linalg::MatrixExpAbsTensor<RocmContext>
173{
174}
175
176pub trait RealMatrixExpLinalgDispatchValue:
177 crate::contracts::RealLinalgRuntimeValue + MatrixExpLinalgDispatchValue
178{
179}
180
181impl<T> RealMatrixExpLinalgDispatchValue for T where
182 T: crate::contracts::RealLinalgRuntimeValue + MatrixExpLinalgDispatchValue
183{
184}
185
186pub trait LuLinalgDispatchValue:
187 crate::contracts::LinalgRuntimeValue
188 + LiftPermutationMatrixTensor<CpuContext>
189 + LiftPermutationMatrixTensor<CudaContext>
190 + LiftPermutationMatrixTensor<RocmContext>
191{
192}
193
194impl<T> LuLinalgDispatchValue for T where
195 T: crate::contracts::LinalgRuntimeValue
196 + LiftPermutationMatrixTensor<CpuContext>
197 + LiftPermutationMatrixTensor<CudaContext>
198 + LiftPermutationMatrixTensor<RocmContext>
199{
200}
201
202pub trait RealLuLinalgDispatchValue:
203 crate::contracts::RealLinalgRuntimeValue
204 + LiftPermutationMatrixTensor<CpuContext>
205 + LiftPermutationMatrixTensor<CudaContext>
206 + LiftPermutationMatrixTensor<RocmContext>
207{
208}
209
210impl<T> RealLuLinalgDispatchValue for T where
211 T: crate::contracts::RealLinalgRuntimeValue
212 + LiftPermutationMatrixTensor<CpuContext>
213 + LiftPermutationMatrixTensor<CudaContext>
214 + LiftPermutationMatrixTensor<RocmContext>
215{
216}
217
218fn contract_capability_marker() -> SemiringFastPathDescriptor {
219 SemiringFastPathDescriptor::Contract {
220 modes_a: vec![0],
221 modes_b: vec![0],
222 modes_c: vec![0],
223 }
224}
225
226fn ensure_einsum_runtime_capability<T, Slot>(op: &'static str) -> Result<()>
227where
228 T: EinsumRuntimeValue,
229 Slot: RuntimeSlot,
230 Slot::SemiringBackend: DenseEinsumBackend<T, Slot::Context>,
231{
232 if !<Slot::SemiringBackend as TensorSemiringFastPath<Standard<T>>>::has_fast_path(
233 contract_capability_marker(),
234 ) {
235 return Err(unsupported_runtime_capability(op, Slot::NAME));
236 }
237 Ok(())
238}
239
240fn ensure_linalg_runtime_capability<T, Slot>(
241 op: &'static str,
242 capability: LinalgCapabilityOp,
243) -> Result<()>
244where
245 T: LinalgRuntimeValue,
246 Slot: RuntimeSlot,
247 Slot::Context: TensorLinalgContextFor<T>,
248 <Slot::Context as TensorLinalgContextFor<T>>::Backend:
249 TensorLinalgBackend<T, Context = Slot::Context>,
250{
251 if !<<Slot::Context as TensorLinalgContextFor<T>>::Backend as TensorLinalgBackend<T>>::has_linalg_support(capability)
252 {
253 return Err(unsupported_runtime_capability(op, Slot::NAME));
254 }
255 Ok(())
256}
257
258pub fn unsupported_runtime_capability(op: &'static str, runtime: &'static str) -> Error {
259 Error::UnsupportedRuntimeOp { op, runtime }
260}
261
262pub fn with_runtime<R>(
263 cpu: impl FnOnce(&mut CpuContext) -> Result<R>,
264 cuda: impl FnOnce(&mut CudaContext) -> Result<R>,
265 rocm: impl FnOnce(&mut RocmContext) -> Result<R>,
266) -> Result<R> {
267 with_default_runtime(|runtime| match runtime {
268 RuntimeContext::Cpu(ctx) => cpu(ctx),
269 RuntimeContext::Cuda(ctx) => cuda(ctx),
270 RuntimeContext::Rocm(ctx) => rocm(ctx),
271 })
272}
273
274pub fn with_einsum_runtime<T: EinsumRuntimeValue, R>(
275 op: &'static str,
276 cpu: impl FnOnce(&mut CpuContext) -> Result<R>,
277 cuda: impl FnOnce(&mut CudaContext) -> Result<R>,
278 rocm: impl FnOnce(&mut RocmContext) -> Result<R>,
279) -> Result<R> {
280 with_runtime(
281 cpu,
282 |ctx| {
283 ensure_einsum_runtime_capability::<T, CudaRuntimeSlot>(op)?;
284 cuda(ctx)
285 },
286 |ctx| {
287 ensure_einsum_runtime_capability::<T, RocmRuntimeSlot>(op)?;
288 rocm(ctx)
289 },
290 )
291}
292
293pub fn with_linalg_runtime<T: LinalgRuntimeValue, R>(
294 op: &'static str,
295 capability: LinalgCapabilityOp,
296 cpu: impl FnOnce(&mut CpuContext) -> Result<R>,
297 cuda: impl FnOnce(&mut CudaContext) -> Result<R>,
298 rocm: impl FnOnce(&mut RocmContext) -> Result<R>,
299) -> Result<R> {
300 with_runtime(
301 |ctx| {
302 ensure_linalg_runtime_capability::<T, CpuRuntimeSlot>(op, capability)?;
303 cpu(ctx)
304 },
305 |ctx| {
306 ensure_linalg_runtime_capability::<T, CudaRuntimeSlot>(op, capability)?;
307 cuda(ctx)
308 },
309 |ctx| {
310 ensure_linalg_runtime_capability::<T, RocmRuntimeSlot>(op, capability)?;
311 rocm(ctx)
312 },
313 )
314}