tenferro_einsum/error.rs
1//! Error types owned by the einsum crate.
2//!
3//! # Examples
4//!
5//! ```rust
6//! use tenferro_einsum::Error;
7//!
8//! let err = Error::InvalidArgument("bad subscripts".into());
9//! assert_eq!(err.to_string(), "invalid argument: bad subscripts");
10//! ```
11
12/// Errors produced while parsing, planning, or lowering einsum expressions.
13///
14/// # Examples
15///
16/// ```rust
17/// use tenferro_einsum::Error;
18///
19/// let err = Error::ShapeMismatch {
20/// expected: vec![2, 3],
21/// got: vec![2, 4],
22/// };
23/// assert!(err.to_string().contains("shape mismatch"));
24/// ```
25#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
26pub enum Error {
27 /// Tensor shapes are incompatible for an einsum expression.
28 #[error("shape mismatch: expected {expected:?}, got {got:?}")]
29 ShapeMismatch {
30 /// Expected shape or dimension sizes.
31 expected: Vec<usize>,
32 /// Actual shape or dimension sizes.
33 got: Vec<usize>,
34 },
35
36 /// An invalid einsum argument was provided.
37 #[error("invalid argument: {0}")]
38 InvalidArgument(String),
39}
40
41impl Error {
42 /// Convert this einsum error into a tensor backend failure.
43 ///
44 /// # Examples
45 ///
46 /// ```rust
47 /// use tenferro_einsum::Error;
48 /// use tenferro_tensor::Error as TensorError;
49 ///
50 /// let err = Error::InvalidArgument("bad subscripts".into())
51 /// .to_tensor_error("einsum_extension");
52 ///
53 /// assert!(matches!(
54 /// err,
55 /// TensorError::BackendFailure {
56 /// op: "einsum_extension",
57 /// ..
58 /// }
59 /// ));
60 /// ```
61 #[must_use]
62 pub fn to_tensor_error(&self, op: &'static str) -> tenferro_tensor::Error {
63 tenferro_tensor::Error::backend_failure(op, self)
64 }
65}
66
67/// Result type alias for einsum parsing, planning, and lowering.
68///
69/// # Examples
70///
71/// ```rust
72/// use tenferro_einsum::{Error, Result};
73///
74/// let result: Result<()> = Err(Error::InvalidArgument("bad input".into()));
75/// assert!(result.is_err());
76/// ```
77pub type Result<T> = std::result::Result<T, Error>;