tenferro_xla/executor.rs
1use std::fmt;
2
3use tenferro_runtime::GraphProgram;
4use tenferro_tensor::Tensor;
5
6use crate::Error;
7use crate::{lower_to_stablehlo, Result, StableHloModule};
8
9/// Options for the experimental XLA executor.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_xla::XlaExecutorOptions;
15///
16/// let options = XlaExecutorOptions::default();
17/// assert_eq!(options, XlaExecutorOptions::default());
18/// ```
19#[derive(Clone, Copy, Default, PartialEq, Eq)]
20#[non_exhaustive]
21pub struct XlaExecutorOptions {
22 _private: (),
23}
24
25impl fmt::Debug for XlaExecutorOptions {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 f.debug_struct("XlaExecutorOptions").finish()
28 }
29}
30
31/// Experimental peer executor for XLA/PJRT.
32///
33/// # Examples
34///
35/// ```
36/// use tenferro_runtime::{GraphCompiler, TracedTensor};
37/// use tenferro_xla::XlaExecutor;
38///
39/// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
40/// let mut compiler = GraphCompiler::new();
41/// let program = compiler.compile(&x.neg()).unwrap();
42/// let module = XlaExecutor::default().lower_to_stablehlo(&program).unwrap();
43/// assert!(module.as_str().contains("stablehlo.negate"));
44/// ```
45pub struct XlaExecutor {
46 options: XlaExecutorOptions,
47 #[cfg(feature = "pjrt")]
48 plugin: Option<crate::pjrt::PjrtPlugin>,
49}
50
51impl fmt::Debug for XlaExecutor {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.debug_struct("XlaExecutor")
54 .field("options", &self.options)
55 .field("has_loaded_pjrt_plugin", &self.has_loaded_pjrt_plugin())
56 .finish()
57 }
58}
59
60impl XlaExecutor {
61 /// Create an executor with explicit options.
62 ///
63 /// # Examples
64 ///
65 /// ```
66 /// use tenferro_xla::{XlaExecutor, XlaExecutorOptions};
67 ///
68 /// let executor = XlaExecutor::new(XlaExecutorOptions::default());
69 /// assert_eq!(executor.options(), XlaExecutorOptions::default());
70 /// ```
71 pub fn new(options: XlaExecutorOptions) -> Self {
72 Self {
73 options,
74 #[cfg(feature = "pjrt")]
75 plugin: None,
76 }
77 }
78
79 /// Create an executor by loading PJRT configuration from environment variables.
80 ///
81 /// Without the `pjrt` feature this returns [`Error::PjrtFeatureDisabled`].
82 ///
83 /// # Examples
84 ///
85 /// ```
86 /// use tenferro_xla::{Error, XlaExecutor};
87 ///
88 /// if let Err(err) = XlaExecutor::from_env() {
89 /// assert!(matches!(err, Error::PjrtFeatureDisabled | Error::MissingEnv { .. } | Error::PluginLoad { .. }));
90 /// }
91 /// ```
92 #[cfg(not(feature = "pjrt"))]
93 pub fn from_env() -> Result<Self> {
94 Err(Error::PjrtFeatureDisabled)
95 }
96
97 /// Create an executor by loading PJRT configuration from environment variables.
98 ///
99 /// # Examples
100 ///
101 /// ```
102 /// use tenferro_xla::XlaExecutor;
103 ///
104 /// let _ = XlaExecutor::from_env();
105 /// ```
106 #[cfg(feature = "pjrt")]
107 pub fn from_env() -> Result<Self> {
108 Self::from_env_var(crate::TENFERRO_PJRT_PLUGIN_ENV)
109 }
110
111 /// Create an executor by loading a PJRT plugin path from a specific
112 /// environment variable.
113 ///
114 /// # Examples
115 ///
116 /// ```
117 /// use tenferro_xla::XlaExecutor;
118 ///
119 /// let _ = XlaExecutor::from_env_var("__TENFERRO_XLA_DOCS_UNSET");
120 /// ```
121 #[cfg(feature = "pjrt")]
122 pub fn from_env_var(var: &'static str) -> Result<Self> {
123 let plugin = crate::pjrt::PjrtPlugin::load_from_env(var)?;
124 Ok(Self {
125 options: XlaExecutorOptions::default(),
126 plugin: Some(plugin),
127 })
128 }
129
130 /// Return the executor options.
131 ///
132 /// # Examples
133 ///
134 /// ```
135 /// use tenferro_xla::XlaExecutor;
136 ///
137 /// assert_eq!(XlaExecutor::default().options(), Default::default());
138 /// ```
139 pub fn options(&self) -> XlaExecutorOptions {
140 self.options
141 }
142
143 /// Return whether this executor owns a loaded PJRT plugin.
144 ///
145 /// # Examples
146 ///
147 /// ```
148 /// use tenferro_xla::XlaExecutor;
149 ///
150 /// assert!(!XlaExecutor::default().has_loaded_pjrt_plugin());
151 /// ```
152 pub fn has_loaded_pjrt_plugin(&self) -> bool {
153 #[cfg(feature = "pjrt")]
154 {
155 self.plugin.is_some()
156 }
157 #[cfg(not(feature = "pjrt"))]
158 {
159 false
160 }
161 }
162
163 /// Lower a graph program to StableHLO without executing it.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
169 /// use tenferro_xla::XlaExecutor;
170 ///
171 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
172 /// let mut compiler = GraphCompiler::new();
173 /// let program = compiler.compile(&x.neg()).unwrap();
174 /// let module = XlaExecutor::default().lower_to_stablehlo(&program).unwrap();
175 /// assert!(module.as_str().contains("stablehlo.negate"));
176 /// ```
177 pub fn lower_to_stablehlo(&self, program: &GraphProgram) -> Result<StableHloModule> {
178 lower_to_stablehlo(program)
179 }
180
181 /// Execute a graph program through a loaded PJRT plugin and return all outputs.
182 ///
183 /// Inputs must match [`GraphProgram::input_specs`] exactly. This
184 /// experimental execution path supports the same exact-static-shape,
185 /// `F32`/`F64` subset as StableHLO lowering.
186 ///
187 /// # Examples
188 ///
189 /// ```
190 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
191 /// use tenferro_tensor::Tensor;
192 /// use tenferro_xla::{Error, XlaExecutor};
193 ///
194 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
195 /// let mut compiler = GraphCompiler::new();
196 /// let program = compiler.compile(&x.neg()).unwrap();
197 /// let input = Tensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
198 /// let err = XlaExecutor::default()
199 /// .run_many_with_inputs(&program, &[&input])
200 /// .unwrap_err();
201 /// assert!(matches!(err, Error::PjrtFeatureDisabled | Error::PjrtPluginNotLoaded));
202 /// ```
203 pub fn run_many_with_inputs(
204 &self,
205 program: &GraphProgram,
206 inputs: &[&Tensor],
207 ) -> Result<Vec<Tensor>> {
208 #[cfg(feature = "pjrt")]
209 {
210 let Some(plugin) = self.plugin.as_ref() else {
211 return Err(Error::PjrtPluginNotLoaded);
212 };
213 crate::pjrt::run_many_with_inputs(plugin, program, inputs)
214 }
215 #[cfg(not(feature = "pjrt"))]
216 {
217 let _ = (program, inputs);
218 Err(Error::PjrtFeatureDisabled)
219 }
220 }
221
222 /// Execute a single-output graph program through a loaded PJRT plugin.
223 ///
224 /// # Examples
225 ///
226 /// ```
227 /// use tenferro_runtime::{GraphCompiler, TracedTensor};
228 /// use tenferro_tensor::Tensor;
229 /// use tenferro_xla::{Error, XlaExecutor};
230 ///
231 /// let x = TracedTensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
232 /// let mut compiler = GraphCompiler::new();
233 /// let program = compiler.compile(&x.neg()).unwrap();
234 /// let input = Tensor::from_vec_col_major(vec![1], vec![2.0_f64]).unwrap();
235 /// let err = XlaExecutor::default().run_with_inputs(&program, &[&input]).unwrap_err();
236 /// assert!(matches!(err, Error::PjrtFeatureDisabled | Error::PjrtPluginNotLoaded));
237 /// ```
238 pub fn run_with_inputs(&self, program: &GraphProgram, inputs: &[&Tensor]) -> Result<Tensor> {
239 let mut outputs = self.run_many_with_inputs(program, inputs)?;
240 if outputs.len() != 1 {
241 return Err(crate::Error::InvalidProgram {
242 message: format!(
243 "PJRT single-output execution expected 1 output, got {}",
244 outputs.len()
245 ),
246 });
247 }
248 Ok(outputs.remove(0))
249 }
250}
251
252impl Default for XlaExecutor {
253 fn default() -> Self {
254 Self::new(XlaExecutorOptions::default())
255 }
256}