1use smallvec::SmallVec;
28use strided_view::{StridedView, StridedViewMut};
29
30use crate::backend::Backend;
31use crate::{einsum2_dispatch, Einsum2Plan, EinsumError, Result, ScalarBase};
32
33#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
38pub struct DotGeneralConfig<'a> {
39 pub lhs_contracting_dims: &'a [usize],
40 pub rhs_contracting_dims: &'a [usize],
41 pub lhs_batch_dims: &'a [usize],
42 pub rhs_batch_dims: &'a [usize],
43}
44
45#[derive(Clone, Debug)]
46pub(crate) struct DotGeneralLabels {
47 pub lhs_labels: SmallVec<[usize; 8]>,
48 pub rhs_labels: SmallVec<[usize; 8]>,
49 pub out_labels: SmallVec<[usize; 8]>,
50}
51
52fn check_no_duplicates(dims: &[usize], label: &str) -> Result<()> {
53 for (i, &dim) in dims.iter().enumerate() {
54 if dims[..i].contains(&dim) {
55 return Err(EinsumError::InvalidDotGeneralConfig(format!(
56 "{label} contains duplicate dim {dim}"
57 )));
58 }
59 }
60 Ok(())
61}
62
63fn check_bounds(dims: &[usize], rank: usize, label: &str) -> Result<()> {
64 for &dim in dims {
65 if dim >= rank {
66 return Err(EinsumError::InvalidDotGeneralConfig(format!(
67 "{label} dim {dim} out of bounds for rank {rank}"
68 )));
69 }
70 }
71 Ok(())
72}
73
74fn free_dims(rank: usize, contracting: &[usize], batch: &[usize]) -> SmallVec<[usize; 8]> {
75 (0..rank)
76 .filter(|dim| !contracting.contains(dim) && !batch.contains(dim))
77 .collect()
78}
79
80impl DotGeneralConfig<'_> {
81 pub fn validate_dims_with_ranks(&self, lhs_rank: usize, rhs_rank: usize) -> Result<()> {
83 check_bounds(&self.lhs_contracting_dims, lhs_rank, "lhs_contracting")?;
84 check_bounds(&self.rhs_contracting_dims, rhs_rank, "rhs_contracting")?;
85 check_bounds(&self.lhs_batch_dims, lhs_rank, "lhs_batch")?;
86 check_bounds(&self.rhs_batch_dims, rhs_rank, "rhs_batch")?;
87 check_no_duplicates(&self.lhs_contracting_dims, "lhs_contracting_dims")?;
88 check_no_duplicates(&self.rhs_contracting_dims, "rhs_contracting_dims")?;
89 check_no_duplicates(&self.lhs_batch_dims, "lhs_batch_dims")?;
90 check_no_duplicates(&self.rhs_batch_dims, "rhs_batch_dims")?;
91
92 for &dim in self.lhs_contracting_dims {
93 if self.lhs_batch_dims.contains(&dim) {
94 return Err(EinsumError::InvalidDotGeneralConfig(format!(
95 "lhs dim {dim} appears in both contracting and batch dims"
96 )));
97 }
98 }
99 for &dim in self.rhs_contracting_dims {
100 if self.rhs_batch_dims.contains(&dim) {
101 return Err(EinsumError::InvalidDotGeneralConfig(format!(
102 "rhs dim {dim} appears in both contracting and batch dims"
103 )));
104 }
105 }
106 if self.lhs_contracting_dims.len() != self.rhs_contracting_dims.len() {
107 return Err(EinsumError::InvalidDotGeneralConfig(format!(
108 "lhs/rhs contracting dim counts differ ({} vs {})",
109 self.lhs_contracting_dims.len(),
110 self.rhs_contracting_dims.len()
111 )));
112 }
113 if self.lhs_batch_dims.len() != self.rhs_batch_dims.len() {
114 return Err(EinsumError::InvalidDotGeneralConfig(format!(
115 "lhs/rhs batch dim counts differ ({} vs {})",
116 self.lhs_batch_dims.len(),
117 self.rhs_batch_dims.len()
118 )));
119 }
120 Ok(())
121 }
122
123 pub fn expected_output_shape(
125 &self,
126 lhs_shape: &[usize],
127 rhs_shape: &[usize],
128 ) -> Result<Vec<usize>> {
129 self.validate_dims_with_ranks(lhs_shape.len(), rhs_shape.len())?;
130 self.validate_pair_shapes(lhs_shape, rhs_shape)?;
131
132 let lhs_free = free_dims(
133 lhs_shape.len(),
134 &self.lhs_contracting_dims,
135 &self.lhs_batch_dims,
136 );
137 let rhs_free = free_dims(
138 rhs_shape.len(),
139 &self.rhs_contracting_dims,
140 &self.rhs_batch_dims,
141 );
142
143 let mut out =
144 Vec::with_capacity(lhs_free.len() + rhs_free.len() + self.lhs_batch_dims.len());
145 out.extend(lhs_free.iter().map(|&dim| lhs_shape[dim]));
146 out.extend(rhs_free.iter().map(|&dim| rhs_shape[dim]));
147 out.extend(self.lhs_batch_dims.iter().map(|&dim| lhs_shape[dim]));
148 Ok(out)
149 }
150
151 fn validate_pair_shapes(&self, lhs_shape: &[usize], rhs_shape: &[usize]) -> Result<()> {
152 for (&lhs_dim, &rhs_dim) in self.lhs_batch_dims.iter().zip(self.rhs_batch_dims) {
153 if lhs_shape[lhs_dim] != rhs_shape[rhs_dim] {
154 return Err(EinsumError::DimensionMismatch {
155 axis: format!("batch lhs {lhs_dim} rhs {rhs_dim}"),
156 dim_a: lhs_shape[lhs_dim],
157 dim_b: rhs_shape[rhs_dim],
158 });
159 }
160 }
161 for (&lhs_dim, &rhs_dim) in self
162 .lhs_contracting_dims
163 .iter()
164 .zip(self.rhs_contracting_dims)
165 {
166 if lhs_shape[lhs_dim] != rhs_shape[rhs_dim] {
167 return Err(EinsumError::DimensionMismatch {
168 axis: format!("contract lhs {lhs_dim} rhs {rhs_dim}"),
169 dim_a: lhs_shape[lhs_dim],
170 dim_b: rhs_shape[rhs_dim],
171 });
172 }
173 }
174 Ok(())
175 }
176
177 pub(crate) fn labels_for_shapes(
178 &self,
179 lhs_shape: &[usize],
180 rhs_shape: &[usize],
181 out_shape: &[usize],
182 ) -> Result<DotGeneralLabels> {
183 self.validate_dims_with_ranks(lhs_shape.len(), rhs_shape.len())?;
184 self.validate_pair_shapes(lhs_shape, rhs_shape)?;
185
186 let lhs_free = free_dims(
187 lhs_shape.len(),
188 &self.lhs_contracting_dims,
189 &self.lhs_batch_dims,
190 );
191 let rhs_free = free_dims(
192 rhs_shape.len(),
193 &self.rhs_contracting_dims,
194 &self.rhs_batch_dims,
195 );
196
197 let mut lhs_labels = smallvec::smallvec![usize::MAX; lhs_shape.len()];
198 let mut rhs_labels = smallvec::smallvec![usize::MAX; rhs_shape.len()];
199 let mut next_label = 0usize;
200
201 for (&lhs_dim, &rhs_dim) in self.lhs_batch_dims.iter().zip(self.rhs_batch_dims) {
202 lhs_labels[lhs_dim] = next_label;
203 rhs_labels[rhs_dim] = next_label;
204 next_label += 1;
205 }
206 for &lhs_dim in &lhs_free {
207 lhs_labels[lhs_dim] = next_label;
208 next_label += 1;
209 }
210 for &rhs_dim in &rhs_free {
211 rhs_labels[rhs_dim] = next_label;
212 next_label += 1;
213 }
214 for (&lhs_dim, &rhs_dim) in self
215 .lhs_contracting_dims
216 .iter()
217 .zip(self.rhs_contracting_dims)
218 {
219 lhs_labels[lhs_dim] = next_label;
220 rhs_labels[rhs_dim] = next_label;
221 next_label += 1;
222 }
223
224 let expected_out_shape = self.expected_output_shape(lhs_shape, rhs_shape)?;
225 if expected_out_shape.as_slice() != out_shape {
226 return Err(EinsumError::OutputShapeMismatch {
227 expected: expected_out_shape,
228 got: out_shape.to_vec(),
229 });
230 }
231
232 let mut out_labels = SmallVec::<[usize; 8]>::new();
233 out_labels.extend(lhs_free.iter().map(|&dim| lhs_labels[dim]));
234 out_labels.extend(rhs_free.iter().map(|&dim| rhs_labels[dim]));
235 out_labels.extend(self.lhs_batch_dims.iter().map(|&dim| lhs_labels[dim]));
236
237 Ok(DotGeneralLabels {
238 lhs_labels,
239 rhs_labels,
240 out_labels,
241 })
242 }
243}
244
245pub fn dot_general_with_backend_into<T, B>(
247 c: StridedViewMut<T>,
248 a: &StridedView<T>,
249 b: &StridedView<T>,
250 config: &DotGeneralConfig<'_>,
251 alpha: T,
252 beta: T,
253) -> Result<()>
254where
255 T: ScalarBase,
256 B: Backend<T>,
257{
258 let labels = config.labels_for_shapes(a.dims(), b.dims(), c.dims())?;
259 let plan = Einsum2Plan::new(
260 labels.lhs_labels.as_slice(),
261 labels.rhs_labels.as_slice(),
262 labels.out_labels.as_slice(),
263 )?;
264 einsum2_dispatch::<T, B, _>(c, a, b, &plan, alpha, beta, false, false, None)
265}
266
267#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
269pub fn dot_general_into<T: crate::Scalar>(
270 c: StridedViewMut<T>,
271 a: &StridedView<T>,
272 b: &StridedView<T>,
273 config: &DotGeneralConfig<'_>,
274 alpha: T,
275 beta: T,
276) -> Result<()>
277where
278 crate::backend::ActiveBackend: Backend<T>,
279{
280 dot_general_with_backend_into::<T, crate::backend::ActiveBackend>(c, a, b, config, alpha, beta)
281}
282
283#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
285pub fn dot_general_into<T: crate::Scalar>(
286 c: StridedViewMut<T>,
287 a: &StridedView<T>,
288 b: &StridedView<T>,
289 config: &DotGeneralConfig<'_>,
290 alpha: T,
291 beta: T,
292) -> Result<()> {
293 let labels = config.labels_for_shapes(a.dims(), b.dims(), c.dims())?;
294 crate::einsum2_naive_into(
295 c,
296 a,
297 b,
298 labels.out_labels.as_slice(),
299 labels.lhs_labels.as_slice(),
300 labels.rhs_labels.as_slice(),
301 alpha,
302 beta,
303 |x| x,
304 |x| x,
305 )
306}