Skip to main content

strided_einsum2/
dot_general.rs

1//! Axis-based general dot product API.
2//!
3//! The dimension-numbering config borrows axis slices so callers can reuse
4//! their existing metadata without allocating a second owned config.
5//!
6//! ```
7//! use strided_einsum2::{dot_general_into, DotGeneralConfig};
8//! use strided_view::StridedArray;
9//!
10//! let a = StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| {
11//!     (idx[0] + 2 * idx[1] + 1) as f64
12//! });
13//! let b = StridedArray::<f64>::from_fn_col_major(&[3, 2], |idx| {
14//!     (idx[0] + 3 * idx[1] + 1) as f64
15//! });
16//! let mut c = StridedArray::<f64>::col_major(&[2, 2]);
17//!
18//! let config = DotGeneralConfig {
19//!     lhs_contracting_dims: &[1],
20//!     rhs_contracting_dims: &[0],
21//!     lhs_batch_dims: &[],
22//!     rhs_batch_dims: &[],
23//! };
24//! dot_general_into(c.view_mut(), &a.view(), &b.view(), &config, 1.0, 0.0).unwrap();
25//! ```
26
27use smallvec::SmallVec;
28use strided_view::{StridedView, StridedViewMut};
29
30use crate::backend::Backend;
31use crate::{einsum2_dispatch, Einsum2Plan, EinsumError, Result, ScalarBase};
32
33/// DotGeneral dimension configuration.
34///
35/// The output shape is `[lhs_free..., rhs_free..., batch...]`, matching
36/// tenferro's batch-trailing col-major convention.
37#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
38pub struct DotGeneralConfig<'a> {
39    pub lhs_contracting_dims: &'a [usize],
40    pub rhs_contracting_dims: &'a [usize],
41    pub lhs_batch_dims: &'a [usize],
42    pub rhs_batch_dims: &'a [usize],
43}
44
45#[derive(Clone, Debug)]
46pub(crate) struct DotGeneralLabels {
47    pub lhs_labels: SmallVec<[usize; 8]>,
48    pub rhs_labels: SmallVec<[usize; 8]>,
49    pub out_labels: SmallVec<[usize; 8]>,
50}
51
52fn check_no_duplicates(dims: &[usize], label: &str) -> Result<()> {
53    for (i, &dim) in dims.iter().enumerate() {
54        if dims[..i].contains(&dim) {
55            return Err(EinsumError::InvalidDotGeneralConfig(format!(
56                "{label} contains duplicate dim {dim}"
57            )));
58        }
59    }
60    Ok(())
61}
62
63fn check_bounds(dims: &[usize], rank: usize, label: &str) -> Result<()> {
64    for &dim in dims {
65        if dim >= rank {
66            return Err(EinsumError::InvalidDotGeneralConfig(format!(
67                "{label} dim {dim} out of bounds for rank {rank}"
68            )));
69        }
70    }
71    Ok(())
72}
73
74fn free_dims(rank: usize, contracting: &[usize], batch: &[usize]) -> SmallVec<[usize; 8]> {
75    (0..rank)
76        .filter(|dim| !contracting.contains(dim) && !batch.contains(dim))
77        .collect()
78}
79
80impl DotGeneralConfig<'_> {
81    /// Validate dimension indices for explicit operand ranks.
82    pub fn validate_dims_with_ranks(&self, lhs_rank: usize, rhs_rank: usize) -> Result<()> {
83        check_bounds(&self.lhs_contracting_dims, lhs_rank, "lhs_contracting")?;
84        check_bounds(&self.rhs_contracting_dims, rhs_rank, "rhs_contracting")?;
85        check_bounds(&self.lhs_batch_dims, lhs_rank, "lhs_batch")?;
86        check_bounds(&self.rhs_batch_dims, rhs_rank, "rhs_batch")?;
87        check_no_duplicates(&self.lhs_contracting_dims, "lhs_contracting_dims")?;
88        check_no_duplicates(&self.rhs_contracting_dims, "rhs_contracting_dims")?;
89        check_no_duplicates(&self.lhs_batch_dims, "lhs_batch_dims")?;
90        check_no_duplicates(&self.rhs_batch_dims, "rhs_batch_dims")?;
91
92        for &dim in self.lhs_contracting_dims {
93            if self.lhs_batch_dims.contains(&dim) {
94                return Err(EinsumError::InvalidDotGeneralConfig(format!(
95                    "lhs dim {dim} appears in both contracting and batch dims"
96                )));
97            }
98        }
99        for &dim in self.rhs_contracting_dims {
100            if self.rhs_batch_dims.contains(&dim) {
101                return Err(EinsumError::InvalidDotGeneralConfig(format!(
102                    "rhs dim {dim} appears in both contracting and batch dims"
103                )));
104            }
105        }
106        if self.lhs_contracting_dims.len() != self.rhs_contracting_dims.len() {
107            return Err(EinsumError::InvalidDotGeneralConfig(format!(
108                "lhs/rhs contracting dim counts differ ({} vs {})",
109                self.lhs_contracting_dims.len(),
110                self.rhs_contracting_dims.len()
111            )));
112        }
113        if self.lhs_batch_dims.len() != self.rhs_batch_dims.len() {
114            return Err(EinsumError::InvalidDotGeneralConfig(format!(
115                "lhs/rhs batch dim counts differ ({} vs {})",
116                self.lhs_batch_dims.len(),
117                self.rhs_batch_dims.len()
118            )));
119        }
120        Ok(())
121    }
122
123    /// Compute the output shape `[lhs_free..., rhs_free..., batch...]`.
124    pub fn expected_output_shape(
125        &self,
126        lhs_shape: &[usize],
127        rhs_shape: &[usize],
128    ) -> Result<Vec<usize>> {
129        self.validate_dims_with_ranks(lhs_shape.len(), rhs_shape.len())?;
130        self.validate_pair_shapes(lhs_shape, rhs_shape)?;
131
132        let lhs_free = free_dims(
133            lhs_shape.len(),
134            &self.lhs_contracting_dims,
135            &self.lhs_batch_dims,
136        );
137        let rhs_free = free_dims(
138            rhs_shape.len(),
139            &self.rhs_contracting_dims,
140            &self.rhs_batch_dims,
141        );
142
143        let mut out =
144            Vec::with_capacity(lhs_free.len() + rhs_free.len() + self.lhs_batch_dims.len());
145        out.extend(lhs_free.iter().map(|&dim| lhs_shape[dim]));
146        out.extend(rhs_free.iter().map(|&dim| rhs_shape[dim]));
147        out.extend(self.lhs_batch_dims.iter().map(|&dim| lhs_shape[dim]));
148        Ok(out)
149    }
150
151    fn validate_pair_shapes(&self, lhs_shape: &[usize], rhs_shape: &[usize]) -> Result<()> {
152        for (&lhs_dim, &rhs_dim) in self.lhs_batch_dims.iter().zip(self.rhs_batch_dims) {
153            if lhs_shape[lhs_dim] != rhs_shape[rhs_dim] {
154                return Err(EinsumError::DimensionMismatch {
155                    axis: format!("batch lhs {lhs_dim} rhs {rhs_dim}"),
156                    dim_a: lhs_shape[lhs_dim],
157                    dim_b: rhs_shape[rhs_dim],
158                });
159            }
160        }
161        for (&lhs_dim, &rhs_dim) in self
162            .lhs_contracting_dims
163            .iter()
164            .zip(self.rhs_contracting_dims)
165        {
166            if lhs_shape[lhs_dim] != rhs_shape[rhs_dim] {
167                return Err(EinsumError::DimensionMismatch {
168                    axis: format!("contract lhs {lhs_dim} rhs {rhs_dim}"),
169                    dim_a: lhs_shape[lhs_dim],
170                    dim_b: rhs_shape[rhs_dim],
171                });
172            }
173        }
174        Ok(())
175    }
176
177    pub(crate) fn labels_for_shapes(
178        &self,
179        lhs_shape: &[usize],
180        rhs_shape: &[usize],
181        out_shape: &[usize],
182    ) -> Result<DotGeneralLabels> {
183        self.validate_dims_with_ranks(lhs_shape.len(), rhs_shape.len())?;
184        self.validate_pair_shapes(lhs_shape, rhs_shape)?;
185
186        let lhs_free = free_dims(
187            lhs_shape.len(),
188            &self.lhs_contracting_dims,
189            &self.lhs_batch_dims,
190        );
191        let rhs_free = free_dims(
192            rhs_shape.len(),
193            &self.rhs_contracting_dims,
194            &self.rhs_batch_dims,
195        );
196
197        let mut lhs_labels = smallvec::smallvec![usize::MAX; lhs_shape.len()];
198        let mut rhs_labels = smallvec::smallvec![usize::MAX; rhs_shape.len()];
199        let mut next_label = 0usize;
200
201        for (&lhs_dim, &rhs_dim) in self.lhs_batch_dims.iter().zip(self.rhs_batch_dims) {
202            lhs_labels[lhs_dim] = next_label;
203            rhs_labels[rhs_dim] = next_label;
204            next_label += 1;
205        }
206        for &lhs_dim in &lhs_free {
207            lhs_labels[lhs_dim] = next_label;
208            next_label += 1;
209        }
210        for &rhs_dim in &rhs_free {
211            rhs_labels[rhs_dim] = next_label;
212            next_label += 1;
213        }
214        for (&lhs_dim, &rhs_dim) in self
215            .lhs_contracting_dims
216            .iter()
217            .zip(self.rhs_contracting_dims)
218        {
219            lhs_labels[lhs_dim] = next_label;
220            rhs_labels[rhs_dim] = next_label;
221            next_label += 1;
222        }
223
224        let expected_out_shape = self.expected_output_shape(lhs_shape, rhs_shape)?;
225        if expected_out_shape.as_slice() != out_shape {
226            return Err(EinsumError::OutputShapeMismatch {
227                expected: expected_out_shape,
228                got: out_shape.to_vec(),
229            });
230        }
231
232        let mut out_labels = SmallVec::<[usize; 8]>::new();
233        out_labels.extend(lhs_free.iter().map(|&dim| lhs_labels[dim]));
234        out_labels.extend(rhs_free.iter().map(|&dim| rhs_labels[dim]));
235        out_labels.extend(self.lhs_batch_dims.iter().map(|&dim| lhs_labels[dim]));
236
237        Ok(DotGeneralLabels {
238            lhs_labels,
239            rhs_labels,
240            out_labels,
241        })
242    }
243}
244
245/// Compute `C = alpha * dot_general(A, B) + beta * C` with an explicit backend.
246pub fn dot_general_with_backend_into<T, B>(
247    c: StridedViewMut<T>,
248    a: &StridedView<T>,
249    b: &StridedView<T>,
250    config: &DotGeneralConfig<'_>,
251    alpha: T,
252    beta: T,
253) -> Result<()>
254where
255    T: ScalarBase,
256    B: Backend<T>,
257{
258    let labels = config.labels_for_shapes(a.dims(), b.dims(), c.dims())?;
259    let plan = Einsum2Plan::new(
260        labels.lhs_labels.as_slice(),
261        labels.rhs_labels.as_slice(),
262        labels.out_labels.as_slice(),
263    )?;
264    einsum2_dispatch::<T, B, _>(c, a, b, &plan, alpha, beta, false, false, None)
265}
266
267/// Compute `C = alpha * dot_general(A, B) + beta * C` with the active backend.
268#[cfg(any(feature = "faer", feature = "blas", feature = "blas-inject"))]
269pub fn dot_general_into<T: crate::Scalar>(
270    c: StridedViewMut<T>,
271    a: &StridedView<T>,
272    b: &StridedView<T>,
273    config: &DotGeneralConfig<'_>,
274    alpha: T,
275    beta: T,
276) -> Result<()>
277where
278    crate::backend::ActiveBackend: Backend<T>,
279{
280    dot_general_with_backend_into::<T, crate::backend::ActiveBackend>(c, a, b, config, alpha, beta)
281}
282
283/// Compute `C = alpha * dot_general(A, B) + beta * C` with the naive fallback.
284#[cfg(not(any(feature = "faer", feature = "blas", feature = "blas-inject")))]
285pub fn dot_general_into<T: crate::Scalar>(
286    c: StridedViewMut<T>,
287    a: &StridedView<T>,
288    b: &StridedView<T>,
289    config: &DotGeneralConfig<'_>,
290    alpha: T,
291    beta: T,
292) -> Result<()> {
293    let labels = config.labels_for_shapes(a.dims(), b.dims(), c.dims())?;
294    crate::einsum2_naive_into(
295        c,
296        a,
297        b,
298        labels.out_labels.as_slice(),
299        labels.lhs_labels.as_slice(),
300        labels.rhs_labels.as_slice(),
301        alpha,
302        beta,
303        |x| x,
304        |x| x,
305    )
306}