Skip to main content

tenferro_tensor/
config.rs

1/// DotGeneral dimension configuration.
2///
3/// Records only the dim-numbering roles (contracting / batch; free is derived).
4/// Rank info travels with the enclosing `StdTensorOp::DotGeneral` variant at
5/// the trace/StdTensorOp layer, and with `ExecInstruction::output_shapes` at
6/// the exec layer. This separation makes it structurally impossible for
7/// stored ranks to drift from actual tensor ranks (issue #664).
8///
9/// The output shape is `[lhs_free..., rhs_free..., batch...]` (col-major
10/// batch-trailing convention). Batch dims have the largest stride so that
11/// each batch slice occupies a contiguous block of memory.
12///
13/// # Examples
14///
15/// ```rust
16/// use tenferro_tensor::DotGeneralConfig;
17///
18/// let config = DotGeneralConfig {
19///     lhs_contracting_dims: vec![1],
20///     rhs_contracting_dims: vec![0],
21///     lhs_batch_dims: vec![],
22///     rhs_batch_dims: vec![],
23/// };
24/// ```
25#[derive(Clone, Debug, Hash, PartialEq, Eq)]
26pub struct DotGeneralConfig {
27    pub lhs_contracting_dims: Vec<usize>,
28    pub rhs_contracting_dims: Vec<usize>,
29    pub lhs_batch_dims: Vec<usize>,
30    pub rhs_batch_dims: Vec<usize>,
31}
32
33impl DotGeneralConfig {
34    fn check_no_duplicates(dims: &[usize], label: &str) -> Result<(), String> {
35        let mut seen = std::collections::HashSet::new();
36        for &d in dims {
37            if !seen.insert(d) {
38                return Err(format!("{} contains duplicate dim {}", label, d));
39            }
40        }
41        Ok(())
42    }
43
44    /// Validate that all dimension indices are within range for the given
45    /// explicit ranks and that no axis appears in multiple roles.
46    ///
47    /// Call sites supply the actual operand ranks (from the tensor shapes they
48    /// have in hand). The config itself carries only the dim-numbering roles.
49    ///
50    /// # Examples
51    ///
52    /// ```rust
53    /// use tenferro_tensor::DotGeneralConfig;
54    ///
55    /// let config = DotGeneralConfig {
56    ///     lhs_contracting_dims: vec![1],
57    ///     rhs_contracting_dims: vec![0],
58    ///     lhs_batch_dims: vec![],
59    ///     rhs_batch_dims: vec![],
60    /// };
61    /// config.validate_dims_with_ranks(2, 2).unwrap();
62    /// ```
63    pub fn validate_dims_with_ranks(&self, lhs_rank: usize, rhs_rank: usize) -> Result<(), String> {
64        for &d in &self.lhs_contracting_dims {
65            if d >= lhs_rank {
66                return Err(format!(
67                    "lhs_contracting_dim {} out of bounds for lhs_rank {}",
68                    d, lhs_rank
69                ));
70            }
71        }
72        for &d in &self.rhs_contracting_dims {
73            if d >= rhs_rank {
74                return Err(format!(
75                    "rhs_contracting_dim {} out of bounds for rhs_rank {}",
76                    d, rhs_rank
77                ));
78            }
79        }
80        for &d in &self.lhs_batch_dims {
81            if d >= lhs_rank {
82                return Err(format!(
83                    "lhs_batch_dim {} out of bounds for lhs_rank {}",
84                    d, lhs_rank
85                ));
86            }
87        }
88        for &d in &self.rhs_batch_dims {
89            if d >= rhs_rank {
90                return Err(format!(
91                    "rhs_batch_dim {} out of bounds for rhs_rank {}",
92                    d, rhs_rank
93                ));
94            }
95        }
96        Self::check_no_duplicates(&self.lhs_contracting_dims, "lhs_contracting_dims")?;
97        Self::check_no_duplicates(&self.rhs_contracting_dims, "rhs_contracting_dims")?;
98        Self::check_no_duplicates(&self.lhs_batch_dims, "lhs_batch_dims")?;
99        Self::check_no_duplicates(&self.rhs_batch_dims, "rhs_batch_dims")?;
100        for &d in &self.lhs_contracting_dims {
101            if self.lhs_batch_dims.contains(&d) {
102                return Err(format!(
103                    "lhs dim {} appears in both contracting and batch dims",
104                    d
105                ));
106            }
107        }
108        for &d in &self.rhs_contracting_dims {
109            if self.rhs_batch_dims.contains(&d) {
110                return Err(format!(
111                    "rhs dim {} appears in both contracting and batch dims",
112                    d
113                ));
114            }
115        }
116        if self.lhs_contracting_dims.len() != self.rhs_contracting_dims.len() {
117            return Err(format!(
118                "lhs/rhs contracting dim counts differ ({} vs {})",
119                self.lhs_contracting_dims.len(),
120                self.rhs_contracting_dims.len()
121            ));
122        }
123        if self.lhs_batch_dims.len() != self.rhs_batch_dims.len() {
124            return Err(format!(
125                "lhs/rhs batch dim counts differ ({} vs {})",
126                self.lhs_batch_dims.len(),
127                self.rhs_batch_dims.len()
128            ));
129        }
130        Ok(())
131    }
132}
133
134/// Comparison direction.
135///
136/// # Examples
137///
138/// ```rust
139/// use tenferro_tensor::CompareDir;
140///
141/// let dir = CompareDir::Eq;
142/// ```
143#[derive(Clone, Debug, Hash, PartialEq, Eq)]
144pub enum CompareDir {
145    Eq,
146    Lt,
147    Le,
148    Gt,
149    Ge,
150}
151
152/// StableHLO gather dimension configuration.
153///
154/// # Examples
155///
156/// ```rust
157/// use tenferro_tensor::GatherConfig;
158///
159/// let config = GatherConfig {
160///     offset_dims: vec![],
161///     collapsed_slice_dims: vec![0],
162///     start_index_map: vec![0],
163///     index_vector_dim: 1,
164///     slice_sizes: vec![1],
165/// };
166/// ```
167#[derive(Clone, Debug, Hash, PartialEq, Eq)]
168pub struct GatherConfig {
169    pub offset_dims: Vec<usize>,
170    pub collapsed_slice_dims: Vec<usize>,
171    pub start_index_map: Vec<usize>,
172    pub index_vector_dim: usize,
173    pub slice_sizes: Vec<usize>,
174}
175
176/// StableHLO scatter dimension configuration.
177///
178/// # Examples
179///
180/// ```rust
181/// use tenferro_tensor::ScatterConfig;
182///
183/// let config = ScatterConfig {
184///     update_window_dims: vec![],
185///     inserted_window_dims: vec![0],
186///     scatter_dims_to_operand_dims: vec![0],
187///     index_vector_dim: 1,
188/// };
189/// ```
190#[derive(Clone, Debug, Hash, PartialEq, Eq)]
191pub struct ScatterConfig {
192    pub update_window_dims: Vec<usize>,
193    pub inserted_window_dims: Vec<usize>,
194    pub scatter_dims_to_operand_dims: Vec<usize>,
195    pub index_vector_dim: usize,
196}
197
198/// Slice configuration.
199///
200/// # Examples
201///
202/// ```rust
203/// use tenferro_tensor::SliceConfig;
204///
205/// let config = SliceConfig {
206///     starts: vec![0],
207///     limits: vec![2],
208///     strides: vec![1],
209/// };
210/// ```
211#[derive(Clone, Debug, Hash, PartialEq, Eq)]
212pub struct SliceConfig {
213    pub starts: Vec<usize>,
214    pub limits: Vec<usize>,
215    pub strides: Vec<usize>,
216}
217
218/// StableHLO pad configuration.
219///
220/// # Examples
221///
222/// ```rust
223/// use tenferro_tensor::PadConfig;
224///
225/// let config = PadConfig {
226///     edge_padding_low: vec![1, 1],
227///     edge_padding_high: vec![1, 1],
228///     interior_padding: vec![0, 0],
229/// };
230/// ```
231#[derive(Clone, Debug, Hash, PartialEq, Eq)]
232pub struct PadConfig {
233    pub edge_padding_low: Vec<i64>,
234    pub edge_padding_high: Vec<i64>,
235    pub interior_padding: Vec<i64>,
236}