strided_kernel/
ops_view.rs

1//! High-level operations on dynamic-rank strided views.
2
3use crate::kernel::{
4    build_plan_fused, ensure_same_shape, for_each_inner_block_preordered, same_contiguous_layout,
5    sequential_contiguous_layout, total_len,
6};
7use crate::map_view::{map_into, zip_map2_into};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::reduce_view::reduce;
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::{Result, StridedError};
13use num_traits::Zero;
14use std::ops::{Add, Mul};
15use strided_view::{ElementOp, ElementOpApply};
16
17#[cfg(feature = "parallel")]
18use crate::fuse::compute_costs;
19#[cfg(feature = "parallel")]
20use crate::threading::{
21    for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
22};
23
24// ============================================================================
25// Stride-specialized inner loop helpers for ops_view
26//
27// When all inner strides are 1 (contiguous in the innermost dimension),
28// we use slice-based iteration so LLVM can auto-vectorize effectively.
29// This mirrors the inner_loop_map* helpers in map_view.rs.
30// ============================================================================
31
32/// Inner loop for add: `dst[i] += Op::apply(src[i])`.
33#[inline(always)]
34unsafe fn inner_loop_add<D: Copy + Add<S, Output = D>, S: Copy, Op: ElementOp<S>>(
35    dp: *mut D,
36    ds: isize,
37    sp: *const S,
38    ss: isize,
39    len: usize,
40) {
41    if ds == 1 && ss == 1 {
42        let dst = std::slice::from_raw_parts_mut(dp, len);
43        let src = std::slice::from_raw_parts(sp, len);
44        simd::dispatch_if_large(len, || {
45            for i in 0..len {
46                dst[i] = dst[i] + Op::apply(src[i]);
47            }
48        });
49    } else {
50        let mut dp = dp;
51        let mut sp = sp;
52        for _ in 0..len {
53            *dp = *dp + Op::apply(*sp);
54            dp = dp.offset(ds);
55            sp = sp.offset(ss);
56        }
57    }
58}
59
60/// Inner loop for mul: `dst[i] *= Op::apply(src[i])`.
61#[inline(always)]
62unsafe fn inner_loop_mul<D: Copy + Mul<S, Output = D>, S: Copy, Op: ElementOp<S>>(
63    dp: *mut D,
64    ds: isize,
65    sp: *const S,
66    ss: isize,
67    len: usize,
68) {
69    if ds == 1 && ss == 1 {
70        let dst = std::slice::from_raw_parts_mut(dp, len);
71        let src = std::slice::from_raw_parts(sp, len);
72        simd::dispatch_if_large(len, || {
73            for i in 0..len {
74                dst[i] = dst[i] * Op::apply(src[i]);
75            }
76        });
77    } else {
78        let mut dp = dp;
79        let mut sp = sp;
80        for _ in 0..len {
81            *dp = *dp * Op::apply(*sp);
82            dp = dp.offset(ds);
83            sp = sp.offset(ss);
84        }
85    }
86}
87
88/// Inner loop for axpy: `dst[i] = alpha * Op::apply(src[i]) + dst[i]`.
89#[inline(always)]
90unsafe fn inner_loop_axpy<
91    D: Copy + Add<D, Output = D>,
92    S: Copy,
93    A: Copy + Mul<S, Output = D>,
94    Op: ElementOp<S>,
95>(
96    dp: *mut D,
97    ds: isize,
98    sp: *const S,
99    ss: isize,
100    len: usize,
101    alpha: A,
102) {
103    if ds == 1 && ss == 1 {
104        let dst = std::slice::from_raw_parts_mut(dp, len);
105        let src = std::slice::from_raw_parts(sp, len);
106        simd::dispatch_if_large(len, || {
107            for i in 0..len {
108                dst[i] = alpha * Op::apply(src[i]) + dst[i];
109            }
110        });
111    } else {
112        let mut dp = dp;
113        let mut sp = sp;
114        for _ in 0..len {
115            *dp = alpha * Op::apply(*sp) + *dp;
116            dp = dp.offset(ds);
117            sp = sp.offset(ss);
118        }
119    }
120}
121
122/// Inner loop for fma: `dst[i] += OpA::apply(a[i]) * OpB::apply(b[i])`.
123#[inline(always)]
124unsafe fn inner_loop_fma<
125    D: Copy + Add<D, Output = D>,
126    A: Copy + Mul<B, Output = D>,
127    B: Copy,
128    OpA: ElementOp<A>,
129    OpB: ElementOp<B>,
130>(
131    dp: *mut D,
132    ds: isize,
133    ap: *const A,
134    a_s: isize,
135    bp: *const B,
136    b_s: isize,
137    len: usize,
138) {
139    if ds == 1 && a_s == 1 && b_s == 1 {
140        let dst = std::slice::from_raw_parts_mut(dp, len);
141        let sa = std::slice::from_raw_parts(ap, len);
142        let sb = std::slice::from_raw_parts(bp, len);
143        simd::dispatch_if_large(len, || {
144            for i in 0..len {
145                dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
146            }
147        });
148    } else {
149        let mut dp = dp;
150        let mut ap = ap;
151        let mut bp = bp;
152        for _ in 0..len {
153            *dp = *dp + OpA::apply(*ap) * OpB::apply(*bp);
154            dp = dp.offset(ds);
155            ap = ap.offset(a_s);
156            bp = bp.offset(b_s);
157        }
158    }
159}
160
161/// Inner loop for dot: `acc += OpA::apply(a[i]) * OpB::apply(b[i])`.
162#[inline(always)]
163unsafe fn inner_loop_dot<
164    A: Copy + Mul<B, Output = R>,
165    B: Copy,
166    R: Copy + Add<R, Output = R>,
167    OpA: ElementOp<A>,
168    OpB: ElementOp<B>,
169>(
170    ap: *const A,
171    a_s: isize,
172    bp: *const B,
173    b_s: isize,
174    len: usize,
175    mut acc: R,
176) -> R {
177    if a_s == 1 && b_s == 1 {
178        let sa = std::slice::from_raw_parts(ap, len);
179        let sb = std::slice::from_raw_parts(bp, len);
180        simd::dispatch_if_large(len, || {
181            for i in 0..len {
182                acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
183            }
184        });
185    } else {
186        let mut ap = ap;
187        let mut bp = bp;
188        for _ in 0..len {
189            acc = acc + OpA::apply(*ap) * OpB::apply(*bp);
190            ap = ap.offset(a_s);
191            bp = bp.offset(b_s);
192        }
193    }
194    acc
195}
196
197/// Copy elements from source to destination: `dest[i] = src[i]`.
198pub fn copy_into<T: Copy + MaybeSendSync, Op: ElementOp<T>>(
199    dest: &mut StridedViewMut<T>,
200    src: &StridedView<T, Op>,
201) -> Result<()> {
202    ensure_same_shape(dest.dims(), src.dims())?;
203
204    let dst_ptr = dest.as_mut_ptr();
205    let src_ptr = src.ptr();
206    let dst_dims = dest.dims();
207    let dst_strides = dest.strides();
208    let src_strides = src.strides();
209
210    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
211        let len = total_len(dst_dims);
212        if Op::IS_IDENTITY {
213            debug_assert!(
214                {
215                    let nbytes = len
216                        .checked_mul(std::mem::size_of::<T>())
217                        .expect("copy size must not overflow");
218                    let dst_start = dst_ptr as usize;
219                    let src_start = src_ptr as usize;
220                    let dst_end = dst_start.saturating_add(nbytes);
221                    let src_end = src_start.saturating_add(nbytes);
222                    dst_end <= src_start || src_end <= dst_start
223                },
224                "overlapping src/dest is not supported"
225            );
226            unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
227        } else {
228            let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
229            let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
230            simd::dispatch_if_large(len, || {
231                for i in 0..len {
232                    dst[i] = Op::apply(src[i]);
233                }
234            });
235        }
236        return Ok(());
237    }
238
239    map_into(dest, src, |x| x)
240}
241
242/// Copy elements from `src` to `dst`, optimized for col-major destination.
243///
244/// Delegates to `strided_perm::copy_into_col_major` for the actual work.
245pub fn copy_into_col_major<T: Copy + MaybeSendSync>(
246    dst: &mut StridedViewMut<T>,
247    src: &StridedView<T>,
248) -> Result<()> {
249    strided_perm::copy_into_col_major(dst, src)
250}
251
252/// Element-wise addition: `dest[i] += src[i]`.
253///
254/// Source may have a different element type from destination.
255pub fn add<
256    D: Copy + Add<S, Output = D> + MaybeSendSync,
257    S: Copy + MaybeSendSync,
258    Op: ElementOp<S>,
259>(
260    dest: &mut StridedViewMut<D>,
261    src: &StridedView<S, Op>,
262) -> Result<()> {
263    ensure_same_shape(dest.dims(), src.dims())?;
264
265    let dst_ptr = dest.as_mut_ptr();
266    let src_ptr = src.ptr();
267    let dst_dims = dest.dims();
268    let dst_strides = dest.strides();
269    let src_strides = src.strides();
270
271    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
272        let len = total_len(dst_dims);
273        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
274        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
275        simd::dispatch_if_large(len, || {
276            for i in 0..len {
277                dst[i] = dst[i] + Op::apply(src[i]);
278            }
279        });
280        return Ok(());
281    }
282
283    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
284    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
285
286    let (fused_dims, ordered_strides, plan) =
287        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
288
289    #[cfg(feature = "parallel")]
290    {
291        let total: usize = fused_dims.iter().product();
292        if total > MINTHREADLENGTH {
293            let dst_send = SendPtr(dst_ptr);
294            let src_send = SendPtr(src_ptr as *mut S);
295
296            let costs = compute_costs(&ordered_strides);
297            let initial_offsets = vec![0isize; strides_list.len()];
298            let nthreads = rayon::current_num_threads();
299
300            return mapreduce_threaded(
301                &fused_dims,
302                &plan.block,
303                &ordered_strides,
304                &initial_offsets,
305                &costs,
306                nthreads,
307                0,
308                1,
309                &|dims, blocks, strides_list, offsets| {
310                    for_each_inner_block_with_offsets(
311                        dims,
312                        blocks,
313                        strides_list,
314                        offsets,
315                        |offsets, len, strides| {
316                            unsafe {
317                                inner_loop_add::<D, S, Op>(
318                                    dst_send.as_ptr().offset(offsets[0]),
319                                    strides[0],
320                                    src_send.as_const().offset(offsets[1]),
321                                    strides[1],
322                                    len,
323                                )
324                            };
325                            Ok(())
326                        },
327                    )
328                },
329            );
330        }
331    }
332
333    let initial_offsets = vec![0isize; ordered_strides.len()];
334    for_each_inner_block_preordered(
335        &fused_dims,
336        &plan.block,
337        &ordered_strides,
338        &initial_offsets,
339        |offsets, len, strides| {
340            unsafe {
341                inner_loop_add::<D, S, Op>(
342                    dst_ptr.offset(offsets[0]),
343                    strides[0],
344                    src_ptr.offset(offsets[1]),
345                    strides[1],
346                    len,
347                )
348            };
349            Ok(())
350        },
351    )
352}
353
354/// Element-wise multiplication: `dest[i] *= src[i]`.
355///
356/// Source may have a different element type from destination.
357pub fn mul<
358    D: Copy + Mul<S, Output = D> + MaybeSendSync,
359    S: Copy + MaybeSendSync,
360    Op: ElementOp<S>,
361>(
362    dest: &mut StridedViewMut<D>,
363    src: &StridedView<S, Op>,
364) -> Result<()> {
365    ensure_same_shape(dest.dims(), src.dims())?;
366
367    let dst_ptr = dest.as_mut_ptr();
368    let src_ptr = src.ptr();
369    let dst_dims = dest.dims();
370    let dst_strides = dest.strides();
371    let src_strides = src.strides();
372
373    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
374        let len = total_len(dst_dims);
375        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
376        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
377        simd::dispatch_if_large(len, || {
378            for i in 0..len {
379                dst[i] = dst[i] * Op::apply(src[i]);
380            }
381        });
382        return Ok(());
383    }
384
385    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
386    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
387
388    let (fused_dims, ordered_strides, plan) =
389        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
390
391    #[cfg(feature = "parallel")]
392    {
393        let total: usize = fused_dims.iter().product();
394        if total > MINTHREADLENGTH {
395            let dst_send = SendPtr(dst_ptr);
396            let src_send = SendPtr(src_ptr as *mut S);
397
398            let costs = compute_costs(&ordered_strides);
399            let initial_offsets = vec![0isize; strides_list.len()];
400            let nthreads = rayon::current_num_threads();
401
402            return mapreduce_threaded(
403                &fused_dims,
404                &plan.block,
405                &ordered_strides,
406                &initial_offsets,
407                &costs,
408                nthreads,
409                0,
410                1,
411                &|dims, blocks, strides_list, offsets| {
412                    for_each_inner_block_with_offsets(
413                        dims,
414                        blocks,
415                        strides_list,
416                        offsets,
417                        |offsets, len, strides| {
418                            unsafe {
419                                inner_loop_mul::<D, S, Op>(
420                                    dst_send.as_ptr().offset(offsets[0]),
421                                    strides[0],
422                                    src_send.as_const().offset(offsets[1]),
423                                    strides[1],
424                                    len,
425                                )
426                            };
427                            Ok(())
428                        },
429                    )
430                },
431            );
432        }
433    }
434
435    let initial_offsets = vec![0isize; ordered_strides.len()];
436    for_each_inner_block_preordered(
437        &fused_dims,
438        &plan.block,
439        &ordered_strides,
440        &initial_offsets,
441        |offsets, len, strides| {
442            unsafe {
443                inner_loop_mul::<D, S, Op>(
444                    dst_ptr.offset(offsets[0]),
445                    strides[0],
446                    src_ptr.offset(offsets[1]),
447                    strides[1],
448                    len,
449                )
450            };
451            Ok(())
452        },
453    )
454}
455
456/// AXPY: `dest[i] = alpha * src[i] + dest[i]`.
457///
458/// Alpha, source, and destination may have different element types.
459pub fn axpy<D, S, A, Op>(
460    dest: &mut StridedViewMut<D>,
461    src: &StridedView<S, Op>,
462    alpha: A,
463) -> Result<()>
464where
465    A: Copy + Mul<S, Output = D> + MaybeSync,
466    D: Copy + Add<D, Output = D> + MaybeSendSync,
467    S: Copy + MaybeSendSync,
468    Op: ElementOp<S>,
469{
470    ensure_same_shape(dest.dims(), src.dims())?;
471
472    let dst_ptr = dest.as_mut_ptr();
473    let src_ptr = src.ptr();
474    let dst_dims = dest.dims();
475    let dst_strides = dest.strides();
476    let src_strides = src.strides();
477
478    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
479        let len = total_len(dst_dims);
480        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
481        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
482        simd::dispatch_if_large(len, || {
483            for i in 0..len {
484                dst[i] = alpha * Op::apply(src[i]) + dst[i];
485            }
486        });
487        return Ok(());
488    }
489
490    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
491    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
492
493    let (fused_dims, ordered_strides, plan) =
494        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
495
496    #[cfg(feature = "parallel")]
497    {
498        let total: usize = fused_dims.iter().product();
499        if total > MINTHREADLENGTH {
500            let dst_send = SendPtr(dst_ptr);
501            let src_send = SendPtr(src_ptr as *mut S);
502
503            let costs = compute_costs(&ordered_strides);
504            let initial_offsets = vec![0isize; strides_list.len()];
505            let nthreads = rayon::current_num_threads();
506
507            return mapreduce_threaded(
508                &fused_dims,
509                &plan.block,
510                &ordered_strides,
511                &initial_offsets,
512                &costs,
513                nthreads,
514                0,
515                1,
516                &|dims, blocks, strides_list, offsets| {
517                    for_each_inner_block_with_offsets(
518                        dims,
519                        blocks,
520                        strides_list,
521                        offsets,
522                        |offsets, len, strides| {
523                            unsafe {
524                                inner_loop_axpy::<D, S, A, Op>(
525                                    dst_send.as_ptr().offset(offsets[0]),
526                                    strides[0],
527                                    src_send.as_const().offset(offsets[1]),
528                                    strides[1],
529                                    len,
530                                    alpha,
531                                )
532                            };
533                            Ok(())
534                        },
535                    )
536                },
537            );
538        }
539    }
540
541    let initial_offsets = vec![0isize; ordered_strides.len()];
542    for_each_inner_block_preordered(
543        &fused_dims,
544        &plan.block,
545        &ordered_strides,
546        &initial_offsets,
547        |offsets, len, strides| {
548            unsafe {
549                inner_loop_axpy::<D, S, A, Op>(
550                    dst_ptr.offset(offsets[0]),
551                    strides[0],
552                    src_ptr.offset(offsets[1]),
553                    strides[1],
554                    len,
555                    alpha,
556                )
557            };
558            Ok(())
559        },
560    )
561}
562
563/// Fused multiply-add: `dest[i] += OpA::apply(a[i]) * OpB::apply(b[i])`.
564///
565/// Operands may have different element types. Element operations are applied lazily.
566pub fn fma<D, A, B, OpA, OpB>(
567    dest: &mut StridedViewMut<D>,
568    a: &StridedView<A, OpA>,
569    b: &StridedView<B, OpB>,
570) -> Result<()>
571where
572    A: Copy + Mul<B, Output = D> + MaybeSendSync,
573    B: Copy + MaybeSendSync,
574    D: Copy + Add<D, Output = D> + MaybeSendSync,
575    OpA: ElementOp<A>,
576    OpB: ElementOp<B>,
577{
578    ensure_same_shape(dest.dims(), a.dims())?;
579    ensure_same_shape(dest.dims(), b.dims())?;
580
581    let dst_ptr = dest.as_mut_ptr();
582    let a_ptr = a.ptr();
583    let b_ptr = b.ptr();
584    let dst_dims = dest.dims();
585    let dst_strides = dest.strides();
586    let a_strides = a.strides();
587    let b_strides = b.strides();
588
589    if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
590        let len = total_len(dst_dims);
591        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
592        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
593        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
594        simd::dispatch_if_large(len, || {
595            for i in 0..len {
596                dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
597            }
598        });
599        return Ok(());
600    }
601
602    let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
603    let elem_size = std::mem::size_of::<D>()
604        .max(std::mem::size_of::<A>())
605        .max(std::mem::size_of::<B>());
606
607    let (fused_dims, ordered_strides, plan) =
608        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
609
610    #[cfg(feature = "parallel")]
611    {
612        let total: usize = fused_dims.iter().product();
613        if total > MINTHREADLENGTH {
614            let dst_send = SendPtr(dst_ptr);
615            let a_send = SendPtr(a_ptr as *mut A);
616            let b_send = SendPtr(b_ptr as *mut B);
617
618            let costs = compute_costs(&ordered_strides);
619            let initial_offsets = vec![0isize; strides_list.len()];
620            let nthreads = rayon::current_num_threads();
621
622            return mapreduce_threaded(
623                &fused_dims,
624                &plan.block,
625                &ordered_strides,
626                &initial_offsets,
627                &costs,
628                nthreads,
629                0,
630                1,
631                &|dims, blocks, strides_list, offsets| {
632                    for_each_inner_block_with_offsets(
633                        dims,
634                        blocks,
635                        strides_list,
636                        offsets,
637                        |offsets, len, strides| {
638                            unsafe {
639                                inner_loop_fma::<D, A, B, OpA, OpB>(
640                                    dst_send.as_ptr().offset(offsets[0]),
641                                    strides[0],
642                                    a_send.as_const().offset(offsets[1]),
643                                    strides[1],
644                                    b_send.as_const().offset(offsets[2]),
645                                    strides[2],
646                                    len,
647                                )
648                            };
649                            Ok(())
650                        },
651                    )
652                },
653            );
654        }
655    }
656
657    let initial_offsets = vec![0isize; ordered_strides.len()];
658    for_each_inner_block_preordered(
659        &fused_dims,
660        &plan.block,
661        &ordered_strides,
662        &initial_offsets,
663        |offsets, len, strides| {
664            unsafe {
665                inner_loop_fma::<D, A, B, OpA, OpB>(
666                    dst_ptr.offset(offsets[0]),
667                    strides[0],
668                    a_ptr.offset(offsets[1]),
669                    strides[1],
670                    b_ptr.offset(offsets[2]),
671                    strides[2],
672                    len,
673                )
674            };
675            Ok(())
676        },
677    )
678}
679
680#[cfg(feature = "parallel")]
681fn parallel_simd_sum<T: Copy + Zero + Add<Output = T> + simd::MaybeSimdOps + Send + Sync>(
682    src: &[T],
683) -> Option<T> {
684    use rayon::prelude::*;
685    // Check that T has SIMD support
686    if T::try_simd_sum(&[]).is_none() {
687        return None;
688    }
689    let nthreads = rayon::current_num_threads();
690    let chunk_size = (src.len() + nthreads - 1) / nthreads;
691    let result = src
692        .par_chunks(chunk_size)
693        .map(|chunk| T::try_simd_sum(chunk).unwrap())
694        .reduce(|| T::zero(), |a, b| a + b);
695    Some(result)
696}
697
698/// Sum all elements: `sum(src)`.
699pub fn sum<
700    T: Copy + Zero + Add<Output = T> + MaybeSendSync + simd::MaybeSimdOps,
701    Op: ElementOp<T>,
702>(
703    src: &StridedView<T, Op>,
704) -> Result<T> {
705    // SIMD fast path: contiguous Identity view with SIMD support
706    if Op::IS_IDENTITY {
707        if same_contiguous_layout(src.dims(), &[src.strides()]).is_some() {
708            let len = total_len(src.dims());
709            let src_slice = unsafe { std::slice::from_raw_parts(src.ptr(), len) };
710
711            #[cfg(feature = "parallel")]
712            if len > MINTHREADLENGTH {
713                if let Some(result) = parallel_simd_sum(src_slice) {
714                    return Ok(result);
715                }
716            }
717
718            if let Some(result) = T::try_simd_sum(src_slice) {
719                return Ok(result);
720            }
721        }
722    }
723    reduce(src, |x| x, |a, b| a + b, T::zero())
724}
725
726/// Dot product: `sum(OpA::apply(a[i]) * OpB::apply(b[i]))`.
727///
728/// Operands may have different element types. Result type `R` must be `A * B`.
729/// SIMD fast path fires only when `A == B == R` (same type) and both Identity ops.
730pub fn dot<A, B, R, OpA, OpB>(a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) -> Result<R>
731where
732    A: Copy + Mul<B, Output = R> + MaybeSendSync + 'static,
733    B: Copy + MaybeSendSync + 'static,
734    R: Copy + Zero + Add<Output = R> + MaybeSendSync + simd::MaybeSimdOps + 'static,
735    OpA: ElementOp<A>,
736    OpB: ElementOp<B>,
737{
738    ensure_same_shape(a.dims(), b.dims())?;
739
740    let a_ptr = a.ptr();
741    let b_ptr = b.ptr();
742    let a_strides = a.strides();
743    let b_strides = b.strides();
744    let a_dims = a.dims();
745
746    if same_contiguous_layout(a_dims, &[a_strides, b_strides]).is_some() {
747        let len = total_len(a_dims);
748
749        // SIMD fast path: both contiguous, both Identity ops, same type
750        if OpA::IS_IDENTITY
751            && OpB::IS_IDENTITY
752            && std::any::TypeId::of::<A>() == std::any::TypeId::of::<R>()
753            && std::any::TypeId::of::<B>() == std::any::TypeId::of::<R>()
754        {
755            let sa = unsafe { std::slice::from_raw_parts(a_ptr as *const R, len) };
756            let sb = unsafe { std::slice::from_raw_parts(b_ptr as *const R, len) };
757            if let Some(result) = R::try_simd_dot(sa, sb) {
758                return Ok(result);
759            }
760        }
761
762        // Generic contiguous fast path
763        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
764        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
765        let mut acc = R::zero();
766        simd::dispatch_if_large(len, || {
767            for i in 0..len {
768                acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
769            }
770        });
771        return Ok(acc);
772    }
773
774    let strides_list: [&[isize]; 2] = [a_strides, b_strides];
775    let elem_size = std::mem::size_of::<A>()
776        .max(std::mem::size_of::<B>())
777        .max(std::mem::size_of::<R>());
778
779    let (fused_dims, ordered_strides, plan) =
780        build_plan_fused(a_dims, &strides_list, None, elem_size);
781
782    let mut acc = R::zero();
783    let initial_offsets = vec![0isize; ordered_strides.len()];
784    for_each_inner_block_preordered(
785        &fused_dims,
786        &plan.block,
787        &ordered_strides,
788        &initial_offsets,
789        |offsets, len, strides| {
790            acc = unsafe {
791                inner_loop_dot::<A, B, R, OpA, OpB>(
792                    a_ptr.offset(offsets[0]),
793                    strides[0],
794                    b_ptr.offset(offsets[1]),
795                    strides[1],
796                    len,
797                    acc,
798                )
799            };
800            Ok(())
801        },
802    )?;
803
804    Ok(acc)
805}
806
807/// Symmetrize a square matrix: `dest = (src + src^T) / 2`.
808pub fn symmetrize_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
809where
810    T: Copy
811        + Add<Output = T>
812        + Mul<Output = T>
813        + num_traits::FromPrimitive
814        + std::ops::Div<Output = T>
815        + MaybeSendSync,
816{
817    if src.ndim() != 2 {
818        return Err(StridedError::RankMismatch(src.ndim(), 2));
819    }
820    let rows = src.dims()[0];
821    let cols = src.dims()[1];
822    if rows != cols {
823        return Err(StridedError::NonSquare { rows, cols });
824    }
825
826    let src_t = src.permute(&[1, 0])?;
827    let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
828
829    zip_map2_into(dest, src, &src_t, |a, b| (a + b) * half)
830}
831
832/// Conjugate-symmetrize a square matrix: `dest = (src + conj(src^T)) / 2`.
833pub fn symmetrize_conj_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
834where
835    T: Copy
836        + ElementOpApply
837        + Add<Output = T>
838        + Mul<Output = T>
839        + num_traits::FromPrimitive
840        + std::ops::Div<Output = T>
841        + MaybeSendSync,
842{
843    if src.ndim() != 2 {
844        return Err(StridedError::RankMismatch(src.ndim(), 2));
845    }
846    let rows = src.dims()[0];
847    let cols = src.dims()[1];
848    if rows != cols {
849        return Err(StridedError::NonSquare { rows, cols });
850    }
851
852    // adjoint = conj + transpose
853    let src_adj = src.adjoint_2d()?;
854    let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
855
856    zip_map2_into(dest, src, &src_adj, |a, b| (a + b) * half)
857}
858
859/// Copy with scaling: `dest[i] = scale * src[i]`.
860///
861/// Scale, source, and destination may have different element types.
862pub fn copy_scale<D, S, A, Op>(
863    dest: &mut StridedViewMut<D>,
864    src: &StridedView<S, Op>,
865    scale: A,
866) -> Result<()>
867where
868    A: Copy + Mul<S, Output = D> + MaybeSync,
869    D: Copy + MaybeSendSync,
870    S: Copy + MaybeSendSync,
871    Op: ElementOp<S>,
872{
873    map_into(dest, src, |x| scale * x)
874}
875
876/// Copy with complex conjugation: `dest[i] = conj(src[i])`.
877pub fn copy_conj<T: Copy + ElementOpApply + MaybeSendSync>(
878    dest: &mut StridedViewMut<T>,
879    src: &StridedView<T>,
880) -> Result<()> {
881    let src_conj = src.conj();
882    copy_into(dest, &src_conj)
883}
884
885/// Copy with transpose and scaling: `dest[j,i] = scale * src[i,j]`.
886pub fn copy_transpose_scale_into<T>(
887    dest: &mut StridedViewMut<T>,
888    src: &StridedView<T>,
889    scale: T,
890) -> Result<()>
891where
892    T: Copy + ElementOpApply + Mul<Output = T> + MaybeSendSync,
893{
894    if src.ndim() != 2 || dest.ndim() != 2 {
895        return Err(StridedError::RankMismatch(src.ndim(), 2));
896    }
897    let src_t = src.transpose_2d()?;
898    map_into(dest, &src_t, |x| scale * x)
899}