tenferro_tensor/structured_tensor/
conversion.rs

1use tenferro_algebra::Scalar;
2use tenferro_device::{checked_batch_count, unflatten_col_major_index_into, Error, Result};
3
4use super::StructuredTensor;
5use crate::{MemoryOrder, Tensor};
6
7impl<T: Scalar> StructuredTensor<T> {
8    /// Materialize this structured tensor into a dense tensor.
9    ///
10    /// # Examples
11    ///
12    /// ```
13    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
14    ///
15    /// let payload =
16    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
17    /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
18    /// let dense = x.to_dense().unwrap();
19    /// assert_eq!(dense.dims(), &[2, 2]);
20    /// assert_eq!(dense.get(&[0, 0]), Some(&1.0));
21    /// assert_eq!(dense.get(&[0, 1]), Some(&0.0));
22    /// ```
23    pub fn to_dense(&self) -> Result<Tensor<T>> {
24        if self.is_dense() {
25            return Ok(self.payload.clone());
26        }
27
28        let mut dense = Tensor::zeros(
29            &self.logical_dims,
30            self.payload.logical_memory_space(),
31            MemoryOrder::ColumnMajor,
32        )?;
33
34        let payload_dims = self.payload.dims();
35        let payload_numel = checked_batch_count(payload_dims)?;
36        let mut payload_index = vec![0usize; payload_dims.len()];
37        let mut logical_index = vec![0usize; self.logical_dims.len()];
38
39        for linear in 0..payload_numel {
40            if !payload_dims.is_empty() {
41                unflatten_col_major_index_into(linear, payload_dims, &mut payload_index)?;
42            }
43
44            for (logical_axis, &class_id) in self.axis_classes.iter().enumerate() {
45                logical_index[logical_axis] = payload_index[class_id];
46            }
47
48            let value = self.payload.get(&payload_index).copied().ok_or_else(|| {
49                Error::InvalidArgument(format!(
50                    "to_dense: could not read payload at index {payload_index:?}"
51                ))
52            })?;
53            dense.set(&logical_index, value)?;
54        }
55
56        Ok(dense)
57    }
58
59    /// Consume this structured tensor and return a dense tensor.
60    ///
61    /// Returns the payload directly when the layout is already dense.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
67    ///
68    /// let payload =
69    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
70    /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
71    /// let dense = x.into_dense().unwrap();
72    /// assert_eq!(dense.dims(), &[2, 2]);
73    /// ```
74    pub fn into_dense(self) -> Result<Tensor<T>> {
75        if self.is_dense() {
76            Ok(self.payload)
77        } else {
78            self.to_dense()
79        }
80    }
81}