tenferro_internal_runtime/
dispatch.rs

1use tenferro_algebra::Standard;
2use tenferro_einsum::EinsumBackend;
3use tenferro_internal_error::{Error, Result};
4use tenferro_linalg::backend::{LinalgCapabilityOp, TensorLinalgBackend, TensorLinalgContextFor};
5use tenferro_linalg::LiftPermutationMatrixTensor;
6use tenferro_prims::{
7    CpuBackend, CpuContext, CudaBackend, CudaContext, RocmBackend, RocmContext,
8    SemiringFastPathDescriptor, TensorSemiringCore, TensorSemiringFastPath,
9};
10
11use crate::contracts::{EinsumRuntimeValue, LinalgRuntimeValue, StandardRuntimeValue};
12use crate::{with_default_runtime, RuntimeContext};
13
14pub trait DenseEinsumBackend<T, C>:
15    EinsumBackend<Standard<T>>
16    + TensorSemiringCore<Standard<T>, Context = C>
17    + TensorSemiringFastPath<
18        Standard<T>,
19        Context = C,
20        Plan = <Self as TensorSemiringCore<Standard<T>>>::Plan,
21    >
22where
23    T: StandardRuntimeValue,
24{
25}
26
27impl<T, C, B> DenseEinsumBackend<T, C> for B
28where
29    T: StandardRuntimeValue,
30    B: EinsumBackend<Standard<T>>
31        + TensorSemiringCore<Standard<T>, Context = C>
32        + TensorSemiringFastPath<
33            Standard<T>,
34            Context = C,
35            Plan = <B as TensorSemiringCore<Standard<T>>>::Plan,
36        >,
37{
38}
39
40pub trait RuntimeSlot {
41    type Context;
42    type SemiringBackend;
43
44    const NAME: &'static str;
45}
46
47pub struct CpuRuntimeSlot;
48pub struct CudaRuntimeSlot;
49pub struct RocmRuntimeSlot;
50
51impl RuntimeSlot for CpuRuntimeSlot {
52    type Context = CpuContext;
53    type SemiringBackend = CpuBackend;
54
55    const NAME: &'static str = "cpu";
56}
57
58impl RuntimeSlot for CudaRuntimeSlot {
59    type Context = CudaContext;
60    type SemiringBackend = CudaBackend;
61
62    const NAME: &'static str = "cuda";
63}
64
65impl RuntimeSlot for RocmRuntimeSlot {
66    type Context = RocmContext;
67    type SemiringBackend = RocmBackend;
68
69    const NAME: &'static str = "rocm";
70}
71
72pub trait ScaledRealLinalgDispatchValue:
73    crate::contracts::RealLinalgRuntimeValue
74    + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
75    + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
76    + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
77{
78}
79
80impl<T> ScaledRealLinalgDispatchValue for T where
81    T: crate::contracts::RealLinalgRuntimeValue
82        + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
83        + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
84        + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
85{
86}
87
88pub trait ScaledLinalgDispatchValue:
89    crate::contracts::LinalgRuntimeValue
90    + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
91    + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
92    + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
93{
94}
95
96impl<T> ScaledLinalgDispatchValue for T where
97    T: crate::contracts::LinalgRuntimeValue
98        + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
99        + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
100        + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
101{
102}
103
104pub trait NormLinalgDispatchValue:
105    crate::contracts::LinalgRuntimeValue
106    + tenferro_linalg::NormPrimal<CpuContext>
107    + tenferro_linalg::NormPrimal<CudaContext>
108    + tenferro_linalg::NormPrimal<RocmContext>
109    + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
110    + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
111    + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
112{
113}
114
115impl<T> NormLinalgDispatchValue for T where
116    T: crate::contracts::LinalgRuntimeValue
117        + tenferro_linalg::NormPrimal<CpuContext>
118        + tenferro_linalg::NormPrimal<CudaContext>
119        + tenferro_linalg::NormPrimal<RocmContext>
120        + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
121        + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
122        + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
123{
124}
125
126pub trait SlogdetLinalgDispatchValue:
127    crate::contracts::LinalgRuntimeValue
128    + tenferro_linalg::SlogdetDispatch<CpuContext>
129    + tenferro_linalg::SlogdetFruleDispatch<CpuContext>
130    + tenferro_linalg::SlogdetRruleDispatch<CpuContext>
131    + tenferro_linalg::SlogdetDispatch<CudaContext>
132    + tenferro_linalg::SlogdetFruleDispatch<CudaContext>
133    + tenferro_linalg::SlogdetRruleDispatch<CudaContext>
134    + tenferro_linalg::SlogdetDispatch<RocmContext>
135    + tenferro_linalg::SlogdetFruleDispatch<RocmContext>
136    + tenferro_linalg::SlogdetRruleDispatch<RocmContext>
137{
138}
139
140impl<T> SlogdetLinalgDispatchValue for T where
141    T: crate::contracts::LinalgRuntimeValue
142        + tenferro_linalg::SlogdetDispatch<CpuContext>
143        + tenferro_linalg::SlogdetFruleDispatch<CpuContext>
144        + tenferro_linalg::SlogdetRruleDispatch<CpuContext>
145        + tenferro_linalg::SlogdetDispatch<CudaContext>
146        + tenferro_linalg::SlogdetFruleDispatch<CudaContext>
147        + tenferro_linalg::SlogdetRruleDispatch<CudaContext>
148        + tenferro_linalg::SlogdetDispatch<RocmContext>
149        + tenferro_linalg::SlogdetFruleDispatch<RocmContext>
150        + tenferro_linalg::SlogdetRruleDispatch<RocmContext>
151{
152}
153
154pub trait MatrixExpLinalgDispatchValue:
155    crate::contracts::LinalgRuntimeValue
156    + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
157    + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
158    + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
159    + tenferro_linalg::MatrixExpAbsTensor<CpuContext>
160    + tenferro_linalg::MatrixExpAbsTensor<CudaContext>
161    + tenferro_linalg::MatrixExpAbsTensor<RocmContext>
162{
163}
164
165impl<T> MatrixExpLinalgDispatchValue for T where
166    T: crate::contracts::LinalgRuntimeValue
167        + tenferro_linalg::ScaleTensorByRealSameShape<CpuContext>
168        + tenferro_linalg::ScaleTensorByRealSameShape<CudaContext>
169        + tenferro_linalg::ScaleTensorByRealSameShape<RocmContext>
170        + tenferro_linalg::MatrixExpAbsTensor<CpuContext>
171        + tenferro_linalg::MatrixExpAbsTensor<CudaContext>
172        + tenferro_linalg::MatrixExpAbsTensor<RocmContext>
173{
174}
175
176pub trait RealMatrixExpLinalgDispatchValue:
177    crate::contracts::RealLinalgRuntimeValue + MatrixExpLinalgDispatchValue
178{
179}
180
181impl<T> RealMatrixExpLinalgDispatchValue for T where
182    T: crate::contracts::RealLinalgRuntimeValue + MatrixExpLinalgDispatchValue
183{
184}
185
186pub trait LuLinalgDispatchValue:
187    crate::contracts::LinalgRuntimeValue
188    + LiftPermutationMatrixTensor<CpuContext>
189    + LiftPermutationMatrixTensor<CudaContext>
190    + LiftPermutationMatrixTensor<RocmContext>
191{
192}
193
194impl<T> LuLinalgDispatchValue for T where
195    T: crate::contracts::LinalgRuntimeValue
196        + LiftPermutationMatrixTensor<CpuContext>
197        + LiftPermutationMatrixTensor<CudaContext>
198        + LiftPermutationMatrixTensor<RocmContext>
199{
200}
201
202pub trait RealLuLinalgDispatchValue:
203    crate::contracts::RealLinalgRuntimeValue
204    + LiftPermutationMatrixTensor<CpuContext>
205    + LiftPermutationMatrixTensor<CudaContext>
206    + LiftPermutationMatrixTensor<RocmContext>
207{
208}
209
210impl<T> RealLuLinalgDispatchValue for T where
211    T: crate::contracts::RealLinalgRuntimeValue
212        + LiftPermutationMatrixTensor<CpuContext>
213        + LiftPermutationMatrixTensor<CudaContext>
214        + LiftPermutationMatrixTensor<RocmContext>
215{
216}
217
218fn contract_capability_marker() -> SemiringFastPathDescriptor {
219    SemiringFastPathDescriptor::Contract {
220        modes_a: vec![0],
221        modes_b: vec![0],
222        modes_c: vec![0],
223    }
224}
225
226fn ensure_einsum_runtime_capability<T, Slot>(op: &'static str) -> Result<()>
227where
228    T: EinsumRuntimeValue,
229    Slot: RuntimeSlot,
230    Slot::SemiringBackend: DenseEinsumBackend<T, Slot::Context>,
231{
232    if !<Slot::SemiringBackend as TensorSemiringFastPath<Standard<T>>>::has_fast_path(
233        contract_capability_marker(),
234    ) {
235        return Err(unsupported_runtime_capability(op, Slot::NAME));
236    }
237    Ok(())
238}
239
240fn ensure_linalg_runtime_capability<T, Slot>(
241    op: &'static str,
242    capability: LinalgCapabilityOp,
243) -> Result<()>
244where
245    T: LinalgRuntimeValue,
246    Slot: RuntimeSlot,
247    Slot::Context: TensorLinalgContextFor<T>,
248    <Slot::Context as TensorLinalgContextFor<T>>::Backend:
249        TensorLinalgBackend<T, Context = Slot::Context>,
250{
251    if !<<Slot::Context as TensorLinalgContextFor<T>>::Backend as TensorLinalgBackend<T>>::has_linalg_support(capability)
252    {
253        return Err(unsupported_runtime_capability(op, Slot::NAME));
254    }
255    Ok(())
256}
257
258pub fn unsupported_runtime_capability(op: &'static str, runtime: &'static str) -> Error {
259    Error::UnsupportedRuntimeOp { op, runtime }
260}
261
262pub fn with_runtime<R>(
263    cpu: impl FnOnce(&mut CpuContext) -> Result<R>,
264    cuda: impl FnOnce(&mut CudaContext) -> Result<R>,
265    rocm: impl FnOnce(&mut RocmContext) -> Result<R>,
266) -> Result<R> {
267    with_default_runtime(|runtime| match runtime {
268        RuntimeContext::Cpu(ctx) => cpu(ctx),
269        RuntimeContext::Cuda(ctx) => cuda(ctx),
270        RuntimeContext::Rocm(ctx) => rocm(ctx),
271    })
272}
273
274pub fn with_einsum_runtime<T: EinsumRuntimeValue, R>(
275    op: &'static str,
276    cpu: impl FnOnce(&mut CpuContext) -> Result<R>,
277    cuda: impl FnOnce(&mut CudaContext) -> Result<R>,
278    rocm: impl FnOnce(&mut RocmContext) -> Result<R>,
279) -> Result<R> {
280    with_runtime(
281        cpu,
282        |ctx| {
283            ensure_einsum_runtime_capability::<T, CudaRuntimeSlot>(op)?;
284            cuda(ctx)
285        },
286        |ctx| {
287            ensure_einsum_runtime_capability::<T, RocmRuntimeSlot>(op)?;
288            rocm(ctx)
289        },
290    )
291}
292
293pub fn with_linalg_runtime<T: LinalgRuntimeValue, R>(
294    op: &'static str,
295    capability: LinalgCapabilityOp,
296    cpu: impl FnOnce(&mut CpuContext) -> Result<R>,
297    cuda: impl FnOnce(&mut CudaContext) -> Result<R>,
298    rocm: impl FnOnce(&mut RocmContext) -> Result<R>,
299) -> Result<R> {
300    with_runtime(
301        |ctx| {
302            ensure_linalg_runtime_capability::<T, CpuRuntimeSlot>(op, capability)?;
303            cpu(ctx)
304        },
305        |ctx| {
306            ensure_linalg_runtime_capability::<T, CudaRuntimeSlot>(op, capability)?;
307            cuda(ctx)
308        },
309        |ctx| {
310            ensure_linalg_runtime_capability::<T, RocmRuntimeSlot>(op, capability)?;
311            rocm(ctx)
312        },
313    )
314}