tenferro_tensor/structured_tensor/
validation.rs1use std::collections::HashMap;
2
3use tenferro_algebra::Scalar;
4use tenferro_device::{Error, Result};
5
6use crate::Tensor;
7
8pub fn canonicalize_axis_classes(classes: &[usize]) -> Vec<usize> {
21 let mut map = HashMap::new();
22 let mut next = 0usize;
23 classes
24 .iter()
25 .map(|&class_id| {
26 if let Some(&mapped) = map.get(&class_id) {
27 mapped
28 } else {
29 let mapped = next;
30 next += 1;
31 map.insert(class_id, mapped);
32 mapped
33 }
34 })
35 .collect()
36}
37
38pub fn validate_layout<T: Scalar>(
50 logical_dims: &[usize],
51 axis_classes: &[usize],
52 payload: &Tensor<T>,
53) -> Result<()> {
54 if logical_dims.len() != axis_classes.len() {
55 return Err(Error::InvalidArgument(format!(
56 "logical_dims length ({}) must match axis_classes length ({})",
57 logical_dims.len(),
58 axis_classes.len(),
59 )));
60 }
61 if logical_dims.is_empty() && payload.dims().is_empty() {
62 return Ok(());
63 }
64
65 let class_count = axis_classes
66 .iter()
67 .copied()
68 .max()
69 .map(|value| value + 1)
70 .unwrap_or(0);
71 if payload.dims().len() != class_count {
72 return Err(Error::InvalidArgument(format!(
73 "payload rank {} must equal number of classes {}",
74 payload.dims().len(),
75 class_count,
76 )));
77 }
78
79 let mut class_dims = vec![None; class_count];
80 for (&dim, &class_id) in logical_dims.iter().zip(axis_classes.iter()) {
81 if let Some(existing) = class_dims[class_id] {
82 if existing != dim {
83 return Err(Error::InvalidArgument(format!(
84 "axis class {class_id} has inconsistent logical dims: {existing} vs {dim}",
85 )));
86 }
87 } else {
88 class_dims[class_id] = Some(dim);
89 }
90 }
91
92 for (class_id, maybe_dim) in class_dims.iter().enumerate() {
93 let expected = maybe_dim.unwrap_or(0);
94 let got = payload.dims()[class_id];
95 if expected != got {
96 return Err(Error::InvalidArgument(format!(
97 "payload dim mismatch for class {class_id}: expected {expected}, got {got}",
98 )));
99 }
100 }
101
102 Ok(())
103}
104
105pub(crate) fn validate_permutation(
115 perm: &[usize],
116 rank: usize,
117 op_name: &'static str,
118) -> Result<()> {
119 if perm.len() != rank {
120 return Err(Error::InvalidArgument(format!(
121 "{op_name} requires permutation length {rank}, got {}",
122 perm.len()
123 )));
124 }
125
126 let mut seen = vec![false; rank];
127 for &axis in perm {
128 if axis >= rank {
129 return Err(Error::InvalidArgument(format!(
130 "{op_name} permutation index {axis} out of range for rank {rank}",
131 )));
132 }
133 if seen[axis] {
134 return Err(Error::InvalidArgument(format!(
135 "{op_name} permutation contains duplicate axis {axis}",
136 )));
137 }
138 seen[axis] = true;
139 }
140
141 Ok(())
142}