strided_einsum2/
contiguous.rs

1//! GEMM-ready operand types and preparation functions for contiguous data.
2//!
3//! These types encapsulate the logic for preparing strided operands for GEMM:
4//! checking fusability, copying to col-major buffers when needed, and managing
5//! the writeback for borrowed output operands.
6
7use crate::ScalarBase;
8use std::any::{Any, TypeId};
9use std::cell::RefCell;
10use std::collections::HashMap;
11use strided_perm::try_fuse_group;
12use strided_view::{StridedArray, StridedView, StridedViewMut};
13
14/// GEMM-ready input operand with contiguous data.
15pub struct ContiguousOperand<T: Copy + 'static> {
16    ptr: *const T,
17    row_stride: isize,
18    col_stride: isize,
19    batch_strides: Vec<isize>,
20    conj: bool,
21    /// Owns the buffer if a copy was made or input was consumed.
22    pub(crate) _buf: Option<StridedArray<T>>,
23    buf_is_pooled: bool,
24}
25
26/// GEMM-ready output operand with contiguous data.
27pub struct ContiguousOperandMut<T: Copy + 'static> {
28    ptr: *mut T,
29    row_stride: isize,
30    col_stride: isize,
31    batch_strides: Vec<isize>,
32    /// Whether the caller must copy the buffer back to the original destination
33    /// after GEMM completes (true only for borrowed non-contiguous C).
34    needs_writeback: bool,
35    /// Owns the buffer if a copy was made.
36    pub(crate) _buf: Option<StridedArray<T>>,
37    buf_is_pooled: bool,
38}
39
40thread_local! {
41    static BUFFER_POOL: RefCell<HashMap<TypeId, Box<dyn Any>>> = RefCell::new(HashMap::new());
42}
43
44const MAX_POOL_PER_TYPE: usize = 16;
45const MAX_POOLED_BYTES: usize = 64 * 1024 * 1024;
46
47fn take_pooled_vec_uninit<T: Copy + 'static>(len: usize) -> Vec<T> {
48    BUFFER_POOL.with(|pool| {
49        let mut pool = pool.borrow_mut();
50        let entry = pool
51            .entry(TypeId::of::<T>())
52            .or_insert_with(|| Box::new(Vec::<Vec<T>>::new()));
53        let vecs = entry
54            .downcast_mut::<Vec<Vec<T>>>()
55            .expect("buffer pool type mismatch");
56
57        let mut best_idx = None;
58        let mut best_cap = usize::MAX;
59        for (idx, v) in vecs.iter().enumerate() {
60            let cap = v.capacity();
61            if cap >= len && cap < best_cap {
62                best_idx = Some(idx);
63                best_cap = cap;
64            }
65        }
66
67        let mut data = best_idx
68            .map(|idx| vecs.swap_remove(idx))
69            .unwrap_or_else(|| Vec::with_capacity(len));
70        if data.capacity() < len {
71            data.reserve(len - data.capacity());
72        }
73        unsafe { data.set_len(len) };
74        data
75    })
76}
77
78fn return_pooled_vec<T: Copy + 'static>(mut data: Vec<T>) {
79    let bytes = data.capacity().saturating_mul(std::mem::size_of::<T>());
80    if bytes == 0 || bytes > MAX_POOLED_BYTES {
81        return;
82    }
83    data.clear();
84    BUFFER_POOL.with(|pool| {
85        let mut pool = pool.borrow_mut();
86        let entry = pool
87            .entry(TypeId::of::<T>())
88            .or_insert_with(|| Box::new(Vec::<Vec<T>>::new()));
89        let vecs = entry
90            .downcast_mut::<Vec<Vec<T>>>()
91            .expect("buffer pool type mismatch");
92        if vecs.len() >= MAX_POOL_PER_TYPE {
93            if let Some((min_idx, min_cap)) = vecs
94                .iter()
95                .enumerate()
96                .map(|(i, v)| (i, v.capacity()))
97                .min_by_key(|(_, cap)| *cap)
98            {
99                if min_cap < data.capacity() {
100                    vecs.swap_remove(min_idx);
101                    vecs.push(data);
102                }
103            }
104        } else {
105            vecs.push(data);
106        }
107    });
108}
109
110fn alloc_col_major_uninit_with_pool<T: Copy + 'static>(dims: &[usize]) -> (StridedArray<T>, bool) {
111    let total: usize = dims.iter().product::<usize>().max(1);
112    let bytes = total.saturating_mul(std::mem::size_of::<T>());
113    if bytes == 0 || bytes > MAX_POOLED_BYTES {
114        return (alloc_col_major_uninit(dims), false);
115    }
116    let data = take_pooled_vec_uninit::<T>(total);
117    let arr = unsafe { StridedArray::col_major_from_buffer_uninit(data, dims) };
118    (arr, true)
119}
120
121/// Allocate a col-major buffer, optionally reusing from the thread-local pool.
122fn alloc_maybe_pooled<T: Copy + 'static>(
123    dims: &[usize],
124    use_pool: bool,
125) -> (StridedArray<T>, bool) {
126    if use_pool {
127        alloc_col_major_uninit_with_pool(dims)
128    } else {
129        (alloc_col_major_uninit(dims), false)
130    }
131}
132
133#[cfg(test)]
134fn pooled_count_for_type<T: 'static>() -> usize {
135    BUFFER_POOL.with(|pool| {
136        let mut pool = pool.borrow_mut();
137        let Some(entry) = pool.get_mut(&TypeId::of::<T>()) else {
138            return 0;
139        };
140        entry
141            .downcast_mut::<Vec<Vec<T>>>()
142            .map_or(0, |vecs| vecs.len())
143    })
144}
145
146impl<T: Copy + 'static> ContiguousOperand<T> {
147    /// Raw const pointer to the operand data at the base offset.
148    #[inline]
149    pub fn ptr(&self) -> *const T {
150        self.ptr
151    }
152
153    /// Row (lo-group) stride for the fused 2D matrix.
154    #[inline]
155    pub fn row_stride(&self) -> isize {
156        self.row_stride
157    }
158
159    /// Column (sum/ro-group) stride for the fused 2D matrix.
160    #[inline]
161    pub fn col_stride(&self) -> isize {
162        self.col_stride
163    }
164
165    /// Batch dimension strides.
166    #[inline]
167    pub fn batch_strides(&self) -> &[isize] {
168        &self.batch_strides
169    }
170
171    /// Whether this operand requires conjugation.
172    #[inline]
173    pub fn conj(&self) -> bool {
174        self.conj
175    }
176
177    /// Returns `true` if this operand owns a buffer (copy was made or ownership transferred).
178    #[cfg(test)]
179    #[inline]
180    pub(crate) fn has_buf(&self) -> bool {
181        self._buf.is_some()
182    }
183}
184
185impl<T: Copy + 'static> ContiguousOperandMut<T> {
186    /// Raw mutable pointer to the operand data at the base offset.
187    #[inline]
188    pub fn ptr(&self) -> *mut T {
189        self.ptr
190    }
191
192    /// Row (lo-group) stride for the fused 2D matrix.
193    #[inline]
194    pub fn row_stride(&self) -> isize {
195        self.row_stride
196    }
197
198    /// Column (ro-group) stride for the fused 2D matrix.
199    #[inline]
200    pub fn col_stride(&self) -> isize {
201        self.col_stride
202    }
203
204    /// Batch dimension strides.
205    #[inline]
206    pub fn batch_strides(&self) -> &[isize] {
207        &self.batch_strides
208    }
209
210    /// Returns `true` if this operand owns a buffer (copy was made).
211    #[cfg(test)]
212    #[inline]
213    pub(crate) fn has_buf(&self) -> bool {
214        self._buf.is_some()
215    }
216
217    /// Returns `true` if the caller must copy the buffer back to the original
218    /// destination after GEMM completes.
219    #[cfg(test)]
220    #[inline]
221    pub(crate) fn needs_writeback(&self) -> bool {
222        self.needs_writeback
223    }
224}
225
226impl<T: Copy + Send + Sync> ContiguousOperandMut<T> {
227    /// After GEMM: copy the internal buffer back to `dest` if needed.
228    ///
229    /// This is a no-op when the GEMM wrote directly to the destination
230    /// (contiguous case or owned output).
231    pub fn finalize_into(self, dest: &mut StridedViewMut<T>) -> crate::Result<()> {
232        if self.needs_writeback {
233            if let Some(ref buf) = self._buf {
234                strided_perm::copy_into(dest, &buf.view())?;
235            }
236        }
237        Ok(())
238    }
239}
240
241impl<T: Copy + 'static> Drop for ContiguousOperand<T> {
242    fn drop(&mut self) {
243        if self.buf_is_pooled {
244            if let Some(arr) = self._buf.take() {
245                return_pooled_vec(arr.into_data());
246            }
247        }
248    }
249}
250
251impl<T: Copy + 'static> Drop for ContiguousOperandMut<T> {
252    fn drop(&mut self) {
253        if self.buf_is_pooled {
254            if let Some(arr) = self._buf.take() {
255                return_pooled_vec(arr.into_data());
256            }
257        }
258    }
259}
260
261/// Result of checking whether dimension groups are contiguous enough for GEMM.
262struct ContiguityCheck {
263    fused_g1: Option<(usize, isize)>,
264    fused_g2: Option<(usize, isize)>,
265    needs_copy: bool,
266}
267
268/// Check if two dimension groups are fusable (contiguous) for GEMM.
269///
270/// When `requires_unit_stride` is true (e.g., CBLAS backend), also checks that
271/// at least one of the fused strides is 0 or 1.
272fn check_contiguity(
273    group1_dims: &[usize],
274    group1_strides: &[isize],
275    group2_dims: &[usize],
276    group2_strides: &[isize],
277    requires_unit_stride: bool,
278) -> ContiguityCheck {
279    let fused_g1 = try_fuse_group(group1_dims, group1_strides);
280    let fused_g2 = try_fuse_group(group2_dims, group2_strides);
281
282    let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
283
284    if requires_unit_stride && !needs_copy {
285        let (_, rs) = fused_g1.unwrap();
286        let (_, cs) = fused_g2.unwrap();
287        if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
288            needs_copy = true;
289        }
290    }
291
292    ContiguityCheck {
293        fused_g1,
294        fused_g2,
295        needs_copy,
296    }
297}
298
299/// Compute col-major layout parameters from a freshly-allocated col-major buffer.
300///
301/// Returns `(row_stride, col_stride, batch_strides)`.
302fn col_major_layout(
303    buf: &StridedArray<impl Copy>,
304    n_group1: usize,
305    n_inner: usize,
306) -> (isize, isize, Vec<isize>) {
307    let m: usize = buf.dims()[..n_group1].iter().product::<usize>().max(1);
308    let row_stride = if m == 0 { 0 } else { 1isize };
309    let col_stride = m as isize;
310    let batch_strides = buf.strides()[n_inner..].to_vec();
311    (row_stride, col_stride, batch_strides)
312}
313
314/// Allocate a column-major StridedArray with uninitialized data.
315///
316/// With batch-last canonical order `[inner..., batch...]`, pure column-major
317/// naturally gives batch dims the largest strides — each batch slice is a
318/// contiguous column-major matrix.
319pub(crate) fn alloc_col_major_uninit<T: Copy>(dims: &[usize]) -> StridedArray<T> {
320    let total: usize = dims.iter().product::<usize>().max(1);
321    // SAFETY: `T: Copy` guarantees no drop glue, so leaving elements
322    // uninitialised is safe. Every call-site writes all elements before
323    // reading: A and B via `copy_into`, C via `copy_into` (beta != 0)
324    // or GEMM with replace semantics (beta == 0).
325    let mut data = Vec::with_capacity(total);
326    unsafe { data.set_len(total) };
327
328    // Pure column-major: stride 1 for first dim, each subsequent dim
329    // has stride = previous stride * previous dim size.
330    let mut strides = vec![0isize; dims.len()];
331    if !dims.is_empty() {
332        strides[0] = 1;
333        for i in 1..dims.len() {
334            strides[i] = strides[i - 1] * dims[i - 1] as isize;
335        }
336    }
337
338    let arr = StridedArray::from_parts(data, dims, &strides, 0).expect("col-major allocation");
339    arr
340}
341
342/// Prepare a borrowed input view for GEMM.
343///
344/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
345/// Checks if the two inner dimension groups are fusable.
346/// If not, copies to a contiguous col-major buffer.
347///
348/// - `requires_unit_stride`: backend needs at least one unit stride (e.g. CBLAS).
349/// - `use_pool`: reuse thread-local buffers to avoid repeated allocation.
350/// - `materialize_conj_fn`: when `Some(f)` and `conj == true`, applies `f` to each
351///   element during copy (for backends that cannot pass conj flags to GEMM).
352pub fn prepare_input_view<T: ScalarBase + 'static>(
353    view: &StridedView<T>,
354    n_group1: usize,
355    n_group2: usize,
356    conj: bool,
357    requires_unit_stride: bool,
358    use_pool: bool,
359    materialize_conj_fn: Option<fn(T) -> T>,
360) -> crate::Result<ContiguousOperand<T>> {
361    let dims = view.dims();
362    let strides = view.strides();
363    let n_inner = n_group1 + n_group2;
364
365    // For backends that cannot pass conjugation flags to GEMM (e.g., CBLAS),
366    // materialize conj into the data before the GEMM call.
367    if let Some(conj_fn) = materialize_conj_fn {
368        if conj {
369            let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
370            strided_kernel::map_into(&mut buf.view_mut(), view, conj_fn)?;
371            let ptr = buf.view().ptr();
372            let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
373            return Ok(ContiguousOperand {
374                ptr,
375                row_stride,
376                col_stride,
377                batch_strides,
378                conj: false,
379                _buf: Some(buf),
380                buf_is_pooled,
381            });
382        }
383    }
384
385    let check = check_contiguity(
386        &dims[..n_group1],
387        &strides[..n_group1],
388        &dims[n_group1..n_inner],
389        &strides[n_group1..n_inner],
390        requires_unit_stride,
391    );
392
393    if check.needs_copy {
394        let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
395        strided_kernel::copy_into_col_major(&mut buf.view_mut(), view)?;
396        let ptr = buf.view().ptr();
397        let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
398        Ok(ContiguousOperand {
399            ptr,
400            row_stride,
401            col_stride,
402            batch_strides,
403            conj,
404            _buf: Some(buf),
405            buf_is_pooled,
406        })
407    } else {
408        let (_, rs) = check.fused_g1.unwrap();
409        let (_, cs) = check.fused_g2.unwrap();
410        Ok(ContiguousOperand {
411            ptr: view.ptr(),
412            row_stride: rs,
413            col_stride: cs,
414            batch_strides: strides[n_inner..].to_vec(),
415            conj,
416            _buf: None,
417            buf_is_pooled: false,
418        })
419    }
420}
421
422/// Prepare an owned input array for GEMM.
423///
424/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
425/// If already contiguous after dimension grouping, transfers ownership without copying.
426/// Otherwise, copies to a new col-major buffer.
427///
428/// Parameters are the same as [`prepare_input_view`].
429pub fn prepare_input_owned<T: ScalarBase + 'static>(
430    arr: StridedArray<T>,
431    n_group1: usize,
432    n_group2: usize,
433    conj: bool,
434    requires_unit_stride: bool,
435    use_pool: bool,
436    materialize_conj_fn: Option<fn(T) -> T>,
437) -> crate::Result<ContiguousOperand<T>> {
438    let dims = arr.dims().to_vec();
439    let strides = arr.strides().to_vec();
440    let n_inner = n_group1 + n_group2;
441
442    // For backends that cannot pass conjugation flags to GEMM,
443    // materialize conj into the data before the GEMM call.
444    if let Some(conj_fn) = materialize_conj_fn {
445        if conj {
446            let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
447            strided_kernel::map_into(&mut buf.view_mut(), &arr.view(), conj_fn)?;
448            let ptr = buf.view().ptr();
449            let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
450            return Ok(ContiguousOperand {
451                ptr,
452                row_stride,
453                col_stride,
454                batch_strides,
455                conj: false,
456                _buf: Some(buf),
457                buf_is_pooled,
458            });
459        }
460    }
461
462    let check = check_contiguity(
463        &dims[..n_group1],
464        &strides[..n_group1],
465        &dims[n_group1..n_inner],
466        &strides[n_group1..n_inner],
467        requires_unit_stride,
468    );
469
470    if check.needs_copy {
471        let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
472        strided_kernel::copy_into_col_major(&mut buf.view_mut(), &arr.view())?;
473        let ptr = buf.view().ptr();
474        let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
475        Ok(ContiguousOperand {
476            ptr,
477            row_stride,
478            col_stride,
479            batch_strides,
480            conj,
481            _buf: Some(buf),
482            buf_is_pooled,
483        })
484    } else {
485        let (_, rs) = check.fused_g1.unwrap();
486        let (_, cs) = check.fused_g2.unwrap();
487        let ptr = arr.view().ptr();
488        Ok(ContiguousOperand {
489            ptr,
490            row_stride: rs,
491            col_stride: cs,
492            batch_strides: strides[n_inner..].to_vec(),
493            conj,
494            _buf: Some(arr),
495            buf_is_pooled: false,
496        })
497    }
498}
499
500/// Prepare a borrowed mutable output view for GEMM.
501///
502/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
503/// Checks if the two inner dimension groups (lo, ro) are fusable.
504/// If not, allocates a col-major buffer and copies the existing data into it
505/// when `beta` is non-zero (so the GEMM accumulation is correct).
506///
507/// After GEMM, call [`ContiguousOperandMut::finalize_into`] with the original
508/// view to copy results back if needed.
509///
510/// # Safety contract
511///
512/// When inner dims are fusable (no copy needed), the returned `ContiguousOperandMut`
513/// holds a raw pointer into `view`'s data. The caller must ensure `view` outlives
514/// the returned operand and that no aliasing mutable references exist during GEMM.
515pub fn prepare_output_view<T: ScalarBase + 'static>(
516    view: &mut StridedViewMut<T>,
517    n_group1: usize,
518    n_group2: usize,
519    beta: T,
520    requires_unit_stride: bool,
521    use_pool: bool,
522) -> crate::Result<ContiguousOperandMut<T>> {
523    let dims = view.dims().to_vec();
524    let strides = view.strides().to_vec();
525    let n_inner = n_group1 + n_group2;
526
527    let check = check_contiguity(
528        &dims[..n_group1],
529        &strides[..n_group1],
530        &dims[n_group1..n_inner],
531        &strides[n_group1..n_inner],
532        requires_unit_stride,
533    );
534
535    if check.needs_copy {
536        let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
537        if beta != T::zero() {
538            strided_kernel::copy_into_col_major(&mut buf.view_mut(), &view.as_view())?;
539        }
540        let ptr = buf.view_mut().as_mut_ptr();
541        let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
542        Ok(ContiguousOperandMut {
543            ptr,
544            row_stride,
545            col_stride,
546            batch_strides,
547            needs_writeback: true,
548            _buf: Some(buf),
549            buf_is_pooled,
550        })
551    } else {
552        let (_, rs) = check.fused_g1.unwrap();
553        let (_, cs) = check.fused_g2.unwrap();
554        Ok(ContiguousOperandMut {
555            ptr: view.as_mut_ptr(),
556            row_stride: rs,
557            col_stride: cs,
558            batch_strides: strides[n_inner..].to_vec(),
559            needs_writeback: false,
560            _buf: None,
561            buf_is_pooled: false,
562        })
563    }
564}
565
566#[cfg(test)]
567mod tests_generic_backend {
568    use super::*;
569    use crate::backend::{Backend, NaiveBackend};
570
571    #[test]
572    fn test_input_for_backend_contiguous() {
573        let a = StridedArray::<f64>::col_major(&[2, 3]);
574        let view = a.view();
575        let op = prepare_input_view(
576            &view,
577            1,
578            1,
579            false,
580            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
581            false,
582            None,
583        )
584        .unwrap();
585        assert!(op._buf.is_none());
586        assert_eq!(op.row_stride(), 1);
587        assert_eq!(op.col_stride(), 2);
588        assert!(!op.conj());
589    }
590
591    #[test]
592    fn test_input_for_backend_non_contiguous() {
593        let data = vec![0.0f64; 100];
594        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
595        let view = a.view();
596        let op = prepare_input_view(
597            &view,
598            2,
599            1,
600            false,
601            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
602            false,
603            None,
604        )
605        .unwrap();
606        assert!(op._buf.is_some());
607        assert_eq!(op.row_stride(), 1);
608        assert_eq!(op.col_stride(), 6);
609    }
610
611    #[test]
612    fn test_output_for_backend_contiguous() {
613        let mut c = StridedArray::<f64>::col_major(&[2, 3]);
614        let mut view = c.view_mut();
615        let op = prepare_output_view(
616            &mut view,
617            1,
618            1,
619            0.0,
620            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
621            false,
622        )
623        .unwrap();
624        assert!(!op.needs_writeback);
625        assert!(op._buf.is_none());
626        assert_eq!(op.row_stride(), 1);
627        assert_eq!(op.col_stride(), 2);
628    }
629
630    #[test]
631    fn test_output_for_backend_non_contiguous_beta_zero() {
632        let data = vec![0.0f64; 100];
633        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
634        let mut view = c.view_mut();
635        let op = prepare_output_view(
636            &mut view,
637            2,
638            1,
639            0.0,
640            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
641            false,
642        )
643        .unwrap();
644        assert!(op.needs_writeback);
645        assert!(op._buf.is_some());
646        assert_eq!(op.row_stride(), 1);
647        assert_eq!(op.col_stride(), 6);
648    }
649
650    #[test]
651    fn test_output_for_backend_non_contiguous_beta_nonzero_and_finalize() {
652        let mut data = vec![0.0f64; 30];
653        data[0] = 10.0;
654        data[1] = 20.0;
655        data[10] = 40.0;
656        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
657        let mut view = c.view_mut();
658        let op = prepare_output_view(
659            &mut view,
660            2,
661            1,
662            1.0,
663            <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
664            false,
665        )
666        .unwrap();
667        assert!(op.needs_writeback);
668        let buf = op._buf.as_ref().unwrap();
669        assert_eq!(buf.get(&[0, 0, 0]), 10.0);
670        assert_eq!(buf.get(&[0, 1, 0]), 20.0);
671        assert_eq!(buf.get(&[1, 0, 0]), 40.0);
672        op.finalize_into(&mut view).unwrap();
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use crate::backend::{ActiveBackend, Backend};
680
681    // Helper to construct prepare_input_view params matching the active backend.
682    const UNIT_STRIDE: bool = <ActiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE;
683
684    #[test]
685    fn test_borrowed_contiguous_no_copy() {
686        let a = StridedArray::<f64>::col_major(&[2, 3]);
687        let view = a.view();
688
689        let op = prepare_input_view(&view, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
690
691        assert!(!op.has_buf());
692        assert_eq!(op.row_stride(), 1);
693        assert_eq!(op.col_stride(), 2);
694        assert!(!op.conj());
695    }
696
697    #[test]
698    fn test_borrowed_non_contiguous_copies() {
699        let data = vec![0.0f64; 100];
700        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
701        let view = a.view();
702
703        let op = prepare_input_view(&view, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
704
705        assert!(op.has_buf());
706        assert_eq!(op.row_stride(), 1);
707        assert_eq!(op.col_stride(), 6);
708    }
709
710    #[test]
711    fn test_owned_contiguous_no_copy() {
712        let a = StridedArray::<f64>::col_major(&[2, 3]);
713
714        let op = prepare_input_owned(a, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
715
716        assert!(op.has_buf());
717        assert_eq!(op.row_stride(), 1);
718        assert_eq!(op.col_stride(), 2);
719    }
720
721    #[test]
722    fn test_owned_non_contiguous_copies() {
723        let data = vec![0.0f64; 100];
724        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
725
726        let op = prepare_input_owned(a, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
727
728        assert!(op.has_buf());
729        assert_eq!(op.row_stride(), 1);
730        assert_eq!(op.col_stride(), 6);
731    }
732
733    #[test]
734    fn test_output_view_contiguous() {
735        let mut c = StridedArray::<f64>::col_major(&[2, 3]);
736        let mut view = c.view_mut();
737
738        let op = prepare_output_view(&mut view, 1, 1, 0.0, UNIT_STRIDE, true).unwrap();
739
740        assert!(!op.needs_writeback());
741        assert!(!op.has_buf());
742        assert_eq!(op.row_stride(), 1);
743        assert_eq!(op.col_stride(), 2);
744    }
745
746    #[test]
747    fn test_output_view_non_contiguous_beta_zero() {
748        let data = vec![0.0f64; 100];
749        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
750        let mut view = c.view_mut();
751
752        let op = prepare_output_view(&mut view, 2, 1, 0.0, UNIT_STRIDE, true).unwrap();
753
754        assert!(op.needs_writeback());
755        assert!(op.has_buf());
756        assert_eq!(op.row_stride(), 1);
757        assert_eq!(op.col_stride(), 6);
758    }
759
760    #[test]
761    fn test_output_view_non_contiguous_beta_nonzero_and_finalize() {
762        let mut data = vec![0.0f64; 30];
763        data[0] = 10.0;
764        data[1] = 20.0;
765        data[2] = 30.0;
766        data[10] = 40.0;
767        data[11] = 50.0;
768        data[12] = 60.0;
769        let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
770
771        assert_eq!(c.get(&[0, 0, 0]), 10.0);
772        assert_eq!(c.get(&[1, 1, 0]), 50.0);
773
774        let mut view = c.view_mut();
775
776        let mut op = prepare_output_view(&mut view, 2, 1, 1.0, UNIT_STRIDE, true).unwrap();
777
778        assert!(op.needs_writeback());
779        assert!(op.has_buf());
780
781        let buf = op._buf.as_ref().unwrap();
782        assert_eq!(buf.get(&[0, 0, 0]), 10.0);
783        assert_eq!(buf.get(&[1, 1, 0]), 50.0);
784
785        {
786            let result_data = vec![100.0f64; 6];
787            let result =
788                StridedArray::<f64>::from_parts(result_data, &[2, 3, 1], &[3, 1, 1], 0).unwrap();
789            strided_kernel::copy_into(&mut op._buf.as_mut().unwrap().view_mut(), &result.view())
790                .unwrap();
791            op.ptr = op._buf.as_mut().unwrap().view_mut().as_mut_ptr();
792        }
793
794        op.finalize_into(&mut view).unwrap();
795
796        assert_eq!(c.get(&[0, 0, 0]), 100.0);
797        assert_eq!(c.get(&[0, 1, 0]), 100.0);
798        assert_eq!(c.get(&[0, 2, 0]), 100.0);
799        assert_eq!(c.get(&[1, 0, 0]), 100.0);
800        assert_eq!(c.get(&[1, 1, 0]), 100.0);
801        assert_eq!(c.get(&[1, 2, 0]), 100.0);
802    }
803
804    #[test]
805    fn test_prepare_input_view_temp_buffer_is_recycled() {
806        let before = pooled_count_for_type::<f64>();
807        let data = vec![0.0f64; 100];
808        let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
809        let view = a.view();
810
811        {
812            let op = prepare_input_view(&view, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
813            assert!(op.has_buf());
814        }
815
816        let after = pooled_count_for_type::<f64>();
817        assert!(after >= before.saturating_add(1));
818    }
819}