Skip to main content

strided_kernel/
ops_view.rs

1//! High-level operations on dynamic-rank strided views.
2
3use crate::kernel::{
4    build_plan_fused, ensure_same_shape, for_each_inner_block_preordered, same_contiguous_layout,
5    sequential_contiguous_layout, total_len,
6};
7use crate::map_view::{map_into, zip_map2_into};
8use crate::maybe_sync::{MaybeSendSync, MaybeSync};
9use crate::reduce_view::reduce;
10use crate::simd;
11use crate::view::{StridedView, StridedViewMut};
12use crate::{Result, StridedError};
13use num_traits::{One, Zero};
14use std::ops::{Add, Mul};
15use strided_view::{ElementOp, ElementOpApply};
16
17#[cfg(feature = "parallel")]
18use crate::fuse::compute_costs;
19#[cfg(feature = "parallel")]
20use crate::threading::{
21    for_each_inner_block_with_offsets, mapreduce_threaded, SendPtr, MINTHREADLENGTH,
22};
23#[cfg(feature = "parallel")]
24use rayon::iter::{IntoParallelIterator, ParallelIterator};
25
26// ============================================================================
27// Stride-specialized inner loop helpers for ops_view
28//
29// When all inner strides are 1 (contiguous in the innermost dimension),
30// we use slice-based iteration so LLVM can auto-vectorize effectively.
31// This mirrors the inner_loop_map* helpers in map_view.rs.
32// ============================================================================
33
34/// Inner loop for add: `dst[i] += Op::apply(src[i])`.
35#[inline(always)]
36unsafe fn inner_loop_add<D: Copy + Add<S, Output = D>, S: Copy, Op: ElementOp<S>>(
37    dp: *mut D,
38    ds: isize,
39    sp: *const S,
40    ss: isize,
41    len: usize,
42) {
43    if ds == 1 && ss == 1 {
44        let dst = std::slice::from_raw_parts_mut(dp, len);
45        let src = std::slice::from_raw_parts(sp, len);
46        simd::dispatch_if_large(len, || {
47            for i in 0..len {
48                dst[i] = dst[i] + Op::apply(src[i]);
49            }
50        });
51    } else {
52        let mut dp = dp;
53        let mut sp = sp;
54        for _ in 0..len {
55            *dp = *dp + Op::apply(*sp);
56            dp = dp.offset(ds);
57            sp = sp.offset(ss);
58        }
59    }
60}
61
62/// Inner loop for mul: `dst[i] *= Op::apply(src[i])`.
63#[inline(always)]
64unsafe fn inner_loop_mul<D: Copy + Mul<S, Output = D>, S: Copy, Op: ElementOp<S>>(
65    dp: *mut D,
66    ds: isize,
67    sp: *const S,
68    ss: isize,
69    len: usize,
70) {
71    if ds == 1 && ss == 1 {
72        let dst = std::slice::from_raw_parts_mut(dp, len);
73        let src = std::slice::from_raw_parts(sp, len);
74        simd::dispatch_if_large(len, || {
75            for i in 0..len {
76                dst[i] = dst[i] * Op::apply(src[i]);
77            }
78        });
79    } else {
80        let mut dp = dp;
81        let mut sp = sp;
82        for _ in 0..len {
83            *dp = *dp * Op::apply(*sp);
84            dp = dp.offset(ds);
85            sp = sp.offset(ss);
86        }
87    }
88}
89
90/// Inner loop for axpy: `dst[i] = alpha * Op::apply(src[i]) + dst[i]`.
91#[inline(always)]
92unsafe fn inner_loop_axpy<
93    D: Copy + Add<D, Output = D>,
94    S: Copy,
95    A: Copy + Mul<S, Output = D>,
96    Op: ElementOp<S>,
97>(
98    dp: *mut D,
99    ds: isize,
100    sp: *const S,
101    ss: isize,
102    len: usize,
103    alpha: A,
104) {
105    if ds == 1 && ss == 1 {
106        let dst = std::slice::from_raw_parts_mut(dp, len);
107        let src = std::slice::from_raw_parts(sp, len);
108        simd::dispatch_if_large(len, || {
109            for i in 0..len {
110                dst[i] = alpha * Op::apply(src[i]) + dst[i];
111            }
112        });
113    } else {
114        let mut dp = dp;
115        let mut sp = sp;
116        for _ in 0..len {
117            *dp = alpha * Op::apply(*sp) + *dp;
118            dp = dp.offset(ds);
119            sp = sp.offset(ss);
120        }
121    }
122}
123
124/// Inner loop for fma: `dst[i] += OpA::apply(a[i]) * OpB::apply(b[i])`.
125#[inline(always)]
126unsafe fn inner_loop_fma<
127    D: Copy + Add<D, Output = D>,
128    A: Copy + Mul<B, Output = D>,
129    B: Copy,
130    OpA: ElementOp<A>,
131    OpB: ElementOp<B>,
132>(
133    dp: *mut D,
134    ds: isize,
135    ap: *const A,
136    a_s: isize,
137    bp: *const B,
138    b_s: isize,
139    len: usize,
140) {
141    if ds == 1 && a_s == 1 && b_s == 1 {
142        let dst = std::slice::from_raw_parts_mut(dp, len);
143        let sa = std::slice::from_raw_parts(ap, len);
144        let sb = std::slice::from_raw_parts(bp, len);
145        simd::dispatch_if_large(len, || {
146            for i in 0..len {
147                dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
148            }
149        });
150    } else {
151        let mut dp = dp;
152        let mut ap = ap;
153        let mut bp = bp;
154        for _ in 0..len {
155            *dp = *dp + OpA::apply(*ap) * OpB::apply(*bp);
156            dp = dp.offset(ds);
157            ap = ap.offset(a_s);
158            bp = bp.offset(b_s);
159        }
160    }
161}
162
163/// Inner loop for dot: `acc += OpA::apply(a[i]) * OpB::apply(b[i])`.
164#[inline(always)]
165unsafe fn inner_loop_dot<
166    A: Copy + Mul<B, Output = R>,
167    B: Copy,
168    R: Copy + Add<R, Output = R>,
169    OpA: ElementOp<A>,
170    OpB: ElementOp<B>,
171>(
172    ap: *const A,
173    a_s: isize,
174    bp: *const B,
175    b_s: isize,
176    len: usize,
177    mut acc: R,
178) -> R {
179    if a_s == 1 && b_s == 1 {
180        let sa = std::slice::from_raw_parts(ap, len);
181        let sb = std::slice::from_raw_parts(bp, len);
182        simd::dispatch_if_large(len, || {
183            for i in 0..len {
184                acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
185            }
186        });
187    } else {
188        let mut ap = ap;
189        let mut bp = bp;
190        for _ in 0..len {
191            acc = acc + OpA::apply(*ap) * OpB::apply(*bp);
192            ap = ap.offset(a_s);
193            bp = bp.offset(b_s);
194        }
195    }
196    acc
197}
198
199/// Copy elements from source to destination: `dest[i] = src[i]`.
200pub fn copy_into<T: Copy + MaybeSendSync, Op: ElementOp<T>>(
201    dest: &mut StridedViewMut<T>,
202    src: &StridedView<T, Op>,
203) -> Result<()> {
204    ensure_same_shape(dest.dims(), src.dims())?;
205
206    let dst_ptr = dest.as_mut_ptr();
207    let src_ptr = src.ptr();
208    let dst_dims = dest.dims();
209    let dst_strides = dest.strides();
210    let src_strides = src.strides();
211
212    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
213        let len = total_len(dst_dims);
214        if Op::IS_IDENTITY {
215            debug_assert!(
216                {
217                    let nbytes = len
218                        .checked_mul(std::mem::size_of::<T>())
219                        .expect("copy size must not overflow");
220                    let dst_start = dst_ptr as usize;
221                    let src_start = src_ptr as usize;
222                    let dst_end = dst_start.saturating_add(nbytes);
223                    let src_end = src_start.saturating_add(nbytes);
224                    dst_end <= src_start || src_end <= dst_start
225                },
226                "overlapping src/dest is not supported"
227            );
228            unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
229        } else {
230            let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
231            let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
232            simd::dispatch_if_large(len, || {
233                for i in 0..len {
234                    dst[i] = Op::apply(src[i]);
235                }
236            });
237        }
238        return Ok(());
239    }
240
241    map_into(dest, src, |x| x)
242}
243
244/// Copy elements from `src` to `dst`, optimized for col-major destination.
245///
246/// Delegates to `strided_perm::copy_into_col_major` for the actual work.
247pub fn copy_into_col_major<T: Copy + MaybeSendSync>(
248    dst: &mut StridedViewMut<T>,
249    src: &StridedView<T>,
250) -> Result<()> {
251    strided_perm::copy_into_col_major(dst, src)
252}
253
254/// Element-wise addition: `dest[i] += src[i]`.
255///
256/// Source may have a different element type from destination.
257pub fn add<
258    D: Copy + Add<S, Output = D> + MaybeSendSync,
259    S: Copy + MaybeSendSync,
260    Op: ElementOp<S>,
261>(
262    dest: &mut StridedViewMut<D>,
263    src: &StridedView<S, Op>,
264) -> Result<()> {
265    ensure_same_shape(dest.dims(), src.dims())?;
266
267    let dst_ptr = dest.as_mut_ptr();
268    let src_ptr = src.ptr();
269    let dst_dims = dest.dims();
270    let dst_strides = dest.strides();
271    let src_strides = src.strides();
272
273    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
274        let len = total_len(dst_dims);
275        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
276        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
277        simd::dispatch_if_large(len, || {
278            for i in 0..len {
279                dst[i] = dst[i] + Op::apply(src[i]);
280            }
281        });
282        return Ok(());
283    }
284
285    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
286    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
287
288    let (fused_dims, ordered_strides, plan) =
289        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
290
291    #[cfg(feature = "parallel")]
292    {
293        let total: usize = fused_dims.iter().product();
294        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
295            let dst_send = SendPtr(dst_ptr);
296            let src_send = SendPtr(src_ptr as *mut S);
297
298            let costs = compute_costs(&ordered_strides);
299            let initial_offsets = vec![0isize; strides_list.len()];
300            let nthreads = rayon::current_num_threads();
301
302            return mapreduce_threaded(
303                &fused_dims,
304                &plan.block,
305                &ordered_strides,
306                &initial_offsets,
307                &costs,
308                nthreads,
309                0,
310                1,
311                &|dims, blocks, strides_list, offsets| {
312                    for_each_inner_block_with_offsets(
313                        dims,
314                        blocks,
315                        strides_list,
316                        offsets,
317                        |offsets, len, strides| {
318                            unsafe {
319                                inner_loop_add::<D, S, Op>(
320                                    dst_send.as_ptr().offset(offsets[0]),
321                                    strides[0],
322                                    src_send.as_const().offset(offsets[1]),
323                                    strides[1],
324                                    len,
325                                )
326                            };
327                            Ok(())
328                        },
329                    )
330                },
331            );
332        }
333    }
334
335    let initial_offsets = vec![0isize; ordered_strides.len()];
336    for_each_inner_block_preordered(
337        &fused_dims,
338        &plan.block,
339        &ordered_strides,
340        &initial_offsets,
341        |offsets, len, strides| {
342            unsafe {
343                inner_loop_add::<D, S, Op>(
344                    dst_ptr.offset(offsets[0]),
345                    strides[0],
346                    src_ptr.offset(offsets[1]),
347                    strides[1],
348                    len,
349                )
350            };
351            Ok(())
352        },
353    )
354}
355
356/// Element-wise multiplication: `dest[i] *= src[i]`.
357///
358/// Source may have a different element type from destination.
359pub fn mul<
360    D: Copy + Mul<S, Output = D> + MaybeSendSync,
361    S: Copy + MaybeSendSync,
362    Op: ElementOp<S>,
363>(
364    dest: &mut StridedViewMut<D>,
365    src: &StridedView<S, Op>,
366) -> Result<()> {
367    ensure_same_shape(dest.dims(), src.dims())?;
368
369    let dst_ptr = dest.as_mut_ptr();
370    let src_ptr = src.ptr();
371    let dst_dims = dest.dims();
372    let dst_strides = dest.strides();
373    let src_strides = src.strides();
374
375    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
376        let len = total_len(dst_dims);
377        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
378        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
379        simd::dispatch_if_large(len, || {
380            for i in 0..len {
381                dst[i] = dst[i] * Op::apply(src[i]);
382            }
383        });
384        return Ok(());
385    }
386
387    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
388    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
389
390    let (fused_dims, ordered_strides, plan) =
391        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
392
393    #[cfg(feature = "parallel")]
394    {
395        let total: usize = fused_dims.iter().product();
396        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
397            let dst_send = SendPtr(dst_ptr);
398            let src_send = SendPtr(src_ptr as *mut S);
399
400            let costs = compute_costs(&ordered_strides);
401            let initial_offsets = vec![0isize; strides_list.len()];
402            let nthreads = rayon::current_num_threads();
403
404            return mapreduce_threaded(
405                &fused_dims,
406                &plan.block,
407                &ordered_strides,
408                &initial_offsets,
409                &costs,
410                nthreads,
411                0,
412                1,
413                &|dims, blocks, strides_list, offsets| {
414                    for_each_inner_block_with_offsets(
415                        dims,
416                        blocks,
417                        strides_list,
418                        offsets,
419                        |offsets, len, strides| {
420                            unsafe {
421                                inner_loop_mul::<D, S, Op>(
422                                    dst_send.as_ptr().offset(offsets[0]),
423                                    strides[0],
424                                    src_send.as_const().offset(offsets[1]),
425                                    strides[1],
426                                    len,
427                                )
428                            };
429                            Ok(())
430                        },
431                    )
432                },
433            );
434        }
435    }
436
437    let initial_offsets = vec![0isize; ordered_strides.len()];
438    for_each_inner_block_preordered(
439        &fused_dims,
440        &plan.block,
441        &ordered_strides,
442        &initial_offsets,
443        |offsets, len, strides| {
444            unsafe {
445                inner_loop_mul::<D, S, Op>(
446                    dst_ptr.offset(offsets[0]),
447                    strides[0],
448                    src_ptr.offset(offsets[1]),
449                    strides[1],
450                    len,
451                )
452            };
453            Ok(())
454        },
455    )
456}
457
458/// AXPY: `dest[i] = alpha * src[i] + dest[i]`.
459///
460/// Alpha, source, and destination may have different element types.
461pub fn axpy<D, S, A, Op>(
462    dest: &mut StridedViewMut<D>,
463    src: &StridedView<S, Op>,
464    alpha: A,
465) -> Result<()>
466where
467    A: Copy + Mul<S, Output = D> + MaybeSync,
468    D: Copy + Add<D, Output = D> + MaybeSendSync,
469    S: Copy + MaybeSendSync,
470    Op: ElementOp<S>,
471{
472    ensure_same_shape(dest.dims(), src.dims())?;
473
474    let dst_ptr = dest.as_mut_ptr();
475    let src_ptr = src.ptr();
476    let dst_dims = dest.dims();
477    let dst_strides = dest.strides();
478    let src_strides = src.strides();
479
480    if sequential_contiguous_layout(dst_dims, &[dst_strides, src_strides]).is_some() {
481        let len = total_len(dst_dims);
482        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
483        let src = unsafe { std::slice::from_raw_parts(src_ptr, len) };
484        simd::dispatch_if_large(len, || {
485            for i in 0..len {
486                dst[i] = alpha * Op::apply(src[i]) + dst[i];
487            }
488        });
489        return Ok(());
490    }
491
492    let strides_list: [&[isize]; 2] = [dst_strides, src_strides];
493    let elem_size = std::mem::size_of::<D>().max(std::mem::size_of::<S>());
494
495    let (fused_dims, ordered_strides, plan) =
496        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
497
498    #[cfg(feature = "parallel")]
499    {
500        let total: usize = fused_dims.iter().product();
501        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
502            let dst_send = SendPtr(dst_ptr);
503            let src_send = SendPtr(src_ptr as *mut S);
504
505            let costs = compute_costs(&ordered_strides);
506            let initial_offsets = vec![0isize; strides_list.len()];
507            let nthreads = rayon::current_num_threads();
508
509            return mapreduce_threaded(
510                &fused_dims,
511                &plan.block,
512                &ordered_strides,
513                &initial_offsets,
514                &costs,
515                nthreads,
516                0,
517                1,
518                &|dims, blocks, strides_list, offsets| {
519                    for_each_inner_block_with_offsets(
520                        dims,
521                        blocks,
522                        strides_list,
523                        offsets,
524                        |offsets, len, strides| {
525                            unsafe {
526                                inner_loop_axpy::<D, S, A, Op>(
527                                    dst_send.as_ptr().offset(offsets[0]),
528                                    strides[0],
529                                    src_send.as_const().offset(offsets[1]),
530                                    strides[1],
531                                    len,
532                                    alpha,
533                                )
534                            };
535                            Ok(())
536                        },
537                    )
538                },
539            );
540        }
541    }
542
543    let initial_offsets = vec![0isize; ordered_strides.len()];
544    for_each_inner_block_preordered(
545        &fused_dims,
546        &plan.block,
547        &ordered_strides,
548        &initial_offsets,
549        |offsets, len, strides| {
550            unsafe {
551                inner_loop_axpy::<D, S, A, Op>(
552                    dst_ptr.offset(offsets[0]),
553                    strides[0],
554                    src_ptr.offset(offsets[1]),
555                    strides[1],
556                    len,
557                    alpha,
558                )
559            };
560            Ok(())
561        },
562    )
563}
564
565/// Fused multiply-add: `dest[i] += OpA::apply(a[i]) * OpB::apply(b[i])`.
566///
567/// Operands may have different element types. Element operations are applied lazily.
568pub fn fma<D, A, B, OpA, OpB>(
569    dest: &mut StridedViewMut<D>,
570    a: &StridedView<A, OpA>,
571    b: &StridedView<B, OpB>,
572) -> Result<()>
573where
574    A: Copy + Mul<B, Output = D> + MaybeSendSync,
575    B: Copy + MaybeSendSync,
576    D: Copy + Add<D, Output = D> + MaybeSendSync,
577    OpA: ElementOp<A>,
578    OpB: ElementOp<B>,
579{
580    ensure_same_shape(dest.dims(), a.dims())?;
581    ensure_same_shape(dest.dims(), b.dims())?;
582
583    let dst_ptr = dest.as_mut_ptr();
584    let a_ptr = a.ptr();
585    let b_ptr = b.ptr();
586    let dst_dims = dest.dims();
587    let dst_strides = dest.strides();
588    let a_strides = a.strides();
589    let b_strides = b.strides();
590
591    if sequential_contiguous_layout(dst_dims, &[dst_strides, a_strides, b_strides]).is_some() {
592        let len = total_len(dst_dims);
593        let dst = unsafe { std::slice::from_raw_parts_mut(dst_ptr, len) };
594        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
595        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
596        simd::dispatch_if_large(len, || {
597            for i in 0..len {
598                dst[i] = dst[i] + OpA::apply(sa[i]) * OpB::apply(sb[i]);
599            }
600        });
601        return Ok(());
602    }
603
604    let strides_list: [&[isize]; 3] = [dst_strides, a_strides, b_strides];
605    let elem_size = std::mem::size_of::<D>()
606        .max(std::mem::size_of::<A>())
607        .max(std::mem::size_of::<B>());
608
609    let (fused_dims, ordered_strides, plan) =
610        build_plan_fused(dst_dims, &strides_list, Some(0), elem_size);
611
612    #[cfg(feature = "parallel")]
613    {
614        let total: usize = fused_dims.iter().product();
615        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
616            let dst_send = SendPtr(dst_ptr);
617            let a_send = SendPtr(a_ptr as *mut A);
618            let b_send = SendPtr(b_ptr as *mut B);
619
620            let costs = compute_costs(&ordered_strides);
621            let initial_offsets = vec![0isize; strides_list.len()];
622            let nthreads = rayon::current_num_threads();
623
624            return mapreduce_threaded(
625                &fused_dims,
626                &plan.block,
627                &ordered_strides,
628                &initial_offsets,
629                &costs,
630                nthreads,
631                0,
632                1,
633                &|dims, blocks, strides_list, offsets| {
634                    for_each_inner_block_with_offsets(
635                        dims,
636                        blocks,
637                        strides_list,
638                        offsets,
639                        |offsets, len, strides| {
640                            unsafe {
641                                inner_loop_fma::<D, A, B, OpA, OpB>(
642                                    dst_send.as_ptr().offset(offsets[0]),
643                                    strides[0],
644                                    a_send.as_const().offset(offsets[1]),
645                                    strides[1],
646                                    b_send.as_const().offset(offsets[2]),
647                                    strides[2],
648                                    len,
649                                )
650                            };
651                            Ok(())
652                        },
653                    )
654                },
655            );
656        }
657    }
658
659    let initial_offsets = vec![0isize; ordered_strides.len()];
660    for_each_inner_block_preordered(
661        &fused_dims,
662        &plan.block,
663        &ordered_strides,
664        &initial_offsets,
665        |offsets, len, strides| {
666            unsafe {
667                inner_loop_fma::<D, A, B, OpA, OpB>(
668                    dst_ptr.offset(offsets[0]),
669                    strides[0],
670                    a_ptr.offset(offsets[1]),
671                    strides[1],
672                    b_ptr.offset(offsets[2]),
673                    strides[2],
674                    len,
675                )
676            };
677            Ok(())
678        },
679    )
680}
681
682#[cfg(feature = "parallel")]
683fn parallel_simd_sum<T: Copy + Zero + Add<Output = T> + simd::MaybeSimdOps + Send + Sync>(
684    src: &[T],
685) -> Option<T> {
686    use rayon::prelude::*;
687    // Check that T has SIMD support
688    if T::try_simd_sum(&[]).is_none() {
689        return None;
690    }
691    let nthreads = rayon::current_num_threads();
692    let chunk_size = (src.len() + nthreads - 1) / nthreads;
693    let result = src
694        .par_chunks(chunk_size)
695        .map(|chunk| T::try_simd_sum(chunk).unwrap())
696        .reduce(|| T::zero(), |a, b| a + b);
697    Some(result)
698}
699
700/// Sum all elements: `sum(src)`.
701pub fn sum<
702    T: Copy + Zero + Add<Output = T> + MaybeSendSync + simd::MaybeSimdOps,
703    Op: ElementOp<T>,
704>(
705    src: &StridedView<T, Op>,
706) -> Result<T> {
707    // SIMD fast path: contiguous Identity view with SIMD support
708    if Op::IS_IDENTITY {
709        if same_contiguous_layout(src.dims(), &[src.strides()]).is_some() {
710            let len = total_len(src.dims());
711            let src_slice = unsafe { std::slice::from_raw_parts(src.ptr(), len) };
712
713            #[cfg(feature = "parallel")]
714            if len > MINTHREADLENGTH {
715                if let Some(result) = parallel_simd_sum(src_slice) {
716                    return Ok(result);
717                }
718            }
719
720            if let Some(result) = T::try_simd_sum(src_slice) {
721                return Ok(result);
722            }
723        }
724    }
725    reduce(src, |x| x, |a, b| a + b, T::zero())
726}
727
728/// Dot product: `sum(OpA::apply(a[i]) * OpB::apply(b[i]))`.
729///
730/// Operands may have different element types. Result type `R` must be `A * B`.
731/// SIMD fast path fires only when `A == B == R` (same type) and both Identity ops.
732pub fn dot<A, B, R, OpA, OpB>(a: &StridedView<A, OpA>, b: &StridedView<B, OpB>) -> Result<R>
733where
734    A: Copy + Mul<B, Output = R> + MaybeSendSync + 'static,
735    B: Copy + MaybeSendSync + 'static,
736    R: Copy + Zero + Add<Output = R> + MaybeSendSync + simd::MaybeSimdOps + 'static,
737    OpA: ElementOp<A>,
738    OpB: ElementOp<B>,
739{
740    ensure_same_shape(a.dims(), b.dims())?;
741
742    let a_ptr = a.ptr();
743    let b_ptr = b.ptr();
744    let a_strides = a.strides();
745    let b_strides = b.strides();
746    let a_dims = a.dims();
747
748    if same_contiguous_layout(a_dims, &[a_strides, b_strides]).is_some() {
749        let len = total_len(a_dims);
750
751        // SIMD fast path: both contiguous, both Identity ops, same type
752        if OpA::IS_IDENTITY
753            && OpB::IS_IDENTITY
754            && std::any::TypeId::of::<A>() == std::any::TypeId::of::<R>()
755            && std::any::TypeId::of::<B>() == std::any::TypeId::of::<R>()
756        {
757            let sa = unsafe { std::slice::from_raw_parts(a_ptr as *const R, len) };
758            let sb = unsafe { std::slice::from_raw_parts(b_ptr as *const R, len) };
759            if let Some(result) = R::try_simd_dot(sa, sb) {
760                return Ok(result);
761            }
762        }
763
764        // Generic contiguous fast path
765        let sa = unsafe { std::slice::from_raw_parts(a_ptr, len) };
766        let sb = unsafe { std::slice::from_raw_parts(b_ptr, len) };
767        let mut acc = R::zero();
768        simd::dispatch_if_large(len, || {
769            for i in 0..len {
770                acc = acc + OpA::apply(sa[i]) * OpB::apply(sb[i]);
771            }
772        });
773        return Ok(acc);
774    }
775
776    let strides_list: [&[isize]; 2] = [a_strides, b_strides];
777    let elem_size = std::mem::size_of::<A>()
778        .max(std::mem::size_of::<B>())
779        .max(std::mem::size_of::<R>());
780
781    let (fused_dims, ordered_strides, plan) =
782        build_plan_fused(a_dims, &strides_list, None, elem_size);
783
784    let mut acc = R::zero();
785    let initial_offsets = vec![0isize; ordered_strides.len()];
786    for_each_inner_block_preordered(
787        &fused_dims,
788        &plan.block,
789        &ordered_strides,
790        &initial_offsets,
791        |offsets, len, strides| {
792            acc = unsafe {
793                inner_loop_dot::<A, B, R, OpA, OpB>(
794                    a_ptr.offset(offsets[0]),
795                    strides[0],
796                    b_ptr.offset(offsets[1]),
797                    strides[1],
798                    len,
799                    acc,
800                )
801            };
802            Ok(())
803        },
804    )?;
805
806    Ok(acc)
807}
808
809/// Symmetrize a square matrix: `dest = (src + src^T) / 2`.
810pub fn symmetrize_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
811where
812    T: Copy
813        + Add<Output = T>
814        + Mul<Output = T>
815        + num_traits::FromPrimitive
816        + std::ops::Div<Output = T>
817        + MaybeSendSync,
818{
819    if src.ndim() != 2 {
820        return Err(StridedError::RankMismatch(src.ndim(), 2));
821    }
822    let rows = src.dims()[0];
823    let cols = src.dims()[1];
824    if rows != cols {
825        return Err(StridedError::NonSquare { rows, cols });
826    }
827
828    let src_t = src.permute(&[1, 0])?;
829    let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
830
831    zip_map2_into(dest, src, &src_t, |a, b| (a + b) * half)
832}
833
834/// Conjugate-symmetrize a square matrix: `dest = (src + conj(src^T)) / 2`.
835pub fn symmetrize_conj_into<T>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()>
836where
837    T: Copy
838        + ElementOpApply
839        + Add<Output = T>
840        + Mul<Output = T>
841        + num_traits::FromPrimitive
842        + std::ops::Div<Output = T>
843        + MaybeSendSync,
844{
845    if src.ndim() != 2 {
846        return Err(StridedError::RankMismatch(src.ndim(), 2));
847    }
848    let rows = src.dims()[0];
849    let cols = src.dims()[1];
850    if rows != cols {
851        return Err(StridedError::NonSquare { rows, cols });
852    }
853
854    // adjoint = conj + transpose
855    let src_adj = src.adjoint_2d()?;
856    let half = T::from_f64(0.5).ok_or(StridedError::ScalarConversion)?;
857
858    zip_map2_into(dest, src, &src_adj, |a, b| (a + b) * half)
859}
860
861/// Copy with scaling: `dest[i] = scale * src[i]`.
862///
863/// Scale, source, and destination may have different element types.
864pub fn copy_scale<D, S, A, Op>(
865    dest: &mut StridedViewMut<D>,
866    src: &StridedView<S, Op>,
867    scale: A,
868) -> Result<()>
869where
870    A: Copy + Mul<S, Output = D> + MaybeSync,
871    D: Copy + MaybeSendSync,
872    S: Copy + MaybeSendSync,
873    Op: ElementOp<S>,
874{
875    map_into(dest, src, |x| scale * x)
876}
877
878/// Copy with complex conjugation: `dest[i] = conj(src[i])`.
879pub fn copy_conj<T: Copy + ElementOpApply + MaybeSendSync>(
880    dest: &mut StridedViewMut<T>,
881    src: &StridedView<T>,
882) -> Result<()> {
883    let src_conj = src.conj();
884    copy_into(dest, &src_conj)
885}
886
887#[inline]
888fn element_transpose_is_identity<T: 'static>() -> bool {
889    use std::any::TypeId;
890
891    macro_rules! matches_type {
892        ($($ty:ty),* $(,)?) => {{
893            let id = TypeId::of::<T>();
894            false $(|| id == TypeId::of::<$ty>())*
895        }};
896    }
897
898    matches_type!(
899        f32,
900        f64,
901        i8,
902        i16,
903        i32,
904        i64,
905        i128,
906        isize,
907        u8,
908        u16,
909        u32,
910        u64,
911        u128,
912        usize,
913        num_complex::Complex32,
914        num_complex::Complex64,
915    )
916}
917
918#[inline]
919fn element_zero_is_all_bits_zero<T: 'static>() -> bool {
920    use std::any::TypeId;
921
922    macro_rules! matches_type {
923        ($($ty:ty),* $(,)?) => {{
924            let id = TypeId::of::<T>();
925            false $(|| id == TypeId::of::<$ty>())*
926        }};
927    }
928
929    matches_type!(
930        f32,
931        f64,
932        i8,
933        i16,
934        i32,
935        i64,
936        i128,
937        isize,
938        u8,
939        u16,
940        u32,
941        u64,
942        u128,
943        usize,
944        num_complex::Complex32,
945        num_complex::Complex64,
946    )
947}
948
949#[inline]
950unsafe fn fill_2d<T: Copy + MaybeSendSync>(
951    dst: *mut T,
952    dim0: usize,
953    dim1: usize,
954    dst_stride0: isize,
955    dst_stride1: isize,
956    value: T,
957) {
958    #[cfg(feature = "parallel")]
959    {
960        let total = dim0.saturating_mul(dim1);
961        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
962            let dst_send = SendPtr(dst);
963            if dst_stride0.unsigned_abs() <= dst_stride1.unsigned_abs() {
964                (0..dim1).into_par_iter().for_each(|j| {
965                    let dst = dst_send.as_ptr();
966                    unsafe {
967                        let base = j as isize * dst_stride1;
968                        for i in 0..dim0 {
969                            *dst.offset(base + i as isize * dst_stride0) = value;
970                        }
971                    }
972                });
973            } else {
974                (0..dim0).into_par_iter().for_each(|i| {
975                    let dst = dst_send.as_ptr();
976                    unsafe {
977                        let base = i as isize * dst_stride0;
978                        for j in 0..dim1 {
979                            *dst.offset(base + j as isize * dst_stride1) = value;
980                        }
981                    }
982                });
983            }
984            return;
985        }
986    }
987
988    if dst_stride0.unsigned_abs() <= dst_stride1.unsigned_abs() {
989        for j in 0..dim1 {
990            let base = j as isize * dst_stride1;
991            for i in 0..dim0 {
992                *dst.offset(base + i as isize * dst_stride0) = value;
993            }
994        }
995    } else {
996        for i in 0..dim0 {
997            let base = i as isize * dst_stride0;
998            for j in 0..dim1 {
999                *dst.offset(base + j as isize * dst_stride1) = value;
1000            }
1001        }
1002    }
1003}
1004
1005#[inline]
1006unsafe fn fill_contiguous<T>(dst: *mut T, len: usize, value: T)
1007where
1008    T: Copy + Zero + PartialEq + MaybeSendSync + 'static,
1009{
1010    if element_zero_is_all_bits_zero::<T>() && value == T::zero() {
1011        std::ptr::write_bytes(dst, 0, len);
1012        return;
1013    }
1014
1015    let dst = std::slice::from_raw_parts_mut(dst, len);
1016    dst.fill(value);
1017}
1018
1019#[inline(always)]
1020unsafe fn transpose_scale_4x4_f64(
1021    dst: *mut f64,
1022    dst_stride0: isize,
1023    dst_stride1: isize,
1024    src: *const f64,
1025    src_stride0: isize,
1026    src_stride1: isize,
1027    i: usize,
1028    j: usize,
1029    scale: f64,
1030) {
1031    let src_base = src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1032
1033    let s00 = *src_base;
1034    let s10 = *src_base.offset(src_stride0);
1035    let s20 = *src_base.offset(2 * src_stride0);
1036    let s30 = *src_base.offset(3 * src_stride0);
1037
1038    let src_col1 = src_base.offset(src_stride1);
1039    let s01 = *src_col1;
1040    let s11 = *src_col1.offset(src_stride0);
1041    let s21 = *src_col1.offset(2 * src_stride0);
1042    let s31 = *src_col1.offset(3 * src_stride0);
1043
1044    let src_col2 = src_base.offset(2 * src_stride1);
1045    let s02 = *src_col2;
1046    let s12 = *src_col2.offset(src_stride0);
1047    let s22 = *src_col2.offset(2 * src_stride0);
1048    let s32 = *src_col2.offset(3 * src_stride0);
1049
1050    let src_col3 = src_base.offset(3 * src_stride1);
1051    let s03 = *src_col3;
1052    let s13 = *src_col3.offset(src_stride0);
1053    let s23 = *src_col3.offset(2 * src_stride0);
1054    let s33 = *src_col3.offset(3 * src_stride0);
1055
1056    let dst_row0 = dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1);
1057    *dst_row0 = scale * s00;
1058    *dst_row0.offset(dst_stride0) = scale * s01;
1059    *dst_row0.offset(2 * dst_stride0) = scale * s02;
1060    *dst_row0.offset(3 * dst_stride0) = scale * s03;
1061
1062    let dst_row1 = dst_row0.offset(dst_stride1);
1063    *dst_row1 = scale * s10;
1064    *dst_row1.offset(dst_stride0) = scale * s11;
1065    *dst_row1.offset(2 * dst_stride0) = scale * s12;
1066    *dst_row1.offset(3 * dst_stride0) = scale * s13;
1067
1068    let dst_row2 = dst_row0.offset(2 * dst_stride1);
1069    *dst_row2 = scale * s20;
1070    *dst_row2.offset(dst_stride0) = scale * s21;
1071    *dst_row2.offset(2 * dst_stride0) = scale * s22;
1072    *dst_row2.offset(3 * dst_stride0) = scale * s23;
1073
1074    let dst_row3 = dst_row0.offset(3 * dst_stride1);
1075    *dst_row3 = scale * s30;
1076    *dst_row3.offset(dst_stride0) = scale * s31;
1077    *dst_row3.offset(2 * dst_stride0) = scale * s32;
1078    *dst_row3.offset(3 * dst_stride0) = scale * s33;
1079}
1080
1081#[inline]
1082unsafe fn copy_transpose_scale_2d_f64_tiled_raw(
1083    dst: *mut f64,
1084    dst_stride0: isize,
1085    dst_stride1: isize,
1086    src: *const f64,
1087    src_stride0: isize,
1088    src_stride1: isize,
1089    src_rows: usize,
1090    src_cols: usize,
1091    scale: f64,
1092) {
1093    const TILE: usize = 4;
1094    let row_full = src_rows / TILE * TILE;
1095    let col_full = src_cols / TILE * TILE;
1096
1097    #[cfg(feature = "parallel")]
1098    {
1099        let total = src_rows.saturating_mul(src_cols);
1100        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1101            let dst_send = SendPtr(dst);
1102            let src_send = SendPtr(src as *mut f64);
1103            let row_tiles = row_full / TILE;
1104            (0..row_tiles).into_par_iter().for_each(|tile_i| {
1105                let i = tile_i * TILE;
1106                let dst = dst_send.as_ptr();
1107                let src = src_send.as_const();
1108                unsafe {
1109                    let mut j = 0;
1110                    while j < col_full {
1111                        transpose_scale_4x4_f64(
1112                            dst,
1113                            dst_stride0,
1114                            dst_stride1,
1115                            src,
1116                            src_stride0,
1117                            src_stride1,
1118                            i,
1119                            j,
1120                            scale,
1121                        );
1122                        j += TILE;
1123                    }
1124                    for j in col_full..src_cols {
1125                        for ii in i..i + TILE {
1126                            *dst.offset(j as isize * dst_stride0 + ii as isize * dst_stride1) =
1127                                scale
1128                                    * *src.offset(
1129                                        ii as isize * src_stride0 + j as isize * src_stride1,
1130                                    );
1131                        }
1132                    }
1133                }
1134            });
1135            for i in row_full..src_rows {
1136                for j in 0..src_cols {
1137                    *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1138                        scale * *src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1139                }
1140            }
1141            return;
1142        }
1143    }
1144
1145    let mut i = 0;
1146    while i < row_full {
1147        let mut j = 0;
1148        while j < col_full {
1149            transpose_scale_4x4_f64(
1150                dst,
1151                dst_stride0,
1152                dst_stride1,
1153                src,
1154                src_stride0,
1155                src_stride1,
1156                i,
1157                j,
1158                scale,
1159            );
1160            j += TILE;
1161        }
1162        for j in col_full..src_cols {
1163            for ii in i..i + TILE {
1164                *dst.offset(j as isize * dst_stride0 + ii as isize * dst_stride1) =
1165                    scale * *src.offset(ii as isize * src_stride0 + j as isize * src_stride1);
1166            }
1167        }
1168        i += TILE;
1169    }
1170    for i in row_full..src_rows {
1171        for j in 0..src_cols {
1172            *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1173                scale * *src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1174        }
1175    }
1176}
1177
1178#[inline]
1179#[cfg(test)]
1180unsafe fn try_copy_transpose_scale_2d_f64_tiled(
1181    dest: &mut StridedViewMut<f64>,
1182    src: &StridedView<f64>,
1183    scale: f64,
1184) -> bool {
1185    if src.ndim() != 2 || dest.ndim() != 2 {
1186        return false;
1187    }
1188    let src_dims = src.dims();
1189    if dest.dims() != [src_dims[1], src_dims[0]] {
1190        return false;
1191    }
1192    if src.strides()[0] != 1 || dest.strides()[0] != 1 {
1193        return false;
1194    }
1195
1196    copy_transpose_scale_2d_f64_tiled_raw(
1197        dest.as_mut_ptr(),
1198        dest.strides()[0],
1199        dest.strides()[1],
1200        src.ptr(),
1201        src.strides()[0],
1202        src.strides()[1],
1203        src_dims[0],
1204        src_dims[1],
1205        scale,
1206    );
1207    true
1208}
1209
1210#[inline]
1211unsafe fn try_copy_transpose_scale_2d_f64_tiled_typed<T>(
1212    dest: &mut StridedViewMut<T>,
1213    src: &StridedView<T>,
1214    scale: T,
1215) -> bool
1216where
1217    T: Copy + 'static,
1218{
1219    if std::any::TypeId::of::<T>() != std::any::TypeId::of::<f64>() {
1220        return false;
1221    }
1222
1223    let scale = *(&scale as *const T).cast::<f64>();
1224    if src.ndim() != 2 || dest.ndim() != 2 {
1225        return false;
1226    }
1227    let src_dims = src.dims();
1228    if dest.dims() != [src_dims[1], src_dims[0]] {
1229        return false;
1230    }
1231    if src.strides()[0] != 1 || dest.strides()[0] != 1 {
1232        return false;
1233    }
1234
1235    copy_transpose_scale_2d_f64_tiled_raw(
1236        dest.as_mut_ptr().cast::<f64>(),
1237        dest.strides()[0],
1238        dest.strides()[1],
1239        src.ptr().cast::<f64>(),
1240        src.strides()[0],
1241        src.strides()[1],
1242        src_dims[0],
1243        src_dims[1],
1244        scale,
1245    );
1246    true
1247}
1248
1249#[inline(always)]
1250unsafe fn transpose_scale_4x4_identity<T>(
1251    dst: *mut T,
1252    dst_stride0: isize,
1253    dst_stride1: isize,
1254    src: *const T,
1255    src_stride0: isize,
1256    src_stride1: isize,
1257    i: usize,
1258    j: usize,
1259    scale: T,
1260) where
1261    T: Copy + Mul<Output = T>,
1262{
1263    let src_base = src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1264
1265    let s00 = *src_base;
1266    let s10 = *src_base.offset(src_stride0);
1267    let s20 = *src_base.offset(2 * src_stride0);
1268    let s30 = *src_base.offset(3 * src_stride0);
1269
1270    let src_col1 = src_base.offset(src_stride1);
1271    let s01 = *src_col1;
1272    let s11 = *src_col1.offset(src_stride0);
1273    let s21 = *src_col1.offset(2 * src_stride0);
1274    let s31 = *src_col1.offset(3 * src_stride0);
1275
1276    let src_col2 = src_base.offset(2 * src_stride1);
1277    let s02 = *src_col2;
1278    let s12 = *src_col2.offset(src_stride0);
1279    let s22 = *src_col2.offset(2 * src_stride0);
1280    let s32 = *src_col2.offset(3 * src_stride0);
1281
1282    let src_col3 = src_base.offset(3 * src_stride1);
1283    let s03 = *src_col3;
1284    let s13 = *src_col3.offset(src_stride0);
1285    let s23 = *src_col3.offset(2 * src_stride0);
1286    let s33 = *src_col3.offset(3 * src_stride0);
1287
1288    let dst_row0 = dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1);
1289    *dst_row0 = scale * s00;
1290    *dst_row0.offset(dst_stride0) = scale * s01;
1291    *dst_row0.offset(2 * dst_stride0) = scale * s02;
1292    *dst_row0.offset(3 * dst_stride0) = scale * s03;
1293
1294    let dst_row1 = dst_row0.offset(dst_stride1);
1295    *dst_row1 = scale * s10;
1296    *dst_row1.offset(dst_stride0) = scale * s11;
1297    *dst_row1.offset(2 * dst_stride0) = scale * s12;
1298    *dst_row1.offset(3 * dst_stride0) = scale * s13;
1299
1300    let dst_row2 = dst_row0.offset(2 * dst_stride1);
1301    *dst_row2 = scale * s20;
1302    *dst_row2.offset(dst_stride0) = scale * s21;
1303    *dst_row2.offset(2 * dst_stride0) = scale * s22;
1304    *dst_row2.offset(3 * dst_stride0) = scale * s23;
1305
1306    let dst_row3 = dst_row0.offset(3 * dst_stride1);
1307    *dst_row3 = scale * s30;
1308    *dst_row3.offset(dst_stride0) = scale * s31;
1309    *dst_row3.offset(2 * dst_stride0) = scale * s32;
1310    *dst_row3.offset(3 * dst_stride0) = scale * s33;
1311}
1312
1313#[inline]
1314unsafe fn copy_transpose_scale_2d_identity_tiled_raw<T>(
1315    dst: *mut T,
1316    dst_stride0: isize,
1317    dst_stride1: isize,
1318    src: *const T,
1319    src_stride0: isize,
1320    src_stride1: isize,
1321    src_rows: usize,
1322    src_cols: usize,
1323    scale: T,
1324) where
1325    T: Copy + Mul<Output = T> + MaybeSendSync,
1326{
1327    const TILE: usize = 4;
1328    let row_full = src_rows / TILE * TILE;
1329    let col_full = src_cols / TILE * TILE;
1330
1331    #[cfg(feature = "parallel")]
1332    {
1333        let total = src_rows.saturating_mul(src_cols);
1334        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1335            let dst_send = SendPtr(dst);
1336            let src_send = SendPtr(src as *mut T);
1337            let row_tiles = row_full / TILE;
1338            (0..row_tiles).into_par_iter().for_each(|tile_i| {
1339                let i = tile_i * TILE;
1340                let dst = dst_send.as_ptr();
1341                let src = src_send.as_const();
1342                unsafe {
1343                    let mut j = 0;
1344                    while j < col_full {
1345                        transpose_scale_4x4_identity(
1346                            dst,
1347                            dst_stride0,
1348                            dst_stride1,
1349                            src,
1350                            src_stride0,
1351                            src_stride1,
1352                            i,
1353                            j,
1354                            scale,
1355                        );
1356                        j += TILE;
1357                    }
1358                    for j in col_full..src_cols {
1359                        for ii in i..i + TILE {
1360                            *dst.offset(j as isize * dst_stride0 + ii as isize * dst_stride1) =
1361                                scale
1362                                    * *src.offset(
1363                                        ii as isize * src_stride0 + j as isize * src_stride1,
1364                                    );
1365                        }
1366                    }
1367                }
1368            });
1369            for i in row_full..src_rows {
1370                for j in 0..src_cols {
1371                    *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1372                        scale * *src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1373                }
1374            }
1375            return;
1376        }
1377    }
1378
1379    let mut i = 0;
1380    while i < row_full {
1381        let mut j = 0;
1382        while j < col_full {
1383            transpose_scale_4x4_identity(
1384                dst,
1385                dst_stride0,
1386                dst_stride1,
1387                src,
1388                src_stride0,
1389                src_stride1,
1390                i,
1391                j,
1392                scale,
1393            );
1394            j += TILE;
1395        }
1396        for j in col_full..src_cols {
1397            for ii in i..i + TILE {
1398                *dst.offset(j as isize * dst_stride0 + ii as isize * dst_stride1) =
1399                    scale * *src.offset(ii as isize * src_stride0 + j as isize * src_stride1);
1400            }
1401        }
1402        i += TILE;
1403    }
1404    for i in row_full..src_rows {
1405        for j in 0..src_cols {
1406            *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1407                scale * *src.offset(i as isize * src_stride0 + j as isize * src_stride1);
1408        }
1409    }
1410}
1411
1412#[inline]
1413unsafe fn try_copy_transpose_scale_2d_identity_tiled<T>(
1414    dest: &mut StridedViewMut<T>,
1415    src: &StridedView<T>,
1416    scale: T,
1417) -> bool
1418where
1419    T: Copy + Mul<Output = T> + MaybeSendSync,
1420{
1421    if src.ndim() != 2 || dest.ndim() != 2 {
1422        return false;
1423    }
1424    let src_dims = src.dims();
1425    if dest.dims() != [src_dims[1], src_dims[0]] {
1426        return false;
1427    }
1428    if src.strides()[0] != 1 || dest.strides()[0] != 1 {
1429        return false;
1430    }
1431
1432    copy_transpose_scale_2d_identity_tiled_raw(
1433        dest.as_mut_ptr(),
1434        dest.strides()[0],
1435        dest.strides()[1],
1436        src.ptr(),
1437        src.strides()[0],
1438        src.strides()[1],
1439        src_dims[0],
1440        src_dims[1],
1441        scale,
1442    );
1443    true
1444}
1445
1446#[inline]
1447unsafe fn copy_transpose_scale_2d_loop<T>(
1448    dst: *mut T,
1449    dst_stride0: isize,
1450    dst_stride1: isize,
1451    src: *const T,
1452    src_stride0: isize,
1453    src_stride1: isize,
1454    src_rows: usize,
1455    src_cols: usize,
1456    scale: T,
1457) where
1458    T: Copy + ElementOpApply + Mul<Output = T> + MaybeSendSync,
1459{
1460    #[cfg(feature = "parallel")]
1461    {
1462        let total = src_rows.saturating_mul(src_cols);
1463        if total > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1464            let dst_send = SendPtr(dst);
1465            let src_send = SendPtr(src as *mut T);
1466            if dst_stride0.unsigned_abs() <= dst_stride1.unsigned_abs() {
1467                (0..src_rows).into_par_iter().for_each(|i| {
1468                    let dst = dst_send.as_ptr();
1469                    let src = src_send.as_const();
1470                    unsafe {
1471                        for j in 0..src_cols {
1472                            let value = (*src
1473                                .offset(i as isize * src_stride0 + j as isize * src_stride1))
1474                            .transpose();
1475                            *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1476                                scale * value;
1477                        }
1478                    }
1479                });
1480            } else {
1481                (0..src_cols).into_par_iter().for_each(|j| {
1482                    let dst = dst_send.as_ptr();
1483                    let src = src_send.as_const();
1484                    unsafe {
1485                        for i in 0..src_rows {
1486                            let value = (*src
1487                                .offset(i as isize * src_stride0 + j as isize * src_stride1))
1488                            .transpose();
1489                            *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) =
1490                                scale * value;
1491                        }
1492                    }
1493                });
1494            }
1495            return;
1496        }
1497    }
1498
1499    if dst_stride0.unsigned_abs() <= dst_stride1.unsigned_abs() {
1500        for i in 0..src_rows {
1501            for j in 0..src_cols {
1502                let value =
1503                    (*src.offset(i as isize * src_stride0 + j as isize * src_stride1)).transpose();
1504                *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) = scale * value;
1505            }
1506        }
1507    } else {
1508        for j in 0..src_cols {
1509            for i in 0..src_rows {
1510                let value =
1511                    (*src.offset(i as isize * src_stride0 + j as isize * src_stride1)).transpose();
1512                *dst.offset(j as isize * dst_stride0 + i as isize * dst_stride1) = scale * value;
1513            }
1514        }
1515    }
1516}
1517
1518/// Copy with transpose and scaling: `dest[j,i] = scale * src[i,j]`.
1519pub fn copy_transpose_scale_into<T>(
1520    dest: &mut StridedViewMut<T>,
1521    src: &StridedView<T>,
1522    scale: T,
1523) -> Result<()>
1524where
1525    T: Copy + ElementOpApply + Mul<Output = T> + Zero + One + PartialEq + MaybeSendSync + 'static,
1526{
1527    if src.ndim() != 2 || dest.ndim() != 2 {
1528        return Err(StridedError::RankMismatch(src.ndim(), 2));
1529    }
1530
1531    let src_dims = src.dims();
1532    let expected_dims = [src_dims[1], src_dims[0]];
1533    ensure_same_shape(dest.dims(), &expected_dims)?;
1534
1535    if scale == T::zero() {
1536        unsafe {
1537            if same_contiguous_layout(dest.dims(), &[dest.strides()]).is_some() {
1538                fill_contiguous(dest.as_mut_ptr(), total_len(dest.dims()), T::zero());
1539            } else {
1540                fill_2d(
1541                    dest.as_mut_ptr(),
1542                    dest.dims()[0],
1543                    dest.dims()[1],
1544                    dest.strides()[0],
1545                    dest.strides()[1],
1546                    T::zero(),
1547                );
1548            }
1549        }
1550        return Ok(());
1551    }
1552
1553    let transpose_is_identity = element_transpose_is_identity::<T>();
1554
1555    unsafe {
1556        if transpose_is_identity && try_copy_transpose_scale_2d_f64_tiled_typed(dest, src, scale) {
1557            return Ok(());
1558        }
1559        if transpose_is_identity && try_copy_transpose_scale_2d_identity_tiled(dest, src, scale) {
1560            return Ok(());
1561        }
1562    }
1563
1564    if scale == T::one() && transpose_is_identity {
1565        let src_t = src.permute(&[1, 0])?;
1566        #[cfg(feature = "parallel")]
1567        {
1568            if total_len(dest.dims()) > MINTHREADLENGTH && rayon::current_num_threads() > 1 {
1569                return strided_perm::copy_into_par(dest, &src_t);
1570            }
1571        }
1572        return strided_perm::copy_into(dest, &src_t);
1573    }
1574
1575    unsafe {
1576        copy_transpose_scale_2d_loop(
1577            dest.as_mut_ptr(),
1578            dest.strides()[0],
1579            dest.strides()[1],
1580            src.ptr(),
1581            src.strides()[0],
1582            src.strides()[1],
1583            src_dims[0],
1584            src_dims[1],
1585            scale,
1586        );
1587    }
1588    Ok(())
1589}
1590
1591#[cfg(test)]
1592mod tiled_tests {
1593    use super::*;
1594    use crate::view::StridedArray;
1595
1596    #[test]
1597    fn test_f64_tiled_transpose_scale_handles_remainders() {
1598        let rows = 7;
1599        let cols = 9;
1600        let a = StridedArray::<f64>::from_fn_col_major(&[rows, cols], |idx| {
1601            (idx[0] * 100 + idx[1]) as f64
1602        });
1603        let mut out = StridedArray::<f64>::col_major(&[cols, rows]);
1604
1605        let used_tiled = {
1606            let src = a.view();
1607            let mut dst = out.view_mut();
1608            unsafe { try_copy_transpose_scale_2d_f64_tiled(&mut dst, &src, 3.0) }
1609        };
1610
1611        assert!(used_tiled);
1612        for i in 0..rows {
1613            for j in 0..cols {
1614                assert_eq!(out.get(&[j, i]), 3.0 * a.get(&[i, j]));
1615            }
1616        }
1617    }
1618
1619    #[test]
1620    fn test_identity_tiled_transpose_scale_handles_integer_remainders() {
1621        let rows = 6;
1622        let cols = 5;
1623        let a = StridedArray::<u64>::from_fn_col_major(&[rows, cols], |idx| {
1624            (idx[0] * 100 + idx[1]) as u64
1625        });
1626        let mut out = StridedArray::<u64>::col_major(&[cols, rows]);
1627
1628        let used_tiled = {
1629            let src = a.view();
1630            let mut dst = out.view_mut();
1631            unsafe { try_copy_transpose_scale_2d_identity_tiled(&mut dst, &src, 2) }
1632        };
1633
1634        assert!(used_tiled);
1635        for i in 0..rows {
1636            for j in 0..cols {
1637                assert_eq!(out.get(&[j, i]), 2 * a.get(&[i, j]));
1638            }
1639        }
1640    }
1641
1642    #[test]
1643    fn test_zero_scale_fills_non_contiguous_destination() {
1644        let rows = 3;
1645        let cols = 4;
1646        let a = StridedArray::<u64>::from_fn_col_major(&[rows, cols], |idx| {
1647            (idx[0] * 10 + idx[1] + 1) as u64
1648        });
1649        let mut out_base = StridedArray::<u64>::from_fn_col_major(&[rows, cols], |_| 99);
1650        let mut out_t = out_base.view_mut().permute(&[1, 0]).unwrap();
1651
1652        copy_transpose_scale_into(&mut out_t, &a.view(), 0).unwrap();
1653
1654        for i in 0..cols {
1655            for j in 0..rows {
1656                assert_eq!(out_t.get(&[i, j]), 0);
1657            }
1658        }
1659    }
1660}