tenferro_capi/
tensor_api.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use tenferro_device::LogicalMemorySpace;
4use tenferro_tensor::{MemoryOrder, Tensor};
5
6use crate::ffi_utils::read_usize_slice;
7use crate::handle::{handle_take, handle_to_ref, tensor_to_handle, TfeTensorF64};
8use crate::status::{
9    finalize_ptr, map_device_error, panic_message, set_last_error, tfe_status_t,
10    TFE_INTERNAL_ERROR, TFE_INVALID_ARGUMENT, TFE_SUCCESS,
11};
12
13/// Create a tensor from caller-provided data (copy semantics).
14///
15/// The data is **copied** into Rust-owned storage. The caller retains
16/// ownership of the `data` pointer and may free it after this call.
17/// The internal memory layout is implementation-defined.
18///
19/// For zero-copy tensor exchange with specific memory layouts, use
20/// [`crate::tfe_tensor_f64_from_dlpack`] instead.
21///
22/// # Safety
23///
24/// - `data` must point to at least `len` valid `f64` values.
25/// - `shape` must point to at least `ndim` valid `usize` values.
26/// - `status` must be a valid, non-null pointer.
27///
28/// # Examples (C)
29///
30/// ```c
31/// double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
32/// size_t shape[] = {2, 3};
33/// tfe_status_t status;
34/// tfe_tensor_f64 *t = tfe_tensor_f64_from_data(data, 6, shape, 2, &status);
35/// assert(status == TFE_SUCCESS);
36/// tfe_tensor_f64_release(t);
37/// ```
38#[no_mangle]
39pub unsafe extern "C" fn tfe_tensor_f64_from_data(
40    data: *const f64,
41    len: usize,
42    shape: *const usize,
43    ndim: usize,
44    status: *mut tfe_status_t,
45) -> *mut TfeTensorF64 {
46    let result = catch_unwind(AssertUnwindSafe(|| {
47        if data.is_null() && len > 0 {
48            set_last_error("from_data: data pointer is null");
49            return Err(TFE_INVALID_ARGUMENT);
50        }
51        let dims = read_usize_slice(shape, ndim, "from_data shape")?.to_vec();
52        let data_slice = if len > 0 {
53            std::slice::from_raw_parts(data, len)
54        } else {
55            &[]
56        };
57
58        Tensor::from_slice(data_slice, &dims, MemoryOrder::ColumnMajor)
59            .map(tensor_to_handle)
60            .map_err(|e| map_device_error(&e))
61    }));
62
63    finalize_ptr(result, status)
64}
65
66/// Create a tensor filled with zeros.
67///
68/// # Safety
69///
70/// - `shape` must point to at least `ndim` valid `usize` values.
71/// - `status` must be a valid, non-null pointer.
72///
73/// # Examples (C)
74///
75/// ```c
76/// size_t shape[] = {3, 4};
77/// tfe_status_t status;
78/// tfe_tensor_f64 *t = tfe_tensor_f64_zeros(shape, 2, &status);
79/// tfe_tensor_f64_release(t);
80/// ```
81#[no_mangle]
82pub unsafe extern "C" fn tfe_tensor_f64_zeros(
83    shape: *const usize,
84    ndim: usize,
85    status: *mut tfe_status_t,
86) -> *mut TfeTensorF64 {
87    let result = catch_unwind(AssertUnwindSafe(|| {
88        let dims = read_usize_slice(shape, ndim, "zeros shape")?.to_vec();
89        let t = Tensor::<f64>::zeros(
90            &dims,
91            LogicalMemorySpace::MainMemory,
92            MemoryOrder::ColumnMajor,
93        )
94        .map_err(|e| map_device_error(&e))?;
95        Ok(tensor_to_handle(t))
96    }));
97
98    finalize_ptr(result, status)
99}
100
101/// Deep-copy a tensor.
102///
103/// `Tensor::clone()` is a shallow copy (Arc refcount increment).
104/// This C API function produces a new tensor with its own independent data buffer.
105///
106/// # Safety
107///
108/// - `tensor` must be a valid pointer returned by a `tfe_tensor_f64_*`
109///   creation function that has not yet been released.
110/// - `status` must be a valid, non-null pointer.
111///
112/// # Examples (C)
113///
114/// ```c
115/// tfe_tensor_f64 *copy = tfe_tensor_f64_clone(original, &status);
116/// tfe_tensor_f64_release(copy);
117/// ```
118#[no_mangle]
119pub unsafe extern "C" fn tfe_tensor_f64_clone(
120    tensor: *const TfeTensorF64,
121    status: *mut tfe_status_t,
122) -> *mut TfeTensorF64 {
123    let result = catch_unwind(AssertUnwindSafe(|| {
124        if tensor.is_null() {
125            set_last_error("clone: tensor pointer is null");
126            return Err(TFE_INVALID_ARGUMENT);
127        }
128        let src = handle_to_ref(tensor);
129        let materialized = src.contiguous(MemoryOrder::ColumnMajor);
130        let src_data = materialized.buffer().as_slice().ok_or_else(|| {
131            set_last_error("clone: tensor buffer is not contiguous host memory");
132            TFE_INTERNAL_ERROR
133        })?;
134        let copy = Tensor::from_slice(src_data, materialized.dims(), MemoryOrder::ColumnMajor)
135            .map_err(|e| map_device_error(&e))?;
136        Ok(tensor_to_handle(copy))
137    }));
138
139    finalize_ptr(result, status)
140}
141
142/// Release (free) a tensor.
143///
144/// After this call, `tensor` is invalid and must not be used.
145/// Passing a null pointer is a no-op.
146///
147/// For tensors imported via DLPack, this calls the DLPack deleter
148/// to notify the external owner that the data is no longer needed.
149///
150/// # Safety
151///
152/// `tensor` must be null or a valid pointer returned by a
153/// `tfe_tensor_f64_*` creation function that has not yet been released.
154#[no_mangle]
155pub unsafe extern "C" fn tfe_tensor_f64_release(tensor: *mut TfeTensorF64) {
156    if tensor.is_null() {
157        return;
158    }
159    let _ = catch_unwind(AssertUnwindSafe(|| {
160        drop(handle_take(tensor));
161    }));
162}
163
164/// Return the number of dimensions (rank) of the tensor.
165///
166/// Returns 0 if `tensor` is null (and sets `status` to `TFE_INVALID_ARGUMENT`).
167///
168/// # Safety
169///
170/// - `tensor` must be a valid tensor pointer or null.
171/// - `status` must be a valid, non-null pointer.
172#[no_mangle]
173pub unsafe extern "C" fn tfe_tensor_f64_ndim(
174    tensor: *const TfeTensorF64,
175    status: *mut tfe_status_t,
176) -> usize {
177    if tensor.is_null() {
178        if !status.is_null() {
179            *status = TFE_INVALID_ARGUMENT;
180        }
181        return 0;
182    }
183    let result = catch_unwind(AssertUnwindSafe(|| handle_to_ref(tensor).ndim()));
184    match result {
185        Ok(n) => {
186            if !status.is_null() {
187                *status = TFE_SUCCESS;
188            }
189            n
190        }
191        Err(panic) => {
192            set_last_error(&panic_message(&*panic));
193            if !status.is_null() {
194                *status = TFE_INTERNAL_ERROR;
195            }
196            0
197        }
198    }
199}
200
201/// Write the shape of the tensor into the caller-provided buffer.
202///
203/// The caller must allocate `out_shape` with at least
204/// `tfe_tensor_f64_ndim(tensor)` elements.
205///
206/// # Safety
207///
208/// - `tensor` must be a valid, non-null tensor pointer.
209/// - `out_shape` must point to a buffer with at least
210///   `tfe_tensor_f64_ndim(tensor)` elements.
211/// - `status` must be a valid, non-null pointer.
212#[no_mangle]
213pub unsafe extern "C" fn tfe_tensor_f64_shape(
214    tensor: *const TfeTensorF64,
215    out_shape: *mut usize,
216    status: *mut tfe_status_t,
217) {
218    if tensor.is_null() || out_shape.is_null() {
219        if !status.is_null() {
220            *status = TFE_INVALID_ARGUMENT;
221        }
222        return;
223    }
224    let result = catch_unwind(AssertUnwindSafe(|| {
225        let t = handle_to_ref(tensor);
226        let dims = t.dims();
227        std::ptr::copy_nonoverlapping(dims.as_ptr(), out_shape, dims.len());
228    }));
229    match result {
230        Ok(()) => {
231            if !status.is_null() {
232                *status = TFE_SUCCESS;
233            }
234        }
235        Err(panic) => {
236            set_last_error(&panic_message(&*panic));
237            if !status.is_null() {
238                *status = TFE_INTERNAL_ERROR;
239            }
240        }
241    }
242}
243
244/// Return the total number of elements in the tensor.
245///
246/// Returns 0 if `tensor` is null (and sets `status` to `TFE_INVALID_ARGUMENT`).
247///
248/// # Safety
249///
250/// - `tensor` must be a valid tensor pointer or null.
251/// - `status` must be a valid, non-null pointer.
252#[no_mangle]
253pub unsafe extern "C" fn tfe_tensor_f64_len(
254    tensor: *const TfeTensorF64,
255    status: *mut tfe_status_t,
256) -> usize {
257    if tensor.is_null() {
258        if !status.is_null() {
259            *status = TFE_INVALID_ARGUMENT;
260        }
261        return 0;
262    }
263    let result = catch_unwind(AssertUnwindSafe(|| handle_to_ref(tensor).len()));
264    match result {
265        Ok(n) => {
266            if !status.is_null() {
267                *status = TFE_SUCCESS;
268            }
269            n
270        }
271        Err(panic) => {
272            set_last_error(&panic_message(&*panic));
273            if !status.is_null() {
274                *status = TFE_INTERNAL_ERROR;
275            }
276            0
277        }
278    }
279}
280
281/// Return a pointer to the tensor's raw data buffer.
282///
283/// The pointer is valid until `tfe_tensor_f64_release` is called on
284/// the tensor. Returns null if `tensor` is null.
285///
286/// # Safety
287///
288/// - `tensor` must be a valid tensor pointer or null.
289/// - `status` must be a valid, non-null pointer.
290/// - The returned pointer must not be used after `tfe_tensor_f64_release(tensor)`.
291#[no_mangle]
292pub unsafe extern "C" fn tfe_tensor_f64_data(
293    tensor: *const TfeTensorF64,
294    status: *mut tfe_status_t,
295) -> *const f64 {
296    if tensor.is_null() {
297        if !status.is_null() {
298            *status = TFE_INVALID_ARGUMENT;
299        }
300        return std::ptr::null();
301    }
302    let result = catch_unwind(AssertUnwindSafe(|| {
303        let t = handle_to_ref(tensor);
304        let slice = t.buffer().as_slice().ok_or_else(|| {
305            set_last_error("data: tensor buffer is not contiguous host memory");
306            TFE_INTERNAL_ERROR
307        })?;
308        Ok(slice.as_ptr().add(t.offset() as usize))
309    }));
310    match result {
311        Ok(Ok(ptr)) => {
312            if !status.is_null() {
313                *status = TFE_SUCCESS;
314            }
315            ptr
316        }
317        Ok(Err(code)) => {
318            if !status.is_null() {
319                *status = code;
320            }
321            std::ptr::null()
322        }
323        Err(panic) => {
324            set_last_error(&panic_message(&*panic));
325            if !status.is_null() {
326                *status = TFE_INTERNAL_ERROR;
327            }
328            std::ptr::null()
329        }
330    }
331}