1use tenferro_algebra::Standard;
2use tenferro_prims::{CpuBackend, CpuContext, SemiringCoreDescriptor, TensorSemiringCore};
3use tenferro_tensor::{MemoryOrder, Tensor};
4
5#[repr(C)]
21pub struct TfeTensorF64 {
22 _private: [u8; 0],
23}
24
25pub(crate) fn tensor_to_handle(tensor: Tensor<f64>) -> *mut TfeTensorF64 {
27 Box::into_raw(Box::new(tensor)) as *mut TfeTensorF64
28}
29
30pub(crate) fn ensure_col_major(
32 ctx: &mut CpuContext,
33 tensor: Tensor<f64>,
34) -> std::result::Result<Tensor<f64>, tenferro_device::Error> {
35 if tensor.is_col_major_contiguous() {
36 return Ok(tensor);
37 }
38 let mut result = Tensor::<f64>::zeros(
39 tensor.dims(),
40 tensor.logical_memory_space(),
41 MemoryOrder::ColumnMajor,
42 )?;
43 let desc = SemiringCoreDescriptor::MakeContiguous;
44 let shapes = [tensor.dims(), result.dims()];
45 let plan = <CpuBackend as TensorSemiringCore<Standard<f64>>>::plan(ctx, &desc, &shapes)?;
46 <CpuBackend as TensorSemiringCore<Standard<f64>>>::execute(
47 ctx,
48 &plan,
49 1.0,
50 &[&tensor],
51 0.0,
52 &mut result,
53 )?;
54 Ok(result)
55}
56
57pub(crate) unsafe fn handle_to_ref<'a>(handle: *const TfeTensorF64) -> &'a Tensor<f64> {
63 &*(handle as *const Tensor<f64>)
64}
65
66pub(crate) unsafe fn handle_take(handle: *mut TfeTensorF64) -> Box<Tensor<f64>> {
73 Box::from_raw(handle as *mut Tensor<f64>)
74}