1use tenferro_cpu::CpuBackend;
2#[cfg(feature = "cuda")]
3use tenferro_gpu::CudaBackend;
4#[cfg(feature = "webgpu")]
5use tenferro_gpu::WebGpuBackend;
6use tenferro_tensor::backend::ElementwiseFusionPlan;
7use tenferro_tensor::{
8 BackendCachedDot, BackendRuntimeCache, BackendSession, BackendSessionHost, CompareDir, DType,
9 DotGeneralConfig, GatherConfig, PadConfig, Result as TensorResult, ScatterConfig, SliceConfig,
10 Tensor, TensorAnalytic, TensorBackend, TensorBuffer, TensorDeviceTransfer, TensorDot,
11 TensorElementwise, TensorFusion, TensorIndexing, TensorRead, TensorReduction, TensorStructural,
12 TensorValue,
13};
14
15pub enum EagerBackend {
16 Cpu(CpuBackend),
17 #[cfg(feature = "cuda")]
18 Cuda(CudaBackend),
19 #[cfg(feature = "webgpu")]
20 WebGpu(WebGpuBackend),
21}
22
23impl std::fmt::Debug for EagerBackend {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::Cpu(backend) => f.debug_tuple("Cpu").field(backend).finish(),
27 #[cfg(feature = "cuda")]
28 Self::Cuda(backend) => f.debug_tuple("Cuda").field(backend).finish(),
29 #[cfg(feature = "webgpu")]
30 Self::WebGpu(backend) => f.debug_tuple("WebGpu").field(backend).finish(),
31 }
32 }
33}
34
35impl EagerBackend {
36 pub(crate) fn cpu(backend: CpuBackend) -> Self {
37 Self::Cpu(backend)
38 }
39
40 #[cfg(feature = "cuda")]
41 pub(crate) fn cuda(backend: CudaBackend) -> Self {
42 Self::Cuda(backend)
43 }
44
45 #[cfg(feature = "webgpu")]
46 pub(crate) fn webgpu(backend: WebGpuBackend) -> Self {
47 Self::WebGpu(backend)
48 }
49
50 pub(crate) fn synchronize(&mut self) -> TensorResult<()> {
51 match self {
52 Self::Cpu(_) => Ok(()),
53 #[cfg(feature = "cuda")]
54 Self::Cuda(backend) => backend.runtime().synchronize(),
55 #[cfg(feature = "webgpu")]
56 Self::WebGpu(backend) => backend.synchronize(),
57 }
58 }
59}
60
61macro_rules! dispatch {
62 ($backend:expr, $method:ident($($arg:expr),* $(,)?)) => {
63 match $backend {
64 EagerBackend::Cpu(backend) => backend.$method($($arg),*),
65 #[cfg(feature = "cuda")]
66 EagerBackend::Cuda(backend) => backend.$method($($arg),*),
67 #[cfg(feature = "webgpu")]
68 EagerBackend::WebGpu(backend) => backend.$method($($arg),*),
69 }
70 };
71}
72
73macro_rules! delegate_tensor_backend_methods {
74 ($(fn $method:ident($($arg:ident: $ty:ty),* $(,)?) -> $ret:ty;)*) => {
75 $(
76 fn $method(&mut self, $($arg: $ty),*) -> $ret {
77 dispatch!(self, $method($($arg),*))
78 }
79 )*
80 };
81}
82
83impl BackendRuntimeCache for EagerBackend {
84 type RuntimeCache = ();
85}
86
87impl TensorElementwise for EagerBackend {
88 delegate_tensor_backend_methods! {
89 fn add(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
90 fn add_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
91 fn mul(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
92 fn mul_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
93 fn neg(input: &Tensor) -> TensorResult<Tensor>;
94 fn neg_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
95 fn conj(input: &Tensor) -> TensorResult<Tensor>;
96 fn conj_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
97 fn div(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
98 fn div_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
99 fn abs(input: &Tensor) -> TensorResult<Tensor>;
100 fn abs_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
101 fn sign(input: &Tensor) -> TensorResult<Tensor>;
102 fn sign_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
103 fn maximum(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
104 fn maximum_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
105 fn minimum(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
106 fn minimum_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
107 fn compare(lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> TensorResult<Tensor>;
108 fn compare_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>, dir: &CompareDir) -> TensorResult<Tensor>;
109 fn select(pred: &Tensor, on_true: &Tensor, on_false: &Tensor) -> TensorResult<Tensor>;
110 fn select_read(pred: TensorRead<'_>, on_true: TensorRead<'_>, on_false: TensorRead<'_>) -> TensorResult<Tensor>;
111 fn clamp(input: &Tensor, lower: &Tensor, upper: &Tensor) -> TensorResult<Tensor>;
112 fn clamp_read(input: TensorRead<'_>, lower: TensorRead<'_>, upper: TensorRead<'_>) -> TensorResult<Tensor>;
113 }
114}
115
116impl TensorAnalytic for EagerBackend {
117 delegate_tensor_backend_methods! {
118 fn exp(input: &Tensor) -> TensorResult<Tensor>;
119 fn exp_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
120 fn log(input: &Tensor) -> TensorResult<Tensor>;
121 fn log_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
122 fn sin(input: &Tensor) -> TensorResult<Tensor>;
123 fn sin_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
124 fn cos(input: &Tensor) -> TensorResult<Tensor>;
125 fn cos_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
126 fn tanh(input: &Tensor) -> TensorResult<Tensor>;
127 fn tanh_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
128 fn sqrt(input: &Tensor) -> TensorResult<Tensor>;
129 fn sqrt_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
130 fn rsqrt(input: &Tensor) -> TensorResult<Tensor>;
131 fn rsqrt_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
132 fn pow(lhs: &Tensor, rhs: &Tensor) -> TensorResult<Tensor>;
133 fn pow_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> TensorResult<Tensor>;
134 fn expm1(input: &Tensor) -> TensorResult<Tensor>;
135 fn expm1_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
136 fn log1p(input: &Tensor) -> TensorResult<Tensor>;
137 fn log1p_read(input: TensorRead<'_>) -> TensorResult<Tensor>;
138 }
139}
140
141impl TensorStructural for EagerBackend {
142 delegate_tensor_backend_methods! {
143 fn transpose(input: &Tensor, perm: &[usize]) -> TensorResult<Tensor>;
144 fn reshape(input: &Tensor, shape: &[usize]) -> TensorResult<Tensor>;
145 fn reshape_read(input: TensorRead<'_>, shape: &[usize]) -> TensorResult<Tensor>;
146 fn broadcast_in_dim(input: &Tensor, shape: &[usize], dims: &[usize]) -> TensorResult<Tensor>;
147 fn broadcast_in_dim_read(input: TensorRead<'_>, shape: &[usize], dims: &[usize]) -> TensorResult<Tensor>;
148 fn cast(input: &Tensor, to: DType) -> TensorResult<Tensor>;
149 fn extract_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> TensorResult<Tensor>;
150 fn embed_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> TensorResult<Tensor>;
151 fn tril(input: &Tensor, k: i64) -> TensorResult<Tensor>;
152 fn triu(input: &Tensor, k: i64) -> TensorResult<Tensor>;
153 }
154}
155
156impl TensorReduction for EagerBackend {
157 delegate_tensor_backend_methods! {
158 fn reduce_sum(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
159 fn reduce_prod(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
160 fn reduce_max(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
161 fn reduce_min(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
162 }
163}
164
165impl TensorDot for EagerBackend {
166 delegate_tensor_backend_methods! {
167 fn dot_general(lhs: &Tensor, rhs: &Tensor, config: &DotGeneralConfig) -> TensorResult<Tensor>;
168 fn dot_general_read(lhs: TensorRead<'_>, rhs: TensorRead<'_>, config: &DotGeneralConfig) -> TensorResult<Tensor>;
169 fn dot_general_with_conj(lhs: &Tensor, rhs: &Tensor, config: &DotGeneralConfig, lhs_conj: bool, rhs_conj: bool) -> TensorResult<Tensor>;
170 }
171}
172
173impl TensorIndexing for EagerBackend {
174 delegate_tensor_backend_methods! {
175 fn gather(operand: &Tensor, start_indices: &Tensor, config: &GatherConfig) -> TensorResult<Tensor>;
176 fn scatter(operand: &Tensor, scatter_indices: &Tensor, updates: &Tensor, config: &ScatterConfig) -> TensorResult<Tensor>;
177 fn slice(input: &Tensor, config: &SliceConfig) -> TensorResult<Tensor>;
178 fn dynamic_slice(input: &Tensor, starts: &Tensor, slice_sizes: &[usize]) -> TensorResult<Tensor>;
179 fn dynamic_update_slice(operand: &Tensor, update: &Tensor, starts: &Tensor) -> TensorResult<Tensor>;
180 fn pad(input: &Tensor, config: &PadConfig) -> TensorResult<Tensor>;
181 fn concatenate(inputs: &[&Tensor], axis: usize) -> TensorResult<Tensor>;
182 fn reverse(input: &Tensor, axes: &[usize]) -> TensorResult<Tensor>;
183 }
184}
185
186impl BackendSessionHost for EagerBackend {
187 fn with_backend_session<R: Send>(
188 &mut self,
189 f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
190 ) -> R {
191 dispatch!(self, with_backend_session(f))
192 }
193}
194
195impl TensorDeviceTransfer for EagerBackend {
196 delegate_tensor_backend_methods! {
197 fn download_to_host(tensor: &Tensor) -> TensorResult<Tensor>;
198 fn upload_host_tensor(tensor: &Tensor) -> TensorResult<Tensor>;
199 }
200}
201
202impl TensorBuffer for EagerBackend {
203 delegate_tensor_backend_methods! {
204 fn reclaim_buffer(tensor: Tensor) -> ();
205 }
206}
207
208impl TensorFusion for EagerBackend {
209 delegate_tensor_backend_methods! {
210 fn execute_elementwise_fusion(inputs: &[&Tensor], plan: &ElementwiseFusionPlan) -> TensorResult<Option<Vec<Tensor>>>;
211 fn execute_broadcast_multiply(lhs: TensorRead<'_>, lhs_shape: &[usize], lhs_dims: &[usize], rhs: TensorRead<'_>, rhs_shape: &[usize], rhs_dims: &[usize]) -> TensorResult<Option<Tensor>>;
212 fn execute_broadcast_multiply_value(lhs: TensorRead<'_>, lhs_shape: &[usize], lhs_dims: &[usize], rhs: TensorRead<'_>, rhs_shape: &[usize], rhs_dims: &[usize]) -> TensorResult<Option<TensorValue>>;
213 }
214}
215
216impl BackendCachedDot for EagerBackend {}
217
218impl TensorBackend for EagerBackend {}