tenferro_capi/
status.rs

1use std::cell::RefCell;
2
3/// Status code type returned by all C-API functions.
4pub type tfe_status_t = i32;
5
6/// Operation completed successfully.
7pub const TFE_SUCCESS: tfe_status_t = 0;
8
9/// Invalid argument (null pointer, bad subscript string, etc.).
10pub const TFE_INVALID_ARGUMENT: tfe_status_t = -1;
11
12/// Tensor shape mismatch for the requested operation.
13pub const TFE_SHAPE_MISMATCH: tfe_status_t = -2;
14
15/// Internal error (Rust panic or unexpected failure).
16pub const TFE_INTERNAL_ERROR: tfe_status_t = -3;
17
18/// Output buffer is too small for the requested data.
19pub const TFE_BUFFER_TOO_SMALL: tfe_status_t = -4;
20
21thread_local! {
22    static LAST_ERROR: RefCell<String> = const { RefCell::new(String::new()) };
23}
24
25pub(crate) type StatusResult<T> = Result<T, tfe_status_t>;
26
27/// Store an error message in thread-local storage.
28pub(crate) fn set_last_error(msg: &str) {
29    LAST_ERROR.with(|cell| {
30        *cell.borrow_mut() = msg.to_string();
31    });
32}
33
34/// Extract a human-readable message from a panic payload.
35pub(crate) fn panic_message(payload: &dyn std::any::Any) -> String {
36    if let Some(s) = payload.downcast_ref::<&str>() {
37        s.to_string()
38    } else if let Some(s) = payload.downcast_ref::<String>() {
39        s.clone()
40    } else {
41        "unknown panic".to_string()
42    }
43}
44
45/// Retrieve the last error message (UTF-8, null-terminated).
46///
47/// - `buf == NULL`: query required length only (written to `*out_len`).
48/// - `buf != NULL`: copy message into buffer.
49///
50/// `out_len` receives the required buffer size including the null terminator.
51///
52/// # Returns
53///
54/// - `TFE_SUCCESS` on success (or query-only mode).
55/// - `TFE_INVALID_ARGUMENT` if `out_len` is null.
56/// - `TFE_BUFFER_TOO_SMALL` if `buf_len` is too small (required size in `*out_len`).
57///
58/// # Safety
59///
60/// - `out_len` must be a valid, non-null pointer.
61/// - If `buf` is non-null, it must point to a buffer of at least `buf_len` bytes.
62///
63/// # Examples (C)
64///
65/// ```c
66/// size_t len;
67/// tfe_last_error_message(NULL, 0, &len);
68/// if (len > 0) {
69///     char *buf = malloc(len);
70///     tfe_last_error_message((uint8_t *)buf, len, &len);
71///     printf("Error: %s\n", buf);
72///     free(buf);
73/// }
74/// ```
75#[no_mangle]
76pub unsafe extern "C" fn tfe_last_error_message(
77    buf: *mut u8,
78    buf_len: usize,
79    out_len: *mut usize,
80) -> tfe_status_t {
81    if out_len.is_null() {
82        return TFE_INVALID_ARGUMENT;
83    }
84
85    LAST_ERROR.with(|cell| {
86        let msg = cell.borrow();
87        let required = msg.len() + 1;
88        *out_len = required;
89
90        if buf.is_null() {
91            return TFE_SUCCESS;
92        }
93        if buf_len < required {
94            return TFE_BUFFER_TOO_SMALL;
95        }
96
97        std::ptr::copy_nonoverlapping(msg.as_ptr(), buf, msg.len());
98        *buf.add(msg.len()) = 0;
99        TFE_SUCCESS
100    })
101}
102
103/// Map `tenferro_device::Error` to the appropriate status code.
104pub(crate) fn map_device_error(err: &tenferro_device::Error) -> tfe_status_t {
105    set_last_error(&err.to_string());
106    use tenferro_device::Error;
107    match err {
108        Error::InvalidArgument(_)
109        | Error::StrideError(_)
110        | Error::CrossMemorySpaceOperation { .. } => TFE_INVALID_ARGUMENT,
111        Error::ShapeMismatch { .. } | Error::RankMismatch { .. } => TFE_SHAPE_MISMATCH,
112        Error::DeviceError(_) | Error::NoCompatibleComputeDevice { .. } => TFE_INTERNAL_ERROR,
113    }
114}
115
116/// Map `chainrules_core::AutodiffError` to the appropriate status code.
117pub(crate) fn map_ad_error(err: &chainrules_core::AutodiffError) -> tfe_status_t {
118    set_last_error(&err.to_string());
119    use chainrules_core::AutodiffError;
120    match err {
121        AutodiffError::InvalidArgument(_)
122        | AutodiffError::ModeNotSupported { .. }
123        | AutodiffError::NonScalarLoss { .. }
124        | AutodiffError::HvpNotSupported
125        | AutodiffError::GraphFreed => TFE_INVALID_ARGUMENT,
126        AutodiffError::TangentShapeMismatch { .. } => TFE_SHAPE_MISMATCH,
127        AutodiffError::MissingNode => TFE_INTERNAL_ERROR,
128    }
129}
130
131pub(crate) unsafe fn finalize_ptr<T>(
132    result: std::thread::Result<StatusResult<*mut T>>,
133    status: *mut tfe_status_t,
134) -> *mut T {
135    match result {
136        Ok(Ok(ptr)) => {
137            if !status.is_null() {
138                *status = TFE_SUCCESS;
139            }
140            ptr
141        }
142        Ok(Err(code)) => {
143            if !status.is_null() {
144                *status = code;
145            }
146            std::ptr::null_mut()
147        }
148        Err(panic) => {
149            set_last_error(&panic_message(&*panic));
150            if !status.is_null() {
151                *status = TFE_INTERNAL_ERROR;
152            }
153            std::ptr::null_mut()
154        }
155    }
156}
157
158pub(crate) unsafe fn finalize_void(
159    result: std::thread::Result<StatusResult<()>>,
160    status: *mut tfe_status_t,
161) {
162    match result {
163        Ok(Ok(())) => {
164            if !status.is_null() {
165                *status = TFE_SUCCESS;
166            }
167        }
168        Ok(Err(code)) => {
169            if !status.is_null() {
170                *status = code;
171            }
172        }
173        Err(panic) => {
174            set_last_error(&panic_message(&*panic));
175            if !status.is_null() {
176                *status = TFE_INTERNAL_ERROR;
177            }
178        }
179    }
180}