tenferro_xla/lib.rs
1//! Experimental StableHLO lowering and runtime PJRT plugin loading for tenferro.
2//!
3//! This crate is an optional peer executor over `tenferro-runtime`
4//! [`GraphProgram`](tenferro_runtime::GraphProgram) values. It does not
5//! implement `TensorBackend` and it does not change native CPU, CUDA, or
6//! WebGPU execution.
7//!
8//! # Examples
9//!
10//! ```
11//! use tenferro_runtime::{GraphCompiler, TracedTensor};
12//! use tenferro_xla::lower_to_stablehlo;
13//!
14//! let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
15//! let y = (&x + &x).unwrap();
16//! let mut compiler = GraphCompiler::new();
17//! let program = compiler.compile(&y).unwrap();
18//! let module = lower_to_stablehlo(&program).unwrap();
19//! assert!(module.as_str().contains("stablehlo.add"));
20//! ```
21
22mod error;
23mod executor;
24mod layout;
25mod lowering;
26mod stablehlo;
27
28#[cfg(feature = "pjrt")]
29mod pjrt;
30
31pub use error::{Error, Result};
32pub use executor::{XlaExecutor, XlaExecutorOptions};
33#[cfg(feature = "pjrt")]
34pub use pjrt::PjrtPlugin;
35pub use stablehlo::{StableHloModule, StableHloModuleFingerprint};
36
37/// Environment variable used for the default PJRT plugin path.
38///
39/// # Examples
40///
41/// ```
42/// use tenferro_xla::TENFERRO_PJRT_PLUGIN_ENV;
43///
44/// assert_eq!(TENFERRO_PJRT_PLUGIN_ENV, "TENFERRO_PJRT_PLUGIN");
45/// ```
46pub const TENFERRO_PJRT_PLUGIN_ENV: &str = "TENFERRO_PJRT_PLUGIN";
47
48/// Environment variable used for a GPU-specific PJRT plugin path.
49///
50/// # Examples
51///
52/// ```
53/// use tenferro_xla::TENFERRO_PJRT_GPU_PLUGIN_ENV;
54///
55/// assert_eq!(TENFERRO_PJRT_GPU_PLUGIN_ENV, "TENFERRO_PJRT_GPU_PLUGIN");
56/// ```
57pub const TENFERRO_PJRT_GPU_PLUGIN_ENV: &str = "TENFERRO_PJRT_GPU_PLUGIN";
58
59/// Lower a static-shaped graph program to StableHLO MLIR text.
60///
61/// # Examples
62///
63/// ```
64/// use tenferro_runtime::{GraphCompiler, TracedTensor};
65/// use tenferro_xla::lower_to_stablehlo;
66///
67/// let x = TracedTensor::from_vec_col_major(vec![1], vec![3.0_f64]).unwrap();
68/// let mut compiler = GraphCompiler::new();
69/// let program = compiler.compile(&x.neg()).unwrap();
70/// let module = lower_to_stablehlo(&program).unwrap();
71/// assert!(module.as_str().contains("stablehlo.negate"));
72/// ```
73pub fn lower_to_stablehlo(program: &tenferro_runtime::GraphProgram) -> Result<StableHloModule> {
74 lowering::lower_to_stablehlo(program)
75}