1use crate::kernel::{
6 build_plan_fused, build_plan_fused_small, ensure_same_shape, for_each_inner_block_preordered,
7 sequential_contiguous_layout, total_len, SMALL_TENSOR_THRESHOLD,
8};
9use crate::maybe_sync::{MaybeSendSync, MaybeSync};
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::Result;
13use strided_view::ElementOp;
14
15#[cfg(feature = "parallel")]
16use crate::fuse::compute_costs;
17#[cfg(feature = "parallel")]
18use crate::threading::{for_each_inner_block_with_offsets, mapreduce_threaded, MINTHREADLENGTH};
19
20#[inline(always)]
30unsafe fn inner_loop_map1<D: Copy, A: Copy, Op: ElementOp<A>>(
31 dp: *mut D,
32 ds: isize,
33 sp: *const A,
34 ss: isize,
35 len: usize,
36 f: &impl Fn(A) -> D,
37) {
38 if ds == 1 && ss == 1 {
39 let src = std::slice::from_raw_parts(sp, len);
40 let dst = std::slice::from_raw_parts_mut(dp, len);
41 simd::dispatch_if_large(len, || {
42 for (d, s) in dst.iter_mut().zip(src.iter()) {
43 *d = f(Op::apply(*s));
44 }
45 });
46 } else {
47 let mut dp = dp;
48 let mut sp = sp;
49 for _ in 0..len {
50 *dp = f(Op::apply(*sp));
51 dp = dp.offset(ds);
52 sp = sp.offset(ss);
53 }
54 }
55}
56
57#[inline(always)]
59unsafe fn inner_loop_map2<D: Copy, A: Copy, B: Copy, OpA: ElementOp<A>, OpB: ElementOp<B>>(
60 dp: *mut D,
61 ds: isize,
62 ap: *const A,
63 a_s: isize,
64 bp: *const B,
65 b_s: isize,
66 len: usize,
67 f: &impl Fn(A, B) -> D,
68) {
69 if ds == 1 && a_s == 1 && b_s == 1 {
70 let src_a = std::slice::from_raw_parts(ap, len);
71 let src_b = std::slice::from_raw_parts(bp, len);
72 let dst = std::slice::from_raw_parts_mut(dp, len);
73 simd::dispatch_if_large(len, || {
74 for i in 0..len {
75 dst[i] = f(OpA::apply(src_a[i]), OpB::apply(src_b[i]));
76 }
77 });
78 } else {
79 let mut dp = dp;
80 let mut ap = ap;
81 let mut bp = bp;
82 for _ in 0..len {
83 *dp = f(OpA::apply(*ap), OpB::apply(*bp));
84 dp = dp.offset(ds);
85 ap = ap.offset(a_s);
86 bp = bp.offset(b_s);
87 }
88 }
89}
90
91#[inline(always)]
93unsafe fn inner_loop_map3<
94 D: Copy,
95 A: Copy,
96 B: Copy,
97 C: Copy,
98 OpA: ElementOp<A>,
99 OpB: ElementOp<B>,
100 OpC: ElementOp<C>,
101>(
102 dp: *mut D,
103 ds: isize,
104 ap: *const A,
105 a_s: isize,
106 bp: *const B,
107 b_s: isize,
108 cp: *const C,
109 c_s: isize,
110 len: usize,
111 f: &impl Fn(A, B, C) -> D,
112) {
113 if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 {
114 let src_a = std::slice::from_raw_parts(ap, len);
115 let src_b = std::slice::from_raw_parts(bp, len);
116 let src_c = std::slice::from_raw_parts(cp, len);
117 let dst = std::slice::from_raw_parts_mut(dp, len);
118 simd::dispatch_if_large(len, || {
119 for i in 0..len {
120 dst[i] = f(
121 OpA::apply(src_a[i]),
122 OpB::apply(src_b[i]),
123 OpC::apply(src_c[i]),
124 );
125 }
126 });
127 } else {
128 let mut dp = dp;
129 let mut ap = ap;
130 let mut bp = bp;
131 let mut cp = cp;
132 for _ in 0..len {
133 *dp = f(OpA::apply(*ap), OpB::apply(*bp), OpC::apply(*cp));
134 dp = dp.offset(ds);
135 ap = ap.offset(a_s);
136 bp = bp.offset(b_s);
137 cp = cp.offset(c_s);
138 }
139 }
140}
141
142#[inline(always)]
144unsafe fn inner_loop_map4<
145 D: Copy,
146 A: Copy,
147 B: Copy,
148 C: Copy,
149 E: Copy,
150 OpA: ElementOp<A>,
151 OpB: ElementOp<B>,
152 OpC: ElementOp<C>,
153 OpE: ElementOp<E>,
154>(
155 dp: *mut D,
156 ds: isize,
157 ap: *const A,
158 a_s: isize,
159 bp: *const B,
160 b_s: isize,
161 cp: *const C,
162 c_s: isize,
163 ep: *const E,
164 e_s: isize,
165 len: usize,
166 f: &impl Fn(A, B, C, E) -> D,
167) {
168 if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 && e_s == 1 {
169 let src_a = std::slice::from_raw_parts(ap, len);
170 let src_b = std::slice::from_raw_parts(bp, len);
171 let src_c = std::slice::from_raw_parts(cp, len);
172 let src_e = std::slice::from_raw_parts(ep, len);
173 let dst = std::slice::from_raw_parts_mut(dp, len);
174 simd::dispatch_if_large(len, || {
175 for i in 0..len {
176 dst[i] = f(
177 OpA::apply(src_a[i]),
178 OpB::apply(src_b[i]),
179 OpC::apply(src_c[i]),
180 OpE::apply(src_e[i]),
181 );
182 }
183 });
184 } else {
185 let mut dp = dp;
186 let mut ap = ap;
187 let mut bp = bp;
188 let mut cp = cp;
189 let mut ep = ep;
190 for _ in 0..len {
191 *dp = f(
192 OpA::apply(*ap),
193 OpB::apply(*bp),
194 OpC::apply(*cp),
195 OpE::apply(*ep),
196 );
197 dp = dp.offset(ds);
198 ap = ap.offset(a_s);
199 bp = bp.offset(b_s);
200 cp = cp.offset(c_s);
201 ep = ep.offset(e_s);
202 }
203 }
204}
205
206pub fn map_into<D: Copy + MaybeSendSync, A: Copy + MaybeSendSync, Op: ElementOp<A>>(
211 dest: &mut StridedViewMut<D>,
212 src: &StridedView<A, Op>,
213 f: impl Fn(A) -> D + MaybeSync,
214) -> Result<()> {
215 ensure_same_shape(dest.dims(), src.dims())?;
216
217 let dst_ptr = dest.as_mut_ptr();
218 let src_ptr = src.ptr();
219 let dst_dims = dest.dims();
220 let dst_strides = dest.strides();
221 let src_strides = src.strides();
222
223 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
224 let len = total_len(dst_dims);
225 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
226 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
227 simd::dispatch_if_large(len, || {
228 for i in 0..len {
229 dst[i] = f(Op::apply(src[i]));
230 }
231 });
232 return Ok(());
233 }
234
235 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
236 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<A>());
237 let total = total_len(dst_dims);
238
239 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
241 build_plan_fused_small(dst_dims, &strides_list)
242 } else {
243 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
244 };
245
246 #[cfg(feature = "parallel")]
247 {
248 let total: usize = fused_dims.iter().product();
249 if total > MINTHREADLENGTH {
250 use crate::threading::SendPtr;
251 let dst_send = SendPtr(dst_ptr);
252 let src_send = SendPtr(src_ptr as *mut A);
253
254 let costs = compute_costs(&ordered_strides);
255 let initial_offsets = vec![0isize; strides_list.len()];
256 let nthreads = rayon::current_num_threads();
257
258 return mapreduce_threaded(
259 &fused_dims,
260 &plan.block,
261 &ordered_strides,
262 &initial_offsets,
263 &costs,
264 nthreads,
265 0,
266 1,
267 &|dims, blocks, strides_list, offsets| {
268 for_each_inner_block_with_offsets(
269 dims,
270 blocks,
271 strides_list,
272 offsets,
273 |offsets, len, strides| {
274 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
275 let sp = unsafe { src_send.as_const().offset(offsets[1]) };
276 unsafe {
277 inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f)
278 };
279 Ok(())
280 },
281 )
282 },
283 );
284 }
285 }
286
287 let initial_offsets = vec![0isize; ordered_strides.len()];
288 for_each_inner_block_preordered(
289 &fused_dims,
290 &plan.block,
291 &ordered_strides,
292 &initial_offsets,
293 |offsets, len, strides| {
294 let dp = unsafe { dst_ptr.offset(offsets[0]) };
295 let sp = unsafe { src_ptr.offset(offsets[1]) };
296 unsafe { inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f) };
297 Ok(())
298 },
299 )
300}
301
302pub fn zip_map2_into<
307 D: Copy + MaybeSendSync,
308 A: Copy + MaybeSendSync,
309 B: Copy + MaybeSendSync,
310 OpA: ElementOp<A>,
311 OpB: ElementOp<B>,
312>(
313 dest: &mut StridedViewMut<D>,
314 a: &StridedView<A, OpA>,
315 b: &StridedView<B, OpB>,
316 f: impl Fn(A, B) -> D + MaybeSync,
317) -> Result<()> {
318 ensure_same_shape(dest.dims(), a.dims())?;
319 ensure_same_shape(dest.dims(), b.dims())?;
320
321 let dst_ptr = dest.as_mut_ptr();
322 let dst_dims = dest.dims();
323 let dst_strides = dest.strides();
324 let a_ptr = a.ptr();
325 let b_ptr = b.ptr();
326
327 let a_strides = a.strides();
328 let b_strides = b.strides();
329
330 if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
331 let len = total_len(dst_dims);
332 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
333 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
334 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
335 simd::dispatch_if_large(len, || {
336 for i in 0..len {
337 dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]));
338 }
339 });
340 return Ok(());
341 }
342
343 let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
344 let elem_size = std::mem::size_of::<D>()
345 .max(std::mem::size_of::<A>())
346 .max(std::mem::size_of::<B>());
347 let total = total_len(dst_dims);
348
349 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
351 build_plan_fused_small(dst_dims, &strides_list)
352 } else {
353 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
354 };
355
356 #[cfg(feature = "parallel")]
357 {
358 let total: usize = fused_dims.iter().product();
359 if total > MINTHREADLENGTH {
360 use crate::threading::SendPtr;
361 let dst_send = SendPtr(dst_ptr);
362 let a_send = SendPtr(a_ptr as *mut A);
363 let b_send = SendPtr(b_ptr as *mut B);
364
365 let costs = compute_costs(&ordered_strides);
366 let initial_offsets = vec![0isize; strides_list.len()];
367 let nthreads = rayon::current_num_threads();
368
369 return mapreduce_threaded(
370 &fused_dims,
371 &plan.block,
372 &ordered_strides,
373 &initial_offsets,
374 &costs,
375 nthreads,
376 0,
377 1,
378 &|dims, blocks, strides_list, offsets| {
379 for_each_inner_block_with_offsets(
380 dims,
381 blocks,
382 strides_list,
383 offsets,
384 |offsets, len, strides| {
385 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
386 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
387 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
388 unsafe {
389 inner_loop_map2::<D, A, B, OpA, OpB>(
390 dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
391 )
392 };
393 Ok(())
394 },
395 )
396 },
397 );
398 }
399 }
400
401 let initial_offsets = vec![0isize; ordered_strides.len()];
402 for_each_inner_block_preordered(
403 &fused_dims,
404 &plan.block,
405 &ordered_strides,
406 &initial_offsets,
407 |offsets, len, strides| {
408 let dp = unsafe { dst_ptr.offset(offsets[0]) };
409 let ap = unsafe { a_ptr.offset(offsets[1]) };
410 let bp = unsafe { b_ptr.offset(offsets[2]) };
411 unsafe {
412 inner_loop_map2::<D, A, B, OpA, OpB>(
413 dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
414 )
415 };
416 Ok(())
417 },
418 )
419}
420
421pub fn zip_map3_into<
423 D: Copy + MaybeSendSync,
424 A: Copy + MaybeSendSync,
425 B: Copy + MaybeSendSync,
426 C: Copy + MaybeSendSync,
427 OpA: ElementOp<A>,
428 OpB: ElementOp<B>,
429 OpC: ElementOp<C>,
430>(
431 dest: &mut StridedViewMut<D>,
432 a: &StridedView<A, OpA>,
433 b: &StridedView<B, OpB>,
434 c: &StridedView<C, OpC>,
435 f: impl Fn(A, B, C) -> D + MaybeSync,
436) -> Result<()> {
437 ensure_same_shape(dest.dims(), a.dims())?;
438 ensure_same_shape(dest.dims(), b.dims())?;
439 ensure_same_shape(dest.dims(), c.dims())?;
440
441 let dst_ptr = dest.as_mut_ptr();
442 let a_ptr = a.ptr();
443 let b_ptr = b.ptr();
444 let c_ptr = c.ptr();
445
446 let dst_dims = dest.dims();
447 let dst_strides = dest.strides();
448
449 if sequential_contiguous_layout(
450 dst_dims,
451 &[dst_strides, a.strides(), b.strides(), c.strides()],
452 )
453 .is_some()
454 {
455 let len = total_len(dst_dims);
456 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
457 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
458 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
459 let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
460 simd::dispatch_if_large(len, || {
461 for i in 0..len {
462 dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]), OpC::apply(sc[i]));
463 }
464 });
465 return Ok(());
466 }
467
468 let strides_list: [&[isize]; 4] = [dst_strides, a.strides(), b.strides(), c.strides()];
469 let elem_size = std::mem::size_of::<D>()
470 .max(std::mem::size_of::<A>())
471 .max(std::mem::size_of::<B>())
472 .max(std::mem::size_of::<C>());
473 let total = total_len(dst_dims);
474
475 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
477 build_plan_fused_small(dst_dims, &strides_list)
478 } else {
479 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
480 };
481
482 #[cfg(feature = "parallel")]
483 {
484 let total: usize = fused_dims.iter().product();
485 if total > MINTHREADLENGTH {
486 use crate::threading::SendPtr;
487 let dst_send = SendPtr(dst_ptr);
488 let a_send = SendPtr(a_ptr as *mut A);
489 let b_send = SendPtr(b_ptr as *mut B);
490 let c_send = SendPtr(c_ptr as *mut C);
491
492 let costs = compute_costs(&ordered_strides);
493 let initial_offsets = vec![0isize; strides_list.len()];
494 let nthreads = rayon::current_num_threads();
495
496 return mapreduce_threaded(
497 &fused_dims,
498 &plan.block,
499 &ordered_strides,
500 &initial_offsets,
501 &costs,
502 nthreads,
503 0,
504 1,
505 &|dims, blocks, strides_list, offsets| {
506 for_each_inner_block_with_offsets(
507 dims,
508 blocks,
509 strides_list,
510 offsets,
511 |offsets, len, strides| {
512 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
513 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
514 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
515 let cp = unsafe { c_send.as_const().offset(offsets[3]) };
516 unsafe {
517 inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
518 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
519 len, &f,
520 )
521 };
522 Ok(())
523 },
524 )
525 },
526 );
527 }
528 }
529
530 let initial_offsets = vec![0isize; ordered_strides.len()];
531 for_each_inner_block_preordered(
532 &fused_dims,
533 &plan.block,
534 &ordered_strides,
535 &initial_offsets,
536 |offsets, len, strides| {
537 let dp = unsafe { dst_ptr.offset(offsets[0]) };
538 let ap = unsafe { a_ptr.offset(offsets[1]) };
539 let bp = unsafe { b_ptr.offset(offsets[2]) };
540 let cp = unsafe { c_ptr.offset(offsets[3]) };
541 unsafe {
542 inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
543 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], len, &f,
544 )
545 };
546 Ok(())
547 },
548 )
549}
550
551pub fn zip_map4_into<
553 D: Copy + MaybeSendSync,
554 A: Copy + MaybeSendSync,
555 B: Copy + MaybeSendSync,
556 C: Copy + MaybeSendSync,
557 E: Copy + MaybeSendSync,
558 OpA: ElementOp<A>,
559 OpB: ElementOp<B>,
560 OpC: ElementOp<C>,
561 OpE: ElementOp<E>,
562>(
563 dest: &mut StridedViewMut<D>,
564 a: &StridedView<A, OpA>,
565 b: &StridedView<B, OpB>,
566 c: &StridedView<C, OpC>,
567 e: &StridedView<E, OpE>,
568 f: impl Fn(A, B, C, E) -> D + MaybeSync,
569) -> Result<()> {
570 ensure_same_shape(dest.dims(), a.dims())?;
571 ensure_same_shape(dest.dims(), b.dims())?;
572 ensure_same_shape(dest.dims(), c.dims())?;
573 ensure_same_shape(dest.dims(), e.dims())?;
574
575 let dst_ptr = dest.as_mut_ptr();
576 let a_ptr = a.ptr();
577 let b_ptr = b.ptr();
578 let c_ptr = c.ptr();
579 let e_ptr = e.ptr();
580
581 let dst_dims = dest.dims();
582 let dst_strides = dest.strides();
583
584 if sequential_contiguous_layout(
585 dst_dims,
586 &[
587 dst_strides,
588 a.strides(),
589 b.strides(),
590 c.strides(),
591 e.strides(),
592 ],
593 )
594 .is_some()
595 {
596 let len = total_len(dst_dims);
597 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
598 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
599 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
600 let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
601 let se = unsafe { std::slice::from_raw_parts(e_ptr, len) };
602 simd::dispatch_if_large(len, || {
603 for i in 0..len {
604 dst[i] = f(
605 OpA::apply(sa[i]),
606 OpB::apply(sb[i]),
607 OpC::apply(sc[i]),
608 OpE::apply(se[i]),
609 );
610 }
611 });
612 return Ok(());
613 }
614
615 let strides_list: [&[isize]; 5] = [
616 dst_strides,
617 a.strides(),
618 b.strides(),
619 c.strides(),
620 e.strides(),
621 ];
622 let elem_size = std::mem::size_of::<D>()
623 .max(std::mem::size_of::<A>())
624 .max(std::mem::size_of::<B>())
625 .max(std::mem::size_of::<C>())
626 .max(std::mem::size_of::<E>());
627 let total = total_len(dst_dims);
628
629 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
631 build_plan_fused_small(dst_dims, &strides_list)
632 } else {
633 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
634 };
635
636 #[cfg(feature = "parallel")]
637 {
638 let total: usize = fused_dims.iter().product();
639 if total > MINTHREADLENGTH {
640 use crate::threading::SendPtr;
641 let dst_send = SendPtr(dst_ptr);
642 let a_send = SendPtr(a_ptr as *mut A);
643 let b_send = SendPtr(b_ptr as *mut B);
644 let c_send = SendPtr(c_ptr as *mut C);
645 let e_send = SendPtr(e_ptr as *mut E);
646
647 let costs = compute_costs(&ordered_strides);
648 let initial_offsets = vec![0isize; strides_list.len()];
649 let nthreads = rayon::current_num_threads();
650
651 return mapreduce_threaded(
652 &fused_dims,
653 &plan.block,
654 &ordered_strides,
655 &initial_offsets,
656 &costs,
657 nthreads,
658 0,
659 1,
660 &|dims, blocks, strides_list, offsets| {
661 for_each_inner_block_with_offsets(
662 dims,
663 blocks,
664 strides_list,
665 offsets,
666 |offsets, len, strides| {
667 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
668 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
669 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
670 let cp = unsafe { c_send.as_const().offset(offsets[3]) };
671 let ep = unsafe { e_send.as_const().offset(offsets[4]) };
672 unsafe {
673 inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
674 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
675 ep, strides[4], len, &f,
676 )
677 };
678 Ok(())
679 },
680 )
681 },
682 );
683 }
684 }
685
686 let initial_offsets = vec![0isize; ordered_strides.len()];
687 for_each_inner_block_preordered(
688 &fused_dims,
689 &plan.block,
690 &ordered_strides,
691 &initial_offsets,
692 |offsets, len, strides| {
693 let dp = unsafe { dst_ptr.offset(offsets[0]) };
694 let ap = unsafe { a_ptr.offset(offsets[1]) };
695 let bp = unsafe { b_ptr.offset(offsets[2]) };
696 let cp = unsafe { c_ptr.offset(offsets[3]) };
697 let ep = unsafe { e_ptr.offset(offsets[4]) };
698 unsafe {
699 inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
700 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], ep, strides[4],
701 len, &f,
702 )
703 };
704 Ok(())
705 },
706 )
707}