tenferro_prims/infra/registry.rs
1use tenferro_device::{Error, Result};
2
3use crate::cpu::CpuBackend;
4
5#[cfg(feature = "cuda")]
6use crate::cuda::CudaBackend;
7#[cfg(not(feature = "cuda"))]
8use crate::gpu_stubs::CudaBackend;
9
10use crate::gpu_stubs::RocmBackend;
11
12// ===========================================================================
13// Backend Registry
14// ===========================================================================
15
16/// Registry of available compute backends.
17///
18/// **Current behavior:** Only the CPU backend is available.
19/// [`load_cutensor`](BackendRegistry::load_cutensor) and
20/// [`load_hiptensor`](BackendRegistry::load_hiptensor) always return
21/// errors. GPU backend loading is not yet implemented.
22///
23/// When GPU support is implemented, this registry will hold the CPU
24/// backend (always available) and optional GPU backends loaded at
25/// runtime.
26///
27/// # Examples
28///
29/// ```ignore
30/// // Aspirational API — GPU loading not yet functional.
31/// use tenferro_prims::BackendRegistry;
32///
33/// let mut registry = BackendRegistry::new(); // CPU only
34/// registry.load_cutensor("/usr/lib/libcutensor.so").unwrap();
35/// assert!(registry.cuda().is_some());
36/// ```
37pub struct BackendRegistry {
38 cpu: CpuBackend,
39 cuda: Option<CudaBackend>,
40 rocm: Option<RocmBackend>,
41}
42
43impl BackendRegistry {
44 /// Create a registry with CPU backend only.
45 pub fn new() -> Self {
46 Self {
47 cpu: CpuBackend,
48 cuda: None,
49 rocm: None,
50 }
51 }
52
53 /// Load the cuTENSOR library from the given path.
54 ///
55 /// When the `cuda` feature is enabled, delegates to
56 /// [`CudaBackend::load`] which initializes cudarc and populates
57 /// the cuTENSOR vtable. Without the `cuda` feature, always returns
58 /// `Err(DeviceError)`.
59 ///
60 /// The caller (Julia, Python, or standalone Rust) provides the path
61 /// to the shared library. No auto-search.
62 #[cfg(feature = "cuda")]
63 pub fn load_cutensor(&mut self, path: &str) -> Result<()> {
64 let (backend, _ctx) = CudaBackend::load(path)?;
65 self.cuda = Some(backend);
66 Ok(())
67 }
68
69 /// Load the cuTENSOR library from the given path.
70 ///
71 /// **Status: Not available.** The `cuda` feature is not enabled.
72 /// Rebuild with `--features cuda` to enable cuTENSOR support.
73 #[cfg(not(feature = "cuda"))]
74 pub fn load_cutensor(&mut self, _path: &str) -> Result<()> {
75 Err(Error::DeviceError(
76 "cuTENSOR runtime loading not available: rebuild with --features cuda".into(),
77 ))
78 }
79
80 /// Load the hipTENSOR library from the given path.
81 ///
82 /// **Status: Not yet implemented.** Always returns
83 /// `Err(DeviceError)`.
84 ///
85 /// When implemented, the caller (Julia, Python, or standalone Rust)
86 /// will provide the path to the shared library. No auto-search.
87 pub fn load_hiptensor(&mut self, _path: &str) -> Result<()> {
88 Err(Error::DeviceError(
89 "hipTENSOR runtime loading not yet implemented".into(),
90 ))
91 }
92
93 /// Returns a reference to the CPU backend.
94 pub fn cpu(&self) -> &CpuBackend {
95 &self.cpu
96 }
97
98 /// Returns a reference to the CUDA backend, if loaded.
99 pub fn cuda(&self) -> Option<&CudaBackend> {
100 self.cuda.as_ref()
101 }
102
103 /// Returns a reference to the ROCm backend, if loaded.
104 pub fn rocm(&self) -> Option<&RocmBackend> {
105 self.rocm.as_ref()
106 }
107}
108
109impl Default for BackendRegistry {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115#[cfg(test)]
116mod tests;