tenferro_capi/
handle.rs

1use tenferro_algebra::Standard;
2use tenferro_prims::{CpuBackend, CpuContext, SemiringCoreDescriptor, TensorSemiringCore};
3use tenferro_tensor::{MemoryOrder, Tensor};
4
5/// Opaque handle wrapping a `Tensor<f64>`.
6///
7/// Host languages hold a pointer to this type and pass it to all
8/// `tfe_*` functions. The internal layout is private; only the C-API
9/// functions can access the inner tensor.
10///
11/// # Examples (C)
12///
13/// ```c
14/// tfe_status_t status;
15/// size_t shape[] = {2, 3};
16/// double data[] = {1, 2, 3, 4, 5, 6};
17/// tfe_tensor_f64 *t = tfe_tensor_f64_from_data(data, 6, shape, 2, &status);
18/// tfe_tensor_f64_release(t);
19/// ```
20#[repr(C)]
21pub struct TfeTensorF64 {
22    _private: [u8; 0],
23}
24
25/// Convert a `Tensor<f64>` into an opaque handle.
26pub(crate) fn tensor_to_handle(tensor: Tensor<f64>) -> *mut TfeTensorF64 {
27    Box::into_raw(Box::new(tensor)) as *mut TfeTensorF64
28}
29
30/// Ensure a tensor has column-major contiguous data layout.
31pub(crate) fn ensure_col_major(
32    ctx: &mut CpuContext,
33    tensor: Tensor<f64>,
34) -> std::result::Result<Tensor<f64>, tenferro_device::Error> {
35    if tensor.is_col_major_contiguous() {
36        return Ok(tensor);
37    }
38    let mut result = Tensor::<f64>::zeros(
39        tensor.dims(),
40        tensor.logical_memory_space(),
41        MemoryOrder::ColumnMajor,
42    )?;
43    let desc = SemiringCoreDescriptor::MakeContiguous;
44    let shapes = [tensor.dims(), result.dims()];
45    let plan = <CpuBackend as TensorSemiringCore<Standard<f64>>>::plan(ctx, &desc, &shapes)?;
46    <CpuBackend as TensorSemiringCore<Standard<f64>>>::execute(
47        ctx,
48        &plan,
49        1.0,
50        &[&tensor],
51        0.0,
52        &mut result,
53    )?;
54    Ok(result)
55}
56
57/// Borrow the tensor behind an opaque handle.
58///
59/// # Safety
60///
61/// `handle` must be a valid, non-null pointer returned by `tensor_to_handle`.
62pub(crate) unsafe fn handle_to_ref<'a>(handle: *const TfeTensorF64) -> &'a Tensor<f64> {
63    &*(handle as *const Tensor<f64>)
64}
65
66/// Take ownership of the tensor behind an opaque handle.
67///
68/// # Safety
69///
70/// `handle` must be a valid, non-null pointer returned by `tensor_to_handle`.
71/// Must not be used after this call.
72pub(crate) unsafe fn handle_take(handle: *mut TfeTensorF64) -> Box<Tensor<f64>> {
73    Box::from_raw(handle as *mut Tensor<f64>)
74}