1use crate::kernel::{
4 build_plan_fused, ensure_same_shape, for_each_inner_block_preordered, same_contiguous_layout,
5 sequential_contiguous_layout, total_len,
6};
7use crate::map_view::{map_into, zip_map2_into};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::reduce_view::reduce;
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::{Result, StridedError};
13use num_traits::{One, Zero};
14use std::ops::{Add, Mul};
15use strided_view::{ElementOp, ElementOpApply};
16
17#[cfg(feature = "parallel")]
18use crate::fuse::compute_costs;
19#[cfg(feature = "parallel")]
20use crate::threading::{
21 for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
22};
23#[cfg(feature = "parallel")]
24use rayon::iter::{IntoParallelIterator, ParallelIterator};
25
26#[inline(always)]
36unsafe fn inner_loop_add<D: Copy + Add<S, Output = D>, S: Copy, Op: ElementOp<S>>(
37 dp: *mut D,
38 ds: isize,
39 sp: *const S,
40 ss: isize,
41 len: usize,
42) {
43 if ds == 1 && ss == 1 {
44 let dst = std::slice::from_raw_parts_mut(dp, len);
45 let src = std::slice::from_raw_parts(sp, len);
46 simd::dispatch_if_large(len, || {
47 for i in 0..len {
48 dst[i] = dst[i] + Op::apply(src[i]);
49 }
50 });
51 } else {
52 let mut dp = dp;
53 let mut sp = sp;
54 for _ in 0..len {
55 *dp = *dp + Op::apply(*sp);
56 dp = dp.offset(ds);
57 sp = sp.offset(ss);
58 }
59 }
60}
61
62#[inline(always)]
64unsafe fn inner_loop_mul<D: Copy + Mul<S, Output = D>, S: Copy, Op: ElementOp<S>>(
65 dp: *mut D,
66 ds: isize,
67 sp: *const S,
68 ss: isize,
69 len: usize,
70) {
71 if ds == 1 && ss == 1 {
72 let dst = std::slice::from_raw_parts_mut(dp, len);
73 let src = std::slice::from_raw_parts(sp, len);
74 simd::dispatch_if_large(len, || {
75 for i in 0..len {
76 dst[i] = dst[i] * Op::apply(src[i]);
77 }
78 });
79 } else {
80 let mut dp = dp;
81 let mut sp = sp;
82 for _ in 0..len {
83 *dp = *dp * Op::apply(*sp);
84 dp = dp.offset(ds);
85 sp = sp.offset(ss);
86 }
87 }
88}
89
90#[inline(always)]
92unsafe fn inner_loop_axpy<
93 D: Copy + Add<D, Output = D>,
94 S: Copy,
95 A: Copy + Mul<S, Output = D>,
96 Op: ElementOp<S>,
97>(
98 dp: *mut D,
99 ds: isize,
100 sp: *const S,
101 ss: isize,
102 len: usize,
103 alpha: A,
104) {
105 if ds == 1 && ss == 1 {
106 let dst = std::slice::from_raw_parts_mut(dp, len);
107 let src = std::slice::from_raw_parts(sp, len);
108 simd::dispatch_if_large(len, || {
109 for i in 0..len {
110 dst[i] = alpha * Op::apply(src[i]) + dst[i];
111 }
112 });
113 } else {
114 let mut dp = dp;
115 let mut sp = sp;
116 for _ in 0..len {
117 *dp = alpha * Op::apply(*sp) + *dp;
118 dp = dp.offset(ds);
119 sp = sp.offset(ss);
120 }
121 }
122}
123
124#[inline(always)]
126unsafe fn inner_loop_fma<
127 D: Copy + Add<D, Output = D>,
128 A: Copy + Mul<B, Output = D>,
129 B: Copy,
130 OpA: ElementOp<A>,
131 OpB: ElementOp<B>,
132>(
133 dp: *mut D,
134 ds: isize,
135 ap: *const A,
136 a_s: isize,
137 bp: *const B,
138 b_s: isize,
139 len: usize,
140) {
141 if ds == 1 && a_s == 1 && b_s == 1 {
142 let dst = std::slice::from_raw_parts_mut(dp, len);
143 let sa = std::slice::from_raw_parts(ap, len);
144 let sb = std::slice::from_raw_parts(bp, len);
145 simd::dispatch_if_large(len, || {
146 for i in 0..len {
147 dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
148 }
149 });
150 } else {
151 let mut dp = dp;
152 let mut ap = ap;
153 let mut bp = bp;
154 for _ in 0..len {
155 *dp = *dp + OpA::apply(*ap) * OpB::apply(*bp);
156 dp = dp.offset(ds);
157 ap = ap.offset(a_s);
158 bp = bp.offset(b_s);
159 }
160 }
161}
162
163#[inline(always)]
165unsafe fn inner_loop_dot<
166 A: Copy + Mul<B, Output = R>,
167 B: Copy,
168 R: Copy + Add<R, Output = R>,
169 OpA: ElementOp<A>,
170 OpB: ElementOp<B>,
171>(
172 ap: *const A,
173 a_s: isize,
174 bp: *const B,
175 b_s: isize,
176 len: usize,
177 mut acc: R,
178) -> R {
179 if a_s == 1 && b_s == 1 {
180 let sa = std::slice::from_raw_parts(ap, len);
181 let sb = std::slice::from_raw_parts(bp, len);
182 simd::dispatch_if_large(len, || {
183 for i in 0..len {
184 acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
185 }
186 });
187 } else {
188 let mut ap = ap;
189 let mut bp = bp;
190 for _ in 0..len {
191 acc = acc + OpA::apply(*ap) * OpB::apply(*bp);
192 ap = ap.offset(a_s);
193 bp = bp.offset(b_s);
194 }
195 }
196 acc
197}
198
199pub fn copy_into<T: Copy + MaybeSendSync, Op: ElementOp<T>>(
201 dest: &mut StridedViewMut<T>,
202 src: &StridedView<T, Op>,
203) -> Result<()> {
204 ensure_same_shape(dest.dims(), src.dims())?;
205
206 let dst_ptr = dest.as_mut_ptr();
207 let src_ptr = src.ptr();
208 let dst_dims = dest.dims();
209 let dst_strides = dest.strides();
210 let src_strides = src.strides();
211
212 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
213 let len = total_len(dst_dims);
214 if Op::IS_IDENTITY {
215 debug_assert!(
216 {
217 let nbytes = len
218 .checked_mul(std::mem::size_of::<T>())
219 .expect("copy size must not overflow");
220 let dst_start = dst_ptr as usize;
221 let src_start = src_ptr as usize;
222 let dst_end = dst_start.saturating_add(nbytes);
223 let src_end = src_start.saturating_add(nbytes);
224 dst_end <= src_start || src_end <= dst_start
225 },
226 "overlapping src/dest is not supported"
227 );
228 unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
229 } else {
230 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
231 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
232 simd::dispatch_if_large(len, || {
233 for i in 0..len {
234 dst[i] = Op::apply(src[i]);
235 }
236 });
237 }
238 return Ok(());
239 }
240
241 map_into(dest, src, |x| x)
242}
243
244pub fn copy_into_col_major<T: Copy + MaybeSendSync>(
248 dst: &mut StridedViewMut<T>,
249 src: &StridedView<T>,
250) -> Result<()> {
251 strided_perm::copy_into_col_major(dst, src)
252}
253
254pub fn add<
258 D: Copy + Add<S, Output = D> + MaybeSendSync,
259 S: Copy + MaybeSendSync,
260 Op: ElementOp<S>,
261>(
262 dest: &mut StridedViewMut<D>,
263 src: &StridedView<S, Op>,
264) -> Result<()> {
265 ensure_same_shape(dest.dims(), src.dims())?;
266
267 let dst_ptr = dest.as_mut_ptr();
268 let src_ptr = src.ptr();
269 let dst_dims = dest.dims();
270 let dst_strides = dest.strides();
271 let src_strides = src.strides();
272
273 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
274 let len = total_len(dst_dims);
275 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
276 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
277 simd::dispatch_if_large(len, || {
278 for i in 0..len {
279 dst[i] = dst[i] + Op::apply(src[i]);
280 }
281 });
282 return Ok(());
283 }
284
285 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
286 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
287
288 let (fused_dims, ordered_strides, plan) =
289 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
290
291 #[cfg(feature = "parallel")]
292 {
293 let total: usize = fused_dims.iter().product();
294 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
295 let dst_send = SendPtr(dst_ptr);
296 let src_send = SendPtr(src_ptr as *mut S);
297
298 let costs = compute_costs(&ordered_strides);
299 let initial_offsets = vec![0isize; strides_list.len()];
300 let nthreads = rayon::current_num_threads();
301
302 return mapreduce_threaded(
303 &fused_dims,
304 &plan.block,
305 &ordered_strides,
306 &initial_offsets,
307 &costs,
308 nthreads,
309 0,
310 1,
311 &|dims, blocks, strides_list, offsets| {
312 for_each_inner_block_with_offsets(
313 dims,
314 blocks,
315 strides_list,
316 offsets,
317 |offsets, len, strides| {
318 unsafe {
319 inner_loop_add::<D, S, Op>(
320 dst_send.as_ptr().offset(offsets[0]),
321 strides[0],
322 src_send.as_const().offset(offsets[1]),
323 strides[1],
324 len,
325 )
326 };
327 Ok(())
328 },
329 )
330 },
331 );
332 }
333 }
334
335 let initial_offsets = vec![0isize; ordered_strides.len()];
336 for_each_inner_block_preordered(
337 &fused_dims,
338 &plan.block,
339 &ordered_strides,
340 &initial_offsets,
341 |offsets, len, strides| {
342 unsafe {
343 inner_loop_add::<D, S, Op>(
344 dst_ptr.offset(offsets[0]),
345 strides[0],
346 src_ptr.offset(offsets[1]),
347 strides[1],
348 len,
349 )
350 };
351 Ok(())
352 },
353 )
354}
355
356pub fn mul<
360 D: Copy + Mul<S, Output = D> + MaybeSendSync,
361 S: Copy + MaybeSendSync,
362 Op: ElementOp<S>,
363>(
364 dest: &mut StridedViewMut<D>,
365 src: &StridedView<S, Op>,
366) -> Result<()> {
367 ensure_same_shape(dest.dims(), src.dims())?;
368
369 let dst_ptr = dest.as_mut_ptr();
370 let src_ptr = src.ptr();
371 let dst_dims = dest.dims();
372 let dst_strides = dest.strides();
373 let src_strides = src.strides();
374
375 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
376 let len = total_len(dst_dims);
377 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
378 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
379 simd::dispatch_if_large(len, || {
380 for i in 0..len {
381 dst[i] = dst[i] * Op::apply(src[i]);
382 }
383 });
384 return Ok(());
385 }
386
387 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
388 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
389
390 let (fused_dims, ordered_strides, plan) =
391 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
392
393 #[cfg(feature = "parallel")]
394 {
395 let total: usize = fused_dims.iter().product();
396 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
397 let dst_send = SendPtr(dst_ptr);
398 let src_send = SendPtr(src_ptr as *mut S);
399
400 let costs = compute_costs(&ordered_strides);
401 let initial_offsets = vec![0isize; strides_list.len()];
402 let nthreads = rayon::current_num_threads();
403
404 return mapreduce_threaded(
405 &fused_dims,
406 &plan.block,
407 &ordered_strides,
408 &initial_offsets,
409 &costs,
410 nthreads,
411 0,
412 1,
413 &|dims, blocks, strides_list, offsets| {
414 for_each_inner_block_with_offsets(
415 dims,
416 blocks,
417 strides_list,
418 offsets,
419 |offsets, len, strides| {
420 unsafe {
421 inner_loop_mul::<D, S, Op>(
422 dst_send.as_ptr().offset(offsets[0]),
423 strides[0],
424 src_send.as_const().offset(offsets[1]),
425 strides[1],
426 len,
427 )
428 };
429 Ok(())
430 },
431 )
432 },
433 );
434 }
435 }
436
437 let initial_offsets = vec![0isize; ordered_strides.len()];
438 for_each_inner_block_preordered(
439 &fused_dims,
440 &plan.block,
441 &ordered_strides,
442 &initial_offsets,
443 |offsets, len, strides| {
444 unsafe {
445 inner_loop_mul::<D, S, Op>(
446 dst_ptr.offset(offsets[0]),
447 strides[0],
448 src_ptr.offset(offsets[1]),
449 strides[1],
450 len,
451 )
452 };
453 Ok(())
454 },
455 )
456}
457
458pub fn axpy<D, S, A, Op>(
462 dest: &mut StridedViewMut<D>,
463 src: &StridedView<S, Op>,
464 alpha: A,
465) -> Result<()>
466where
467 A: Copy + Mul<S, Output = D> + MaybeSync,
468 D: Copy + Add<D, Output = D> + MaybeSendSync,
469 S: Copy + MaybeSendSync,
470 Op: ElementOp<S>,
471{
472 ensure_same_shape(dest.dims(), src.dims())?;
473
474 let dst_ptr = dest.as_mut_ptr();
475 let src_ptr = src.ptr();
476 let dst_dims = dest.dims();
477 let dst_strides = dest.strides();
478 let src_strides = src.strides();
479
480 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
481 let len = total_len(dst_dims);
482 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
483 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
484 simd::dispatch_if_large(len, || {
485 for i in 0..len {
486 dst[i] = alpha * Op::apply(src[i]) + dst[i];
487 }
488 });
489 return Ok(());
490 }
491
492 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
493 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
494
495 let (fused_dims, ordered_strides, plan) =
496 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
497
498 #[cfg(feature = "parallel")]
499 {
500 let total: usize = fused_dims.iter().product();
501 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
502 let dst_send = SendPtr(dst_ptr);
503 let src_send = SendPtr(src_ptr as *mut S);
504
505 let costs = compute_costs(&ordered_strides);
506 let initial_offsets = vec![0isize; strides_list.len()];
507 let nthreads = rayon::current_num_threads();
508
509 return mapreduce_threaded(
510 &fused_dims,
511 &plan.block,
512 &ordered_strides,
513 &initial_offsets,
514 &costs,
515 nthreads,
516 0,
517 1,
518 &|dims, blocks, strides_list, offsets| {
519 for_each_inner_block_with_offsets(
520 dims,
521 blocks,
522 strides_list,
523 offsets,
524 |offsets, len, strides| {
525 unsafe {
526 inner_loop_axpy::<D, S, A, Op>(
527 dst_send.as_ptr().offset(offsets[0]),
528 strides[0],
529 src_send.as_const().offset(offsets[1]),
530 strides[1],
531 len,
532 alpha,
533 )
534 };
535 Ok(())
536 },
537 )
538 },
539 );
540 }
541 }
542
543 let initial_offsets = vec![0isize; ordered_strides.len()];
544 for_each_inner_block_preordered(
545 &fused_dims,
546 &plan.block,
547 &ordered_strides,
548 &initial_offsets,
549 |offsets, len, strides| {
550 unsafe {
551 inner_loop_axpy::<D, S, A, Op>(
552 dst_ptr.offset(offsets[0]),
553 strides[0],
554 src_ptr.offset(offsets[1]),
555 strides[1],
556 len,
557 alpha,
558 )
559 };
560 Ok(())
561 },
562 )
563}
564
565pub fn fma<D, A, B, OpA, OpB>(
569 dest: &mut StridedViewMut<D>,
570 a: &StridedView<A, OpA>,
571 b: &StridedView<B, OpB>,
572) -> Result<()>
573where
574 A: Copy + Mul<B, Output = D> + MaybeSendSync,
575 B: Copy + MaybeSendSync,
576 D: Copy + Add<D, Output = D> + MaybeSendSync,
577 OpA: ElementOp<A>,
578 OpB: ElementOp<B>,
579{
580 ensure_same_shape(dest.dims(), a.dims())?;
581 ensure_same_shape(dest.dims(), b.dims())?;
582
583 let dst_ptr = dest.as_mut_ptr();
584 let a_ptr = a.ptr();
585 let b_ptr = b.ptr();
586 let dst_dims = dest.dims();
587 let dst_strides = dest.strides();
588 let a_strides = a.strides();
589 let b_strides = b.strides();
590
591 if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
592 let len = total_len(dst_dims);
593 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
594 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
595 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
596 simd::dispatch_if_large(len, || {
597 for i in 0..len {
598 dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
599 }
600 });
601 return Ok(());
602 }
603
604 let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
605 let elem_size = std::mem::size_of::<D>()
606 .max(std::mem::size_of::<A>())
607 .max(std::mem::size_of::<B>());
608
609 let (fused_dims, ordered_strides, plan) =
610 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
611
612 #[cfg(feature = "parallel")]
613 {
614 let total: usize = fused_dims.iter().product();
615 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
616 let dst_send = SendPtr(dst_ptr);
617 let a_send = SendPtr(a_ptr as *mut A);
618 let b_send = SendPtr(b_ptr as *mut B);
619
620 let costs = compute_costs(&ordered_strides);
621 let initial_offsets = vec![0isize; strides_list.len()];
622 let nthreads = rayon::current_num_threads();
623
624 return mapreduce_threaded(
625 &fused_dims,
626 &plan.block,
627 &ordered_strides,
628 &initial_offsets,
629 &costs,
630 nthreads,
631 0,
632 1,
633 &|dims, blocks, strides_list, offsets| {
634 for_each_inner_block_with_offsets(
635 dims,
636 blocks,
637 strides_list,
638 offsets,
639 |offsets, len, strides| {
640 unsafe {
641 inner_loop_fma::<D, A, B, OpA, OpB>(
642 dst_send.as_ptr().offset(offsets[0]),
643 strides[0],
644 a_send.as_const().offset(offsets[1]),
645 strides[1],
646 b_send.as_const().offset(offsets[2]),
647 strides[2],
648 len,
649 )
650 };
651 Ok(())
652 },
653 )
654 },
655 );
656 }
657 }
658
659 let initial_offsets = vec![0isize; ordered_strides.len()];
660 for_each_inner_block_preordered(
661 &fused_dims,
662 &plan.block,
663 &ordered_strides,
664 &initial_offsets,
665 |offsets, len, strides| {
666 unsafe {
667 inner_loop_fma::<D, A, B, OpA, OpB>(
668 dst_ptr.offset(offsets[0]),
669 strides[0],
670 a_ptr.offset(offsets[1]),
671 strides[1],
672 b_ptr.offset(offsets[2]),
673 strides[2],
674 len,
675 )
676 };
677 Ok(())
678 },
679 )
680}
681
682#[cfg(feature = "parallel")]
683fn parallel_simd_sum<T: Copy + Zero + Add<Output = T> + simd::MaybeSimdOps + Send + Sync>(
684 src: &[T],
685) -> Option<T> {
686 use rayon::prelude::*;
687 if T::try_simd_sum(&[]).is_none() {
689 return None;
690 }
691 let nthreads = rayon::current_num_threads();
692 let chunk_size = (src.len() + nthreads - 1) / nthreads;
693 let result = src
694 .par_chunks(chunk_size)
695 .map(|chunk| T::try_simd_sum(chunk).unwrap())
696 .reduce(|| T::zero(), |a, b| a + b);
697 Some(result)
698}
699
700pub fn sum<
702 T: Copy + Zero + Add<Output = T> + MaybeSendSync + simd::MaybeSimdOps,
703 Op: ElementOp<T>,
704>(
705 src: &StridedView<T, Op>,
706) -> Result<T> {
707 if Op::IS_IDENTITY {
709 if same_contiguous_layout(src.dims(), &[src.strides()]).is_some() {
710 let len = total_len(src.dims());
711 let src_slice = unsafe { std::slice::from_raw_parts(src.ptr(), len) };
712
713 #[cfg(feature = "parallel")]
714 if len > MINTHREADLENGTH {
715 if let Some(result) = parallel_simd_sum(src_slice) {
716 return Ok(result);
717 }
718 }
719
720 if let Some(result) = T::try_simd_sum(src_slice) {
721 return Ok(result);
722 }
723 }
724 }
725 reduce(src, |x| x, |a, b| a + b, T::zero())
726}
727
728pub fn dot<A, B, R, OpA, OpB>(a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) -> Result<R>
733where
734 A: Copy + Mul<B, Output = R> + MaybeSendSync + 'static,
735 B: Copy + MaybeSendSync + 'static,
736 R: Copy + Zero + Add<Output = R> + MaybeSendSync + simd::MaybeSimdOps + 'static,
737 OpA: ElementOp<A>,
738 OpB: ElementOp<B>,
739{
740 ensure_same_shape(a.dims(), b.dims())?;
741
742 let a_ptr = a.ptr();
743 let b_ptr = b.ptr();
744 let a_strides = a.strides();
745 let b_strides = b.strides();
746 let a_dims = a.dims();
747
748 if same_contiguous_layout(a_dims, &[a_strides, b_strides]).is_some() {
749 let len = total_len(a_dims);
750
751 if OpA::IS_IDENTITY
753 && OpB::IS_IDENTITY
754 && std::any::TypeId::of::<A>() == std::any::TypeId::of::<R>()
755 && std::any::TypeId::of::<B>() == std::any::TypeId::of::<R>()
756 {
757 let sa = unsafe { std::slice::from_raw_parts(a_ptr as *const R, len) };
758 let sb = unsafe { std::slice::from_raw_parts(b_ptr as *const R, len) };
759 if let Some(result) = R::try_simd_dot(sa, sb) {
760 return Ok(result);
761 }
762 }
763
764 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
766 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
767 let mut acc = R::zero();
768 simd::dispatch_if_large(len, || {
769 for i in 0..len {
770 acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
771 }
772 });
773 return Ok(acc);
774 }
775
776 let strides_list: [&[isize]; 2] = [a_strides, b_strides];
777 let elem_size = std::mem::size_of::<A>()
778 .max(std::mem::size_of::<B>())
779 .max(std::mem::size_of::<R>());
780
781 let (fused_dims, ordered_strides, plan) =
782 build_plan_fused(a_dims, &strides_list, None, elem_size);
783
784 let mut acc = R::zero();
785 let initial_offsets = vec![0isize; ordered_strides.len()];
786 for_each_inner_block_preordered(
787 &fused_dims,
788 &plan.block,
789 &ordered_strides,
790 &initial_offsets,
791 |offsets, len, strides| {
792 acc = unsafe {
793 inner_loop_dot::<A, B, R, OpA, OpB>(
794 a_ptr.offset(offsets[0]),
795 strides[0],
796 b_ptr.offset(offsets[1]),
797 strides[1],
798 len,
799 acc,
800 )
801 };
802 Ok(())
803 },
804 )?;
805
806 Ok(acc)
807}
808
809pub fn symmetrize_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
811where
812 T: Copy
813 + Add<Output = T>
814 + Mul<Output = T>
815 + num_traits::FromPrimitive
816 + std::ops::Div<Output = T>
817 + MaybeSendSync,
818{
819 if src.ndim() != 2 {
820 return Err(StridedError::RankMismatch(src.ndim(), 2));
821 }
822 let rows = src.dims()[0];
823 let cols = src.dims()[1];
824 if rows != cols {
825 return Err(StridedError::NonSquare { rows, cols });
826 }
827
828 let src_t = src.permute(&[1, 0])?;
829 let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
830
831 zip_map2_into(dest, src, &src_t, |a, b| (a + b) * half)
832}
833
834pub fn symmetrize_conj_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
836where
837 T: Copy
838 + ElementOpApply
839 + Add<Output = T>
840 + Mul<Output = T>
841 + num_traits::FromPrimitive
842 + std::ops::Div<Output = T>
843 + MaybeSendSync,
844{
845 if src.ndim() != 2 {
846 return Err(StridedError::RankMismatch(src.ndim(), 2));
847 }
848 let rows = src.dims()[0];
849 let cols = src.dims()[1];
850 if rows != cols {
851 return Err(StridedError::NonSquare { rows, cols });
852 }
853
854 let src_adj = src.adjoint_2d()?;
856 let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
857
858 zip_map2_into(dest, src, &src_adj, |a, b| (a + b) * half)
859}
860
861pub fn copy_scale<D, S, A, Op>(
865 dest: &mut StridedViewMut<D>,
866 src: &StridedView<S, Op>,
867 scale: A,
868) -> Result<()>
869where
870 A: Copy + Mul<S, Output = D> + MaybeSync,
871 D: Copy + MaybeSendSync,
872 S: Copy + MaybeSendSync,
873 Op: ElementOp<S>,
874{
875 map_into(dest, src, |x| scale * x)
876}
877
878pub fn copy_conj<T: Copy + ElementOpApply + MaybeSendSync>(
880 dest: &mut StridedViewMut<T>,
881 src: &StridedView<T>,
882) -> Result<()> {
883 let src_conj = src.conj();
884 copy_into(dest, &src_conj)
885}
886
887#[inline]
888fn element_transpose_is_identity<T: 'static>() -> bool {
889 use std::any::TypeId;
890
891 macro_rules! matches_type {
892 ($($ty:ty),* $(,)?) => {{
893 let id = TypeId::of::<T>();
894 false $(|| id == TypeId::of::<$ty>())*
895 }};
896 }
897
898 matches_type!(
899 f32,
900 f64,
901 i8,
902 i16,
903 i32,
904 i64,
905 i128,
906 isize,
907 u8,
908 u16,
909 u32,
910 u64,
911 u128,
912 usize,
913 num_complex::Complex32,
914 num_complex::Complex64,
915 )
916}
917
918#[inline]
919fn element_zero_is_all_bits_zero<T: 'static>() -> bool {
920 use std::any::TypeId;
921
922 macro_rules! matches_type {
923 ($($ty:ty),* $(,)?) => {{
924 let id = TypeId::of::<T>();
925 false $(|| id == TypeId::of::<$ty>())*
926 }};
927 }
928
929 matches_type!(
930 f32,
931 f64,
932 i8,
933 i16,
934 i32,
935 i64,
936 i128,
937 isize,
938 u8,
939 u16,
940 u32,
941 u64,
942 u128,
943 usize,
944 num_complex::Complex32,
945 num_complex::Complex64,
946 )
947}
948
949#[inline]
950unsafe fn fill_2d<T: Copy + MaybeSendSync>(
951 dst: *mut T,
952 dim0: usize,
953 dim1: usize,
954 dst_stride0: isize,
955 dst_stride1: isize,
956 value: T,
957) {
958 #[cfg(feature = "parallel")]
959 {
960 let total = dim0.saturating_mul(dim1);
961 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
962 let dst_send = SendPtr(dst);
963 if dst_stride0.unsigned_abs() <= dst_stride1.unsigned_abs() {
964 (0..dim1).into_par_iter().for_each(|j| {
965 let dst = dst_send.as_ptr();
966 unsafe {
967 let base = j as isize * dst_stride1;
968 for i in 0..dim0 {
969 *dst.offset(base + i as isize * dst_stride0) = value;
970 }
971 }
972 });
973 } else {
974 (0..dim0).into_par_iter().for_each(|i| {
975 let dst = dst_send.as_ptr();
976 unsafe {
977 let base = i as isize * dst_stride0;
978 for j in 0..dim1 {
979 *dst.offset(base + j as isize * dst_stride1) = value;
980 }
981 }
982 });
983 }
984 return;
985 }
986 }
987
988 if dst_stride0.unsigned_abs() <= dst_stride1.unsigned_abs() {
989 for j in 0..dim1 {
990 let base = j as isize * dst_stride1;
991 for i in 0..dim0 {
992 *dst.offset(base + i as isize * dst_stride0) = value;
993 }
994 }
995 } else {
996 for i in 0..dim0 {
997 let base = i as isize * dst_stride0;
998 for j in 0..dim1 {
999 *dst.offset(base + j as isize * dst_stride1) = value;
1000 }
1001 }
1002 }
1003}
1004
1005#[inline]
1006unsafe fn fill_contiguous<T>(dst: *mut T, len: usize, value: T)
1007where
1008 T: Copy + Zero + PartialEq + MaybeSendSync + 'static,
1009{
1010 if element_zero_is_all_bits_zero::<T>() && value == T::zero() {
1011 std::ptr::write_bytes(dst, 0, len);
1012 return;
1013 }
1014
1015 let dst = std::slice::from_raw_parts_mut(dst, len);
1016 dst.fill(value);
1017}
1018
1019#[inline(always)]
1020unsafe fn transpose_scale_4x4_f64(
1021 dst: *mut f64,
1022 dst_stride0: isize,
1023 dst_stride1: isize,
1024 src: *const f64,
1025 src_stride0: isize,
1026 src_stride1: isize,
1027 i: usize,
1028 j: usize,
1029 scale: f64,
1030) {
1031 let src_base = src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1032
1033 let s00 = *src_base;
1034 let s10 = *src_base.offset(src_stride0);
1035 let s20 = *src_base.offset(2 * src_stride0);
1036 let s30 = *src_base.offset(3 * src_stride0);
1037
1038 let src_col1 = src_base.offset(src_stride1);
1039 let s01 = *src_col1;
1040 let s11 = *src_col1.offset(src_stride0);
1041 let s21 = *src_col1.offset(2 * src_stride0);
1042 let s31 = *src_col1.offset(3 * src_stride0);
1043
1044 let src_col2 = src_base.offset(2 * src_stride1);
1045 let s02 = *src_col2;
1046 let s12 = *src_col2.offset(src_stride0);
1047 let s22 = *src_col2.offset(2 * src_stride0);
1048 let s32 = *src_col2.offset(3 * src_stride0);
1049
1050 let src_col3 = src_base.offset(3 * src_stride1);
1051 let s03 = *src_col3;
1052 let s13 = *src_col3.offset(src_stride0);
1053 let s23 = *src_col3.offset(2 * src_stride0);
1054 let s33 = *src_col3.offset(3 * src_stride0);
1055
1056 let dst_row0 = dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1);
1057 *dst_row0 = scale * s00;
1058 *dst_row0.offset(dst_stride0) = scale * s01;
1059 *dst_row0.offset(2 * dst_stride0) = scale * s02;
1060 *dst_row0.offset(3 * dst_stride0) = scale * s03;
1061
1062 let dst_row1 = dst_row0.offset(dst_stride1);
1063 *dst_row1 = scale * s10;
1064 *dst_row1.offset(dst_stride0) = scale * s11;
1065 *dst_row1.offset(2 * dst_stride0) = scale * s12;
1066 *dst_row1.offset(3 * dst_stride0) = scale * s13;
1067
1068 let dst_row2 = dst_row0.offset(2 * dst_stride1);
1069 *dst_row2 = scale * s20;
1070 *dst_row2.offset(dst_stride0) = scale * s21;
1071 *dst_row2.offset(2 * dst_stride0) = scale * s22;
1072 *dst_row2.offset(3 * dst_stride0) = scale * s23;
1073
1074 let dst_row3 = dst_row0.offset(3 * dst_stride1);
1075 *dst_row3 = scale * s30;
1076 *dst_row3.offset(dst_stride0) = scale * s31;
1077 *dst_row3.offset(2 * dst_stride0) = scale * s32;
1078 *dst_row3.offset(3 * dst_stride0) = scale * s33;
1079}
1080
1081#[inline]
1082unsafe fn copy_transpose_scale_2d_f64_tiled_raw(
1083 dst: *mut f64,
1084 dst_stride0: isize,
1085 dst_stride1: isize,
1086 src: *const f64,
1087 src_stride0: isize,
1088 src_stride1: isize,
1089 src_rows: usize,
1090 src_cols: usize,
1091 scale: f64,
1092) {
1093 const TILE: usize = 4;
1094 let row_full = src_rows / TILE * TILE;
1095 let col_full = src_cols / TILE * TILE;
1096
1097 #[cfg(feature = "parallel")]
1098 {
1099 let total = src_rows.saturating_mul(src_cols);
1100 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1101 let dst_send = SendPtr(dst);
1102 let src_send = SendPtr(src as *mut f64);
1103 let row_tiles = row_full / TILE;
1104 (0..row_tiles).into_par_iter().for_each(|tile_i| {
1105 let i = tile_i * TILE;
1106 let dst = dst_send.as_ptr();
1107 let src = src_send.as_const();
1108 unsafe {
1109 let mut j = 0;
1110 while j < col_full {
1111 transpose_scale_4x4_f64(
1112 dst,
1113 dst_stride0,
1114 dst_stride1,
1115 src,
1116 src_stride0,
1117 src_stride1,
1118 i,
1119 j,
1120 scale,
1121 );
1122 j += TILE;
1123 }
1124 for j in col_full..src_cols {
1125 for ii in i..i + TILE {
1126 *dst.offset(j as isize * dst_stride0 + ii as isize * dst_stride1) =
1127 scale
1128 * *src.offset(
1129 ii as isize * src_stride0 + j as isize * src_stride1,
1130 );
1131 }
1132 }
1133 }
1134 });
1135 for i in row_full..src_rows {
1136 for j in 0..src_cols {
1137 *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1138 scale * *src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1139 }
1140 }
1141 return;
1142 }
1143 }
1144
1145 let mut i = 0;
1146 while i < row_full {
1147 let mut j = 0;
1148 while j < col_full {
1149 transpose_scale_4x4_f64(
1150 dst,
1151 dst_stride0,
1152 dst_stride1,
1153 src,
1154 src_stride0,
1155 src_stride1,
1156 i,
1157 j,
1158 scale,
1159 );
1160 j += TILE;
1161 }
1162 for j in col_full..src_cols {
1163 for ii in i..i + TILE {
1164 *dst.offset(j as isize * dst_stride0 + ii as isize * dst_stride1) =
1165 scale * *src.offset(ii as isize * src_stride0 + j as isize * src_stride1);
1166 }
1167 }
1168 i += TILE;
1169 }
1170 for i in row_full..src_rows {
1171 for j in 0..src_cols {
1172 *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1173 scale * *src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1174 }
1175 }
1176}
1177
1178#[inline]
1179#[cfg(test)]
1180unsafe fn try_copy_transpose_scale_2d_f64_tiled(
1181 dest: &mut StridedViewMut<f64>,
1182 src: &StridedView<f64>,
1183 scale: f64,
1184) -> bool {
1185 if src.ndim() != 2 || dest.ndim() != 2 {
1186 return false;
1187 }
1188 let src_dims = src.dims();
1189 if dest.dims() != [src_dims[1], src_dims[0]] {
1190 return false;
1191 }
1192 if src.strides()[0] != 1 || dest.strides()[0] != 1 {
1193 return false;
1194 }
1195
1196 copy_transpose_scale_2d_f64_tiled_raw(
1197 dest.as_mut_ptr(),
1198 dest.strides()[0],
1199 dest.strides()[1],
1200 src.ptr(),
1201 src.strides()[0],
1202 src.strides()[1],
1203 src_dims[0],
1204 src_dims[1],
1205 scale,
1206 );
1207 true
1208}
1209
1210#[inline]
1211unsafe fn try_copy_transpose_scale_2d_f64_tiled_typed<T>(
1212 dest: &mut StridedViewMut<T>,
1213 src: &StridedView<T>,
1214 scale: T,
1215) -> bool
1216where
1217 T: Copy + 'static,
1218{
1219 if std::any::TypeId::of::<T>() != std::any::TypeId::of::<f64>() {
1220 return false;
1221 }
1222
1223 let scale = *(&scale as *const T).cast::<f64>();
1224 if src.ndim() != 2 || dest.ndim() != 2 {
1225 return false;
1226 }
1227 let src_dims = src.dims();
1228 if dest.dims() != [src_dims[1], src_dims[0]] {
1229 return false;
1230 }
1231 if src.strides()[0] != 1 || dest.strides()[0] != 1 {
1232 return false;
1233 }
1234
1235 copy_transpose_scale_2d_f64_tiled_raw(
1236 dest.as_mut_ptr().cast::<f64>(),
1237 dest.strides()[0],
1238 dest.strides()[1],
1239 src.ptr().cast::<f64>(),
1240 src.strides()[0],
1241 src.strides()[1],
1242 src_dims[0],
1243 src_dims[1],
1244 scale,
1245 );
1246 true
1247}
1248
1249#[inline(always)]
1250unsafe fn transpose_scale_4x4_identity<T>(
1251 dst: *mut T,
1252 dst_stride0: isize,
1253 dst_stride1: isize,
1254 src: *const T,
1255 src_stride0: isize,
1256 src_stride1: isize,
1257 i: usize,
1258 j: usize,
1259 scale: T,
1260) where
1261 T: Copy + Mul<Output = T>,
1262{
1263 let src_base = src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1264
1265 let s00 = *src_base;
1266 let s10 = *src_base.offset(src_stride0);
1267 let s20 = *src_base.offset(2 * src_stride0);
1268 let s30 = *src_base.offset(3 * src_stride0);
1269
1270 let src_col1 = src_base.offset(src_stride1);
1271 let s01 = *src_col1;
1272 let s11 = *src_col1.offset(src_stride0);
1273 let s21 = *src_col1.offset(2 * src_stride0);
1274 let s31 = *src_col1.offset(3 * src_stride0);
1275
1276 let src_col2 = src_base.offset(2 * src_stride1);
1277 let s02 = *src_col2;
1278 let s12 = *src_col2.offset(src_stride0);
1279 let s22 = *src_col2.offset(2 * src_stride0);
1280 let s32 = *src_col2.offset(3 * src_stride0);
1281
1282 let src_col3 = src_base.offset(3 * src_stride1);
1283 let s03 = *src_col3;
1284 let s13 = *src_col3.offset(src_stride0);
1285 let s23 = *src_col3.offset(2 * src_stride0);
1286 let s33 = *src_col3.offset(3 * src_stride0);
1287
1288 let dst_row0 = dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1);
1289 *dst_row0 = scale * s00;
1290 *dst_row0.offset(dst_stride0) = scale * s01;
1291 *dst_row0.offset(2 * dst_stride0) = scale * s02;
1292 *dst_row0.offset(3 * dst_stride0) = scale * s03;
1293
1294 let dst_row1 = dst_row0.offset(dst_stride1);
1295 *dst_row1 = scale * s10;
1296 *dst_row1.offset(dst_stride0) = scale * s11;
1297 *dst_row1.offset(2 * dst_stride0) = scale * s12;
1298 *dst_row1.offset(3 * dst_stride0) = scale * s13;
1299
1300 let dst_row2 = dst_row0.offset(2 * dst_stride1);
1301 *dst_row2 = scale * s20;
1302 *dst_row2.offset(dst_stride0) = scale * s21;
1303 *dst_row2.offset(2 * dst_stride0) = scale * s22;
1304 *dst_row2.offset(3 * dst_stride0) = scale * s23;
1305
1306 let dst_row3 = dst_row0.offset(3 * dst_stride1);
1307 *dst_row3 = scale * s30;
1308 *dst_row3.offset(dst_stride0) = scale * s31;
1309 *dst_row3.offset(2 * dst_stride0) = scale * s32;
1310 *dst_row3.offset(3 * dst_stride0) = scale * s33;
1311}
1312
1313#[inline]
1314unsafe fn copy_transpose_scale_2d_identity_tiled_raw<T>(
1315 dst: *mut T,
1316 dst_stride0: isize,
1317 dst_stride1: isize,
1318 src: *const T,
1319 src_stride0: isize,
1320 src_stride1: isize,
1321 src_rows: usize,
1322 src_cols: usize,
1323 scale: T,
1324) where
1325 T: Copy + Mul<Output = T> + MaybeSendSync,
1326{
1327 const TILE: usize = 4;
1328 let row_full = src_rows / TILE * TILE;
1329 let col_full = src_cols / TILE * TILE;
1330
1331 #[cfg(feature = "parallel")]
1332 {
1333 let total = src_rows.saturating_mul(src_cols);
1334 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1335 let dst_send = SendPtr(dst);
1336 let src_send = SendPtr(src as *mut T);
1337 let row_tiles = row_full / TILE;
1338 (0..row_tiles).into_par_iter().for_each(|tile_i| {
1339 let i = tile_i * TILE;
1340 let dst = dst_send.as_ptr();
1341 let src = src_send.as_const();
1342 unsafe {
1343 let mut j = 0;
1344 while j < col_full {
1345 transpose_scale_4x4_identity(
1346 dst,
1347 dst_stride0,
1348 dst_stride1,
1349 src,
1350 src_stride0,
1351 src_stride1,
1352 i,
1353 j,
1354 scale,
1355 );
1356 j += TILE;
1357 }
1358 for j in col_full..src_cols {
1359 for ii in i..i + TILE {
1360 *dst.offset(j as isize * dst_stride0 + ii as isize * dst_stride1) =
1361 scale
1362 * *src.offset(
1363 ii as isize * src_stride0 + j as isize * src_stride1,
1364 );
1365 }
1366 }
1367 }
1368 });
1369 for i in row_full..src_rows {
1370 for j in 0..src_cols {
1371 *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1372 scale * *src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1373 }
1374 }
1375 return;
1376 }
1377 }
1378
1379 let mut i = 0;
1380 while i < row_full {
1381 let mut j = 0;
1382 while j < col_full {
1383 transpose_scale_4x4_identity(
1384 dst,
1385 dst_stride0,
1386 dst_stride1,
1387 src,
1388 src_stride0,
1389 src_stride1,
1390 i,
1391 j,
1392 scale,
1393 );
1394 j += TILE;
1395 }
1396 for j in col_full..src_cols {
1397 for ii in i..i + TILE {
1398 *dst.offset(j as isize * dst_stride0 + ii as isize * dst_stride1) =
1399 scale * *src.offset(ii as isize * src_stride0 + j as isize * src_stride1);
1400 }
1401 }
1402 i += TILE;
1403 }
1404 for i in row_full..src_rows {
1405 for j in 0..src_cols {
1406 *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1407 scale * *src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1408 }
1409 }
1410}
1411
1412#[inline]
1413unsafe fn try_copy_transpose_scale_2d_identity_tiled<T>(
1414 dest: &mut StridedViewMut<T>,
1415 src: &StridedView<T>,
1416 scale: T,
1417) -> bool
1418where
1419 T: Copy + Mul<Output = T> + MaybeSendSync,
1420{
1421 if src.ndim() != 2 || dest.ndim() != 2 {
1422 return false;
1423 }
1424 let src_dims = src.dims();
1425 if dest.dims() != [src_dims[1], src_dims[0]] {
1426 return false;
1427 }
1428 if src.strides()[0] != 1 || dest.strides()[0] != 1 {
1429 return false;
1430 }
1431
1432 copy_transpose_scale_2d_identity_tiled_raw(
1433 dest.as_mut_ptr(),
1434 dest.strides()[0],
1435 dest.strides()[1],
1436 src.ptr(),
1437 src.strides()[0],
1438 src.strides()[1],
1439 src_dims[0],
1440 src_dims[1],
1441 scale,
1442 );
1443 true
1444}
1445
1446#[inline]
1447unsafe fn copy_transpose_scale_2d_loop<T>(
1448 dst: *mut T,
1449 dst_stride0: isize,
1450 dst_stride1: isize,
1451 src: *const T,
1452 src_stride0: isize,
1453 src_stride1: isize,
1454 src_rows: usize,
1455 src_cols: usize,
1456 scale: T,
1457) where
1458 T: Copy + ElementOpApply + Mul<Output = T> + MaybeSendSync,
1459{
1460 #[cfg(feature = "parallel")]
1461 {
1462 let total = src_rows.saturating_mul(src_cols);
1463 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1464 let dst_send = SendPtr(dst);
1465 let src_send = SendPtr(src as *mut T);
1466 if dst_stride0.unsigned_abs() <= dst_stride1.unsigned_abs() {
1467 (0..src_rows).into_par_iter().for_each(|i| {
1468 let dst = dst_send.as_ptr();
1469 let src = src_send.as_const();
1470 unsafe {
1471 for j in 0..src_cols {
1472 let value = (*src
1473 .offset(i as isize * src_stride0 + j as isize * src_stride1))
1474 .transpose();
1475 *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1476 scale * value;
1477 }
1478 }
1479 });
1480 } else {
1481 (0..src_cols).into_par_iter().for_each(|j| {
1482 let dst = dst_send.as_ptr();
1483 let src = src_send.as_const();
1484 unsafe {
1485 for i in 0..src_rows {
1486 let value = (*src
1487 .offset(i as isize * src_stride0 + j as isize * src_stride1))
1488 .transpose();
1489 *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1490 scale * value;
1491 }
1492 }
1493 });
1494 }
1495 return;
1496 }
1497 }
1498
1499 if dst_stride0.unsigned_abs() <= dst_stride1.unsigned_abs() {
1500 for i in 0..src_rows {
1501 for j in 0..src_cols {
1502 let value =
1503 (*src.offset(i as isize * src_stride0 + j as isize * src_stride1)).transpose();
1504 *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) = scale * value;
1505 }
1506 }
1507 } else {
1508 for j in 0..src_cols {
1509 for i in 0..src_rows {
1510 let value =
1511 (*src.offset(i as isize * src_stride0 + j as isize * src_stride1)).transpose();
1512 *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) = scale * value;
1513 }
1514 }
1515 }
1516}
1517
1518pub fn copy_transpose_scale_into<T>(
1520 dest: &mut StridedViewMut<T>,
1521 src: &StridedView<T>,
1522 scale: T,
1523) -> Result<()>
1524where
1525 T: Copy + ElementOpApply + Mul<Output = T> + Zero + One + PartialEq + MaybeSendSync + 'static,
1526{
1527 if src.ndim() != 2 || dest.ndim() != 2 {
1528 return Err(StridedError::RankMismatch(src.ndim(), 2));
1529 }
1530
1531 let src_dims = src.dims();
1532 let expected_dims = [src_dims[1], src_dims[0]];
1533 ensure_same_shape(dest.dims(), &expected_dims)?;
1534
1535 if scale == T::zero() {
1536 unsafe {
1537 if same_contiguous_layout(dest.dims(), &[dest.strides()]).is_some() {
1538 fill_contiguous(dest.as_mut_ptr(), total_len(dest.dims()), T::zero());
1539 } else {
1540 fill_2d(
1541 dest.as_mut_ptr(),
1542 dest.dims()[0],
1543 dest.dims()[1],
1544 dest.strides()[0],
1545 dest.strides()[1],
1546 T::zero(),
1547 );
1548 }
1549 }
1550 return Ok(());
1551 }
1552
1553 let transpose_is_identity = element_transpose_is_identity::<T>();
1554
1555 unsafe {
1556 if transpose_is_identity && try_copy_transpose_scale_2d_f64_tiled_typed(dest, src, scale) {
1557 return Ok(());
1558 }
1559 if transpose_is_identity && try_copy_transpose_scale_2d_identity_tiled(dest, src, scale) {
1560 return Ok(());
1561 }
1562 }
1563
1564 if scale == T::one() && transpose_is_identity {
1565 let src_t = src.permute(&[1, 0])?;
1566 #[cfg(feature = "parallel")]
1567 {
1568 if total_len(dest.dims()) > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1569 return strided_perm::copy_into_par(dest, &src_t);
1570 }
1571 }
1572 return strided_perm::copy_into(dest, &src_t);
1573 }
1574
1575 unsafe {
1576 copy_transpose_scale_2d_loop(
1577 dest.as_mut_ptr(),
1578 dest.strides()[0],
1579 dest.strides()[1],
1580 src.ptr(),
1581 src.strides()[0],
1582 src.strides()[1],
1583 src_dims[0],
1584 src_dims[1],
1585 scale,
1586 );
1587 }
1588 Ok(())
1589}
1590
1591#[cfg(test)]
1592mod tiled_tests {
1593 use super::*;
1594 use crate::view::StridedArray;
1595
1596 #[test]
1597 fn test_f64_tiled_transpose_scale_handles_remainders() {
1598 let rows = 7;
1599 let cols = 9;
1600 let a = StridedArray::<f64>::from_fn_col_major(&[rows, cols], |idx| {
1601 (idx[0] * 100 + idx[1]) as f64
1602 });
1603 let mut out = StridedArray::<f64>::col_major(&[cols, rows]);
1604
1605 let used_tiled = {
1606 let src = a.view();
1607 let mut dst = out.view_mut();
1608 unsafe { try_copy_transpose_scale_2d_f64_tiled(&mut dst, &src, 3.0) }
1609 };
1610
1611 assert!(used_tiled);
1612 for i in 0..rows {
1613 for j in 0..cols {
1614 assert_eq!(out.get(&[j, i]), 3.0 * a.get(&[i, j]));
1615 }
1616 }
1617 }
1618
1619 #[test]
1620 fn test_identity_tiled_transpose_scale_handles_integer_remainders() {
1621 let rows = 6;
1622 let cols = 5;
1623 let a = StridedArray::<u64>::from_fn_col_major(&[rows, cols], |idx| {
1624 (idx[0] * 100 + idx[1]) as u64
1625 });
1626 let mut out = StridedArray::<u64>::col_major(&[cols, rows]);
1627
1628 let used_tiled = {
1629 let src = a.view();
1630 let mut dst = out.view_mut();
1631 unsafe { try_copy_transpose_scale_2d_identity_tiled(&mut dst, &src, 2) }
1632 };
1633
1634 assert!(used_tiled);
1635 for i in 0..rows {
1636 for j in 0..cols {
1637 assert_eq!(out.get(&[j, i]), 2 * a.get(&[i, j]));
1638 }
1639 }
1640 }
1641
1642 #[test]
1643 fn test_zero_scale_fills_non_contiguous_destination() {
1644 let rows = 3;
1645 let cols = 4;
1646 let a = StridedArray::<u64>::from_fn_col_major(&[rows, cols], |idx| {
1647 (idx[0] * 10 + idx[1] + 1) as u64
1648 });
1649 let mut out_base = StridedArray::<u64>::from_fn_col_major(&[rows, cols], |_| 99);
1650 let mut out_t = out_base.view_mut().permute(&[1, 0]).unwrap();
1651
1652 copy_transpose_scale_into(&mut out_t, &a.view(), 0).unwrap();
1653
1654 for i in 0..cols {
1655 for j in 0..rows {
1656 assert_eq!(out_t.get(&[i, j]), 0);
1657 }
1658 }
1659 }
1660}