tenferro_ext_tropical/ad/
convert.rs

1use tenferro_device::{Error, Result};
2use tenferro_tensor::{MemoryOrder, Tensor};
3
4use super::TropicalScalar;
5
6/// Promote a standard-real tensor to a tropical scalar tensor.
7pub fn promote_to_tropical<T: TropicalScalar>(tensor: &Tensor<T::Inner>) -> Result<Tensor<T>> {
8    tensor
9        .buffer()
10        .as_slice()
11        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
12    let contiguous = tensor.contiguous(MemoryOrder::ColumnMajor);
13    let data = contiguous.buffer().as_slice().ok_or_else(|| {
14        Error::DeviceError("tensor materialization produced a non-CPU buffer".into())
15    })?;
16    let tropical_data: Vec<T> = data.iter().map(|&v| T::from_inner(v)).collect();
17    Tensor::<T>::from_slice(&tropical_data, tensor.dims(), MemoryOrder::ColumnMajor)
18        .map_err(|e| Error::InvalidArgument(format!("{e}")))
19}
20
21/// Extract the inner real values from a tropical tensor.
22pub fn extract_inner<T: TropicalScalar>(tensor: &Tensor<T>) -> Result<Tensor<T::Inner>> {
23    tensor
24        .buffer()
25        .as_slice()
26        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
27    let contiguous = tensor.contiguous(MemoryOrder::ColumnMajor);
28    let data = contiguous.buffer().as_slice().ok_or_else(|| {
29        Error::DeviceError("tensor materialization produced a non-CPU buffer".into())
30    })?;
31    let inner_data: Vec<T::Inner> = data.iter().map(|value| value.inner()).collect();
32    Tensor::<T::Inner>::from_slice(&inner_data, tensor.dims(), MemoryOrder::ColumnMajor)
33        .map_err(|e| Error::InvalidArgument(format!("{e}")))
34}