tenferro_capi/
tensor_api.rs1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use tenferro_device::LogicalMemorySpace;
4use tenferro_tensor::{MemoryOrder, Tensor};
5
6use crate::ffi_utils::read_usize_slice;
7use crate::handle::{handle_take, handle_to_ref, tensor_to_handle, TfeTensorF64};
8use crate::status::{
9 finalize_ptr, map_device_error, panic_message, set_last_error, tfe_status_t,
10 TFE_INTERNAL_ERROR, TFE_INVALID_ARGUMENT, TFE_SUCCESS,
11};
12
13#[no_mangle]
39pub unsafe extern "C" fn tfe_tensor_f64_from_data(
40 data: *const f64,
41 len: usize,
42 shape: *const usize,
43 ndim: usize,
44 status: *mut tfe_status_t,
45) -> *mut TfeTensorF64 {
46 let result = catch_unwind(AssertUnwindSafe(|| {
47 if data.is_null() && len > 0 {
48 set_last_error("from_data: data pointer is null");
49 return Err(TFE_INVALID_ARGUMENT);
50 }
51 let dims = read_usize_slice(shape, ndim, "from_data shape")?.to_vec();
52 let data_slice = if len > 0 {
53 std::slice::from_raw_parts(data, len)
54 } else {
55 &[]
56 };
57
58 Tensor::from_slice(data_slice, &dims, MemoryOrder::ColumnMajor)
59 .map(tensor_to_handle)
60 .map_err(|e| map_device_error(&e))
61 }));
62
63 finalize_ptr(result, status)
64}
65
66#[no_mangle]
82pub unsafe extern "C" fn tfe_tensor_f64_zeros(
83 shape: *const usize,
84 ndim: usize,
85 status: *mut tfe_status_t,
86) -> *mut TfeTensorF64 {
87 let result = catch_unwind(AssertUnwindSafe(|| {
88 let dims = read_usize_slice(shape, ndim, "zeros shape")?.to_vec();
89 let t = Tensor::<f64>::zeros(
90 &dims,
91 LogicalMemorySpace::MainMemory,
92 MemoryOrder::ColumnMajor,
93 )
94 .map_err(|e| map_device_error(&e))?;
95 Ok(tensor_to_handle(t))
96 }));
97
98 finalize_ptr(result, status)
99}
100
101#[no_mangle]
119pub unsafe extern "C" fn tfe_tensor_f64_clone(
120 tensor: *const TfeTensorF64,
121 status: *mut tfe_status_t,
122) -> *mut TfeTensorF64 {
123 let result = catch_unwind(AssertUnwindSafe(|| {
124 if tensor.is_null() {
125 set_last_error("clone: tensor pointer is null");
126 return Err(TFE_INVALID_ARGUMENT);
127 }
128 let src = handle_to_ref(tensor);
129 let materialized = src.contiguous(MemoryOrder::ColumnMajor);
130 let src_data = materialized.buffer().as_slice().ok_or_else(|| {
131 set_last_error("clone: tensor buffer is not contiguous host memory");
132 TFE_INTERNAL_ERROR
133 })?;
134 let copy = Tensor::from_slice(src_data, materialized.dims(), MemoryOrder::ColumnMajor)
135 .map_err(|e| map_device_error(&e))?;
136 Ok(tensor_to_handle(copy))
137 }));
138
139 finalize_ptr(result, status)
140}
141
142#[no_mangle]
155pub unsafe extern "C" fn tfe_tensor_f64_release(tensor: *mut TfeTensorF64) {
156 if tensor.is_null() {
157 return;
158 }
159 let _ = catch_unwind(AssertUnwindSafe(|| {
160 drop(handle_take(tensor));
161 }));
162}
163
164#[no_mangle]
173pub unsafe extern "C" fn tfe_tensor_f64_ndim(
174 tensor: *const TfeTensorF64,
175 status: *mut tfe_status_t,
176) -> usize {
177 if tensor.is_null() {
178 if !status.is_null() {
179 *status = TFE_INVALID_ARGUMENT;
180 }
181 return 0;
182 }
183 let result = catch_unwind(AssertUnwindSafe(|| handle_to_ref(tensor).ndim()));
184 match result {
185 Ok(n) => {
186 if !status.is_null() {
187 *status = TFE_SUCCESS;
188 }
189 n
190 }
191 Err(panic) => {
192 set_last_error(&panic_message(&*panic));
193 if !status.is_null() {
194 *status = TFE_INTERNAL_ERROR;
195 }
196 0
197 }
198 }
199}
200
201#[no_mangle]
213pub unsafe extern "C" fn tfe_tensor_f64_shape(
214 tensor: *const TfeTensorF64,
215 out_shape: *mut usize,
216 status: *mut tfe_status_t,
217) {
218 if tensor.is_null() || out_shape.is_null() {
219 if !status.is_null() {
220 *status = TFE_INVALID_ARGUMENT;
221 }
222 return;
223 }
224 let result = catch_unwind(AssertUnwindSafe(|| {
225 let t = handle_to_ref(tensor);
226 let dims = t.dims();
227 std::ptr::copy_nonoverlapping(dims.as_ptr(), out_shape, dims.len());
228 }));
229 match result {
230 Ok(()) => {
231 if !status.is_null() {
232 *status = TFE_SUCCESS;
233 }
234 }
235 Err(panic) => {
236 set_last_error(&panic_message(&*panic));
237 if !status.is_null() {
238 *status = TFE_INTERNAL_ERROR;
239 }
240 }
241 }
242}
243
244#[no_mangle]
253pub unsafe extern "C" fn tfe_tensor_f64_len(
254 tensor: *const TfeTensorF64,
255 status: *mut tfe_status_t,
256) -> usize {
257 if tensor.is_null() {
258 if !status.is_null() {
259 *status = TFE_INVALID_ARGUMENT;
260 }
261 return 0;
262 }
263 let result = catch_unwind(AssertUnwindSafe(|| handle_to_ref(tensor).len()));
264 match result {
265 Ok(n) => {
266 if !status.is_null() {
267 *status = TFE_SUCCESS;
268 }
269 n
270 }
271 Err(panic) => {
272 set_last_error(&panic_message(&*panic));
273 if !status.is_null() {
274 *status = TFE_INTERNAL_ERROR;
275 }
276 0
277 }
278 }
279}
280
281#[no_mangle]
292pub unsafe extern "C" fn tfe_tensor_f64_data(
293 tensor: *const TfeTensorF64,
294 status: *mut tfe_status_t,
295) -> *const f64 {
296 if tensor.is_null() {
297 if !status.is_null() {
298 *status = TFE_INVALID_ARGUMENT;
299 }
300 return std::ptr::null();
301 }
302 let result = catch_unwind(AssertUnwindSafe(|| {
303 let t = handle_to_ref(tensor);
304 let slice = t.buffer().as_slice().ok_or_else(|| {
305 set_last_error("data: tensor buffer is not contiguous host memory");
306 TFE_INTERNAL_ERROR
307 })?;
308 Ok(slice.as_ptr().add(t.offset() as usize))
309 }));
310 match result {
311 Ok(Ok(ptr)) => {
312 if !status.is_null() {
313 *status = TFE_SUCCESS;
314 }
315 ptr
316 }
317 Ok(Err(code)) => {
318 if !status.is_null() {
319 *status = code;
320 }
321 std::ptr::null()
322 }
323 Err(panic) => {
324 set_last_error(&panic_message(&*panic));
325 if !status.is_null() {
326 *status = TFE_INTERNAL_ERROR;
327 }
328 std::ptr::null()
329 }
330 }
331}