Skip to main content

strided_kernel/
outer_product.rs

1//! Semantic outer-product API on dynamic-rank strided views.
2
3use std::ops::Mul;
4
5#[cfg(feature = "parallel")]
6use smallvec::SmallVec;
7
8use crate::map_view::broadcast_mul_into;
9use crate::maybe_sync::MaybeSendSync;
10use crate::view::{StridedView, StridedViewMut};
11use crate::{ElementOp, Result, StridedError};
12
13#[cfg(feature = "parallel")]
14type AxisVec<T> = SmallVec<[T; 8]>;
15#[cfg(not(feature = "parallel"))]
16type AxisVec<T> = Vec<T>;
17
18/// Compute `dest[lhs_free..., rhs_free..., batch...] =
19/// lhs[lhs_free..., batch...] * rhs[rhs_free..., batch...]`.
20///
21/// This is a semantic convenience wrapper over [`broadcast_mul_into`]. The
22/// broadcast/mul planner owns kernel selection, so explicit outer-product calls
23/// and equivalent broadcasted multiplication use the same implementation path.
24pub fn batched_outer_product_into<D, A, B, OpA, OpB>(
25    dest: &mut StridedViewMut<D>,
26    lhs: &StridedView<A, OpA>,
27    rhs: &StridedView<B, OpB>,
28    lhs_free_ndim: usize,
29    rhs_free_ndim: usize,
30) -> Result<()>
31where
32    D: Copy + MaybeSendSync + 'static,
33    A: Copy + MaybeSendSync + Mul<B, Output = D> + 'static,
34    B: Copy + MaybeSendSync + 'static,
35    OpA: ElementOp<A>,
36    OpB: ElementOp<B>,
37{
38    validate_batched_outer_shape(dest, lhs, rhs, lhs_free_ndim, rhs_free_ndim)?;
39
40    let batch_ndim = lhs.ndim() - lhs_free_ndim;
41    let mut lhs_axes = AxisVec::<usize>::with_capacity(lhs.ndim());
42    let mut rhs_axes = AxisVec::<usize>::with_capacity(rhs.ndim());
43
44    lhs_axes.extend(0..lhs_free_ndim);
45    rhs_axes.extend(lhs_free_ndim..lhs_free_ndim + rhs_free_ndim);
46
47    let batch_axis_start = lhs_free_ndim + rhs_free_ndim;
48    lhs_axes.extend(batch_axis_start..batch_axis_start + batch_ndim);
49    rhs_axes.extend(batch_axis_start..batch_axis_start + batch_ndim);
50
51    broadcast_mul_into(dest, lhs, &lhs_axes, rhs, &rhs_axes)
52}
53
54fn validate_batched_outer_shape<D, A, OpA, B, OpB>(
55    dest: &StridedViewMut<D>,
56    lhs: &StridedView<A, OpA>,
57    rhs: &StridedView<B, OpB>,
58    lhs_free_ndim: usize,
59    rhs_free_ndim: usize,
60) -> Result<()> {
61    if lhs_free_ndim > lhs.ndim() {
62        return Err(StridedError::RankMismatch(lhs_free_ndim, lhs.ndim()));
63    }
64    if rhs_free_ndim > rhs.ndim() {
65        return Err(StridedError::RankMismatch(rhs_free_ndim, rhs.ndim()));
66    }
67
68    let lhs_batch_ndim = lhs.ndim() - lhs_free_ndim;
69    let rhs_batch_ndim = rhs.ndim() - rhs_free_ndim;
70    if lhs_batch_ndim != rhs_batch_ndim {
71        return Err(StridedError::RankMismatch(lhs_batch_ndim, rhs_batch_ndim));
72    }
73
74    let expected_dest_rank = lhs_free_ndim + rhs_free_ndim + lhs_batch_ndim;
75    if dest.ndim() != expected_dest_rank {
76        return Err(StridedError::RankMismatch(dest.ndim(), expected_dest_rank));
77    }
78
79    ensure_dims(&dest.dims()[..lhs_free_ndim], &lhs.dims()[..lhs_free_ndim])?;
80    ensure_dims(
81        &dest.dims()[lhs_free_ndim..lhs_free_ndim + rhs_free_ndim],
82        &rhs.dims()[..rhs_free_ndim],
83    )?;
84    ensure_dims(
85        &dest.dims()[lhs_free_ndim + rhs_free_ndim..],
86        &lhs.dims()[lhs_free_ndim..],
87    )?;
88    ensure_dims(
89        &dest.dims()[lhs_free_ndim + rhs_free_ndim..],
90        &rhs.dims()[rhs_free_ndim..],
91    )?;
92
93    Ok(())
94}
95
96fn ensure_dims(actual: &[usize], expected: &[usize]) -> Result<()> {
97    if actual == expected {
98        Ok(())
99    } else {
100        Err(StridedError::ShapeMismatch(
101            actual.to_vec(),
102            expected.to_vec(),
103        ))
104    }
105}