1use crate::kernel::{
6 build_plan_fused, ensure_same_shape, for_each_inner_block_preordered,
7 sequential_contiguous_layout, total_len,
8};
9use crate::maybe_sync::{MaybeSendSync, MaybeSync};
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::Result;
13use strided_view::ElementOp;
14
15#[cfg(feature = "parallel")]
16use crate::fuse::compute_costs;
17#[cfg(feature = "parallel")]
18use crate::threading::{for_each_inner_block_with_offsets, mapreduce_threaded, MINTHREADLENGTH};
19
20#[inline(always)]
30unsafe fn inner_loop_map1<D: Copy, A: Copy, Op: ElementOp<A>>(
31 dp: *mut D,
32 ds: isize,
33 sp: *const A,
34 ss: isize,
35 len: usize,
36 f: &impl Fn(A) -> D,
37) {
38 if ds == 1 && ss == 1 {
39 let src = std::slice::from_raw_parts(sp, len);
40 let dst = std::slice::from_raw_parts_mut(dp, len);
41 simd::dispatch_if_large(len, || {
42 for (d, s) in dst.iter_mut().zip(src.iter()) {
43 *d = f(Op::apply(*s));
44 }
45 });
46 } else {
47 let mut dp = dp;
48 let mut sp = sp;
49 for _ in 0..len {
50 *dp = f(Op::apply(*sp));
51 dp = dp.offset(ds);
52 sp = sp.offset(ss);
53 }
54 }
55}
56
57#[inline(always)]
59unsafe fn inner_loop_map2<D: Copy, A: Copy, B: Copy, OpA: ElementOp<A>, OpB: ElementOp<B>>(
60 dp: *mut D,
61 ds: isize,
62 ap: *const A,
63 a_s: isize,
64 bp: *const B,
65 b_s: isize,
66 len: usize,
67 f: &impl Fn(A, B) -> D,
68) {
69 if ds == 1 && a_s == 1 && b_s == 1 {
70 let src_a = std::slice::from_raw_parts(ap, len);
71 let src_b = std::slice::from_raw_parts(bp, len);
72 let dst = std::slice::from_raw_parts_mut(dp, len);
73 simd::dispatch_if_large(len, || {
74 for i in 0..len {
75 dst[i] = f(OpA::apply(src_a[i]), OpB::apply(src_b[i]));
76 }
77 });
78 } else {
79 let mut dp = dp;
80 let mut ap = ap;
81 let mut bp = bp;
82 for _ in 0..len {
83 *dp = f(OpA::apply(*ap), OpB::apply(*bp));
84 dp = dp.offset(ds);
85 ap = ap.offset(a_s);
86 bp = bp.offset(b_s);
87 }
88 }
89}
90
91#[inline(always)]
93unsafe fn inner_loop_map3<
94 D: Copy,
95 A: Copy,
96 B: Copy,
97 C: Copy,
98 OpA: ElementOp<A>,
99 OpB: ElementOp<B>,
100 OpC: ElementOp<C>,
101>(
102 dp: *mut D,
103 ds: isize,
104 ap: *const A,
105 a_s: isize,
106 bp: *const B,
107 b_s: isize,
108 cp: *const C,
109 c_s: isize,
110 len: usize,
111 f: &impl Fn(A, B, C) -> D,
112) {
113 if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 {
114 let src_a = std::slice::from_raw_parts(ap, len);
115 let src_b = std::slice::from_raw_parts(bp, len);
116 let src_c = std::slice::from_raw_parts(cp, len);
117 let dst = std::slice::from_raw_parts_mut(dp, len);
118 simd::dispatch_if_large(len, || {
119 for i in 0..len {
120 dst[i] = f(
121 OpA::apply(src_a[i]),
122 OpB::apply(src_b[i]),
123 OpC::apply(src_c[i]),
124 );
125 }
126 });
127 } else {
128 let mut dp = dp;
129 let mut ap = ap;
130 let mut bp = bp;
131 let mut cp = cp;
132 for _ in 0..len {
133 *dp = f(OpA::apply(*ap), OpB::apply(*bp), OpC::apply(*cp));
134 dp = dp.offset(ds);
135 ap = ap.offset(a_s);
136 bp = bp.offset(b_s);
137 cp = cp.offset(c_s);
138 }
139 }
140}
141
142#[inline(always)]
144unsafe fn inner_loop_map4<
145 D: Copy,
146 A: Copy,
147 B: Copy,
148 C: Copy,
149 E: Copy,
150 OpA: ElementOp<A>,
151 OpB: ElementOp<B>,
152 OpC: ElementOp<C>,
153 OpE: ElementOp<E>,
154>(
155 dp: *mut D,
156 ds: isize,
157 ap: *const A,
158 a_s: isize,
159 bp: *const B,
160 b_s: isize,
161 cp: *const C,
162 c_s: isize,
163 ep: *const E,
164 e_s: isize,
165 len: usize,
166 f: &impl Fn(A, B, C, E) -> D,
167) {
168 if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 && e_s == 1 {
169 let src_a = std::slice::from_raw_parts(ap, len);
170 let src_b = std::slice::from_raw_parts(bp, len);
171 let src_c = std::slice::from_raw_parts(cp, len);
172 let src_e = std::slice::from_raw_parts(ep, len);
173 let dst = std::slice::from_raw_parts_mut(dp, len);
174 simd::dispatch_if_large(len, || {
175 for i in 0..len {
176 dst[i] = f(
177 OpA::apply(src_a[i]),
178 OpB::apply(src_b[i]),
179 OpC::apply(src_c[i]),
180 OpE::apply(src_e[i]),
181 );
182 }
183 });
184 } else {
185 let mut dp = dp;
186 let mut ap = ap;
187 let mut bp = bp;
188 let mut cp = cp;
189 let mut ep = ep;
190 for _ in 0..len {
191 *dp = f(
192 OpA::apply(*ap),
193 OpB::apply(*bp),
194 OpC::apply(*cp),
195 OpE::apply(*ep),
196 );
197 dp = dp.offset(ds);
198 ap = ap.offset(a_s);
199 bp = bp.offset(b_s);
200 cp = cp.offset(c_s);
201 ep = ep.offset(e_s);
202 }
203 }
204}
205
206pub fn map_into<D: Copy + MaybeSendSync, A: Copy + MaybeSendSync, Op: ElementOp<A>>(
211 dest: &mut StridedViewMut<D>,
212 src: &StridedView<A, Op>,
213 f: impl Fn(A) -> D + MaybeSync,
214) -> Result<()> {
215 ensure_same_shape(dest.dims(), src.dims())?;
216
217 let dst_ptr = dest.as_mut_ptr();
218 let src_ptr = src.ptr();
219 let dst_dims = dest.dims();
220 let dst_strides = dest.strides();
221 let src_strides = src.strides();
222
223 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
224 let len = total_len(dst_dims);
225 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
226 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
227 simd::dispatch_if_large(len, || {
228 for i in 0..len {
229 dst[i] = f(Op::apply(src[i]));
230 }
231 });
232 return Ok(());
233 }
234
235 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
236 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<A>());
237
238 let (fused_dims, ordered_strides, plan) =
239 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
240
241 #[cfg(feature = "parallel")]
242 {
243 let total: usize = fused_dims.iter().product();
244 if total > MINTHREADLENGTH {
245 use crate::threading::SendPtr;
246 let dst_send = SendPtr(dst_ptr);
247 let src_send = SendPtr(src_ptr as *mut A);
248
249 let costs = compute_costs(&ordered_strides);
250 let initial_offsets = vec![0isize; strides_list.len()];
251 let nthreads = rayon::current_num_threads();
252
253 return mapreduce_threaded(
254 &fused_dims,
255 &plan.block,
256 &ordered_strides,
257 &initial_offsets,
258 &costs,
259 nthreads,
260 0,
261 1,
262 &|dims, blocks, strides_list, offsets| {
263 for_each_inner_block_with_offsets(
264 dims,
265 blocks,
266 strides_list,
267 offsets,
268 |offsets, len, strides| {
269 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
270 let sp = unsafe { src_send.as_const().offset(offsets[1]) };
271 unsafe {
272 inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f)
273 };
274 Ok(())
275 },
276 )
277 },
278 );
279 }
280 }
281
282 let initial_offsets = vec![0isize; ordered_strides.len()];
283 for_each_inner_block_preordered(
284 &fused_dims,
285 &plan.block,
286 &ordered_strides,
287 &initial_offsets,
288 |offsets, len, strides| {
289 let dp = unsafe { dst_ptr.offset(offsets[0]) };
290 let sp = unsafe { src_ptr.offset(offsets[1]) };
291 unsafe { inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f) };
292 Ok(())
293 },
294 )
295}
296
297pub fn zip_map2_into<
302 D: Copy + MaybeSendSync,
303 A: Copy + MaybeSendSync,
304 B: Copy + MaybeSendSync,
305 OpA: ElementOp<A>,
306 OpB: ElementOp<B>,
307>(
308 dest: &mut StridedViewMut<D>,
309 a: &StridedView<A, OpA>,
310 b: &StridedView<B, OpB>,
311 f: impl Fn(A, B) -> D + MaybeSync,
312) -> Result<()> {
313 ensure_same_shape(dest.dims(), a.dims())?;
314 ensure_same_shape(dest.dims(), b.dims())?;
315
316 let dst_ptr = dest.as_mut_ptr();
317 let dst_dims = dest.dims();
318 let dst_strides = dest.strides();
319 let a_ptr = a.ptr();
320 let b_ptr = b.ptr();
321
322 let a_strides = a.strides();
323 let b_strides = b.strides();
324
325 if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
326 let len = total_len(dst_dims);
327 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
328 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
329 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
330 simd::dispatch_if_large(len, || {
331 for i in 0..len {
332 dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]));
333 }
334 });
335 return Ok(());
336 }
337
338 let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
339 let elem_size = std::mem::size_of::<D>()
340 .max(std::mem::size_of::<A>())
341 .max(std::mem::size_of::<B>());
342
343 let (fused_dims, ordered_strides, plan) =
344 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
345
346 #[cfg(feature = "parallel")]
347 {
348 let total: usize = fused_dims.iter().product();
349 if total > MINTHREADLENGTH {
350 use crate::threading::SendPtr;
351 let dst_send = SendPtr(dst_ptr);
352 let a_send = SendPtr(a_ptr as *mut A);
353 let b_send = SendPtr(b_ptr as *mut B);
354
355 let costs = compute_costs(&ordered_strides);
356 let initial_offsets = vec![0isize; strides_list.len()];
357 let nthreads = rayon::current_num_threads();
358
359 return mapreduce_threaded(
360 &fused_dims,
361 &plan.block,
362 &ordered_strides,
363 &initial_offsets,
364 &costs,
365 nthreads,
366 0,
367 1,
368 &|dims, blocks, strides_list, offsets| {
369 for_each_inner_block_with_offsets(
370 dims,
371 blocks,
372 strides_list,
373 offsets,
374 |offsets, len, strides| {
375 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
376 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
377 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
378 unsafe {
379 inner_loop_map2::<D, A, B, OpA, OpB>(
380 dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
381 )
382 };
383 Ok(())
384 },
385 )
386 },
387 );
388 }
389 }
390
391 let initial_offsets = vec![0isize; ordered_strides.len()];
392 for_each_inner_block_preordered(
393 &fused_dims,
394 &plan.block,
395 &ordered_strides,
396 &initial_offsets,
397 |offsets, len, strides| {
398 let dp = unsafe { dst_ptr.offset(offsets[0]) };
399 let ap = unsafe { a_ptr.offset(offsets[1]) };
400 let bp = unsafe { b_ptr.offset(offsets[2]) };
401 unsafe {
402 inner_loop_map2::<D, A, B, OpA, OpB>(
403 dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
404 )
405 };
406 Ok(())
407 },
408 )
409}
410
411pub fn zip_map3_into<
413 D: Copy + MaybeSendSync,
414 A: Copy + MaybeSendSync,
415 B: Copy + MaybeSendSync,
416 C: Copy + MaybeSendSync,
417 OpA: ElementOp<A>,
418 OpB: ElementOp<B>,
419 OpC: ElementOp<C>,
420>(
421 dest: &mut StridedViewMut<D>,
422 a: &StridedView<A, OpA>,
423 b: &StridedView<B, OpB>,
424 c: &StridedView<C, OpC>,
425 f: impl Fn(A, B, C) -> D + MaybeSync,
426) -> Result<()> {
427 ensure_same_shape(dest.dims(), a.dims())?;
428 ensure_same_shape(dest.dims(), b.dims())?;
429 ensure_same_shape(dest.dims(), c.dims())?;
430
431 let dst_ptr = dest.as_mut_ptr();
432 let a_ptr = a.ptr();
433 let b_ptr = b.ptr();
434 let c_ptr = c.ptr();
435
436 let dst_dims = dest.dims();
437 let dst_strides = dest.strides();
438
439 if sequential_contiguous_layout(
440 dst_dims,
441 &[dst_strides, a.strides(), b.strides(), c.strides()],
442 )
443 .is_some()
444 {
445 let len = total_len(dst_dims);
446 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
447 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
448 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
449 let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
450 simd::dispatch_if_large(len, || {
451 for i in 0..len {
452 dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]), OpC::apply(sc[i]));
453 }
454 });
455 return Ok(());
456 }
457
458 let strides_list: [&[isize]; 4] = [dst_strides, a.strides(), b.strides(), c.strides()];
459 let elem_size = std::mem::size_of::<D>()
460 .max(std::mem::size_of::<A>())
461 .max(std::mem::size_of::<B>())
462 .max(std::mem::size_of::<C>());
463
464 let (fused_dims, ordered_strides, plan) =
465 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
466
467 #[cfg(feature = "parallel")]
468 {
469 let total: usize = fused_dims.iter().product();
470 if total > MINTHREADLENGTH {
471 use crate::threading::SendPtr;
472 let dst_send = SendPtr(dst_ptr);
473 let a_send = SendPtr(a_ptr as *mut A);
474 let b_send = SendPtr(b_ptr as *mut B);
475 let c_send = SendPtr(c_ptr as *mut C);
476
477 let costs = compute_costs(&ordered_strides);
478 let initial_offsets = vec![0isize; strides_list.len()];
479 let nthreads = rayon::current_num_threads();
480
481 return mapreduce_threaded(
482 &fused_dims,
483 &plan.block,
484 &ordered_strides,
485 &initial_offsets,
486 &costs,
487 nthreads,
488 0,
489 1,
490 &|dims, blocks, strides_list, offsets| {
491 for_each_inner_block_with_offsets(
492 dims,
493 blocks,
494 strides_list,
495 offsets,
496 |offsets, len, strides| {
497 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
498 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
499 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
500 let cp = unsafe { c_send.as_const().offset(offsets[3]) };
501 unsafe {
502 inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
503 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
504 len, &f,
505 )
506 };
507 Ok(())
508 },
509 )
510 },
511 );
512 }
513 }
514
515 let initial_offsets = vec![0isize; ordered_strides.len()];
516 for_each_inner_block_preordered(
517 &fused_dims,
518 &plan.block,
519 &ordered_strides,
520 &initial_offsets,
521 |offsets, len, strides| {
522 let dp = unsafe { dst_ptr.offset(offsets[0]) };
523 let ap = unsafe { a_ptr.offset(offsets[1]) };
524 let bp = unsafe { b_ptr.offset(offsets[2]) };
525 let cp = unsafe { c_ptr.offset(offsets[3]) };
526 unsafe {
527 inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
528 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], len, &f,
529 )
530 };
531 Ok(())
532 },
533 )
534}
535
536pub fn zip_map4_into<
538 D: Copy + MaybeSendSync,
539 A: Copy + MaybeSendSync,
540 B: Copy + MaybeSendSync,
541 C: Copy + MaybeSendSync,
542 E: Copy + MaybeSendSync,
543 OpA: ElementOp<A>,
544 OpB: ElementOp<B>,
545 OpC: ElementOp<C>,
546 OpE: ElementOp<E>,
547>(
548 dest: &mut StridedViewMut<D>,
549 a: &StridedView<A, OpA>,
550 b: &StridedView<B, OpB>,
551 c: &StridedView<C, OpC>,
552 e: &StridedView<E, OpE>,
553 f: impl Fn(A, B, C, E) -> D + MaybeSync,
554) -> Result<()> {
555 ensure_same_shape(dest.dims(), a.dims())?;
556 ensure_same_shape(dest.dims(), b.dims())?;
557 ensure_same_shape(dest.dims(), c.dims())?;
558 ensure_same_shape(dest.dims(), e.dims())?;
559
560 let dst_ptr = dest.as_mut_ptr();
561 let a_ptr = a.ptr();
562 let b_ptr = b.ptr();
563 let c_ptr = c.ptr();
564 let e_ptr = e.ptr();
565
566 let dst_dims = dest.dims();
567 let dst_strides = dest.strides();
568
569 if sequential_contiguous_layout(
570 dst_dims,
571 &[
572 dst_strides,
573 a.strides(),
574 b.strides(),
575 c.strides(),
576 e.strides(),
577 ],
578 )
579 .is_some()
580 {
581 let len = total_len(dst_dims);
582 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
583 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
584 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
585 let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
586 let se = unsafe { std::slice::from_raw_parts(e_ptr, len) };
587 simd::dispatch_if_large(len, || {
588 for i in 0..len {
589 dst[i] = f(
590 OpA::apply(sa[i]),
591 OpB::apply(sb[i]),
592 OpC::apply(sc[i]),
593 OpE::apply(se[i]),
594 );
595 }
596 });
597 return Ok(());
598 }
599
600 let strides_list: [&[isize]; 5] = [
601 dst_strides,
602 a.strides(),
603 b.strides(),
604 c.strides(),
605 e.strides(),
606 ];
607 let elem_size = std::mem::size_of::<D>()
608 .max(std::mem::size_of::<A>())
609 .max(std::mem::size_of::<B>())
610 .max(std::mem::size_of::<C>())
611 .max(std::mem::size_of::<E>());
612
613 let (fused_dims, ordered_strides, plan) =
614 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
615
616 #[cfg(feature = "parallel")]
617 {
618 let total: usize = fused_dims.iter().product();
619 if total > MINTHREADLENGTH {
620 use crate::threading::SendPtr;
621 let dst_send = SendPtr(dst_ptr);
622 let a_send = SendPtr(a_ptr as *mut A);
623 let b_send = SendPtr(b_ptr as *mut B);
624 let c_send = SendPtr(c_ptr as *mut C);
625 let e_send = SendPtr(e_ptr as *mut E);
626
627 let costs = compute_costs(&ordered_strides);
628 let initial_offsets = vec![0isize; strides_list.len()];
629 let nthreads = rayon::current_num_threads();
630
631 return mapreduce_threaded(
632 &fused_dims,
633 &plan.block,
634 &ordered_strides,
635 &initial_offsets,
636 &costs,
637 nthreads,
638 0,
639 1,
640 &|dims, blocks, strides_list, offsets| {
641 for_each_inner_block_with_offsets(
642 dims,
643 blocks,
644 strides_list,
645 offsets,
646 |offsets, len, strides| {
647 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
648 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
649 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
650 let cp = unsafe { c_send.as_const().offset(offsets[3]) };
651 let ep = unsafe { e_send.as_const().offset(offsets[4]) };
652 unsafe {
653 inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
654 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
655 ep, strides[4], len, &f,
656 )
657 };
658 Ok(())
659 },
660 )
661 },
662 );
663 }
664 }
665
666 let initial_offsets = vec![0isize; ordered_strides.len()];
667 for_each_inner_block_preordered(
668 &fused_dims,
669 &plan.block,
670 &ordered_strides,
671 &initial_offsets,
672 |offsets, len, strides| {
673 let dp = unsafe { dst_ptr.offset(offsets[0]) };
674 let ap = unsafe { a_ptr.offset(offsets[1]) };
675 let bp = unsafe { b_ptr.offset(offsets[2]) };
676 let cp = unsafe { c_ptr.offset(offsets[3]) };
677 let ep = unsafe { e_ptr.offset(offsets[4]) };
678 unsafe {
679 inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
680 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], ep, strides[4],
681 len, &f,
682 )
683 };
684 Ok(())
685 },
686 )
687}