Skip to main content

tenferro_ad/
eager_backend.rs

1use tenferro_cpu::CpuBackend;
2#[cfg(feature = "cuda")]
3use tenferro_gpu::CudaBackend;
4#[cfg(feature = "webgpu")]
5use tenferro_gpu::WebGpuBackend;
6use tenferro_tensor::backend::ElementwiseFusionPlan;
7use tenferro_tensor::{
8    BackendCachedDot, BackendRuntimeCache, BackendSession, BackendSessionHost, CompareDir, DType,
9    DotGeneralConfig, GatherConfig, PadConfig, Result as TensorResult, ScatterConfig, SliceConfig,
10    Tensor, TensorAnalytic, TensorBackend, TensorBuffer, TensorDeviceTransfer, TensorDot,
11    TensorElementwise, TensorFusion, TensorIndexing, TensorRead, TensorReduction, TensorStructural,
12    TensorValue,
13};
14
15pub enum EagerBackend {
16    Cpu(CpuBackend),
17    #[cfg(feature = "cuda")]
18    Cuda(CudaBackend),
19    #[cfg(feature = "webgpu")]
20    WebGpu(WebGpuBackend),
21}
22
23impl std::fmt::Debug for EagerBackend {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::Cpu(backend) => f.debug_tuple("Cpu").field(backend).finish(),
27            #[cfg(feature = "cuda")]
28            Self::Cuda(backend) => f.debug_tuple("Cuda").field(backend).finish(),
29            #[cfg(feature = "webgpu")]
30            Self::WebGpu(backend) => f.debug_tuple("WebGpu").field(backend).finish(),
31        }
32    }
33}
34
35impl EagerBackend {
36    pub(crate) fn cpu(backend: CpuBackend) -> Self {
37        Self::Cpu(backend)
38    }
39
40    #[cfg(feature = "cuda")]
41    pub(crate) fn cuda(backend: CudaBackend) -> Self {
42        Self::Cuda(backend)
43    }
44
45    #[cfg(feature = "webgpu")]
46    pub(crate) fn webgpu(backend: WebGpuBackend) -> Self {
47        Self::WebGpu(backend)
48    }
49
50    pub(crate) fn synchronize(&mut self) -> TensorResult<()> {
51        match self {
52            Self::Cpu(_) => Ok(()),
53            #[cfg(feature = "cuda")]
54            Self::Cuda(backend) => backend.runtime().synchronize(),
55            #[cfg(feature = "webgpu")]
56            Self::WebGpu(backend) => backend.synchronize(),
57        }
58    }
59}
60
61macro_rules! dispatch {
62    ($backend:expr, $method:ident($($arg:expr),* $(,)?)) => {
63        match $backend {
64            EagerBackend::Cpu(backend) => backend.$method($($arg),*),
65            #[cfg(feature = "cuda")]
66            EagerBackend::Cuda(backend) => backend.$method($($arg),*),
67            #[cfg(feature = "webgpu")]
68            EagerBackend::WebGpu(backend) => backend.$method($($arg),*),
69        }
70    };
71}
72
73macro_rules! delegate_tensor_backend_methods {
74    ($(fn $method:ident($($arg:ident: $ty:ty),* $(,)?) -> $ret:ty;)*) => {
75        $(
76            fn $method(&mut self, $($arg: $ty),*) -> $ret {
77                dispatch!(self, $method($($arg),*))
78            }
79        )*
80    };
81}
82
83impl BackendRuntimeCache for EagerBackend {
84    type RuntimeCache = ();
85}
86
87impl TensorElementwise for EagerBackend {
88    delegate_tensor_backend_methods! {
89        fn add(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
90        fn add_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
91        fn mul(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
92        fn mul_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
93        fn neg(input: &Tensor) -> TensorResult<Tensor>;
94        fn neg_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
95        fn conj(input: &Tensor) -> TensorResult<Tensor>;
96        fn conj_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
97        fn div(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
98        fn div_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
99        fn abs(input: &Tensor) -> TensorResult<Tensor>;
100        fn abs_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
101        fn sign(input: &Tensor) -> TensorResult<Tensor>;
102        fn sign_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
103        fn maximum(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
104        fn maximum_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
105        fn minimum(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
106        fn minimum_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
107        fn compare(lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> TensorResult<Tensor>;
108        fn compare_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>, dir: &CompareDir) -> TensorResult<Tensor>;
109        fn select(pred: &Tensor, on_true: &Tensor, on_false: &Tensor) -> TensorResult<Tensor>;
110        fn select_read(pred: TensorRead<'_>, on_true: TensorRead<'_>, on_false: TensorRead<'_>) -> TensorResult<Tensor>;
111        fn clamp(input: &Tensor, lower: &Tensor, upper: &Tensor) -> TensorResult<Tensor>;
112        fn clamp_read(input: TensorRead<'_>, lower: TensorRead<'_>, upper: TensorRead<'_>) -> TensorResult<Tensor>;
113    }
114}
115
116impl TensorAnalytic for EagerBackend {
117    delegate_tensor_backend_methods! {
118        fn exp(input: &Tensor) -> TensorResult<Tensor>;
119        fn exp_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
120        fn log(input: &Tensor) -> TensorResult<Tensor>;
121        fn log_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
122        fn sin(input: &Tensor) -> TensorResult<Tensor>;
123        fn sin_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
124        fn cos(input: &Tensor) -> TensorResult<Tensor>;
125        fn cos_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
126        fn tanh(input: &Tensor) -> TensorResult<Tensor>;
127        fn tanh_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
128        fn sqrt(input: &Tensor) -> TensorResult<Tensor>;
129        fn sqrt_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
130        fn rsqrt(input: &Tensor) -> TensorResult<Tensor>;
131        fn rsqrt_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
132        fn pow(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
133        fn pow_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
134        fn expm1(input: &Tensor) -> TensorResult<Tensor>;
135        fn expm1_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
136        fn log1p(input: &Tensor) -> TensorResult<Tensor>;
137        fn log1p_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
138    }
139}
140
141impl TensorStructural for EagerBackend {
142    delegate_tensor_backend_methods! {
143        fn transpose(input: &Tensor, perm: &[usize]) -> TensorResult<Tensor>;
144        fn reshape(input: &Tensor, shape: &[usize]) -> TensorResult<Tensor>;
145        fn reshape_read(input: TensorRead<'_>, shape: &[usize]) -> TensorResult<Tensor>;
146        fn broadcast_in_dim(input: &Tensor, shape: &[usize], dims: &[usize]) -> TensorResult<Tensor>;
147        fn broadcast_in_dim_read(input: TensorRead<'_>, shape: &[usize], dims: &[usize]) -> TensorResult<Tensor>;
148        fn cast(input: &Tensor, to: DType) -> TensorResult<Tensor>;
149        fn extract_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> TensorResult<Tensor>;
150        fn embed_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> TensorResult<Tensor>;
151        fn tril(input: &Tensor, k: i64) -> TensorResult<Tensor>;
152        fn triu(input: &Tensor, k: i64) -> TensorResult<Tensor>;
153    }
154}
155
156impl TensorReduction for EagerBackend {
157    delegate_tensor_backend_methods! {
158        fn reduce_sum(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
159        fn reduce_prod(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
160        fn reduce_max(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
161        fn reduce_min(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
162    }
163}
164
165impl TensorDot for EagerBackend {
166    delegate_tensor_backend_methods! {
167        fn dot_general(lhs: &Tensor, rhs: &Tensor, config: &DotGeneralConfig) -> TensorResult<Tensor>;
168        fn dot_general_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>, config: &DotGeneralConfig) -> TensorResult<Tensor>;
169        fn dot_general_with_conj(lhs: &Tensor, rhs: &Tensor, config: &DotGeneralConfig, lhs_conj: bool, rhs_conj: bool) -> TensorResult<Tensor>;
170    }
171}
172
173impl TensorIndexing for EagerBackend {
174    delegate_tensor_backend_methods! {
175        fn gather(operand: &Tensor, start_indices: &Tensor, config: &GatherConfig) -> TensorResult<Tensor>;
176        fn scatter(operand: &Tensor, scatter_indices: &Tensor, updates: &Tensor, config: &ScatterConfig) -> TensorResult<Tensor>;
177        fn slice(input: &Tensor, config: &SliceConfig) -> TensorResult<Tensor>;
178        fn dynamic_slice(input: &Tensor, starts: &Tensor, slice_sizes: &[usize]) -> TensorResult<Tensor>;
179        fn dynamic_update_slice(operand: &Tensor, update: &Tensor, starts: &Tensor) -> TensorResult<Tensor>;
180        fn pad(input: &Tensor, config: &PadConfig) -> TensorResult<Tensor>;
181        fn concatenate(inputs: &[&Tensor], axis: usize) -> TensorResult<Tensor>;
182        fn reverse(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
183    }
184}
185
186impl BackendSessionHost for EagerBackend {
187    fn with_backend_session<R: Send>(
188        &mut self,
189        f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
190    ) -> R {
191        dispatch!(self, with_backend_session(f))
192    }
193}
194
195impl TensorDeviceTransfer for EagerBackend {
196    delegate_tensor_backend_methods! {
197        fn download_to_host(tensor: &Tensor) -> TensorResult<Tensor>;
198        fn upload_host_tensor(tensor: &Tensor) -> TensorResult<Tensor>;
199    }
200}
201
202impl TensorBuffer for EagerBackend {
203    delegate_tensor_backend_methods! {
204        fn reclaim_buffer(tensor: Tensor) -> ();
205    }
206}
207
208impl TensorFusion for EagerBackend {
209    delegate_tensor_backend_methods! {
210        fn execute_elementwise_fusion(inputs: &[&Tensor], plan: &ElementwiseFusionPlan) -> TensorResult<Option<Vec<Tensor>>>;
211        fn execute_broadcast_multiply(lhs: TensorRead<'_>, lhs_shape: &[usize], lhs_dims: &[usize], rhs: TensorRead<'_>, rhs_shape: &[usize], rhs_dims: &[usize]) -> TensorResult<Option<Tensor>>;
212        fn execute_broadcast_multiply_value(lhs: TensorRead<'_>, lhs_shape: &[usize], lhs_dims: &[usize], rhs: TensorRead<'_>, rhs_shape: &[usize], rhs_dims: &[usize]) -> TensorResult<Option<TensorValue>>;
213    }
214}
215
216impl BackendCachedDot for EagerBackend {}
217
218impl TensorBackend for EagerBackend {}