tenferro_ext_mdarray/
lib.rs

1//! Bridge between [mdarray](https://docs.rs/mdarray) multidimensional arrays
2//! and [tenferro](https://docs.rs/tenferro-tensor) tensors.
3//!
4//! This crate provides conversion functions between mdarray's
5//! `Array<T, DynRank>` and tenferro's `Tensor<T>`, enabling convenient data
6//! exchange between the two ecosystems.
7//!
8//! Due to Rust's orphan rules, [`From`]/[`Into`] trait impls cannot be provided
9//! for two external types. Instead, use the standalone conversion functions
10//! [`mdarray_to_tensor`] and [`tensor_to_mdarray`].
11//!
12//! **Zero-copy is a non-goal.** Both conversion directions copy element data.
13//! The purpose of this crate is ergonomic interoperability, not performance-
14//! critical data sharing.
15//!
16//! # Examples
17//!
18//! ```ignore
19//! use mdarray::{Array, DynRank};
20//! use tenferro_device::LogicalMemorySpace;
21//! use tenferro_tensor::{MemoryOrder, Tensor};
22//! use tenferro_ext_mdarray::{mdarray_to_tensor, tensor_to_mdarray};
23//!
24//! // mdarray -> tenferro
25//! let md: Array<f64, DynRank> = mdarray::tensor![1.0, 2.0, 3.0, 4.0].into_dyn();
26//! let t: Tensor<f64> = mdarray_to_tensor(md);
27//!
28//! // tenferro -> mdarray
29//! let t2: Tensor<f64> = Tensor::zeros(
30//!     &[2, 3],
31//!     LogicalMemorySpace::MainMemory,
32//!     MemoryOrder::ColumnMajor,
33//! );
34//! let md2: Array<f64, DynRank> = tensor_to_mdarray(t2);
35//! ```
36
37use mdarray::{Array, DynRank};
38use tenferro_algebra::Scalar;
39use tenferro_device::{Error, Result};
40use tenferro_tensor::{MemoryOrder, Tensor};
41
42#[cfg(test)]
43mod tests;
44
45fn row_major_strides(dims: &[usize]) -> Vec<isize> {
46    let ndim = dims.len();
47    if ndim == 0 {
48        return vec![];
49    }
50
51    let mut strides = vec![0isize; ndim];
52    strides[ndim - 1] = 1;
53    for i in (0..ndim - 1).rev() {
54        strides[i] = strides[i + 1] * dims[i + 1] as isize;
55    }
56    strides
57}
58
59/// Fallibly converts an mdarray `Array<T, DynRank>` into a tenferro `Tensor<T>`.
60///
61/// This conversion copies all element data from the mdarray array into a
62/// newly allocated tenferro tensor on the CPU. The shape is preserved.
63///
64/// # Examples
65///
66/// ```ignore
67/// use mdarray::{Array, DynRank};
68/// use tenferro_tensor::Tensor;
69/// use tenferro_ext_mdarray::try_mdarray_to_tensor;
70///
71/// let md: Array<f64, DynRank> = mdarray::tensor![1.0, 2.0, 3.0].into_dyn();
72/// let t: Tensor<f64> = try_mdarray_to_tensor(md).unwrap();
73/// ```
74pub fn try_mdarray_to_tensor<T: Scalar>(array: Array<T, DynRank>) -> Result<Tensor<T>> {
75    let dims = array.dims().to_vec();
76    let strides = row_major_strides(&dims);
77    Tensor::from_vec(array.into_vec(), &dims, &strides, 0)
78}
79
80/// Converts an mdarray `Array<T, DynRank>` into a tenferro `Tensor<T>`,
81/// panicking if conversion fails.
82///
83/// # Examples
84///
85/// ```ignore
86/// use mdarray::{Array, DynRank};
87/// use tenferro_ext_mdarray::mdarray_to_tensor;
88///
89/// let md: Array<f64, DynRank> = mdarray::tensor![1.0, 2.0, 3.0].into_dyn();
90/// let t = mdarray_to_tensor(md);
91/// assert_eq!(t.dims(), &[3]);
92/// ```
93pub fn mdarray_to_tensor<T: Scalar>(array: Array<T, DynRank>) -> Tensor<T> {
94    try_mdarray_to_tensor(array).unwrap_or_else(|err| panic!("{err}"))
95}
96
97/// Fallibly converts a tenferro `Tensor<T>` into an mdarray `Array<T, DynRank>`.
98///
99/// This conversion copies all element data from the tenferro tensor into a
100/// newly allocated mdarray array. The shape is preserved.
101///
102/// # Examples
103///
104/// ```ignore
105/// use mdarray::{Array, DynRank};
106/// use tenferro_tensor::Tensor;
107/// use tenferro_ext_mdarray::try_tensor_to_mdarray;
108///
109/// use tenferro_device::LogicalMemorySpace;
110/// use tenferro_tensor::MemoryOrder;
111///
112/// let t: Tensor<f64> = Tensor::zeros(
113///     &[3, 4],
114///     LogicalMemorySpace::MainMemory,
115///     MemoryOrder::ColumnMajor,
116/// );
117/// let md: Array<f64, DynRank> = try_tensor_to_mdarray(t).unwrap();
118/// ```
119pub fn try_tensor_to_mdarray<T: Scalar>(tensor: Tensor<T>) -> Result<Array<T, DynRank>> {
120    let row_major = tensor.into_contiguous(MemoryOrder::RowMajor);
121    let dims = row_major.dims().to_vec();
122    let data = row_major.try_into_data_vec().ok_or(Error::InvalidArgument(
123        "into_contiguous must return an owned CPU buffer".into(),
124    ))?;
125    Ok(Array::from(data).into_shape(dims.as_slice()).into_dyn())
126}
127
128/// Converts a tenferro `Tensor<T>` into an mdarray `Array<T, DynRank>`,
129/// panicking if conversion fails.
130///
131/// # Examples
132///
133/// ```ignore
134/// use tenferro_device::LogicalMemorySpace;
135/// use tenferro_ext_mdarray::tensor_to_mdarray;
136/// use tenferro_tensor::{MemoryOrder, Tensor};
137///
138/// let t: Tensor<f64> = Tensor::zeros(
139///     &[3, 4],
140///     LogicalMemorySpace::MainMemory,
141///     MemoryOrder::ColumnMajor,
142/// );
143/// let md = tensor_to_mdarray(t);
144/// assert_eq!(md.shape().ndim(), 2);
145/// ```
146pub fn tensor_to_mdarray<T: Scalar>(tensor: Tensor<T>) -> Array<T, DynRank> {
147    try_tensor_to_mdarray(tensor).unwrap_or_else(|err| panic!("{err}"))
148}