Skip to main content

tenferro_xla/
executor.rs

1use std::fmt;
2
3use tenferro_runtime::GraphProgram;
4use tenferro_tensor::Tensor;
5
6use crate::Error;
7use crate::{lower_to_stablehlo, Result, StableHloModule};
8
9/// Options for the experimental XLA executor.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_xla::XlaExecutorOptions;
15///
16/// let options = XlaExecutorOptions::default();
17/// assert_eq!(options, XlaExecutorOptions::default());
18/// ```
19#[derive(Clone, Copy, Default, PartialEq, Eq)]
20#[non_exhaustive]
21pub struct XlaExecutorOptions {
22    _private: (),
23}
24
25impl fmt::Debug for XlaExecutorOptions {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        f.debug_struct("XlaExecutorOptions").finish()
28    }
29}
30
31/// Experimental peer executor for XLA/PJRT.
32///
33/// # Examples
34///
35/// ```
36/// use tenferro_runtime::{GraphCompiler, TracedTensor};
37/// use tenferro_xla::XlaExecutor;
38///
39/// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
40/// let mut compiler = GraphCompiler::new();
41/// let program = compiler.compile(&x.neg()).unwrap();
42/// let module = XlaExecutor::default().lower_to_stablehlo(&program).unwrap();
43/// assert!(module.as_str().contains("stablehlo.negate"));
44/// ```
45pub struct XlaExecutor {
46    options: XlaExecutorOptions,
47    #[cfg(feature = "pjrt")]
48    plugin: Option<crate::pjrt::PjrtPlugin>,
49}
50
51impl fmt::Debug for XlaExecutor {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("XlaExecutor")
54            .field("options", &self.options)
55            .field("has_loaded_pjrt_plugin", &self.has_loaded_pjrt_plugin())
56            .finish()
57    }
58}
59
60impl XlaExecutor {
61    /// Create an executor with explicit options.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use tenferro_xla::{XlaExecutor, XlaExecutorOptions};
67    ///
68    /// let executor = XlaExecutor::new(XlaExecutorOptions::default());
69    /// assert_eq!(executor.options(), XlaExecutorOptions::default());
70    /// ```
71    pub fn new(options: XlaExecutorOptions) -> Self {
72        Self {
73            options,
74            #[cfg(feature = "pjrt")]
75            plugin: None,
76        }
77    }
78
79    /// Create an executor by loading PJRT configuration from environment variables.
80    ///
81    /// Without the `pjrt` feature this returns [`Error::PjrtFeatureDisabled`].
82    ///
83    /// # Examples
84    ///
85    /// ```
86    /// use tenferro_xla::{Error, XlaExecutor};
87    ///
88    /// if let Err(err) = XlaExecutor::from_env() {
89    ///     assert!(matches!(err, Error::PjrtFeatureDisabled | Error::MissingEnv { .. } | Error::PluginLoad { .. }));
90    /// }
91    /// ```
92    #[cfg(not(feature = "pjrt"))]
93    pub fn from_env() -> Result<Self> {
94        Err(Error::PjrtFeatureDisabled)
95    }
96
97    /// Create an executor by loading PJRT configuration from environment variables.
98    ///
99    /// # Examples
100    ///
101    /// ```
102    /// use tenferro_xla::XlaExecutor;
103    ///
104    /// let _ = XlaExecutor::from_env();
105    /// ```
106    #[cfg(feature = "pjrt")]
107    pub fn from_env() -> Result<Self> {
108        Self::from_env_var(crate::TENFERRO_PJRT_PLUGIN_ENV)
109    }
110
111    /// Create an executor by loading a PJRT plugin path from a specific
112    /// environment variable.
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// use tenferro_xla::XlaExecutor;
118    ///
119    /// let _ = XlaExecutor::from_env_var("__TENFERRO_XLA_DOCS_UNSET");
120    /// ```
121    #[cfg(feature = "pjrt")]
122    pub fn from_env_var(var: &'static str) -> Result<Self> {
123        let plugin = crate::pjrt::PjrtPlugin::load_from_env(var)?;
124        Ok(Self {
125            options: XlaExecutorOptions::default(),
126            plugin: Some(plugin),
127        })
128    }
129
130    /// Return the executor options.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// use tenferro_xla::XlaExecutor;
136    ///
137    /// assert_eq!(XlaExecutor::default().options(), Default::default());
138    /// ```
139    pub fn options(&self) -> XlaExecutorOptions {
140        self.options
141    }
142
143    /// Return whether this executor owns a loaded PJRT plugin.
144    ///
145    /// # Examples
146    ///
147    /// ```
148    /// use tenferro_xla::XlaExecutor;
149    ///
150    /// assert!(!XlaExecutor::default().has_loaded_pjrt_plugin());
151    /// ```
152    pub fn has_loaded_pjrt_plugin(&self) -> bool {
153        #[cfg(feature = "pjrt")]
154        {
155            self.plugin.is_some()
156        }
157        #[cfg(not(feature = "pjrt"))]
158        {
159            false
160        }
161    }
162
163    /// Lower a graph program to StableHLO without executing it.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
169    /// use tenferro_xla::XlaExecutor;
170    ///
171    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
172    /// let mut compiler = GraphCompiler::new();
173    /// let program = compiler.compile(&x.neg()).unwrap();
174    /// let module = XlaExecutor::default().lower_to_stablehlo(&program).unwrap();
175    /// assert!(module.as_str().contains("stablehlo.negate"));
176    /// ```
177    pub fn lower_to_stablehlo(&self, program: &GraphProgram) -> Result<StableHloModule> {
178        lower_to_stablehlo(program)
179    }
180
181    /// Execute a graph program through a loaded PJRT plugin and return all outputs.
182    ///
183    /// Inputs must match [`GraphProgram::input_specs`] exactly. This
184    /// experimental execution path supports the same exact-static-shape,
185    /// `F32`/`F64` subset as StableHLO lowering.
186    ///
187    /// # Examples
188    ///
189    /// ```
190    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
191    /// use tenferro_tensor::Tensor;
192    /// use tenferro_xla::{Error, XlaExecutor};
193    ///
194    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
195    /// let mut compiler = GraphCompiler::new();
196    /// let program = compiler.compile(&x.neg()).unwrap();
197    /// let input = Tensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
198    /// let err = XlaExecutor::default()
199    ///     .run_many_with_inputs(&program, &[&input])
200    ///     .unwrap_err();
201    /// assert!(matches!(err, Error::PjrtFeatureDisabled | Error::PjrtPluginNotLoaded));
202    /// ```
203    pub fn run_many_with_inputs(
204        &self,
205        program: &GraphProgram,
206        inputs: &[&Tensor],
207    ) -> Result<Vec<Tensor>> {
208        #[cfg(feature = "pjrt")]
209        {
210            let Some(plugin) = self.plugin.as_ref() else {
211                return Err(Error::PjrtPluginNotLoaded);
212            };
213            crate::pjrt::run_many_with_inputs(plugin, program, inputs)
214        }
215        #[cfg(not(feature = "pjrt"))]
216        {
217            let _ = (program, inputs);
218            Err(Error::PjrtFeatureDisabled)
219        }
220    }
221
222    /// Execute a single-output graph program through a loaded PJRT plugin.
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// use tenferro_runtime::{GraphCompiler, TracedTensor};
228    /// use tenferro_tensor::Tensor;
229    /// use tenferro_xla::{Error, XlaExecutor};
230    ///
231    /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
232    /// let mut compiler = GraphCompiler::new();
233    /// let program = compiler.compile(&x.neg()).unwrap();
234    /// let input = Tensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
235    /// let err = XlaExecutor::default().run_with_inputs(&program, &[&input]).unwrap_err();
236    /// assert!(matches!(err, Error::PjrtFeatureDisabled | Error::PjrtPluginNotLoaded));
237    /// ```
238    pub fn run_with_inputs(&self, program: &GraphProgram, inputs: &[&Tensor]) -> Result<Tensor> {
239        let mut outputs = self.run_many_with_inputs(program, inputs)?;
240        if outputs.len() != 1 {
241            return Err(crate::Error::InvalidProgram {
242                message: format!(
243                    "PJRT single-output execution expected 1 output, got {}",
244                    outputs.len()
245                ),
246            });
247        }
248        Ok(outputs.remove(0))
249    }
250}
251
252impl Default for XlaExecutor {
253    fn default() -> Self {
254        Self::new(XlaExecutorOptions::default())
255    }
256}