tenferro_ext_tropical/ad/
convert.rs1use tenferro_device::{Error, Result};
2use tenferro_tensor::{MemoryOrder, Tensor};
3
4use super::TropicalScalar;
5
6pub fn promote_to_tropical<T: TropicalScalar>(tensor: &Tensor<T::Inner>) -> Result<Tensor<T>> {
8 tensor
9 .buffer()
10 .as_slice()
11 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
12 let contiguous = tensor.contiguous(MemoryOrder::ColumnMajor);
13 let data = contiguous.buffer().as_slice().ok_or_else(|| {
14 Error::DeviceError("tensor materialization produced a non-CPU buffer".into())
15 })?;
16 let tropical_data: Vec<T> = data.iter().map(|&v| T::from_inner(v)).collect();
17 Tensor::<T>::from_slice(&tropical_data, tensor.dims(), MemoryOrder::ColumnMajor)
18 .map_err(|e| Error::InvalidArgument(format!("{e}")))
19}
20
21pub fn extract_inner<T: TropicalScalar>(tensor: &Tensor<T>) -> Result<Tensor<T::Inner>> {
23 tensor
24 .buffer()
25 .as_slice()
26 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
27 let contiguous = tensor.contiguous(MemoryOrder::ColumnMajor);
28 let data = contiguous.buffer().as_slice().ok_or_else(|| {
29 Error::DeviceError("tensor materialization produced a non-CPU buffer".into())
30 })?;
31 let inner_data: Vec<T::Inner> = data.iter().map(|value| value.inner()).collect();
32 Tensor::<T::Inner>::from_slice(&inner_data, tensor.dims(), MemoryOrder::ColumnMajor)
33 .map_err(|e| Error::InvalidArgument(format!("{e}")))
34}