Skip to main content

strided_kernel/
map_view.rs

1//! Map operations on dynamic-rank strided views.
2//!
3//! These are the canonical view-based map functions, equivalent to Julia's `Base.map!`.
4
5use crate::kernel::{
6    build_plan_fused, build_plan_fused_small, ensure_same_shape, for_each_inner_block_preordered,
7    sequential_contiguous_layout, total_len, SMALL_TENSOR_THRESHOLD,
8};
9use crate::maybe_sync::{MaybeSendSync, MaybeSync};
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::{Result, StridedError};
13use std::ops::Mul;
14use strided_view::ElementOp;
15
16#[cfg(feature = "parallel")]
17use crate::fuse::compute_costs;
18#[cfg(feature = "parallel")]
19use crate::threading::{for_each_inner_block_with_offsets, mapreduce_threaded, MINTHREADLENGTH};
20#[cfg(feature = "parallel")]
21use smallvec::SmallVec;
22
23#[cfg(feature = "parallel")]
24type AxisVec<T> = SmallVec<[T; 8]>;
25#[cfg(not(feature = "parallel"))]
26type AxisVec<T> = Vec<T>;
27
28const CONTIGUOUS_RANGE_MIN_LEN: usize = 1 << 15;
29
30// ============================================================================
31// Stride-specialized inner loop helpers
32//
33// When all inner strides are 1 (contiguous in the innermost dimension),
34// we use slice-based iteration so LLVM can auto-vectorize effectively.
35// This is the Rust equivalent of Julia's @simd on the innermost loop.
36// ============================================================================
37
38/// Unary inner loop: `dest[i] = f(Op::apply(src[i]))` for `len` elements.
39#[inline(always)]
40unsafe fn inner_loop_map1<D: Copy, A: Copy, Op: ElementOp<A>>(
41    dp: *mut D,
42    ds: isize,
43    sp: *const A,
44    ss: isize,
45    len: usize,
46    f: &impl Fn(A) -> D,
47) {
48    if ds == 1 && ss == 1 {
49        let src = std::slice::from_raw_parts(sp, len);
50        let dst = std::slice::from_raw_parts_mut(dp, len);
51        simd::dispatch_if_large(len, || {
52            for (d, s) in dst.iter_mut().zip(src.iter()) {
53                *d = f(Op::apply(*s));
54            }
55        });
56    } else {
57        let mut dp = dp;
58        let mut sp = sp;
59        for _ in 0..len {
60            *dp = f(Op::apply(*sp));
61            dp = dp.offset(ds);
62            sp = sp.offset(ss);
63        }
64    }
65}
66
67/// Binary inner loop: `dest[i] = f(OpA::apply(a[i]), OpB::apply(b[i]))`.
68#[inline(always)]
69unsafe fn inner_loop_map2<D: Copy, A: Copy, B: Copy, OpA: ElementOp<A>, OpB: ElementOp<B>>(
70    dp: *mut D,
71    ds: isize,
72    ap: *const A,
73    a_s: isize,
74    bp: *const B,
75    b_s: isize,
76    len: usize,
77    f: &impl Fn(A, B) -> D,
78) {
79    if ds == 1 && a_s == 1 && b_s == 1 {
80        let src_a = std::slice::from_raw_parts(ap, len);
81        let src_b = std::slice::from_raw_parts(bp, len);
82        let dst = std::slice::from_raw_parts_mut(dp, len);
83        simd::dispatch_if_large(len, || {
84            for i in 0..len {
85                dst[i] = f(OpA::apply(src_a[i]), OpB::apply(src_b[i]));
86            }
87        });
88    } else if ds == 1 && a_s == 1 && b_s == 0 {
89        let src_a = std::slice::from_raw_parts(ap, len);
90        let b = OpB::apply(*bp);
91        let dst = std::slice::from_raw_parts_mut(dp, len);
92        simd::dispatch_if_large(len, || {
93            for i in 0..len {
94                dst[i] = f(OpA::apply(src_a[i]), b);
95            }
96        });
97    } else if ds == 1 && a_s == 0 && b_s == 1 {
98        let a = OpA::apply(*ap);
99        let src_b = std::slice::from_raw_parts(bp, len);
100        let dst = std::slice::from_raw_parts_mut(dp, len);
101        simd::dispatch_if_large(len, || {
102            for i in 0..len {
103                dst[i] = f(a, OpB::apply(src_b[i]));
104            }
105        });
106    } else if ds == 1 && a_s == 0 && b_s == 0 {
107        let a = OpA::apply(*ap);
108        let b = OpB::apply(*bp);
109        let dst = std::slice::from_raw_parts_mut(dp, len);
110        simd::dispatch_if_large(len, || {
111            for d in dst.iter_mut() {
112                *d = f(a, b);
113            }
114        });
115    } else if ds == 1 && b_s == 0 {
116        let b = OpB::apply(*bp);
117        let dst = std::slice::from_raw_parts_mut(dp, len);
118        let mut ap = ap;
119        simd::dispatch_if_large(len, || {
120            for d in dst.iter_mut() {
121                *d = f(OpA::apply(*ap), b);
122                ap = ap.offset(a_s);
123            }
124        });
125    } else if ds == 1 && a_s == 0 {
126        let a = OpA::apply(*ap);
127        let dst = std::slice::from_raw_parts_mut(dp, len);
128        let mut bp = bp;
129        simd::dispatch_if_large(len, || {
130            for d in dst.iter_mut() {
131                *d = f(a, OpB::apply(*bp));
132                bp = bp.offset(b_s);
133            }
134        });
135    } else {
136        let mut dp = dp;
137        let mut ap = ap;
138        let mut bp = bp;
139        for _ in 0..len {
140            *dp = f(OpA::apply(*ap), OpB::apply(*bp));
141            dp = dp.offset(ds);
142            ap = ap.offset(a_s);
143            bp = bp.offset(b_s);
144        }
145    }
146}
147
148/// Binary multiplication inner loop for identity element ops.
149#[inline(always)]
150unsafe fn inner_loop_mul2<
151    D: Copy + 'static,
152    A: Copy + Mul<B, Output = D> + 'static,
153    B: Copy + 'static,
154>(
155    dp: *mut D,
156    ds: isize,
157    ap: *const A,
158    a_s: isize,
159    bp: *const B,
160    b_s: isize,
161    len: usize,
162) {
163    if ds == 1 && a_s == 1 && b_s == 1 {
164        let src_a = std::slice::from_raw_parts(ap, len);
165        let src_b = std::slice::from_raw_parts(bp, len);
166        let dst = std::slice::from_raw_parts_mut(dp, len);
167        if len >= 64 && simd::try_mul_contiguous(dst, src_a, src_b) {
168            return;
169        }
170        for i in 0..len {
171            dst[i] = src_a[i] * src_b[i];
172        }
173    } else if ds == 1 && a_s == 1 && b_s == 0 {
174        let src_a = std::slice::from_raw_parts(ap, len);
175        let b = *bp;
176        let dst = std::slice::from_raw_parts_mut(dp, len);
177        for i in 0..len {
178            dst[i] = src_a[i] * b;
179        }
180    } else if ds == 1 && a_s == 0 && b_s == 1 {
181        let a = *ap;
182        let src_b = std::slice::from_raw_parts(bp, len);
183        let dst = std::slice::from_raw_parts_mut(dp, len);
184        for i in 0..len {
185            dst[i] = a * src_b[i];
186        }
187    } else if ds == 1 && a_s == 0 && b_s == 0 {
188        let a = *ap;
189        let b = *bp;
190        let dst = std::slice::from_raw_parts_mut(dp, len);
191        for d in dst.iter_mut() {
192            *d = a * b;
193        }
194    } else if ds == 1 && b_s == 0 {
195        let b = *bp;
196        let dst = std::slice::from_raw_parts_mut(dp, len);
197        let mut ap = ap;
198        for d in dst.iter_mut() {
199            *d = *ap * b;
200            ap = ap.offset(a_s);
201        }
202    } else if ds == 1 && a_s == 0 {
203        let a = *ap;
204        let dst = std::slice::from_raw_parts_mut(dp, len);
205        let mut bp = bp;
206        for d in dst.iter_mut() {
207            *d = a * *bp;
208            bp = bp.offset(b_s);
209        }
210    } else {
211        let mut dp = dp;
212        let mut ap = ap;
213        let mut bp = bp;
214        for _ in 0..len {
215            *dp = *ap * *bp;
216            dp = dp.offset(ds);
217            ap = ap.offset(a_s);
218            bp = bp.offset(b_s);
219        }
220    }
221}
222
223#[derive(Clone, Debug, Eq, PartialEq)]
224struct ContiguousMulRangePlan {
225    axis_order: AxisVec<usize>,
226    inner_len: usize,
227    inner_axis_count: usize,
228    row_len: usize,
229    outer_axis_start: usize,
230    fast_axis: usize,
231    a_fast_stride: isize,
232    b_fast_stride: isize,
233    a_row_stride: isize,
234    b_row_stride: isize,
235}
236
237#[cfg(feature = "parallel")]
238#[derive(Clone, Copy, Debug, Eq, PartialEq)]
239enum TransposedScalarTileKind {
240    RhsScalar,
241    LhsScalar,
242}
243
244#[cfg(feature = "parallel")]
245fn transposed_scalar_tile_kind(plan: &ContiguousMulRangePlan) -> Option<TransposedScalarTileKind> {
246    let row_len = isize::try_from(plan.row_len).ok()?;
247    if plan.a_fast_stride == row_len
248        && plan.a_row_stride == 1
249        && plan.b_fast_stride == 0
250        && plan.b_row_stride == 0
251    {
252        return Some(TransposedScalarTileKind::RhsScalar);
253    }
254
255    if plan.b_fast_stride == row_len
256        && plan.b_row_stride == 1
257        && plan.a_fast_stride == 0
258        && plan.a_row_stride == 0
259    {
260        return Some(TransposedScalarTileKind::LhsScalar);
261    }
262
263    None
264}
265
266fn compact_axis_order(dims: &[usize], strides: &[isize]) -> Option<AxisVec<usize>> {
267    if dims.len() != strides.len() {
268        return None;
269    }
270
271    let mut active = AxisVec::<usize>::new();
272    let mut inactive = AxisVec::<usize>::new();
273    for (axis, (&dim, &stride)) in dims.iter().zip(strides.iter()).enumerate() {
274        if stride < 0 {
275            return None;
276        }
277        if dim > 1 {
278            active.push(axis);
279        } else {
280            inactive.push(axis);
281        }
282    }
283
284    active.sort_by(|&lhs, &rhs| strides[lhs].cmp(&strides[rhs]).then_with(|| lhs.cmp(&rhs)));
285
286    let mut expected = 1isize;
287    for &axis in &active {
288        if strides[axis] != expected {
289            return None;
290        }
291        expected = expected.saturating_mul(dims[axis] as isize);
292    }
293
294    active.extend(inactive);
295    Some(active)
296}
297
298fn can_fuse_contiguous_range_axis(dim: usize, prev_stride: isize, next_stride: isize) -> bool {
299    dim <= 1 || (prev_stride == 0 && next_stride == 0) || next_stride == prev_stride * dim as isize
300}
301
302fn contiguous_mul_range_plan(
303    dims: &[usize],
304    dst_strides: &[isize],
305    a_strides: &[isize],
306    b_strides: &[isize],
307) -> Option<ContiguousMulRangePlan> {
308    let axis_order = compact_axis_order(dims, dst_strides)?;
309    if dims.is_empty() {
310        return Some(ContiguousMulRangePlan {
311            axis_order,
312            inner_len: 1,
313            inner_axis_count: 0,
314            row_len: 1,
315            outer_axis_start: 0,
316            fast_axis: 0,
317            a_fast_stride: 0,
318            b_fast_stride: 0,
319            a_row_stride: 0,
320            b_row_stride: 0,
321        });
322    }
323
324    let first_pos = axis_order
325        .iter()
326        .position(|&axis| dims[axis] > 1)
327        .unwrap_or(0);
328    let first_axis = axis_order[first_pos];
329    let mut inner_len = dims[first_axis].max(1);
330    let mut inner_axis_count = first_pos + 1;
331    let mut prev_axis = first_axis;
332
333    for &axis in axis_order.iter().skip(first_pos + 1) {
334        if can_fuse_contiguous_range_axis(
335            dims[prev_axis],
336            dst_strides[prev_axis],
337            dst_strides[axis],
338        ) && can_fuse_contiguous_range_axis(
339            dims[prev_axis],
340            a_strides[prev_axis],
341            a_strides[axis],
342        ) && can_fuse_contiguous_range_axis(
343            dims[prev_axis],
344            b_strides[prev_axis],
345            b_strides[axis],
346        ) {
347            inner_len = inner_len.checked_mul(dims[axis].max(1))?;
348            inner_axis_count += 1;
349            prev_axis = axis;
350        } else {
351            break;
352        }
353    }
354
355    let row = axis_order
356        .iter()
357        .enumerate()
358        .skip(inner_axis_count)
359        .find(|&(_, &axis)| dims[axis] > 1);
360    let (row_len, outer_axis_start, a_row_stride, b_row_stride) =
361        if let Some((row_pos, &row_axis)) = row {
362            (
363                dims[row_axis],
364                row_pos + 1,
365                a_strides[row_axis],
366                b_strides[row_axis],
367            )
368        } else {
369            (1, axis_order.len(), 0, 0)
370        };
371
372    Some(ContiguousMulRangePlan {
373        axis_order,
374        inner_len,
375        inner_axis_count,
376        row_len,
377        outer_axis_start,
378        fast_axis: first_axis,
379        a_fast_stride: a_strides[first_axis],
380        b_fast_stride: b_strides[first_axis],
381        a_row_stride,
382        b_row_stride,
383    })
384}
385
386struct ContiguousMulOuterCursor<'a> {
387    dims: &'a [usize],
388    a_strides: &'a [isize],
389    b_strides: &'a [isize],
390    axes: AxisVec<usize>,
391    coords: AxisVec<usize>,
392    a_offset: isize,
393    b_offset: isize,
394}
395
396impl<'a> ContiguousMulOuterCursor<'a> {
397    fn new(
398        dims: &'a [usize],
399        a_strides: &'a [isize],
400        b_strides: &'a [isize],
401        plan: &ContiguousMulRangePlan,
402        outer_group: usize,
403    ) -> Self {
404        let axes: AxisVec<usize> = plan
405            .axis_order
406            .iter()
407            .skip(plan.outer_axis_start)
408            .copied()
409            .collect();
410        let mut coords = AxisVec::<usize>::with_capacity(axes.len());
411        let mut rem = outer_group;
412        let mut a_offset = 0isize;
413        let mut b_offset = 0isize;
414
415        for &axis in &axes {
416            let dim = dims[axis].max(1);
417            let coord = rem % dim;
418            rem /= dim;
419            coords.push(coord);
420            a_offset += coord as isize * a_strides[axis];
421            b_offset += coord as isize * b_strides[axis];
422        }
423
424        Self {
425            dims,
426            a_strides,
427            b_strides,
428            axes,
429            coords,
430            a_offset,
431            b_offset,
432        }
433    }
434
435    fn advance(&mut self) {
436        for (i, &axis) in self.axes.iter().enumerate() {
437            let dim = self.dims[axis].max(1);
438            if dim <= 1 {
439                continue;
440            }
441
442            self.coords[i] += 1;
443            self.a_offset += self.a_strides[axis];
444            self.b_offset += self.b_strides[axis];
445
446            if self.coords[i] < dim {
447                break;
448            }
449
450            self.coords[i] = 0;
451            self.a_offset -= dim as isize * self.a_strides[axis];
452            self.b_offset -= dim as isize * self.b_strides[axis];
453        }
454    }
455}
456
457#[inline(always)]
458unsafe fn run_contiguous_mul_row_block<
459    D: Copy + 'static,
460    A: Copy + Mul<B, Output = D> + 'static,
461    B: Copy + 'static,
462>(
463    dst_ptr: *mut D,
464    a_ptr: *const A,
465    b_ptr: *const B,
466    plan: &ContiguousMulRangePlan,
467    base_index: usize,
468    total: usize,
469    base_a_offset: isize,
470    base_b_offset: isize,
471) {
472    let inner_len = plan.inner_len.max(1);
473    let row_len = plan.row_len.max(1);
474    #[cfg(feature = "parallel")]
475    let block_len = inner_len.saturating_mul(row_len);
476
477    #[cfg(feature = "parallel")]
478    if total.saturating_sub(base_index) >= block_len {
479        match transposed_scalar_tile_kind(plan) {
480            Some(TransposedScalarTileKind::RhsScalar) => {
481                if simd::try_mul_transposed_scalar_rhs_2d::<D, A, B>(
482                    dst_ptr.add(base_index),
483                    a_ptr.offset(base_a_offset),
484                    b_ptr.offset(base_b_offset),
485                    inner_len,
486                    row_len,
487                    plan.a_fast_stride,
488                    plan.a_row_stride,
489                ) {
490                    return;
491                }
492            }
493            Some(TransposedScalarTileKind::LhsScalar) => {
494                if simd::try_mul_transposed_scalar_lhs_2d::<D, A, B>(
495                    dst_ptr.add(base_index),
496                    a_ptr.offset(base_a_offset),
497                    b_ptr.offset(base_b_offset),
498                    inner_len,
499                    row_len,
500                    plan.b_fast_stride,
501                    plan.b_row_stride,
502                ) {
503                    return;
504                }
505            }
506            None => {}
507        }
508    }
509
510    let mut index = base_index;
511    let mut a_offset = base_a_offset;
512    let mut b_offset = base_b_offset;
513
514    for _ in 0..row_len {
515        if index >= total {
516            break;
517        }
518        let len = inner_len.min(total - index);
519        inner_loop_mul2::<D, A, B>(
520            dst_ptr.add(index),
521            1,
522            a_ptr.offset(a_offset),
523            plan.a_fast_stride,
524            b_ptr.offset(b_offset),
525            plan.b_fast_stride,
526            len,
527        );
528        index += inner_len;
529        a_offset += plan.a_row_stride;
530        b_offset += plan.b_row_stride;
531    }
532}
533
534#[cfg(feature = "parallel")]
535fn strided_offset_for_contiguous_linear_index(
536    dims: &[usize],
537    strides: &[isize],
538    axis_order: &[usize],
539    mut index: usize,
540) -> isize {
541    let mut offset = 0isize;
542    for &axis in axis_order {
543        let dim = dims[axis];
544        if dim == 0 {
545            return 0;
546        }
547        let coord = index % dim;
548        index /= dim;
549        offset += coord as isize * strides[axis];
550    }
551    offset
552}
553
554fn try_contiguous_range_mul<
555    D: Copy + MaybeSendSync + 'static,
556    A: Copy + MaybeSendSync + Mul<B, Output = D> + 'static,
557    B: Copy + MaybeSendSync + 'static,
558>(
559    dst_ptr: *mut D,
560    dims: &[usize],
561    dst_strides: &[isize],
562    a_ptr: *const A,
563    a_strides: &[isize],
564    b_ptr: *const B,
565    b_strides: &[isize],
566) -> bool {
567    let total = total_len(dims);
568    if total == 0 {
569        return true;
570    }
571    if total <= CONTIGUOUS_RANGE_MIN_LEN {
572        return false;
573    }
574
575    let Some(plan) = contiguous_mul_range_plan(dims, dst_strides, a_strides, b_strides) else {
576        return false;
577    };
578
579    let inner_len = plan.inner_len.max(1);
580    let row_len = plan.row_len.max(1);
581    let block_len = inner_len.saturating_mul(row_len).max(1);
582    let outer_groups = total.div_ceil(block_len);
583
584    #[cfg(feature = "parallel")]
585    {
586        let nthreads = rayon::current_num_threads();
587        if nthreads > 1 {
588            use crate::threading::SendPtr;
589            use rayon::prelude::*;
590
591            let dst = SendPtr(dst_ptr);
592            let a = SendPtr(a_ptr as *mut A);
593            let b = SendPtr(b_ptr as *mut B);
594
595            if outer_groups < nthreads {
596                let chunk_len = total.div_ceil(nthreads);
597                let nchunks = total.div_ceil(chunk_len);
598
599                (0..nchunks).into_par_iter().for_each(|chunk| {
600                    let start = chunk * chunk_len;
601                    let end = (start + chunk_len).min(total);
602                    let mut index = start;
603
604                    while index < end {
605                        let in_inner = index % inner_len;
606                        let len = (inner_len - in_inner).min(end - index);
607                        let a_offset = strided_offset_for_contiguous_linear_index(
608                            dims,
609                            a_strides,
610                            &plan.axis_order,
611                            index,
612                        );
613                        let b_offset = strided_offset_for_contiguous_linear_index(
614                            dims,
615                            b_strides,
616                            &plan.axis_order,
617                            index,
618                        );
619
620                        unsafe {
621                            inner_loop_mul2::<D, A, B>(
622                                dst.as_ptr().add(index),
623                                1,
624                                a.as_const().offset(a_offset),
625                                plan.a_fast_stride,
626                                b.as_const().offset(b_offset),
627                                plan.b_fast_stride,
628                                len,
629                            );
630                        }
631                        index += len;
632                    }
633                });
634
635                return true;
636            }
637
638            let groups_per_chunk = outer_groups.div_ceil(nthreads);
639            let nchunks = outer_groups.div_ceil(groups_per_chunk);
640
641            (0..nchunks).into_par_iter().for_each(|chunk| {
642                let group_start = chunk * groups_per_chunk;
643                let group_end = (group_start + groups_per_chunk).min(outer_groups);
644                let mut cursor =
645                    ContiguousMulOuterCursor::new(dims, a_strides, b_strides, &plan, group_start);
646
647                for group in group_start..group_end {
648                    let index = group * block_len;
649                    unsafe {
650                        run_contiguous_mul_row_block::<D, A, B>(
651                            dst.as_ptr(),
652                            a.as_const(),
653                            b.as_const(),
654                            &plan,
655                            index,
656                            total,
657                            cursor.a_offset,
658                            cursor.b_offset,
659                        );
660                    }
661                    cursor.advance();
662                }
663            });
664
665            true
666        } else {
667            run_contiguous_range_mul_single_thread(
668                dst_ptr,
669                dims,
670                a_ptr,
671                a_strides,
672                b_ptr,
673                b_strides,
674                &plan,
675                total,
676                block_len,
677                outer_groups,
678            )
679        }
680    }
681
682    #[cfg(not(feature = "parallel"))]
683    {
684        run_contiguous_range_mul_single_thread(
685            dst_ptr,
686            dims,
687            a_ptr,
688            a_strides,
689            b_ptr,
690            b_strides,
691            &plan,
692            total,
693            block_len,
694            outer_groups,
695        )
696    }
697}
698
699fn run_contiguous_range_mul_single_thread<
700    D: Copy + MaybeSendSync + 'static,
701    A: Copy + MaybeSendSync + Mul<B, Output = D> + 'static,
702    B: Copy + MaybeSendSync + 'static,
703>(
704    dst_ptr: *mut D,
705    dims: &[usize],
706    a_ptr: *const A,
707    a_strides: &[isize],
708    b_ptr: *const B,
709    b_strides: &[isize],
710    plan: &ContiguousMulRangePlan,
711    total: usize,
712    block_len: usize,
713    outer_groups: usize,
714) -> bool {
715    let mut cursor = ContiguousMulOuterCursor::new(dims, a_strides, b_strides, plan, 0);
716    for group in 0..outer_groups {
717        let index = group * block_len;
718        unsafe {
719            run_contiguous_mul_row_block::<D, A, B>(
720                dst_ptr,
721                a_ptr,
722                b_ptr,
723                plan,
724                index,
725                total,
726                cursor.a_offset,
727                cursor.b_offset,
728            );
729        }
730        cursor.advance();
731    }
732    true
733}
734
735/// Ternary inner loop: `dest[i] = f(a[i], b[i], c[i])`.
736#[inline(always)]
737unsafe fn inner_loop_map3<
738    D: Copy,
739    A: Copy,
740    B: Copy,
741    C: Copy,
742    OpA: ElementOp<A>,
743    OpB: ElementOp<B>,
744    OpC: ElementOp<C>,
745>(
746    dp: *mut D,
747    ds: isize,
748    ap: *const A,
749    a_s: isize,
750    bp: *const B,
751    b_s: isize,
752    cp: *const C,
753    c_s: isize,
754    len: usize,
755    f: &impl Fn(A, B, C) -> D,
756) {
757    if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 {
758        let src_a = std::slice::from_raw_parts(ap, len);
759        let src_b = std::slice::from_raw_parts(bp, len);
760        let src_c = std::slice::from_raw_parts(cp, len);
761        let dst = std::slice::from_raw_parts_mut(dp, len);
762        simd::dispatch_if_large(len, || {
763            for i in 0..len {
764                dst[i] = f(
765                    OpA::apply(src_a[i]),
766                    OpB::apply(src_b[i]),
767                    OpC::apply(src_c[i]),
768                );
769            }
770        });
771    } else {
772        let mut dp = dp;
773        let mut ap = ap;
774        let mut bp = bp;
775        let mut cp = cp;
776        for _ in 0..len {
777            *dp = f(OpA::apply(*ap), OpB::apply(*bp), OpC::apply(*cp));
778            dp = dp.offset(ds);
779            ap = ap.offset(a_s);
780            bp = bp.offset(b_s);
781            cp = cp.offset(c_s);
782        }
783    }
784}
785
786/// Quaternary inner loop: `dest[i] = f(a[i], b[i], c[i], e[i])`.
787#[inline(always)]
788unsafe fn inner_loop_map4<
789    D: Copy,
790    A: Copy,
791    B: Copy,
792    C: Copy,
793    E: Copy,
794    OpA: ElementOp<A>,
795    OpB: ElementOp<B>,
796    OpC: ElementOp<C>,
797    OpE: ElementOp<E>,
798>(
799    dp: *mut D,
800    ds: isize,
801    ap: *const A,
802    a_s: isize,
803    bp: *const B,
804    b_s: isize,
805    cp: *const C,
806    c_s: isize,
807    ep: *const E,
808    e_s: isize,
809    len: usize,
810    f: &impl Fn(A, B, C, E) -> D,
811) {
812    if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 && e_s == 1 {
813        let src_a = std::slice::from_raw_parts(ap, len);
814        let src_b = std::slice::from_raw_parts(bp, len);
815        let src_c = std::slice::from_raw_parts(cp, len);
816        let src_e = std::slice::from_raw_parts(ep, len);
817        let dst = std::slice::from_raw_parts_mut(dp, len);
818        simd::dispatch_if_large(len, || {
819            for i in 0..len {
820                dst[i] = f(
821                    OpA::apply(src_a[i]),
822                    OpB::apply(src_b[i]),
823                    OpC::apply(src_c[i]),
824                    OpE::apply(src_e[i]),
825                );
826            }
827        });
828    } else {
829        let mut dp = dp;
830        let mut ap = ap;
831        let mut bp = bp;
832        let mut cp = cp;
833        let mut ep = ep;
834        for _ in 0..len {
835            *dp = f(
836                OpA::apply(*ap),
837                OpB::apply(*bp),
838                OpC::apply(*cp),
839                OpE::apply(*ep),
840            );
841            dp = dp.offset(ds);
842            ap = ap.offset(a_s);
843            bp = bp.offset(b_s);
844            cp = cp.offset(c_s);
845            ep = ep.offset(e_s);
846        }
847    }
848}
849
850/// Apply a function element-wise from source to destination.
851///
852/// The element operation `Op` is applied lazily when reading from `src`.
853/// Source and destination may have different element types.
854pub fn map_into<D: Copy + MaybeSendSync, A: Copy + MaybeSendSync, Op: ElementOp<A>>(
855    dest: &mut StridedViewMut<D>,
856    src: &StridedView<A, Op>,
857    f: impl Fn(A) -> D + MaybeSync,
858) -> Result<()> {
859    ensure_same_shape(dest.dims(), src.dims())?;
860
861    let dst_ptr = dest.as_mut_ptr();
862    let src_ptr = src.ptr();
863    let dst_dims = dest.dims();
864    let dst_strides = dest.strides();
865    let src_strides = src.strides();
866
867    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
868        let len = total_len(dst_dims);
869        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
870        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
871        simd::dispatch_if_large(len, || {
872            for i in 0..len {
873                dst[i] = f(Op::apply(src[i]));
874            }
875        });
876        return Ok(());
877    }
878
879    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
880    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<A>());
881    let total = total_len(dst_dims);
882
883    // Small tensor fast path: skip compute_order and compute_block_sizes
884    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
885        build_plan_fused_small(dst_dims, &strides_list)
886    } else {
887        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
888    };
889
890    #[cfg(feature = "parallel")]
891    {
892        let total: usize = fused_dims.iter().product();
893        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
894            use crate::threading::SendPtr;
895            let dst_send = SendPtr(dst_ptr);
896            let src_send = SendPtr(src_ptr as *mut A);
897
898            let costs = compute_costs(&ordered_strides);
899            let initial_offsets = vec![0isize; strides_list.len()];
900            let nthreads = rayon::current_num_threads();
901
902            return mapreduce_threaded(
903                &fused_dims,
904                &plan.block,
905                &ordered_strides,
906                &initial_offsets,
907                &costs,
908                nthreads,
909                0,
910                1,
911                &|dims, blocks, strides_list, offsets| {
912                    for_each_inner_block_with_offsets(
913                        dims,
914                        blocks,
915                        strides_list,
916                        offsets,
917                        |offsets, len, strides| {
918                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
919                            let sp = unsafe { src_send.as_const().offset(offsets[1]) };
920                            unsafe {
921                                inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f)
922                            };
923                            Ok(())
924                        },
925                    )
926                },
927            );
928        }
929    }
930
931    let initial_offsets = vec![0isize; ordered_strides.len()];
932    for_each_inner_block_preordered(
933        &fused_dims,
934        &plan.block,
935        &ordered_strides,
936        &initial_offsets,
937        |offsets, len, strides| {
938            let dp = unsafe { dst_ptr.offset(offsets[0]) };
939            let sp = unsafe { src_ptr.offset(offsets[1]) };
940            unsafe { inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f) };
941            Ok(())
942        },
943    )
944}
945
946/// Binary element-wise operation: `dest[i] = f(a[i], b[i])`.
947///
948/// Source operands `a` and `b` may have different element types from each other
949/// and from `dest`. The closure `f` handles per-element type conversion.
950pub fn zip_map2_into<
951    D: Copy + MaybeSendSync,
952    A: Copy + MaybeSendSync,
953    B: Copy + MaybeSendSync,
954    OpA: ElementOp<A>,
955    OpB: ElementOp<B>,
956>(
957    dest: &mut StridedViewMut<D>,
958    a: &StridedView<A, OpA>,
959    b: &StridedView<B, OpB>,
960    f: impl Fn(A, B) -> D + MaybeSync,
961) -> Result<()> {
962    ensure_same_shape(dest.dims(), a.dims())?;
963    ensure_same_shape(dest.dims(), b.dims())?;
964
965    let dst_ptr = dest.as_mut_ptr();
966    let dst_dims = dest.dims();
967    let dst_strides = dest.strides();
968    let a_ptr = a.ptr();
969    let b_ptr = b.ptr();
970
971    let a_strides = a.strides();
972    let b_strides = b.strides();
973
974    if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
975        let len = total_len(dst_dims);
976        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
977        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
978        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
979        simd::dispatch_if_large(len, || {
980            for i in 0..len {
981                dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]));
982            }
983        });
984        return Ok(());
985    }
986
987    let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
988    let elem_size = std::mem::size_of::<D>()
989        .max(std::mem::size_of::<A>())
990        .max(std::mem::size_of::<B>());
991    let total = total_len(dst_dims);
992
993    // Small tensor fast path: skip compute_order and compute_block_sizes
994    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
995        build_plan_fused_small(dst_dims, &strides_list)
996    } else {
997        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
998    };
999
1000    #[cfg(feature = "parallel")]
1001    {
1002        let total: usize = fused_dims.iter().product();
1003        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1004            use crate::threading::SendPtr;
1005            let dst_send = SendPtr(dst_ptr);
1006            let a_send = SendPtr(a_ptr as *mut A);
1007            let b_send = SendPtr(b_ptr as *mut B);
1008
1009            let costs = compute_costs(&ordered_strides);
1010            let initial_offsets = vec![0isize; strides_list.len()];
1011            let nthreads = rayon::current_num_threads();
1012
1013            return mapreduce_threaded(
1014                &fused_dims,
1015                &plan.block,
1016                &ordered_strides,
1017                &initial_offsets,
1018                &costs,
1019                nthreads,
1020                0,
1021                1,
1022                &|dims, blocks, strides_list, offsets| {
1023                    for_each_inner_block_with_offsets(
1024                        dims,
1025                        blocks,
1026                        strides_list,
1027                        offsets,
1028                        |offsets, len, strides| {
1029                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
1030                            let ap = unsafe { a_send.as_const().offset(offsets[1]) };
1031                            let bp = unsafe { b_send.as_const().offset(offsets[2]) };
1032                            unsafe {
1033                                inner_loop_map2::<D, A, B, OpA, OpB>(
1034                                    dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
1035                                )
1036                            };
1037                            Ok(())
1038                        },
1039                    )
1040                },
1041            );
1042        }
1043    }
1044
1045    let initial_offsets = vec![0isize; ordered_strides.len()];
1046    for_each_inner_block_preordered(
1047        &fused_dims,
1048        &plan.block,
1049        &ordered_strides,
1050        &initial_offsets,
1051        |offsets, len, strides| {
1052            let dp = unsafe { dst_ptr.offset(offsets[0]) };
1053            let ap = unsafe { a_ptr.offset(offsets[1]) };
1054            let bp = unsafe { b_ptr.offset(offsets[2]) };
1055            unsafe {
1056                inner_loop_map2::<D, A, B, OpA, OpB>(
1057                    dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
1058                )
1059            };
1060            Ok(())
1061        },
1062    )
1063}
1064
1065fn mul_identity_into_raw<
1066    D: Copy + MaybeSendSync + 'static,
1067    A: Copy + MaybeSendSync + Mul<B, Output = D> + 'static,
1068    B: Copy + MaybeSendSync + 'static,
1069>(
1070    dest: &mut StridedViewMut<D>,
1071    a_ptr: *const A,
1072    a_strides: &[isize],
1073    b_ptr: *const B,
1074    b_strides: &[isize],
1075) -> Result<()> {
1076    let dst_ptr = dest.as_mut_ptr();
1077    let dst_dims = dest.dims();
1078    let dst_strides = dest.strides();
1079    debug_assert_eq!(dst_dims.len(), a_strides.len());
1080    debug_assert_eq!(dst_dims.len(), b_strides.len());
1081
1082    if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
1083        let len = total_len(dst_dims);
1084        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
1085        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
1086        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
1087        if simd::try_mul_contiguous(dst, sa, sb) {
1088            return Ok(());
1089        }
1090        for i in 0..len {
1091            dst[i] = sa[i] * sb[i];
1092        }
1093        return Ok(());
1094    }
1095
1096    let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
1097    let elem_size = std::mem::size_of::<D>()
1098        .max(std::mem::size_of::<A>())
1099        .max(std::mem::size_of::<B>());
1100    let total = total_len(dst_dims);
1101
1102    if try_contiguous_range_mul(
1103        dst_ptr,
1104        dst_dims,
1105        dst_strides,
1106        a_ptr,
1107        a_strides,
1108        b_ptr,
1109        b_strides,
1110    ) {
1111        return Ok(());
1112    }
1113
1114    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
1115        build_plan_fused_small(dst_dims, &strides_list)
1116    } else {
1117        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
1118    };
1119
1120    #[cfg(feature = "parallel")]
1121    {
1122        let total: usize = fused_dims.iter().product();
1123        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1124            use crate::threading::SendPtr;
1125            let dst_send = SendPtr(dst_ptr);
1126            let a_send = SendPtr(a_ptr as *mut A);
1127            let b_send = SendPtr(b_ptr as *mut B);
1128
1129            let costs = compute_costs(&ordered_strides);
1130            let initial_offsets = vec![0isize; strides_list.len()];
1131            let nthreads = rayon::current_num_threads();
1132
1133            return mapreduce_threaded(
1134                &fused_dims,
1135                &plan.block,
1136                &ordered_strides,
1137                &initial_offsets,
1138                &costs,
1139                nthreads,
1140                0,
1141                1,
1142                &|dims, blocks, strides_list, offsets| {
1143                    for_each_inner_block_with_offsets(
1144                        dims,
1145                        blocks,
1146                        strides_list,
1147                        offsets,
1148                        |offsets, len, strides| {
1149                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
1150                            let ap = unsafe { a_send.as_const().offset(offsets[1]) };
1151                            let bp = unsafe { b_send.as_const().offset(offsets[2]) };
1152                            unsafe {
1153                                inner_loop_mul2::<D, A, B>(
1154                                    dp, strides[0], ap, strides[1], bp, strides[2], len,
1155                                )
1156                            };
1157                            Ok(())
1158                        },
1159                    )
1160                },
1161            );
1162        }
1163    }
1164
1165    let initial_offsets = vec![0isize; ordered_strides.len()];
1166    for_each_inner_block_preordered(
1167        &fused_dims,
1168        &plan.block,
1169        &ordered_strides,
1170        &initial_offsets,
1171        |offsets, len, strides| {
1172            let dp = unsafe { dst_ptr.offset(offsets[0]) };
1173            let ap = unsafe { a_ptr.offset(offsets[1]) };
1174            let bp = unsafe { b_ptr.offset(offsets[2]) };
1175            unsafe {
1176                inner_loop_mul2::<D, A, B>(dp, strides[0], ap, strides[1], bp, strides[2], len)
1177            };
1178            Ok(())
1179        },
1180    )
1181}
1182
1183fn mul_identity_into<
1184    D: Copy + MaybeSendSync + 'static,
1185    A: Copy + Mul<B, Output = D> + MaybeSendSync + 'static,
1186    B: Copy + MaybeSendSync + 'static,
1187    OpA: ElementOp<A>,
1188    OpB: ElementOp<B>,
1189>(
1190    dest: &mut StridedViewMut<D>,
1191    a: &StridedView<A, OpA>,
1192    b: &StridedView<B, OpB>,
1193) -> Result<()> {
1194    ensure_same_shape(dest.dims(), a.dims())?;
1195    ensure_same_shape(dest.dims(), b.dims())?;
1196    mul_identity_into_raw(dest, a.ptr(), a.strides(), b.ptr(), b.strides())
1197}
1198
1199/// Element-wise multiplication: `dest[i] = a[i] * b[i]`.
1200///
1201/// All views must have the same shape. Broadcast operands should be represented
1202/// as stride-0 views before calling this function.
1203pub fn mul_into<
1204    D: Copy + MaybeSendSync + 'static,
1205    A: Copy + Mul<B, Output = D> + MaybeSendSync + 'static,
1206    B: Copy + MaybeSendSync + 'static,
1207    OpA: ElementOp<A>,
1208    OpB: ElementOp<B>,
1209>(
1210    dest: &mut StridedViewMut<D>,
1211    a: &StridedView<A, OpA>,
1212    b: &StridedView<B, OpB>,
1213) -> Result<()> {
1214    if OpA::IS_IDENTITY && OpB::IS_IDENTITY {
1215        return mul_identity_into(dest, a, b);
1216    }
1217
1218    zip_map2_into(dest, a, b, |x, y| x * y)
1219}
1220
1221fn broadcast_strides_for_axes(
1222    source_dims: &[usize],
1223    source_strides: &[isize],
1224    target_dims: &[usize],
1225    axes: &[usize],
1226) -> Result<AxisVec<isize>> {
1227    if source_dims.len() != axes.len() {
1228        return Err(StridedError::RankMismatch(source_dims.len(), axes.len()));
1229    }
1230    debug_assert_eq!(source_dims.len(), source_strides.len());
1231
1232    let mut seen = AxisVec::<bool>::new();
1233    seen.resize(target_dims.len(), false);
1234    let mut strides = AxisVec::<isize>::new();
1235    strides.resize(target_dims.len(), 0);
1236    for (src_axis, &dst_axis) in axes.iter().enumerate() {
1237        if dst_axis >= target_dims.len() {
1238            return Err(StridedError::InvalidAxis {
1239                axis: dst_axis,
1240                rank: target_dims.len(),
1241            });
1242        }
1243        if seen[dst_axis] {
1244            return Err(StridedError::InvalidAxis {
1245                axis: dst_axis,
1246                rank: target_dims.len(),
1247            });
1248        }
1249        seen[dst_axis] = true;
1250
1251        let source_dim = source_dims[src_axis];
1252        let target_dim = target_dims[dst_axis];
1253        if source_dim != target_dim && source_dim != 1 {
1254            return Err(StridedError::ShapeMismatch(
1255                source_dims.to_vec(),
1256                target_dims.to_vec(),
1257            ));
1258        }
1259        if source_dim == target_dim {
1260            strides[dst_axis] = source_strides[src_axis];
1261        }
1262    }
1263
1264    Ok(strides)
1265}
1266
1267fn broadcast_view_with_strides<'a, T, Op: ElementOp<T>>(
1268    view: &StridedView<'a, T, Op>,
1269    target_dims: &[usize],
1270    strides: &[isize],
1271) -> StridedView<'a, T, Op> {
1272    unsafe { StridedView::new_unchecked(view.data(), target_dims, strides, view.offset()) }
1273}
1274
1275/// Broadcasted element-wise multiplication: `dest[i] = a[i] * b[i]`.
1276///
1277/// `a_axes` and `b_axes` map each source axis to an axis of `dest`. Output axes
1278/// not referenced by a source operand are treated as stride-0 broadcast axes.
1279pub fn broadcast_mul_into<
1280    D: Copy + MaybeSendSync + 'static,
1281    A: Copy + Mul<B, Output = D> + MaybeSendSync + 'static,
1282    B: Copy + MaybeSendSync + 'static,
1283    OpA: ElementOp<A>,
1284    OpB: ElementOp<B>,
1285>(
1286    dest: &mut StridedViewMut<D>,
1287    a: &StridedView<A, OpA>,
1288    a_axes: &[usize],
1289    b: &StridedView<B, OpB>,
1290    b_axes: &[usize],
1291) -> Result<()> {
1292    let a_strides = broadcast_strides_for_axes(a.dims(), a.strides(), dest.dims(), a_axes)?;
1293    let b_strides = broadcast_strides_for_axes(b.dims(), b.strides(), dest.dims(), b_axes)?;
1294
1295    if OpA::IS_IDENTITY && OpB::IS_IDENTITY {
1296        return mul_identity_into_raw(dest, a.ptr(), &a_strides, b.ptr(), &b_strides);
1297    }
1298
1299    let a = broadcast_view_with_strides(a, dest.dims(), &a_strides);
1300    let b = broadcast_view_with_strides(b, dest.dims(), &b_strides);
1301    mul_into(dest, &a, &b)
1302}
1303
1304/// Ternary element-wise operation: `dest[i] = f(a[i], b[i], c[i])`.
1305pub fn zip_map3_into<
1306    D: Copy + MaybeSendSync,
1307    A: Copy + MaybeSendSync,
1308    B: Copy + MaybeSendSync,
1309    C: Copy + MaybeSendSync,
1310    OpA: ElementOp<A>,
1311    OpB: ElementOp<B>,
1312    OpC: ElementOp<C>,
1313>(
1314    dest: &mut StridedViewMut<D>,
1315    a: &StridedView<A, OpA>,
1316    b: &StridedView<B, OpB>,
1317    c: &StridedView<C, OpC>,
1318    f: impl Fn(A, B, C) -> D + MaybeSync,
1319) -> Result<()> {
1320    ensure_same_shape(dest.dims(), a.dims())?;
1321    ensure_same_shape(dest.dims(), b.dims())?;
1322    ensure_same_shape(dest.dims(), c.dims())?;
1323
1324    let dst_ptr = dest.as_mut_ptr();
1325    let a_ptr = a.ptr();
1326    let b_ptr = b.ptr();
1327    let c_ptr = c.ptr();
1328
1329    let dst_dims = dest.dims();
1330    let dst_strides = dest.strides();
1331
1332    if sequential_contiguous_layout(
1333        dst_dims,
1334        &[dst_strides, a.strides(), b.strides(), c.strides()],
1335    )
1336    .is_some()
1337    {
1338        let len = total_len(dst_dims);
1339        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
1340        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
1341        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
1342        let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
1343        simd::dispatch_if_large(len, || {
1344            for i in 0..len {
1345                dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]), OpC::apply(sc[i]));
1346            }
1347        });
1348        return Ok(());
1349    }
1350
1351    let strides_list: [&[isize]; 4] = [dst_strides, a.strides(), b.strides(), c.strides()];
1352    let elem_size = std::mem::size_of::<D>()
1353        .max(std::mem::size_of::<A>())
1354        .max(std::mem::size_of::<B>())
1355        .max(std::mem::size_of::<C>());
1356    let total = total_len(dst_dims);
1357
1358    // Small tensor fast path: skip compute_order and compute_block_sizes
1359    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
1360        build_plan_fused_small(dst_dims, &strides_list)
1361    } else {
1362        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
1363    };
1364
1365    #[cfg(feature = "parallel")]
1366    {
1367        let total: usize = fused_dims.iter().product();
1368        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1369            use crate::threading::SendPtr;
1370            let dst_send = SendPtr(dst_ptr);
1371            let a_send = SendPtr(a_ptr as *mut A);
1372            let b_send = SendPtr(b_ptr as *mut B);
1373            let c_send = SendPtr(c_ptr as *mut C);
1374
1375            let costs = compute_costs(&ordered_strides);
1376            let initial_offsets = vec![0isize; strides_list.len()];
1377            let nthreads = rayon::current_num_threads();
1378
1379            return mapreduce_threaded(
1380                &fused_dims,
1381                &plan.block,
1382                &ordered_strides,
1383                &initial_offsets,
1384                &costs,
1385                nthreads,
1386                0,
1387                1,
1388                &|dims, blocks, strides_list, offsets| {
1389                    for_each_inner_block_with_offsets(
1390                        dims,
1391                        blocks,
1392                        strides_list,
1393                        offsets,
1394                        |offsets, len, strides| {
1395                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
1396                            let ap = unsafe { a_send.as_const().offset(offsets[1]) };
1397                            let bp = unsafe { b_send.as_const().offset(offsets[2]) };
1398                            let cp = unsafe { c_send.as_const().offset(offsets[3]) };
1399                            unsafe {
1400                                inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
1401                                    dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
1402                                    len, &f,
1403                                )
1404                            };
1405                            Ok(())
1406                        },
1407                    )
1408                },
1409            );
1410        }
1411    }
1412
1413    let initial_offsets = vec![0isize; ordered_strides.len()];
1414    for_each_inner_block_preordered(
1415        &fused_dims,
1416        &plan.block,
1417        &ordered_strides,
1418        &initial_offsets,
1419        |offsets, len, strides| {
1420            let dp = unsafe { dst_ptr.offset(offsets[0]) };
1421            let ap = unsafe { a_ptr.offset(offsets[1]) };
1422            let bp = unsafe { b_ptr.offset(offsets[2]) };
1423            let cp = unsafe { c_ptr.offset(offsets[3]) };
1424            unsafe {
1425                inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
1426                    dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], len, &f,
1427                )
1428            };
1429            Ok(())
1430        },
1431    )
1432}
1433
1434/// Quaternary element-wise operation: `dest[i] = f(a[i], b[i], c[i], e[i])`.
1435pub fn zip_map4_into<
1436    D: Copy + MaybeSendSync,
1437    A: Copy + MaybeSendSync,
1438    B: Copy + MaybeSendSync,
1439    C: Copy + MaybeSendSync,
1440    E: Copy + MaybeSendSync,
1441    OpA: ElementOp<A>,
1442    OpB: ElementOp<B>,
1443    OpC: ElementOp<C>,
1444    OpE: ElementOp<E>,
1445>(
1446    dest: &mut StridedViewMut<D>,
1447    a: &StridedView<A, OpA>,
1448    b: &StridedView<B, OpB>,
1449    c: &StridedView<C, OpC>,
1450    e: &StridedView<E, OpE>,
1451    f: impl Fn(A, B, C, E) -> D + MaybeSync,
1452) -> Result<()> {
1453    ensure_same_shape(dest.dims(), a.dims())?;
1454    ensure_same_shape(dest.dims(), b.dims())?;
1455    ensure_same_shape(dest.dims(), c.dims())?;
1456    ensure_same_shape(dest.dims(), e.dims())?;
1457
1458    let dst_ptr = dest.as_mut_ptr();
1459    let a_ptr = a.ptr();
1460    let b_ptr = b.ptr();
1461    let c_ptr = c.ptr();
1462    let e_ptr = e.ptr();
1463
1464    let dst_dims = dest.dims();
1465    let dst_strides = dest.strides();
1466
1467    if sequential_contiguous_layout(
1468        dst_dims,
1469        &[
1470            dst_strides,
1471            a.strides(),
1472            b.strides(),
1473            c.strides(),
1474            e.strides(),
1475        ],
1476    )
1477    .is_some()
1478    {
1479        let len = total_len(dst_dims);
1480        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
1481        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
1482        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
1483        let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
1484        let se = unsafe { std::slice::from_raw_parts(e_ptr, len) };
1485        simd::dispatch_if_large(len, || {
1486            for i in 0..len {
1487                dst[i] = f(
1488                    OpA::apply(sa[i]),
1489                    OpB::apply(sb[i]),
1490                    OpC::apply(sc[i]),
1491                    OpE::apply(se[i]),
1492                );
1493            }
1494        });
1495        return Ok(());
1496    }
1497
1498    let strides_list: [&[isize]; 5] = [
1499        dst_strides,
1500        a.strides(),
1501        b.strides(),
1502        c.strides(),
1503        e.strides(),
1504    ];
1505    let elem_size = std::mem::size_of::<D>()
1506        .max(std::mem::size_of::<A>())
1507        .max(std::mem::size_of::<B>())
1508        .max(std::mem::size_of::<C>())
1509        .max(std::mem::size_of::<E>());
1510    let total = total_len(dst_dims);
1511
1512    // Small tensor fast path: skip compute_order and compute_block_sizes
1513    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
1514        build_plan_fused_small(dst_dims, &strides_list)
1515    } else {
1516        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
1517    };
1518
1519    #[cfg(feature = "parallel")]
1520    {
1521        let total: usize = fused_dims.iter().product();
1522        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1523            use crate::threading::SendPtr;
1524            let dst_send = SendPtr(dst_ptr);
1525            let a_send = SendPtr(a_ptr as *mut A);
1526            let b_send = SendPtr(b_ptr as *mut B);
1527            let c_send = SendPtr(c_ptr as *mut C);
1528            let e_send = SendPtr(e_ptr as *mut E);
1529
1530            let costs = compute_costs(&ordered_strides);
1531            let initial_offsets = vec![0isize; strides_list.len()];
1532            let nthreads = rayon::current_num_threads();
1533
1534            return mapreduce_threaded(
1535                &fused_dims,
1536                &plan.block,
1537                &ordered_strides,
1538                &initial_offsets,
1539                &costs,
1540                nthreads,
1541                0,
1542                1,
1543                &|dims, blocks, strides_list, offsets| {
1544                    for_each_inner_block_with_offsets(
1545                        dims,
1546                        blocks,
1547                        strides_list,
1548                        offsets,
1549                        |offsets, len, strides| {
1550                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
1551                            let ap = unsafe { a_send.as_const().offset(offsets[1]) };
1552                            let bp = unsafe { b_send.as_const().offset(offsets[2]) };
1553                            let cp = unsafe { c_send.as_const().offset(offsets[3]) };
1554                            let ep = unsafe { e_send.as_const().offset(offsets[4]) };
1555                            unsafe {
1556                                inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
1557                                    dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
1558                                    ep, strides[4], len, &f,
1559                                )
1560                            };
1561                            Ok(())
1562                        },
1563                    )
1564                },
1565            );
1566        }
1567    }
1568
1569    let initial_offsets = vec![0isize; ordered_strides.len()];
1570    for_each_inner_block_preordered(
1571        &fused_dims,
1572        &plan.block,
1573        &ordered_strides,
1574        &initial_offsets,
1575        |offsets, len, strides| {
1576            let dp = unsafe { dst_ptr.offset(offsets[0]) };
1577            let ap = unsafe { a_ptr.offset(offsets[1]) };
1578            let bp = unsafe { b_ptr.offset(offsets[2]) };
1579            let cp = unsafe { c_ptr.offset(offsets[3]) };
1580            let ep = unsafe { e_ptr.offset(offsets[4]) };
1581            unsafe {
1582                inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
1583                    dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], ep, strides[4],
1584                    len, &f,
1585                )
1586            };
1587            Ok(())
1588        },
1589    )
1590}
1591
1592#[cfg(test)]
1593mod scalar_branch_tests {
1594    use super::*;
1595    use crate::StridedArray;
1596    use strided_view::Identity;
1597
1598    #[test]
1599    fn test_inner_loop_map2_stride_specializations() {
1600        let a = [2.0, 3.0, 5.0, 7.0, 11.0, 13.0];
1601        let b = [17.0, 19.0, 23.0, 29.0, 31.0, 37.0];
1602
1603        let mut out = [0.0; 3];
1604        unsafe {
1605            inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1606                out.as_mut_ptr(),
1607                1,
1608                a.as_ptr(),
1609                1,
1610                b.as_ptr(),
1611                1,
1612                3,
1613                &|x, y| x + y,
1614            );
1615        }
1616        assert_eq!(out, [19.0, 22.0, 28.0]);
1617
1618        let mut out = [0.0; 3];
1619        unsafe {
1620            inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1621                out.as_mut_ptr(),
1622                1,
1623                a.as_ptr(),
1624                1,
1625                b.as_ptr(),
1626                0,
1627                3,
1628                &|x, y| x * y,
1629            );
1630        }
1631        assert_eq!(out, [34.0, 51.0, 85.0]);
1632
1633        let mut out = [0.0; 3];
1634        unsafe {
1635            inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1636                out.as_mut_ptr(),
1637                1,
1638                a.as_ptr(),
1639                0,
1640                b.as_ptr(),
1641                1,
1642                3,
1643                &|x, y| x * y,
1644            );
1645        }
1646        assert_eq!(out, [34.0, 38.0, 46.0]);
1647
1648        let mut out = [0.0; 3];
1649        unsafe {
1650            inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1651                out.as_mut_ptr(),
1652                1,
1653                a.as_ptr(),
1654                0,
1655                b.as_ptr(),
1656                0,
1657                3,
1658                &|x, y| x + y,
1659            );
1660        }
1661        assert_eq!(out, [19.0, 19.0, 19.0]);
1662
1663        let mut out = [0.0; 3];
1664        unsafe {
1665            inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1666                out.as_mut_ptr(),
1667                1,
1668                a.as_ptr(),
1669                2,
1670                b.as_ptr(),
1671                0,
1672                3,
1673                &|x, y| x + y,
1674            );
1675        }
1676        assert_eq!(out, [19.0, 22.0, 28.0]);
1677
1678        let mut out = [0.0; 3];
1679        unsafe {
1680            inner_loop_map2::<f64, f64, f64, Identity, Identity>(
1681                out.as_mut_ptr(),
1682                1,
1683                a.as_ptr(),
1684                0,
1685                b.as_ptr(),
1686                2,
1687                3,
1688                &|x, y| x + y,
1689            );
1690        }
1691        assert_eq!(out, [19.0, 25.0, 33.0]);
1692    }
1693
1694    #[test]
1695    fn test_inner_loop_mul2_stride_specializations() {
1696        let a = [2.0, 3.0, 5.0, 7.0, 11.0, 13.0];
1697        let b = [17.0, 19.0, 23.0, 29.0, 31.0, 37.0];
1698
1699        let mut out = [0.0; 3];
1700        unsafe {
1701            inner_loop_mul2::<f64, f64, f64>(out.as_mut_ptr(), 1, a.as_ptr(), 1, b.as_ptr(), 1, 3);
1702        }
1703        assert_eq!(out, [34.0, 57.0, 115.0]);
1704
1705        let mut out = [0.0; 3];
1706        unsafe {
1707            inner_loop_mul2::<f64, f64, f64>(out.as_mut_ptr(), 1, a.as_ptr(), 0, b.as_ptr(), 1, 3);
1708        }
1709        assert_eq!(out, [34.0, 38.0, 46.0]);
1710
1711        let mut out = [0.0; 3];
1712        unsafe {
1713            inner_loop_mul2::<f64, f64, f64>(out.as_mut_ptr(), 1, a.as_ptr(), 0, b.as_ptr(), 0, 3);
1714        }
1715        assert_eq!(out, [34.0, 34.0, 34.0]);
1716
1717        let mut out = [0.0; 3];
1718        unsafe {
1719            inner_loop_mul2::<f64, f64, f64>(out.as_mut_ptr(), 1, a.as_ptr(), 0, b.as_ptr(), 2, 3);
1720        }
1721        assert_eq!(out, [34.0, 46.0, 62.0]);
1722    }
1723
1724    #[test]
1725    fn test_broadcast_mul_into_error_branches_and_non_identity_ops() {
1726        let lhs = StridedArray::<f64>::row_major(&[2, 3]);
1727        let rhs = StridedArray::<f64>::row_major(&[2, 3]);
1728        let mut out = StridedArray::<f64>::row_major(&[2, 3]);
1729
1730        let err = broadcast_mul_into(&mut out.view_mut(), &lhs.view(), &[0], &rhs.view(), &[0, 1])
1731            .unwrap_err();
1732        assert!(matches!(err, StridedError::RankMismatch(2, 1)));
1733
1734        let err = broadcast_mul_into(
1735            &mut out.view_mut(),
1736            &lhs.view(),
1737            &[0, 3],
1738            &rhs.view(),
1739            &[0, 1],
1740        )
1741        .unwrap_err();
1742        assert!(matches!(
1743            err,
1744            StridedError::InvalidAxis { axis: 3, rank: 2 }
1745        ));
1746
1747        let err = broadcast_mul_into(
1748            &mut out.view_mut(),
1749            &lhs.view(),
1750            &[0, 0],
1751            &rhs.view(),
1752            &[0, 1],
1753        )
1754        .unwrap_err();
1755        assert!(matches!(
1756            err,
1757            StridedError::InvalidAxis { axis: 0, rank: 2 }
1758        ));
1759
1760        let rhs_bad = StridedArray::<f64>::row_major(&[2, 4]);
1761        let err = broadcast_mul_into(
1762            &mut out.view_mut(),
1763            &lhs.view(),
1764            &[0, 1],
1765            &rhs_bad.view(),
1766            &[0, 1],
1767        )
1768        .unwrap_err();
1769        assert!(matches!(err, StridedError::ShapeMismatch(_, _)));
1770
1771        let lhs_conj = lhs.view().conj();
1772        broadcast_mul_into(
1773            &mut out.view_mut(),
1774            &lhs_conj,
1775            &[0, 1],
1776            &rhs.view(),
1777            &[0, 1],
1778        )
1779        .unwrap();
1780    }
1781
1782    #[test]
1783    fn contiguous_mul_range_plan_available_without_parallel_feature() {
1784        let dims = [3usize; 16];
1785        let dst = [
1786            1, 3, 9, 27, 81, 243, 729, 2187, 6561, 19683, 59049, 177147, 531441, 1594323, 4782969,
1787            14348907,
1788        ];
1789        let lhs = [1isize, 3, 9, 27, 81, 243, 729, 2187, 0, 0, 0, 0, 0, 0, 0, 0];
1790        let rhs = [0isize, 0, 0, 0, 0, 0, 0, 0, 1, 3, 9, 27, 81, 243, 729, 2187];
1791
1792        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1793
1794        assert_eq!(plan.inner_len, 6561);
1795        assert_eq!(plan.row_len, 3);
1796        assert_eq!(plan.fast_axis, 0);
1797    }
1798
1799    #[test]
1800    fn contiguous_range_mul_single_thread_computes_large_broadcast_mul() {
1801        let dims = [3usize; 10];
1802        let dst = [1isize, 3, 9, 27, 81, 243, 729, 2187, 6561, 19683];
1803        let lhs = [1isize, 3, 9, 27, 81, 0, 0, 0, 0, 0];
1804        let rhs = [0isize, 0, 0, 0, 0, 1, 3, 9, 27, 81];
1805        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1806        let total = total_len(&dims);
1807        let block_len = plan.inner_len.max(1).saturating_mul(plan.row_len.max(1));
1808        let outer_groups = total.div_ceil(block_len);
1809
1810        let a = vec![2.0; 243];
1811        let b = vec![3.0; 243];
1812        let mut out = vec![0.0; total];
1813
1814        assert!(run_contiguous_range_mul_single_thread::<f64, f64, f64>(
1815            out.as_mut_ptr(),
1816            &dims,
1817            a.as_ptr(),
1818            &lhs,
1819            b.as_ptr(),
1820            &rhs,
1821            &plan,
1822            total,
1823            block_len,
1824            outer_groups,
1825        ));
1826        assert!(out.iter().all(|&x| x == 6.0));
1827    }
1828}
1829
1830#[cfg(all(test, feature = "parallel"))]
1831mod tests {
1832    use super::*;
1833
1834    fn compact_strides_for_axis_order<const N: usize>(
1835        dims: [usize; N],
1836        axis_order: [usize; N],
1837    ) -> [isize; N] {
1838        let mut strides = [0isize; N];
1839        let mut stride = 1isize;
1840        for &axis in &axis_order {
1841            strides[axis] = stride;
1842            stride *= dims[axis] as isize;
1843        }
1844        strides
1845    }
1846
1847    #[test]
1848    fn test_contiguous_mul_range_plan_pure_outer() {
1849        let dims = [7usize, 11];
1850        let dst = [1isize, 7];
1851        let lhs = [1isize, 0];
1852        let rhs = [0isize, 1];
1853
1854        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1855
1856        assert_eq!(plan.inner_len, 7);
1857        assert_eq!(plan.row_len, 11);
1858        assert_eq!(plan.fast_axis, 0);
1859        assert_eq!(plan.a_fast_stride, 1);
1860        assert_eq!(plan.b_fast_stride, 0);
1861        assert_eq!(plan.a_row_stride, 0);
1862        assert_eq!(plan.b_row_stride, 1);
1863    }
1864
1865    #[test]
1866    fn test_compact_axis_order_accepts_all_rank4_axis_permutations() {
1867        fn visit(dims: [usize; 4], axes: &mut [usize; 4], pos: usize, count: &mut usize) {
1868            if pos == axes.len() {
1869                let dst = compact_strides_for_axis_order(dims, *axes);
1870                let axis_order = compact_axis_order(&dims, &dst).unwrap();
1871                assert_eq!(&axis_order[..], &axes[..]);
1872                *count += 1;
1873                return;
1874            }
1875
1876            for i in pos..axes.len() {
1877                axes.swap(pos, i);
1878                visit(dims, axes, pos + 1, count);
1879                axes.swap(pos, i);
1880            }
1881        }
1882
1883        let dims = [2usize, 3, 5, 7];
1884        let mut axes = [0usize, 1, 2, 3];
1885        let mut count = 0usize;
1886        visit(dims, &mut axes, 0, &mut count);
1887
1888        assert_eq!(count, 24);
1889    }
1890
1891    #[test]
1892    fn test_compact_axis_order_rejects_strided_layout_with_holes() {
1893        let dims = [2usize, 3, 5];
1894        let strides = [1isize, 4, 2];
1895
1896        assert_eq!(compact_axis_order(&dims, &strides), None);
1897    }
1898
1899    #[test]
1900    fn test_contiguous_mul_range_plan_uses_permuted_compact_output_for_unrelated_shape() {
1901        let dims = [2usize, 3, 5, 7, 11];
1902        let dst = compact_strides_for_axis_order(dims, [2usize, 0, 4, 1, 3]);
1903        let lhs = [5isize, 0, 1, 0, 10];
1904        let rhs = [0isize, 1, 0, 3, 0];
1905
1906        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1907
1908        assert_eq!(&plan.axis_order[..], &[2, 0, 4, 1, 3]);
1909        assert_eq!(plan.inner_len, 110);
1910        assert_eq!(plan.row_len, 3);
1911        assert_eq!(plan.fast_axis, 2);
1912        assert_eq!(plan.a_fast_stride, 1);
1913        assert_eq!(plan.b_fast_stride, 0);
1914        assert_eq!(plan.a_row_stride, 0);
1915        assert_eq!(plan.b_row_stride, 1);
1916        assert_eq!(transposed_scalar_tile_kind(&plan), None);
1917    }
1918
1919    #[test]
1920    fn test_contiguous_mul_range_plan_compact_batched_outer() {
1921        let dims = [3usize, 5, 7, 11];
1922        let dst = [1isize, 3, 15, 105];
1923        let lhs = [1isize, 3, 0, 15];
1924        let rhs = [0isize, 0, 1, 7];
1925
1926        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1927
1928        assert_eq!(plan.inner_len, 15);
1929        assert_eq!(plan.row_len, 7);
1930        assert_eq!(plan.fast_axis, 0);
1931        assert_eq!(plan.a_fast_stride, 1);
1932        assert_eq!(plan.b_fast_stride, 0);
1933        assert_eq!(plan.a_row_stride, 0);
1934        assert_eq!(plan.b_row_stride, 1);
1935    }
1936
1937    #[test]
1938    fn test_contiguous_mul_range_plan_noncompact_batched_outer() {
1939        let dims = [5usize, 5, 7, 11];
1940        let dst = [1isize, 5, 25, 175];
1941        let lhs = [5isize, 1, 0, 25];
1942        let rhs = [0isize, 0, 1, 7];
1943
1944        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1945
1946        assert_eq!(plan.inner_len, 5);
1947        assert_eq!(plan.row_len, 5);
1948        assert_eq!(plan.fast_axis, 0);
1949        assert_eq!(plan.a_fast_stride, 5);
1950        assert_eq!(plan.b_fast_stride, 0);
1951        assert_eq!(plan.a_row_stride, 1);
1952        assert_eq!(plan.b_row_stride, 0);
1953    }
1954
1955    #[test]
1956    fn test_contiguous_mul_range_plan_noncompact_row_major_output() {
1957        let dims = [5usize, 5, 7, 11];
1958        let dst = [5isize, 1, 25, 175];
1959        let lhs = [5isize, 1, 0, 25];
1960        let rhs = [0isize, 0, 1, 7];
1961
1962        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
1963
1964        assert_eq!(plan.inner_len, 25);
1965        assert_eq!(plan.row_len, 7);
1966        assert_eq!(plan.fast_axis, 1);
1967        assert_eq!(plan.a_fast_stride, 1);
1968        assert_eq!(plan.b_fast_stride, 0);
1969        assert_eq!(plan.a_row_stride, 0);
1970        assert_eq!(plan.b_row_stride, 1);
1971        assert_eq!(transposed_scalar_tile_kind(&plan), None);
1972    }
1973
1974    #[test]
1975    fn test_broadcast_strides_for_axes_batched_outer() {
1976        let target_dims = [3usize, 5, 7, 11];
1977        let lhs_dims = [3usize, 5, 11];
1978        let lhs_strides = [3isize, 1, 15];
1979        let rhs_dims = [7usize, 11];
1980        let rhs_strides = [1isize, 7];
1981
1982        let lhs =
1983            broadcast_strides_for_axes(&lhs_dims, &lhs_strides, &target_dims, &[0, 1, 3]).unwrap();
1984        let rhs =
1985            broadcast_strides_for_axes(&rhs_dims, &rhs_strides, &target_dims, &[2, 3]).unwrap();
1986
1987        assert_eq!(&lhs[..], &[3, 1, 0, 15]);
1988        assert_eq!(&rhs[..], &[0, 0, 1, 7]);
1989    }
1990
1991    #[test]
1992    fn test_broadcast_strides_for_axes_uses_zero_stride_for_size_one_source_dim() {
1993        let target_dims = [8usize, 4];
1994        let source_dims = [1usize, 4];
1995        let source_strides = [1isize, 1];
1996
1997        let strides =
1998            broadcast_strides_for_axes(&source_dims, &source_strides, &target_dims, &[0, 1])
1999                .unwrap();
2000
2001        assert_eq!(&strides[..], &[0, 1]);
2002    }
2003
2004    #[test]
2005    fn test_transposed_scalar_tile_kind_detects_noncompact_rhs_scalar() {
2006        let dims = [5usize, 5, 7, 11];
2007        let dst = [1isize, 5, 25, 175];
2008        let lhs = [5isize, 1, 0, 25];
2009        let rhs = [0isize, 0, 1, 7];
2010
2011        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
2012
2013        assert_eq!(
2014            transposed_scalar_tile_kind(&plan),
2015            Some(TransposedScalarTileKind::RhsScalar)
2016        );
2017    }
2018
2019    #[test]
2020    fn test_contiguous_mul_outer_cursor_matches_linear_offsets() {
2021        let dims = [16usize, 16, 64, 64];
2022        let dst = [1isize, 16, 256, 16_384];
2023        let lhs = [16isize, 1, 0, 256];
2024        let rhs = [0isize, 0, 1, 64];
2025        let plan = contiguous_mul_range_plan(&dims, &dst, &lhs, &rhs).unwrap();
2026        let mut cursor = ContiguousMulOuterCursor::new(&dims, &lhs, &rhs, &plan, 13);
2027        let block_len = plan.inner_len * plan.row_len;
2028
2029        for group in 13..80 {
2030            let index = group * block_len;
2031            assert_eq!(
2032                cursor.a_offset,
2033                strided_offset_for_contiguous_linear_index(&dims, &lhs, &plan.axis_order, index)
2034            );
2035            assert_eq!(
2036                cursor.b_offset,
2037                strided_offset_for_contiguous_linear_index(&dims, &rhs, &plan.axis_order, index)
2038            );
2039            cursor.advance();
2040        }
2041    }
2042}