Skip to main content

strided_kernel/
simd.rs

1#[inline(always)]
2pub(crate) fn dispatch<R>(f: impl FnOnce() -> R) -> R {
3    #[cfg(feature = "simd")]
4    {
5        pulp::Arch::new().dispatch(f)
6    }
7    #[cfg(not(feature = "simd"))]
8    {
9        f()
10    }
11}
12
13#[inline(always)]
14pub(crate) fn dispatch_if_large<R>(len: usize, f: impl FnOnce() -> R) -> R {
15    // Avoid runtime-dispatch overhead for tiny loops (especially common for small-array cases).
16    // This is a heuristic; correctness does not depend on it.
17    if len >= 64 {
18        dispatch(f)
19    } else {
20        f()
21    }
22}
23
24#[cfg(feature = "simd")]
25#[inline(always)]
26unsafe fn cast_slice<T, U>(src: &[T]) -> &[U] {
27    debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
28    unsafe { std::slice::from_raw_parts(src.as_ptr().cast::<U>(), src.len()) }
29}
30
31#[cfg(feature = "simd")]
32#[inline(always)]
33unsafe fn cast_slice_mut<T, U>(src: &mut [T]) -> &mut [U] {
34    debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
35    unsafe { std::slice::from_raw_parts_mut(src.as_mut_ptr().cast::<U>(), src.len()) }
36}
37
38#[cfg(feature = "simd")]
39macro_rules! impl_simd_mul_partial {
40    (
41        $mul_into:ident,
42        $ty:ty,
43        $lanes:ident,
44        $load:ident,
45        $store:ident,
46        $mul:ident
47    ) => {
48        fn $mul_into(dst: &mut [$ty], a: &[$ty], b: &[$ty]) {
49            struct Mul<'a> {
50                dst: &'a mut [$ty],
51                a: &'a [$ty],
52                b: &'a [$ty],
53            }
54
55            impl<'a> pulp::WithSimd for Mul<'a> {
56                type Output = ();
57
58                #[inline(always)]
59                fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
60                    debug_assert_eq!(self.dst.len(), self.a.len());
61                    debug_assert_eq!(self.dst.len(), self.b.len());
62
63                    let lanes = S::$lanes;
64                    let mut i = 0usize;
65                    while i + lanes <= self.dst.len() {
66                        let va = simd.$load(&self.a[i..i + lanes]);
67                        let vb = simd.$load(&self.b[i..i + lanes]);
68                        simd.$store(&mut self.dst[i..i + lanes], simd.$mul(va, vb));
69                        i += lanes;
70                    }
71                    if i < self.dst.len() {
72                        let va = simd.$load(&self.a[i..]);
73                        let vb = simd.$load(&self.b[i..]);
74                        simd.$store(&mut self.dst[i..], simd.$mul(va, vb));
75                    }
76                }
77            }
78
79            pulp::Arch::new().dispatch(Mul { dst, a, b });
80        }
81    };
82}
83
84#[cfg(feature = "simd")]
85macro_rules! impl_simd_mul_body_tail {
86    (
87        $mul_into:ident,
88        $ty:ty,
89        $as_simd:ident,
90        $as_mut_simd:ident,
91        $load:ident,
92        $store:ident,
93        $mul:ident
94    ) => {
95        fn $mul_into(dst: &mut [$ty], a: &[$ty], b: &[$ty]) {
96            struct Mul<'a> {
97                dst: &'a mut [$ty],
98                a: &'a [$ty],
99                b: &'a [$ty],
100            }
101
102            impl<'a> pulp::WithSimd for Mul<'a> {
103                type Output = ();
104
105                #[inline(always)]
106                fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
107                    debug_assert_eq!(self.dst.len(), self.a.len());
108                    debug_assert_eq!(self.dst.len(), self.b.len());
109
110                    let (dst_head, dst_tail) = S::$as_mut_simd(self.dst);
111                    let (a_head, a_tail) = S::$as_simd(self.a);
112                    let (b_head, b_tail) = S::$as_simd(self.b);
113                    debug_assert_eq!(dst_head.len(), a_head.len());
114                    debug_assert_eq!(dst_head.len(), b_head.len());
115                    debug_assert_eq!(dst_tail.len(), a_tail.len());
116                    debug_assert_eq!(dst_tail.len(), b_tail.len());
117
118                    for i in 0..dst_head.len() {
119                        dst_head[i] = simd.$mul(a_head[i], b_head[i]);
120                    }
121                    if !dst_tail.is_empty() {
122                        let va = simd.$load(a_tail);
123                        let vb = simd.$load(b_tail);
124                        simd.$store(dst_tail, simd.$mul(va, vb));
125                    }
126                }
127            }
128
129            pulp::Arch::new().dispatch(Mul { dst, a, b });
130        }
131    };
132}
133
134#[cfg(feature = "simd")]
135impl_simd_mul_partial!(
136    simd_mul_f32_into,
137    f32,
138    F32_LANES,
139    partial_load_f32s,
140    partial_store_f32s,
141    mul_f32s
142);
143
144#[cfg(feature = "simd")]
145impl_simd_mul_partial!(
146    simd_mul_f64_into,
147    f64,
148    F64_LANES,
149    partial_load_f64s,
150    partial_store_f64s,
151    mul_f64s
152);
153
154#[cfg(feature = "simd")]
155impl_simd_mul_body_tail!(
156    simd_mul_c32_into,
157    num_complex::Complex32,
158    as_simd_c32s,
159    as_mut_simd_c32s,
160    partial_load_c32s,
161    partial_store_c32s,
162    mul_e_c32s
163);
164
165#[cfg(feature = "simd")]
166impl_simd_mul_body_tail!(
167    simd_mul_c64_into,
168    num_complex::Complex64,
169    as_simd_c64s,
170    as_mut_simd_c64s,
171    partial_load_c64s,
172    partial_store_c64s,
173    mul_e_c64s
174);
175
176#[cfg(feature = "simd")]
177#[inline]
178pub(crate) fn try_mul_contiguous<D: 'static, A: 'static, B: 'static>(
179    dst: &mut [D],
180    a: &[A],
181    b: &[B],
182) -> bool {
183    use std::any::TypeId;
184
185    macro_rules! try_same_type {
186        ($ty:ty, $mul_into:ident) => {
187            if TypeId::of::<D>() == TypeId::of::<$ty>()
188                && TypeId::of::<A>() == TypeId::of::<$ty>()
189                && TypeId::of::<B>() == TypeId::of::<$ty>()
190            {
191                unsafe { $mul_into(cast_slice_mut(dst), cast_slice(a), cast_slice(b)) };
192                return true;
193            }
194        };
195    }
196
197    try_same_type!(f64, simd_mul_f64_into);
198    try_same_type!(f32, simd_mul_f32_into);
199    try_same_type!(num_complex::Complex64, simd_mul_c64_into);
200    try_same_type!(num_complex::Complex32, simd_mul_c32_into);
201
202    false
203}
204
205#[cfg(not(feature = "simd"))]
206#[inline]
207pub(crate) fn try_mul_contiguous<D: 'static, A: 'static, B: 'static>(
208    _dst: &mut [D],
209    _a: &[A],
210    _b: &[B],
211) -> bool {
212    false
213}
214
215#[cfg(feature = "parallel")]
216const TRANSPOSE_TILE: usize = 8;
217
218#[cfg(feature = "parallel")]
219unsafe fn mul_transposed_scalar_rhs_source_contiguous<T>(
220    dst: *mut T,
221    src: *const T,
222    scalar: T,
223    inner_len: usize,
224    row_len: usize,
225    src_fast_stride: isize,
226    src_row_stride: isize,
227) where
228    T: Copy + std::ops::Mul<Output = T>,
229{
230    let mut inner0 = 0usize;
231
232    while inner0 < inner_len {
233        let inner_count = TRANSPOSE_TILE.min(inner_len - inner0);
234        let mut row0 = 0usize;
235
236        while row0 < row_len {
237            let row_count = TRANSPOSE_TILE.min(row_len - row0);
238
239            for inner in 0..inner_count {
240                let src_base = src.offset((inner0 + inner) as isize * src_fast_stride);
241                for row in 0..row_count {
242                    let row_index = row0 + row;
243                    let src_offset = row_index as isize * src_row_stride;
244                    *dst.add(row_index * inner_len + inner0 + inner) =
245                        *src_base.offset(src_offset) * scalar;
246                }
247            }
248
249            row0 += TRANSPOSE_TILE;
250        }
251
252        inner0 += TRANSPOSE_TILE;
253    }
254}
255
256#[cfg(feature = "parallel")]
257unsafe fn mul_transposed_scalar_rhs_dst_contiguous<T>(
258    dst: *mut T,
259    src: *const T,
260    scalar: T,
261    inner_len: usize,
262    row_len: usize,
263    src_fast_stride: isize,
264    src_row_stride: isize,
265) where
266    T: Copy + std::ops::Mul<Output = T>,
267{
268    for row in 0..row_len {
269        let dst_base = dst.add(row * inner_len);
270        let src_base = src.offset(row as isize * src_row_stride);
271        for inner in 0..inner_len {
272            *dst_base.add(inner) = *src_base.offset(inner as isize * src_fast_stride) * scalar;
273        }
274    }
275}
276
277#[cfg(feature = "parallel")]
278unsafe fn mul_transposed_scalar_lhs_source_contiguous<T>(
279    dst: *mut T,
280    scalar: T,
281    src: *const T,
282    inner_len: usize,
283    row_len: usize,
284    src_fast_stride: isize,
285    src_row_stride: isize,
286) where
287    T: Copy + std::ops::Mul<Output = T>,
288{
289    let mut inner0 = 0usize;
290
291    while inner0 < inner_len {
292        let inner_count = TRANSPOSE_TILE.min(inner_len - inner0);
293        let mut row0 = 0usize;
294
295        while row0 < row_len {
296            let row_count = TRANSPOSE_TILE.min(row_len - row0);
297
298            for inner in 0..inner_count {
299                let src_base = src.offset((inner0 + inner) as isize * src_fast_stride);
300                for row in 0..row_count {
301                    let row_index = row0 + row;
302                    let src_offset = row_index as isize * src_row_stride;
303                    *dst.add(row_index * inner_len + inner0 + inner) =
304                        scalar * *src_base.offset(src_offset);
305                }
306            }
307
308            row0 += TRANSPOSE_TILE;
309        }
310
311        inner0 += TRANSPOSE_TILE;
312    }
313}
314
315#[cfg(feature = "parallel")]
316unsafe fn mul_transposed_scalar_lhs_dst_contiguous<T>(
317    dst: *mut T,
318    scalar: T,
319    src: *const T,
320    inner_len: usize,
321    row_len: usize,
322    src_fast_stride: isize,
323    src_row_stride: isize,
324) where
325    T: Copy + std::ops::Mul<Output = T>,
326{
327    for row in 0..row_len {
328        let dst_base = dst.add(row * inner_len);
329        let src_base = src.offset(row as isize * src_row_stride);
330        for inner in 0..inner_len {
331            *dst_base.add(inner) = scalar * *src_base.offset(inner as isize * src_fast_stride);
332        }
333    }
334}
335
336#[cfg(feature = "parallel")]
337#[inline(always)]
338pub(crate) unsafe fn mul_transposed_scalar_rhs_2d_typed<T>(
339    dst: *mut T,
340    src: *const T,
341    scalar: T,
342    inner_len: usize,
343    row_len: usize,
344    src_fast_stride: isize,
345    src_row_stride: isize,
346) where
347    T: Copy + std::ops::Mul<Output = T>,
348{
349    if row_len > inner_len {
350        unsafe {
351            mul_transposed_scalar_rhs_source_contiguous(
352                dst,
353                src,
354                scalar,
355                inner_len,
356                row_len,
357                src_fast_stride,
358                src_row_stride,
359            );
360        }
361    } else {
362        unsafe {
363            mul_transposed_scalar_rhs_dst_contiguous(
364                dst,
365                src,
366                scalar,
367                inner_len,
368                row_len,
369                src_fast_stride,
370                src_row_stride,
371            );
372        }
373    }
374}
375
376#[cfg(feature = "parallel")]
377#[inline(always)]
378pub(crate) unsafe fn mul_transposed_scalar_lhs_2d_typed<T>(
379    dst: *mut T,
380    scalar: T,
381    src: *const T,
382    inner_len: usize,
383    row_len: usize,
384    src_fast_stride: isize,
385    src_row_stride: isize,
386) where
387    T: Copy + std::ops::Mul<Output = T>,
388{
389    if row_len > inner_len {
390        unsafe {
391            mul_transposed_scalar_lhs_source_contiguous(
392                dst,
393                scalar,
394                src,
395                inner_len,
396                row_len,
397                src_fast_stride,
398                src_row_stride,
399            );
400        }
401    } else {
402        unsafe {
403            mul_transposed_scalar_lhs_dst_contiguous(
404                dst,
405                scalar,
406                src,
407                inner_len,
408                row_len,
409                src_fast_stride,
410                src_row_stride,
411            );
412        }
413    }
414}
415
416#[inline]
417#[cfg(feature = "parallel")]
418pub(crate) unsafe fn try_mul_transposed_scalar_rhs_2d<D: 'static, A: 'static, B: 'static>(
419    dst: *mut D,
420    src: *const A,
421    scalar: *const B,
422    inner_len: usize,
423    row_len: usize,
424    src_fast_stride: isize,
425    src_row_stride: isize,
426) -> bool {
427    use std::any::TypeId;
428
429    macro_rules! try_same_type {
430        ($ty:ty) => {
431            if TypeId::of::<D>() == TypeId::of::<$ty>()
432                && TypeId::of::<A>() == TypeId::of::<$ty>()
433                && TypeId::of::<B>() == TypeId::of::<$ty>()
434            {
435                unsafe {
436                    mul_transposed_scalar_rhs_2d_typed(
437                        dst.cast::<$ty>(),
438                        src.cast::<$ty>(),
439                        *scalar.cast::<$ty>(),
440                        inner_len,
441                        row_len,
442                        src_fast_stride,
443                        src_row_stride,
444                    );
445                }
446                return true;
447            }
448        };
449    }
450
451    try_same_type!(f64);
452    try_same_type!(f32);
453    try_same_type!(num_complex::Complex64);
454    try_same_type!(num_complex::Complex32);
455
456    false
457}
458
459#[inline]
460#[cfg(feature = "parallel")]
461pub(crate) unsafe fn try_mul_transposed_scalar_lhs_2d<D: 'static, A: 'static, B: 'static>(
462    dst: *mut D,
463    scalar: *const A,
464    src: *const B,
465    inner_len: usize,
466    row_len: usize,
467    src_fast_stride: isize,
468    src_row_stride: isize,
469) -> bool {
470    use std::any::TypeId;
471
472    macro_rules! try_same_type {
473        ($ty:ty) => {
474            if TypeId::of::<D>() == TypeId::of::<$ty>()
475                && TypeId::of::<A>() == TypeId::of::<$ty>()
476                && TypeId::of::<B>() == TypeId::of::<$ty>()
477            {
478                unsafe {
479                    mul_transposed_scalar_lhs_2d_typed(
480                        dst.cast::<$ty>(),
481                        *scalar.cast::<$ty>(),
482                        src.cast::<$ty>(),
483                        inner_len,
484                        row_len,
485                        src_fast_stride,
486                        src_row_stride,
487                    );
488                }
489                return true;
490            }
491        };
492    }
493
494    try_same_type!(f64);
495    try_same_type!(f32);
496    try_same_type!(num_complex::Complex64);
497    try_same_type!(num_complex::Complex32);
498
499    false
500}
501
502/// Trait for types that may have SIMD-accelerated sum/dot operations.
503///
504/// Default implementations return `None` (no SIMD available).
505/// f32/f64 override these with SIMD kernels when the `simd` feature is enabled.
506pub trait MaybeSimdOps: Copy + Sized {
507    fn try_simd_sum(_src: &[Self]) -> Option<Self> {
508        None
509    }
510    fn try_simd_dot(_a: &[Self], _b: &[Self]) -> Option<Self> {
511        None
512    }
513}
514
515// Default (no-op) impls for integer types and Complex
516macro_rules! impl_no_simd {
517    ($($t:ty),*) => {
518        $(impl MaybeSimdOps for $t {})*
519    };
520}
521
522impl_no_simd!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize);
523
524impl<T: num_traits::Num + Copy + Clone + std::ops::Neg<Output = T>> MaybeSimdOps
525    for num_complex::Complex<T>
526{
527}
528
529// f32/f64: SIMD-accelerated when feature enabled, no-op otherwise
530#[cfg(not(feature = "simd"))]
531impl MaybeSimdOps for f32 {}
532
533#[cfg(not(feature = "simd"))]
534impl MaybeSimdOps for f64 {}
535
536#[cfg(feature = "simd")]
537mod simd_impls {
538    use super::MaybeSimdOps;
539    use pulp::{Simd, WithSimd};
540
541    impl MaybeSimdOps for f32 {
542        fn try_simd_sum(src: &[f32]) -> Option<f32> {
543            struct Sum<'a>(&'a [f32]);
544            impl<'a> WithSimd for Sum<'a> {
545                type Output = f32;
546
547                #[inline(always)]
548                fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
549                    let (head, tail) = S::as_simd_f32s(self.0);
550
551                    let mut acc0 = simd.splat_f32s(0.0);
552                    let mut acc1 = simd.splat_f32s(0.0);
553                    let mut acc2 = simd.splat_f32s(0.0);
554                    let mut acc3 = simd.splat_f32s(0.0);
555
556                    let mut i = 0usize;
557                    while i + 4 <= head.len() {
558                        acc0 = simd.add_f32s(acc0, head[i]);
559                        acc1 = simd.add_f32s(acc1, head[i + 1]);
560                        acc2 = simd.add_f32s(acc2, head[i + 2]);
561                        acc3 = simd.add_f32s(acc3, head[i + 3]);
562                        i += 4;
563                    }
564                    for &v in &head[i..] {
565                        acc0 = simd.add_f32s(acc0, v);
566                    }
567
568                    let acc = simd.add_f32s(simd.add_f32s(acc0, acc1), simd.add_f32s(acc2, acc3));
569                    let mut sum = simd.reduce_sum_f32s(acc);
570                    for &x in tail {
571                        sum += x;
572                    }
573                    sum
574                }
575            }
576
577            Some(pulp::Arch::new().dispatch(Sum(src)))
578        }
579
580        fn try_simd_dot(a: &[f32], b: &[f32]) -> Option<f32> {
581            struct Dot<'a> {
582                a: &'a [f32],
583                b: &'a [f32],
584            }
585            impl<'a> WithSimd for Dot<'a> {
586                type Output = f32;
587
588                #[inline(always)]
589                fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
590                    debug_assert_eq!(self.a.len(), self.b.len());
591                    let (a_head, a_tail) = S::as_simd_f32s(self.a);
592                    let (b_head, b_tail) = S::as_simd_f32s(self.b);
593                    debug_assert_eq!(a_head.len(), b_head.len());
594                    debug_assert_eq!(a_tail.len(), b_tail.len());
595
596                    let mut acc0 = simd.splat_f32s(0.0);
597                    let mut acc1 = simd.splat_f32s(0.0);
598                    let mut acc2 = simd.splat_f32s(0.0);
599                    let mut acc3 = simd.splat_f32s(0.0);
600
601                    let mut i = 0usize;
602                    while i + 4 <= a_head.len() {
603                        acc0 = simd.mul_add_f32s(a_head[i], b_head[i], acc0);
604                        acc1 = simd.mul_add_f32s(a_head[i + 1], b_head[i + 1], acc1);
605                        acc2 = simd.mul_add_f32s(a_head[i + 2], b_head[i + 2], acc2);
606                        acc3 = simd.mul_add_f32s(a_head[i + 3], b_head[i + 3], acc3);
607                        i += 4;
608                    }
609                    for j in i..a_head.len() {
610                        acc0 = simd.mul_add_f32s(a_head[j], b_head[j], acc0);
611                    }
612
613                    let acc = simd.add_f32s(simd.add_f32s(acc0, acc1), simd.add_f32s(acc2, acc3));
614                    let mut sum = simd.reduce_sum_f32s(acc);
615                    for (&x, &y) in a_tail.iter().zip(b_tail.iter()) {
616                        sum += x * y;
617                    }
618                    sum
619                }
620            }
621
622            Some(pulp::Arch::new().dispatch(Dot { a, b }))
623        }
624    }
625
626    impl MaybeSimdOps for f64 {
627        fn try_simd_sum(src: &[f64]) -> Option<f64> {
628            struct Sum<'a>(&'a [f64]);
629            impl<'a> WithSimd for Sum<'a> {
630                type Output = f64;
631
632                #[inline(always)]
633                fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
634                    let (head, tail) = S::as_simd_f64s(self.0);
635
636                    let mut acc0 = simd.splat_f64s(0.0);
637                    let mut acc1 = simd.splat_f64s(0.0);
638                    let mut acc2 = simd.splat_f64s(0.0);
639                    let mut acc3 = simd.splat_f64s(0.0);
640
641                    let mut i = 0usize;
642                    while i + 4 <= head.len() {
643                        acc0 = simd.add_f64s(acc0, head[i]);
644                        acc1 = simd.add_f64s(acc1, head[i + 1]);
645                        acc2 = simd.add_f64s(acc2, head[i + 2]);
646                        acc3 = simd.add_f64s(acc3, head[i + 3]);
647                        i += 4;
648                    }
649                    for &v in &head[i..] {
650                        acc0 = simd.add_f64s(acc0, v);
651                    }
652
653                    let acc = simd.add_f64s(simd.add_f64s(acc0, acc1), simd.add_f64s(acc2, acc3));
654                    let mut sum = simd.reduce_sum_f64s(acc);
655                    for &x in tail {
656                        sum += x;
657                    }
658                    sum
659                }
660            }
661
662            Some(pulp::Arch::new().dispatch(Sum(src)))
663        }
664
665        fn try_simd_dot(a: &[f64], b: &[f64]) -> Option<f64> {
666            struct Dot<'a> {
667                a: &'a [f64],
668                b: &'a [f64],
669            }
670            impl<'a> WithSimd for Dot<'a> {
671                type Output = f64;
672
673                #[inline(always)]
674                fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
675                    debug_assert_eq!(self.a.len(), self.b.len());
676                    let (a_head, a_tail) = S::as_simd_f64s(self.a);
677                    let (b_head, b_tail) = S::as_simd_f64s(self.b);
678                    debug_assert_eq!(a_head.len(), b_head.len());
679                    debug_assert_eq!(a_tail.len(), b_tail.len());
680
681                    let mut acc0 = simd.splat_f64s(0.0);
682                    let mut acc1 = simd.splat_f64s(0.0);
683                    let mut acc2 = simd.splat_f64s(0.0);
684                    let mut acc3 = simd.splat_f64s(0.0);
685
686                    let mut i = 0usize;
687                    while i + 4 <= a_head.len() {
688                        acc0 = simd.mul_add_f64s(a_head[i], b_head[i], acc0);
689                        acc1 = simd.mul_add_f64s(a_head[i + 1], b_head[i + 1], acc1);
690                        acc2 = simd.mul_add_f64s(a_head[i + 2], b_head[i + 2], acc2);
691                        acc3 = simd.mul_add_f64s(a_head[i + 3], b_head[i + 3], acc3);
692                        i += 4;
693                    }
694                    for j in i..a_head.len() {
695                        acc0 = simd.mul_add_f64s(a_head[j], b_head[j], acc0);
696                    }
697
698                    let acc = simd.add_f64s(simd.add_f64s(acc0, acc1), simd.add_f64s(acc2, acc3));
699                    let mut sum = simd.reduce_sum_f64s(acc);
700                    for (&x, &y) in a_tail.iter().zip(b_tail.iter()) {
701                        sum += x * y;
702                    }
703                    sum
704                }
705            }
706
707            Some(pulp::Arch::new().dispatch(Dot { a, b }))
708        }
709    }
710}
711
712#[cfg(test)]
713mod tests {
714    #[cfg(feature = "simd")]
715    #[test]
716    fn test_try_mul_contiguous_complex64() {
717        let a = vec![
718            num_complex::Complex64::new(1.0, 2.0),
719            num_complex::Complex64::new(-3.0, 4.0),
720            num_complex::Complex64::new(0.5, -0.25),
721        ];
722        let b = vec![
723            num_complex::Complex64::new(5.0, -1.0),
724            num_complex::Complex64::new(2.0, 0.25),
725            num_complex::Complex64::new(-4.0, 3.0),
726        ];
727        let mut dst = vec![num_complex::Complex64::new(0.0, 0.0); a.len()];
728
729        assert!(super::try_mul_contiguous(&mut dst, &a, &b));
730        for i in 0..a.len() {
731            assert_eq!(dst[i], a[i] * b[i]);
732        }
733    }
734
735    #[cfg(feature = "simd")]
736    #[test]
737    fn test_try_mul_contiguous_complex32() {
738        let a = vec![
739            num_complex::Complex32::new(1.0, 2.0),
740            num_complex::Complex32::new(-3.0, 4.0),
741            num_complex::Complex32::new(0.5, -0.25),
742        ];
743        let b = vec![
744            num_complex::Complex32::new(5.0, -1.0),
745            num_complex::Complex32::new(2.0, 0.25),
746            num_complex::Complex32::new(-4.0, 3.0),
747        ];
748        let mut dst = vec![num_complex::Complex32::new(0.0, 0.0); a.len()];
749
750        assert!(super::try_mul_contiguous(&mut dst, &a, &b));
751        for i in 0..a.len() {
752            assert_eq!(dst[i], a[i] * b[i]);
753        }
754    }
755
756    #[cfg(feature = "parallel")]
757    #[test]
758    fn test_transposed_scalar_rhs_2d_f64_source_contiguous() {
759        let inner_len = 5usize;
760        let row_len = 7usize;
761        let src: Vec<f64> = (0..inner_len * row_len).map(|i| i as f64 + 0.25).collect();
762        let scalar = 2.0f64;
763        let mut dst = vec![0.0f64; inner_len * row_len];
764
765        let used = unsafe {
766            super::try_mul_transposed_scalar_rhs_2d::<f64, f64, f64>(
767                dst.as_mut_ptr(),
768                src.as_ptr(),
769                &scalar,
770                inner_len,
771                row_len,
772                row_len as isize,
773                1,
774            )
775        };
776
777        assert!(used);
778        for row in 0..row_len {
779            for inner in 0..inner_len {
780                assert_eq!(
781                    dst[row * inner_len + inner],
782                    src[inner * row_len + row] * scalar
783                );
784            }
785        }
786    }
787
788    #[cfg(feature = "parallel")]
789    #[test]
790    fn test_transposed_scalar_lhs_2d_f32_source_contiguous() {
791        let inner_len = 5usize;
792        let row_len = 7usize;
793        let src: Vec<f32> = (0..inner_len * row_len)
794            .map(|i| i as f32 * 0.5 + 1.0)
795            .collect();
796        let scalar = 3.0f32;
797        let mut dst = vec![0.0f32; inner_len * row_len];
798
799        let used = unsafe {
800            super::try_mul_transposed_scalar_lhs_2d::<f32, f32, f32>(
801                dst.as_mut_ptr(),
802                &scalar,
803                src.as_ptr(),
804                inner_len,
805                row_len,
806                row_len as isize,
807                1,
808            )
809        };
810
811        assert!(used);
812        for row in 0..row_len {
813            for inner in 0..inner_len {
814                assert_eq!(
815                    dst[row * inner_len + inner],
816                    scalar * src[inner * row_len + row]
817                );
818            }
819        }
820    }
821
822    #[cfg(feature = "parallel")]
823    #[test]
824    fn test_transposed_scalar_rhs_2d_handles_short_rows() {
825        let inner_len = 8usize;
826        let row_len = 3usize;
827        let src: Vec<f64> = (0..inner_len * row_len).map(|i| i as f64 + 1.0).collect();
828        let scalar = 2.0f64;
829        let mut dst = vec![0.0f64; inner_len * row_len];
830
831        let used = unsafe {
832            super::try_mul_transposed_scalar_rhs_2d::<f64, f64, f64>(
833                dst.as_mut_ptr(),
834                src.as_ptr(),
835                &scalar,
836                inner_len,
837                row_len,
838                row_len as isize,
839                1,
840            )
841        };
842
843        assert!(used);
844        for row in 0..row_len {
845            for inner in 0..inner_len {
846                assert_eq!(
847                    dst[row * inner_len + inner],
848                    src[inner * row_len + row] * scalar
849                );
850            }
851        }
852    }
853
854    #[cfg(feature = "parallel")]
855    #[test]
856    fn test_transposed_scalar_2d_handles_small_square_tiles() {
857        let inner_len = 4usize;
858        let row_len = 4usize;
859        let src: Vec<f64> = (0..inner_len * row_len).map(|i| i as f64 + 1.0).collect();
860        let scalar = 2.0f64;
861        let mut rhs_dst = vec![0.0f64; inner_len * row_len];
862        let mut lhs_dst = vec![0.0f64; inner_len * row_len];
863
864        let rhs_used = unsafe {
865            super::try_mul_transposed_scalar_rhs_2d::<f64, f64, f64>(
866                rhs_dst.as_mut_ptr(),
867                src.as_ptr(),
868                &scalar,
869                inner_len,
870                row_len,
871                row_len as isize,
872                1,
873            )
874        };
875        let lhs_used = unsafe {
876            super::try_mul_transposed_scalar_lhs_2d::<f64, f64, f64>(
877                lhs_dst.as_mut_ptr(),
878                &scalar,
879                src.as_ptr(),
880                inner_len,
881                row_len,
882                row_len as isize,
883                1,
884            )
885        };
886
887        assert!(rhs_used);
888        assert!(lhs_used);
889        for row in 0..row_len {
890            for inner in 0..inner_len {
891                assert_eq!(
892                    rhs_dst[row * inner_len + inner],
893                    src[inner * row_len + row] * scalar
894                );
895                assert_eq!(
896                    lhs_dst[row * inner_len + inner],
897                    scalar * src[inner * row_len + row]
898                );
899            }
900        }
901    }
902
903    #[cfg(feature = "parallel")]
904    #[test]
905    fn test_transposed_scalar_rhs_2d_complex64_source_contiguous() {
906        let inner_len = 5usize;
907        let row_len = 7usize;
908        let src: Vec<num_complex::Complex64> = (0..inner_len * row_len)
909            .map(|i| num_complex::Complex64::new(i as f64 + 0.25, i as f64 * -0.5))
910            .collect();
911        let scalar = num_complex::Complex64::new(2.0, -0.25);
912        let mut dst = vec![num_complex::Complex64::new(0.0, 0.0); inner_len * row_len];
913
914        let used = unsafe {
915            super::try_mul_transposed_scalar_rhs_2d::<
916                num_complex::Complex64,
917                num_complex::Complex64,
918                num_complex::Complex64,
919            >(
920                dst.as_mut_ptr(),
921                src.as_ptr(),
922                &scalar,
923                inner_len,
924                row_len,
925                row_len as isize,
926                1,
927            )
928        };
929
930        assert!(used);
931        for row in 0..row_len {
932            for inner in 0..inner_len {
933                assert_eq!(
934                    dst[row * inner_len + inner],
935                    src[inner * row_len + row] * scalar
936                );
937            }
938        }
939    }
940
941    #[cfg(feature = "parallel")]
942    #[test]
943    fn test_transposed_scalar_lhs_2d_complex32_source_contiguous() {
944        let inner_len = 5usize;
945        let row_len = 7usize;
946        let src: Vec<num_complex::Complex32> = (0..inner_len * row_len)
947            .map(|i| num_complex::Complex32::new(i as f32 * 0.5 + 1.0, i as f32 * 0.25))
948            .collect();
949        let scalar = num_complex::Complex32::new(3.0, -0.5);
950        let mut dst = vec![num_complex::Complex32::new(0.0, 0.0); inner_len * row_len];
951
952        let used = unsafe {
953            super::try_mul_transposed_scalar_lhs_2d::<
954                num_complex::Complex32,
955                num_complex::Complex32,
956                num_complex::Complex32,
957            >(
958                dst.as_mut_ptr(),
959                &scalar,
960                src.as_ptr(),
961                inner_len,
962                row_len,
963                row_len as isize,
964                1,
965            )
966        };
967
968        assert!(used);
969        for row in 0..row_len {
970            for inner in 0..inner_len {
971                assert_eq!(
972                    dst[row * inner_len + inner],
973                    scalar * src[inner * row_len + row]
974                );
975            }
976        }
977    }
978}