tenferro_tensor/structured_tensor/mod.rs
1//! Structured tensor metadata layered on top of dense [`Tensor`] payloads.
2
3mod conversion;
4mod validation;
5mod views;
6
7use tenferro_algebra::Scalar;
8use tenferro_device::{Error, Result};
9
10use crate::Tensor;
11
12pub(crate) use validation::validate_permutation;
13pub use validation::{canonicalize_axis_classes, validate_layout};
14
15/// Structured tensor payload with logical axis metadata.
16///
17/// This stores logical tensor metadata separately from the compressed payload.
18/// Dense and diagonal tensors are representation cases of the same type.
19///
20/// # Examples
21///
22/// ```ignore
23/// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
24///
25/// let payload =
26/// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
27/// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
28/// assert_eq!(x.logical_dims(), &[2, 2]);
29/// assert!(x.is_diag());
30/// ```
31#[derive(Debug, Clone)]
32pub struct StructuredTensor<T: Scalar> {
33 payload: Tensor<T>,
34 logical_dims: Vec<usize>,
35 axis_classes: Vec<usize>,
36}
37
38impl<T: Scalar> StructuredTensor<T> {
39 /// Construct a structured tensor from logical metadata and compressed payload.
40 ///
41 /// Axis classes are canonicalized to first-appearance order.
42 ///
43 /// # Examples
44 ///
45 /// ```ignore
46 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
47 ///
48 /// let payload =
49 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
50 /// let x = StructuredTensor::new(vec![2, 2], vec![9, 9], payload).unwrap();
51 /// assert_eq!(x.axis_classes(), &[0, 0]);
52 /// ```
53 pub fn new(
54 logical_dims: Vec<usize>,
55 axis_classes: Vec<usize>,
56 payload: Tensor<T>,
57 ) -> Result<Self> {
58 let axis_classes = canonicalize_axis_classes(&axis_classes);
59 validate_layout(&logical_dims, &axis_classes, &payload)?;
60 Ok(Self {
61 payload,
62 logical_dims,
63 axis_classes,
64 })
65 }
66
67 /// Construct a dense structured tensor.
68 ///
69 /// # Examples
70 ///
71 /// ```ignore
72 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
73 ///
74 /// let dense = Tensor::<f64>::from_slice(
75 /// &[1.0, 2.0, 3.0, 4.0],
76 /// &[2, 2],
77 /// MemoryOrder::ColumnMajor,
78 /// )
79 /// .unwrap();
80 /// let x = StructuredTensor::from_dense(dense);
81 /// assert!(x.is_dense());
82 /// ```
83 pub fn from_dense(payload: Tensor<T>) -> Self {
84 let logical_dims = payload.dims().to_vec();
85 let axis_classes = (0..logical_dims.len()).collect();
86 Self {
87 payload,
88 logical_dims,
89 axis_classes,
90 }
91 }
92
93 /// Construct a diagonal structured tensor from a rank-1 payload.
94 ///
95 /// # Examples
96 ///
97 /// ```ignore
98 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
99 ///
100 /// let payload =
101 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
102 /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
103 /// assert!(x.is_diag());
104 /// ```
105 pub fn from_diagonal_vector(payload: Tensor<T>, logical_rank: usize) -> Result<Self> {
106 if payload.dims().len() != 1 {
107 return Err(Error::InvalidArgument(format!(
108 "from_diagonal_vector expects rank-1 payload, got rank {}",
109 payload.dims().len()
110 )));
111 }
112 if logical_rank == 0 {
113 return Err(Error::InvalidArgument(
114 "from_diagonal_vector requires logical_rank >= 1".to_string(),
115 ));
116 }
117 let n = payload.dims()[0];
118 Self::new(vec![n; logical_rank], vec![0; logical_rank], payload)
119 }
120
121 /// Borrow the compressed payload tensor.
122 ///
123 /// # Examples
124 ///
125 /// ```ignore
126 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
127 ///
128 /// let dense =
129 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
130 /// let x = StructuredTensor::from_dense(dense);
131 /// assert_eq!(x.payload().dims(), &[2]);
132 /// ```
133 pub fn payload(&self) -> &Tensor<T> {
134 &self.payload
135 }
136
137 /// Consume the structured tensor and return the compressed payload tensor.
138 ///
139 /// # Examples
140 ///
141 /// ```ignore
142 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
143 ///
144 /// let dense =
145 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
146 /// let x = StructuredTensor::from_dense(dense);
147 /// let payload = x.into_payload();
148 /// assert_eq!(payload.dims(), &[2]);
149 /// ```
150 pub fn into_payload(self) -> Tensor<T> {
151 self.payload
152 }
153
154 /// Returns logical dimensions.
155 ///
156 /// # Examples
157 ///
158 /// ```ignore
159 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
160 ///
161 /// let payload =
162 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
163 /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
164 /// assert_eq!(x.logical_dims(), &[2, 2]);
165 /// ```
166 pub fn logical_dims(&self) -> &[usize] {
167 &self.logical_dims
168 }
169
170 /// Returns axis classes for logical axes.
171 ///
172 /// # Examples
173 ///
174 /// ```ignore
175 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
176 ///
177 /// let payload =
178 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
179 /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
180 /// assert_eq!(x.axis_classes(), &[0, 0]);
181 /// ```
182 pub fn axis_classes(&self) -> &[usize] {
183 &self.axis_classes
184 }
185
186 /// Returns the number of distinct axis classes.
187 ///
188 /// # Examples
189 ///
190 /// ```ignore
191 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
192 ///
193 /// let payload =
194 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
195 /// let x = StructuredTensor::from_diagonal_vector(payload, 3).unwrap();
196 /// assert_eq!(x.class_count(), 1);
197 /// ```
198 pub fn class_count(&self) -> usize {
199 self.payload.dims().len()
200 }
201
202 /// Returns `true` when the layout is dense.
203 ///
204 /// # Examples
205 ///
206 /// ```ignore
207 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
208 ///
209 /// let dense = Tensor::<f64>::from_slice(
210 /// &[1.0, 2.0, 3.0, 4.0],
211 /// &[2, 2],
212 /// MemoryOrder::ColumnMajor,
213 /// )
214 /// .unwrap();
215 /// let x = StructuredTensor::from_dense(dense);
216 /// assert!(x.is_dense());
217 /// ```
218 pub fn is_dense(&self) -> bool {
219 self.axis_classes.len() == self.logical_dims.len()
220 && self.logical_dims == self.payload.dims()
221 && self
222 .axis_classes
223 .iter()
224 .enumerate()
225 .all(|(i, &class_id)| class_id == i)
226 }
227
228 /// Returns `true` when the layout is a pure diagonal.
229 ///
230 /// # Examples
231 ///
232 /// ```ignore
233 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
234 ///
235 /// let payload =
236 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
237 /// let x = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
238 /// assert!(x.is_diag());
239 /// ```
240 pub fn is_diag(&self) -> bool {
241 if self.logical_dims.is_empty() || self.axis_classes.len() != self.logical_dims.len() {
242 return false;
243 }
244 let first_dim = self.logical_dims[0];
245 self.axis_classes.iter().all(|&class_id| class_id == 0)
246 && self.logical_dims.iter().all(|&dim| dim == first_dim)
247 && self.payload.dims().len() == 1
248 && self.payload.dims()[0] == first_dim
249 }
250
251 /// Rebuild the same structured layout with a different payload tensor.
252 ///
253 /// # Examples
254 ///
255 /// ```ignore
256 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
257 ///
258 /// let payload =
259 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
260 /// let layout = StructuredTensor::from_diagonal_vector(payload, 2).unwrap();
261 /// let replacement =
262 /// Tensor::<f64>::from_slice(&[3.0, 4.0], &[2], MemoryOrder::ColumnMajor).unwrap();
263 /// let updated = layout.with_payload_like(replacement).unwrap();
264 /// assert!(updated.is_diag());
265 /// ```
266 pub fn with_payload_like(&self, payload: Tensor<T>) -> Result<Self> {
267 Self::new(
268 self.logical_dims.clone(),
269 self.axis_classes.clone(),
270 payload,
271 )
272 }
273
274 /// Construct a structured tensor without re-validating the metadata.
275 ///
276 /// # Safety
277 ///
278 /// Callers must ensure that `logical_dims`, `axis_classes`, and `payload`
279 /// already satisfy the invariants enforced by [`StructuredTensor::new`].
280 ///
281 /// # Examples
282 ///
283 /// ```ignore
284 /// use tenferro_tensor::{MemoryOrder, StructuredTensor, Tensor};
285 ///
286 /// let payload =
287 /// Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
288 /// let x = StructuredTensor::from_validated_parts(vec![2, 2], vec![0, 0], payload);
289 /// assert!(x.is_diag());
290 /// ```
291 pub fn from_validated_parts(
292 logical_dims: Vec<usize>,
293 axis_classes: Vec<usize>,
294 payload: Tensor<T>,
295 ) -> Self {
296 Self {
297 payload,
298 logical_dims,
299 axis_classes,
300 }
301 }
302}
303
304impl<T: Scalar> From<Tensor<T>> for StructuredTensor<T> {
305 fn from(value: Tensor<T>) -> Self {
306 Self::from_dense(value)
307 }
308}
309
310impl<T: Scalar> AsRef<Tensor<T>> for StructuredTensor<T> {
311 fn as_ref(&self) -> &Tensor<T> {
312 self.payload()
313 }
314}
315
316#[cfg(test)]
317mod tests;