tenferro_xla/pjrt/plugin.rs
1use std::fmt;
2use std::path::{Path, PathBuf};
3
4use libloading::Library;
5
6use crate::{Error, Result};
7
8use super::sys::{GetPjrtApiFn, PJRT_Api};
9
10/// Dynamically loaded PJRT plugin.
11///
12/// # Examples
13///
14/// ```
15/// use tenferro_xla::{Error, PjrtPlugin};
16///
17/// let err = PjrtPlugin::load_from_env("__TENFERRO_XLA_DOCS_UNSET").unwrap_err();
18/// assert!(matches!(err, Error::MissingEnv { .. }));
19/// ```
20pub struct PjrtPlugin {
21 path: PathBuf,
22 _api: *const PJRT_Api,
23 _library: Library,
24}
25
26impl fmt::Debug for PjrtPlugin {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("PjrtPlugin")
29 .field("path", &self.path)
30 .finish_non_exhaustive()
31 }
32}
33
34impl PjrtPlugin {
35 /// Load a PJRT plugin path from an environment variable.
36 ///
37 /// # Examples
38 ///
39 /// ```
40 /// use tenferro_xla::{Error, PjrtPlugin};
41 ///
42 /// let err = PjrtPlugin::load_from_env("__TENFERRO_XLA_DOCS_UNSET").unwrap_err();
43 /// assert!(matches!(err, Error::MissingEnv { .. }));
44 /// ```
45 pub fn load_from_env(var: &'static str) -> Result<Self> {
46 let path = super::plugin_path_from_env(var)?;
47 Self::load_path(path)
48 }
49
50 /// Load a PJRT plugin from an explicit dynamic-library path.
51 ///
52 /// # Examples
53 ///
54 /// ```
55 /// use tenferro_xla::{Error, PjrtPlugin};
56 ///
57 /// let err = PjrtPlugin::load_path("/definitely/missing/pjrt.so").unwrap_err();
58 /// assert!(matches!(err, Error::PluginLoad { .. }));
59 /// ```
60 pub fn load_path(path: impl Into<PathBuf>) -> Result<Self> {
61 let path = path.into();
62 // SAFETY: Loading a dynamic library is inherently unsafe. The library is
63 // retained in `Self` for at least as long as the API pointer is exposed.
64 let library = unsafe { Library::new(&path) }.map_err(|err| Error::PluginLoad {
65 path: path.clone(),
66 message: err.to_string(),
67 })?;
68 // SAFETY: OpenXLA PJRT plugins export `GetPjrtApi` with this signature.
69 // The returned table is owned by the plugin and remains valid while the
70 // dynamic library is loaded.
71 let api = unsafe {
72 let symbol: libloading::Symbol<'_, GetPjrtApiFn> =
73 library
74 .get(b"GetPjrtApi")
75 .map_err(|err| Error::PluginLoad {
76 path: path.clone(),
77 message: err.to_string(),
78 })?;
79 symbol()
80 };
81 if api.is_null() {
82 return Err(Error::PluginLoad {
83 path,
84 message: "GetPjrtApi returned null".to_string(),
85 });
86 }
87 Ok(Self {
88 path,
89 _api: api,
90 _library: library,
91 })
92 }
93
94 /// Return the path used to load the plugin.
95 ///
96 /// # Examples
97 ///
98 /// ```
99 /// use std::path::Path;
100 /// use tenferro_xla::PjrtPlugin;
101 ///
102 /// let _path_type = std::any::type_name::<&Path>();
103 /// let _method: fn(&PjrtPlugin) -> &Path = PjrtPlugin::path;
104 /// ```
105 pub fn path(&self) -> &Path {
106 &self.path
107 }
108
109 pub(crate) fn api(&self) -> *const PJRT_Api {
110 self._api
111 }
112}