tenferro_tensor/structured_tensor/
validation.rs

1use std::collections::HashMap;
2
3use tenferro_algebra::Scalar;
4use tenferro_device::{Error, Result};
5
6use crate::Tensor;
7
8/// Canonicalize arbitrary axis class IDs to first-appearance order.
9///
10/// # Examples
11///
12/// ```ignore
13/// use tenferro_tensor::structured_tensor::canonicalize_axis_classes;
14///
15/// assert_eq!(
16///     canonicalize_axis_classes(&[4, 9, 4, 7, 9]),
17///     vec![0, 1, 0, 2, 1],
18/// );
19/// ```
20pub fn canonicalize_axis_classes(classes: &[usize]) -> Vec<usize> {
21    let mut map = HashMap::new();
22    let mut next = 0usize;
23    classes
24        .iter()
25        .map(|&class_id| {
26            if let Some(&mapped) = map.get(&class_id) {
27                mapped
28            } else {
29                let mapped = next;
30                next += 1;
31                map.insert(class_id, mapped);
32                mapped
33            }
34        })
35        .collect()
36}
37
38/// Validate structured-tensor metadata against a compressed payload.
39///
40/// # Examples
41///
42/// ```ignore
43/// use tenferro_tensor::{structured_tensor::validate_layout, MemoryOrder, Tensor};
44///
45/// let payload =
46///     Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
47/// validate_layout(&[2, 2], &[0, 0], &payload).unwrap();
48/// ```
49pub fn validate_layout<T: Scalar>(
50    logical_dims: &[usize],
51    axis_classes: &[usize],
52    payload: &Tensor<T>,
53) -> Result<()> {
54    if logical_dims.len() != axis_classes.len() {
55        return Err(Error::InvalidArgument(format!(
56            "logical_dims length ({}) must match axis_classes length ({})",
57            logical_dims.len(),
58            axis_classes.len(),
59        )));
60    }
61    if logical_dims.is_empty() && payload.dims().is_empty() {
62        return Ok(());
63    }
64
65    let class_count = axis_classes
66        .iter()
67        .copied()
68        .max()
69        .map(|value| value + 1)
70        .unwrap_or(0);
71    if payload.dims().len() != class_count {
72        return Err(Error::InvalidArgument(format!(
73            "payload rank {} must equal number of classes {}",
74            payload.dims().len(),
75            class_count,
76        )));
77    }
78
79    let mut class_dims = vec![None; class_count];
80    for (&dim, &class_id) in logical_dims.iter().zip(axis_classes.iter()) {
81        if let Some(existing) = class_dims[class_id] {
82            if existing != dim {
83                return Err(Error::InvalidArgument(format!(
84                    "axis class {class_id} has inconsistent logical dims: {existing} vs {dim}",
85                )));
86            }
87        } else {
88            class_dims[class_id] = Some(dim);
89        }
90    }
91
92    for (class_id, maybe_dim) in class_dims.iter().enumerate() {
93        let expected = maybe_dim.unwrap_or(0);
94        let got = payload.dims()[class_id];
95        if expected != got {
96            return Err(Error::InvalidArgument(format!(
97                "payload dim mismatch for class {class_id}: expected {expected}, got {got}",
98            )));
99        }
100    }
101
102    Ok(())
103}
104
105/// Validate that `perm` is a complete permutation of `0..rank`.
106///
107/// # Examples
108///
109/// ```ignore
110/// use tenferro_tensor::structured_tensor::validate_permutation;
111///
112/// validate_permutation(&[1, 0], 2, "example").unwrap();
113/// ```
114pub(crate) fn validate_permutation(
115    perm: &[usize],
116    rank: usize,
117    op_name: &'static str,
118) -> Result<()> {
119    if perm.len() != rank {
120        return Err(Error::InvalidArgument(format!(
121            "{op_name} requires permutation length {rank}, got {}",
122            perm.len()
123        )));
124    }
125
126    let mut seen = vec![false; rank];
127    for &axis in perm {
128        if axis >= rank {
129            return Err(Error::InvalidArgument(format!(
130                "{op_name} permutation index {axis} out of range for rank {rank}",
131            )));
132        }
133        if seen[axis] {
134            return Err(Error::InvalidArgument(format!(
135                "{op_name} permutation contains duplicate axis {axis}",
136            )));
137        }
138        seen[axis] = true;
139    }
140
141    Ok(())
142}