1#[inline(always)]
2pub(crate) fn dispatch<R>(f: impl FnOnce() -> R) -> R {
3 #[cfg(feature = "simd")]
4 {
5 pulp::Arch::new().dispatch(f)
6 }
7 #[cfg(not(feature = "simd"))]
8 {
9 f()
10 }
11}
12
13#[inline(always)]
14pub(crate) fn dispatch_if_large<R>(len: usize, f: impl FnOnce() -> R) -> R {
15 if len >= 64 {
18 dispatch(f)
19 } else {
20 f()
21 }
22}
23
24#[cfg(feature = "simd")]
25#[inline(always)]
26unsafe fn cast_slice<T, U>(src: &[T]) -> &[U] {
27 debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
28 unsafe { std::slice::from_raw_parts(src.as_ptr().cast::<U>(), src.len()) }
29}
30
31#[cfg(feature = "simd")]
32#[inline(always)]
33unsafe fn cast_slice_mut<T, U>(src: &mut [T]) -> &mut [U] {
34 debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
35 unsafe { std::slice::from_raw_parts_mut(src.as_mut_ptr().cast::<U>(), src.len()) }
36}
37
38#[cfg(feature = "simd")]
39macro_rules! impl_simd_mul_partial {
40 (
41 $mul_into:ident,
42 $ty:ty,
43 $lanes:ident,
44 $load:ident,
45 $store:ident,
46 $mul:ident
47 ) => {
48 fn $mul_into(dst: &mut [$ty], a: &[$ty], b: &[$ty]) {
49 struct Mul<'a> {
50 dst: &'a mut [$ty],
51 a: &'a [$ty],
52 b: &'a [$ty],
53 }
54
55 impl<'a> pulp::WithSimd for Mul<'a> {
56 type Output = ();
57
58 #[inline(always)]
59 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
60 debug_assert_eq!(self.dst.len(), self.a.len());
61 debug_assert_eq!(self.dst.len(), self.b.len());
62
63 let lanes = S::$lanes;
64 let mut i = 0usize;
65 while i + lanes <= self.dst.len() {
66 let va = simd.$load(&self.a[i..i + lanes]);
67 let vb = simd.$load(&self.b[i..i + lanes]);
68 simd.$store(&mut self.dst[i..i + lanes], simd.$mul(va, vb));
69 i += lanes;
70 }
71 if i < self.dst.len() {
72 let va = simd.$load(&self.a[i..]);
73 let vb = simd.$load(&self.b[i..]);
74 simd.$store(&mut self.dst[i..], simd.$mul(va, vb));
75 }
76 }
77 }
78
79 pulp::Arch::new().dispatch(Mul { dst, a, b });
80 }
81 };
82}
83
84#[cfg(feature = "simd")]
85macro_rules! impl_simd_mul_body_tail {
86 (
87 $mul_into:ident,
88 $ty:ty,
89 $as_simd:ident,
90 $as_mut_simd:ident,
91 $load:ident,
92 $store:ident,
93 $mul:ident
94 ) => {
95 fn $mul_into(dst: &mut [$ty], a: &[$ty], b: &[$ty]) {
96 struct Mul<'a> {
97 dst: &'a mut [$ty],
98 a: &'a [$ty],
99 b: &'a [$ty],
100 }
101
102 impl<'a> pulp::WithSimd for Mul<'a> {
103 type Output = ();
104
105 #[inline(always)]
106 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
107 debug_assert_eq!(self.dst.len(), self.a.len());
108 debug_assert_eq!(self.dst.len(), self.b.len());
109
110 let (dst_head, dst_tail) = S::$as_mut_simd(self.dst);
111 let (a_head, a_tail) = S::$as_simd(self.a);
112 let (b_head, b_tail) = S::$as_simd(self.b);
113 debug_assert_eq!(dst_head.len(), a_head.len());
114 debug_assert_eq!(dst_head.len(), b_head.len());
115 debug_assert_eq!(dst_tail.len(), a_tail.len());
116 debug_assert_eq!(dst_tail.len(), b_tail.len());
117
118 for i in 0..dst_head.len() {
119 dst_head[i] = simd.$mul(a_head[i], b_head[i]);
120 }
121 if !dst_tail.is_empty() {
122 let va = simd.$load(a_tail);
123 let vb = simd.$load(b_tail);
124 simd.$store(dst_tail, simd.$mul(va, vb));
125 }
126 }
127 }
128
129 pulp::Arch::new().dispatch(Mul { dst, a, b });
130 }
131 };
132}
133
134#[cfg(feature = "simd")]
135impl_simd_mul_partial!(
136 simd_mul_f32_into,
137 f32,
138 F32_LANES,
139 partial_load_f32s,
140 partial_store_f32s,
141 mul_f32s
142);
143
144#[cfg(feature = "simd")]
145impl_simd_mul_partial!(
146 simd_mul_f64_into,
147 f64,
148 F64_LANES,
149 partial_load_f64s,
150 partial_store_f64s,
151 mul_f64s
152);
153
154#[cfg(feature = "simd")]
155impl_simd_mul_body_tail!(
156 simd_mul_c32_into,
157 num_complex::Complex32,
158 as_simd_c32s,
159 as_mut_simd_c32s,
160 partial_load_c32s,
161 partial_store_c32s,
162 mul_e_c32s
163);
164
165#[cfg(feature = "simd")]
166impl_simd_mul_body_tail!(
167 simd_mul_c64_into,
168 num_complex::Complex64,
169 as_simd_c64s,
170 as_mut_simd_c64s,
171 partial_load_c64s,
172 partial_store_c64s,
173 mul_e_c64s
174);
175
176#[cfg(feature = "simd")]
177#[inline]
178pub(crate) fn try_mul_contiguous<D: 'static, A: 'static, B: 'static>(
179 dst: &mut [D],
180 a: &[A],
181 b: &[B],
182) -> bool {
183 use std::any::TypeId;
184
185 macro_rules! try_same_type {
186 ($ty:ty, $mul_into:ident) => {
187 if TypeId::of::<D>() == TypeId::of::<$ty>()
188 && TypeId::of::<A>() == TypeId::of::<$ty>()
189 && TypeId::of::<B>() == TypeId::of::<$ty>()
190 {
191 unsafe { $mul_into(cast_slice_mut(dst), cast_slice(a), cast_slice(b)) };
192 return true;
193 }
194 };
195 }
196
197 try_same_type!(f64, simd_mul_f64_into);
198 try_same_type!(f32, simd_mul_f32_into);
199 try_same_type!(num_complex::Complex64, simd_mul_c64_into);
200 try_same_type!(num_complex::Complex32, simd_mul_c32_into);
201
202 false
203}
204
205#[cfg(not(feature = "simd"))]
206#[inline]
207pub(crate) fn try_mul_contiguous<D: 'static, A: 'static, B: 'static>(
208 _dst: &mut [D],
209 _a: &[A],
210 _b: &[B],
211) -> bool {
212 false
213}
214
215#[cfg(feature = "parallel")]
216const TRANSPOSE_TILE: usize = 8;
217
218#[cfg(feature = "parallel")]
219unsafe fn mul_transposed_scalar_rhs_source_contiguous<T>(
220 dst: *mut T,
221 src: *const T,
222 scalar: T,
223 inner_len: usize,
224 row_len: usize,
225 src_fast_stride: isize,
226 src_row_stride: isize,
227) where
228 T: Copy + std::ops::Mul<Output = T>,
229{
230 let mut inner0 = 0usize;
231
232 while inner0 < inner_len {
233 let inner_count = TRANSPOSE_TILE.min(inner_len - inner0);
234 let mut row0 = 0usize;
235
236 while row0 < row_len {
237 let row_count = TRANSPOSE_TILE.min(row_len - row0);
238
239 for inner in 0..inner_count {
240 let src_base = src.offset((inner0 + inner) as isize * src_fast_stride);
241 for row in 0..row_count {
242 let row_index = row0 + row;
243 let src_offset = row_index as isize * src_row_stride;
244 *dst.add(row_index * inner_len + inner0 + inner) =
245 *src_base.offset(src_offset) * scalar;
246 }
247 }
248
249 row0 += TRANSPOSE_TILE;
250 }
251
252 inner0 += TRANSPOSE_TILE;
253 }
254}
255
256#[cfg(feature = "parallel")]
257unsafe fn mul_transposed_scalar_rhs_dst_contiguous<T>(
258 dst: *mut T,
259 src: *const T,
260 scalar: T,
261 inner_len: usize,
262 row_len: usize,
263 src_fast_stride: isize,
264 src_row_stride: isize,
265) where
266 T: Copy + std::ops::Mul<Output = T>,
267{
268 for row in 0..row_len {
269 let dst_base = dst.add(row * inner_len);
270 let src_base = src.offset(row as isize * src_row_stride);
271 for inner in 0..inner_len {
272 *dst_base.add(inner) = *src_base.offset(inner as isize * src_fast_stride) * scalar;
273 }
274 }
275}
276
277#[cfg(feature = "parallel")]
278unsafe fn mul_transposed_scalar_lhs_source_contiguous<T>(
279 dst: *mut T,
280 scalar: T,
281 src: *const T,
282 inner_len: usize,
283 row_len: usize,
284 src_fast_stride: isize,
285 src_row_stride: isize,
286) where
287 T: Copy + std::ops::Mul<Output = T>,
288{
289 let mut inner0 = 0usize;
290
291 while inner0 < inner_len {
292 let inner_count = TRANSPOSE_TILE.min(inner_len - inner0);
293 let mut row0 = 0usize;
294
295 while row0 < row_len {
296 let row_count = TRANSPOSE_TILE.min(row_len - row0);
297
298 for inner in 0..inner_count {
299 let src_base = src.offset((inner0 + inner) as isize * src_fast_stride);
300 for row in 0..row_count {
301 let row_index = row0 + row;
302 let src_offset = row_index as isize * src_row_stride;
303 *dst.add(row_index * inner_len + inner0 + inner) =
304 scalar * *src_base.offset(src_offset);
305 }
306 }
307
308 row0 += TRANSPOSE_TILE;
309 }
310
311 inner0 += TRANSPOSE_TILE;
312 }
313}
314
315#[cfg(feature = "parallel")]
316unsafe fn mul_transposed_scalar_lhs_dst_contiguous<T>(
317 dst: *mut T,
318 scalar: T,
319 src: *const T,
320 inner_len: usize,
321 row_len: usize,
322 src_fast_stride: isize,
323 src_row_stride: isize,
324) where
325 T: Copy + std::ops::Mul<Output = T>,
326{
327 for row in 0..row_len {
328 let dst_base = dst.add(row * inner_len);
329 let src_base = src.offset(row as isize * src_row_stride);
330 for inner in 0..inner_len {
331 *dst_base.add(inner) = scalar * *src_base.offset(inner as isize * src_fast_stride);
332 }
333 }
334}
335
336#[cfg(feature = "parallel")]
337#[inline(always)]
338pub(crate) unsafe fn mul_transposed_scalar_rhs_2d_typed<T>(
339 dst: *mut T,
340 src: *const T,
341 scalar: T,
342 inner_len: usize,
343 row_len: usize,
344 src_fast_stride: isize,
345 src_row_stride: isize,
346) where
347 T: Copy + std::ops::Mul<Output = T>,
348{
349 if row_len > inner_len {
350 unsafe {
351 mul_transposed_scalar_rhs_source_contiguous(
352 dst,
353 src,
354 scalar,
355 inner_len,
356 row_len,
357 src_fast_stride,
358 src_row_stride,
359 );
360 }
361 } else {
362 unsafe {
363 mul_transposed_scalar_rhs_dst_contiguous(
364 dst,
365 src,
366 scalar,
367 inner_len,
368 row_len,
369 src_fast_stride,
370 src_row_stride,
371 );
372 }
373 }
374}
375
376#[cfg(feature = "parallel")]
377#[inline(always)]
378pub(crate) unsafe fn mul_transposed_scalar_lhs_2d_typed<T>(
379 dst: *mut T,
380 scalar: T,
381 src: *const T,
382 inner_len: usize,
383 row_len: usize,
384 src_fast_stride: isize,
385 src_row_stride: isize,
386) where
387 T: Copy + std::ops::Mul<Output = T>,
388{
389 if row_len > inner_len {
390 unsafe {
391 mul_transposed_scalar_lhs_source_contiguous(
392 dst,
393 scalar,
394 src,
395 inner_len,
396 row_len,
397 src_fast_stride,
398 src_row_stride,
399 );
400 }
401 } else {
402 unsafe {
403 mul_transposed_scalar_lhs_dst_contiguous(
404 dst,
405 scalar,
406 src,
407 inner_len,
408 row_len,
409 src_fast_stride,
410 src_row_stride,
411 );
412 }
413 }
414}
415
416#[inline]
417#[cfg(feature = "parallel")]
418pub(crate) unsafe fn try_mul_transposed_scalar_rhs_2d<D: 'static, A: 'static, B: 'static>(
419 dst: *mut D,
420 src: *const A,
421 scalar: *const B,
422 inner_len: usize,
423 row_len: usize,
424 src_fast_stride: isize,
425 src_row_stride: isize,
426) -> bool {
427 use std::any::TypeId;
428
429 macro_rules! try_same_type {
430 ($ty:ty) => {
431 if TypeId::of::<D>() == TypeId::of::<$ty>()
432 && TypeId::of::<A>() == TypeId::of::<$ty>()
433 && TypeId::of::<B>() == TypeId::of::<$ty>()
434 {
435 unsafe {
436 mul_transposed_scalar_rhs_2d_typed(
437 dst.cast::<$ty>(),
438 src.cast::<$ty>(),
439 *scalar.cast::<$ty>(),
440 inner_len,
441 row_len,
442 src_fast_stride,
443 src_row_stride,
444 );
445 }
446 return true;
447 }
448 };
449 }
450
451 try_same_type!(f64);
452 try_same_type!(f32);
453 try_same_type!(num_complex::Complex64);
454 try_same_type!(num_complex::Complex32);
455
456 false
457}
458
459#[inline]
460#[cfg(feature = "parallel")]
461pub(crate) unsafe fn try_mul_transposed_scalar_lhs_2d<D: 'static, A: 'static, B: 'static>(
462 dst: *mut D,
463 scalar: *const A,
464 src: *const B,
465 inner_len: usize,
466 row_len: usize,
467 src_fast_stride: isize,
468 src_row_stride: isize,
469) -> bool {
470 use std::any::TypeId;
471
472 macro_rules! try_same_type {
473 ($ty:ty) => {
474 if TypeId::of::<D>() == TypeId::of::<$ty>()
475 && TypeId::of::<A>() == TypeId::of::<$ty>()
476 && TypeId::of::<B>() == TypeId::of::<$ty>()
477 {
478 unsafe {
479 mul_transposed_scalar_lhs_2d_typed(
480 dst.cast::<$ty>(),
481 *scalar.cast::<$ty>(),
482 src.cast::<$ty>(),
483 inner_len,
484 row_len,
485 src_fast_stride,
486 src_row_stride,
487 );
488 }
489 return true;
490 }
491 };
492 }
493
494 try_same_type!(f64);
495 try_same_type!(f32);
496 try_same_type!(num_complex::Complex64);
497 try_same_type!(num_complex::Complex32);
498
499 false
500}
501
502pub trait MaybeSimdOps: Copy + Sized {
507 fn try_simd_sum(_src: &[Self]) -> Option<Self> {
508 None
509 }
510 fn try_simd_dot(_a: &[Self], _b: &[Self]) -> Option<Self> {
511 None
512 }
513}
514
515macro_rules! impl_no_simd {
517 ($($t:ty),*) => {
518 $(impl MaybeSimdOps for $t {})*
519 };
520}
521
522impl_no_simd!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
523
524impl<T: num_traits::Num + Copy + Clone + std::ops::Neg<Output = T>> MaybeSimdOps
525 for num_complex::Complex<T>
526{
527}
528
529#[cfg(not(feature = "simd"))]
531impl MaybeSimdOps for f32 {}
532
533#[cfg(not(feature = "simd"))]
534impl MaybeSimdOps for f64 {}
535
536#[cfg(feature = "simd")]
537mod simd_impls {
538 use super::MaybeSimdOps;
539 use pulp::{Simd, WithSimd};
540
541 impl MaybeSimdOps for f32 {
542 fn try_simd_sum(src: &[f32]) -> Option<f32> {
543 struct Sum<'a>(&'a [f32]);
544 impl<'a> WithSimd for Sum<'a> {
545 type Output = f32;
546
547 #[inline(always)]
548 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
549 let (head, tail) = S::as_simd_f32s(self.0);
550
551 let mut acc0 = simd.splat_f32s(0.0);
552 let mut acc1 = simd.splat_f32s(0.0);
553 let mut acc2 = simd.splat_f32s(0.0);
554 let mut acc3 = simd.splat_f32s(0.0);
555
556 let mut i = 0usize;
557 while i + 4 <= head.len() {
558 acc0 = simd.add_f32s(acc0, head[i]);
559 acc1 = simd.add_f32s(acc1, head[i + 1]);
560 acc2 = simd.add_f32s(acc2, head[i + 2]);
561 acc3 = simd.add_f32s(acc3, head[i + 3]);
562 i += 4;
563 }
564 for &v in &head[i..] {
565 acc0 = simd.add_f32s(acc0, v);
566 }
567
568 let acc = simd.add_f32s(simd.add_f32s(acc0, acc1), simd.add_f32s(acc2, acc3));
569 let mut sum = simd.reduce_sum_f32s(acc);
570 for &x in tail {
571 sum += x;
572 }
573 sum
574 }
575 }
576
577 Some(pulp::Arch::new().dispatch(Sum(src)))
578 }
579
580 fn try_simd_dot(a: &[f32], b: &[f32]) -> Option<f32> {
581 struct Dot<'a> {
582 a: &'a [f32],
583 b: &'a [f32],
584 }
585 impl<'a> WithSimd for Dot<'a> {
586 type Output = f32;
587
588 #[inline(always)]
589 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
590 debug_assert_eq!(self.a.len(), self.b.len());
591 let (a_head, a_tail) = S::as_simd_f32s(self.a);
592 let (b_head, b_tail) = S::as_simd_f32s(self.b);
593 debug_assert_eq!(a_head.len(), b_head.len());
594 debug_assert_eq!(a_tail.len(), b_tail.len());
595
596 let mut acc0 = simd.splat_f32s(0.0);
597 let mut acc1 = simd.splat_f32s(0.0);
598 let mut acc2 = simd.splat_f32s(0.0);
599 let mut acc3 = simd.splat_f32s(0.0);
600
601 let mut i = 0usize;
602 while i + 4 <= a_head.len() {
603 acc0 = simd.mul_add_f32s(a_head[i], b_head[i], acc0);
604 acc1 = simd.mul_add_f32s(a_head[i + 1], b_head[i + 1], acc1);
605 acc2 = simd.mul_add_f32s(a_head[i + 2], b_head[i + 2], acc2);
606 acc3 = simd.mul_add_f32s(a_head[i + 3], b_head[i + 3], acc3);
607 i += 4;
608 }
609 for j in i..a_head.len() {
610 acc0 = simd.mul_add_f32s(a_head[j], b_head[j], acc0);
611 }
612
613 let acc = simd.add_f32s(simd.add_f32s(acc0, acc1), simd.add_f32s(acc2, acc3));
614 let mut sum = simd.reduce_sum_f32s(acc);
615 for (&x, &y) in a_tail.iter().zip(b_tail.iter()) {
616 sum += x * y;
617 }
618 sum
619 }
620 }
621
622 Some(pulp::Arch::new().dispatch(Dot { a, b }))
623 }
624 }
625
626 impl MaybeSimdOps for f64 {
627 fn try_simd_sum(src: &[f64]) -> Option<f64> {
628 struct Sum<'a>(&'a [f64]);
629 impl<'a> WithSimd for Sum<'a> {
630 type Output = f64;
631
632 #[inline(always)]
633 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
634 let (head, tail) = S::as_simd_f64s(self.0);
635
636 let mut acc0 = simd.splat_f64s(0.0);
637 let mut acc1 = simd.splat_f64s(0.0);
638 let mut acc2 = simd.splat_f64s(0.0);
639 let mut acc3 = simd.splat_f64s(0.0);
640
641 let mut i = 0usize;
642 while i + 4 <= head.len() {
643 acc0 = simd.add_f64s(acc0, head[i]);
644 acc1 = simd.add_f64s(acc1, head[i + 1]);
645 acc2 = simd.add_f64s(acc2, head[i + 2]);
646 acc3 = simd.add_f64s(acc3, head[i + 3]);
647 i += 4;
648 }
649 for &v in &head[i..] {
650 acc0 = simd.add_f64s(acc0, v);
651 }
652
653 let acc = simd.add_f64s(simd.add_f64s(acc0, acc1), simd.add_f64s(acc2, acc3));
654 let mut sum = simd.reduce_sum_f64s(acc);
655 for &x in tail {
656 sum += x;
657 }
658 sum
659 }
660 }
661
662 Some(pulp::Arch::new().dispatch(Sum(src)))
663 }
664
665 fn try_simd_dot(a: &[f64], b: &[f64]) -> Option<f64> {
666 struct Dot<'a> {
667 a: &'a [f64],
668 b: &'a [f64],
669 }
670 impl<'a> WithSimd for Dot<'a> {
671 type Output = f64;
672
673 #[inline(always)]
674 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
675 debug_assert_eq!(self.a.len(), self.b.len());
676 let (a_head, a_tail) = S::as_simd_f64s(self.a);
677 let (b_head, b_tail) = S::as_simd_f64s(self.b);
678 debug_assert_eq!(a_head.len(), b_head.len());
679 debug_assert_eq!(a_tail.len(), b_tail.len());
680
681 let mut acc0 = simd.splat_f64s(0.0);
682 let mut acc1 = simd.splat_f64s(0.0);
683 let mut acc2 = simd.splat_f64s(0.0);
684 let mut acc3 = simd.splat_f64s(0.0);
685
686 let mut i = 0usize;
687 while i + 4 <= a_head.len() {
688 acc0 = simd.mul_add_f64s(a_head[i], b_head[i], acc0);
689 acc1 = simd.mul_add_f64s(a_head[i + 1], b_head[i + 1], acc1);
690 acc2 = simd.mul_add_f64s(a_head[i + 2], b_head[i + 2], acc2);
691 acc3 = simd.mul_add_f64s(a_head[i + 3], b_head[i + 3], acc3);
692 i += 4;
693 }
694 for j in i..a_head.len() {
695 acc0 = simd.mul_add_f64s(a_head[j], b_head[j], acc0);
696 }
697
698 let acc = simd.add_f64s(simd.add_f64s(acc0, acc1), simd.add_f64s(acc2, acc3));
699 let mut sum = simd.reduce_sum_f64s(acc);
700 for (&x, &y) in a_tail.iter().zip(b_tail.iter()) {
701 sum += x * y;
702 }
703 sum
704 }
705 }
706
707 Some(pulp::Arch::new().dispatch(Dot { a, b }))
708 }
709 }
710}
711
712#[cfg(test)]
713mod tests {
714 #[cfg(feature = "simd")]
715 #[test]
716 fn test_try_mul_contiguous_complex64() {
717 let a = vec![
718 num_complex::Complex64::new(1.0, 2.0),
719 num_complex::Complex64::new(-3.0, 4.0),
720 num_complex::Complex64::new(0.5, -0.25),
721 ];
722 let b = vec![
723 num_complex::Complex64::new(5.0, -1.0),
724 num_complex::Complex64::new(2.0, 0.25),
725 num_complex::Complex64::new(-4.0, 3.0),
726 ];
727 let mut dst = vec![num_complex::Complex64::new(0.0, 0.0); a.len()];
728
729 assert!(super::try_mul_contiguous(&mut dst, &a, &b));
730 for i in 0..a.len() {
731 assert_eq!(dst[i], a[i] * b[i]);
732 }
733 }
734
735 #[cfg(feature = "simd")]
736 #[test]
737 fn test_try_mul_contiguous_complex32() {
738 let a = vec![
739 num_complex::Complex32::new(1.0, 2.0),
740 num_complex::Complex32::new(-3.0, 4.0),
741 num_complex::Complex32::new(0.5, -0.25),
742 ];
743 let b = vec![
744 num_complex::Complex32::new(5.0, -1.0),
745 num_complex::Complex32::new(2.0, 0.25),
746 num_complex::Complex32::new(-4.0, 3.0),
747 ];
748 let mut dst = vec![num_complex::Complex32::new(0.0, 0.0); a.len()];
749
750 assert!(super::try_mul_contiguous(&mut dst, &a, &b));
751 for i in 0..a.len() {
752 assert_eq!(dst[i], a[i] * b[i]);
753 }
754 }
755
756 #[cfg(feature = "parallel")]
757 #[test]
758 fn test_transposed_scalar_rhs_2d_f64_source_contiguous() {
759 let inner_len = 5usize;
760 let row_len = 7usize;
761 let src: Vec<f64> = (0..inner_len * row_len).map(|i| i as f64 + 0.25).collect();
762 let scalar = 2.0f64;
763 let mut dst = vec![0.0f64; inner_len * row_len];
764
765 let used = unsafe {
766 super::try_mul_transposed_scalar_rhs_2d::<f64, f64, f64>(
767 dst.as_mut_ptr(),
768 src.as_ptr(),
769 &scalar,
770 inner_len,
771 row_len,
772 row_len as isize,
773 1,
774 )
775 };
776
777 assert!(used);
778 for row in 0..row_len {
779 for inner in 0..inner_len {
780 assert_eq!(
781 dst[row * inner_len + inner],
782 src[inner * row_len + row] * scalar
783 );
784 }
785 }
786 }
787
788 #[cfg(feature = "parallel")]
789 #[test]
790 fn test_transposed_scalar_lhs_2d_f32_source_contiguous() {
791 let inner_len = 5usize;
792 let row_len = 7usize;
793 let src: Vec<f32> = (0..inner_len * row_len)
794 .map(|i| i as f32 * 0.5 + 1.0)
795 .collect();
796 let scalar = 3.0f32;
797 let mut dst = vec![0.0f32; inner_len * row_len];
798
799 let used = unsafe {
800 super::try_mul_transposed_scalar_lhs_2d::<f32, f32, f32>(
801 dst.as_mut_ptr(),
802 &scalar,
803 src.as_ptr(),
804 inner_len,
805 row_len,
806 row_len as isize,
807 1,
808 )
809 };
810
811 assert!(used);
812 for row in 0..row_len {
813 for inner in 0..inner_len {
814 assert_eq!(
815 dst[row * inner_len + inner],
816 scalar * src[inner * row_len + row]
817 );
818 }
819 }
820 }
821
822 #[cfg(feature = "parallel")]
823 #[test]
824 fn test_transposed_scalar_rhs_2d_handles_short_rows() {
825 let inner_len = 8usize;
826 let row_len = 3usize;
827 let src: Vec<f64> = (0..inner_len * row_len).map(|i| i as f64 + 1.0).collect();
828 let scalar = 2.0f64;
829 let mut dst = vec![0.0f64; inner_len * row_len];
830
831 let used = unsafe {
832 super::try_mul_transposed_scalar_rhs_2d::<f64, f64, f64>(
833 dst.as_mut_ptr(),
834 src.as_ptr(),
835 &scalar,
836 inner_len,
837 row_len,
838 row_len as isize,
839 1,
840 )
841 };
842
843 assert!(used);
844 for row in 0..row_len {
845 for inner in 0..inner_len {
846 assert_eq!(
847 dst[row * inner_len + inner],
848 src[inner * row_len + row] * scalar
849 );
850 }
851 }
852 }
853
854 #[cfg(feature = "parallel")]
855 #[test]
856 fn test_transposed_scalar_2d_handles_small_square_tiles() {
857 let inner_len = 4usize;
858 let row_len = 4usize;
859 let src: Vec<f64> = (0..inner_len * row_len).map(|i| i as f64 + 1.0).collect();
860 let scalar = 2.0f64;
861 let mut rhs_dst = vec![0.0f64; inner_len * row_len];
862 let mut lhs_dst = vec![0.0f64; inner_len * row_len];
863
864 let rhs_used = unsafe {
865 super::try_mul_transposed_scalar_rhs_2d::<f64, f64, f64>(
866 rhs_dst.as_mut_ptr(),
867 src.as_ptr(),
868 &scalar,
869 inner_len,
870 row_len,
871 row_len as isize,
872 1,
873 )
874 };
875 let lhs_used = unsafe {
876 super::try_mul_transposed_scalar_lhs_2d::<f64, f64, f64>(
877 lhs_dst.as_mut_ptr(),
878 &scalar,
879 src.as_ptr(),
880 inner_len,
881 row_len,
882 row_len as isize,
883 1,
884 )
885 };
886
887 assert!(rhs_used);
888 assert!(lhs_used);
889 for row in 0..row_len {
890 for inner in 0..inner_len {
891 assert_eq!(
892 rhs_dst[row * inner_len + inner],
893 src[inner * row_len + row] * scalar
894 );
895 assert_eq!(
896 lhs_dst[row * inner_len + inner],
897 scalar * src[inner * row_len + row]
898 );
899 }
900 }
901 }
902
903 #[cfg(feature = "parallel")]
904 #[test]
905 fn test_transposed_scalar_rhs_2d_complex64_source_contiguous() {
906 let inner_len = 5usize;
907 let row_len = 7usize;
908 let src: Vec<num_complex::Complex64> = (0..inner_len * row_len)
909 .map(|i| num_complex::Complex64::new(i as f64 + 0.25, i as f64 * -0.5))
910 .collect();
911 let scalar = num_complex::Complex64::new(2.0, -0.25);
912 let mut dst = vec![num_complex::Complex64::new(0.0, 0.0); inner_len * row_len];
913
914 let used = unsafe {
915 super::try_mul_transposed_scalar_rhs_2d::<
916 num_complex::Complex64,
917 num_complex::Complex64,
918 num_complex::Complex64,
919 >(
920 dst.as_mut_ptr(),
921 src.as_ptr(),
922 &scalar,
923 inner_len,
924 row_len,
925 row_len as isize,
926 1,
927 )
928 };
929
930 assert!(used);
931 for row in 0..row_len {
932 for inner in 0..inner_len {
933 assert_eq!(
934 dst[row * inner_len + inner],
935 src[inner * row_len + row] * scalar
936 );
937 }
938 }
939 }
940
941 #[cfg(feature = "parallel")]
942 #[test]
943 fn test_transposed_scalar_lhs_2d_complex32_source_contiguous() {
944 let inner_len = 5usize;
945 let row_len = 7usize;
946 let src: Vec<num_complex::Complex32> = (0..inner_len * row_len)
947 .map(|i| num_complex::Complex32::new(i as f32 * 0.5 + 1.0, i as f32 * 0.25))
948 .collect();
949 let scalar = num_complex::Complex32::new(3.0, -0.5);
950 let mut dst = vec![num_complex::Complex32::new(0.0, 0.0); inner_len * row_len];
951
952 let used = unsafe {
953 super::try_mul_transposed_scalar_lhs_2d::<
954 num_complex::Complex32,
955 num_complex::Complex32,
956 num_complex::Complex32,
957 >(
958 dst.as_mut_ptr(),
959 &scalar,
960 src.as_ptr(),
961 inner_len,
962 row_len,
963 row_len as isize,
964 1,
965 )
966 };
967
968 assert!(used);
969 for row in 0..row_len {
970 for inner in 0..inner_len {
971 assert_eq!(
972 dst[row * inner_len + inner],
973 scalar * src[inner * row_len + row]
974 );
975 }
976 }
977 }
978}