strided_kernel/
outer_product.rs1use std::ops::Mul;
4
5#[cfg(feature = "parallel")]
6use smallvec::SmallVec;
7
8use crate::map_view::broadcast_mul_into;
9use crate::maybe_sync::MaybeSendSync;
10use crate::view::{StridedView, StridedViewMut};
11use crate::{ElementOp, Result, StridedError};
12
13#[cfg(feature = "parallel")]
14type AxisVec<T> = SmallVec<[T; 8]>;
15#[cfg(not(feature = "parallel"))]
16type AxisVec<T> = Vec<T>;
17
18pub fn batched_outer_product_into<D, A, B, OpA, OpB>(
25 dest: &mut StridedViewMut<D>,
26 lhs: &StridedView<A, OpA>,
27 rhs: &StridedView<B, OpB>,
28 lhs_free_ndim: usize,
29 rhs_free_ndim: usize,
30) -> Result<()>
31where
32 D: Copy + MaybeSendSync + 'static,
33 A: Copy + MaybeSendSync + Mul<B, Output = D> + 'static,
34 B: Copy + MaybeSendSync + 'static,
35 OpA: ElementOp<A>,
36 OpB: ElementOp<B>,
37{
38 validate_batched_outer_shape(dest, lhs, rhs, lhs_free_ndim, rhs_free_ndim)?;
39
40 let batch_ndim = lhs.ndim() - lhs_free_ndim;
41 let mut lhs_axes = AxisVec::<usize>::with_capacity(lhs.ndim());
42 let mut rhs_axes = AxisVec::<usize>::with_capacity(rhs.ndim());
43
44 lhs_axes.extend(0..lhs_free_ndim);
45 rhs_axes.extend(lhs_free_ndim..lhs_free_ndim + rhs_free_ndim);
46
47 let batch_axis_start = lhs_free_ndim + rhs_free_ndim;
48 lhs_axes.extend(batch_axis_start..batch_axis_start + batch_ndim);
49 rhs_axes.extend(batch_axis_start..batch_axis_start + batch_ndim);
50
51 broadcast_mul_into(dest, lhs, &lhs_axes, rhs, &rhs_axes)
52}
53
54fn validate_batched_outer_shape<D, A, OpA, B, OpB>(
55 dest: &StridedViewMut<D>,
56 lhs: &StridedView<A, OpA>,
57 rhs: &StridedView<B, OpB>,
58 lhs_free_ndim: usize,
59 rhs_free_ndim: usize,
60) -> Result<()> {
61 if lhs_free_ndim > lhs.ndim() {
62 return Err(StridedError::RankMismatch(lhs_free_ndim, lhs.ndim()));
63 }
64 if rhs_free_ndim > rhs.ndim() {
65 return Err(StridedError::RankMismatch(rhs_free_ndim, rhs.ndim()));
66 }
67
68 let lhs_batch_ndim = lhs.ndim() - lhs_free_ndim;
69 let rhs_batch_ndim = rhs.ndim() - rhs_free_ndim;
70 if lhs_batch_ndim != rhs_batch_ndim {
71 return Err(StridedError::RankMismatch(lhs_batch_ndim, rhs_batch_ndim));
72 }
73
74 let expected_dest_rank = lhs_free_ndim + rhs_free_ndim + lhs_batch_ndim;
75 if dest.ndim() != expected_dest_rank {
76 return Err(StridedError::RankMismatch(dest.ndim(), expected_dest_rank));
77 }
78
79 ensure_dims(&dest.dims()[..lhs_free_ndim], &lhs.dims()[..lhs_free_ndim])?;
80 ensure_dims(
81 &dest.dims()[lhs_free_ndim..lhs_free_ndim + rhs_free_ndim],
82 &rhs.dims()[..rhs_free_ndim],
83 )?;
84 ensure_dims(
85 &dest.dims()[lhs_free_ndim + rhs_free_ndim..],
86 &lhs.dims()[lhs_free_ndim..],
87 )?;
88 ensure_dims(
89 &dest.dims()[lhs_free_ndim + rhs_free_ndim..],
90 &rhs.dims()[rhs_free_ndim..],
91 )?;
92
93 Ok(())
94}
95
96fn ensure_dims(actual: &[usize], expected: &[usize]) -> Result<()> {
97 if actual == expected {
98 Ok(())
99 } else {
100 Err(StridedError::ShapeMismatch(
101 actual.to_vec(),
102 expected.to_vec(),
103 ))
104 }
105}