tenferro_tensor/config.rs
1/// DotGeneral dimension configuration.
2///
3/// Records only the dim-numbering roles (contracting / batch; free is derived).
4/// Rank info travels with the enclosing `StdTensorOp::DotGeneral` variant at
5/// the trace/StdTensorOp layer, and with `ExecInstruction::output_shapes` at
6/// the exec layer. This separation makes it structurally impossible for
7/// stored ranks to drift from actual tensor ranks (issue #664).
8///
9/// The output shape is `[lhs_free..., rhs_free..., batch...]` (col-major
10/// batch-trailing convention). Batch dims have the largest stride so that
11/// each batch slice occupies a contiguous block of memory.
12///
13/// # Examples
14///
15/// ```rust
16/// use tenferro_tensor::DotGeneralConfig;
17///
18/// let config = DotGeneralConfig {
19/// lhs_contracting_dims: vec![1],
20/// rhs_contracting_dims: vec![0],
21/// lhs_batch_dims: vec![],
22/// rhs_batch_dims: vec![],
23/// };
24/// ```
25#[derive(Clone, Debug, Hash, PartialEq, Eq)]
26pub struct DotGeneralConfig {
27 pub lhs_contracting_dims: Vec<usize>,
28 pub rhs_contracting_dims: Vec<usize>,
29 pub lhs_batch_dims: Vec<usize>,
30 pub rhs_batch_dims: Vec<usize>,
31}
32
33impl DotGeneralConfig {
34 fn check_no_duplicates(dims: &[usize], label: &str) -> Result<(), String> {
35 let mut seen = std::collections::HashSet::new();
36 for &d in dims {
37 if !seen.insert(d) {
38 return Err(format!("{} contains duplicate dim {}", label, d));
39 }
40 }
41 Ok(())
42 }
43
44 /// Validate that all dimension indices are within range for the given
45 /// explicit ranks and that no axis appears in multiple roles.
46 ///
47 /// Call sites supply the actual operand ranks (from the tensor shapes they
48 /// have in hand). The config itself carries only the dim-numbering roles.
49 ///
50 /// # Examples
51 ///
52 /// ```rust
53 /// use tenferro_tensor::DotGeneralConfig;
54 ///
55 /// let config = DotGeneralConfig {
56 /// lhs_contracting_dims: vec![1],
57 /// rhs_contracting_dims: vec![0],
58 /// lhs_batch_dims: vec![],
59 /// rhs_batch_dims: vec![],
60 /// };
61 /// config.validate_dims_with_ranks(2, 2).unwrap();
62 /// ```
63 pub fn validate_dims_with_ranks(&self, lhs_rank: usize, rhs_rank: usize) -> Result<(), String> {
64 for &d in &self.lhs_contracting_dims {
65 if d >= lhs_rank {
66 return Err(format!(
67 "lhs_contracting_dim {} out of bounds for lhs_rank {}",
68 d, lhs_rank
69 ));
70 }
71 }
72 for &d in &self.rhs_contracting_dims {
73 if d >= rhs_rank {
74 return Err(format!(
75 "rhs_contracting_dim {} out of bounds for rhs_rank {}",
76 d, rhs_rank
77 ));
78 }
79 }
80 for &d in &self.lhs_batch_dims {
81 if d >= lhs_rank {
82 return Err(format!(
83 "lhs_batch_dim {} out of bounds for lhs_rank {}",
84 d, lhs_rank
85 ));
86 }
87 }
88 for &d in &self.rhs_batch_dims {
89 if d >= rhs_rank {
90 return Err(format!(
91 "rhs_batch_dim {} out of bounds for rhs_rank {}",
92 d, rhs_rank
93 ));
94 }
95 }
96 Self::check_no_duplicates(&self.lhs_contracting_dims, "lhs_contracting_dims")?;
97 Self::check_no_duplicates(&self.rhs_contracting_dims, "rhs_contracting_dims")?;
98 Self::check_no_duplicates(&self.lhs_batch_dims, "lhs_batch_dims")?;
99 Self::check_no_duplicates(&self.rhs_batch_dims, "rhs_batch_dims")?;
100 for &d in &self.lhs_contracting_dims {
101 if self.lhs_batch_dims.contains(&d) {
102 return Err(format!(
103 "lhs dim {} appears in both contracting and batch dims",
104 d
105 ));
106 }
107 }
108 for &d in &self.rhs_contracting_dims {
109 if self.rhs_batch_dims.contains(&d) {
110 return Err(format!(
111 "rhs dim {} appears in both contracting and batch dims",
112 d
113 ));
114 }
115 }
116 if self.lhs_contracting_dims.len() != self.rhs_contracting_dims.len() {
117 return Err(format!(
118 "lhs/rhs contracting dim counts differ ({} vs {})",
119 self.lhs_contracting_dims.len(),
120 self.rhs_contracting_dims.len()
121 ));
122 }
123 if self.lhs_batch_dims.len() != self.rhs_batch_dims.len() {
124 return Err(format!(
125 "lhs/rhs batch dim counts differ ({} vs {})",
126 self.lhs_batch_dims.len(),
127 self.rhs_batch_dims.len()
128 ));
129 }
130 Ok(())
131 }
132}
133
134/// Comparison direction.
135///
136/// # Examples
137///
138/// ```rust
139/// use tenferro_tensor::CompareDir;
140///
141/// let dir = CompareDir::Eq;
142/// ```
143#[derive(Clone, Debug, Hash, PartialEq, Eq)]
144pub enum CompareDir {
145 Eq,
146 Lt,
147 Le,
148 Gt,
149 Ge,
150}
151
152/// StableHLO gather dimension configuration.
153///
154/// # Examples
155///
156/// ```rust
157/// use tenferro_tensor::GatherConfig;
158///
159/// let config = GatherConfig {
160/// offset_dims: vec![],
161/// collapsed_slice_dims: vec![0],
162/// start_index_map: vec![0],
163/// index_vector_dim: 1,
164/// slice_sizes: vec![1],
165/// };
166/// ```
167#[derive(Clone, Debug, Hash, PartialEq, Eq)]
168pub struct GatherConfig {
169 pub offset_dims: Vec<usize>,
170 pub collapsed_slice_dims: Vec<usize>,
171 pub start_index_map: Vec<usize>,
172 pub index_vector_dim: usize,
173 pub slice_sizes: Vec<usize>,
174}
175
176/// StableHLO scatter dimension configuration.
177///
178/// # Examples
179///
180/// ```rust
181/// use tenferro_tensor::ScatterConfig;
182///
183/// let config = ScatterConfig {
184/// update_window_dims: vec![],
185/// inserted_window_dims: vec![0],
186/// scatter_dims_to_operand_dims: vec![0],
187/// index_vector_dim: 1,
188/// };
189/// ```
190#[derive(Clone, Debug, Hash, PartialEq, Eq)]
191pub struct ScatterConfig {
192 pub update_window_dims: Vec<usize>,
193 pub inserted_window_dims: Vec<usize>,
194 pub scatter_dims_to_operand_dims: Vec<usize>,
195 pub index_vector_dim: usize,
196}
197
198/// Slice configuration.
199///
200/// # Examples
201///
202/// ```rust
203/// use tenferro_tensor::SliceConfig;
204///
205/// let config = SliceConfig {
206/// starts: vec![0],
207/// limits: vec![2],
208/// strides: vec![1],
209/// };
210/// ```
211#[derive(Clone, Debug, Hash, PartialEq, Eq)]
212pub struct SliceConfig {
213 pub starts: Vec<usize>,
214 pub limits: Vec<usize>,
215 pub strides: Vec<usize>,
216}
217
218/// StableHLO pad configuration.
219///
220/// # Examples
221///
222/// ```rust
223/// use tenferro_tensor::PadConfig;
224///
225/// let config = PadConfig {
226/// edge_padding_low: vec![1, 1],
227/// edge_padding_high: vec![1, 1],
228/// interior_padding: vec![0, 0],
229/// };
230/// ```
231#[derive(Clone, Debug, Hash, PartialEq, Eq)]
232pub struct PadConfig {
233 pub edge_padding_low: Vec<i64>,
234 pub edge_padding_high: Vec<i64>,
235 pub interior_padding: Vec<i64>,
236}