tenferro_tensor/structured_tensor/views.rs
1use tenferro_algebra::Conjugate;
2
3use super::{validate_permutation, StructuredTensor};
4
5impl<T: tenferro_algebra::Scalar> StructuredTensor<T> {
6 /// Returns the same logical tensor with permuted logical axes.
7 ///
8 /// This permutes both the logical axes and the compressed payload class
9 /// order, then rebuilds the canonical axis-class representation.
10 ///
11 /// # Examples
12 ///
13 /// ```ignore
14 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
15 ///
16 /// let dense = Tensor::<f64>::from_slice(
17 /// &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
18 /// &[2, 3],
19 /// MemoryOrder::ColumnMajor,
20 /// )
21 /// .unwrap();
22 /// let x = StructuredTensor::from_dense(dense);
23 /// let y = x.permute_logical(&[1, 0]).unwrap();
24 /// assert_eq!(y.logical_dims(), &[3, 2]);
25 /// ```
26 pub fn permute_logical(&self, perm: &[usize]) -> tenferro_device::Result<Self> {
27 validate_permutation(
28 perm,
29 self.logical_dims.len(),
30 "StructuredTensor::permute_logical",
31 )?;
32
33 let permuted_dims: Vec<usize> = perm.iter().map(|&axis| self.logical_dims[axis]).collect();
34 let permuted_raw_classes: Vec<usize> =
35 perm.iter().map(|&axis| self.axis_classes[axis]).collect();
36
37 let mut seen_classes = vec![false; self.class_count()];
38 let mut class_order = Vec::with_capacity(self.class_count());
39 for &class_id in &permuted_raw_classes {
40 if !seen_classes[class_id] {
41 seen_classes[class_id] = true;
42 class_order.push(class_id);
43 }
44 }
45
46 let mut remap = vec![usize::MAX; self.class_count()];
47 for (new_class, &old_class) in class_order.iter().enumerate() {
48 remap[old_class] = new_class;
49 }
50 let canonical_classes: Vec<usize> = permuted_raw_classes
51 .iter()
52 .map(|&old_class| remap[old_class])
53 .collect();
54
55 let payload = self.payload.permute(&class_order)?;
56 Self::new(permuted_dims, canonical_classes, payload)
57 }
58
59 /// Returns the same structured tensor with payload conjugation toggled.
60 ///
61 /// # Examples
62 ///
63 /// ```ignore
64 /// use num_complex::Complex64;
65 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
66 ///
67 /// let payload = Tensor::<Complex64>::from_slice(
68 /// &[Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
69 /// &[2],
70 /// MemoryOrder::ColumnMajor,
71 /// )
72 /// .unwrap();
73 /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
74 /// let y = x.conj();
75 /// assert_eq!(y.logical_dims(), x.logical_dims());
76 /// ```
77 pub fn conj(&self) -> Self
78 where
79 T: Conjugate,
80 {
81 Self {
82 payload: self.payload.conj(),
83 logical_dims: self.logical_dims.clone(),
84 axis_classes: self.axis_classes.clone(),
85 }
86 }
87}