tenferro_capi/
dlpack.rs

1use std::ffi::c_void;
2use std::panic::{catch_unwind, AssertUnwindSafe};
3
4use tenferro_device::LogicalMemorySpace;
5use tenferro_tensor::Tensor;
6
7use crate::handle::{handle_take, tensor_to_handle, TfeTensorF64};
8use crate::status::{
9    finalize_ptr, map_device_error, set_last_error, tfe_status_t, TFE_INTERNAL_ERROR,
10    TFE_INVALID_ARGUMENT,
11};
12
13/// DLPack version information.
14#[repr(C)]
15pub struct DLPackVersion {
16    /// Major version (1 for DLPack v1.0).
17    pub major: u32,
18    /// Minor version.
19    pub minor: u32,
20}
21
22/// DLPack device descriptor.
23#[repr(C)]
24pub struct DLDevice {
25    /// Device type.
26    pub device_type: i32,
27    /// Device ID.
28    pub device_id: i32,
29}
30
31/// DLPack data type descriptor.
32#[repr(C)]
33pub struct DLDataType {
34    /// Type code.
35    pub code: u8,
36    /// Number of bits per element.
37    pub bits: u8,
38    /// Number of lanes.
39    pub lanes: u16,
40}
41
42/// DLPack tensor descriptor (unmanaged).
43#[repr(C)]
44pub struct DLTensor {
45    /// Pointer to the data.
46    pub data: *mut c_void,
47    /// Device where the data resides.
48    pub device: DLDevice,
49    /// Number of dimensions.
50    pub ndim: i32,
51    /// Data type.
52    pub dtype: DLDataType,
53    /// Shape array.
54    pub shape: *mut i64,
55    /// Strides array in element units.
56    pub strides: *mut i64,
57    /// Byte offset from `data`.
58    pub byte_offset: u64,
59}
60
61/// DLPack managed tensor with version and ownership.
62#[repr(C)]
63pub struct DLManagedTensorVersioned {
64    /// DLPack version.
65    pub version: DLPackVersion,
66    /// Opaque pointer for the producer's use.
67    pub manager_ctx: *mut c_void,
68    /// Callback to free the producer's resources.
69    pub deleter: Option<unsafe extern "C" fn(*mut DLManagedTensorVersioned)>,
70    /// Bitmask flags.
71    pub flags: u64,
72    /// The tensor descriptor.
73    pub dl_tensor: DLTensor,
74}
75
76/// CPU device.
77pub const KDLCPU: i32 = 1;
78/// NVIDIA CUDA GPU device memory.
79pub const KDLCUDA: i32 = 2;
80/// Pinned CUDA CPU memory.
81pub const KDLCUDA_HOST: i32 = 3;
82/// AMD ROCm GPU device memory.
83pub const KDLROCM: i32 = 10;
84/// Pinned ROCm CPU memory.
85pub const KDLROCM_HOST: i32 = 11;
86/// CUDA managed/unified memory.
87pub const KDLCUDA_MANAGED: i32 = 13;
88
89/// Integer type code.
90pub const KDLINT: u8 = 0;
91/// Floating-point type code.
92pub const KDLFLOAT: u8 = 2;
93/// Complex type code.
94pub const KDLCOMPLEX: u8 = 5;
95
96/// Data is read-only.
97pub const DLPACK_FLAG_BITMASK_READ_ONLY: u64 = 1 << 0;
98/// Data was copied.
99pub const DLPACK_FLAG_BITMASK_IS_COPIED: u64 = 1 << 1;
100
101struct ExportedDLPackTensor {
102    tensor: Box<Tensor<f64>>,
103    shape: Box<[i64]>,
104    strides: Box<[i64]>,
105}
106
107struct ImportedDLPackGuard {
108    managed: *mut DLManagedTensorVersioned,
109    armed: bool,
110}
111
112impl ImportedDLPackGuard {
113    unsafe fn new(managed: *mut DLManagedTensorVersioned) -> Self {
114        Self {
115            managed,
116            armed: true,
117        }
118    }
119
120    fn as_ptr(&self) -> *mut DLManagedTensorVersioned {
121        self.managed
122    }
123
124    fn disarm(&mut self) {
125        self.armed = false;
126    }
127}
128
129impl Drop for ImportedDLPackGuard {
130    fn drop(&mut self) {
131        if !self.armed || self.managed.is_null() {
132            return;
133        }
134
135        unsafe {
136            if let Some(deleter) = (*self.managed).deleter {
137                deleter(self.managed);
138            } else {
139                drop(Box::from_raw(self.managed));
140            }
141        }
142    }
143}
144
145pub(crate) fn row_major_strides(dims: &[usize]) -> tenferro_device::Result<Vec<isize>> {
146    let ndim = dims.len();
147    if ndim == 0 {
148        return Ok(vec![]);
149    }
150    let mut strides = vec![0isize; ndim];
151    strides[ndim - 1] = 1;
152    for i in (0..ndim - 1).rev() {
153        let axis = i + 1;
154        let dim = isize::try_from(dims[axis]).map_err(|_| {
155            tenferro_device::Error::InvalidArgument(format!(
156                "dimension {axis} (size={}) too large for stride calculation",
157                dims[axis]
158            ))
159        })?;
160        strides[i] = strides[i + 1].checked_mul(dim).ok_or_else(|| {
161            tenferro_device::Error::StrideError(format!(
162                "stride overflow in row-major inference at dimension {axis} (size={dim})"
163            ))
164        })?;
165    }
166    Ok(strides)
167}
168
169fn required_buffer_len(
170    dims: &[usize],
171    strides: &[isize],
172    offset: isize,
173) -> tenferro_device::Result<usize> {
174    if dims.len() != strides.len() {
175        return Err(tenferro_device::Error::InvalidArgument(format!(
176            "strides length {} doesn't match dims length {}",
177            strides.len(),
178            dims.len()
179        )));
180    }
181    if dims.iter().product::<usize>() == 0 {
182        return Ok(0);
183    }
184
185    let mut min_pos = offset;
186    let mut max_pos = offset;
187    for (axis, (&dim, &stride)) in dims.iter().zip(strides.iter()).enumerate() {
188        if dim == 0 {
189            continue;
190        }
191        let extent = isize::try_from(dim - 1)
192            .ok()
193            .and_then(|d| d.checked_mul(stride))
194            .ok_or_else(|| {
195                tenferro_device::Error::StrideError(format!(
196                    "extent overflow for dimension {axis} (size={dim}, stride={stride})"
197                ))
198            })?;
199        if extent >= 0 {
200            max_pos += extent;
201        } else {
202            min_pos += extent;
203        }
204    }
205
206    if min_pos < 0 {
207        return Err(tenferro_device::Error::StrideError(format!(
208            "layout accesses negative buffer position {min_pos}"
209        )));
210    }
211
212    Ok(max_pos as usize + 1)
213}
214
215fn logical_memory_space_to_dl_device(space: LogicalMemorySpace) -> DLDevice {
216    match space {
217        LogicalMemorySpace::MainMemory => DLDevice {
218            device_type: KDLCPU,
219            device_id: 0,
220        },
221        LogicalMemorySpace::PinnedMemory => DLDevice {
222            device_type: KDLCUDA_HOST,
223            device_id: 0,
224        },
225        LogicalMemorySpace::GpuMemory { device_id } => DLDevice {
226            device_type: KDLCUDA,
227            device_id: device_id as i32,
228        },
229        LogicalMemorySpace::ManagedMemory => DLDevice {
230            device_type: KDLCUDA_MANAGED,
231            device_id: 0,
232        },
233    }
234}
235
236unsafe extern "C" fn exported_dlpack_tensor_deleter(managed: *mut DLManagedTensorVersioned) {
237    if managed.is_null() {
238        return;
239    }
240
241    let managed = Box::from_raw(managed);
242    if !managed.manager_ctx.is_null() {
243        drop(Box::from_raw(
244            managed.manager_ctx as *mut ExportedDLPackTensor,
245        ));
246    }
247}
248
249/// Export a tensor as a DLPack managed tensor (zero-copy).
250///
251/// The tensor handle is **consumed** by this call and must not be
252/// used afterwards.
253///
254/// # Safety
255///
256/// - `tensor` must be a valid tensor pointer or NULL.
257/// - `status` must be a valid, non-null pointer.
258///
259/// # Examples (C)
260///
261/// ```c
262/// tfe_status_t status;
263/// DLManagedTensorVersioned *dl = tfe_tensor_f64_to_dlpack(t, &status);
264/// ```
265#[no_mangle]
266pub unsafe extern "C" fn tfe_tensor_f64_to_dlpack(
267    tensor: *mut TfeTensorF64,
268    status: *mut tfe_status_t,
269) -> *mut DLManagedTensorVersioned {
270    let result = catch_unwind(AssertUnwindSafe(
271        || -> Result<*mut DLManagedTensorVersioned, tfe_status_t> {
272            if tensor.is_null() {
273                set_last_error("to_dlpack: tensor pointer is null");
274                return Err(TFE_INVALID_ARGUMENT);
275            }
276
277            let tensor = handle_take(tensor);
278            let shape = tensor
279                .dims()
280                .iter()
281                .map(|&dim| i64::try_from(dim))
282                .collect::<std::result::Result<Vec<_>, _>>()
283                .map_err(|_| {
284                    set_last_error("to_dlpack: shape dimension does not fit into i64");
285                    TFE_INVALID_ARGUMENT
286                })?
287                .into_boxed_slice();
288            let strides = tensor
289                .strides()
290                .iter()
291                .map(|&stride| i64::try_from(stride))
292                .collect::<std::result::Result<Vec<_>, _>>()
293                .map_err(|_| {
294                    set_last_error("to_dlpack: stride does not fit into i64");
295                    TFE_INVALID_ARGUMENT
296                })?
297                .into_boxed_slice();
298
299            let data_ptr = if let Some(ptr) = tensor.buffer().as_ptr() {
300                ptr as *mut c_void
301            } else if let Some(ptr) = tensor.buffer().as_device_ptr() {
302                ptr as *mut c_void
303            } else {
304                set_last_error("to_dlpack: tensor buffer has no exported pointer");
305                return Err(TFE_INTERNAL_ERROR);
306            };
307
308            let offset = tensor.offset();
309            if offset < 0 {
310                set_last_error(
311                    "to_dlpack: negative tensor offset is not supported for DLPack export",
312                );
313                return Err(TFE_INVALID_ARGUMENT);
314            }
315            let byte_offset = u64::try_from(
316                (offset as i128) * (std::mem::size_of::<f64>() as i128),
317            )
318            .map_err(|_| {
319                set_last_error("to_dlpack: byte offset overflow");
320                TFE_INVALID_ARGUMENT
321            })?;
322
323            let mut ctx = Box::new(ExportedDLPackTensor {
324                tensor,
325                shape,
326                strides,
327            });
328
329            let shape_ptr = if ctx.shape.is_empty() {
330                std::ptr::null_mut()
331            } else {
332                ctx.shape.as_mut_ptr()
333            };
334            let strides_ptr = if ctx.strides.is_empty() {
335                std::ptr::null_mut()
336            } else {
337                ctx.strides.as_mut_ptr()
338            };
339            let ndim = i32::try_from(ctx.shape.len()).map_err(|_| {
340                set_last_error("to_dlpack: rank does not fit into i32");
341                TFE_INVALID_ARGUMENT
342            })?;
343            let device = logical_memory_space_to_dl_device(ctx.tensor.logical_memory_space());
344
345            let managed = Box::new(DLManagedTensorVersioned {
346                version: DLPackVersion { major: 1, minor: 0 },
347                manager_ctx: Box::into_raw(ctx) as *mut c_void,
348                deleter: Some(exported_dlpack_tensor_deleter),
349                flags: 0,
350                dl_tensor: DLTensor {
351                    data: data_ptr,
352                    device,
353                    ndim,
354                    dtype: DLDataType {
355                        code: KDLFLOAT,
356                        bits: 64,
357                        lanes: 1,
358                    },
359                    shape: shape_ptr,
360                    strides: strides_ptr,
361                    byte_offset,
362                },
363            });
364
365            Ok(Box::into_raw(managed))
366        },
367    ));
368    finalize_ptr(result, status)
369}
370
371/// Import a DLPack managed tensor as a tenferro tensor (zero-copy).
372///
373/// Currently only `kDLCPU` device and float64 dtype are accepted.
374///
375/// # Safety
376///
377/// - `managed` must be a valid pointer to a `DLManagedTensorVersioned`.
378/// - `status` must be a valid, non-null pointer.
379///
380/// # Examples (C)
381///
382/// ```c
383/// tfe_status_t status;
384/// tfe_tensor_f64 *t = tfe_tensor_f64_from_dlpack(dl, &status);
385/// ```
386#[no_mangle]
387pub unsafe extern "C" fn tfe_tensor_f64_from_dlpack(
388    managed: *mut DLManagedTensorVersioned,
389    status: *mut tfe_status_t,
390) -> *mut TfeTensorF64 {
391    let result = catch_unwind(AssertUnwindSafe(
392        || -> Result<*mut TfeTensorF64, tfe_status_t> {
393            if managed.is_null() {
394                set_last_error("from_dlpack: managed tensor pointer is null");
395                return Err(TFE_INVALID_ARGUMENT);
396            }
397
398            let mut guard = ImportedDLPackGuard::new(managed);
399            let managed_ref = &*guard.as_ptr();
400            if managed_ref.version.major != 1 {
401                set_last_error("from_dlpack: unsupported DLPack major version");
402                return Err(TFE_INVALID_ARGUMENT);
403            }
404
405            let dl = &managed_ref.dl_tensor;
406            if dl.device.device_type != KDLCPU || dl.device.device_id != 0 {
407                set_last_error("from_dlpack: only kDLCPU device_id=0 is supported");
408                return Err(TFE_INVALID_ARGUMENT);
409            }
410            if dl.dtype.code != KDLFLOAT || dl.dtype.bits != 64 || dl.dtype.lanes != 1 {
411                set_last_error("from_dlpack: only float64 tensors are supported");
412                return Err(TFE_INVALID_ARGUMENT);
413            }
414            if dl.ndim < 0 {
415                set_last_error("from_dlpack: negative ndim is invalid");
416                return Err(TFE_INVALID_ARGUMENT);
417            }
418
419            let ndim = dl.ndim as usize;
420            if ndim > 0 && dl.shape.is_null() {
421                set_last_error("from_dlpack: shape pointer is null for non-scalar tensor");
422                return Err(TFE_INVALID_ARGUMENT);
423            }
424
425            let dims_i64 = if ndim == 0 {
426                &[][..]
427            } else {
428                std::slice::from_raw_parts(dl.shape, ndim)
429            };
430            let dims = dims_i64
431                .iter()
432                .map(|&dim| usize::try_from(dim))
433                .collect::<std::result::Result<Vec<_>, _>>()
434                .map_err(|_| {
435                    set_last_error("from_dlpack: shape contains negative or oversized dimension");
436                    TFE_INVALID_ARGUMENT
437                })?;
438
439            let strides = if dl.strides.is_null() {
440                row_major_strides(&dims).map_err(|err| {
441                    set_last_error(&format!("from_dlpack: {}", err));
442                    TFE_INVALID_ARGUMENT
443                })?
444            } else {
445                std::slice::from_raw_parts(dl.strides, ndim)
446                    .iter()
447                    .map(|&stride| isize::try_from(stride))
448                    .collect::<std::result::Result<Vec<_>, _>>()
449                    .map_err(|_| {
450                        set_last_error("from_dlpack: stride does not fit into isize");
451                        TFE_INVALID_ARGUMENT
452                    })?
453            };
454
455            let elem_size = std::mem::size_of::<f64>() as u64;
456            if dl.byte_offset % elem_size != 0 {
457                set_last_error("from_dlpack: byte offset is not aligned to f64 elements");
458                return Err(TFE_INVALID_ARGUMENT);
459            }
460            let offset = isize::try_from(dl.byte_offset / elem_size).map_err(|_| {
461                set_last_error("from_dlpack: element offset does not fit into isize");
462                TFE_INVALID_ARGUMENT
463            })?;
464            let len = required_buffer_len(&dims, &strides, offset)
465                .map_err(|err| map_device_error(&err))?;
466
467            let data = dl.data as *const f64;
468            if data.is_null() && len > 0 {
469                set_last_error("from_dlpack: data pointer is null for non-empty tensor");
470                return Err(TFE_INVALID_ARGUMENT);
471            }
472
473            let managed_addr = guard.as_ptr() as usize;
474            let release = move || unsafe {
475                let managed = managed_addr as *mut DLManagedTensorVersioned;
476                if let Some(deleter) = (*managed).deleter {
477                    deleter(managed);
478                } else {
479                    drop(Box::from_raw(managed));
480                }
481            };
482            let tensor = Tensor::from_external_parts(data, len, &dims, &strides, offset, release)
483                .map_err(|err| map_device_error(&err))?;
484            guard.disarm();
485            Ok(tensor_to_handle(tensor))
486        },
487    ));
488    finalize_ptr(result, status)
489}