tenferro_tensor/structured_tensor/
views.rs

1use tenferro_algebra::Conjugate;
2
3use super::{validate_permutation, StructuredTensor};
4
5impl<T: tenferro_algebra::Scalar> StructuredTensor<T> {
6    /// Returns the same logical tensor with permuted logical axes.
7    ///
8    /// This permutes both the logical axes and the compressed payload class
9    /// order, then rebuilds the canonical axis-class representation.
10    ///
11    /// # Examples
12    ///
13    /// ```ignore
14    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
15    ///
16    /// let dense = Tensor::<f64>::from_slice(
17    ///     &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
18    ///     &[2, 3],
19    ///     MemoryOrder::ColumnMajor,
20    /// )
21    /// .unwrap();
22    /// let x = StructuredTensor::from_dense(dense);
23    /// let y = x.permute_logical(&[1, 0]).unwrap();
24    /// assert_eq!(y.logical_dims(), &[3, 2]);
25    /// ```
26    pub fn permute_logical(&self, perm: &[usize]) -> tenferro_device::Result<Self> {
27        validate_permutation(
28            perm,
29            self.logical_dims.len(),
30            "StructuredTensor::permute_logical",
31        )?;
32
33        let permuted_dims: Vec<usize> = perm.iter().map(|&axis| self.logical_dims[axis]).collect();
34        let permuted_raw_classes: Vec<usize> =
35            perm.iter().map(|&axis| self.axis_classes[axis]).collect();
36
37        let mut seen_classes = vec![false; self.class_count()];
38        let mut class_order = Vec::with_capacity(self.class_count());
39        for &class_id in &permuted_raw_classes {
40            if !seen_classes[class_id] {
41                seen_classes[class_id] = true;
42                class_order.push(class_id);
43            }
44        }
45
46        let mut remap = vec![usize::MAX; self.class_count()];
47        for (new_class, &old_class) in class_order.iter().enumerate() {
48            remap[old_class] = new_class;
49        }
50        let canonical_classes: Vec<usize> = permuted_raw_classes
51            .iter()
52            .map(|&old_class| remap[old_class])
53            .collect();
54
55        let payload = self.payload.permute(&class_order)?;
56        Self::new(permuted_dims, canonical_classes, payload)
57    }
58
59    /// Returns the same structured tensor with payload conjugation toggled.
60    ///
61    /// # Examples
62    ///
63    /// ```ignore
64    /// use num_complex::Complex64;
65    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
66    ///
67    /// let payload = Tensor::<Complex64>::from_slice(
68    ///     &[Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
69    ///     &[2],
70    ///     MemoryOrder::ColumnMajor,
71    /// )
72    /// .unwrap();
73    /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
74    /// let y = x.conj();
75    /// assert_eq!(y.logical_dims(), x.logical_dims());
76    /// ```
77    pub fn conj(&self) -> Self
78    where
79        T: Conjugate,
80    {
81        Self {
82            payload: self.payload.conj(),
83            logical_dims: self.logical_dims.clone(),
84            axis_classes: self.axis_classes.clone(),
85        }
86    }
87}