strided_kernel/
map_view.rs

1//! Map operations on dynamic-rank strided views.
2//!
3//! These are the canonical view-based map functions, equivalent to Julia's `Base.map!`.
4
5use crate::kernel::{
6    build_plan_fused, build_plan_fused_small, ensure_same_shape, for_each_inner_block_preordered,
7    sequential_contiguous_layout, total_len, SMALL_TENSOR_THRESHOLD,
8};
9use crate::maybe_sync::{MaybeSendSync, MaybeSync};
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::Result;
13use strided_view::ElementOp;
14
15#[cfg(feature = "parallel")]
16use crate::fuse::compute_costs;
17#[cfg(feature = "parallel")]
18use crate::threading::{for_each_inner_block_with_offsets, mapreduce_threaded, MINTHREADLENGTH};
19
20// ============================================================================
21// Stride-specialized inner loop helpers
22//
23// When all inner strides are 1 (contiguous in the innermost dimension),
24// we use slice-based iteration so LLVM can auto-vectorize effectively.
25// This is the Rust equivalent of Julia's @simd on the innermost loop.
26// ============================================================================
27
28/// Unary inner loop: `dest[i] = f(Op::apply(src[i]))` for `len` elements.
29#[inline(always)]
30unsafe fn inner_loop_map1<D: Copy, A: Copy, Op: ElementOp<A>>(
31    dp: *mut D,
32    ds: isize,
33    sp: *const A,
34    ss: isize,
35    len: usize,
36    f: &impl Fn(A) -> D,
37) {
38    if ds == 1 && ss == 1 {
39        let src = std::slice::from_raw_parts(sp, len);
40        let dst = std::slice::from_raw_parts_mut(dp, len);
41        simd::dispatch_if_large(len, || {
42            for (d, s) in dst.iter_mut().zip(src.iter()) {
43                *d = f(Op::apply(*s));
44            }
45        });
46    } else {
47        let mut dp = dp;
48        let mut sp = sp;
49        for _ in 0..len {
50            *dp = f(Op::apply(*sp));
51            dp = dp.offset(ds);
52            sp = sp.offset(ss);
53        }
54    }
55}
56
57/// Binary inner loop: `dest[i] = f(OpA::apply(a[i]), OpB::apply(b[i]))`.
58#[inline(always)]
59unsafe fn inner_loop_map2<D: Copy, A: Copy, B: Copy, OpA: ElementOp<A>, OpB: ElementOp<B>>(
60    dp: *mut D,
61    ds: isize,
62    ap: *const A,
63    a_s: isize,
64    bp: *const B,
65    b_s: isize,
66    len: usize,
67    f: &impl Fn(A, B) -> D,
68) {
69    if ds == 1 && a_s == 1 && b_s == 1 {
70        let src_a = std::slice::from_raw_parts(ap, len);
71        let src_b = std::slice::from_raw_parts(bp, len);
72        let dst = std::slice::from_raw_parts_mut(dp, len);
73        simd::dispatch_if_large(len, || {
74            for i in 0..len {
75                dst[i] = f(OpA::apply(src_a[i]), OpB::apply(src_b[i]));
76            }
77        });
78    } else {
79        let mut dp = dp;
80        let mut ap = ap;
81        let mut bp = bp;
82        for _ in 0..len {
83            *dp = f(OpA::apply(*ap), OpB::apply(*bp));
84            dp = dp.offset(ds);
85            ap = ap.offset(a_s);
86            bp = bp.offset(b_s);
87        }
88    }
89}
90
91/// Ternary inner loop: `dest[i] = f(a[i], b[i], c[i])`.
92#[inline(always)]
93unsafe fn inner_loop_map3<
94    D: Copy,
95    A: Copy,
96    B: Copy,
97    C: Copy,
98    OpA: ElementOp<A>,
99    OpB: ElementOp<B>,
100    OpC: ElementOp<C>,
101>(
102    dp: *mut D,
103    ds: isize,
104    ap: *const A,
105    a_s: isize,
106    bp: *const B,
107    b_s: isize,
108    cp: *const C,
109    c_s: isize,
110    len: usize,
111    f: &impl Fn(A, B, C) -> D,
112) {
113    if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 {
114        let src_a = std::slice::from_raw_parts(ap, len);
115        let src_b = std::slice::from_raw_parts(bp, len);
116        let src_c = std::slice::from_raw_parts(cp, len);
117        let dst = std::slice::from_raw_parts_mut(dp, len);
118        simd::dispatch_if_large(len, || {
119            for i in 0..len {
120                dst[i] = f(
121                    OpA::apply(src_a[i]),
122                    OpB::apply(src_b[i]),
123                    OpC::apply(src_c[i]),
124                );
125            }
126        });
127    } else {
128        let mut dp = dp;
129        let mut ap = ap;
130        let mut bp = bp;
131        let mut cp = cp;
132        for _ in 0..len {
133            *dp = f(OpA::apply(*ap), OpB::apply(*bp), OpC::apply(*cp));
134            dp = dp.offset(ds);
135            ap = ap.offset(a_s);
136            bp = bp.offset(b_s);
137            cp = cp.offset(c_s);
138        }
139    }
140}
141
142/// Quaternary inner loop: `dest[i] = f(a[i], b[i], c[i], e[i])`.
143#[inline(always)]
144unsafe fn inner_loop_map4<
145    D: Copy,
146    A: Copy,
147    B: Copy,
148    C: Copy,
149    E: Copy,
150    OpA: ElementOp<A>,
151    OpB: ElementOp<B>,
152    OpC: ElementOp<C>,
153    OpE: ElementOp<E>,
154>(
155    dp: *mut D,
156    ds: isize,
157    ap: *const A,
158    a_s: isize,
159    bp: *const B,
160    b_s: isize,
161    cp: *const C,
162    c_s: isize,
163    ep: *const E,
164    e_s: isize,
165    len: usize,
166    f: &impl Fn(A, B, C, E) -> D,
167) {
168    if ds == 1 && a_s == 1 && b_s == 1 && c_s == 1 && e_s == 1 {
169        let src_a = std::slice::from_raw_parts(ap, len);
170        let src_b = std::slice::from_raw_parts(bp, len);
171        let src_c = std::slice::from_raw_parts(cp, len);
172        let src_e = std::slice::from_raw_parts(ep, len);
173        let dst = std::slice::from_raw_parts_mut(dp, len);
174        simd::dispatch_if_large(len, || {
175            for i in 0..len {
176                dst[i] = f(
177                    OpA::apply(src_a[i]),
178                    OpB::apply(src_b[i]),
179                    OpC::apply(src_c[i]),
180                    OpE::apply(src_e[i]),
181                );
182            }
183        });
184    } else {
185        let mut dp = dp;
186        let mut ap = ap;
187        let mut bp = bp;
188        let mut cp = cp;
189        let mut ep = ep;
190        for _ in 0..len {
191            *dp = f(
192                OpA::apply(*ap),
193                OpB::apply(*bp),
194                OpC::apply(*cp),
195                OpE::apply(*ep),
196            );
197            dp = dp.offset(ds);
198            ap = ap.offset(a_s);
199            bp = bp.offset(b_s);
200            cp = cp.offset(c_s);
201            ep = ep.offset(e_s);
202        }
203    }
204}
205
206/// Apply a function element-wise from source to destination.
207///
208/// The element operation `Op` is applied lazily when reading from `src`.
209/// Source and destination may have different element types.
210pub fn map_into<D: Copy + MaybeSendSync, A: Copy + MaybeSendSync, Op: ElementOp<A>>(
211    dest: &mut StridedViewMut<D>,
212    src: &StridedView<A, Op>,
213    f: impl Fn(A) -> D + MaybeSync,
214) -> Result<()> {
215    ensure_same_shape(dest.dims(), src.dims())?;
216
217    let dst_ptr = dest.as_mut_ptr();
218    let src_ptr = src.ptr();
219    let dst_dims = dest.dims();
220    let dst_strides = dest.strides();
221    let src_strides = src.strides();
222
223    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
224        let len = total_len(dst_dims);
225        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
226        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
227        simd::dispatch_if_large(len, || {
228            for i in 0..len {
229                dst[i] = f(Op::apply(src[i]));
230            }
231        });
232        return Ok(());
233    }
234
235    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
236    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<A>());
237    let total = total_len(dst_dims);
238
239    // Small tensor fast path: skip compute_order and compute_block_sizes
240    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
241        build_plan_fused_small(dst_dims, &strides_list)
242    } else {
243        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
244    };
245
246    #[cfg(feature = "parallel")]
247    {
248        let total: usize = fused_dims.iter().product();
249        if total > MINTHREADLENGTH {
250            use crate::threading::SendPtr;
251            let dst_send = SendPtr(dst_ptr);
252            let src_send = SendPtr(src_ptr as *mut A);
253
254            let costs = compute_costs(&ordered_strides);
255            let initial_offsets = vec![0isize; strides_list.len()];
256            let nthreads = rayon::current_num_threads();
257
258            return mapreduce_threaded(
259                &fused_dims,
260                &plan.block,
261                &ordered_strides,
262                &initial_offsets,
263                &costs,
264                nthreads,
265                0,
266                1,
267                &|dims, blocks, strides_list, offsets| {
268                    for_each_inner_block_with_offsets(
269                        dims,
270                        blocks,
271                        strides_list,
272                        offsets,
273                        |offsets, len, strides| {
274                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
275                            let sp = unsafe { src_send.as_const().offset(offsets[1]) };
276                            unsafe {
277                                inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f)
278                            };
279                            Ok(())
280                        },
281                    )
282                },
283            );
284        }
285    }
286
287    let initial_offsets = vec![0isize; ordered_strides.len()];
288    for_each_inner_block_preordered(
289        &fused_dims,
290        &plan.block,
291        &ordered_strides,
292        &initial_offsets,
293        |offsets, len, strides| {
294            let dp = unsafe { dst_ptr.offset(offsets[0]) };
295            let sp = unsafe { src_ptr.offset(offsets[1]) };
296            unsafe { inner_loop_map1::<D, A, Op>(dp, strides[0], sp, strides[1], len, &f) };
297            Ok(())
298        },
299    )
300}
301
302/// Binary element-wise operation: `dest[i] = f(a[i], b[i])`.
303///
304/// Source operands `a` and `b` may have different element types from each other
305/// and from `dest`. The closure `f` handles per-element type conversion.
306pub fn zip_map2_into<
307    D: Copy + MaybeSendSync,
308    A: Copy + MaybeSendSync,
309    B: Copy + MaybeSendSync,
310    OpA: ElementOp<A>,
311    OpB: ElementOp<B>,
312>(
313    dest: &mut StridedViewMut<D>,
314    a: &StridedView<A, OpA>,
315    b: &StridedView<B, OpB>,
316    f: impl Fn(A, B) -> D + MaybeSync,
317) -> Result<()> {
318    ensure_same_shape(dest.dims(), a.dims())?;
319    ensure_same_shape(dest.dims(), b.dims())?;
320
321    let dst_ptr = dest.as_mut_ptr();
322    let dst_dims = dest.dims();
323    let dst_strides = dest.strides();
324    let a_ptr = a.ptr();
325    let b_ptr = b.ptr();
326
327    let a_strides = a.strides();
328    let b_strides = b.strides();
329
330    if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
331        let len = total_len(dst_dims);
332        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
333        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
334        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
335        simd::dispatch_if_large(len, || {
336            for i in 0..len {
337                dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]));
338            }
339        });
340        return Ok(());
341    }
342
343    let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
344    let elem_size = std::mem::size_of::<D>()
345        .max(std::mem::size_of::<A>())
346        .max(std::mem::size_of::<B>());
347    let total = total_len(dst_dims);
348
349    // Small tensor fast path: skip compute_order and compute_block_sizes
350    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
351        build_plan_fused_small(dst_dims, &strides_list)
352    } else {
353        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
354    };
355
356    #[cfg(feature = "parallel")]
357    {
358        let total: usize = fused_dims.iter().product();
359        if total > MINTHREADLENGTH {
360            use crate::threading::SendPtr;
361            let dst_send = SendPtr(dst_ptr);
362            let a_send = SendPtr(a_ptr as *mut A);
363            let b_send = SendPtr(b_ptr as *mut B);
364
365            let costs = compute_costs(&ordered_strides);
366            let initial_offsets = vec![0isize; strides_list.len()];
367            let nthreads = rayon::current_num_threads();
368
369            return mapreduce_threaded(
370                &fused_dims,
371                &plan.block,
372                &ordered_strides,
373                &initial_offsets,
374                &costs,
375                nthreads,
376                0,
377                1,
378                &|dims, blocks, strides_list, offsets| {
379                    for_each_inner_block_with_offsets(
380                        dims,
381                        blocks,
382                        strides_list,
383                        offsets,
384                        |offsets, len, strides| {
385                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
386                            let ap = unsafe { a_send.as_const().offset(offsets[1]) };
387                            let bp = unsafe { b_send.as_const().offset(offsets[2]) };
388                            unsafe {
389                                inner_loop_map2::<D, A, B, OpA, OpB>(
390                                    dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
391                                )
392                            };
393                            Ok(())
394                        },
395                    )
396                },
397            );
398        }
399    }
400
401    let initial_offsets = vec![0isize; ordered_strides.len()];
402    for_each_inner_block_preordered(
403        &fused_dims,
404        &plan.block,
405        &ordered_strides,
406        &initial_offsets,
407        |offsets, len, strides| {
408            let dp = unsafe { dst_ptr.offset(offsets[0]) };
409            let ap = unsafe { a_ptr.offset(offsets[1]) };
410            let bp = unsafe { b_ptr.offset(offsets[2]) };
411            unsafe {
412                inner_loop_map2::<D, A, B, OpA, OpB>(
413                    dp, strides[0], ap, strides[1], bp, strides[2], len, &f,
414                )
415            };
416            Ok(())
417        },
418    )
419}
420
421/// Ternary element-wise operation: `dest[i] = f(a[i], b[i], c[i])`.
422pub fn zip_map3_into<
423    D: Copy + MaybeSendSync,
424    A: Copy + MaybeSendSync,
425    B: Copy + MaybeSendSync,
426    C: Copy + MaybeSendSync,
427    OpA: ElementOp<A>,
428    OpB: ElementOp<B>,
429    OpC: ElementOp<C>,
430>(
431    dest: &mut StridedViewMut<D>,
432    a: &StridedView<A, OpA>,
433    b: &StridedView<B, OpB>,
434    c: &StridedView<C, OpC>,
435    f: impl Fn(A, B, C) -> D + MaybeSync,
436) -> Result<()> {
437    ensure_same_shape(dest.dims(), a.dims())?;
438    ensure_same_shape(dest.dims(), b.dims())?;
439    ensure_same_shape(dest.dims(), c.dims())?;
440
441    let dst_ptr = dest.as_mut_ptr();
442    let a_ptr = a.ptr();
443    let b_ptr = b.ptr();
444    let c_ptr = c.ptr();
445
446    let dst_dims = dest.dims();
447    let dst_strides = dest.strides();
448
449    if sequential_contiguous_layout(
450        dst_dims,
451        &[dst_strides, a.strides(), b.strides(), c.strides()],
452    )
453    .is_some()
454    {
455        let len = total_len(dst_dims);
456        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
457        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
458        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
459        let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
460        simd::dispatch_if_large(len, || {
461            for i in 0..len {
462                dst[i] = f(OpA::apply(sa[i]), OpB::apply(sb[i]), OpC::apply(sc[i]));
463            }
464        });
465        return Ok(());
466    }
467
468    let strides_list: [&[isize]; 4] = [dst_strides, a.strides(), b.strides(), c.strides()];
469    let elem_size = std::mem::size_of::<D>()
470        .max(std::mem::size_of::<A>())
471        .max(std::mem::size_of::<B>())
472        .max(std::mem::size_of::<C>());
473    let total = total_len(dst_dims);
474
475    // Small tensor fast path: skip compute_order and compute_block_sizes
476    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
477        build_plan_fused_small(dst_dims, &strides_list)
478    } else {
479        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
480    };
481
482    #[cfg(feature = "parallel")]
483    {
484        let total: usize = fused_dims.iter().product();
485        if total > MINTHREADLENGTH {
486            use crate::threading::SendPtr;
487            let dst_send = SendPtr(dst_ptr);
488            let a_send = SendPtr(a_ptr as *mut A);
489            let b_send = SendPtr(b_ptr as *mut B);
490            let c_send = SendPtr(c_ptr as *mut C);
491
492            let costs = compute_costs(&ordered_strides);
493            let initial_offsets = vec![0isize; strides_list.len()];
494            let nthreads = rayon::current_num_threads();
495
496            return mapreduce_threaded(
497                &fused_dims,
498                &plan.block,
499                &ordered_strides,
500                &initial_offsets,
501                &costs,
502                nthreads,
503                0,
504                1,
505                &|dims, blocks, strides_list, offsets| {
506                    for_each_inner_block_with_offsets(
507                        dims,
508                        blocks,
509                        strides_list,
510                        offsets,
511                        |offsets, len, strides| {
512                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
513                            let ap = unsafe { a_send.as_const().offset(offsets[1]) };
514                            let bp = unsafe { b_send.as_const().offset(offsets[2]) };
515                            let cp = unsafe { c_send.as_const().offset(offsets[3]) };
516                            unsafe {
517                                inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
518                                    dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
519                                    len, &f,
520                                )
521                            };
522                            Ok(())
523                        },
524                    )
525                },
526            );
527        }
528    }
529
530    let initial_offsets = vec![0isize; ordered_strides.len()];
531    for_each_inner_block_preordered(
532        &fused_dims,
533        &plan.block,
534        &ordered_strides,
535        &initial_offsets,
536        |offsets, len, strides| {
537            let dp = unsafe { dst_ptr.offset(offsets[0]) };
538            let ap = unsafe { a_ptr.offset(offsets[1]) };
539            let bp = unsafe { b_ptr.offset(offsets[2]) };
540            let cp = unsafe { c_ptr.offset(offsets[3]) };
541            unsafe {
542                inner_loop_map3::<D, A, B, C, OpA, OpB, OpC>(
543                    dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], len, &f,
544                )
545            };
546            Ok(())
547        },
548    )
549}
550
551/// Quaternary element-wise operation: `dest[i] = f(a[i], b[i], c[i], e[i])`.
552pub fn zip_map4_into<
553    D: Copy + MaybeSendSync,
554    A: Copy + MaybeSendSync,
555    B: Copy + MaybeSendSync,
556    C: Copy + MaybeSendSync,
557    E: Copy + MaybeSendSync,
558    OpA: ElementOp<A>,
559    OpB: ElementOp<B>,
560    OpC: ElementOp<C>,
561    OpE: ElementOp<E>,
562>(
563    dest: &mut StridedViewMut<D>,
564    a: &StridedView<A, OpA>,
565    b: &StridedView<B, OpB>,
566    c: &StridedView<C, OpC>,
567    e: &StridedView<E, OpE>,
568    f: impl Fn(A, B, C, E) -> D + MaybeSync,
569) -> Result<()> {
570    ensure_same_shape(dest.dims(), a.dims())?;
571    ensure_same_shape(dest.dims(), b.dims())?;
572    ensure_same_shape(dest.dims(), c.dims())?;
573    ensure_same_shape(dest.dims(), e.dims())?;
574
575    let dst_ptr = dest.as_mut_ptr();
576    let a_ptr = a.ptr();
577    let b_ptr = b.ptr();
578    let c_ptr = c.ptr();
579    let e_ptr = e.ptr();
580
581    let dst_dims = dest.dims();
582    let dst_strides = dest.strides();
583
584    if sequential_contiguous_layout(
585        dst_dims,
586        &[
587            dst_strides,
588            a.strides(),
589            b.strides(),
590            c.strides(),
591            e.strides(),
592        ],
593    )
594    .is_some()
595    {
596        let len = total_len(dst_dims);
597        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
598        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
599        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
600        let sc = unsafe { std::slice::from_raw_parts(c_ptr, len) };
601        let se = unsafe { std::slice::from_raw_parts(e_ptr, len) };
602        simd::dispatch_if_large(len, || {
603            for i in 0..len {
604                dst[i] = f(
605                    OpA::apply(sa[i]),
606                    OpB::apply(sb[i]),
607                    OpC::apply(sc[i]),
608                    OpE::apply(se[i]),
609                );
610            }
611        });
612        return Ok(());
613    }
614
615    let strides_list: [&[isize]; 5] = [
616        dst_strides,
617        a.strides(),
618        b.strides(),
619        c.strides(),
620        e.strides(),
621    ];
622    let elem_size = std::mem::size_of::<D>()
623        .max(std::mem::size_of::<A>())
624        .max(std::mem::size_of::<B>())
625        .max(std::mem::size_of::<C>())
626        .max(std::mem::size_of::<E>());
627    let total = total_len(dst_dims);
628
629    // Small tensor fast path: skip compute_order and compute_block_sizes
630    let (fused_dims, ordered_strides, plan) = if total <= SMALL_TENSOR_THRESHOLD {
631        build_plan_fused_small(dst_dims, &strides_list)
632    } else {
633        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size)
634    };
635
636    #[cfg(feature = "parallel")]
637    {
638        let total: usize = fused_dims.iter().product();
639        if total > MINTHREADLENGTH {
640            use crate::threading::SendPtr;
641            let dst_send = SendPtr(dst_ptr);
642            let a_send = SendPtr(a_ptr as *mut A);
643            let b_send = SendPtr(b_ptr as *mut B);
644            let c_send = SendPtr(c_ptr as *mut C);
645            let e_send = SendPtr(e_ptr as *mut E);
646
647            let costs = compute_costs(&ordered_strides);
648            let initial_offsets = vec![0isize; strides_list.len()];
649            let nthreads = rayon::current_num_threads();
650
651            return mapreduce_threaded(
652                &fused_dims,
653                &plan.block,
654                &ordered_strides,
655                &initial_offsets,
656                &costs,
657                nthreads,
658                0,
659                1,
660                &|dims, blocks, strides_list, offsets| {
661                    for_each_inner_block_with_offsets(
662                        dims,
663                        blocks,
664                        strides_list,
665                        offsets,
666                        |offsets, len, strides| {
667                            let dp = unsafe { dst_send.as_ptr().offset(offsets[0]) };
668                            let ap = unsafe { a_send.as_const().offset(offsets[1]) };
669                            let bp = unsafe { b_send.as_const().offset(offsets[2]) };
670                            let cp = unsafe { c_send.as_const().offset(offsets[3]) };
671                            let ep = unsafe { e_send.as_const().offset(offsets[4]) };
672                            unsafe {
673                                inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
674                                    dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3],
675                                    ep, strides[4], len, &f,
676                                )
677                            };
678                            Ok(())
679                        },
680                    )
681                },
682            );
683        }
684    }
685
686    let initial_offsets = vec![0isize; ordered_strides.len()];
687    for_each_inner_block_preordered(
688        &fused_dims,
689        &plan.block,
690        &ordered_strides,
691        &initial_offsets,
692        |offsets, len, strides| {
693            let dp = unsafe { dst_ptr.offset(offsets[0]) };
694            let ap = unsafe { a_ptr.offset(offsets[1]) };
695            let bp = unsafe { b_ptr.offset(offsets[2]) };
696            let cp = unsafe { c_ptr.offset(offsets[3]) };
697            let ep = unsafe { e_ptr.offset(offsets[4]) };
698            unsafe {
699                inner_loop_map4::<D, A, B, C, E, OpA, OpB, OpC, OpE>(
700                    dp, strides[0], ap, strides[1], bp, strides[2], cp, strides[3], ep, strides[4],
701                    len, &f,
702                )
703            };
704            Ok(())
705        },
706    )
707}