Skip to main content

tenferro_tensor/
config.rs

1/// DotGeneral dimension configuration.
2///
3/// The output shape is `[lhs_free..., rhs_free..., batch...]` (col-major
4/// batch-trailing convention). Batch dims have the largest stride so that
5/// each batch slice occupies a contiguous block of memory.
6///
7/// # Examples
8///
9/// ```ignore
10/// use tenferro_tensor::DotGeneralConfig;
11///
12/// let config = DotGeneralConfig {
13///     lhs_contracting_dims: vec![1],
14///     rhs_contracting_dims: vec![0],
15///     lhs_batch_dims: vec![],
16///     rhs_batch_dims: vec![],
17///     lhs_rank: 2,
18///     rhs_rank: 2,
19/// };
20/// ```
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct DotGeneralConfig {
23    pub lhs_contracting_dims: Vec<usize>,
24    pub rhs_contracting_dims: Vec<usize>,
25    pub lhs_batch_dims: Vec<usize>,
26    pub rhs_batch_dims: Vec<usize>,
27    pub lhs_rank: usize,
28    pub rhs_rank: usize,
29}
30
31impl DotGeneralConfig {
32    /// Validate that `lhs_rank` and `rhs_rank` match the actual tensor ranks.
33    ///
34    /// # Examples
35    ///
36    /// ```ignore
37    /// use tenferro_tensor::DotGeneralConfig;
38    ///
39    /// let config = DotGeneralConfig {
40    ///     lhs_contracting_dims: vec![1],
41    ///     rhs_contracting_dims: vec![0],
42    ///     lhs_batch_dims: vec![],
43    ///     rhs_batch_dims: vec![],
44    ///     lhs_rank: 2,
45    ///     rhs_rank: 2,
46    /// };
47    /// config.validate_ranks(2, 2).unwrap();
48    /// ```
49    pub fn validate_ranks(
50        &self,
51        actual_lhs_rank: usize,
52        actual_rhs_rank: usize,
53    ) -> Result<(), String> {
54        if self.lhs_rank != actual_lhs_rank {
55            return Err(format!(
56                "DotGeneralConfig.lhs_rank ({}) does not match actual lhs tensor rank ({})",
57                self.lhs_rank, actual_lhs_rank
58            ));
59        }
60        if self.rhs_rank != actual_rhs_rank {
61            return Err(format!(
62                "DotGeneralConfig.rhs_rank ({}) does not match actual rhs tensor rank ({})",
63                self.rhs_rank, actual_rhs_rank
64            ));
65        }
66        Ok(())
67    }
68
69    /// Validate that all dimension indices are within range for the stored ranks
70    /// and that no axis appears in multiple roles.
71    ///
72    /// # Examples
73    ///
74    /// ```ignore
75    /// use tenferro_tensor::DotGeneralConfig;
76    ///
77    /// let config = DotGeneralConfig {
78    ///     lhs_contracting_dims: vec![1],
79    ///     rhs_contracting_dims: vec![0],
80    ///     lhs_batch_dims: vec![],
81    ///     rhs_batch_dims: vec![],
82    ///     lhs_rank: 2,
83    ///     rhs_rank: 2,
84    /// };
85    /// config.validate_dims().unwrap();
86    /// ```
87    fn check_no_duplicates(dims: &[usize], label: &str) -> Result<(), String> {
88        let mut seen = std::collections::HashSet::new();
89        for &d in dims {
90            if !seen.insert(d) {
91                return Err(format!("{} contains duplicate dim {}", label, d));
92            }
93        }
94        Ok(())
95    }
96
97    pub fn validate_dims(&self) -> Result<(), String> {
98        for &d in &self.lhs_contracting_dims {
99            if d >= self.lhs_rank {
100                return Err(format!(
101                    "lhs_contracting_dim {} out of bounds for lhs_rank {}",
102                    d, self.lhs_rank
103                ));
104            }
105        }
106        for &d in &self.rhs_contracting_dims {
107            if d >= self.rhs_rank {
108                return Err(format!(
109                    "rhs_contracting_dim {} out of bounds for rhs_rank {}",
110                    d, self.rhs_rank
111                ));
112            }
113        }
114        for &d in &self.lhs_batch_dims {
115            if d >= self.lhs_rank {
116                return Err(format!(
117                    "lhs_batch_dim {} out of bounds for lhs_rank {}",
118                    d, self.lhs_rank
119                ));
120            }
121        }
122        for &d in &self.rhs_batch_dims {
123            if d >= self.rhs_rank {
124                return Err(format!(
125                    "rhs_batch_dim {} out of bounds for rhs_rank {}",
126                    d, self.rhs_rank
127                ));
128            }
129        }
130        Self::check_no_duplicates(&self.lhs_contracting_dims, "lhs_contracting_dims")?;
131        Self::check_no_duplicates(&self.rhs_contracting_dims, "rhs_contracting_dims")?;
132        Self::check_no_duplicates(&self.lhs_batch_dims, "lhs_batch_dims")?;
133        Self::check_no_duplicates(&self.rhs_batch_dims, "rhs_batch_dims")?;
134        for &d in &self.lhs_contracting_dims {
135            if self.lhs_batch_dims.contains(&d) {
136                return Err(format!(
137                    "lhs dim {} appears in both contracting and batch dims",
138                    d
139                ));
140            }
141        }
142        for &d in &self.rhs_contracting_dims {
143            if self.rhs_batch_dims.contains(&d) {
144                return Err(format!(
145                    "rhs dim {} appears in both contracting and batch dims",
146                    d
147                ));
148            }
149        }
150        if self.lhs_contracting_dims.len() != self.rhs_contracting_dims.len() {
151            return Err(format!(
152                "lhs/rhs contracting dim counts differ ({} vs {})",
153                self.lhs_contracting_dims.len(),
154                self.rhs_contracting_dims.len()
155            ));
156        }
157        if self.lhs_batch_dims.len() != self.rhs_batch_dims.len() {
158            return Err(format!(
159                "lhs/rhs batch dim counts differ ({} vs {})",
160                self.lhs_batch_dims.len(),
161                self.rhs_batch_dims.len()
162            ));
163        }
164        Ok(())
165    }
166}
167
168/// Comparison direction.
169///
170/// # Examples
171///
172/// ```ignore
173/// use tenferro_tensor::CompareDir;
174///
175/// let dir = CompareDir::Eq;
176/// ```
177#[derive(Clone, Debug, Hash, PartialEq, Eq)]
178pub enum CompareDir {
179    Eq,
180    Lt,
181    Le,
182    Gt,
183    Ge,
184}
185
186/// StableHLO gather dimension configuration.
187///
188/// # Examples
189///
190/// ```ignore
191/// use tenferro_tensor::GatherConfig;
192///
193/// let config = GatherConfig {
194///     offset_dims: vec![],
195///     collapsed_slice_dims: vec![0],
196///     start_index_map: vec![0],
197///     index_vector_dim: 1,
198///     slice_sizes: vec![1],
199/// };
200/// ```
201#[derive(Clone, Debug, Hash, PartialEq, Eq)]
202pub struct GatherConfig {
203    pub offset_dims: Vec<usize>,
204    pub collapsed_slice_dims: Vec<usize>,
205    pub start_index_map: Vec<usize>,
206    pub index_vector_dim: usize,
207    pub slice_sizes: Vec<usize>,
208}
209
210/// StableHLO scatter dimension configuration.
211///
212/// # Examples
213///
214/// ```ignore
215/// use tenferro_tensor::ScatterConfig;
216///
217/// let config = ScatterConfig {
218///     update_window_dims: vec![],
219///     inserted_window_dims: vec![0],
220///     scatter_dims_to_operand_dims: vec![0],
221///     index_vector_dim: 1,
222/// };
223/// ```
224#[derive(Clone, Debug, Hash, PartialEq, Eq)]
225pub struct ScatterConfig {
226    pub update_window_dims: Vec<usize>,
227    pub inserted_window_dims: Vec<usize>,
228    pub scatter_dims_to_operand_dims: Vec<usize>,
229    pub index_vector_dim: usize,
230}
231
232/// Slice configuration.
233///
234/// # Examples
235///
236/// ```ignore
237/// use tenferro_tensor::SliceConfig;
238///
239/// let config = SliceConfig {
240///     starts: vec![0],
241///     limits: vec![2],
242///     strides: vec![1],
243/// };
244/// ```
245#[derive(Clone, Debug, Hash, PartialEq, Eq)]
246pub struct SliceConfig {
247    pub starts: Vec<usize>,
248    pub limits: Vec<usize>,
249    pub strides: Vec<usize>,
250}
251
252/// StableHLO pad configuration.
253///
254/// # Examples
255///
256/// ```ignore
257/// use tenferro_tensor::PadConfig;
258///
259/// let config = PadConfig {
260///     edge_padding_low: vec![1, 1],
261///     edge_padding_high: vec![1, 1],
262///     interior_padding: vec![0, 0],
263/// };
264/// ```
265#[derive(Clone, Debug, Hash, PartialEq, Eq)]
266pub struct PadConfig {
267    pub edge_padding_low: Vec<i64>,
268    pub edge_padding_high: Vec<i64>,
269    pub interior_padding: Vec<i64>,
270}