1use std::ffi::c_void;
2use std::panic::{catch_unwind, AssertUnwindSafe};
3
4use tenferro_device::LogicalMemorySpace;
5use tenferro_tensor::Tensor;
6
7use crate::handle::{handle_take, tensor_to_handle, TfeTensorF64};
8use crate::status::{
9 finalize_ptr, map_device_error, set_last_error, tfe_status_t, TFE_INTERNAL_ERROR,
10 TFE_INVALID_ARGUMENT,
11};
12
13#[repr(C)]
15pub struct DLPackVersion {
16 pub major: u32,
18 pub minor: u32,
20}
21
22#[repr(C)]
24pub struct DLDevice {
25 pub device_type: i32,
27 pub device_id: i32,
29}
30
31#[repr(C)]
33pub struct DLDataType {
34 pub code: u8,
36 pub bits: u8,
38 pub lanes: u16,
40}
41
42#[repr(C)]
44pub struct DLTensor {
45 pub data: *mut c_void,
47 pub device: DLDevice,
49 pub ndim: i32,
51 pub dtype: DLDataType,
53 pub shape: *mut i64,
55 pub strides: *mut i64,
57 pub byte_offset: u64,
59}
60
61#[repr(C)]
63pub struct DLManagedTensorVersioned {
64 pub version: DLPackVersion,
66 pub manager_ctx: *mut c_void,
68 pub deleter: Option<unsafe extern "C" fn(*mut DLManagedTensorVersioned)>,
70 pub flags: u64,
72 pub dl_tensor: DLTensor,
74}
75
76pub const KDLCPU: i32 = 1;
78pub const KDLCUDA: i32 = 2;
80pub const KDLCUDA_HOST: i32 = 3;
82pub const KDLROCM: i32 = 10;
84pub const KDLROCM_HOST: i32 = 11;
86pub const KDLCUDA_MANAGED: i32 = 13;
88
89pub const KDLINT: u8 = 0;
91pub const KDLFLOAT: u8 = 2;
93pub const KDLCOMPLEX: u8 = 5;
95
96pub const DLPACK_FLAG_BITMASK_READ_ONLY: u64 = 1 << 0;
98pub const DLPACK_FLAG_BITMASK_IS_COPIED: u64 = 1 << 1;
100
101struct ExportedDLPackTensor {
102 tensor: Box<Tensor<f64>>,
103 shape: Box<[i64]>,
104 strides: Box<[i64]>,
105}
106
107struct ImportedDLPackGuard {
108 managed: *mut DLManagedTensorVersioned,
109 armed: bool,
110}
111
112impl ImportedDLPackGuard {
113 unsafe fn new(managed: *mut DLManagedTensorVersioned) -> Self {
114 Self {
115 managed,
116 armed: true,
117 }
118 }
119
120 fn as_ptr(&self) -> *mut DLManagedTensorVersioned {
121 self.managed
122 }
123
124 fn disarm(&mut self) {
125 self.armed = false;
126 }
127}
128
129impl Drop for ImportedDLPackGuard {
130 fn drop(&mut self) {
131 if !self.armed || self.managed.is_null() {
132 return;
133 }
134
135 unsafe {
136 if let Some(deleter) = (*self.managed).deleter {
137 deleter(self.managed);
138 } else {
139 drop(Box::from_raw(self.managed));
140 }
141 }
142 }
143}
144
145pub(crate) fn row_major_strides(dims: &[usize]) -> tenferro_device::Result<Vec<isize>> {
146 let ndim = dims.len();
147 if ndim == 0 {
148 return Ok(vec![]);
149 }
150 let mut strides = vec![0isize; ndim];
151 strides[ndim - 1] = 1;
152 for i in (0..ndim - 1).rev() {
153 let axis = i + 1;
154 let dim = isize::try_from(dims[axis]).map_err(|_| {
155 tenferro_device::Error::InvalidArgument(format!(
156 "dimension {axis} (size={}) too large for stride calculation",
157 dims[axis]
158 ))
159 })?;
160 strides[i] = strides[i + 1].checked_mul(dim).ok_or_else(|| {
161 tenferro_device::Error::StrideError(format!(
162 "stride overflow in row-major inference at dimension {axis} (size={dim})"
163 ))
164 })?;
165 }
166 Ok(strides)
167}
168
169fn required_buffer_len(
170 dims: &[usize],
171 strides: &[isize],
172 offset: isize,
173) -> tenferro_device::Result<usize> {
174 if dims.len() != strides.len() {
175 return Err(tenferro_device::Error::InvalidArgument(format!(
176 "strides length {} doesn't match dims length {}",
177 strides.len(),
178 dims.len()
179 )));
180 }
181 if dims.iter().product::<usize>() == 0 {
182 return Ok(0);
183 }
184
185 let mut min_pos = offset;
186 let mut max_pos = offset;
187 for (axis, (&dim, &stride)) in dims.iter().zip(strides.iter()).enumerate() {
188 if dim == 0 {
189 continue;
190 }
191 let extent = isize::try_from(dim - 1)
192 .ok()
193 .and_then(|d| d.checked_mul(stride))
194 .ok_or_else(|| {
195 tenferro_device::Error::StrideError(format!(
196 "extent overflow for dimension {axis} (size={dim}, stride={stride})"
197 ))
198 })?;
199 if extent >= 0 {
200 max_pos += extent;
201 } else {
202 min_pos += extent;
203 }
204 }
205
206 if min_pos < 0 {
207 return Err(tenferro_device::Error::StrideError(format!(
208 "layout accesses negative buffer position {min_pos}"
209 )));
210 }
211
212 Ok(max_pos as usize + 1)
213}
214
215fn logical_memory_space_to_dl_device(space: LogicalMemorySpace) -> DLDevice {
216 match space {
217 LogicalMemorySpace::MainMemory => DLDevice {
218 device_type: KDLCPU,
219 device_id: 0,
220 },
221 LogicalMemorySpace::PinnedMemory => DLDevice {
222 device_type: KDLCUDA_HOST,
223 device_id: 0,
224 },
225 LogicalMemorySpace::GpuMemory { device_id } => DLDevice {
226 device_type: KDLCUDA,
227 device_id: device_id as i32,
228 },
229 LogicalMemorySpace::ManagedMemory => DLDevice {
230 device_type: KDLCUDA_MANAGED,
231 device_id: 0,
232 },
233 }
234}
235
236unsafe extern "C" fn exported_dlpack_tensor_deleter(managed: *mut DLManagedTensorVersioned) {
237 if managed.is_null() {
238 return;
239 }
240
241 let managed = Box::from_raw(managed);
242 if !managed.manager_ctx.is_null() {
243 drop(Box::from_raw(
244 managed.manager_ctx as *mut ExportedDLPackTensor,
245 ));
246 }
247}
248
249#[no_mangle]
266pub unsafe extern "C" fn tfe_tensor_f64_to_dlpack(
267 tensor: *mut TfeTensorF64,
268 status: *mut tfe_status_t,
269) -> *mut DLManagedTensorVersioned {
270 let result = catch_unwind(AssertUnwindSafe(
271 || -> Result<*mut DLManagedTensorVersioned, tfe_status_t> {
272 if tensor.is_null() {
273 set_last_error("to_dlpack: tensor pointer is null");
274 return Err(TFE_INVALID_ARGUMENT);
275 }
276
277 let tensor = handle_take(tensor);
278 let shape = tensor
279 .dims()
280 .iter()
281 .map(|&dim| i64::try_from(dim))
282 .collect::<std::result::Result<Vec<_>, _>>()
283 .map_err(|_| {
284 set_last_error("to_dlpack: shape dimension does not fit into i64");
285 TFE_INVALID_ARGUMENT
286 })?
287 .into_boxed_slice();
288 let strides = tensor
289 .strides()
290 .iter()
291 .map(|&stride| i64::try_from(stride))
292 .collect::<std::result::Result<Vec<_>, _>>()
293 .map_err(|_| {
294 set_last_error("to_dlpack: stride does not fit into i64");
295 TFE_INVALID_ARGUMENT
296 })?
297 .into_boxed_slice();
298
299 let data_ptr = if let Some(ptr) = tensor.buffer().as_ptr() {
300 ptr as *mut c_void
301 } else if let Some(ptr) = tensor.buffer().as_device_ptr() {
302 ptr as *mut c_void
303 } else {
304 set_last_error("to_dlpack: tensor buffer has no exported pointer");
305 return Err(TFE_INTERNAL_ERROR);
306 };
307
308 let offset = tensor.offset();
309 if offset < 0 {
310 set_last_error(
311 "to_dlpack: negative tensor offset is not supported for DLPack export",
312 );
313 return Err(TFE_INVALID_ARGUMENT);
314 }
315 let byte_offset = u64::try_from(
316 (offset as i128) * (std::mem::size_of::<f64>() as i128),
317 )
318 .map_err(|_| {
319 set_last_error("to_dlpack: byte offset overflow");
320 TFE_INVALID_ARGUMENT
321 })?;
322
323 let mut ctx = Box::new(ExportedDLPackTensor {
324 tensor,
325 shape,
326 strides,
327 });
328
329 let shape_ptr = if ctx.shape.is_empty() {
330 std::ptr::null_mut()
331 } else {
332 ctx.shape.as_mut_ptr()
333 };
334 let strides_ptr = if ctx.strides.is_empty() {
335 std::ptr::null_mut()
336 } else {
337 ctx.strides.as_mut_ptr()
338 };
339 let ndim = i32::try_from(ctx.shape.len()).map_err(|_| {
340 set_last_error("to_dlpack: rank does not fit into i32");
341 TFE_INVALID_ARGUMENT
342 })?;
343 let device = logical_memory_space_to_dl_device(ctx.tensor.logical_memory_space());
344
345 let managed = Box::new(DLManagedTensorVersioned {
346 version: DLPackVersion { major: 1, minor: 0 },
347 manager_ctx: Box::into_raw(ctx) as *mut c_void,
348 deleter: Some(exported_dlpack_tensor_deleter),
349 flags: 0,
350 dl_tensor: DLTensor {
351 data: data_ptr,
352 device,
353 ndim,
354 dtype: DLDataType {
355 code: KDLFLOAT,
356 bits: 64,
357 lanes: 1,
358 },
359 shape: shape_ptr,
360 strides: strides_ptr,
361 byte_offset,
362 },
363 });
364
365 Ok(Box::into_raw(managed))
366 },
367 ));
368 finalize_ptr(result, status)
369}
370
371#[no_mangle]
387pub unsafe extern "C" fn tfe_tensor_f64_from_dlpack(
388 managed: *mut DLManagedTensorVersioned,
389 status: *mut tfe_status_t,
390) -> *mut TfeTensorF64 {
391 let result = catch_unwind(AssertUnwindSafe(
392 || -> Result<*mut TfeTensorF64, tfe_status_t> {
393 if managed.is_null() {
394 set_last_error("from_dlpack: managed tensor pointer is null");
395 return Err(TFE_INVALID_ARGUMENT);
396 }
397
398 let mut guard = ImportedDLPackGuard::new(managed);
399 let managed_ref = &*guard.as_ptr();
400 if managed_ref.version.major != 1 {
401 set_last_error("from_dlpack: unsupported DLPack major version");
402 return Err(TFE_INVALID_ARGUMENT);
403 }
404
405 let dl = &managed_ref.dl_tensor;
406 if dl.device.device_type != KDLCPU || dl.device.device_id != 0 {
407 set_last_error("from_dlpack: only kDLCPU device_id=0 is supported");
408 return Err(TFE_INVALID_ARGUMENT);
409 }
410 if dl.dtype.code != KDLFLOAT || dl.dtype.bits != 64 || dl.dtype.lanes != 1 {
411 set_last_error("from_dlpack: only float64 tensors are supported");
412 return Err(TFE_INVALID_ARGUMENT);
413 }
414 if dl.ndim < 0 {
415 set_last_error("from_dlpack: negative ndim is invalid");
416 return Err(TFE_INVALID_ARGUMENT);
417 }
418
419 let ndim = dl.ndim as usize;
420 if ndim > 0 && dl.shape.is_null() {
421 set_last_error("from_dlpack: shape pointer is null for non-scalar tensor");
422 return Err(TFE_INVALID_ARGUMENT);
423 }
424
425 let dims_i64 = if ndim == 0 {
426 &[][..]
427 } else {
428 std::slice::from_raw_parts(dl.shape, ndim)
429 };
430 let dims = dims_i64
431 .iter()
432 .map(|&dim| usize::try_from(dim))
433 .collect::<std::result::Result<Vec<_>, _>>()
434 .map_err(|_| {
435 set_last_error("from_dlpack: shape contains negative or oversized dimension");
436 TFE_INVALID_ARGUMENT
437 })?;
438
439 let strides = if dl.strides.is_null() {
440 row_major_strides(&dims).map_err(|err| {
441 set_last_error(&format!("from_dlpack: {}", err));
442 TFE_INVALID_ARGUMENT
443 })?
444 } else {
445 std::slice::from_raw_parts(dl.strides, ndim)
446 .iter()
447 .map(|&stride| isize::try_from(stride))
448 .collect::<std::result::Result<Vec<_>, _>>()
449 .map_err(|_| {
450 set_last_error("from_dlpack: stride does not fit into isize");
451 TFE_INVALID_ARGUMENT
452 })?
453 };
454
455 let elem_size = std::mem::size_of::<f64>() as u64;
456 if dl.byte_offset % elem_size != 0 {
457 set_last_error("from_dlpack: byte offset is not aligned to f64 elements");
458 return Err(TFE_INVALID_ARGUMENT);
459 }
460 let offset = isize::try_from(dl.byte_offset / elem_size).map_err(|_| {
461 set_last_error("from_dlpack: element offset does not fit into isize");
462 TFE_INVALID_ARGUMENT
463 })?;
464 let len = required_buffer_len(&dims, &strides, offset)
465 .map_err(|err| map_device_error(&err))?;
466
467 let data = dl.data as *const f64;
468 if data.is_null() && len > 0 {
469 set_last_error("from_dlpack: data pointer is null for non-empty tensor");
470 return Err(TFE_INVALID_ARGUMENT);
471 }
472
473 let managed_addr = guard.as_ptr() as usize;
474 let release = move || unsafe {
475 let managed = managed_addr as *mut DLManagedTensorVersioned;
476 if let Some(deleter) = (*managed).deleter {
477 deleter(managed);
478 } else {
479 drop(Box::from_raw(managed));
480 }
481 };
482 let tensor = Tensor::from_external_parts(data, len, &dims, &strides, offset, release)
483 .map_err(|err| map_device_error(&err))?;
484 guard.disarm();
485 Ok(tensor_to_handle(tensor))
486 },
487 ));
488 finalize_ptr(result, status)
489}