1use crate::kernel::{
6 build_plan_fused, build_plan_fused_small, ensure_same_shape, for_each_inner_block_preordered,
7 sequential_contiguous_layout, total_len, SMALL_TENSOR_THRESHOLD,
8};
9use crate::maybe_sync::{MaybeSendSync, MaybeSync};
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::{Result, StridedError};
13use std::ops::Mul;
14use strided_view::ElementOp;
15
16#[cfg(feature = "parallel")]
17use crate::fuse::compute_costs;
18#[cfg(feature = "parallel")]
19use crate::threading::{for_each_inner_block_with_offsets, mapreduce_threaded, MINTHREADLENGTH};
20#[cfg(feature = "parallel")]
21use smallvec::SmallVec;
22
23#[cfg(feature = "parallel")]
24type AxisVec<T> = SmallVec<[T; 8]>;
25#[cfg(not(feature = "parallel"))]
26type AxisVec<T> = Vec<T>;
27
28const CONTIGUOUS_RANGE_MIN_LEN: usize = 1 << 15;
29
30#[inline(always)]
40unsafe fn inner_loop_map1<D: Copy, A: Copy, Op: ElementOp<A>>(
41 dp: *mut D,
42 ds: isize,
43 sp: *const A,
44 ss: isize,
45 len: usize,
46 f: &impl Fn(A) -> D,
47) {
48 if ds == 1 && ss == 1 {
49 let src = std::slice::from_raw_parts(sp, len);
50 let dst = std::slice::from_raw_parts_mut(dp, len);
51 simd::dispatch_if_large(len, || {
52 for (d, s) in dst.iter_mut().zip(src.iter()) {
53 *d = f(Op::apply(*s));
54 }
55 });
56 } else {
57 let mut dp = dp;
58 let mut sp = sp;
59 for _ in 0..len {
60 *dp = f(Op::apply(*sp));
61 dp = dp.offset(ds);
62 sp = sp.offset(ss);
63 }
64 }
65}
66
67#[inline(always)]
69unsafe fn inner_loop_map2<D: Copy, A: Copy, B: Copy, OpA: ElementOp<A>, OpB: ElementOp<B>>(
70 dp: *mut D,
71 ds: isize,
72 ap: *const A,
73 a_s: isize,
74 bp: *const B,
75 b_s: isize,
76 len: usize,
77 f: &impl Fn(A, B) -> D,
78) {
79 if ds == 1 && a_s == 1 && b_s == 1 {
80 let src_a = std::slice::from_raw_parts(ap, len);
81 let src_b = std::slice::from_raw_parts(bp, len);
82 let dst = std::slice::from_raw_parts_mut(dp, len);
83 simd::dispatch_if_large(len, || {
84 for i in 0..len {
85 dst[i] = f(OpA::apply(src_a[i]), OpB::apply(src_b[i]));
86 }
87 });
88 } else if ds == 1 && a_s == 1 && b_s == 0 {
89 let src_a = std::slice::from_raw_parts(ap, len);
90 let b = OpB::apply(*bp);
91 let dst = std::slice::from_raw_parts_mut(dp, len);
92 simd::dispatch_if_large(len, || {
93 for i in 0..len {
94 dst[i] = f(OpA::apply(src_a[i]), b);
95 }
96 });
97 } else if ds == 1 && a_s == 0 && b_s == 1 {
98 let a = OpA::apply(*ap);
99 let src_b = std::slice::from_raw_parts(bp, len);
100 let dst = std::slice::from_raw_parts_mut(dp, len);
101 simd::dispatch_if_large(len, || {
102 for i in 0..len {
103 dst[i] = f(a, OpB::apply(src_b[i]));
104 }
105 });
106 } else if ds == 1 && a_s == 0 && b_s == 0 {
107 let a = OpA::apply(*ap);
108 let b = OpB::apply(*bp);
109 let dst = std::slice::from_raw_parts_mut(dp, len);
110 simd::dispatch_if_large(len, || {
111 for d in dst.iter_mut() {
112 *d = f(a, b);
113 }
114 });
115 } else if ds == 1 && b_s == 0 {
116 let b = OpB::apply(*bp);
117 let dst = std::slice::from_raw_parts_mut(dp, len);
118 let mut ap = ap;
119 simd::dispatch_if_large(len, || {
120 for d in dst.iter_mut() {
121 *d = f(OpA::apply(*ap), b);
122 ap = ap.offset(a_s);
123 }
124 });
125 } else if ds == 1 && a_s == 0 {
126 let a = OpA::apply(*ap);
127 let dst = std::slice::from_raw_parts_mut(dp, len);
128 let mut bp = bp;
129 simd::dispatch_if_large(len, || {
130 for d in dst.iter_mut() {
131 *d = f(a, OpB::apply(*bp));
132 bp = bp.offset(b_s);
133 }
134 });
135 } else {
136 let mut dp = dp;
137 let mut ap = ap;
138 let mut bp = bp;
139 for _ in 0..len {
140 *dp = f(OpA::apply(*ap), OpB::apply(*bp));
141 dp = dp.offset(ds);
142 ap = ap.offset(a_s);
143 bp = bp.offset(b_s);
144 }
145 }
146}
147
148#[inline(always)]
150unsafe fn inner_loop_mul2<
151 D: Copy + 'static,
152 A: Copy + Mul<B, Output = D> + 'static,
153 B: Copy + 'static,
154>(
155 dp: *mut D,
156 ds: isize,
157 ap: *const A,
158 a_s: isize,
159 bp: *const B,
160 b_s: isize,
161 len: usize,
162) {
163 if ds == 1 && a_s == 1 && b_s == 1 {
164 let src_a = std::slice::from_raw_parts(ap, len);
165 let src_b = std::slice::from_raw_parts(bp, len);
166 let dst = std::slice::from_raw_parts_mut(dp, len);
167 if len >= 64 && simd::try_mul_contiguous(dst, src_a, src_b) {
168 return;
169 }
170 for i in 0..len {
171 dst[i] = src_a[i] * src_b[i];
172 }
173 } else if ds == 1 && a_s == 1 && b_s == 0 {
174 let src_a = std::slice::from_raw_parts(ap, len);
175 let b = *bp;
176 let dst = std::slice::from_raw_parts_mut(dp, len);
177 for i in 0..len {
178 dst[i] = src_a[i] * b;
179 }
180 } else if ds == 1 && a_s == 0 && b_s == 1 {
181 let a = *ap;
182 let src_b = std::slice::from_raw_parts(bp, len);
183 let dst = std::slice::from_raw_parts_mut(dp, len);
184 for i in 0..len {
185 dst[i] = a * src_b[i];
186 }
187 } else if ds == 1 && a_s == 0 && b_s == 0 {
188 let a = *ap;
189 let b = *bp;
190 let dst = std::slice::from_raw_parts_mut(dp, len);
191 for d in dst.iter_mut() {
192 *d = a * b;
193 }
194 } else if ds == 1 && b_s == 0 {
195 let b = *bp;
196 let dst = std::slice::from_raw_parts_mut(dp, len);
197 let mut ap = ap;
198 for d in dst.iter_mut() {
199 *d = *ap * b;
200 ap = ap.offset(a_s);
201 }
202 } else if ds == 1 && a_s == 0 {
203 let a = *ap;
204 let dst = std::slice::from_raw_parts_mut(dp, len);
205 let mut bp = bp;
206 for d in dst.iter_mut() {
207 *d = a * *bp;
208 bp = bp.offset(b_s);
209 }
210 } else {
211 let mut dp = dp;
212 let mut ap = ap;
213 let mut bp = bp;
214 for _ in 0..len {
215 *dp = *ap * *bp;
216 dp = dp.offset(ds);
217 ap = ap.offset(a_s);
218 bp = bp.offset(b_s);
219 }
220 }
221}
222
223#[derive(Clone, Debug, Eq, PartialEq)]
224struct ContiguousMulRangePlan {
225 axis_order: AxisVec<usize>,
226 inner_len: usize,
227 inner_axis_count: usize,
228 row_len: usize,
229 outer_axis_start: usize,
230 fast_axis: usize,
231 a_fast_stride: isize,
232 b_fast_stride: isize,
233 a_row_stride: isize,
234 b_row_stride: isize,
235}
236
237#[cfg(feature = "parallel")]
238#[derive(Clone, Copy, Debug, Eq, PartialEq)]
239enum TransposedScalarTileKind {
240 RhsScalar,
241 LhsScalar,
242}
243
244#[cfg(feature = "parallel")]
245fn transposed_scalar_tile_kind(plan: &ContiguousMulRangePlan) -> Option<TransposedScalarTileKind> {
246 let row_len = isize::try_from(plan.row_len).ok()?;
247 if plan.a_fast_stride == row_len
248 && plan.a_row_stride == 1
249 && plan.b_fast_stride == 0
250 && plan.b_row_stride == 0
251 {
252 return Some(TransposedScalarTileKind::RhsScalar);
253 }
254
255 if plan.b_fast_stride == row_len
256 && plan.b_row_stride == 1
257 && plan.a_fast_stride == 0
258 && plan.a_row_stride == 0
259 {
260 return Some(TransposedScalarTileKind::LhsScalar);
261 }
262
263 None
264}
265
266fn compact_axis_order(dims: &[usize], strides: &[isize]) -> Option<AxisVec<usize>> {
267 if dims.len() != strides.len() {
268 return None;
269 }
270
271 let mut active = AxisVec::<usize>::new();
272 let mut inactive = AxisVec::<usize>::new();
273 for (axis, (&dim, &stride)) in dims.iter().zip(strides.iter()).enumerate() {
274 if stride < 0 {
275 return None;
276 }
277 if dim > 1 {
278 active.push(axis);
279 } else {
280 inactive.push(axis);
281 }
282 }
283
284 active.sort_by(|&lhs, &rhs| strides[lhs].cmp(&strides[rhs]).then_with(|| lhs.cmp(&rhs)));
285
286 let mut expected = 1isize;
287 for &axis in &active {
288 if strides[axis] != expected {
289 return None;
290 }
291 expected = expected.saturating_mul(dims[axis] as isize);
292 }
293
294 active.extend(inactive);
295 Some(active)
296}
297
298fn can_fuse_contiguous_range_axis(dim: usize, prev_stride: isize, next_stride: isize) -> bool {
299 dim <= 1 || (prev_stride == 0 && next_stride == 0) || next_stride == prev_stride * dim as isize
300}
301
302fn contiguous_mul_range_plan(
303 dims: &[usize],
304 dst_strides: &[isize],
305 a_strides: &[isize],
306 b_strides: &[isize],
307) -> Option<ContiguousMulRangePlan> {
308 let axis_order = compact_axis_order(dims, dst_strides)?;
309 if dims.is_empty() {
310 return Some(ContiguousMulRangePlan {
311 axis_order,
312 inner_len: 1,
313 inner_axis_count: 0,
314 row_len: 1,
315 outer_axis_start: 0,
316 fast_axis: 0,
317 a_fast_stride: 0,
318 b_fast_stride: 0,
319 a_row_stride: 0,
320 b_row_stride: 0,
321 });
322 }
323
324 let first_pos = axis_order
325 .iter()
326 .position(|&axis| dims[axis] > 1)
327 .unwrap_or(0);
328 let first_axis = axis_order[first_pos];
329 let mut inner_len = dims[first_axis].max(1);
330 let mut inner_axis_count = first_pos + 1;
331 let mut prev_axis = first_axis;
332
333 for &axis in axis_order.iter().skip(first_pos + 1) {
334 if can_fuse_contiguous_range_axis(
335 dims[prev_axis],
336 dst_strides[prev_axis],
337 dst_strides[axis],
338 ) && can_fuse_contiguous_range_axis(
339 dims[prev_axis],
340 a_strides[prev_axis],
341 a_strides[axis],
342 ) && can_fuse_contiguous_range_axis(
343 dims[prev_axis],
344 b_strides[prev_axis],
345 b_strides[axis],
346 ) {
347 inner_len = inner_len.checked_mul(dims[axis].max(1))?;
348 inner_axis_count += 1;
349 prev_axis = axis;
350 } else {
351 break;
352 }
353 }
354
355 let row = axis_order
356 .iter()
357 .enumerate()
358 .skip(inner_axis_count)
359 .find(|&(_, &axis)| dims[axis] > 1);
360 let (row_len, outer_axis_start, a_row_stride, b_row_stride) =
361 if let Some((row_pos, &row_axis)) = row {
362 (
363 dims[row_axis],
364 row_pos + 1,
365 a_strides[row_axis],
366 b_strides[row_axis],
367 )
368 } else {
369 (1, axis_order.len(), 0, 0)
370 };
371
372 Some(ContiguousMulRangePlan {
373 axis_order,
374 inner_len,
375 inner_axis_count,
376 row_len,
377 outer_axis_start,
378 fast_axis: first_axis,
379 a_fast_stride: a_strides[first_axis],
380 b_fast_stride: b_strides[first_axis],
381 a_row_stride,
382 b_row_stride,
383 })
384}
385
386struct ContiguousMulOuterCursor<'a> {
387 dims: &'a [usize],
388 a_strides: &'a [isize],
389 b_strides: &'a [isize],
390 axes: AxisVec<usize>,
391 coords: AxisVec<usize>,
392 a_offset: isize,
393 b_offset: isize,
394}
395
396impl<'a> ContiguousMulOuterCursor<'a> {
397 fn new(
398 dims: &'a [usize],
399 a_strides: &'a [isize],
400 b_strides: &'a [isize],
401 plan: &ContiguousMulRangePlan,
402 outer_group: usize,
403 ) -> Self {
404 let axes: AxisVec<usize> = plan
405 .axis_order
406 .iter()
407 .skip(plan.outer_axis_start)
408 .copied()
409 .collect();
410 let mut coords = AxisVec::<usize>::with_capacity(axes.len());
411 let mut rem = outer_group;
412 let mut a_offset = 0isize;
413 let mut b_offset = 0isize;
414
415 for &axis in &axes {
416 let dim = dims[axis].max(1);
417 let coord = rem % dim;
418 rem /= dim;
419 coords.push(coord);
420 a_offset += coord as isize * a_strides[axis];
421 b_offset += coord as isize * b_strides[axis];
422 }
423
424 Self {
425 dims,
426 a_strides,
427 b_strides,
428 axes,
429 coords,
430 a_offset,
431 b_offset,
432 }
433 }
434
435 fn advance(&mut self) {
436 for (i, &axis) in self.axes.iter().enumerate() {
437 let dim = self.dims[axis].max(1);
438 if dim <= 1 {
439 continue;
440 }
441
442 self.coords[i] += 1;
443 self.a_offset += self.a_strides[axis];
444 self.b_offset += self.b_strides[axis];
445
446 if self.coords[i] < dim {
447 break;
448 }
449
450 self.coords[i] = 0;
451 self.a_offset -= dim as isize * self.a_strides[axis];
452 self.b_offset -= dim as isize * self.b_strides[axis];
453 }
454 }
455}
456
457#[inline(always)]
458unsafe fn run_contiguous_mul_row_block<
459 D: Copy + 'static,
460 A: Copy + Mul<B, Output = D> + 'static,
461 B: Copy + 'static,
462>(
463 dst_ptr: *mut D,
464 a_ptr: *const A,
465 b_ptr: *const B,
466 plan: &ContiguousMulRangePlan,
467 base_index: usize,
468 total: usize,
469 base_a_offset: isize,
470 base_b_offset: isize,
471) {
472 let inner_len = plan.inner_len.max(1);
473 let row_len = plan.row_len.max(1);
474 #[cfg(feature = "parallel")]
475 let block_len = inner_len.saturating_mul(row_len);
476
477 #[cfg(feature = "parallel")]
478 if total.saturating_sub(base_index) >= block_len {
479 match transposed_scalar_tile_kind(plan) {
480 Some(TransposedScalarTileKind::RhsScalar) => {
481 if simd::try_mul_transposed_scalar_rhs_2d::<D, A, B>(
482 dst_ptr.add(base_index),
483 a_ptr.offset(base_a_offset),
484 b_ptr.offset(base_b_offset),
485 inner_len,
486 row_len,
487 plan.a_fast_stride,
488 plan.a_row_stride,
489 ) {
490 return;
491 }
492 }
493 Some(TransposedScalarTileKind::LhsScalar) => {
494 if simd::try_mul_transposed_scalar_lhs_2d::<D, A, B>(
495 dst_ptr.add(base_index),
496 a_ptr.offset(base_a_offset),
497 b_ptr.offset(base_b_offset),
498 inner_len,
499 row_len,
500 plan.b_fast_stride,
501 plan.b_row_stride,
502 ) {
503 return;
504 }
505 }
506 None => {}
507 }
508 }
509
510 let mut index = base_index;
511 let mut a_offset = base_a_offset;
512 let mut b_offset = base_b_offset;
513
514 for _ in 0..row_len {
515 if index >= total {
516 break;
517 }
518 let len = inner_len.min(total - index);
519 inner_loop_mul2::<D, A, B>(
520 dst_ptr.add(index),
521 1,
522 a_ptr.offset(a_offset),
523 plan.a_fast_stride,
524 b_ptr.offset(b_offset),
525 plan.b_fast_stride,
526 len,
527 );
528 index += inner_len;
529 a_offset += plan.a_row_stride;
530 b_offset += plan.b_row_stride;
531 }
532}
533
534#[cfg(feature = "parallel")]
535fn strided_offset_for_contiguous_linear_index(
536 dims: &[usize],
537 strides: &[isize],
538 axis_order: &[usize],
539 mut index: usize,
540) -> isize {
541 let mut offset = 0isize;
542 for &axis in axis_order {
543 let dim = dims[axis];
544 if dim == 0 {
545 return 0;
546 }
547 let coord = index % dim;
548 index /= dim;
549 offset += coord as isize * strides[axis];
550 }
551 offset
552}
553
554fn try_contiguous_range_mul<
555 D: Copy + MaybeSendSync + 'static,
556 A: Copy + MaybeSendSync + Mul<B, Output = D> + 'static,
557 B: Copy + MaybeSendSync + 'static,
558>(
559 dst_ptr: *mut D,
560 dims: &[usize],
561 dst_strides: &[isize],
562 a_ptr: *const A,
563 a_strides: &[isize],
564 b_ptr: *const B,
565 b_strides: &[isize],
566) -> bool {
567 let total = total_len(dims);
568 if total == 0 {
569 return true;
570 }
571 if total <= CONTIGUOUS_RANGE_MIN_LEN {
572 return false;
573 }
574
575 let Some(plan) = contiguous_mul_range_plan(dims, dst_strides, a_strides, b_strides) else {
576 return false;
577 };
578
579 let inner_len = plan.inner_len.max(1);
580 let row_len = plan.row_len.max(1);
581 let block_len = inner_len.saturating_mul(row_len).max(1);
582 let outer_groups = total.div_ceil(block_len);
583
584 #[cfg(feature = "parallel")]
585 {
586 let nthreads = rayon::current_num_threads();
587 if nthreads > 1 {
588 use crate::threading::SendPtr;
589 use rayon::prelude::*;
590
591 let dst = SendPtr(dst_ptr);
592 let a = SendPtr(a_ptr as *mut A);
593 let b = SendPtr(b_ptr as *mut B);
594
595 if outer_groups < nthreads {
596 let chunk_len = total.div_ceil(nthreads);
597 let nchunks = total.div_ceil(chunk_len);
598
599 (0..nchunks).into_par_iter().for_each(|chunk| {
600 let start = chunk * chunk_len;
601 let end = (start + chunk_len).min(total);
602 let mut index = start;
603
604 while index < end {
605 let in_inner = index % inner_len;
606 let len = (inner_len - in_inner).min(end - index);
607 let a_offset = strided_offset_for_contiguous_linear_index(
608 dims,
609 a_strides,
610 &plan.axis_order,
611 index,
612 );
613 let b_offset = strided_offset_for_contiguous_linear_index(
614 dims,
615 b_strides,
616 &plan.axis_order,
617 index,
618 );
619
620 unsafe {
621 inner_loop_mul2::<D, A, B>(
622 dst.as_ptr().add(index),
623 1,
624 a.as_const().offset(a_offset),
625 plan.a_fast_stride,
626 b.as_const().offset(b_offset),
627 plan.b_fast_stride,
628 len,
629 );
630 }
631 index += len;
632 }
633 });
634
635 return true;
636 }
637
638 let groups_per_chunk = outer_groups.div_ceil(nthreads);
639 let nchunks = outer_groups.div_ceil(groups_per_chunk);
640
641 (0..nchunks).into_par_iter().for_each(|chunk| {
642 let group_start = chunk * groups_per_chunk;
643 let group_end = (group_start + groups_per_chunk).min(outer_groups);
644 let mut cursor =
645 ContiguousMulOuterCursor::new(dims, a_strides, b_strides, &plan, group_start);
646
647 for group in group_start..group_end {
648 let index = group * block_len;
649 unsafe {
650 run_contiguous_mul_row_block::<D, A, B>(
651 dst.as_ptr(),
652 a.as_const(),
653 b.as_const(),
654 &plan,
655 index,
656 total,
657 cursor.a_offset,
658 cursor.b_offset,
659 );
660 }
661 cursor.advance();
662 }
663 });
664
665 true
666 } else {
667 run_contiguous_range_mul_single_thread(
668 dst_ptr,
669 dims,
670 a_ptr,
671 a_strides,
672 b_ptr,
673 b_strides,
674 &plan,
675 total,
676 block_len,
677 outer_groups,
678 )
679 }
680 }
681
682 #[cfg(not(feature = "parallel"))]
683 {
684 run_contiguous_range_mul_single_thread(
685 dst_ptr,
686 dims,
687 a_ptr,
688 a_strides,
689 b_ptr,
690 b_strides,
691 &plan,
692 total,
693 block_len,
694 outer_groups,
695 )
696 }
697}
698
699fn run_contiguous_range_mul_single_thread<
700 D: Copy + MaybeSendSync + 'static,
701 A: Copy + MaybeSendSync + Mul<B, Output = D> + 'static,
702 B: Copy + MaybeSendSync + 'static,
703>(
704 dst_ptr: *mut D,
705 dims: &[usize],
706 a_ptr: *const A,
707 a_strides: &[isize],
708 b_ptr: *const B,
709 b_strides: &[isize],
710 plan: &ContiguousMulRangePlan,
711 total: usize,
712 block_len: usize,
713 outer_groups: usize,
714) -> bool {
715 let mut cursor = ContiguousMulOuterCursor::new(dims, a_strides, b_strides, plan, 0);
716 for group in 0..outer_groups {
717 let index = group * block_len;
718 unsafe {
719 run_contiguous_mul_row_block::<D, A, B>(
720 dst_ptr,
721 a_ptr,
722 b_ptr,
723 plan,
724 index,
725 total,
726 cursor.a_offset,
727 cursor.b_offset,
728 );
729 }
730 cursor.advance();
731 }
732 true
733}
734
735#[inline(always)]
737unsafe fn inner_loop_map3<
738 D: Copy,
739 A: Copy,
740 B: Copy,
741 C: Copy,
742 OpA: ElementOp<A>,
743 OpB: ElementOp<B>,
744 OpC: ElementOp<C>,
745>(
746 dp: *mut D,
747 ds: isize,
748 ap: *const A,
749 a_s: isize,
750 bp: *const B,
751 b_s: isize,
752 cp: *const C,
753 c_s: isize,
754 len: usize,
755 f: &impl Fn(A, B, C) -> D,
756) {
757 if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 {
758 let src_a = std::slice::from_raw_parts(ap, len);
759 let src_b = std::slice::from_raw_parts(bp, len);
760 let src_c = std::slice::from_raw_parts(cp, len);
761 let dst = std::slice::from_raw_parts_mut(dp, len);
762 simd::dispatch_if_large(len, || {
763 for i in 0..len {
764 dst[i] = f(
765 OpA::apply(src_a[i]),
766 OpB::apply(src_b[i]),
767 OpC::apply(src_c[i]),
768 );
769 }
770 });
771 } else {
772 let mut dp = dp;
773 let mut ap = ap;
774 let mut bp = bp;
775 let mut cp = cp;
776 for _ in 0..len {
777 *dp = f(OpA::apply(*ap), OpB::apply(*bp), OpC::apply(*cp));
778 dp = dp.offset(ds);
779 ap = ap.offset(a_s);
780 bp = bp.offset(b_s);
781 cp = cp.offset(c_s);
782 }
783 }
784}
785
786#[inline(always)]
788unsafe fn inner_loop_map4<
789 D: Copy,
790 A: Copy,
791 B: Copy,
792 C: Copy,
793 E: Copy,
794 OpA: ElementOp<A>,
795 OpB: ElementOp<B>,
796 OpC: ElementOp<C>,
797 OpE: ElementOp<E>,
798>(
799 dp: *mut D,
800 ds: isize,
801 ap: *const A,
802 a_s: isize,
803 bp: *const B,
804 b_s: isize,
805 cp: *const C,
806 c_s: isize,
807 ep: *const E,
808 e_s: isize,
809 len: usize,
810 f: &impl Fn(A, B, C, E) -> D,
811) {
812 if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 && e_s == 1 {
813 let src_a = std::slice::from_raw_parts(ap, len);
814 let src_b = std::slice::from_raw_parts(bp, len);
815 let src_c = std::slice::from_raw_parts(cp, len);
816 let src_e = std::slice::from_raw_parts(ep, len);
817 let dst = std::slice::from_raw_parts_mut(dp, len);
818 simd::dispatch_if_large(len, || {
819 for i in 0..len {
820 dst[i] = f(
821 OpA::apply(src_a[i]),
822 OpB::apply(src_b[i]),
823 OpC::apply(src_c[i]),
824 OpE::apply(src_e[i]),
825 );
826 }
827 });
828 } else {
829 let mut dp = dp;
830 let mut ap = ap;
831 let mut bp = bp;
832 let mut cp = cp;
833 let mut ep = ep;
834 for _ in 0..len {
835 *dp = f(
836 OpA::apply(*ap),
837 OpB::apply(*bp),
838 OpC::apply(*cp),
839 OpE::apply(*ep),
840 );
841 dp = dp.offset(ds);
842 ap = ap.offset(a_s);
843 bp = bp.offset(b_s);
844 cp = cp.offset(c_s);
845 ep = ep.offset(e_s);
846 }
847 }
848}
849
850pub fn map_into<D: Copy + MaybeSendSync, A: Copy + MaybeSendSync, Op: ElementOp<A>>(
855 dest: &mut StridedViewMut<D>,
856 src: &StridedView<A, Op>,
857 f: impl Fn(A) -> D + MaybeSync,
858) -> Result<()> {
859 ensure_same_shape(dest.dims(), src.dims())?;
860
861 let dst_ptr = dest.as_mut_ptr();
862 let src_ptr = src.ptr();
863 let dst_dims = dest.dims();
864 let dst_strides = dest.strides();
865 let src_strides = src.strides();
866
867 if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
868 let len = total_len(dst_dims);
869 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
870 let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
871 simd::dispatch_if_large(len, || {
872 for i in 0..len {
873 dst[i] = f(Op::apply(src[i]));
874 }
875 });
876 return Ok(());
877 }
878
879 let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
880 let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<A>());
881 let total = total_len(dst_dims);
882
883 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
885 build_plan_fused_small(dst_dims, &strides_list)
886 } else {
887 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
888 };
889
890 #[cfg(feature = "parallel")]
891 {
892 let total: usize = fused_dims.iter().product();
893 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
894 use crate::threading::SendPtr;
895 let dst_send = SendPtr(dst_ptr);
896 let src_send = SendPtr(src_ptr as *mut A);
897
898 let costs = compute_costs(&ordered_strides);
899 let initial_offsets = vec![0isize; strides_list.len()];
900 let nthreads = rayon::current_num_threads();
901
902 return mapreduce_threaded(
903 &fused_dims,
904 &plan.block,
905 &ordered_strides,
906 &initial_offsets,
907 &costs,
908 nthreads,
909 0,
910 1,
911 &|dims, blocks, strides_list, offsets| {
912 for_each_inner_block_with_offsets(
913 dims,
914 blocks,
915 strides_list,
916 offsets,
917 |offsets, len, strides| {
918 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
919 let sp = unsafe { src_send.as_const().offset(offsets[1]) };
920 unsafe {
921 inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f)
922 };
923 Ok(())
924 },
925 )
926 },
927 );
928 }
929 }
930
931 let initial_offsets = vec![0isize; ordered_strides.len()];
932 for_each_inner_block_preordered(
933 &fused_dims,
934 &plan.block,
935 &ordered_strides,
936 &initial_offsets,
937 |offsets, len, strides| {
938 let dp = unsafe { dst_ptr.offset(offsets[0]) };
939 let sp = unsafe { src_ptr.offset(offsets[1]) };
940 unsafe { inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f) };
941 Ok(())
942 },
943 )
944}
945
946pub fn zip_map2_into<
951 D: Copy + MaybeSendSync,
952 A: Copy + MaybeSendSync,
953 B: Copy + MaybeSendSync,
954 OpA: ElementOp<A>,
955 OpB: ElementOp<B>,
956>(
957 dest: &mut StridedViewMut<D>,
958 a: &StridedView<A, OpA>,
959 b: &StridedView<B, OpB>,
960 f: impl Fn(A, B) -> D + MaybeSync,
961) -> Result<()> {
962 ensure_same_shape(dest.dims(), a.dims())?;
963 ensure_same_shape(dest.dims(), b.dims())?;
964
965 let dst_ptr = dest.as_mut_ptr();
966 let dst_dims = dest.dims();
967 let dst_strides = dest.strides();
968 let a_ptr = a.ptr();
969 let b_ptr = b.ptr();
970
971 let a_strides = a.strides();
972 let b_strides = b.strides();
973
974 if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
975 let len = total_len(dst_dims);
976 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
977 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
978 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
979 simd::dispatch_if_large(len, || {
980 for i in 0..len {
981 dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]));
982 }
983 });
984 return Ok(());
985 }
986
987 let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
988 let elem_size = std::mem::size_of::<D>()
989 .max(std::mem::size_of::<A>())
990 .max(std::mem::size_of::<B>());
991 let total = total_len(dst_dims);
992
993 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
995 build_plan_fused_small(dst_dims, &strides_list)
996 } else {
997 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
998 };
999
1000 #[cfg(feature = "parallel")]
1001 {
1002 let total: usize = fused_dims.iter().product();
1003 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1004 use crate::threading::SendPtr;
1005 let dst_send = SendPtr(dst_ptr);
1006 let a_send = SendPtr(a_ptr as *mut A);
1007 let b_send = SendPtr(b_ptr as *mut B);
1008
1009 let costs = compute_costs(&ordered_strides);
1010 let initial_offsets = vec![0isize; strides_list.len()];
1011 let nthreads = rayon::current_num_threads();
1012
1013 return mapreduce_threaded(
1014 &fused_dims,
1015 &plan.block,
1016 &ordered_strides,
1017 &initial_offsets,
1018 &costs,
1019 nthreads,
1020 0,
1021 1,
1022 &|dims, blocks, strides_list, offsets| {
1023 for_each_inner_block_with_offsets(
1024 dims,
1025 blocks,
1026 strides_list,
1027 offsets,
1028 |offsets, len, strides| {
1029 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
1030 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
1031 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
1032 unsafe {
1033 inner_loop_map2::<D, A, B, OpA, OpB>(
1034 dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
1035 )
1036 };
1037 Ok(())
1038 },
1039 )
1040 },
1041 );
1042 }
1043 }
1044
1045 let initial_offsets = vec![0isize; ordered_strides.len()];
1046 for_each_inner_block_preordered(
1047 &fused_dims,
1048 &plan.block,
1049 &ordered_strides,
1050 &initial_offsets,
1051 |offsets, len, strides| {
1052 let dp = unsafe { dst_ptr.offset(offsets[0]) };
1053 let ap = unsafe { a_ptr.offset(offsets[1]) };
1054 let bp = unsafe { b_ptr.offset(offsets[2]) };
1055 unsafe {
1056 inner_loop_map2::<D, A, B, OpA, OpB>(
1057 dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
1058 )
1059 };
1060 Ok(())
1061 },
1062 )
1063}
1064
1065fn mul_identity_into_raw<
1066 D: Copy + MaybeSendSync + 'static,
1067 A: Copy + MaybeSendSync + Mul<B, Output = D> + 'static,
1068 B: Copy + MaybeSendSync + 'static,
1069>(
1070 dest: &mut StridedViewMut<D>,
1071 a_ptr: *const A,
1072 a_strides: &[isize],
1073 b_ptr: *const B,
1074 b_strides: &[isize],
1075) -> Result<()> {
1076 let dst_ptr = dest.as_mut_ptr();
1077 let dst_dims = dest.dims();
1078 let dst_strides = dest.strides();
1079 debug_assert_eq!(dst_dims.len(), a_strides.len());
1080 debug_assert_eq!(dst_dims.len(), b_strides.len());
1081
1082 if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
1083 let len = total_len(dst_dims);
1084 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
1085 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
1086 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
1087 if simd::try_mul_contiguous(dst, sa, sb) {
1088 return Ok(());
1089 }
1090 for i in 0..len {
1091 dst[i] = sa[i] * sb[i];
1092 }
1093 return Ok(());
1094 }
1095
1096 let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
1097 let elem_size = std::mem::size_of::<D>()
1098 .max(std::mem::size_of::<A>())
1099 .max(std::mem::size_of::<B>());
1100 let total = total_len(dst_dims);
1101
1102 if try_contiguous_range_mul(
1103 dst_ptr,
1104 dst_dims,
1105 dst_strides,
1106 a_ptr,
1107 a_strides,
1108 b_ptr,
1109 b_strides,
1110 ) {
1111 return Ok(());
1112 }
1113
1114 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
1115 build_plan_fused_small(dst_dims, &strides_list)
1116 } else {
1117 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
1118 };
1119
1120 #[cfg(feature = "parallel")]
1121 {
1122 let total: usize = fused_dims.iter().product();
1123 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1124 use crate::threading::SendPtr;
1125 let dst_send = SendPtr(dst_ptr);
1126 let a_send = SendPtr(a_ptr as *mut A);
1127 let b_send = SendPtr(b_ptr as *mut B);
1128
1129 let costs = compute_costs(&ordered_strides);
1130 let initial_offsets = vec![0isize; strides_list.len()];
1131 let nthreads = rayon::current_num_threads();
1132
1133 return mapreduce_threaded(
1134 &fused_dims,
1135 &plan.block,
1136 &ordered_strides,
1137 &initial_offsets,
1138 &costs,
1139 nthreads,
1140 0,
1141 1,
1142 &|dims, blocks, strides_list, offsets| {
1143 for_each_inner_block_with_offsets(
1144 dims,
1145 blocks,
1146 strides_list,
1147 offsets,
1148 |offsets, len, strides| {
1149 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
1150 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
1151 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
1152 unsafe {
1153 inner_loop_mul2::<D, A, B>(
1154 dp, strides[0], ap, strides[1], bp, strides[2], len,
1155 )
1156 };
1157 Ok(())
1158 },
1159 )
1160 },
1161 );
1162 }
1163 }
1164
1165 let initial_offsets = vec![0isize; ordered_strides.len()];
1166 for_each_inner_block_preordered(
1167 &fused_dims,
1168 &plan.block,
1169 &ordered_strides,
1170 &initial_offsets,
1171 |offsets, len, strides| {
1172 let dp = unsafe { dst_ptr.offset(offsets[0]) };
1173 let ap = unsafe { a_ptr.offset(offsets[1]) };
1174 let bp = unsafe { b_ptr.offset(offsets[2]) };
1175 unsafe {
1176 inner_loop_mul2::<D, A, B>(dp, strides[0], ap, strides[1], bp, strides[2], len)
1177 };
1178 Ok(())
1179 },
1180 )
1181}
1182
1183fn mul_identity_into<
1184 D: Copy + MaybeSendSync + 'static,
1185 A: Copy + Mul<B, Output = D> + MaybeSendSync + 'static,
1186 B: Copy + MaybeSendSync + 'static,
1187 OpA: ElementOp<A>,
1188 OpB: ElementOp<B>,
1189>(
1190 dest: &mut StridedViewMut<D>,
1191 a: &StridedView<A, OpA>,
1192 b: &StridedView<B, OpB>,
1193) -> Result<()> {
1194 ensure_same_shape(dest.dims(), a.dims())?;
1195 ensure_same_shape(dest.dims(), b.dims())?;
1196 mul_identity_into_raw(dest, a.ptr(), a.strides(), b.ptr(), b.strides())
1197}
1198
1199pub fn mul_into<
1204 D: Copy + MaybeSendSync + 'static,
1205 A: Copy + Mul<B, Output = D> + MaybeSendSync + 'static,
1206 B: Copy + MaybeSendSync + 'static,
1207 OpA: ElementOp<A>,
1208 OpB: ElementOp<B>,
1209>(
1210 dest: &mut StridedViewMut<D>,
1211 a: &StridedView<A, OpA>,
1212 b: &StridedView<B, OpB>,
1213) -> Result<()> {
1214 if OpA::IS_IDENTITY && OpB::IS_IDENTITY {
1215 return mul_identity_into(dest, a, b);
1216 }
1217
1218 zip_map2_into(dest, a, b, |x, y| x * y)
1219}
1220
1221fn broadcast_strides_for_axes(
1222 source_dims: &[usize],
1223 source_strides: &[isize],
1224 target_dims: &[usize],
1225 axes: &[usize],
1226) -> Result<AxisVec<isize>> {
1227 if source_dims.len() != axes.len() {
1228 return Err(StridedError::RankMismatch(source_dims.len(), axes.len()));
1229 }
1230 debug_assert_eq!(source_dims.len(), source_strides.len());
1231
1232 let mut seen = AxisVec::<bool>::new();
1233 seen.resize(target_dims.len(), false);
1234 let mut strides = AxisVec::<isize>::new();
1235 strides.resize(target_dims.len(), 0);
1236 for (src_axis, &dst_axis) in axes.iter().enumerate() {
1237 if dst_axis >= target_dims.len() {
1238 return Err(StridedError::InvalidAxis {
1239 axis: dst_axis,
1240 rank: target_dims.len(),
1241 });
1242 }
1243 if seen[dst_axis] {
1244 return Err(StridedError::InvalidAxis {
1245 axis: dst_axis,
1246 rank: target_dims.len(),
1247 });
1248 }
1249 seen[dst_axis] = true;
1250
1251 let source_dim = source_dims[src_axis];
1252 let target_dim = target_dims[dst_axis];
1253 if source_dim != target_dim && source_dim != 1 {
1254 return Err(StridedError::ShapeMismatch(
1255 source_dims.to_vec(),
1256 target_dims.to_vec(),
1257 ));
1258 }
1259 if source_dim == target_dim {
1260 strides[dst_axis] = source_strides[src_axis];
1261 }
1262 }
1263
1264 Ok(strides)
1265}
1266
1267fn broadcast_view_with_strides<'a, T, Op: ElementOp<T>>(
1268 view: &StridedView<'a, T, Op>,
1269 target_dims: &[usize],
1270 strides: &[isize],
1271) -> StridedView<'a, T, Op> {
1272 unsafe { StridedView::new_unchecked(view.data(), target_dims, strides, view.offset()) }
1273}
1274
1275pub fn broadcast_mul_into<
1280 D: Copy + MaybeSendSync + 'static,
1281 A: Copy + Mul<B, Output = D> + MaybeSendSync + 'static,
1282 B: Copy + MaybeSendSync + 'static,
1283 OpA: ElementOp<A>,
1284 OpB: ElementOp<B>,
1285>(
1286 dest: &mut StridedViewMut<D>,
1287 a: &StridedView<A, OpA>,
1288 a_axes: &[usize],
1289 b: &StridedView<B, OpB>,
1290 b_axes: &[usize],
1291) -> Result<()> {
1292 let a_strides = broadcast_strides_for_axes(a.dims(), a.strides(), dest.dims(), a_axes)?;
1293 let b_strides = broadcast_strides_for_axes(b.dims(), b.strides(), dest.dims(), b_axes)?;
1294
1295 if OpA::IS_IDENTITY && OpB::IS_IDENTITY {
1296 return mul_identity_into_raw(dest, a.ptr(), &a_strides, b.ptr(), &b_strides);
1297 }
1298
1299 let a = broadcast_view_with_strides(a, dest.dims(), &a_strides);
1300 let b = broadcast_view_with_strides(b, dest.dims(), &b_strides);
1301 mul_into(dest, &a, &b)
1302}
1303
1304pub fn zip_map3_into<
1306 D: Copy + MaybeSendSync,
1307 A: Copy + MaybeSendSync,
1308 B: Copy + MaybeSendSync,
1309 C: Copy + MaybeSendSync,
1310 OpA: ElementOp<A>,
1311 OpB: ElementOp<B>,
1312 OpC: ElementOp<C>,
1313>(
1314 dest: &mut StridedViewMut<D>,
1315 a: &StridedView<A, OpA>,
1316 b: &StridedView<B, OpB>,
1317 c: &StridedView<C, OpC>,
1318 f: impl Fn(A, B, C) -> D + MaybeSync,
1319) -> Result<()> {
1320 ensure_same_shape(dest.dims(), a.dims())?;
1321 ensure_same_shape(dest.dims(), b.dims())?;
1322 ensure_same_shape(dest.dims(), c.dims())?;
1323
1324 let dst_ptr = dest.as_mut_ptr();
1325 let a_ptr = a.ptr();
1326 let b_ptr = b.ptr();
1327 let c_ptr = c.ptr();
1328
1329 let dst_dims = dest.dims();
1330 let dst_strides = dest.strides();
1331
1332 if sequential_contiguous_layout(
1333 dst_dims,
1334 &[dst_strides, a.strides(), b.strides(), c.strides()],
1335 )
1336 .is_some()
1337 {
1338 let len = total_len(dst_dims);
1339 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
1340 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
1341 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
1342 let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
1343 simd::dispatch_if_large(len, || {
1344 for i in 0..len {
1345 dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]), OpC::apply(sc[i]));
1346 }
1347 });
1348 return Ok(());
1349 }
1350
1351 let strides_list: [&[isize]; 4] = [dst_strides, a.strides(), b.strides(), c.strides()];
1352 let elem_size = std::mem::size_of::<D>()
1353 .max(std::mem::size_of::<A>())
1354 .max(std::mem::size_of::<B>())
1355 .max(std::mem::size_of::<C>());
1356 let total = total_len(dst_dims);
1357
1358 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
1360 build_plan_fused_small(dst_dims, &strides_list)
1361 } else {
1362 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
1363 };
1364
1365 #[cfg(feature = "parallel")]
1366 {
1367 let total: usize = fused_dims.iter().product();
1368 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1369 use crate::threading::SendPtr;
1370 let dst_send = SendPtr(dst_ptr);
1371 let a_send = SendPtr(a_ptr as *mut A);
1372 let b_send = SendPtr(b_ptr as *mut B);
1373 let c_send = SendPtr(c_ptr as *mut C);
1374
1375 let costs = compute_costs(&ordered_strides);
1376 let initial_offsets = vec![0isize; strides_list.len()];
1377 let nthreads = rayon::current_num_threads();
1378
1379 return mapreduce_threaded(
1380 &fused_dims,
1381 &plan.block,
1382 &ordered_strides,
1383 &initial_offsets,
1384 &costs,
1385 nthreads,
1386 0,
1387 1,
1388 &|dims, blocks, strides_list, offsets| {
1389 for_each_inner_block_with_offsets(
1390 dims,
1391 blocks,
1392 strides_list,
1393 offsets,
1394 |offsets, len, strides| {
1395 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
1396 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
1397 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
1398 let cp = unsafe { c_send.as_const().offset(offsets[3]) };
1399 unsafe {
1400 inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
1401 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
1402 len, &f,
1403 )
1404 };
1405 Ok(())
1406 },
1407 )
1408 },
1409 );
1410 }
1411 }
1412
1413 let initial_offsets = vec![0isize; ordered_strides.len()];
1414 for_each_inner_block_preordered(
1415 &fused_dims,
1416 &plan.block,
1417 &ordered_strides,
1418 &initial_offsets,
1419 |offsets, len, strides| {
1420 let dp = unsafe { dst_ptr.offset(offsets[0]) };
1421 let ap = unsafe { a_ptr.offset(offsets[1]) };
1422 let bp = unsafe { b_ptr.offset(offsets[2]) };
1423 let cp = unsafe { c_ptr.offset(offsets[3]) };
1424 unsafe {
1425 inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
1426 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], len, &f,
1427 )
1428 };
1429 Ok(())
1430 },
1431 )
1432}
1433
1434pub fn zip_map4_into<
1436 D: Copy + MaybeSendSync,
1437 A: Copy + MaybeSendSync,
1438 B: Copy + MaybeSendSync,
1439 C: Copy + MaybeSendSync,
1440 E: Copy + MaybeSendSync,
1441 OpA: ElementOp<A>,
1442 OpB: ElementOp<B>,
1443 OpC: ElementOp<C>,
1444 OpE: ElementOp<E>,
1445>(
1446 dest: &mut StridedViewMut<D>,
1447 a: &StridedView<A, OpA>,
1448 b: &StridedView<B, OpB>,
1449 c: &StridedView<C, OpC>,
1450 e: &StridedView<E, OpE>,
1451 f: impl Fn(A, B, C, E) -> D + MaybeSync,
1452) -> Result<()> {
1453 ensure_same_shape(dest.dims(), a.dims())?;
1454 ensure_same_shape(dest.dims(), b.dims())?;
1455 ensure_same_shape(dest.dims(), c.dims())?;
1456 ensure_same_shape(dest.dims(), e.dims())?;
1457
1458 let dst_ptr = dest.as_mut_ptr();
1459 let a_ptr = a.ptr();
1460 let b_ptr = b.ptr();
1461 let c_ptr = c.ptr();
1462 let e_ptr = e.ptr();
1463
1464 let dst_dims = dest.dims();
1465 let dst_strides = dest.strides();
1466
1467 if sequential_contiguous_layout(
1468 dst_dims,
1469 &[
1470 dst_strides,
1471 a.strides(),
1472 b.strides(),
1473 c.strides(),
1474 e.strides(),
1475 ],
1476 )
1477 .is_some()
1478 {
1479 let len = total_len(dst_dims);
1480 let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
1481 let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
1482 let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
1483 let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
1484 let se = unsafe { std::slice::from_raw_parts(e_ptr, len) };
1485 simd::dispatch_if_large(len, || {
1486 for i in 0..len {
1487 dst[i] = f(
1488 OpA::apply(sa[i]),
1489 OpB::apply(sb[i]),
1490 OpC::apply(sc[i]),
1491 OpE::apply(se[i]),
1492 );
1493 }
1494 });
1495 return Ok(());
1496 }
1497
1498 let strides_list: [&[isize]; 5] = [
1499 dst_strides,
1500 a.strides(),
1501 b.strides(),
1502 c.strides(),
1503 e.strides(),
1504 ];
1505 let elem_size = std::mem::size_of::<D>()
1506 .max(std::mem::size_of::<A>())
1507 .max(std::mem::size_of::<B>())
1508 .max(std::mem::size_of::<C>())
1509 .max(std::mem::size_of::<E>());
1510 let total = total_len(dst_dims);
1511
1512 let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
1514 build_plan_fused_small(dst_dims, &strides_list)
1515 } else {
1516 build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
1517 };
1518
1519 #[cfg(feature = "parallel")]
1520 {
1521 let total: usize = fused_dims.iter().product();
1522 if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1523 use crate::threading::SendPtr;
1524 let dst_send = SendPtr(dst_ptr);
1525 let a_send = SendPtr(a_ptr as *mut A);
1526 let b_send = SendPtr(b_ptr as *mut B);
1527 let c_send = SendPtr(c_ptr as *mut C);
1528 let e_send = SendPtr(e_ptr as *mut E);
1529
1530 let costs = compute_costs(&ordered_strides);
1531 let initial_offsets = vec![0isize; strides_list.len()];
1532 let nthreads = rayon::current_num_threads();
1533
1534 return mapreduce_threaded(
1535 &fused_dims,
1536 &plan.block,
1537 &ordered_strides,
1538 &initial_offsets,
1539 &costs,
1540 nthreads,
1541 0,
1542 1,
1543 &|dims, blocks, strides_list, offsets| {
1544 for_each_inner_block_with_offsets(
1545 dims,
1546 blocks,
1547 strides_list,
1548 offsets,
1549 |offsets, len, strides| {
1550 let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
1551 let ap = unsafe { a_send.as_const().offset(offsets[1]) };
1552 let bp = unsafe { b_send.as_const().offset(offsets[2]) };
1553 let cp = unsafe { c_send.as_const().offset(offsets[3]) };
1554 let ep = unsafe { e_send.as_const().offset(offsets[4]) };
1555 unsafe {
1556 inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
1557 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
1558 ep, strides[4], len, &f,
1559 )
1560 };
1561 Ok(())
1562 },
1563 )
1564 },
1565 );
1566 }
1567 }
1568
1569 let initial_offsets = vec![0isize; ordered_strides.len()];
1570 for_each_inner_block_preordered(
1571 &fused_dims,
1572 &plan.block,
1573 &ordered_strides,
1574 &initial_offsets,
1575 |offsets, len, strides| {
1576 let dp = unsafe { dst_ptr.offset(offsets[0]) };
1577 let ap = unsafe { a_ptr.offset(offsets[1]) };
1578 let bp = unsafe { b_ptr.offset(offsets[2]) };
1579 let cp = unsafe { c_ptr.offset(offsets[3]) };
1580 let ep = unsafe { e_ptr.offset(offsets[4]) };
1581 unsafe {
1582 inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
1583 dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], ep, strides[4],
1584 len, &f,
1585 )
1586 };
1587 Ok(())
1588 },
1589 )
1590}
1591
1592#[cfg(test)]
1593mod scalar_branch_tests {
1594 use super::*;
1595 use crate::StridedArray;
1596 use strided_view::Identity;
1597
1598 #[test]
1599 fn test_inner_loop_map2_stride_specializations() {
1600 let a = [2.0, 3.0, 5.0, 7.0, 11.0, 13.0];
1601 let b = [17.0, 19.0, 23.0, 29.0, 31.0, 37.0];
1602
1603 let mut out = [0.0; 3];
1604 unsafe {
1605 inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1606 out.as_mut_ptr(),
1607 1,
1608 a.as_ptr(),
1609 1,
1610 b.as_ptr(),
1611 1,
1612 3,
1613 &|x, y| x + y,
1614 );
1615 }
1616 assert_eq!(out, [19.0, 22.0, 28.0]);
1617
1618 let mut out = [0.0; 3];
1619 unsafe {
1620 inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1621 out.as_mut_ptr(),
1622 1,
1623 a.as_ptr(),
1624 1,
1625 b.as_ptr(),
1626 0,
1627 3,
1628 &|x, y| x * y,
1629 );
1630 }
1631 assert_eq!(out, [34.0, 51.0, 85.0]);
1632
1633 let mut out = [0.0; 3];
1634 unsafe {
1635 inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1636 out.as_mut_ptr(),
1637 1,
1638 a.as_ptr(),
1639 0,
1640 b.as_ptr(),
1641 1,
1642 3,
1643 &|x, y| x * y,
1644 );
1645 }
1646 assert_eq!(out, [34.0, 38.0, 46.0]);
1647
1648 let mut out = [0.0; 3];
1649 unsafe {
1650 inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1651 out.as_mut_ptr(),
1652 1,
1653 a.as_ptr(),
1654 0,
1655 b.as_ptr(),
1656 0,
1657 3,
1658 &|x, y| x + y,
1659 );
1660 }
1661 assert_eq!(out, [19.0, 19.0, 19.0]);
1662
1663 let mut out = [0.0; 3];
1664 unsafe {
1665 inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1666 out.as_mut_ptr(),
1667 1,
1668 a.as_ptr(),
1669 2,
1670 b.as_ptr(),
1671 0,
1672 3,
1673 &|x, y| x + y,
1674 );
1675 }
1676 assert_eq!(out, [19.0, 22.0, 28.0]);
1677
1678 let mut out = [0.0; 3];
1679 unsafe {
1680 inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1681 out.as_mut_ptr(),
1682 1,
1683 a.as_ptr(),
1684 0,
1685 b.as_ptr(),
1686 2,
1687 3,
1688 &|x, y| x + y,
1689 );
1690 }
1691 assert_eq!(out, [19.0, 25.0, 33.0]);
1692 }
1693
1694 #[test]
1695 fn test_inner_loop_mul2_stride_specializations() {
1696 let a = [2.0, 3.0, 5.0, 7.0, 11.0, 13.0];
1697 let b = [17.0, 19.0, 23.0, 29.0, 31.0, 37.0];
1698
1699 let mut out = [0.0; 3];
1700 unsafe {
1701 inner_loop_mul2::<f64, f64, f64>(out.as_mut_ptr(), 1, a.as_ptr(), 1, b.as_ptr(), 1, 3);
1702 }
1703 assert_eq!(out, [34.0, 57.0, 115.0]);
1704
1705 let mut out = [0.0; 3];
1706 unsafe {
1707 inner_loop_mul2::<f64, f64, f64>(out.as_mut_ptr(), 1, a.as_ptr(), 0, b.as_ptr(), 1, 3);
1708 }
1709 assert_eq!(out, [34.0, 38.0, 46.0]);
1710
1711 let mut out = [0.0; 3];
1712 unsafe {
1713 inner_loop_mul2::<f64, f64, f64>(out.as_mut_ptr(), 1, a.as_ptr(), 0, b.as_ptr(), 0, 3);
1714 }
1715 assert_eq!(out, [34.0, 34.0, 34.0]);
1716
1717 let mut out = [0.0; 3];
1718 unsafe {
1719 inner_loop_mul2::<f64, f64, f64>(out.as_mut_ptr(), 1, a.as_ptr(), 0, b.as_ptr(), 2, 3);
1720 }
1721 assert_eq!(out, [34.0, 46.0, 62.0]);
1722 }
1723
1724 #[test]
1725 fn test_broadcast_mul_into_error_branches_and_non_identity_ops() {
1726 let lhs = StridedArray::<f64>::row_major(&[2, 3]);
1727 let rhs = StridedArray::<f64>::row_major(&[2, 3]);
1728 let mut out = StridedArray::<f64>::row_major(&[2, 3]);
1729
1730 let err = broadcast_mul_into(&mut out.view_mut(), &lhs.view(), &[0], &rhs.view(), &[0, 1])
1731 .unwrap_err();
1732 assert!(matches!(err, StridedError::RankMismatch(2, 1)));
1733
1734 let err = broadcast_mul_into(
1735 &mut out.view_mut(),
1736 &lhs.view(),
1737 &[0, 3],
1738 &rhs.view(),
1739 &[0, 1],
1740 )
1741 .unwrap_err();
1742 assert!(matches!(
1743 err,
1744 StridedError::InvalidAxis { axis: 3, rank: 2 }
1745 ));
1746
1747 let err = broadcast_mul_into(
1748 &mut out.view_mut(),
1749 &lhs.view(),
1750 &[0, 0],
1751 &rhs.view(),
1752 &[0, 1],
1753 )
1754 .unwrap_err();
1755 assert!(matches!(
1756 err,
1757 StridedError::InvalidAxis { axis: 0, rank: 2 }
1758 ));
1759
1760 let rhs_bad = StridedArray::<f64>::row_major(&[2, 4]);
1761 let err = broadcast_mul_into(
1762 &mut out.view_mut(),
1763 &lhs.view(),
1764 &[0, 1],
1765 &rhs_bad.view(),
1766 &[0, 1],
1767 )
1768 .unwrap_err();
1769 assert!(matches!(err, StridedError::ShapeMismatch(_, _)));
1770
1771 let lhs_conj = lhs.view().conj();
1772 broadcast_mul_into(
1773 &mut out.view_mut(),
1774 &lhs_conj,
1775 &[0, 1],
1776 &rhs.view(),
1777 &[0, 1],
1778 )
1779 .unwrap();
1780 }
1781
1782 #[test]
1783 fn contiguous_mul_range_plan_available_without_parallel_feature() {
1784 let dims = [3usize; 16];
1785 let dst = [
1786 1, 3, 9, 27, 81, 243, 729, 2187, 6561, 19683, 59049, 177147, 531441, 1594323, 4782969,
1787 14348907,
1788 ];
1789 let lhs = [1isize, 3, 9, 27, 81, 243, 729, 2187, 0, 0, 0, 0, 0, 0, 0, 0];
1790 let rhs = [0isize, 0, 0, 0, 0, 0, 0, 0, 1, 3, 9, 27, 81, 243, 729, 2187];
1791
1792 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1793
1794 assert_eq!(plan.inner_len, 6561);
1795 assert_eq!(plan.row_len, 3);
1796 assert_eq!(plan.fast_axis, 0);
1797 }
1798
1799 #[test]
1800 fn contiguous_range_mul_single_thread_computes_large_broadcast_mul() {
1801 let dims = [3usize; 10];
1802 let dst = [1isize, 3, 9, 27, 81, 243, 729, 2187, 6561, 19683];
1803 let lhs = [1isize, 3, 9, 27, 81, 0, 0, 0, 0, 0];
1804 let rhs = [0isize, 0, 0, 0, 0, 1, 3, 9, 27, 81];
1805 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1806 let total = total_len(&dims);
1807 let block_len = plan.inner_len.max(1).saturating_mul(plan.row_len.max(1));
1808 let outer_groups = total.div_ceil(block_len);
1809
1810 let a = vec![2.0; 243];
1811 let b = vec![3.0; 243];
1812 let mut out = vec![0.0; total];
1813
1814 assert!(run_contiguous_range_mul_single_thread::<f64, f64, f64>(
1815 out.as_mut_ptr(),
1816 &dims,
1817 a.as_ptr(),
1818 &lhs,
1819 b.as_ptr(),
1820 &rhs,
1821 &plan,
1822 total,
1823 block_len,
1824 outer_groups,
1825 ));
1826 assert!(out.iter().all(|&x| x == 6.0));
1827 }
1828}
1829
1830#[cfg(all(test, feature = "parallel"))]
1831mod tests {
1832 use super::*;
1833
1834 fn compact_strides_for_axis_order<const N: usize>(
1835 dims: [usize; N],
1836 axis_order: [usize; N],
1837 ) -> [isize; N] {
1838 let mut strides = [0isize; N];
1839 let mut stride = 1isize;
1840 for &axis in &axis_order {
1841 strides[axis] = stride;
1842 stride *= dims[axis] as isize;
1843 }
1844 strides
1845 }
1846
1847 #[test]
1848 fn test_contiguous_mul_range_plan_pure_outer() {
1849 let dims = [7usize, 11];
1850 let dst = [1isize, 7];
1851 let lhs = [1isize, 0];
1852 let rhs = [0isize, 1];
1853
1854 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1855
1856 assert_eq!(plan.inner_len, 7);
1857 assert_eq!(plan.row_len, 11);
1858 assert_eq!(plan.fast_axis, 0);
1859 assert_eq!(plan.a_fast_stride, 1);
1860 assert_eq!(plan.b_fast_stride, 0);
1861 assert_eq!(plan.a_row_stride, 0);
1862 assert_eq!(plan.b_row_stride, 1);
1863 }
1864
1865 #[test]
1866 fn test_compact_axis_order_accepts_all_rank4_axis_permutations() {
1867 fn visit(dims: [usize; 4], axes: &mut [usize; 4], pos: usize, count: &mut usize) {
1868 if pos == axes.len() {
1869 let dst = compact_strides_for_axis_order(dims, *axes);
1870 let axis_order = compact_axis_order(&dims, &dst).unwrap();
1871 assert_eq!(&axis_order[..], &axes[..]);
1872 *count += 1;
1873 return;
1874 }
1875
1876 for i in pos..axes.len() {
1877 axes.swap(pos, i);
1878 visit(dims, axes, pos + 1, count);
1879 axes.swap(pos, i);
1880 }
1881 }
1882
1883 let dims = [2usize, 3, 5, 7];
1884 let mut axes = [0usize, 1, 2, 3];
1885 let mut count = 0usize;
1886 visit(dims, &mut axes, 0, &mut count);
1887
1888 assert_eq!(count, 24);
1889 }
1890
1891 #[test]
1892 fn test_compact_axis_order_rejects_strided_layout_with_holes() {
1893 let dims = [2usize, 3, 5];
1894 let strides = [1isize, 4, 2];
1895
1896 assert_eq!(compact_axis_order(&dims, &strides), None);
1897 }
1898
1899 #[test]
1900 fn test_contiguous_mul_range_plan_uses_permuted_compact_output_for_unrelated_shape() {
1901 let dims = [2usize, 3, 5, 7, 11];
1902 let dst = compact_strides_for_axis_order(dims, [2usize, 0, 4, 1, 3]);
1903 let lhs = [5isize, 0, 1, 0, 10];
1904 let rhs = [0isize, 1, 0, 3, 0];
1905
1906 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1907
1908 assert_eq!(&plan.axis_order[..], &[2, 0, 4, 1, 3]);
1909 assert_eq!(plan.inner_len, 110);
1910 assert_eq!(plan.row_len, 3);
1911 assert_eq!(plan.fast_axis, 2);
1912 assert_eq!(plan.a_fast_stride, 1);
1913 assert_eq!(plan.b_fast_stride, 0);
1914 assert_eq!(plan.a_row_stride, 0);
1915 assert_eq!(plan.b_row_stride, 1);
1916 assert_eq!(transposed_scalar_tile_kind(&plan), None);
1917 }
1918
1919 #[test]
1920 fn test_contiguous_mul_range_plan_compact_batched_outer() {
1921 let dims = [3usize, 5, 7, 11];
1922 let dst = [1isize, 3, 15, 105];
1923 let lhs = [1isize, 3, 0, 15];
1924 let rhs = [0isize, 0, 1, 7];
1925
1926 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1927
1928 assert_eq!(plan.inner_len, 15);
1929 assert_eq!(plan.row_len, 7);
1930 assert_eq!(plan.fast_axis, 0);
1931 assert_eq!(plan.a_fast_stride, 1);
1932 assert_eq!(plan.b_fast_stride, 0);
1933 assert_eq!(plan.a_row_stride, 0);
1934 assert_eq!(plan.b_row_stride, 1);
1935 }
1936
1937 #[test]
1938 fn test_contiguous_mul_range_plan_noncompact_batched_outer() {
1939 let dims = [5usize, 5, 7, 11];
1940 let dst = [1isize, 5, 25, 175];
1941 let lhs = [5isize, 1, 0, 25];
1942 let rhs = [0isize, 0, 1, 7];
1943
1944 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1945
1946 assert_eq!(plan.inner_len, 5);
1947 assert_eq!(plan.row_len, 5);
1948 assert_eq!(plan.fast_axis, 0);
1949 assert_eq!(plan.a_fast_stride, 5);
1950 assert_eq!(plan.b_fast_stride, 0);
1951 assert_eq!(plan.a_row_stride, 1);
1952 assert_eq!(plan.b_row_stride, 0);
1953 }
1954
1955 #[test]
1956 fn test_contiguous_mul_range_plan_noncompact_row_major_output() {
1957 let dims = [5usize, 5, 7, 11];
1958 let dst = [5isize, 1, 25, 175];
1959 let lhs = [5isize, 1, 0, 25];
1960 let rhs = [0isize, 0, 1, 7];
1961
1962 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1963
1964 assert_eq!(plan.inner_len, 25);
1965 assert_eq!(plan.row_len, 7);
1966 assert_eq!(plan.fast_axis, 1);
1967 assert_eq!(plan.a_fast_stride, 1);
1968 assert_eq!(plan.b_fast_stride, 0);
1969 assert_eq!(plan.a_row_stride, 0);
1970 assert_eq!(plan.b_row_stride, 1);
1971 assert_eq!(transposed_scalar_tile_kind(&plan), None);
1972 }
1973
1974 #[test]
1975 fn test_broadcast_strides_for_axes_batched_outer() {
1976 let target_dims = [3usize, 5, 7, 11];
1977 let lhs_dims = [3usize, 5, 11];
1978 let lhs_strides = [3isize, 1, 15];
1979 let rhs_dims = [7usize, 11];
1980 let rhs_strides = [1isize, 7];
1981
1982 let lhs =
1983 broadcast_strides_for_axes(&lhs_dims, &lhs_strides, &target_dims, &[0, 1, 3]).unwrap();
1984 let rhs =
1985 broadcast_strides_for_axes(&rhs_dims, &rhs_strides, &target_dims, &[2, 3]).unwrap();
1986
1987 assert_eq!(&lhs[..], &[3, 1, 0, 15]);
1988 assert_eq!(&rhs[..], &[0, 0, 1, 7]);
1989 }
1990
1991 #[test]
1992 fn test_broadcast_strides_for_axes_uses_zero_stride_for_size_one_source_dim() {
1993 let target_dims = [8usize, 4];
1994 let source_dims = [1usize, 4];
1995 let source_strides = [1isize, 1];
1996
1997 let strides =
1998 broadcast_strides_for_axes(&source_dims, &source_strides, &target_dims, &[0, 1])
1999 .unwrap();
2000
2001 assert_eq!(&strides[..], &[0, 1]);
2002 }
2003
2004 #[test]
2005 fn test_transposed_scalar_tile_kind_detects_noncompact_rhs_scalar() {
2006 let dims = [5usize, 5, 7, 11];
2007 let dst = [1isize, 5, 25, 175];
2008 let lhs = [5isize, 1, 0, 25];
2009 let rhs = [0isize, 0, 1, 7];
2010
2011 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
2012
2013 assert_eq!(
2014 transposed_scalar_tile_kind(&plan),
2015 Some(TransposedScalarTileKind::RhsScalar)
2016 );
2017 }
2018
2019 #[test]
2020 fn test_contiguous_mul_outer_cursor_matches_linear_offsets() {
2021 let dims = [16usize, 16, 64, 64];
2022 let dst = [1isize, 16, 256, 16_384];
2023 let lhs = [16isize, 1, 0, 256];
2024 let rhs = [0isize, 0, 1, 64];
2025 let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
2026 let mut cursor = ContiguousMulOuterCursor::new(&dims, &lhs, &rhs, &plan, 13);
2027 let block_len = plan.inner_len * plan.row_len;
2028
2029 for group in 13..80 {
2030 let index = group * block_len;
2031 assert_eq!(
2032 cursor.a_offset,
2033 strided_offset_for_contiguous_linear_index(&dims, &lhs, &plan.axis_order, index)
2034 );
2035 assert_eq!(
2036 cursor.b_offset,
2037 strided_offset_for_contiguous_linear_index(&dims, &rhs, &plan.axis_order, index)
2038 );
2039 cursor.advance();
2040 }
2041 }
2042}