Skip to main content

tenferro_gpu/
lib.rs

1//! GPU backend implementations for tenferro tensors.
2//!
3//! # Examples
4//!
5//! ```rust
6//! #[cfg(feature = "cuda")]
7//! {
8//!     use tenferro_gpu::{download_tensor, gpu_available, upload_tensor, CudaBackend};
9//!     use tenferro_tensor::{Tensor, TensorElementwise};
10//!
11//!     if gpu_available() {
12//!         let mut backend = CudaBackend::new(0).unwrap();
13//!         let a = Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]);
14//!         let b = Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]);
15//!         let gpu_a = upload_tensor(backend.runtime(), &a).unwrap();
16//!         let gpu_b = upload_tensor(backend.runtime(), &b).unwrap();
17//!         let gpu_sum = backend.add(&gpu_a, &gpu_b).unwrap();
18//!         let sum = download_tensor(backend.runtime(), &gpu_sum).unwrap();
19//!         assert_eq!(sum.as_slice::<f64>().unwrap(), &[4.0, 6.0]);
20//!     }
21//! }
22//! ```
23
24#[cfg(feature = "cuda")]
25use std::any::Any;
26
27#[cfg(feature = "cuda")]
28mod cubecl;
29#[cfg(feature = "cuda")]
30mod kernels;
31#[cfg(feature = "webgpu")]
32mod webgpu;
33
34#[cfg(feature = "cuda")]
35pub use cubecl::{
36    device_ptr, download_tensor, gpu_available, upload_tensor, CudaBackend, CudaRuntime,
37};
38#[cfg(feature = "cuda")]
39#[doc(hidden)]
40pub use cubecl::{CudaExtensionCache, CudaExtensionCacheGuard};
41#[cfg(feature = "webgpu")]
42pub use webgpu::{
43    download_webgpu_tensor, upload_webgpu_tensor, webgpu_available, WebGpuBackend, WebGpuRuntime,
44};
45
46#[cfg(feature = "cuda")]
47#[doc(hidden)]
48pub mod cuda_interop {
49    pub use crate::cubecl::interop::*;
50    pub use crate::cubecl::{CudaExtensionCache, CudaExtensionCacheGuard};
51}
52
53#[cfg(any(feature = "cuda", feature = "webgpu"))]
54use tenferro_tensor::*;
55
56#[cfg(feature = "cuda")]
57pub(crate) mod backend {
58    pub use tenferro_tensor::backend::*;
59}
60
61#[cfg(feature = "cuda")]
62pub(crate) mod config {
63    pub use tenferro_tensor::config::*;
64}
65
66#[cfg(feature = "cuda")]
67pub(crate) mod types {
68    pub(crate) use crate::CubeclBuffer;
69    pub use tenferro_tensor::types::*;
70}
71
72/// CubeCL-managed GPU buffer stored behind tensor backend-buffer trait objects.
73#[cfg(feature = "cuda")]
74#[derive(Clone)]
75pub(crate) struct CubeclBuffer<T> {
76    handle: cubecl_runtime::server::Handle,
77    len: usize,
78    pub(crate) _marker: std::marker::PhantomData<T>,
79}
80
81#[cfg(feature = "cuda")]
82impl<T> std::fmt::Debug for CubeclBuffer<T> {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct("CubeclBuffer")
85            .field("len", &self.len)
86            .finish()
87    }
88}
89
90#[cfg(feature = "cuda")]
91impl<T> CubeclBuffer<T> {
92    pub(crate) fn new(handle: cubecl_runtime::server::Handle, len: usize) -> Self {
93        Self {
94            handle,
95            len,
96            _marker: std::marker::PhantomData,
97        }
98    }
99
100    pub(crate) fn handle(&self) -> &cubecl_runtime::server::Handle {
101        &self.handle
102    }
103
104    pub(crate) fn element_len(&self) -> usize {
105        self.len
106    }
107}
108
109#[cfg(feature = "cuda")]
110impl<T: Send + Sync + 'static> BackendBuffer<T> for CubeclBuffer<T> {
111    fn backend_family(&self) -> &'static str {
112        "cubecl"
113    }
114
115    fn len(&self) -> usize {
116        self.len
117    }
118
119    fn as_any(&self) -> &dyn Any {
120        self
121    }
122}