1use crate::kernel::{
4 build_plan_fused, ensure_same_shape, for_each_inner_block_preordered, same_contiguous_layout,
5 sequential_contiguous_layout, total_len,
6};
7use crate::map_view::{map_into, zip_map2_into};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::reduce_view::reduce;
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::{Result, StridedError};
13use num_traits::Zero;
14use std::ops::{Add, Mul};
15use strided_view::{ElementOp, ElementOpApply};
16
17#[cfg(feature = "parallel")]
18use crate::fuse::compute_costs;
19#[cfg(feature = "parallel")]
20use crate::threading::{
21 for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
22};
23
24#[inline(always)]
34unsafe fn inner_loop_add<D: Copy + Add<S, Output = D>, S: Copy, Op: ElementOp<S>>(
35 dp: *mut D,
36 ds: isize,
37 sp: *const S,
38 ss: isize,
39 len: usize,
40) {
41 if ds == 1 && ss == 1 {
42 let dst = std::slice::from_raw_parts_mut(dp, len);
43 let src = std::slice::from_raw_parts(sp, len);
44 simd::dispatch_if_large(len, || {
45 for i in 0..len {
46 dst[i] = dst[i] + Op::apply(src[i]);
47 }
48 });
49 } else {
50 let mut dp = dp;
51 let mut sp = sp;
52 for _ in 0..len {
53 *dp = *dp + Op::apply(*sp);
54 dp = dp.offset(ds);
55 sp = sp.offset(ss);
56 }
57 }
58}
59
60#[inline(always)]
62unsafe fn inner_loop_mul<D: Copy + Mul<S, Output = D>, S: Copy, Op: ElementOp<S>>(
63 dp: *mut D,
64 ds: isize,
65 sp: *const S,
66 ss: isize,
67 len: usize,
68) {
69 if ds == 1 && ss == 1 {
70 let dst = std::slice::from_raw_parts_mut(dp, len);
71 let src = std::slice::from_raw_parts(sp, len);
72 simd::dispatch_if_large(len, || {
73 for i in 0..len {
74 dst[i] = dst[i] * Op::apply(src[i]);
75 }
76 });
77 } else {
78 let mut dp = dp;
79 let mut sp = sp;
80 for _ in 0..len {
81 *dp = *dp * Op::apply(*sp);
82 dp = dp.offset(ds);
83 sp = sp.offset(ss);
84 }
85 }
86}
87
88#[inline(always)]
90unsafe fn inner_loop_axpy<
91 D: Copy + Add<D, Output = D>,
92 S: Copy,
93 A: Copy + Mul<S, Output = D>,
94 Op: ElementOp<S>,
95>(
96 dp: *mut D,
97 ds: isize,
98 sp: *const S,
99 ss: isize,
100 len: usize,
101 alpha: A,
102) {
103 if ds == 1 && ss == 1 {
104 let dst = std::slice::from_raw_parts_mut(dp, len);
105 let src = std::slice::from_raw_parts(sp, len);
106 simd::dispatch_if_large(len, || {
107 for i in 0..len {
108 dst[i] = alpha * Op::apply(src[i]) + dst[i];
109 }
110 });
111 } else {
112 let mut dp = dp;
113 let mut sp = sp;
114 for _ in 0..len {
115 *dp = alpha * Op::apply(*sp) + *dp;
116 dp = dp.offset(ds);
117 sp = sp.offset(ss);
118 }
119 }
120}
121
122#[inline(always)]
124unsafe fn inner_loop_fma<
125 D: Copy + Add<D, Output = D>,
126 A: Copy + Mul<B, Output = D>,
127 B: Copy,
128 OpA: ElementOp<A>,
129 OpB: ElementOp<B>,
130>(
131 dp: *mut D,
132 ds: isize,
133 ap: *const A,
134 a_s: isize,
135 bp: *const B,
136 b_s: isize,
137 len: usize,
138) {
139 if ds == 1 && a_s == 1 && b_s == 1 {
140 let dst = std::slice::from_raw_parts_mut(dp, len);
141 let sa = std::slice::from_raw_parts(ap, len);
142 let sb = std::slice::from_raw_parts(bp, len);
143 simd::dispatch_if_large(len, || {
144 for i in 0..len {
145 dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
146 }
147 });
148 } else {
149 let mut dp = dp;
150 let mut ap = ap;
151 let mut bp = bp;
152 for _ in 0..len {
153 *dp = *dp + OpA::apply(*ap) * OpB::apply(*bp);
154 dp = dp.offset(ds);
155 ap = ap.offset(a_s);
156 bp = bp.offset(b_s);
157 }
158 }
159}
160
161#[inline(always)]
163unsafe fn inner_loop_dot<
164 A: Copy + Mul<B, Output = R>,
165 B: Copy,
166 R: Copy + Add<R, Output = R>,
167 OpA: ElementOp<A>,
168 OpB: ElementOp<B>,
169>(
170 ap: *const A,
171 a_s: isize,
172 bp: *const B,
173 b_s: isize,
174 len: usize,
175 mut acc: R,
176) -> R {
177 if a_s == 1 && b_s == 1 {
178 let sa = std::slice::from_raw_parts(ap, len);
179 let sb = std::slice::from_raw_parts(bp, len);
180 simd::dispatch_if_large(len, || {
181 for i in 0..len {
182 acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
183 }
184 });
185 } else {
186 let mut ap = ap;
187 let mut bp = bp;
188 for _ in 0..len {
189 acc = acc + OpA::apply(*ap) * OpB::apply(*bp);
190 ap = ap.offset(a_s);
191 bp = bp.offset(b_s);
192 }
193 }
194 acc
195}
196
197pub fn copy_into<T: Copy + MaybeSendSync, Op: ElementOp<T>>(
199 dest: &mut StridedViewMut<T>,
200 src: &StridedView<T, Op>,
201) -> Result<()> {
202 ensure_same_shape(dest.dims(), src.dims())?;
203
204 let dst_ptr = dest.as_mut_ptr();
205 let src_ptr = src.ptr();
206 let dst_dims = dest.dims();
207 let dst_strides = dest.strides();
208 let src_strides = src.strides();
209
210 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
211 let len = total_len(dst_dims);
212 if Op::IS_IDENTITY {
213 debug_assert!(
214 {
215 let nbytes = len
216 .checked_mul(std::mem::size_of::<T>())
217 .expect("copy size must not overflow");
218 let dst_start = dst_ptr as usize;
219 let src_start = src_ptr as usize;
220 let dst_end = dst_start.saturating_add(nbytes);
221 let src_end = src_start.saturating_add(nbytes);
222 dst_end <= src_start || src_end <= dst_start
223 },
224 "overlapping src/dest is not supported"
225 );
226 unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
227 } else {
228 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
229 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
230 simd::dispatch_if_large(len, || {
231 for i in 0..len {
232 dst[i] = Op::apply(src[i]);
233 }
234 });
235 }
236 return Ok(());
237 }
238
239 map_into(dest, src, |x| x)
240}
241
242pub fn copy_into_col_major<T: Copy + MaybeSendSync>(
246 dst: &mut StridedViewMut<T>,
247 src: &StridedView<T>,
248) -> Result<()> {
249 strided_perm::copy_into_col_major(dst, src)
250}
251
252pub fn add<
256 D: Copy + Add<S, Output = D> + MaybeSendSync,
257 S: Copy + MaybeSendSync,
258 Op: ElementOp<S>,
259>(
260 dest: &mut StridedViewMut<D>,
261 src: &StridedView<S, Op>,
262) -> Result<()> {
263 ensure_same_shape(dest.dims(), src.dims())?;
264
265 let dst_ptr = dest.as_mut_ptr();
266 let src_ptr = src.ptr();
267 let dst_dims = dest.dims();
268 let dst_strides = dest.strides();
269 let src_strides = src.strides();
270
271 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
272 let len = total_len(dst_dims);
273 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
274 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
275 simd::dispatch_if_large(len, || {
276 for i in 0..len {
277 dst[i] = dst[i] + Op::apply(src[i]);
278 }
279 });
280 return Ok(());
281 }
282
283 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
284 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
285
286 let (fused_dims, ordered_strides, plan) =
287 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
288
289 #[cfg(feature = "parallel")]
290 {
291 let total: usize = fused_dims.iter().product();
292 if total > MINTHREADLENGTH {
293 let dst_send = SendPtr(dst_ptr);
294 let src_send = SendPtr(src_ptr as *mut S);
295
296 let costs = compute_costs(&ordered_strides);
297 let initial_offsets = vec![0isize; strides_list.len()];
298 let nthreads = rayon::current_num_threads();
299
300 return mapreduce_threaded(
301 &fused_dims,
302 &plan.block,
303 &ordered_strides,
304 &initial_offsets,
305 &costs,
306 nthreads,
307 0,
308 1,
309 &|dims, blocks, strides_list, offsets| {
310 for_each_inner_block_with_offsets(
311 dims,
312 blocks,
313 strides_list,
314 offsets,
315 |offsets, len, strides| {
316 unsafe {
317 inner_loop_add::<D, S, Op>(
318 dst_send.as_ptr().offset(offsets[0]),
319 strides[0],
320 src_send.as_const().offset(offsets[1]),
321 strides[1],
322 len,
323 )
324 };
325 Ok(())
326 },
327 )
328 },
329 );
330 }
331 }
332
333 let initial_offsets = vec![0isize; ordered_strides.len()];
334 for_each_inner_block_preordered(
335 &fused_dims,
336 &plan.block,
337 &ordered_strides,
338 &initial_offsets,
339 |offsets, len, strides| {
340 unsafe {
341 inner_loop_add::<D, S, Op>(
342 dst_ptr.offset(offsets[0]),
343 strides[0],
344 src_ptr.offset(offsets[1]),
345 strides[1],
346 len,
347 )
348 };
349 Ok(())
350 },
351 )
352}
353
354pub fn mul<
358 D: Copy + Mul<S, Output = D> + MaybeSendSync,
359 S: Copy + MaybeSendSync,
360 Op: ElementOp<S>,
361>(
362 dest: &mut StridedViewMut<D>,
363 src: &StridedView<S, Op>,
364) -> Result<()> {
365 ensure_same_shape(dest.dims(), src.dims())?;
366
367 let dst_ptr = dest.as_mut_ptr();
368 let src_ptr = src.ptr();
369 let dst_dims = dest.dims();
370 let dst_strides = dest.strides();
371 let src_strides = src.strides();
372
373 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
374 let len = total_len(dst_dims);
375 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
376 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
377 simd::dispatch_if_large(len, || {
378 for i in 0..len {
379 dst[i] = dst[i] * Op::apply(src[i]);
380 }
381 });
382 return Ok(());
383 }
384
385 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
386 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
387
388 let (fused_dims, ordered_strides, plan) =
389 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
390
391 #[cfg(feature = "parallel")]
392 {
393 let total: usize = fused_dims.iter().product();
394 if total > MINTHREADLENGTH {
395 let dst_send = SendPtr(dst_ptr);
396 let src_send = SendPtr(src_ptr as *mut S);
397
398 let costs = compute_costs(&ordered_strides);
399 let initial_offsets = vec![0isize; strides_list.len()];
400 let nthreads = rayon::current_num_threads();
401
402 return mapreduce_threaded(
403 &fused_dims,
404 &plan.block,
405 &ordered_strides,
406 &initial_offsets,
407 &costs,
408 nthreads,
409 0,
410 1,
411 &|dims, blocks, strides_list, offsets| {
412 for_each_inner_block_with_offsets(
413 dims,
414 blocks,
415 strides_list,
416 offsets,
417 |offsets, len, strides| {
418 unsafe {
419 inner_loop_mul::<D, S, Op>(
420 dst_send.as_ptr().offset(offsets[0]),
421 strides[0],
422 src_send.as_const().offset(offsets[1]),
423 strides[1],
424 len,
425 )
426 };
427 Ok(())
428 },
429 )
430 },
431 );
432 }
433 }
434
435 let initial_offsets = vec![0isize; ordered_strides.len()];
436 for_each_inner_block_preordered(
437 &fused_dims,
438 &plan.block,
439 &ordered_strides,
440 &initial_offsets,
441 |offsets, len, strides| {
442 unsafe {
443 inner_loop_mul::<D, S, Op>(
444 dst_ptr.offset(offsets[0]),
445 strides[0],
446 src_ptr.offset(offsets[1]),
447 strides[1],
448 len,
449 )
450 };
451 Ok(())
452 },
453 )
454}
455
456pub fn axpy<D, S, A, Op>(
460 dest: &mut StridedViewMut<D>,
461 src: &StridedView<S, Op>,
462 alpha: A,
463) -> Result<()>
464where
465 A: Copy + Mul<S, Output = D> + MaybeSync,
466 D: Copy + Add<D, Output = D> + MaybeSendSync,
467 S: Copy + MaybeSendSync,
468 Op: ElementOp<S>,
469{
470 ensure_same_shape(dest.dims(), src.dims())?;
471
472 let dst_ptr = dest.as_mut_ptr();
473 let src_ptr = src.ptr();
474 let dst_dims = dest.dims();
475 let dst_strides = dest.strides();
476 let src_strides = src.strides();
477
478 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
479 let len = total_len(dst_dims);
480 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
481 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
482 simd::dispatch_if_large(len, || {
483 for i in 0..len {
484 dst[i] = alpha * Op::apply(src[i]) + dst[i];
485 }
486 });
487 return Ok(());
488 }
489
490 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
491 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
492
493 let (fused_dims, ordered_strides, plan) =
494 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
495
496 #[cfg(feature = "parallel")]
497 {
498 let total: usize = fused_dims.iter().product();
499 if total > MINTHREADLENGTH {
500 let dst_send = SendPtr(dst_ptr);
501 let src_send = SendPtr(src_ptr as *mut S);
502
503 let costs = compute_costs(&ordered_strides);
504 let initial_offsets = vec![0isize; strides_list.len()];
505 let nthreads = rayon::current_num_threads();
506
507 return mapreduce_threaded(
508 &fused_dims,
509 &plan.block,
510 &ordered_strides,
511 &initial_offsets,
512 &costs,
513 nthreads,
514 0,
515 1,
516 &|dims, blocks, strides_list, offsets| {
517 for_each_inner_block_with_offsets(
518 dims,
519 blocks,
520 strides_list,
521 offsets,
522 |offsets, len, strides| {
523 unsafe {
524 inner_loop_axpy::<D, S, A, Op>(
525 dst_send.as_ptr().offset(offsets[0]),
526 strides[0],
527 src_send.as_const().offset(offsets[1]),
528 strides[1],
529 len,
530 alpha,
531 )
532 };
533 Ok(())
534 },
535 )
536 },
537 );
538 }
539 }
540
541 let initial_offsets = vec![0isize; ordered_strides.len()];
542 for_each_inner_block_preordered(
543 &fused_dims,
544 &plan.block,
545 &ordered_strides,
546 &initial_offsets,
547 |offsets, len, strides| {
548 unsafe {
549 inner_loop_axpy::<D, S, A, Op>(
550 dst_ptr.offset(offsets[0]),
551 strides[0],
552 src_ptr.offset(offsets[1]),
553 strides[1],
554 len,
555 alpha,
556 )
557 };
558 Ok(())
559 },
560 )
561}
562
563pub fn fma<D, A, B, OpA, OpB>(
567 dest: &mut StridedViewMut<D>,
568 a: &StridedView<A, OpA>,
569 b: &StridedView<B, OpB>,
570) -> Result<()>
571where
572 A: Copy + Mul<B, Output = D> + MaybeSendSync,
573 B: Copy + MaybeSendSync,
574 D: Copy + Add<D, Output = D> + MaybeSendSync,
575 OpA: ElementOp<A>,
576 OpB: ElementOp<B>,
577{
578 ensure_same_shape(dest.dims(), a.dims())?;
579 ensure_same_shape(dest.dims(), b.dims())?;
580
581 let dst_ptr = dest.as_mut_ptr();
582 let a_ptr = a.ptr();
583 let b_ptr = b.ptr();
584 let dst_dims = dest.dims();
585 let dst_strides = dest.strides();
586 let a_strides = a.strides();
587 let b_strides = b.strides();
588
589 if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
590 let len = total_len(dst_dims);
591 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
592 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
593 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
594 simd::dispatch_if_large(len, || {
595 for i in 0..len {
596 dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
597 }
598 });
599 return Ok(());
600 }
601
602 let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
603 let elem_size = std::mem::size_of::<D>()
604 .max(std::mem::size_of::<A>())
605 .max(std::mem::size_of::<B>());
606
607 let (fused_dims, ordered_strides, plan) =
608 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
609
610 #[cfg(feature = "parallel")]
611 {
612 let total: usize = fused_dims.iter().product();
613 if total > MINTHREADLENGTH {
614 let dst_send = SendPtr(dst_ptr);
615 let a_send = SendPtr(a_ptr as *mut A);
616 let b_send = SendPtr(b_ptr as *mut B);
617
618 let costs = compute_costs(&ordered_strides);
619 let initial_offsets = vec![0isize; strides_list.len()];
620 let nthreads = rayon::current_num_threads();
621
622 return mapreduce_threaded(
623 &fused_dims,
624 &plan.block,
625 &ordered_strides,
626 &initial_offsets,
627 &costs,
628 nthreads,
629 0,
630 1,
631 &|dims, blocks, strides_list, offsets| {
632 for_each_inner_block_with_offsets(
633 dims,
634 blocks,
635 strides_list,
636 offsets,
637 |offsets, len, strides| {
638 unsafe {
639 inner_loop_fma::<D, A, B, OpA, OpB>(
640 dst_send.as_ptr().offset(offsets[0]),
641 strides[0],
642 a_send.as_const().offset(offsets[1]),
643 strides[1],
644 b_send.as_const().offset(offsets[2]),
645 strides[2],
646 len,
647 )
648 };
649 Ok(())
650 },
651 )
652 },
653 );
654 }
655 }
656
657 let initial_offsets = vec![0isize; ordered_strides.len()];
658 for_each_inner_block_preordered(
659 &fused_dims,
660 &plan.block,
661 &ordered_strides,
662 &initial_offsets,
663 |offsets, len, strides| {
664 unsafe {
665 inner_loop_fma::<D, A, B, OpA, OpB>(
666 dst_ptr.offset(offsets[0]),
667 strides[0],
668 a_ptr.offset(offsets[1]),
669 strides[1],
670 b_ptr.offset(offsets[2]),
671 strides[2],
672 len,
673 )
674 };
675 Ok(())
676 },
677 )
678}
679
680#[cfg(feature = "parallel")]
681fn parallel_simd_sum<T: Copy + Zero + Add<Output = T> + simd::MaybeSimdOps + Send + Sync>(
682 src: &[T],
683) -> Option<T> {
684 use rayon::prelude::*;
685 if T::try_simd_sum(&[]).is_none() {
687 return None;
688 }
689 let nthreads = rayon::current_num_threads();
690 let chunk_size = (src.len() + nthreads - 1) / nthreads;
691 let result = src
692 .par_chunks(chunk_size)
693 .map(|chunk| T::try_simd_sum(chunk).unwrap())
694 .reduce(|| T::zero(), |a, b| a + b);
695 Some(result)
696}
697
698pub fn sum<
700 T: Copy + Zero + Add<Output = T> + MaybeSendSync + simd::MaybeSimdOps,
701 Op: ElementOp<T>,
702>(
703 src: &StridedView<T, Op>,
704) -> Result<T> {
705 if Op::IS_IDENTITY {
707 if same_contiguous_layout(src.dims(), &[src.strides()]).is_some() {
708 let len = total_len(src.dims());
709 let src_slice = unsafe { std::slice::from_raw_parts(src.ptr(), len) };
710
711 #[cfg(feature = "parallel")]
712 if len > MINTHREADLENGTH {
713 if let Some(result) = parallel_simd_sum(src_slice) {
714 return Ok(result);
715 }
716 }
717
718 if let Some(result) = T::try_simd_sum(src_slice) {
719 return Ok(result);
720 }
721 }
722 }
723 reduce(src, |x| x, |a, b| a + b, T::zero())
724}
725
726pub fn dot<A, B, R, OpA, OpB>(a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) -> Result<R>
731where
732 A: Copy + Mul<B, Output = R> + MaybeSendSync + 'static,
733 B: Copy + MaybeSendSync + 'static,
734 R: Copy + Zero + Add<Output = R> + MaybeSendSync + simd::MaybeSimdOps + 'static,
735 OpA: ElementOp<A>,
736 OpB: ElementOp<B>,
737{
738 ensure_same_shape(a.dims(), b.dims())?;
739
740 let a_ptr = a.ptr();
741 let b_ptr = b.ptr();
742 let a_strides = a.strides();
743 let b_strides = b.strides();
744 let a_dims = a.dims();
745
746 if same_contiguous_layout(a_dims, &[a_strides, b_strides]).is_some() {
747 let len = total_len(a_dims);
748
749 if OpA::IS_IDENTITY
751 && OpB::IS_IDENTITY
752 && std::any::TypeId::of::<A>() == std::any::TypeId::of::<R>()
753 && std::any::TypeId::of::<B>() == std::any::TypeId::of::<R>()
754 {
755 let sa = unsafe { std::slice::from_raw_parts(a_ptr as *const R, len) };
756 let sb = unsafe { std::slice::from_raw_parts(b_ptr as *const R, len) };
757 if let Some(result) = R::try_simd_dot(sa, sb) {
758 return Ok(result);
759 }
760 }
761
762 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
764 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
765 let mut acc = R::zero();
766 simd::dispatch_if_large(len, || {
767 for i in 0..len {
768 acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
769 }
770 });
771 return Ok(acc);
772 }
773
774 let strides_list: [&[isize]; 2] = [a_strides, b_strides];
775 let elem_size = std::mem::size_of::<A>()
776 .max(std::mem::size_of::<B>())
777 .max(std::mem::size_of::<R>());
778
779 let (fused_dims, ordered_strides, plan) =
780 build_plan_fused(a_dims, &strides_list, None, elem_size);
781
782 let mut acc = R::zero();
783 let initial_offsets = vec![0isize; ordered_strides.len()];
784 for_each_inner_block_preordered(
785 &fused_dims,
786 &plan.block,
787 &ordered_strides,
788 &initial_offsets,
789 |offsets, len, strides| {
790 acc = unsafe {
791 inner_loop_dot::<A, B, R, OpA, OpB>(
792 a_ptr.offset(offsets[0]),
793 strides[0],
794 b_ptr.offset(offsets[1]),
795 strides[1],
796 len,
797 acc,
798 )
799 };
800 Ok(())
801 },
802 )?;
803
804 Ok(acc)
805}
806
807pub fn symmetrize_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
809where
810 T: Copy
811 + Add<Output = T>
812 + Mul<Output = T>
813 + num_traits::FromPrimitive
814 + std::ops::Div<Output = T>
815 + MaybeSendSync,
816{
817 if src.ndim() != 2 {
818 return Err(StridedError::RankMismatch(src.ndim(), 2));
819 }
820 let rows = src.dims()[0];
821 let cols = src.dims()[1];
822 if rows != cols {
823 return Err(StridedError::NonSquare { rows, cols });
824 }
825
826 let src_t = src.permute(&[1, 0])?;
827 let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
828
829 zip_map2_into(dest, src, &src_t, |a, b| (a + b) * half)
830}
831
832pub fn symmetrize_conj_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
834where
835 T: Copy
836 + ElementOpApply
837 + Add<Output = T>
838 + Mul<Output = T>
839 + num_traits::FromPrimitive
840 + std::ops::Div<Output = T>
841 + MaybeSendSync,
842{
843 if src.ndim() != 2 {
844 return Err(StridedError::RankMismatch(src.ndim(), 2));
845 }
846 let rows = src.dims()[0];
847 let cols = src.dims()[1];
848 if rows != cols {
849 return Err(StridedError::NonSquare { rows, cols });
850 }
851
852 let src_adj = src.adjoint_2d()?;
854 let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
855
856 zip_map2_into(dest, src, &src_adj, |a, b| (a + b) * half)
857}
858
859pub fn copy_scale<D, S, A, Op>(
863 dest: &mut StridedViewMut<D>,
864 src: &StridedView<S, Op>,
865 scale: A,
866) -> Result<()>
867where
868 A: Copy + Mul<S, Output = D> + MaybeSync,
869 D: Copy + MaybeSendSync,
870 S: Copy + MaybeSendSync,
871 Op: ElementOp<S>,
872{
873 map_into(dest, src, |x| scale * x)
874}
875
876pub fn copy_conj<T: Copy + ElementOpApply + MaybeSendSync>(
878 dest: &mut StridedViewMut<T>,
879 src: &StridedView<T>,
880) -> Result<()> {
881 let src_conj = src.conj();
882 copy_into(dest, &src_conj)
883}
884
885pub fn copy_transpose_scale_into<T>(
887 dest: &mut StridedViewMut<T>,
888 src: &StridedView<T>,
889 scale: T,
890) -> Result<()>
891where
892 T: Copy + ElementOpApply + Mul<Output = T> + MaybeSendSync,
893{
894 if src.ndim() != 2 || dest.ndim() != 2 {
895 return Err(StridedError::RankMismatch(src.ndim(), 2));
896 }
897 let src_t = src.transpose_2d()?;
898 map_into(dest, &src_t, |x| scale * x)
899}