1#[cfg(feature = "cuda")]
25use std::any::Any;
26
27#[cfg(feature = "cuda")]
28mod cubecl;
29#[cfg(feature = "cuda")]
30mod kernels;
31#[cfg(feature = "webgpu")]
32mod webgpu;
33
34#[cfg(feature = "cuda")]
35pub use cubecl::{
36 device_ptr, download_tensor, gpu_available, upload_tensor, CudaBackend, CudaRuntime,
37};
38#[cfg(feature = "cuda")]
39#[doc(hidden)]
40pub use cubecl::{CudaExtensionCache, CudaExtensionCacheGuard};
41#[cfg(feature = "webgpu")]
42pub use webgpu::{
43 download_webgpu_tensor, upload_webgpu_tensor, webgpu_available, WebGpuBackend, WebGpuRuntime,
44};
45
46#[cfg(feature = "cuda")]
47#[doc(hidden)]
48pub mod cuda_interop {
49 pub use crate::cubecl::interop::*;
50 pub use crate::cubecl::{CudaExtensionCache, CudaExtensionCacheGuard};
51}
52
53#[cfg(any(feature = "cuda", feature = "webgpu"))]
54use tenferro_tensor::*;
55
56#[cfg(feature = "cuda")]
57pub(crate) mod backend {
58 pub use tenferro_tensor::backend::*;
59}
60
61#[cfg(feature = "cuda")]
62pub(crate) mod config {
63 pub use tenferro_tensor::config::*;
64}
65
66#[cfg(feature = "cuda")]
67pub(crate) mod types {
68 pub(crate) use crate::CubeclBuffer;
69 pub use tenferro_tensor::types::*;
70}
71
72#[cfg(feature = "cuda")]
74#[derive(Clone)]
75pub(crate) struct CubeclBuffer<T> {
76 handle: cubecl_runtime::server::Handle,
77 len: usize,
78 pub(crate) _marker: std::marker::PhantomData<T>,
79}
80
81#[cfg(feature = "cuda")]
82impl<T> std::fmt::Debug for CubeclBuffer<T> {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("CubeclBuffer")
85 .field("len", &self.len)
86 .finish()
87 }
88}
89
90#[cfg(feature = "cuda")]
91impl<T> CubeclBuffer<T> {
92 pub(crate) fn new(handle: cubecl_runtime::server::Handle, len: usize) -> Self {
93 Self {
94 handle,
95 len,
96 _marker: std::marker::PhantomData,
97 }
98 }
99
100 pub(crate) fn handle(&self) -> &cubecl_runtime::server::Handle {
101 &self.handle
102 }
103
104 pub(crate) fn element_len(&self) -> usize {
105 self.len
106 }
107}
108
109#[cfg(feature = "cuda")]
110impl<T: Send + Sync + 'static> BackendBuffer<T> for CubeclBuffer<T> {
111 fn backend_family(&self) -> &'static str {
112 "cubecl"
113 }
114
115 fn len(&self) -> usize {
116 self.len
117 }
118
119 fn as_any(&self) -> &dyn Any {
120 self
121 }
122}