Skip to main content

tenferro_xla/
error.rs

1use std::path::PathBuf;
2
3use tenferro_tensor::DType;
4
5/// Error type for StableHLO lowering and runtime PJRT plugin loading.
6///
7/// # Examples
8///
9/// ```
10/// use tenferro_runtime::{DType, GraphCompiler, TracedTensor};
11/// use tenferro_xla::{lower_to_stablehlo, Error};
12///
13/// let x = TracedTensor::input_symbolic_shape(DType::I64, 1).unwrap();
14/// let mut compiler = GraphCompiler::new();
15/// let program = compiler
16///     .compile_with_input_specs(&x.neg(), &[(&x, DType::I64, &[2])])
17///     .unwrap();
18/// let err = lower_to_stablehlo(&program).unwrap_err();
19/// assert!(matches!(err, Error::UnsupportedDType { .. }));
20/// ```
21#[derive(Debug, thiserror::Error)]
22pub enum Error {
23    #[error("XLA lowering does not support dtype {dtype:?} in {context}")]
24    UnsupportedDType { dtype: DType, context: &'static str },
25    #[error("XLA lowering does not support ExecOp::{op}: {reason}")]
26    UnsupportedOp {
27        op: &'static str,
28        reason: &'static str,
29    },
30    #[error(
31        "XLA lowering supports only exact static shapes; ExecOp::{op} output {output_index} axis {axis} is {kind}"
32    )]
33    NonStaticShape {
34        op: &'static str,
35        output_index: usize,
36        axis: usize,
37        kind: &'static str,
38    },
39    #[error("invalid XLA program: {message}")]
40    InvalidProgram { message: String },
41    #[error("PJRT support requires enabling the tenferro-xla `pjrt` feature")]
42    PjrtFeatureDisabled,
43    #[error("PJRT execution requires an executor created from a loaded plugin")]
44    PjrtPluginNotLoaded,
45    #[error("PJRT call {call} failed: {message}")]
46    PjrtCall { call: &'static str, message: String },
47    #[error("environment variable {var} is not set; set it to a PJRT plugin .so path")]
48    MissingEnv { var: &'static str },
49    #[error("failed to load PJRT plugin from {path}: {message}")]
50    PluginLoad { path: PathBuf, message: String },
51}
52
53/// Result alias for `tenferro-xla`.
54///
55/// # Examples
56///
57/// ```
58/// use tenferro_xla::Result;
59///
60/// fn ok() -> Result<()> {
61///     Ok(())
62/// }
63///
64/// ok().unwrap();
65/// ```
66pub type Result<T> = std::result::Result<T, Error>;