tenferro_ext_burn/convert.rs
1//! Conversion utilities between Burn tensor primitives and tenferro tensors.
2//!
3//! # Current Limitations
4//!
5//! These conversion functions currently only support `f64` element type.
6//! Support for `f32` and other numeric types will be added in future versions.
7
8use burn::tensor::backend::Backend;
9use burn::tensor::ops::FloatTensor;
10use burn::tensor::{Tensor as BurnTensor, TensorData, TensorPrimitive};
11use tenferro_tensor::{MemoryOrder, Tensor as TfTensor};
12
13use crate::{panic_on_error, Error, Result};
14
15fn row_major_strides(dims: &[usize]) -> Vec<isize> {
16 let ndim = dims.len();
17 if ndim == 0 {
18 return vec![];
19 }
20
21 let mut strides = vec![0isize; ndim];
22 strides[ndim - 1] = 1;
23 for i in (0..ndim - 1).rev() {
24 strides[i] = strides[i + 1] * dims[i + 1] as isize;
25 }
26 strides
27}
28
29/// Fallibly convert a Burn backend tensor primitive into a tenferro
30/// `Tensor<f64>`.
31///
32/// Burn tensors are treated as row-major boundary values. The canonical bridge
33/// normalizes them into tenferro's internal column-major tensor layout before
34/// returning.
35///
36/// # Current Limitations
37///
38/// This function currently supports only Burn backends whose float element
39/// type is `f64`. Support for other element types (e.g., `f32`) will be added
40/// in future versions.
41///
42/// # Examples
43///
44/// ```ignore
45/// use burn::backend::NdArray;
46/// use tenferro_ext_burn::convert::try_burn_to_tenferro;
47///
48/// let burn_prim: <NdArray<f64> as burn::tensor::backend::Backend>::FloatTensorPrimitive =
49/// todo!();
50/// let tenferro_t = try_burn_to_tenferro::<NdArray<f64>>(burn_prim).unwrap();
51/// ```
52pub fn try_burn_to_tenferro<B: Backend<FloatElem = f64>>(
53 tensor: FloatTensor<B>,
54) -> Result<TfTensor<f64>> {
55 let data = BurnTensor::<B, 1>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
56 let dims = data.shape.clone();
57 let values = data.into_vec::<f64>().map_err(|_| {
58 Error::InvalidArgument("burn_to_tenferro only supports f64 float tensors".into())
59 })?;
60
61 let tensor =
62 TfTensor::from_vec(values, &dims, &row_major_strides(&dims), 0).map_err(|err| {
63 Error::InvalidArgument(format!("Burn TensorData must be dense row-major: {err}"))
64 })?;
65 Ok(tensor.into_contiguous(MemoryOrder::ColumnMajor))
66}
67
68/// Convert a Burn backend tensor primitive into a tenferro `Tensor<f64>`,
69/// panicking if conversion fails.
70///
71/// # Examples
72///
73/// ```ignore
74/// use burn::backend::NdArray;
75/// use tenferro_ext_burn::convert::burn_to_tenferro;
76///
77/// let burn_prim: <NdArray<f64> as burn::tensor::backend::Backend>::FloatTensorPrimitive =
78/// todo!();
79/// let tenferro_t = burn_to_tenferro::<NdArray<f64>>(burn_prim);
80/// assert_eq!(tenferro_t.dims().len(), 1);
81/// ```
82pub fn burn_to_tenferro<B: Backend<FloatElem = f64>>(tensor: FloatTensor<B>) -> TfTensor<f64> {
83 panic_on_error(try_burn_to_tenferro::<B>(tensor))
84}
85
86/// Convert a tenferro `Tensor<f64>` into a Burn backend tensor primitive.
87///
88/// The `device` parameter specifies which Burn device the resulting tensor
89/// should be placed on. For the `NdArray` backend this is typically
90/// `NdArrayDevice::Cpu`, obtainable via `Default::default()`.
91/// The bridge always materializes a row-major owned buffer at this boundary.
92///
93/// # Current Limitations
94///
95/// This function currently supports only Burn backends whose float element
96/// type is `f64`. Support for other element types will be added in future
97/// versions.
98///
99/// # Examples
100///
101/// ```ignore
102/// use burn::backend::NdArray;
103/// use burn::backend::ndarray::NdArrayDevice;
104/// use tenferro_ext_burn::convert::try_tenferro_to_burn;
105///
106/// let tenferro_t: tenferro_tensor::Tensor<f64> = todo!();
107/// let device = NdArrayDevice::Cpu;
108/// let burn_prim = try_tenferro_to_burn::<NdArray<f64>>(tenferro_t, &device).unwrap();
109/// ```
110pub fn try_tenferro_to_burn<B: Backend<FloatElem = f64>>(
111 tensor: TfTensor<f64>,
112 device: &B::Device,
113) -> Result<FloatTensor<B>> {
114 let row_major = tensor.into_contiguous(MemoryOrder::RowMajor);
115 let dims = row_major.dims().to_vec();
116 let data = row_major
117 .try_into_data_vec()
118 .ok_or(Error::InternalInvariant(
119 "into_contiguous must return a uniquely-owned CPU buffer",
120 ))?;
121
122 Ok(B::float_from_data(TensorData::new(data, dims), device))
123}
124
125/// Convert a tenferro `Tensor<f64>` into a Burn backend tensor primitive,
126/// panicking if conversion fails.
127///
128/// # Examples
129///
130/// ```ignore
131/// use burn::backend::NdArray;
132/// use tenferro_ext_burn::convert::tenferro_to_burn;
133///
134/// let device = Default::default();
135/// let tenferro_t: tenferro_tensor::Tensor<f64> = todo!();
136/// let _burn_prim = tenferro_to_burn::<NdArray<f64>>(tenferro_t, &device);
137/// ```
138pub fn tenferro_to_burn<B: Backend<FloatElem = f64>>(
139 tensor: TfTensor<f64>,
140 device: &B::Device,
141) -> FloatTensor<B> {
142 panic_on_error(try_tenferro_to_burn::<B>(tensor, device))
143}