strided_kernel/
ops_view.rs

1//! High-level operations on dynamic-rank strided views.
2
3use crate::kernel::{
4    build_plan_fused, ensure_same_shape, for_each_inner_block_preordered, same_contiguous_layout,
5    sequential_contiguous_layout, total_len,
6};
7use crate::map_view::{map_into, zip_map2_into};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::reduce_view::reduce;
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::{Result, StridedError};
13use num_traits::Zero;
14use std::ops::{Add, Mul};
15use strided_view::{ElementOp, ElementOpApply};
16
17#[cfg(feature = "parallel")]
18use crate::fuse::compute_costs;
19#[cfg(feature = "parallel")]
20use crate::threading::{
21    for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
22};
23
24// ============================================================================
25// Stride-specialized inner loop helpers for ops_view
26//
27// When all inner strides are 1 (contiguous in the innermost dimension),
28// we use slice-based iteration so LLVM can auto-vectorize effectively.
29// This mirrors the inner_loop_map* helpers in map_view.rs.
30// ============================================================================
31
32/// Inner loop for add: `dst[i] += Op::apply(src[i])`.
33#[inline(always)]
34unsafe fn inner_loop_add<D: Copy + Add<S, Output = D>, S: Copy, Op: ElementOp<S>>(
35    dp: *mut D,
36    ds: isize,
37    sp: *const S,
38    ss: isize,
39    len: usize,
40) {
41    if ds == 1 && ss == 1 {
42        let dst = std::slice::from_raw_parts_mut(dp, len);
43        let src = std::slice::from_raw_parts(sp, len);
44        simd::dispatch_if_large(len, || {
45            for i in 0..len {
46                dst[i] = dst[i] + Op::apply(src[i]);
47            }
48        });
49    } else {
50        let mut dp = dp;
51        let mut sp = sp;
52        for _ in 0..len {
53            *dp = *dp + Op::apply(*sp);
54            dp = dp.offset(ds);
55            sp = sp.offset(ss);
56        }
57    }
58}
59
60/// Inner loop for mul: `dst[i] *= Op::apply(src[i])`.
61#[inline(always)]
62unsafe fn inner_loop_mul<D: Copy + Mul<S, Output = D>, S: Copy, Op: ElementOp<S>>(
63    dp: *mut D,
64    ds: isize,
65    sp: *const S,
66    ss: isize,
67    len: usize,
68) {
69    if ds == 1 && ss == 1 {
70        let dst = std::slice::from_raw_parts_mut(dp, len);
71        let src = std::slice::from_raw_parts(sp, len);
72        simd::dispatch_if_large(len, || {
73            for i in 0..len {
74                dst[i] = dst[i] * Op::apply(src[i]);
75            }
76        });
77    } else {
78        let mut dp = dp;
79        let mut sp = sp;
80        for _ in 0..len {
81            *dp = *dp * Op::apply(*sp);
82            dp = dp.offset(ds);
83            sp = sp.offset(ss);
84        }
85    }
86}
87
88/// Inner loop for axpy: `dst[i] = alpha * Op::apply(src[i]) + dst[i]`.
89#[inline(always)]
90unsafe fn inner_loop_axpy<
91    D: Copy + Add<D, Output = D>,
92    S: Copy,
93    A: Copy + Mul<S, Output = D>,
94    Op: ElementOp<S>,
95>(
96    dp: *mut D,
97    ds: isize,
98    sp: *const S,
99    ss: isize,
100    len: usize,
101    alpha: A,
102) {
103    if ds == 1 && ss == 1 {
104        let dst = std::slice::from_raw_parts_mut(dp, len);
105        let src = std::slice::from_raw_parts(sp, len);
106        simd::dispatch_if_large(len, || {
107            for i in 0..len {
108                dst[i] = alpha * Op::apply(src[i]) + dst[i];
109            }
110        });
111    } else {
112        let mut dp = dp;
113        let mut sp = sp;
114        for _ in 0..len {
115            *dp = alpha * Op::apply(*sp) + *dp;
116            dp = dp.offset(ds);
117            sp = sp.offset(ss);
118        }
119    }
120}
121
122/// Inner loop for fma: `dst[i] += OpA::apply(a[i]) * OpB::apply(b[i])`.
123#[inline(always)]
124unsafe fn inner_loop_fma<
125    D: Copy + Add<D, Output = D>,
126    A: Copy + Mul<B, Output = D>,
127    B: Copy,
128    OpA: ElementOp<A>,
129    OpB: ElementOp<B>,
130>(
131    dp: *mut D,
132    ds: isize,
133    ap: *const A,
134    a_s: isize,
135    bp: *const B,
136    b_s: isize,
137    len: usize,
138) {
139    if ds == 1 && a_s == 1 && b_s == 1 {
140        let dst = std::slice::from_raw_parts_mut(dp, len);
141        let sa = std::slice::from_raw_parts(ap, len);
142        let sb = std::slice::from_raw_parts(bp, len);
143        simd::dispatch_if_large(len, || {
144            for i in 0..len {
145                dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
146            }
147        });
148    } else {
149        let mut dp = dp;
150        let mut ap = ap;
151        let mut bp = bp;
152        for _ in 0..len {
153            *dp = *dp + OpA::apply(*ap) * OpB::apply(*bp);
154            dp = dp.offset(ds);
155            ap = ap.offset(a_s);
156            bp = bp.offset(b_s);
157        }
158    }
159}
160
161/// Inner loop for dot: `acc += OpA::apply(a[i]) * OpB::apply(b[i])`.
162#[inline(always)]
163unsafe fn inner_loop_dot<
164    A: Copy + Mul<B, Output = R>,
165    B: Copy,
166    R: Copy + Add<R, Output = R>,
167    OpA: ElementOp<A>,
168    OpB: ElementOp<B>,
169>(
170    ap: *const A,
171    a_s: isize,
172    bp: *const B,
173    b_s: isize,
174    len: usize,
175    mut acc: R,
176) -> R {
177    if a_s == 1 && b_s == 1 {
178        let sa = std::slice::from_raw_parts(ap, len);
179        let sb = std::slice::from_raw_parts(bp, len);
180        simd::dispatch_if_large(len, || {
181            for i in 0..len {
182                acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
183            }
184        });
185    } else {
186        let mut ap = ap;
187        let mut bp = bp;
188        for _ in 0..len {
189            acc = acc + OpA::apply(*ap) * OpB::apply(*bp);
190            ap = ap.offset(a_s);
191            bp = bp.offset(b_s);
192        }
193    }
194    acc
195}
196
197/// Copy elements from source to destination: `dest[i] = src[i]`.
198pub fn copy_into<T: Copy + MaybeSendSync, Op: ElementOp<T>>(
199    dest: &mut StridedViewMut<T>,
200    src: &StridedView<T, Op>,
201) -> Result<()> {
202    ensure_same_shape(dest.dims(), src.dims())?;
203
204    let dst_ptr = dest.as_mut_ptr();
205    let src_ptr = src.ptr();
206    let dst_dims = dest.dims();
207    let dst_strides = dest.strides();
208    let src_strides = src.strides();
209
210    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
211        let len = total_len(dst_dims);
212        if Op::IS_IDENTITY {
213            debug_assert!(
214                {
215                    let nbytes = len
216                        .checked_mul(std::mem::size_of::<T>())
217                        .expect("copy size must not overflow");
218                    let dst_start = dst_ptr as usize;
219                    let src_start = src_ptr as usize;
220                    let dst_end = dst_start.saturating_add(nbytes);
221                    let src_end = src_start.saturating_add(nbytes);
222                    dst_end <= src_start || src_end <= dst_start
223                },
224                "overlapping src/dest is not supported"
225            );
226            unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
227        } else {
228            let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
229            let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
230            simd::dispatch_if_large(len, || {
231                for i in 0..len {
232                    dst[i] = Op::apply(src[i]);
233                }
234            });
235        }
236        return Ok(());
237    }
238
239    map_into(dest, src, |x| x)
240}
241
242/// Element-wise addition: `dest[i] += src[i]`.
243///
244/// Source may have a different element type from destination.
245pub fn add<
246    D: Copy + Add<S, Output = D> + MaybeSendSync,
247    S: Copy + MaybeSendSync,
248    Op: ElementOp<S>,
249>(
250    dest: &mut StridedViewMut<D>,
251    src: &StridedView<S, Op>,
252) -> Result<()> {
253    ensure_same_shape(dest.dims(), src.dims())?;
254
255    let dst_ptr = dest.as_mut_ptr();
256    let src_ptr = src.ptr();
257    let dst_dims = dest.dims();
258    let dst_strides = dest.strides();
259    let src_strides = src.strides();
260
261    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
262        let len = total_len(dst_dims);
263        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
264        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
265        simd::dispatch_if_large(len, || {
266            for i in 0..len {
267                dst[i] = dst[i] + Op::apply(src[i]);
268            }
269        });
270        return Ok(());
271    }
272
273    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
274    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
275
276    let (fused_dims, ordered_strides, plan) =
277        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
278
279    #[cfg(feature = "parallel")]
280    {
281        let total: usize = fused_dims.iter().product();
282        if total > MINTHREADLENGTH {
283            let dst_send = SendPtr(dst_ptr);
284            let src_send = SendPtr(src_ptr as *mut S);
285
286            let costs = compute_costs(&ordered_strides);
287            let initial_offsets = vec![0isize; strides_list.len()];
288            let nthreads = rayon::current_num_threads();
289
290            return mapreduce_threaded(
291                &fused_dims,
292                &plan.block,
293                &ordered_strides,
294                &initial_offsets,
295                &costs,
296                nthreads,
297                0,
298                1,
299                &|dims, blocks, strides_list, offsets| {
300                    for_each_inner_block_with_offsets(
301                        dims,
302                        blocks,
303                        strides_list,
304                        offsets,
305                        |offsets, len, strides| {
306                            unsafe {
307                                inner_loop_add::<D, S, Op>(
308                                    dst_send.as_ptr().offset(offsets[0]),
309                                    strides[0],
310                                    src_send.as_const().offset(offsets[1]),
311                                    strides[1],
312                                    len,
313                                )
314                            };
315                            Ok(())
316                        },
317                    )
318                },
319            );
320        }
321    }
322
323    let initial_offsets = vec![0isize; ordered_strides.len()];
324    for_each_inner_block_preordered(
325        &fused_dims,
326        &plan.block,
327        &ordered_strides,
328        &initial_offsets,
329        |offsets, len, strides| {
330            unsafe {
331                inner_loop_add::<D, S, Op>(
332                    dst_ptr.offset(offsets[0]),
333                    strides[0],
334                    src_ptr.offset(offsets[1]),
335                    strides[1],
336                    len,
337                )
338            };
339            Ok(())
340        },
341    )
342}
343
344/// Element-wise multiplication: `dest[i] *= src[i]`.
345///
346/// Source may have a different element type from destination.
347pub fn mul<
348    D: Copy + Mul<S, Output = D> + MaybeSendSync,
349    S: Copy + MaybeSendSync,
350    Op: ElementOp<S>,
351>(
352    dest: &mut StridedViewMut<D>,
353    src: &StridedView<S, Op>,
354) -> Result<()> {
355    ensure_same_shape(dest.dims(), src.dims())?;
356
357    let dst_ptr = dest.as_mut_ptr();
358    let src_ptr = src.ptr();
359    let dst_dims = dest.dims();
360    let dst_strides = dest.strides();
361    let src_strides = src.strides();
362
363    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
364        let len = total_len(dst_dims);
365        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
366        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
367        simd::dispatch_if_large(len, || {
368            for i in 0..len {
369                dst[i] = dst[i] * Op::apply(src[i]);
370            }
371        });
372        return Ok(());
373    }
374
375    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
376    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
377
378    let (fused_dims, ordered_strides, plan) =
379        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
380
381    #[cfg(feature = "parallel")]
382    {
383        let total: usize = fused_dims.iter().product();
384        if total > MINTHREADLENGTH {
385            let dst_send = SendPtr(dst_ptr);
386            let src_send = SendPtr(src_ptr as *mut S);
387
388            let costs = compute_costs(&ordered_strides);
389            let initial_offsets = vec![0isize; strides_list.len()];
390            let nthreads = rayon::current_num_threads();
391
392            return mapreduce_threaded(
393                &fused_dims,
394                &plan.block,
395                &ordered_strides,
396                &initial_offsets,
397                &costs,
398                nthreads,
399                0,
400                1,
401                &|dims, blocks, strides_list, offsets| {
402                    for_each_inner_block_with_offsets(
403                        dims,
404                        blocks,
405                        strides_list,
406                        offsets,
407                        |offsets, len, strides| {
408                            unsafe {
409                                inner_loop_mul::<D, S, Op>(
410                                    dst_send.as_ptr().offset(offsets[0]),
411                                    strides[0],
412                                    src_send.as_const().offset(offsets[1]),
413                                    strides[1],
414                                    len,
415                                )
416                            };
417                            Ok(())
418                        },
419                    )
420                },
421            );
422        }
423    }
424
425    let initial_offsets = vec![0isize; ordered_strides.len()];
426    for_each_inner_block_preordered(
427        &fused_dims,
428        &plan.block,
429        &ordered_strides,
430        &initial_offsets,
431        |offsets, len, strides| {
432            unsafe {
433                inner_loop_mul::<D, S, Op>(
434                    dst_ptr.offset(offsets[0]),
435                    strides[0],
436                    src_ptr.offset(offsets[1]),
437                    strides[1],
438                    len,
439                )
440            };
441            Ok(())
442        },
443    )
444}
445
446/// AXPY: `dest[i] = alpha * src[i] + dest[i]`.
447///
448/// Alpha, source, and destination may have different element types.
449pub fn axpy<D, S, A, Op>(
450    dest: &mut StridedViewMut<D>,
451    src: &StridedView<S, Op>,
452    alpha: A,
453) -> Result<()>
454where
455    A: Copy + Mul<S, Output = D> + MaybeSync,
456    D: Copy + Add<D, Output = D> + MaybeSendSync,
457    S: Copy + MaybeSendSync,
458    Op: ElementOp<S>,
459{
460    ensure_same_shape(dest.dims(), src.dims())?;
461
462    let dst_ptr = dest.as_mut_ptr();
463    let src_ptr = src.ptr();
464    let dst_dims = dest.dims();
465    let dst_strides = dest.strides();
466    let src_strides = src.strides();
467
468    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
469        let len = total_len(dst_dims);
470        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
471        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
472        simd::dispatch_if_large(len, || {
473            for i in 0..len {
474                dst[i] = alpha * Op::apply(src[i]) + dst[i];
475            }
476        });
477        return Ok(());
478    }
479
480    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
481    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
482
483    let (fused_dims, ordered_strides, plan) =
484        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
485
486    #[cfg(feature = "parallel")]
487    {
488        let total: usize = fused_dims.iter().product();
489        if total > MINTHREADLENGTH {
490            let dst_send = SendPtr(dst_ptr);
491            let src_send = SendPtr(src_ptr as *mut S);
492
493            let costs = compute_costs(&ordered_strides);
494            let initial_offsets = vec![0isize; strides_list.len()];
495            let nthreads = rayon::current_num_threads();
496
497            return mapreduce_threaded(
498                &fused_dims,
499                &plan.block,
500                &ordered_strides,
501                &initial_offsets,
502                &costs,
503                nthreads,
504                0,
505                1,
506                &|dims, blocks, strides_list, offsets| {
507                    for_each_inner_block_with_offsets(
508                        dims,
509                        blocks,
510                        strides_list,
511                        offsets,
512                        |offsets, len, strides| {
513                            unsafe {
514                                inner_loop_axpy::<D, S, A, Op>(
515                                    dst_send.as_ptr().offset(offsets[0]),
516                                    strides[0],
517                                    src_send.as_const().offset(offsets[1]),
518                                    strides[1],
519                                    len,
520                                    alpha,
521                                )
522                            };
523                            Ok(())
524                        },
525                    )
526                },
527            );
528        }
529    }
530
531    let initial_offsets = vec![0isize; ordered_strides.len()];
532    for_each_inner_block_preordered(
533        &fused_dims,
534        &plan.block,
535        &ordered_strides,
536        &initial_offsets,
537        |offsets, len, strides| {
538            unsafe {
539                inner_loop_axpy::<D, S, A, Op>(
540                    dst_ptr.offset(offsets[0]),
541                    strides[0],
542                    src_ptr.offset(offsets[1]),
543                    strides[1],
544                    len,
545                    alpha,
546                )
547            };
548            Ok(())
549        },
550    )
551}
552
553/// Fused multiply-add: `dest[i] += OpA::apply(a[i]) * OpB::apply(b[i])`.
554///
555/// Operands may have different element types. Element operations are applied lazily.
556pub fn fma<D, A, B, OpA, OpB>(
557    dest: &mut StridedViewMut<D>,
558    a: &StridedView<A, OpA>,
559    b: &StridedView<B, OpB>,
560) -> Result<()>
561where
562    A: Copy + Mul<B, Output = D> + MaybeSendSync,
563    B: Copy + MaybeSendSync,
564    D: Copy + Add<D, Output = D> + MaybeSendSync,
565    OpA: ElementOp<A>,
566    OpB: ElementOp<B>,
567{
568    ensure_same_shape(dest.dims(), a.dims())?;
569    ensure_same_shape(dest.dims(), b.dims())?;
570
571    let dst_ptr = dest.as_mut_ptr();
572    let a_ptr = a.ptr();
573    let b_ptr = b.ptr();
574    let dst_dims = dest.dims();
575    let dst_strides = dest.strides();
576    let a_strides = a.strides();
577    let b_strides = b.strides();
578
579    if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
580        let len = total_len(dst_dims);
581        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
582        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
583        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
584        simd::dispatch_if_large(len, || {
585            for i in 0..len {
586                dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
587            }
588        });
589        return Ok(());
590    }
591
592    let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
593    let elem_size = std::mem::size_of::<D>()
594        .max(std::mem::size_of::<A>())
595        .max(std::mem::size_of::<B>());
596
597    let (fused_dims, ordered_strides, plan) =
598        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
599
600    #[cfg(feature = "parallel")]
601    {
602        let total: usize = fused_dims.iter().product();
603        if total > MINTHREADLENGTH {
604            let dst_send = SendPtr(dst_ptr);
605            let a_send = SendPtr(a_ptr as *mut A);
606            let b_send = SendPtr(b_ptr as *mut B);
607
608            let costs = compute_costs(&ordered_strides);
609            let initial_offsets = vec![0isize; strides_list.len()];
610            let nthreads = rayon::current_num_threads();
611
612            return mapreduce_threaded(
613                &fused_dims,
614                &plan.block,
615                &ordered_strides,
616                &initial_offsets,
617                &costs,
618                nthreads,
619                0,
620                1,
621                &|dims, blocks, strides_list, offsets| {
622                    for_each_inner_block_with_offsets(
623                        dims,
624                        blocks,
625                        strides_list,
626                        offsets,
627                        |offsets, len, strides| {
628                            unsafe {
629                                inner_loop_fma::<D, A, B, OpA, OpB>(
630                                    dst_send.as_ptr().offset(offsets[0]),
631                                    strides[0],
632                                    a_send.as_const().offset(offsets[1]),
633                                    strides[1],
634                                    b_send.as_const().offset(offsets[2]),
635                                    strides[2],
636                                    len,
637                                )
638                            };
639                            Ok(())
640                        },
641                    )
642                },
643            );
644        }
645    }
646
647    let initial_offsets = vec![0isize; ordered_strides.len()];
648    for_each_inner_block_preordered(
649        &fused_dims,
650        &plan.block,
651        &ordered_strides,
652        &initial_offsets,
653        |offsets, len, strides| {
654            unsafe {
655                inner_loop_fma::<D, A, B, OpA, OpB>(
656                    dst_ptr.offset(offsets[0]),
657                    strides[0],
658                    a_ptr.offset(offsets[1]),
659                    strides[1],
660                    b_ptr.offset(offsets[2]),
661                    strides[2],
662                    len,
663                )
664            };
665            Ok(())
666        },
667    )
668}
669
670#[cfg(feature = "parallel")]
671fn parallel_simd_sum<T: Copy + Zero + Add<Output = T> + simd::MaybeSimdOps + Send + Sync>(
672    src: &[T],
673) -> Option<T> {
674    use rayon::prelude::*;
675    // Check that T has SIMD support
676    if T::try_simd_sum(&[]).is_none() {
677        return None;
678    }
679    let nthreads = rayon::current_num_threads();
680    let chunk_size = (src.len() + nthreads - 1) / nthreads;
681    let result = src
682        .par_chunks(chunk_size)
683        .map(|chunk| T::try_simd_sum(chunk).unwrap())
684        .reduce(|| T::zero(), |a, b| a + b);
685    Some(result)
686}
687
688/// Sum all elements: `sum(src)`.
689pub fn sum<
690    T: Copy + Zero + Add<Output = T> + MaybeSendSync + simd::MaybeSimdOps,
691    Op: ElementOp<T>,
692>(
693    src: &StridedView<T, Op>,
694) -> Result<T> {
695    // SIMD fast path: contiguous Identity view with SIMD support
696    if Op::IS_IDENTITY {
697        if same_contiguous_layout(src.dims(), &[src.strides()]).is_some() {
698            let len = total_len(src.dims());
699            let src_slice = unsafe { std::slice::from_raw_parts(src.ptr(), len) };
700
701            #[cfg(feature = "parallel")]
702            if len > MINTHREADLENGTH {
703                if let Some(result) = parallel_simd_sum(src_slice) {
704                    return Ok(result);
705                }
706            }
707
708            if let Some(result) = T::try_simd_sum(src_slice) {
709                return Ok(result);
710            }
711        }
712    }
713    reduce(src, |x| x, |a, b| a + b, T::zero())
714}
715
716/// Dot product: `sum(OpA::apply(a[i]) * OpB::apply(b[i]))`.
717///
718/// Operands may have different element types. Result type `R` must be `A * B`.
719/// SIMD fast path fires only when `A == B == R` (same type) and both Identity ops.
720pub fn dot<A, B, R, OpA, OpB>(a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) -> Result<R>
721where
722    A: Copy + Mul<B, Output = R> + MaybeSendSync + 'static,
723    B: Copy + MaybeSendSync + 'static,
724    R: Copy + Zero + Add<Output = R> + MaybeSendSync + simd::MaybeSimdOps + 'static,
725    OpA: ElementOp<A>,
726    OpB: ElementOp<B>,
727{
728    ensure_same_shape(a.dims(), b.dims())?;
729
730    let a_ptr = a.ptr();
731    let b_ptr = b.ptr();
732    let a_strides = a.strides();
733    let b_strides = b.strides();
734    let a_dims = a.dims();
735
736    if same_contiguous_layout(a_dims, &[a_strides, b_strides]).is_some() {
737        let len = total_len(a_dims);
738
739        // SIMD fast path: both contiguous, both Identity ops, same type
740        if OpA::IS_IDENTITY
741            && OpB::IS_IDENTITY
742            && std::any::TypeId::of::<A>() == std::any::TypeId::of::<R>()
743            && std::any::TypeId::of::<B>() == std::any::TypeId::of::<R>()
744        {
745            let sa = unsafe { std::slice::from_raw_parts(a_ptr as *const R, len) };
746            let sb = unsafe { std::slice::from_raw_parts(b_ptr as *const R, len) };
747            if let Some(result) = R::try_simd_dot(sa, sb) {
748                return Ok(result);
749            }
750        }
751
752        // Generic contiguous fast path
753        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
754        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
755        let mut acc = R::zero();
756        simd::dispatch_if_large(len, || {
757            for i in 0..len {
758                acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
759            }
760        });
761        return Ok(acc);
762    }
763
764    let strides_list: [&[isize]; 2] = [a_strides, b_strides];
765    let elem_size = std::mem::size_of::<A>()
766        .max(std::mem::size_of::<B>())
767        .max(std::mem::size_of::<R>());
768
769    let (fused_dims, ordered_strides, plan) =
770        build_plan_fused(a_dims, &strides_list, None, elem_size);
771
772    let mut acc = R::zero();
773    let initial_offsets = vec![0isize; ordered_strides.len()];
774    for_each_inner_block_preordered(
775        &fused_dims,
776        &plan.block,
777        &ordered_strides,
778        &initial_offsets,
779        |offsets, len, strides| {
780            acc = unsafe {
781                inner_loop_dot::<A, B, R, OpA, OpB>(
782                    a_ptr.offset(offsets[0]),
783                    strides[0],
784                    b_ptr.offset(offsets[1]),
785                    strides[1],
786                    len,
787                    acc,
788                )
789            };
790            Ok(())
791        },
792    )?;
793
794    Ok(acc)
795}
796
797/// Symmetrize a square matrix: `dest = (src + src^T) / 2`.
798pub fn symmetrize_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
799where
800    T: Copy
801        + Add<Output = T>
802        + Mul<Output = T>
803        + num_traits::FromPrimitive
804        + std::ops::Div<Output = T>
805        + MaybeSendSync,
806{
807    if src.ndim() != 2 {
808        return Err(StridedError::RankMismatch(src.ndim(), 2));
809    }
810    let rows = src.dims()[0];
811    let cols = src.dims()[1];
812    if rows != cols {
813        return Err(StridedError::NonSquare { rows, cols });
814    }
815
816    let src_t = src.permute(&[1, 0])?;
817    let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
818
819    zip_map2_into(dest, src, &src_t, |a, b| (a + b) * half)
820}
821
822/// Conjugate-symmetrize a square matrix: `dest = (src + conj(src^T)) / 2`.
823pub fn symmetrize_conj_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
824where
825    T: Copy
826        + ElementOpApply
827        + Add<Output = T>
828        + Mul<Output = T>
829        + num_traits::FromPrimitive
830        + std::ops::Div<Output = T>
831        + MaybeSendSync,
832{
833    if src.ndim() != 2 {
834        return Err(StridedError::RankMismatch(src.ndim(), 2));
835    }
836    let rows = src.dims()[0];
837    let cols = src.dims()[1];
838    if rows != cols {
839        return Err(StridedError::NonSquare { rows, cols });
840    }
841
842    // adjoint = conj + transpose
843    let src_adj = src.adjoint_2d()?;
844    let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
845
846    zip_map2_into(dest, src, &src_adj, |a, b| (a + b) * half)
847}
848
849/// Copy with scaling: `dest[i] = scale * src[i]`.
850///
851/// Scale, source, and destination may have different element types.
852pub fn copy_scale<D, S, A, Op>(
853    dest: &mut StridedViewMut<D>,
854    src: &StridedView<S, Op>,
855    scale: A,
856) -> Result<()>
857where
858    A: Copy + Mul<S, Output = D> + MaybeSync,
859    D: Copy + MaybeSendSync,
860    S: Copy + MaybeSendSync,
861    Op: ElementOp<S>,
862{
863    map_into(dest, src, |x| scale * x)
864}
865
866/// Copy with complex conjugation: `dest[i] = conj(src[i])`.
867pub fn copy_conj<T: Copy + ElementOpApply + MaybeSendSync>(
868    dest: &mut StridedViewMut<T>,
869    src: &StridedView<T>,
870) -> Result<()> {
871    let src_conj = src.conj();
872    copy_into(dest, &src_conj)
873}
874
875/// Copy with transpose and scaling: `dest[j,i] = scale * src[i,j]`.
876pub fn copy_transpose_scale_into<T>(
877    dest: &mut StridedViewMut<T>,
878    src: &StridedView<T>,
879    scale: T,
880) -> Result<()>
881where
882    T: Copy + ElementOpApply + Mul<Output = T> + MaybeSendSync,
883{
884    if src.ndim() != 2 || dest.ndim() != 2 {
885        return Err(StridedError::RankMismatch(src.ndim(), 2));
886    }
887    let src_t = src.transpose_2d()?;
888    map_into(dest, &src_t, |x| scale * x)
889}