tenferro_tensor/structured_tensor/
mod.rs

1//! Structured tensor metadata layered on top of dense [`Tensor`] payloads.
2
3mod conversion;
4mod validation;
5mod views;
6
7use tenferro_algebra::Scalar;
8use tenferro_device::{Error, Result};
9
10use crate::Tensor;
11
12pub(crate) use validation::validate_permutation;
13pub use validation::{canonicalize_axis_classes, validate_layout};
14
15/// Structured tensor payload with logical axis metadata.
16///
17/// This stores logical tensor metadata separately from the compressed payload.
18/// Dense and diagonal tensors are representation cases of the same type.
19///
20/// # Examples
21///
22/// ```ignore
23/// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
24///
25/// let payload =
26///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
27/// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
28/// assert_eq!(x.logical_dims(), &[2, 2]);
29/// assert!(x.is_diag());
30/// ```
31#[derive(Debug, Clone)]
32pub struct StructuredTensor<T: Scalar> {
33    payload: Tensor<T>,
34    logical_dims: Vec<usize>,
35    axis_classes: Vec<usize>,
36}
37
38impl<T: Scalar> StructuredTensor<T> {
39    /// Construct a structured tensor from logical metadata and compressed payload.
40    ///
41    /// Axis classes are canonicalized to first-appearance order.
42    ///
43    /// # Examples
44    ///
45    /// ```ignore
46    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
47    ///
48    /// let payload =
49    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
50    /// let x = StructuredTensor::new(vec![2, 2], vec![9, 9], payload).unwrap();
51    /// assert_eq!(x.axis_classes(), &[0, 0]);
52    /// ```
53    pub fn new(
54        logical_dims: Vec<usize>,
55        axis_classes: Vec<usize>,
56        payload: Tensor<T>,
57    ) -> Result<Self> {
58        let axis_classes = canonicalize_axis_classes(&axis_classes);
59        validate_layout(&logical_dims, &axis_classes, &payload)?;
60        Ok(Self {
61            payload,
62            logical_dims,
63            axis_classes,
64        })
65    }
66
67    /// Construct a dense structured tensor.
68    ///
69    /// # Examples
70    ///
71    /// ```ignore
72    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
73    ///
74    /// let dense = Tensor::<f64>::from_slice(
75    ///     &[1.0, 2.0, 3.0, 4.0],
76    ///     &[2, 2],
77    ///     MemoryOrder::ColumnMajor,
78    /// )
79    /// .unwrap();
80    /// let x = StructuredTensor::from_dense(dense);
81    /// assert!(x.is_dense());
82    /// ```
83    pub fn from_dense(payload: Tensor<T>) -> Self {
84        let logical_dims = payload.dims().to_vec();
85        let axis_classes = (0..logical_dims.len()).collect();
86        Self {
87            payload,
88            logical_dims,
89            axis_classes,
90        }
91    }
92
93    /// Construct a diagonal structured tensor from a rank-1 payload.
94    ///
95    /// # Examples
96    ///
97    /// ```ignore
98    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
99    ///
100    /// let payload =
101    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
102    /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
103    /// assert!(x.is_diag());
104    /// ```
105    pub fn from_diagonal_vector(payload: Tensor<T>, logical_rank: usize) -> Result<Self> {
106        if payload.dims().len() != 1 {
107            return Err(Error::InvalidArgument(format!(
108                "from_diagonal_vector expects rank-1 payload, got rank {}",
109                payload.dims().len()
110            )));
111        }
112        if logical_rank == 0 {
113            return Err(Error::InvalidArgument(
114                "from_diagonal_vector requires logical_rank >= 1".to_string(),
115            ));
116        }
117        let n = payload.dims()[0];
118        Self::new(vec![n; logical_rank], vec![0; logical_rank], payload)
119    }
120
121    /// Borrow the compressed payload tensor.
122    ///
123    /// # Examples
124    ///
125    /// ```ignore
126    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
127    ///
128    /// let dense =
129    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
130    /// let x = StructuredTensor::from_dense(dense);
131    /// assert_eq!(x.payload().dims(), &[2]);
132    /// ```
133    pub fn payload(&self) -> &Tensor<T> {
134        &self.payload
135    }
136
137    /// Consume the structured tensor and return the compressed payload tensor.
138    ///
139    /// # Examples
140    ///
141    /// ```ignore
142    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
143    ///
144    /// let dense =
145    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
146    /// let x = StructuredTensor::from_dense(dense);
147    /// let payload = x.into_payload();
148    /// assert_eq!(payload.dims(), &[2]);
149    /// ```
150    pub fn into_payload(self) -> Tensor<T> {
151        self.payload
152    }
153
154    /// Returns logical dimensions.
155    ///
156    /// # Examples
157    ///
158    /// ```ignore
159    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
160    ///
161    /// let payload =
162    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
163    /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
164    /// assert_eq!(x.logical_dims(), &[2, 2]);
165    /// ```
166    pub fn logical_dims(&self) -> &[usize] {
167        &self.logical_dims
168    }
169
170    /// Returns axis classes for logical axes.
171    ///
172    /// # Examples
173    ///
174    /// ```ignore
175    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
176    ///
177    /// let payload =
178    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
179    /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
180    /// assert_eq!(x.axis_classes(), &[0, 0]);
181    /// ```
182    pub fn axis_classes(&self) -> &[usize] {
183        &self.axis_classes
184    }
185
186    /// Returns the number of distinct axis classes.
187    ///
188    /// # Examples
189    ///
190    /// ```ignore
191    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
192    ///
193    /// let payload =
194    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
195    /// let x = StructuredTensor::from_diagonal_vector(payload, 3).unwrap();
196    /// assert_eq!(x.class_count(), 1);
197    /// ```
198    pub fn class_count(&self) -> usize {
199        self.payload.dims().len()
200    }
201
202    /// Returns `true` when the layout is dense.
203    ///
204    /// # Examples
205    ///
206    /// ```ignore
207    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
208    ///
209    /// let dense = Tensor::<f64>::from_slice(
210    ///     &[1.0, 2.0, 3.0, 4.0],
211    ///     &[2, 2],
212    ///     MemoryOrder::ColumnMajor,
213    /// )
214    /// .unwrap();
215    /// let x = StructuredTensor::from_dense(dense);
216    /// assert!(x.is_dense());
217    /// ```
218    pub fn is_dense(&self) -> bool {
219        self.axis_classes.len() == self.logical_dims.len()
220            && self.logical_dims == self.payload.dims()
221            && self
222                .axis_classes
223                .iter()
224                .enumerate()
225                .all(|(i, &class_id)| class_id == i)
226    }
227
228    /// Returns `true` when the layout is a pure diagonal.
229    ///
230    /// # Examples
231    ///
232    /// ```ignore
233    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
234    ///
235    /// let payload =
236    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
237    /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
238    /// assert!(x.is_diag());
239    /// ```
240    pub fn is_diag(&self) -> bool {
241        if self.logical_dims.is_empty() || self.axis_classes.len() != self.logical_dims.len() {
242            return false;
243        }
244        let first_dim = self.logical_dims[0];
245        self.axis_classes.iter().all(|&class_id| class_id == 0)
246            && self.logical_dims.iter().all(|&dim| dim == first_dim)
247            && self.payload.dims().len() == 1
248            && self.payload.dims()[0] == first_dim
249    }
250
251    /// Rebuild the same structured layout with a different payload tensor.
252    ///
253    /// # Examples
254    ///
255    /// ```ignore
256    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
257    ///
258    /// let payload =
259    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
260    /// let layout = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
261    /// let replacement =
262    ///     Tensor::<f64>::from_slice(&[3.0, 4.0], &[2], MemoryOrder::ColumnMajor).unwrap();
263    /// let updated = layout.with_payload_like(replacement).unwrap();
264    /// assert!(updated.is_diag());
265    /// ```
266    pub fn with_payload_like(&self, payload: Tensor<T>) -> Result<Self> {
267        Self::new(
268            self.logical_dims.clone(),
269            self.axis_classes.clone(),
270            payload,
271        )
272    }
273
274    /// Construct a structured tensor without re-validating the metadata.
275    ///
276    /// # Safety
277    ///
278    /// Callers must ensure that `logical_dims`, `axis_classes`, and `payload`
279    /// already satisfy the invariants enforced by [`StructuredTensor::new`].
280    ///
281    /// # Examples
282    ///
283    /// ```ignore
284    /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
285    ///
286    /// let payload =
287    ///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
288    /// let x = StructuredTensor::from_validated_parts(vec![2, 2], vec![0, 0], payload);
289    /// assert!(x.is_diag());
290    /// ```
291    pub fn from_validated_parts(
292        logical_dims: Vec<usize>,
293        axis_classes: Vec<usize>,
294        payload: Tensor<T>,
295    ) -> Self {
296        Self {
297            payload,
298            logical_dims,
299            axis_classes,
300        }
301    }
302}
303
304impl<T: Scalar> From<Tensor<T>> for StructuredTensor<T> {
305    fn from(value: Tensor<T>) -> Self {
306        Self::from_dense(value)
307    }
308}
309
310impl<T: Scalar> AsRef<Tensor<T>> for StructuredTensor<T> {
311    fn as_ref(&self) -> &Tensor<T> {
312        self.payload()
313    }
314}
315
316#[cfg(test)]
317mod tests;