tenferro_tensor/structured_tensor/conversion.rs
1use tenferro_algebra::Scalar;
2use tenferro_device::{checked_batch_count, unflatten_col_major_index_into, Error, Result};
3
4use super::StructuredTensor;
5use crate::{MemoryOrder, Tensor};
6
7impl<T: Scalar> StructuredTensor<T> {
8 /// Materialize this structured tensor into a dense tensor.
9 ///
10 /// # Examples
11 ///
12 /// ```
13 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
14 ///
15 /// let payload =
16 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
17 /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
18 /// let dense = x.to_dense().unwrap();
19 /// assert_eq!(dense.dims(), &[2, 2]);
20 /// assert_eq!(dense.get(&[0, 0]), Some(&1.0));
21 /// assert_eq!(dense.get(&[0, 1]), Some(&0.0));
22 /// ```
23 pub fn to_dense(&self) -> Result<Tensor<T>> {
24 if self.is_dense() {
25 return Ok(self.payload.clone());
26 }
27
28 let mut dense = Tensor::zeros(
29 &self.logical_dims,
30 self.payload.logical_memory_space(),
31 MemoryOrder::ColumnMajor,
32 )?;
33
34 let payload_dims = self.payload.dims();
35 let payload_numel = checked_batch_count(payload_dims)?;
36 let mut payload_index = vec![0usize; payload_dims.len()];
37 let mut logical_index = vec![0usize; self.logical_dims.len()];
38
39 for linear in 0..payload_numel {
40 if !payload_dims.is_empty() {
41 unflatten_col_major_index_into(linear, payload_dims, &mut payload_index)?;
42 }
43
44 for (logical_axis, &class_id) in self.axis_classes.iter().enumerate() {
45 logical_index[logical_axis] = payload_index[class_id];
46 }
47
48 let value = self.payload.get(&payload_index).copied().ok_or_else(|| {
49 Error::InvalidArgument(format!(
50 "to_dense: could not read payload at index {payload_index:?}"
51 ))
52 })?;
53 dense.set(&logical_index, value)?;
54 }
55
56 Ok(dense)
57 }
58
59 /// Consume this structured tensor and return a dense tensor.
60 ///
61 /// Returns the payload directly when the layout is already dense.
62 ///
63 /// # Examples
64 ///
65 /// ```
66 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
67 ///
68 /// let payload =
69 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
70 /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
71 /// let dense = x.into_dense().unwrap();
72 /// assert_eq!(dense.dims(), &[2, 2]);
73 /// ```
74 pub fn into_dense(self) -> Result<Tensor<T>> {
75 if self.is_dense() {
76 Ok(self.payload)
77 } else {
78 self.to_dense()
79 }
80 }
81}