tenferro_tensor/config.rs
1/// DotGeneral dimension configuration.
2///
3/// The output shape is `[lhs_free..., rhs_free..., batch...]` (col-major
4/// batch-trailing convention). Batch dims have the largest stride so that
5/// each batch slice occupies a contiguous block of memory.
6///
7/// # Examples
8///
9/// ```ignore
10/// use tenferro_tensor::DotGeneralConfig;
11///
12/// let config = DotGeneralConfig {
13/// lhs_contracting_dims: vec![1],
14/// rhs_contracting_dims: vec![0],
15/// lhs_batch_dims: vec![],
16/// rhs_batch_dims: vec![],
17/// lhs_rank: 2,
18/// rhs_rank: 2,
19/// };
20/// ```
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct DotGeneralConfig {
23 pub lhs_contracting_dims: Vec<usize>,
24 pub rhs_contracting_dims: Vec<usize>,
25 pub lhs_batch_dims: Vec<usize>,
26 pub rhs_batch_dims: Vec<usize>,
27 pub lhs_rank: usize,
28 pub rhs_rank: usize,
29}
30
31impl DotGeneralConfig {
32 /// Validate that `lhs_rank` and `rhs_rank` match the actual tensor ranks.
33 ///
34 /// # Examples
35 ///
36 /// ```ignore
37 /// use tenferro_tensor::DotGeneralConfig;
38 ///
39 /// let config = DotGeneralConfig {
40 /// lhs_contracting_dims: vec![1],
41 /// rhs_contracting_dims: vec![0],
42 /// lhs_batch_dims: vec![],
43 /// rhs_batch_dims: vec![],
44 /// lhs_rank: 2,
45 /// rhs_rank: 2,
46 /// };
47 /// config.validate_ranks(2, 2).unwrap();
48 /// ```
49 pub fn validate_ranks(
50 &self,
51 actual_lhs_rank: usize,
52 actual_rhs_rank: usize,
53 ) -> Result<(), String> {
54 if self.lhs_rank != actual_lhs_rank {
55 return Err(format!(
56 "DotGeneralConfig.lhs_rank ({}) does not match actual lhs tensor rank ({})",
57 self.lhs_rank, actual_lhs_rank
58 ));
59 }
60 if self.rhs_rank != actual_rhs_rank {
61 return Err(format!(
62 "DotGeneralConfig.rhs_rank ({}) does not match actual rhs tensor rank ({})",
63 self.rhs_rank, actual_rhs_rank
64 ));
65 }
66 Ok(())
67 }
68
69 /// Validate that all dimension indices are within range for the stored ranks
70 /// and that no axis appears in multiple roles.
71 ///
72 /// # Examples
73 ///
74 /// ```ignore
75 /// use tenferro_tensor::DotGeneralConfig;
76 ///
77 /// let config = DotGeneralConfig {
78 /// lhs_contracting_dims: vec![1],
79 /// rhs_contracting_dims: vec![0],
80 /// lhs_batch_dims: vec![],
81 /// rhs_batch_dims: vec![],
82 /// lhs_rank: 2,
83 /// rhs_rank: 2,
84 /// };
85 /// config.validate_dims().unwrap();
86 /// ```
87 fn check_no_duplicates(dims: &[usize], label: &str) -> Result<(), String> {
88 let mut seen = std::collections::HashSet::new();
89 for &d in dims {
90 if !seen.insert(d) {
91 return Err(format!("{} contains duplicate dim {}", label, d));
92 }
93 }
94 Ok(())
95 }
96
97 pub fn validate_dims(&self) -> Result<(), String> {
98 for &d in &self.lhs_contracting_dims {
99 if d >= self.lhs_rank {
100 return Err(format!(
101 "lhs_contracting_dim {} out of bounds for lhs_rank {}",
102 d, self.lhs_rank
103 ));
104 }
105 }
106 for &d in &self.rhs_contracting_dims {
107 if d >= self.rhs_rank {
108 return Err(format!(
109 "rhs_contracting_dim {} out of bounds for rhs_rank {}",
110 d, self.rhs_rank
111 ));
112 }
113 }
114 for &d in &self.lhs_batch_dims {
115 if d >= self.lhs_rank {
116 return Err(format!(
117 "lhs_batch_dim {} out of bounds for lhs_rank {}",
118 d, self.lhs_rank
119 ));
120 }
121 }
122 for &d in &self.rhs_batch_dims {
123 if d >= self.rhs_rank {
124 return Err(format!(
125 "rhs_batch_dim {} out of bounds for rhs_rank {}",
126 d, self.rhs_rank
127 ));
128 }
129 }
130 Self::check_no_duplicates(&self.lhs_contracting_dims, "lhs_contracting_dims")?;
131 Self::check_no_duplicates(&self.rhs_contracting_dims, "rhs_contracting_dims")?;
132 Self::check_no_duplicates(&self.lhs_batch_dims, "lhs_batch_dims")?;
133 Self::check_no_duplicates(&self.rhs_batch_dims, "rhs_batch_dims")?;
134 for &d in &self.lhs_contracting_dims {
135 if self.lhs_batch_dims.contains(&d) {
136 return Err(format!(
137 "lhs dim {} appears in both contracting and batch dims",
138 d
139 ));
140 }
141 }
142 for &d in &self.rhs_contracting_dims {
143 if self.rhs_batch_dims.contains(&d) {
144 return Err(format!(
145 "rhs dim {} appears in both contracting and batch dims",
146 d
147 ));
148 }
149 }
150 if self.lhs_contracting_dims.len() != self.rhs_contracting_dims.len() {
151 return Err(format!(
152 "lhs/rhs contracting dim counts differ ({} vs {})",
153 self.lhs_contracting_dims.len(),
154 self.rhs_contracting_dims.len()
155 ));
156 }
157 if self.lhs_batch_dims.len() != self.rhs_batch_dims.len() {
158 return Err(format!(
159 "lhs/rhs batch dim counts differ ({} vs {})",
160 self.lhs_batch_dims.len(),
161 self.rhs_batch_dims.len()
162 ));
163 }
164 Ok(())
165 }
166}
167
168/// Comparison direction.
169///
170/// # Examples
171///
172/// ```ignore
173/// use tenferro_tensor::CompareDir;
174///
175/// let dir = CompareDir::Eq;
176/// ```
177#[derive(Clone, Debug, Hash, PartialEq, Eq)]
178pub enum CompareDir {
179 Eq,
180 Lt,
181 Le,
182 Gt,
183 Ge,
184}
185
186/// StableHLO gather dimension configuration.
187///
188/// # Examples
189///
190/// ```ignore
191/// use tenferro_tensor::GatherConfig;
192///
193/// let config = GatherConfig {
194/// offset_dims: vec![],
195/// collapsed_slice_dims: vec![0],
196/// start_index_map: vec![0],
197/// index_vector_dim: 1,
198/// slice_sizes: vec![1],
199/// };
200/// ```
201#[derive(Clone, Debug, Hash, PartialEq, Eq)]
202pub struct GatherConfig {
203 pub offset_dims: Vec<usize>,
204 pub collapsed_slice_dims: Vec<usize>,
205 pub start_index_map: Vec<usize>,
206 pub index_vector_dim: usize,
207 pub slice_sizes: Vec<usize>,
208}
209
210/// StableHLO scatter dimension configuration.
211///
212/// # Examples
213///
214/// ```ignore
215/// use tenferro_tensor::ScatterConfig;
216///
217/// let config = ScatterConfig {
218/// update_window_dims: vec![],
219/// inserted_window_dims: vec![0],
220/// scatter_dims_to_operand_dims: vec![0],
221/// index_vector_dim: 1,
222/// };
223/// ```
224#[derive(Clone, Debug, Hash, PartialEq, Eq)]
225pub struct ScatterConfig {
226 pub update_window_dims: Vec<usize>,
227 pub inserted_window_dims: Vec<usize>,
228 pub scatter_dims_to_operand_dims: Vec<usize>,
229 pub index_vector_dim: usize,
230}
231
232/// Slice configuration.
233///
234/// # Examples
235///
236/// ```ignore
237/// use tenferro_tensor::SliceConfig;
238///
239/// let config = SliceConfig {
240/// starts: vec![0],
241/// limits: vec![2],
242/// strides: vec![1],
243/// };
244/// ```
245#[derive(Clone, Debug, Hash, PartialEq, Eq)]
246pub struct SliceConfig {
247 pub starts: Vec<usize>,
248 pub limits: Vec<usize>,
249 pub strides: Vec<usize>,
250}
251
252/// StableHLO pad configuration.
253///
254/// # Examples
255///
256/// ```ignore
257/// use tenferro_tensor::PadConfig;
258///
259/// let config = PadConfig {
260/// edge_padding_low: vec![1, 1],
261/// edge_padding_high: vec![1, 1],
262/// interior_padding: vec![0, 0],
263/// };
264/// ```
265#[derive(Clone, Debug, Hash, PartialEq, Eq)]
266pub struct PadConfig {
267 pub edge_padding_low: Vec<i64>,
268 pub edge_padding_high: Vec<i64>,
269 pub interior_padding: Vec<i64>,
270}