1use crate::kernel::{
4 build_plan_fused, ensure_same_shape, for_each_inner_block_preordered, same_contiguous_layout,
5 sequential_contiguous_layout, total_len,
6};
7use crate::map_view::{map_into, zip_map2_into};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::reduce_view::reduce;
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::{Result, StridedError};
13use num_traits::Zero;
14use std::ops::{Add, Mul};
15use strided_view::{ElementOp, ElementOpApply};
16
17#[cfg(feature = "parallel")]
18use crate::fuse::compute_costs;
19#[cfg(feature = "parallel")]
20use crate::threading::{
21 for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
22};
23
24#[inline(always)]
34unsafe fn inner_loop_add<D: Copy + Add<S, Output = D>, S: Copy, Op: ElementOp<S>>(
35 dp: *mut D,
36 ds: isize,
37 sp: *const S,
38 ss: isize,
39 len: usize,
40) {
41 if ds == 1 && ss == 1 {
42 let dst = std::slice::from_raw_parts_mut(dp, len);
43 let src = std::slice::from_raw_parts(sp, len);
44 simd::dispatch_if_large(len, || {
45 for i in 0..len {
46 dst[i] = dst[i] + Op::apply(src[i]);
47 }
48 });
49 } else {
50 let mut dp = dp;
51 let mut sp = sp;
52 for _ in 0..len {
53 *dp = *dp + Op::apply(*sp);
54 dp = dp.offset(ds);
55 sp = sp.offset(ss);
56 }
57 }
58}
59
60#[inline(always)]
62unsafe fn inner_loop_mul<D: Copy + Mul<S, Output = D>, S: Copy, Op: ElementOp<S>>(
63 dp: *mut D,
64 ds: isize,
65 sp: *const S,
66 ss: isize,
67 len: usize,
68) {
69 if ds == 1 && ss == 1 {
70 let dst = std::slice::from_raw_parts_mut(dp, len);
71 let src = std::slice::from_raw_parts(sp, len);
72 simd::dispatch_if_large(len, || {
73 for i in 0..len {
74 dst[i] = dst[i] * Op::apply(src[i]);
75 }
76 });
77 } else {
78 let mut dp = dp;
79 let mut sp = sp;
80 for _ in 0..len {
81 *dp = *dp * Op::apply(*sp);
82 dp = dp.offset(ds);
83 sp = sp.offset(ss);
84 }
85 }
86}
87
88#[inline(always)]
90unsafe fn inner_loop_axpy<
91 D: Copy + Add<D, Output = D>,
92 S: Copy,
93 A: Copy + Mul<S, Output = D>,
94 Op: ElementOp<S>,
95>(
96 dp: *mut D,
97 ds: isize,
98 sp: *const S,
99 ss: isize,
100 len: usize,
101 alpha: A,
102) {
103 if ds == 1 && ss == 1 {
104 let dst = std::slice::from_raw_parts_mut(dp, len);
105 let src = std::slice::from_raw_parts(sp, len);
106 simd::dispatch_if_large(len, || {
107 for i in 0..len {
108 dst[i] = alpha * Op::apply(src[i]) + dst[i];
109 }
110 });
111 } else {
112 let mut dp = dp;
113 let mut sp = sp;
114 for _ in 0..len {
115 *dp = alpha * Op::apply(*sp) + *dp;
116 dp = dp.offset(ds);
117 sp = sp.offset(ss);
118 }
119 }
120}
121
122#[inline(always)]
124unsafe fn inner_loop_fma<
125 D: Copy + Add<D, Output = D>,
126 A: Copy + Mul<B, Output = D>,
127 B: Copy,
128 OpA: ElementOp<A>,
129 OpB: ElementOp<B>,
130>(
131 dp: *mut D,
132 ds: isize,
133 ap: *const A,
134 a_s: isize,
135 bp: *const B,
136 b_s: isize,
137 len: usize,
138) {
139 if ds == 1 && a_s == 1 && b_s == 1 {
140 let dst = std::slice::from_raw_parts_mut(dp, len);
141 let sa = std::slice::from_raw_parts(ap, len);
142 let sb = std::slice::from_raw_parts(bp, len);
143 simd::dispatch_if_large(len, || {
144 for i in 0..len {
145 dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
146 }
147 });
148 } else {
149 let mut dp = dp;
150 let mut ap = ap;
151 let mut bp = bp;
152 for _ in 0..len {
153 *dp = *dp + OpA::apply(*ap) * OpB::apply(*bp);
154 dp = dp.offset(ds);
155 ap = ap.offset(a_s);
156 bp = bp.offset(b_s);
157 }
158 }
159}
160
161#[inline(always)]
163unsafe fn inner_loop_dot<
164 A: Copy + Mul<B, Output = R>,
165 B: Copy,
166 R: Copy + Add<R, Output = R>,
167 OpA: ElementOp<A>,
168 OpB: ElementOp<B>,
169>(
170 ap: *const A,
171 a_s: isize,
172 bp: *const B,
173 b_s: isize,
174 len: usize,
175 mut acc: R,
176) -> R {
177 if a_s == 1 && b_s == 1 {
178 let sa = std::slice::from_raw_parts(ap, len);
179 let sb = std::slice::from_raw_parts(bp, len);
180 simd::dispatch_if_large(len, || {
181 for i in 0..len {
182 acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
183 }
184 });
185 } else {
186 let mut ap = ap;
187 let mut bp = bp;
188 for _ in 0..len {
189 acc = acc + OpA::apply(*ap) * OpB::apply(*bp);
190 ap = ap.offset(a_s);
191 bp = bp.offset(b_s);
192 }
193 }
194 acc
195}
196
197pub fn copy_into<T: Copy + MaybeSendSync, Op: ElementOp<T>>(
199 dest: &mut StridedViewMut<T>,
200 src: &StridedView<T, Op>,
201) -> Result<()> {
202 ensure_same_shape(dest.dims(), src.dims())?;
203
204 let dst_ptr = dest.as_mut_ptr();
205 let src_ptr = src.ptr();
206 let dst_dims = dest.dims();
207 let dst_strides = dest.strides();
208 let src_strides = src.strides();
209
210 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
211 let len = total_len(dst_dims);
212 if Op::IS_IDENTITY {
213 debug_assert!(
214 {
215 let nbytes = len
216 .checked_mul(std::mem::size_of::<T>())
217 .expect("copy size must not overflow");
218 let dst_start = dst_ptr as usize;
219 let src_start = src_ptr as usize;
220 let dst_end = dst_start.saturating_add(nbytes);
221 let src_end = src_start.saturating_add(nbytes);
222 dst_end <= src_start || src_end <= dst_start
223 },
224 "overlapping src/dest is not supported"
225 );
226 unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
227 } else {
228 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
229 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
230 simd::dispatch_if_large(len, || {
231 for i in 0..len {
232 dst[i] = Op::apply(src[i]);
233 }
234 });
235 }
236 return Ok(());
237 }
238
239 map_into(dest, src, |x| x)
240}
241
242pub fn add<
246 D: Copy + Add<S, Output = D> + MaybeSendSync,
247 S: Copy + MaybeSendSync,
248 Op: ElementOp<S>,
249>(
250 dest: &mut StridedViewMut<D>,
251 src: &StridedView<S, Op>,
252) -> Result<()> {
253 ensure_same_shape(dest.dims(), src.dims())?;
254
255 let dst_ptr = dest.as_mut_ptr();
256 let src_ptr = src.ptr();
257 let dst_dims = dest.dims();
258 let dst_strides = dest.strides();
259 let src_strides = src.strides();
260
261 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
262 let len = total_len(dst_dims);
263 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
264 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
265 simd::dispatch_if_large(len, || {
266 for i in 0..len {
267 dst[i] = dst[i] + Op::apply(src[i]);
268 }
269 });
270 return Ok(());
271 }
272
273 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
274 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
275
276 let (fused_dims, ordered_strides, plan) =
277 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
278
279 #[cfg(feature = "parallel")]
280 {
281 let total: usize = fused_dims.iter().product();
282 if total > MINTHREADLENGTH {
283 let dst_send = SendPtr(dst_ptr);
284 let src_send = SendPtr(src_ptr as *mut S);
285
286 let costs = compute_costs(&ordered_strides);
287 let initial_offsets = vec![0isize; strides_list.len()];
288 let nthreads = rayon::current_num_threads();
289
290 return mapreduce_threaded(
291 &fused_dims,
292 &plan.block,
293 &ordered_strides,
294 &initial_offsets,
295 &costs,
296 nthreads,
297 0,
298 1,
299 &|dims, blocks, strides_list, offsets| {
300 for_each_inner_block_with_offsets(
301 dims,
302 blocks,
303 strides_list,
304 offsets,
305 |offsets, len, strides| {
306 unsafe {
307 inner_loop_add::<D, S, Op>(
308 dst_send.as_ptr().offset(offsets[0]),
309 strides[0],
310 src_send.as_const().offset(offsets[1]),
311 strides[1],
312 len,
313 )
314 };
315 Ok(())
316 },
317 )
318 },
319 );
320 }
321 }
322
323 let initial_offsets = vec![0isize; ordered_strides.len()];
324 for_each_inner_block_preordered(
325 &fused_dims,
326 &plan.block,
327 &ordered_strides,
328 &initial_offsets,
329 |offsets, len, strides| {
330 unsafe {
331 inner_loop_add::<D, S, Op>(
332 dst_ptr.offset(offsets[0]),
333 strides[0],
334 src_ptr.offset(offsets[1]),
335 strides[1],
336 len,
337 )
338 };
339 Ok(())
340 },
341 )
342}
343
344pub fn mul<
348 D: Copy + Mul<S, Output = D> + MaybeSendSync,
349 S: Copy + MaybeSendSync,
350 Op: ElementOp<S>,
351>(
352 dest: &mut StridedViewMut<D>,
353 src: &StridedView<S, Op>,
354) -> Result<()> {
355 ensure_same_shape(dest.dims(), src.dims())?;
356
357 let dst_ptr = dest.as_mut_ptr();
358 let src_ptr = src.ptr();
359 let dst_dims = dest.dims();
360 let dst_strides = dest.strides();
361 let src_strides = src.strides();
362
363 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
364 let len = total_len(dst_dims);
365 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
366 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
367 simd::dispatch_if_large(len, || {
368 for i in 0..len {
369 dst[i] = dst[i] * Op::apply(src[i]);
370 }
371 });
372 return Ok(());
373 }
374
375 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
376 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
377
378 let (fused_dims, ordered_strides, plan) =
379 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
380
381 #[cfg(feature = "parallel")]
382 {
383 let total: usize = fused_dims.iter().product();
384 if total > MINTHREADLENGTH {
385 let dst_send = SendPtr(dst_ptr);
386 let src_send = SendPtr(src_ptr as *mut S);
387
388 let costs = compute_costs(&ordered_strides);
389 let initial_offsets = vec![0isize; strides_list.len()];
390 let nthreads = rayon::current_num_threads();
391
392 return mapreduce_threaded(
393 &fused_dims,
394 &plan.block,
395 &ordered_strides,
396 &initial_offsets,
397 &costs,
398 nthreads,
399 0,
400 1,
401 &|dims, blocks, strides_list, offsets| {
402 for_each_inner_block_with_offsets(
403 dims,
404 blocks,
405 strides_list,
406 offsets,
407 |offsets, len, strides| {
408 unsafe {
409 inner_loop_mul::<D, S, Op>(
410 dst_send.as_ptr().offset(offsets[0]),
411 strides[0],
412 src_send.as_const().offset(offsets[1]),
413 strides[1],
414 len,
415 )
416 };
417 Ok(())
418 },
419 )
420 },
421 );
422 }
423 }
424
425 let initial_offsets = vec![0isize; ordered_strides.len()];
426 for_each_inner_block_preordered(
427 &fused_dims,
428 &plan.block,
429 &ordered_strides,
430 &initial_offsets,
431 |offsets, len, strides| {
432 unsafe {
433 inner_loop_mul::<D, S, Op>(
434 dst_ptr.offset(offsets[0]),
435 strides[0],
436 src_ptr.offset(offsets[1]),
437 strides[1],
438 len,
439 )
440 };
441 Ok(())
442 },
443 )
444}
445
446pub fn axpy<D, S, A, Op>(
450 dest: &mut StridedViewMut<D>,
451 src: &StridedView<S, Op>,
452 alpha: A,
453) -> Result<()>
454where
455 A: Copy + Mul<S, Output = D> + MaybeSync,
456 D: Copy + Add<D, Output = D> + MaybeSendSync,
457 S: Copy + MaybeSendSync,
458 Op: ElementOp<S>,
459{
460 ensure_same_shape(dest.dims(), src.dims())?;
461
462 let dst_ptr = dest.as_mut_ptr();
463 let src_ptr = src.ptr();
464 let dst_dims = dest.dims();
465 let dst_strides = dest.strides();
466 let src_strides = src.strides();
467
468 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
469 let len = total_len(dst_dims);
470 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
471 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
472 simd::dispatch_if_large(len, || {
473 for i in 0..len {
474 dst[i] = alpha * Op::apply(src[i]) + dst[i];
475 }
476 });
477 return Ok(());
478 }
479
480 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
481 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
482
483 let (fused_dims, ordered_strides, plan) =
484 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
485
486 #[cfg(feature = "parallel")]
487 {
488 let total: usize = fused_dims.iter().product();
489 if total > MINTHREADLENGTH {
490 let dst_send = SendPtr(dst_ptr);
491 let src_send = SendPtr(src_ptr as *mut S);
492
493 let costs = compute_costs(&ordered_strides);
494 let initial_offsets = vec![0isize; strides_list.len()];
495 let nthreads = rayon::current_num_threads();
496
497 return mapreduce_threaded(
498 &fused_dims,
499 &plan.block,
500 &ordered_strides,
501 &initial_offsets,
502 &costs,
503 nthreads,
504 0,
505 1,
506 &|dims, blocks, strides_list, offsets| {
507 for_each_inner_block_with_offsets(
508 dims,
509 blocks,
510 strides_list,
511 offsets,
512 |offsets, len, strides| {
513 unsafe {
514 inner_loop_axpy::<D, S, A, Op>(
515 dst_send.as_ptr().offset(offsets[0]),
516 strides[0],
517 src_send.as_const().offset(offsets[1]),
518 strides[1],
519 len,
520 alpha,
521 )
522 };
523 Ok(())
524 },
525 )
526 },
527 );
528 }
529 }
530
531 let initial_offsets = vec![0isize; ordered_strides.len()];
532 for_each_inner_block_preordered(
533 &fused_dims,
534 &plan.block,
535 &ordered_strides,
536 &initial_offsets,
537 |offsets, len, strides| {
538 unsafe {
539 inner_loop_axpy::<D, S, A, Op>(
540 dst_ptr.offset(offsets[0]),
541 strides[0],
542 src_ptr.offset(offsets[1]),
543 strides[1],
544 len,
545 alpha,
546 )
547 };
548 Ok(())
549 },
550 )
551}
552
553pub fn fma<D, A, B, OpA, OpB>(
557 dest: &mut StridedViewMut<D>,
558 a: &StridedView<A, OpA>,
559 b: &StridedView<B, OpB>,
560) -> Result<()>
561where
562 A: Copy + Mul<B, Output = D> + MaybeSendSync,
563 B: Copy + MaybeSendSync,
564 D: Copy + Add<D, Output = D> + MaybeSendSync,
565 OpA: ElementOp<A>,
566 OpB: ElementOp<B>,
567{
568 ensure_same_shape(dest.dims(), a.dims())?;
569 ensure_same_shape(dest.dims(), b.dims())?;
570
571 let dst_ptr = dest.as_mut_ptr();
572 let a_ptr = a.ptr();
573 let b_ptr = b.ptr();
574 let dst_dims = dest.dims();
575 let dst_strides = dest.strides();
576 let a_strides = a.strides();
577 let b_strides = b.strides();
578
579 if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
580 let len = total_len(dst_dims);
581 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
582 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
583 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
584 simd::dispatch_if_large(len, || {
585 for i in 0..len {
586 dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
587 }
588 });
589 return Ok(());
590 }
591
592 let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
593 let elem_size = std::mem::size_of::<D>()
594 .max(std::mem::size_of::<A>())
595 .max(std::mem::size_of::<B>());
596
597 let (fused_dims, ordered_strides, plan) =
598 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
599
600 #[cfg(feature = "parallel")]
601 {
602 let total: usize = fused_dims.iter().product();
603 if total > MINTHREADLENGTH {
604 let dst_send = SendPtr(dst_ptr);
605 let a_send = SendPtr(a_ptr as *mut A);
606 let b_send = SendPtr(b_ptr as *mut B);
607
608 let costs = compute_costs(&ordered_strides);
609 let initial_offsets = vec![0isize; strides_list.len()];
610 let nthreads = rayon::current_num_threads();
611
612 return mapreduce_threaded(
613 &fused_dims,
614 &plan.block,
615 &ordered_strides,
616 &initial_offsets,
617 &costs,
618 nthreads,
619 0,
620 1,
621 &|dims, blocks, strides_list, offsets| {
622 for_each_inner_block_with_offsets(
623 dims,
624 blocks,
625 strides_list,
626 offsets,
627 |offsets, len, strides| {
628 unsafe {
629 inner_loop_fma::<D, A, B, OpA, OpB>(
630 dst_send.as_ptr().offset(offsets[0]),
631 strides[0],
632 a_send.as_const().offset(offsets[1]),
633 strides[1],
634 b_send.as_const().offset(offsets[2]),
635 strides[2],
636 len,
637 )
638 };
639 Ok(())
640 },
641 )
642 },
643 );
644 }
645 }
646
647 let initial_offsets = vec![0isize; ordered_strides.len()];
648 for_each_inner_block_preordered(
649 &fused_dims,
650 &plan.block,
651 &ordered_strides,
652 &initial_offsets,
653 |offsets, len, strides| {
654 unsafe {
655 inner_loop_fma::<D, A, B, OpA, OpB>(
656 dst_ptr.offset(offsets[0]),
657 strides[0],
658 a_ptr.offset(offsets[1]),
659 strides[1],
660 b_ptr.offset(offsets[2]),
661 strides[2],
662 len,
663 )
664 };
665 Ok(())
666 },
667 )
668}
669
670#[cfg(feature = "parallel")]
671fn parallel_simd_sum<T: Copy + Zero + Add<Output = T> + simd::MaybeSimdOps + Send + Sync>(
672 src: &[T],
673) -> Option<T> {
674 use rayon::prelude::*;
675 if T::try_simd_sum(&[]).is_none() {
677 return None;
678 }
679 let nthreads = rayon::current_num_threads();
680 let chunk_size = (src.len() + nthreads - 1) / nthreads;
681 let result = src
682 .par_chunks(chunk_size)
683 .map(|chunk| T::try_simd_sum(chunk).unwrap())
684 .reduce(|| T::zero(), |a, b| a + b);
685 Some(result)
686}
687
688pub fn sum<
690 T: Copy + Zero + Add<Output = T> + MaybeSendSync + simd::MaybeSimdOps,
691 Op: ElementOp<T>,
692>(
693 src: &StridedView<T, Op>,
694) -> Result<T> {
695 if Op::IS_IDENTITY {
697 if same_contiguous_layout(src.dims(), &[src.strides()]).is_some() {
698 let len = total_len(src.dims());
699 let src_slice = unsafe { std::slice::from_raw_parts(src.ptr(), len) };
700
701 #[cfg(feature = "parallel")]
702 if len > MINTHREADLENGTH {
703 if let Some(result) = parallel_simd_sum(src_slice) {
704 return Ok(result);
705 }
706 }
707
708 if let Some(result) = T::try_simd_sum(src_slice) {
709 return Ok(result);
710 }
711 }
712 }
713 reduce(src, |x| x, |a, b| a + b, T::zero())
714}
715
716pub fn dot<A, B, R, OpA, OpB>(a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) -> Result<R>
721where
722 A: Copy + Mul<B, Output = R> + MaybeSendSync + 'static,
723 B: Copy + MaybeSendSync + 'static,
724 R: Copy + Zero + Add<Output = R> + MaybeSendSync + simd::MaybeSimdOps + 'static,
725 OpA: ElementOp<A>,
726 OpB: ElementOp<B>,
727{
728 ensure_same_shape(a.dims(), b.dims())?;
729
730 let a_ptr = a.ptr();
731 let b_ptr = b.ptr();
732 let a_strides = a.strides();
733 let b_strides = b.strides();
734 let a_dims = a.dims();
735
736 if same_contiguous_layout(a_dims, &[a_strides, b_strides]).is_some() {
737 let len = total_len(a_dims);
738
739 if OpA::IS_IDENTITY
741 && OpB::IS_IDENTITY
742 && std::any::TypeId::of::<A>() == std::any::TypeId::of::<R>()
743 && std::any::TypeId::of::<B>() == std::any::TypeId::of::<R>()
744 {
745 let sa = unsafe { std::slice::from_raw_parts(a_ptr as *const R, len) };
746 let sb = unsafe { std::slice::from_raw_parts(b_ptr as *const R, len) };
747 if let Some(result) = R::try_simd_dot(sa, sb) {
748 return Ok(result);
749 }
750 }
751
752 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
754 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
755 let mut acc = R::zero();
756 simd::dispatch_if_large(len, || {
757 for i in 0..len {
758 acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
759 }
760 });
761 return Ok(acc);
762 }
763
764 let strides_list: [&[isize]; 2] = [a_strides, b_strides];
765 let elem_size = std::mem::size_of::<A>()
766 .max(std::mem::size_of::<B>())
767 .max(std::mem::size_of::<R>());
768
769 let (fused_dims, ordered_strides, plan) =
770 build_plan_fused(a_dims, &strides_list, None, elem_size);
771
772 let mut acc = R::zero();
773 let initial_offsets = vec![0isize; ordered_strides.len()];
774 for_each_inner_block_preordered(
775 &fused_dims,
776 &plan.block,
777 &ordered_strides,
778 &initial_offsets,
779 |offsets, len, strides| {
780 acc = unsafe {
781 inner_loop_dot::<A, B, R, OpA, OpB>(
782 a_ptr.offset(offsets[0]),
783 strides[0],
784 b_ptr.offset(offsets[1]),
785 strides[1],
786 len,
787 acc,
788 )
789 };
790 Ok(())
791 },
792 )?;
793
794 Ok(acc)
795}
796
797pub fn symmetrize_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
799where
800 T: Copy
801 + Add<Output = T>
802 + Mul<Output = T>
803 + num_traits::FromPrimitive
804 + std::ops::Div<Output = T>
805 + MaybeSendSync,
806{
807 if src.ndim() != 2 {
808 return Err(StridedError::RankMismatch(src.ndim(), 2));
809 }
810 let rows = src.dims()[0];
811 let cols = src.dims()[1];
812 if rows != cols {
813 return Err(StridedError::NonSquare { rows, cols });
814 }
815
816 let src_t = src.permute(&[1, 0])?;
817 let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
818
819 zip_map2_into(dest, src, &src_t, |a, b| (a + b) * half)
820}
821
822pub fn symmetrize_conj_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
824where
825 T: Copy
826 + ElementOpApply
827 + Add<Output = T>
828 + Mul<Output = T>
829 + num_traits::FromPrimitive
830 + std::ops::Div<Output = T>
831 + MaybeSendSync,
832{
833 if src.ndim() != 2 {
834 return Err(StridedError::RankMismatch(src.ndim(), 2));
835 }
836 let rows = src.dims()[0];
837 let cols = src.dims()[1];
838 if rows != cols {
839 return Err(StridedError::NonSquare { rows, cols });
840 }
841
842 let src_adj = src.adjoint_2d()?;
844 let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
845
846 zip_map2_into(dest, src, &src_adj, |a, b| (a + b) * half)
847}
848
849pub fn copy_scale<D, S, A, Op>(
853 dest: &mut StridedViewMut<D>,
854 src: &StridedView<S, Op>,
855 scale: A,
856) -> Result<()>
857where
858 A: Copy + Mul<S, Output = D> + MaybeSync,
859 D: Copy + MaybeSendSync,
860 S: Copy + MaybeSendSync,
861 Op: ElementOp<S>,
862{
863 map_into(dest, src, |x| scale * x)
864}
865
866pub fn copy_conj<T: Copy + ElementOpApply + MaybeSendSync>(
868 dest: &mut StridedViewMut<T>,
869 src: &StridedView<T>,
870) -> Result<()> {
871 let src_conj = src.conj();
872 copy_into(dest, &src_conj)
873}
874
875pub fn copy_transpose_scale_into<T>(
877 dest: &mut StridedViewMut<T>,
878 src: &StridedView<T>,
879 scale: T,
880) -> Result<()>
881where
882 T: Copy + ElementOpApply + Mul<Output = T> + MaybeSendSync,
883{
884 if src.ndim() != 2 || dest.ndim() != 2 {
885 return Err(StridedError::RankMismatch(src.ndim(), 2));
886 }
887 let src_t = src.transpose_2d()?;
888 map_into(dest, &src_t, |x| scale * x)
889}