tenferro_prims/infra/
registry.rs

1use tenferro_device::{Error, Result};
2
3use crate::cpu::CpuBackend;
4
5#[cfg(feature = "cuda")]
6use crate::cuda::CudaBackend;
7#[cfg(not(feature = "cuda"))]
8use crate::gpu_stubs::CudaBackend;
9
10use crate::gpu_stubs::RocmBackend;
11
12// ===========================================================================
13// Backend Registry
14// ===========================================================================
15
16/// Registry of available compute backends.
17///
18/// **Current behavior:** Only the CPU backend is available.
19/// [`load_cutensor`](BackendRegistry::load_cutensor) and
20/// [`load_hiptensor`](BackendRegistry::load_hiptensor) always return
21/// errors. GPU backend loading is not yet implemented.
22///
23/// When GPU support is implemented, this registry will hold the CPU
24/// backend (always available) and optional GPU backends loaded at
25/// runtime.
26///
27/// # Examples
28///
29/// ```ignore
30/// // Aspirational API — GPU loading not yet functional.
31/// use tenferro_prims::BackendRegistry;
32///
33/// let mut registry = BackendRegistry::new(); // CPU only
34/// registry.load_cutensor("/usr/lib/libcutensor.so").unwrap();
35/// assert!(registry.cuda().is_some());
36/// ```
37pub struct BackendRegistry {
38    cpu: CpuBackend,
39    cuda: Option<CudaBackend>,
40    rocm: Option<RocmBackend>,
41}
42
43impl BackendRegistry {
44    /// Create a registry with CPU backend only.
45    pub fn new() -> Self {
46        Self {
47            cpu: CpuBackend,
48            cuda: None,
49            rocm: None,
50        }
51    }
52
53    /// Load the cuTENSOR library from the given path.
54    ///
55    /// When the `cuda` feature is enabled, delegates to
56    /// [`CudaBackend::load`] which initializes cudarc and populates
57    /// the cuTENSOR vtable. Without the `cuda` feature, always returns
58    /// `Err(DeviceError)`.
59    ///
60    /// The caller (Julia, Python, or standalone Rust) provides the path
61    /// to the shared library. No auto-search.
62    #[cfg(feature = "cuda")]
63    pub fn load_cutensor(&mut self, path: &str) -> Result<()> {
64        let (backend, _ctx) = CudaBackend::load(path)?;
65        self.cuda = Some(backend);
66        Ok(())
67    }
68
69    /// Load the cuTENSOR library from the given path.
70    ///
71    /// **Status: Not available.** The `cuda` feature is not enabled.
72    /// Rebuild with `--features cuda` to enable cuTENSOR support.
73    #[cfg(not(feature = "cuda"))]
74    pub fn load_cutensor(&mut self, _path: &str) -> Result<()> {
75        Err(Error::DeviceError(
76            "cuTENSOR runtime loading not available: rebuild with --features cuda".into(),
77        ))
78    }
79
80    /// Load the hipTENSOR library from the given path.
81    ///
82    /// **Status: Not yet implemented.** Always returns
83    /// `Err(DeviceError)`.
84    ///
85    /// When implemented, the caller (Julia, Python, or standalone Rust)
86    /// will provide the path to the shared library. No auto-search.
87    pub fn load_hiptensor(&mut self, _path: &str) -> Result<()> {
88        Err(Error::DeviceError(
89            "hipTENSOR runtime loading not yet implemented".into(),
90        ))
91    }
92
93    /// Returns a reference to the CPU backend.
94    pub fn cpu(&self) -> &CpuBackend {
95        &self.cpu
96    }
97
98    /// Returns a reference to the CUDA backend, if loaded.
99    pub fn cuda(&self) -> Option<&CudaBackend> {
100        self.cuda.as_ref()
101    }
102
103    /// Returns a reference to the ROCm backend, if loaded.
104    pub fn rocm(&self) -> Option<&RocmBackend> {
105        self.rocm.as_ref()
106    }
107}
108
109impl Default for BackendRegistry {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115#[cfg(test)]
116mod tests;