1use std::cell::RefCell;
2
3pub type tfe_status_t = i32;
5
6pub const TFE_SUCCESS: tfe_status_t = 0;
8
9pub const TFE_INVALID_ARGUMENT: tfe_status_t = -1;
11
12pub const TFE_SHAPE_MISMATCH: tfe_status_t = -2;
14
15pub const TFE_INTERNAL_ERROR: tfe_status_t = -3;
17
18pub const TFE_BUFFER_TOO_SMALL: tfe_status_t = -4;
20
21thread_local! {
22 static LAST_ERROR: RefCell<String> = const { RefCell::new(String::new()) };
23}
24
25pub(crate) type StatusResult<T> = Result<T, tfe_status_t>;
26
27pub(crate) fn set_last_error(msg: &str) {
29 LAST_ERROR.with(|cell| {
30 *cell.borrow_mut() = msg.to_string();
31 });
32}
33
34pub(crate) fn panic_message(payload: &dyn std::any::Any) -> String {
36 if let Some(s) = payload.downcast_ref::<&str>() {
37 s.to_string()
38 } else if let Some(s) = payload.downcast_ref::<String>() {
39 s.clone()
40 } else {
41 "unknown panic".to_string()
42 }
43}
44
45#[no_mangle]
76pub unsafe extern "C" fn tfe_last_error_message(
77 buf: *mut u8,
78 buf_len: usize,
79 out_len: *mut usize,
80) -> tfe_status_t {
81 if out_len.is_null() {
82 return TFE_INVALID_ARGUMENT;
83 }
84
85 LAST_ERROR.with(|cell| {
86 let msg = cell.borrow();
87 let required = msg.len() + 1;
88 *out_len = required;
89
90 if buf.is_null() {
91 return TFE_SUCCESS;
92 }
93 if buf_len < required {
94 return TFE_BUFFER_TOO_SMALL;
95 }
96
97 std::ptr::copy_nonoverlapping(msg.as_ptr(), buf, msg.len());
98 *buf.add(msg.len()) = 0;
99 TFE_SUCCESS
100 })
101}
102
103pub(crate) fn map_device_error(err: &tenferro_device::Error) -> tfe_status_t {
105 set_last_error(&err.to_string());
106 use tenferro_device::Error;
107 match err {
108 Error::InvalidArgument(_)
109 | Error::StrideError(_)
110 | Error::CrossMemorySpaceOperation { .. } => TFE_INVALID_ARGUMENT,
111 Error::ShapeMismatch { .. } | Error::RankMismatch { .. } => TFE_SHAPE_MISMATCH,
112 Error::DeviceError(_) | Error::NoCompatibleComputeDevice { .. } => TFE_INTERNAL_ERROR,
113 }
114}
115
116pub(crate) fn map_ad_error(err: &chainrules_core::AutodiffError) -> tfe_status_t {
118 set_last_error(&err.to_string());
119 use chainrules_core::AutodiffError;
120 match err {
121 AutodiffError::InvalidArgument(_)
122 | AutodiffError::ModeNotSupported { .. }
123 | AutodiffError::NonScalarLoss { .. }
124 | AutodiffError::HvpNotSupported
125 | AutodiffError::GraphFreed => TFE_INVALID_ARGUMENT,
126 AutodiffError::TangentShapeMismatch { .. } => TFE_SHAPE_MISMATCH,
127 AutodiffError::MissingNode => TFE_INTERNAL_ERROR,
128 }
129}
130
131pub(crate) unsafe fn finalize_ptr<T>(
132 result: std::thread::Result<StatusResult<*mut T>>,
133 status: *mut tfe_status_t,
134) -> *mut T {
135 match result {
136 Ok(Ok(ptr)) => {
137 if !status.is_null() {
138 *status = TFE_SUCCESS;
139 }
140 ptr
141 }
142 Ok(Err(code)) => {
143 if !status.is_null() {
144 *status = code;
145 }
146 std::ptr::null_mut()
147 }
148 Err(panic) => {
149 set_last_error(&panic_message(&*panic));
150 if !status.is_null() {
151 *status = TFE_INTERNAL_ERROR;
152 }
153 std::ptr::null_mut()
154 }
155 }
156}
157
158pub(crate) unsafe fn finalize_void(
159 result: std::thread::Result<StatusResult<()>>,
160 status: *mut tfe_status_t,
161) {
162 match result {
163 Ok(Ok(())) => {
164 if !status.is_null() {
165 *status = TFE_SUCCESS;
166 }
167 }
168 Ok(Err(code)) => {
169 if !status.is_null() {
170 *status = code;
171 }
172 }
173 Err(panic) => {
174 set_last_error(&panic_message(&*panic));
175 if !status.is_null() {
176 *status = TFE_INTERNAL_ERROR;
177 }
178 }
179 }
180}