Skip to main content

tenferro_xla/pjrt/
plugin.rs

1use std::fmt;
2use std::path::{Path, PathBuf};
3
4use libloading::Library;
5
6use crate::{Error, Result};
7
8use super::sys::{GetPjrtApiFn, PJRT_Api};
9
10/// Dynamically loaded PJRT plugin.
11///
12/// # Examples
13///
14/// ```
15/// use tenferro_xla::{Error, PjrtPlugin};
16///
17/// let err = PjrtPlugin::load_from_env("__TENFERRO_XLA_DOCS_UNSET").unwrap_err();
18/// assert!(matches!(err, Error::MissingEnv { .. }));
19/// ```
20pub struct PjrtPlugin {
21    path: PathBuf,
22    _api: *const PJRT_Api,
23    _library: Library,
24}
25
26impl fmt::Debug for PjrtPlugin {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_struct("PjrtPlugin")
29            .field("path", &self.path)
30            .finish_non_exhaustive()
31    }
32}
33
34impl PjrtPlugin {
35    /// Load a PJRT plugin path from an environment variable.
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// use tenferro_xla::{Error, PjrtPlugin};
41    ///
42    /// let err = PjrtPlugin::load_from_env("__TENFERRO_XLA_DOCS_UNSET").unwrap_err();
43    /// assert!(matches!(err, Error::MissingEnv { .. }));
44    /// ```
45    pub fn load_from_env(var: &'static str) -> Result<Self> {
46        let path = super::plugin_path_from_env(var)?;
47        Self::load_path(path)
48    }
49
50    /// Load a PJRT plugin from an explicit dynamic-library path.
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use tenferro_xla::{Error, PjrtPlugin};
56    ///
57    /// let err = PjrtPlugin::load_path("/definitely/missing/pjrt.so").unwrap_err();
58    /// assert!(matches!(err, Error::PluginLoad { .. }));
59    /// ```
60    pub fn load_path(path: impl Into<PathBuf>) -> Result<Self> {
61        let path = path.into();
62        // SAFETY: Loading a dynamic library is inherently unsafe. The library is
63        // retained in `Self` for at least as long as the API pointer is exposed.
64        let library = unsafe { Library::new(&path) }.map_err(|err| Error::PluginLoad {
65            path: path.clone(),
66            message: err.to_string(),
67        })?;
68        // SAFETY: OpenXLA PJRT plugins export `GetPjrtApi` with this signature.
69        // The returned table is owned by the plugin and remains valid while the
70        // dynamic library is loaded.
71        let api = unsafe {
72            let symbol: libloading::Symbol<'_, GetPjrtApiFn> =
73                library
74                    .get(b"GetPjrtApi")
75                    .map_err(|err| Error::PluginLoad {
76                        path: path.clone(),
77                        message: err.to_string(),
78                    })?;
79            symbol()
80        };
81        if api.is_null() {
82            return Err(Error::PluginLoad {
83                path,
84                message: "GetPjrtApi returned null".to_string(),
85            });
86        }
87        Ok(Self {
88            path,
89            _api: api,
90            _library: library,
91        })
92    }
93
94    /// Return the path used to load the plugin.
95    ///
96    /// # Examples
97    ///
98    /// ```
99    /// use std::path::Path;
100    /// use tenferro_xla::PjrtPlugin;
101    ///
102    /// let _path_type = std::any::type_name::<&Path>();
103    /// let _method: fn(&PjrtPlugin) -> &Path = PjrtPlugin::path;
104    /// ```
105    pub fn path(&self) -> &Path {
106        &self.path
107    }
108
109    pub(crate) fn api(&self) -> *const PJRT_Api {
110        self._api
111    }
112}