1use crate::ScalarBase;
8use std::any::{Any, TypeId};
9use std::cell::RefCell;
10use std::collections::HashMap;
11use strided_view::{RawStridedMut, RawStridedRef, StridedArray, StridedView, StridedViewMut};
12
13pub struct ContiguousOperand<T: Copy + 'static> {
15 ptr: *const T,
16 row_stride: isize,
17 col_stride: isize,
18 batch_strides: Vec<isize>,
19 conj: bool,
20 pub(crate) _buf: Option<StridedArray<T>>,
22 buf_is_pooled: bool,
23}
24
25pub struct ContiguousOperandMut<T: Copy + 'static> {
27 ptr: *mut T,
28 row_stride: isize,
29 col_stride: isize,
30 batch_strides: Vec<isize>,
31 needs_writeback: bool,
34 pub(crate) _buf: Option<StridedArray<T>>,
36 buf_is_pooled: bool,
37}
38
39thread_local! {
40 static BUFFER_POOL: RefCell<HashMap<TypeId, Box<dyn Any>>> = RefCell::new(HashMap::new());
41}
42
43const MAX_POOL_PER_TYPE: usize = 16;
44const MAX_POOLED_BYTES: usize = 64 * 1024 * 1024;
45
46fn take_pooled_vec_uninit<T: Copy + 'static>(len: usize) -> Vec<T> {
47 BUFFER_POOL.with(|pool| {
48 let mut pool = pool.borrow_mut();
49 let entry = pool
50 .entry(TypeId::of::<T>())
51 .or_insert_with(|| Box::new(Vec::<Vec<T>>::new()));
52 let vecs = entry
53 .downcast_mut::<Vec<Vec<T>>>()
54 .expect("buffer pool type mismatch");
55
56 let mut best_idx = None;
57 let mut best_cap = usize::MAX;
58 for (idx, v) in vecs.iter().enumerate() {
59 let cap = v.capacity();
60 if cap >= len && cap < best_cap {
61 best_idx = Some(idx);
62 best_cap = cap;
63 }
64 }
65
66 let mut data = best_idx
67 .map(|idx| vecs.swap_remove(idx))
68 .unwrap_or_else(|| Vec::with_capacity(len));
69 if data.capacity() < len {
70 data.reserve(len - data.capacity());
71 }
72 unsafe { data.set_len(len) };
73 data
74 })
75}
76
77fn return_pooled_vec<T: Copy + 'static>(mut data: Vec<T>) {
78 let bytes = data.capacity().saturating_mul(std::mem::size_of::<T>());
79 if bytes == 0 || bytes > MAX_POOLED_BYTES {
80 return;
81 }
82 data.clear();
83 BUFFER_POOL.with(|pool| {
84 let mut pool = pool.borrow_mut();
85 let entry = pool
86 .entry(TypeId::of::<T>())
87 .or_insert_with(|| Box::new(Vec::<Vec<T>>::new()));
88 let vecs = entry
89 .downcast_mut::<Vec<Vec<T>>>()
90 .expect("buffer pool type mismatch");
91 if vecs.len() >= MAX_POOL_PER_TYPE {
92 if let Some((min_idx, min_cap)) = vecs
93 .iter()
94 .enumerate()
95 .map(|(i, v)| (i, v.capacity()))
96 .min_by_key(|(_, cap)| *cap)
97 {
98 if min_cap < data.capacity() {
99 vecs.swap_remove(min_idx);
100 vecs.push(data);
101 }
102 }
103 } else {
104 vecs.push(data);
105 }
106 });
107}
108
109fn alloc_col_major_uninit_with_pool<T: Copy + 'static>(dims: &[usize]) -> (StridedArray<T>, bool) {
110 let total: usize = dims.iter().product::<usize>().max(1);
111 let bytes = total.saturating_mul(std::mem::size_of::<T>());
112 if bytes == 0 || bytes > MAX_POOLED_BYTES {
113 return (alloc_col_major_uninit(dims), false);
114 }
115 let data = take_pooled_vec_uninit::<T>(total);
116 let arr = unsafe { StridedArray::col_major_from_buffer_uninit(data, dims) };
117 (arr, true)
118}
119
120fn alloc_maybe_pooled<T: Copy + 'static>(
122 dims: &[usize],
123 use_pool: bool,
124) -> (StridedArray<T>, bool) {
125 if use_pool {
126 alloc_col_major_uninit_with_pool(dims)
127 } else {
128 (alloc_col_major_uninit(dims), false)
129 }
130}
131
132#[cfg(test)]
133fn pooled_count_for_type<T: 'static>() -> usize {
134 BUFFER_POOL.with(|pool| {
135 let mut pool = pool.borrow_mut();
136 let Some(entry) = pool.get_mut(&TypeId::of::<T>()) else {
137 return 0;
138 };
139 entry
140 .downcast_mut::<Vec<Vec<T>>>()
141 .map_or(0, |vecs| vecs.len())
142 })
143}
144
145impl<T: Copy + 'static> ContiguousOperand<T> {
146 #[inline]
148 pub fn ptr(&self) -> *const T {
149 self.ptr
150 }
151
152 #[inline]
154 pub fn row_stride(&self) -> isize {
155 self.row_stride
156 }
157
158 #[inline]
160 pub fn col_stride(&self) -> isize {
161 self.col_stride
162 }
163
164 #[inline]
166 pub fn batch_strides(&self) -> &[isize] {
167 &self.batch_strides
168 }
169
170 #[inline]
172 pub fn conj(&self) -> bool {
173 self.conj
174 }
175
176 #[cfg(test)]
178 #[inline]
179 pub(crate) fn has_buf(&self) -> bool {
180 self._buf.is_some()
181 }
182}
183
184impl<T: Copy + 'static> ContiguousOperandMut<T> {
185 #[inline]
187 pub fn ptr(&self) -> *mut T {
188 self.ptr
189 }
190
191 #[inline]
193 pub fn row_stride(&self) -> isize {
194 self.row_stride
195 }
196
197 #[inline]
199 pub fn col_stride(&self) -> isize {
200 self.col_stride
201 }
202
203 #[inline]
205 pub fn batch_strides(&self) -> &[isize] {
206 &self.batch_strides
207 }
208
209 #[cfg(test)]
211 #[inline]
212 pub(crate) fn has_buf(&self) -> bool {
213 self._buf.is_some()
214 }
215
216 #[cfg(test)]
219 #[inline]
220 pub(crate) fn needs_writeback(&self) -> bool {
221 self.needs_writeback
222 }
223}
224
225impl<T: Copy + Send + Sync> ContiguousOperandMut<T> {
226 pub fn finalize_into(self, dest: &mut StridedViewMut<T>) -> crate::Result<()> {
231 if self.needs_writeback {
232 if let Some(ref buf) = self._buf {
233 strided_perm::copy_into(dest, &buf.view())?;
234 }
235 }
236 Ok(())
237 }
238
239 pub fn finalize_raw_into(self, dest: &mut RawStridedMut<'_, T>) -> crate::Result<()> {
244 if self.needs_writeback {
245 if let Some(ref buf) = self._buf {
246 let mut dest_view = dest.as_view_mut();
247 strided_perm::copy_into(&mut dest_view, &buf.view())?;
248 }
249 }
250 Ok(())
251 }
252}
253
254impl<T: Copy + 'static> Drop for ContiguousOperand<T> {
255 fn drop(&mut self) {
256 if self.buf_is_pooled {
257 if let Some(arr) = self._buf.take() {
258 return_pooled_vec(arr.into_data());
259 }
260 }
261 }
262}
263
264impl<T: Copy + 'static> Drop for ContiguousOperandMut<T> {
265 fn drop(&mut self) {
266 if self.buf_is_pooled {
267 if let Some(arr) = self._buf.take() {
268 return_pooled_vec(arr.into_data());
269 }
270 }
271 }
272}
273
274struct ContiguityCheck {
276 fused_g1: Option<(usize, isize)>,
277 fused_g2: Option<(usize, isize)>,
278 needs_copy: bool,
279}
280
281fn try_fuse_col_major_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
285 if dims.len() != strides.len() {
286 return None;
287 }
288 let total = dims
289 .iter()
290 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))?;
291 if dims.is_empty() {
292 return Some((1, 0));
293 }
294
295 let mut base_stride = None;
296 let mut expected_stride = None;
297 for (&dim, &stride) in dims.iter().zip(strides.iter()) {
298 if dim <= 1 {
299 continue;
300 }
301 if stride == 0 {
302 return None;
303 }
304 if let Some(expected) = expected_stride {
305 if stride != expected {
306 return None;
307 }
308 } else {
309 base_stride = Some(stride);
310 }
311 let dim = isize::try_from(dim).ok()?;
312 expected_stride = Some(stride.checked_mul(dim)?);
313 }
314
315 let stride = base_stride.unwrap_or_else(|| {
316 strides
317 .iter()
318 .copied()
319 .min_by_key(|stride| stride.unsigned_abs())
320 .unwrap_or(0)
321 });
322 Some((total, stride))
323}
324
325fn check_contiguity(
331 group1_dims: &[usize],
332 group1_strides: &[isize],
333 group2_dims: &[usize],
334 group2_strides: &[isize],
335 requires_unit_stride: bool,
336) -> ContiguityCheck {
337 let fused_g1 = try_fuse_col_major_group(group1_dims, group1_strides);
338 let fused_g2 = try_fuse_col_major_group(group2_dims, group2_strides);
339
340 let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
341
342 if requires_unit_stride && !needs_copy {
343 let (_, rs) = fused_g1.unwrap();
344 let (_, cs) = fused_g2.unwrap();
345 if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
346 needs_copy = true;
347 }
348 }
349
350 ContiguityCheck {
351 fused_g1,
352 fused_g2,
353 needs_copy,
354 }
355}
356
357fn col_major_layout(
361 buf: &StridedArray<impl Copy>,
362 n_group1: usize,
363 n_inner: usize,
364) -> (isize, isize, Vec<isize>) {
365 let m: usize = buf.dims()[..n_group1].iter().product::<usize>().max(1);
366 let row_stride = if m == 0 { 0 } else { 1isize };
367 let col_stride = m as isize;
368 let batch_strides = buf.strides()[n_inner..].to_vec();
369 (row_stride, col_stride, batch_strides)
370}
371
372pub(crate) fn alloc_col_major_uninit<T: Copy>(dims: &[usize]) -> StridedArray<T> {
378 let total: usize = dims.iter().product::<usize>().max(1);
379 let mut data = Vec::with_capacity(total);
384 unsafe { data.set_len(total) };
385
386 let mut strides = vec![0isize; dims.len()];
389 if !dims.is_empty() {
390 strides[0] = 1;
391 for i in 1..dims.len() {
392 strides[i] = strides[i - 1] * dims[i - 1] as isize;
393 }
394 }
395
396 let arr = StridedArray::from_parts(data, dims, &strides, 0).expect("col-major allocation");
397 arr
398}
399
400pub fn prepare_input_view<T: ScalarBase + 'static>(
411 view: &StridedView<T>,
412 n_group1: usize,
413 n_group2: usize,
414 conj: bool,
415 requires_unit_stride: bool,
416 use_pool: bool,
417 materialize_conj_fn: Option<fn(T) -> T>,
418) -> crate::Result<ContiguousOperand<T>> {
419 let dims = view.dims();
420 let strides = view.strides();
421 let n_inner = n_group1 + n_group2;
422
423 if let Some(conj_fn) = materialize_conj_fn {
426 if conj {
427 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
428 strided_kernel::map_into(&mut buf.view_mut(), view, conj_fn)?;
429 let ptr = buf.view().ptr();
430 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
431 return Ok(ContiguousOperand {
432 ptr,
433 row_stride,
434 col_stride,
435 batch_strides,
436 conj: false,
437 _buf: Some(buf),
438 buf_is_pooled,
439 });
440 }
441 }
442
443 let check = check_contiguity(
444 &dims[..n_group1],
445 &strides[..n_group1],
446 &dims[n_group1..n_inner],
447 &strides[n_group1..n_inner],
448 requires_unit_stride,
449 );
450
451 if check.needs_copy {
452 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
453 strided_kernel::copy_into_col_major(&mut buf.view_mut(), view)?;
454 let ptr = buf.view().ptr();
455 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
456 Ok(ContiguousOperand {
457 ptr,
458 row_stride,
459 col_stride,
460 batch_strides,
461 conj,
462 _buf: Some(buf),
463 buf_is_pooled,
464 })
465 } else {
466 let (_, rs) = check.fused_g1.unwrap();
467 let (_, cs) = check.fused_g2.unwrap();
468 Ok(ContiguousOperand {
469 ptr: view.ptr(),
470 row_stride: rs,
471 col_stride: cs,
472 batch_strides: strides[n_inner..].to_vec(),
473 conj,
474 _buf: None,
475 buf_is_pooled: false,
476 })
477 }
478}
479
480pub fn prepare_input_raw<T: ScalarBase + 'static>(
486 view: &RawStridedRef<'_, T>,
487 n_group1: usize,
488 n_group2: usize,
489 conj: bool,
490 requires_unit_stride: bool,
491 use_pool: bool,
492 materialize_conj_fn: Option<fn(T) -> T>,
493) -> crate::Result<ContiguousOperand<T>> {
494 let dims = view.dims();
495 let strides = view.strides();
496 let n_inner = n_group1 + n_group2;
497
498 if let Some(conj_fn) = materialize_conj_fn {
499 if conj {
500 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
501 strided_kernel::map_into(&mut buf.view_mut(), &view.as_view(), conj_fn)?;
502 let ptr = buf.view().ptr();
503 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
504 return Ok(ContiguousOperand {
505 ptr,
506 row_stride,
507 col_stride,
508 batch_strides,
509 conj: false,
510 _buf: Some(buf),
511 buf_is_pooled,
512 });
513 }
514 }
515
516 let check = check_contiguity(
517 &dims[..n_group1],
518 &strides[..n_group1],
519 &dims[n_group1..n_inner],
520 &strides[n_group1..n_inner],
521 requires_unit_stride,
522 );
523
524 if check.needs_copy {
525 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
526 strided_kernel::copy_into_col_major(&mut buf.view_mut(), &view.as_view())?;
527 let ptr = buf.view().ptr();
528 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
529 Ok(ContiguousOperand {
530 ptr,
531 row_stride,
532 col_stride,
533 batch_strides,
534 conj,
535 _buf: Some(buf),
536 buf_is_pooled,
537 })
538 } else {
539 let (_, rs) = check.fused_g1.unwrap();
540 let (_, cs) = check.fused_g2.unwrap();
541 Ok(ContiguousOperand {
542 ptr: view.ptr(),
543 row_stride: rs,
544 col_stride: cs,
545 batch_strides: strides[n_inner..].to_vec(),
546 conj,
547 _buf: None,
548 buf_is_pooled: false,
549 })
550 }
551}
552
553pub fn prepare_input_owned<T: ScalarBase + 'static>(
561 arr: StridedArray<T>,
562 n_group1: usize,
563 n_group2: usize,
564 conj: bool,
565 requires_unit_stride: bool,
566 use_pool: bool,
567 materialize_conj_fn: Option<fn(T) -> T>,
568) -> crate::Result<ContiguousOperand<T>> {
569 let dims = arr.dims().to_vec();
570 let strides = arr.strides().to_vec();
571 let n_inner = n_group1 + n_group2;
572
573 if let Some(conj_fn) = materialize_conj_fn {
576 if conj {
577 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
578 strided_kernel::map_into(&mut buf.view_mut(), &arr.view(), conj_fn)?;
579 let ptr = buf.view().ptr();
580 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
581 return Ok(ContiguousOperand {
582 ptr,
583 row_stride,
584 col_stride,
585 batch_strides,
586 conj: false,
587 _buf: Some(buf),
588 buf_is_pooled,
589 });
590 }
591 }
592
593 let check = check_contiguity(
594 &dims[..n_group1],
595 &strides[..n_group1],
596 &dims[n_group1..n_inner],
597 &strides[n_group1..n_inner],
598 requires_unit_stride,
599 );
600
601 if check.needs_copy {
602 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
603 strided_kernel::copy_into_col_major(&mut buf.view_mut(), &arr.view())?;
604 let ptr = buf.view().ptr();
605 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
606 Ok(ContiguousOperand {
607 ptr,
608 row_stride,
609 col_stride,
610 batch_strides,
611 conj,
612 _buf: Some(buf),
613 buf_is_pooled,
614 })
615 } else {
616 let (_, rs) = check.fused_g1.unwrap();
617 let (_, cs) = check.fused_g2.unwrap();
618 let ptr = arr.view().ptr();
619 Ok(ContiguousOperand {
620 ptr,
621 row_stride: rs,
622 col_stride: cs,
623 batch_strides: strides[n_inner..].to_vec(),
624 conj,
625 _buf: Some(arr),
626 buf_is_pooled: false,
627 })
628 }
629}
630
631pub fn prepare_output_view<T: ScalarBase + 'static>(
647 view: &mut StridedViewMut<T>,
648 n_group1: usize,
649 n_group2: usize,
650 beta: T,
651 requires_unit_stride: bool,
652 use_pool: bool,
653) -> crate::Result<ContiguousOperandMut<T>> {
654 let dims = view.dims().to_vec();
655 let strides = view.strides().to_vec();
656 let n_inner = n_group1 + n_group2;
657
658 let check = check_contiguity(
659 &dims[..n_group1],
660 &strides[..n_group1],
661 &dims[n_group1..n_inner],
662 &strides[n_group1..n_inner],
663 requires_unit_stride,
664 );
665
666 if check.needs_copy {
667 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
668 if beta != T::zero() {
669 strided_kernel::copy_into_col_major(&mut buf.view_mut(), &view.as_view())?;
670 }
671 let ptr = buf.view_mut().as_mut_ptr();
672 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
673 Ok(ContiguousOperandMut {
674 ptr,
675 row_stride,
676 col_stride,
677 batch_strides,
678 needs_writeback: true,
679 _buf: Some(buf),
680 buf_is_pooled,
681 })
682 } else {
683 let (_, rs) = check.fused_g1.unwrap();
684 let (_, cs) = check.fused_g2.unwrap();
685 Ok(ContiguousOperandMut {
686 ptr: view.as_mut_ptr(),
687 row_stride: rs,
688 col_stride: cs,
689 batch_strides: strides[n_inner..].to_vec(),
690 needs_writeback: false,
691 _buf: None,
692 buf_is_pooled: false,
693 })
694 }
695}
696
697pub fn prepare_output_raw<T: ScalarBase + 'static>(
703 view: &mut RawStridedMut<'_, T>,
704 n_group1: usize,
705 n_group2: usize,
706 beta: T,
707 requires_unit_stride: bool,
708 use_pool: bool,
709) -> crate::Result<ContiguousOperandMut<T>> {
710 let dims = view.dims().to_vec();
711 let strides = view.strides().to_vec();
712 let n_inner = n_group1 + n_group2;
713
714 let check = check_contiguity(
715 &dims[..n_group1],
716 &strides[..n_group1],
717 &dims[n_group1..n_inner],
718 &strides[n_group1..n_inner],
719 requires_unit_stride,
720 );
721
722 if check.needs_copy {
723 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
724 if beta != T::zero() {
725 strided_kernel::copy_into_col_major(&mut buf.view_mut(), &view.as_view())?;
726 }
727 let ptr = buf.view_mut().as_mut_ptr();
728 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
729 Ok(ContiguousOperandMut {
730 ptr,
731 row_stride,
732 col_stride,
733 batch_strides,
734 needs_writeback: true,
735 _buf: Some(buf),
736 buf_is_pooled,
737 })
738 } else {
739 let (_, rs) = check.fused_g1.unwrap();
740 let (_, cs) = check.fused_g2.unwrap();
741 Ok(ContiguousOperandMut {
742 ptr: view.as_mut_ptr(),
743 row_stride: rs,
744 col_stride: cs,
745 batch_strides: strides[n_inner..].to_vec(),
746 needs_writeback: false,
747 _buf: None,
748 buf_is_pooled: false,
749 })
750 }
751}
752
753#[cfg(test)]
754mod tests_generic_backend {
755 use super::*;
756 use crate::backend::{Backend, NaiveBackend};
757
758 #[test]
759 fn test_input_for_backend_contiguous() {
760 let a = StridedArray::<f64>::col_major(&[2, 3]);
761 let view = a.view();
762 let op = prepare_input_view(
763 &view,
764 1,
765 1,
766 false,
767 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
768 false,
769 None,
770 )
771 .unwrap();
772 assert!(op._buf.is_none());
773 assert_eq!(op.row_stride(), 1);
774 assert_eq!(op.col_stride(), 2);
775 assert!(!op.conj());
776 }
777
778 #[test]
779 fn test_input_for_backend_non_contiguous() {
780 let data = vec![0.0f64; 100];
781 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
782 let view = a.view();
783 let op = prepare_input_view(
784 &view,
785 2,
786 1,
787 false,
788 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
789 false,
790 None,
791 )
792 .unwrap();
793 assert!(op._buf.is_some());
794 assert_eq!(op.row_stride(), 1);
795 assert_eq!(op.col_stride(), 6);
796 }
797
798 #[test]
799 fn test_output_for_backend_contiguous() {
800 let mut c = StridedArray::<f64>::col_major(&[2, 3]);
801 let mut view = c.view_mut();
802 let op = prepare_output_view(
803 &mut view,
804 1,
805 1,
806 0.0,
807 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
808 false,
809 )
810 .unwrap();
811 assert!(!op.needs_writeback);
812 assert!(op._buf.is_none());
813 assert_eq!(op.row_stride(), 1);
814 assert_eq!(op.col_stride(), 2);
815 }
816
817 #[test]
818 fn test_output_for_backend_non_contiguous_beta_zero() {
819 let data = vec![0.0f64; 100];
820 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
821 let mut view = c.view_mut();
822 let op = prepare_output_view(
823 &mut view,
824 2,
825 1,
826 0.0,
827 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
828 false,
829 )
830 .unwrap();
831 assert!(op.needs_writeback);
832 assert!(op._buf.is_some());
833 assert_eq!(op.row_stride(), 1);
834 assert_eq!(op.col_stride(), 6);
835 }
836
837 #[test]
838 fn test_output_for_backend_non_contiguous_beta_nonzero_and_finalize() {
839 let mut data = vec![0.0f64; 30];
840 data[0] = 10.0;
841 data[1] = 20.0;
842 data[10] = 40.0;
843 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
844 let mut view = c.view_mut();
845 let op = prepare_output_view(
846 &mut view,
847 2,
848 1,
849 1.0,
850 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
851 false,
852 )
853 .unwrap();
854 assert!(op.needs_writeback);
855 let buf = op._buf.as_ref().unwrap();
856 assert_eq!(buf.get(&[0, 0, 0]), 10.0);
857 assert_eq!(buf.get(&[0, 1, 0]), 20.0);
858 assert_eq!(buf.get(&[1, 0, 0]), 40.0);
859 op.finalize_into(&mut view).unwrap();
860 }
861}
862
863#[cfg(test)]
864mod tests {
865 use super::*;
866 use crate::backend::{ActiveBackend, Backend};
867
868 const UNIT_STRIDE: bool = <ActiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE;
870
871 #[test]
872 fn test_borrowed_contiguous_no_copy() {
873 let a = StridedArray::<f64>::col_major(&[2, 3]);
874 let view = a.view();
875
876 let op = prepare_input_view(&view, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
877
878 assert!(!op.has_buf());
879 assert_eq!(op.row_stride(), 1);
880 assert_eq!(op.col_stride(), 2);
881 assert!(!op.conj());
882 }
883
884 #[test]
885 fn test_borrowed_transposed_matrix_no_copy() {
886 let data = vec![0.0f64; 6];
887 let a_t = StridedArray::<f64>::from_parts(data, &[2, 3], &[3, 1], 0).unwrap();
888 let view = a_t.view();
889
890 let op = prepare_input_view(&view, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
891
892 assert!(!op.has_buf());
893 assert_eq!(op.row_stride(), 3);
894 assert_eq!(op.col_stride(), 1);
895 }
896
897 #[test]
898 fn test_borrowed_batched_transposed_matrix_no_copy() {
899 let data = vec![0.0f64; 2 * 3 * 5];
900 let a_t = StridedArray::<f64>::from_parts(data, &[2, 3, 5], &[3, 1, 6], 0).unwrap();
901 let view = a_t.view();
902
903 let op = prepare_input_view(&view, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
904
905 assert!(!op.has_buf());
906 assert_eq!(op.row_stride(), 3);
907 assert_eq!(op.col_stride(), 1);
908 assert_eq!(op.batch_strides(), &[6]);
909 }
910
911 #[test]
912 fn test_borrowed_non_contiguous_copies() {
913 let data = vec![0.0f64; 100];
914 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
915 let view = a.view();
916
917 let op = prepare_input_view(&view, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
918
919 assert!(op.has_buf());
920 assert_eq!(op.row_stride(), 1);
921 assert_eq!(op.col_stride(), 6);
922 }
923
924 #[test]
925 fn test_owned_contiguous_no_copy() {
926 let a = StridedArray::<f64>::col_major(&[2, 3]);
927
928 let op = prepare_input_owned(a, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
929
930 assert!(op.has_buf());
931 assert_eq!(op.row_stride(), 1);
932 assert_eq!(op.col_stride(), 2);
933 }
934
935 #[test]
936 fn test_owned_non_contiguous_copies() {
937 let data = vec![0.0f64; 100];
938 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
939
940 let op = prepare_input_owned(a, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
941
942 assert!(op.has_buf());
943 assert_eq!(op.row_stride(), 1);
944 assert_eq!(op.col_stride(), 6);
945 }
946
947 #[test]
948 fn test_output_view_contiguous() {
949 let mut c = StridedArray::<f64>::col_major(&[2, 3]);
950 let mut view = c.view_mut();
951
952 let op = prepare_output_view(&mut view, 1, 1, 0.0, UNIT_STRIDE, true).unwrap();
953
954 assert!(!op.needs_writeback());
955 assert!(!op.has_buf());
956 assert_eq!(op.row_stride(), 1);
957 assert_eq!(op.col_stride(), 2);
958 }
959
960 #[test]
961 fn test_output_view_non_contiguous_beta_zero() {
962 let data = vec![0.0f64; 100];
963 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
964 let mut view = c.view_mut();
965
966 let op = prepare_output_view(&mut view, 2, 1, 0.0, UNIT_STRIDE, true).unwrap();
967
968 assert!(op.needs_writeback());
969 assert!(op.has_buf());
970 assert_eq!(op.row_stride(), 1);
971 assert_eq!(op.col_stride(), 6);
972 }
973
974 #[test]
975 fn test_output_view_non_contiguous_beta_nonzero_and_finalize() {
976 let mut data = vec![0.0f64; 30];
977 data[0] = 10.0;
978 data[1] = 20.0;
979 data[2] = 30.0;
980 data[10] = 40.0;
981 data[11] = 50.0;
982 data[12] = 60.0;
983 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
984
985 assert_eq!(c.get(&[0, 0, 0]), 10.0);
986 assert_eq!(c.get(&[1, 1, 0]), 50.0);
987
988 let mut view = c.view_mut();
989
990 let mut op = prepare_output_view(&mut view, 2, 1, 1.0, UNIT_STRIDE, true).unwrap();
991
992 assert!(op.needs_writeback());
993 assert!(op.has_buf());
994
995 let buf = op._buf.as_ref().unwrap();
996 assert_eq!(buf.get(&[0, 0, 0]), 10.0);
997 assert_eq!(buf.get(&[1, 1, 0]), 50.0);
998
999 {
1000 let result_data = vec![100.0f64; 6];
1001 let result =
1002 StridedArray::<f64>::from_parts(result_data, &[2, 3, 1], &[3, 1, 1], 0).unwrap();
1003 strided_kernel::copy_into(&mut op._buf.as_mut().unwrap().view_mut(), &result.view())
1004 .unwrap();
1005 op.ptr = op._buf.as_mut().unwrap().view_mut().as_mut_ptr();
1006 }
1007
1008 op.finalize_into(&mut view).unwrap();
1009
1010 assert_eq!(c.get(&[0, 0, 0]), 100.0);
1011 assert_eq!(c.get(&[0, 1, 0]), 100.0);
1012 assert_eq!(c.get(&[0, 2, 0]), 100.0);
1013 assert_eq!(c.get(&[1, 0, 0]), 100.0);
1014 assert_eq!(c.get(&[1, 1, 0]), 100.0);
1015 assert_eq!(c.get(&[1, 2, 0]), 100.0);
1016 }
1017
1018 #[test]
1019 fn test_prepare_input_view_temp_buffer_is_recycled() {
1020 let before = pooled_count_for_type::<f64>();
1021 let data = vec![0.0f64; 100];
1022 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
1023 let view = a.view();
1024
1025 {
1026 let op = prepare_input_view(&view, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
1027 assert!(op.has_buf());
1028 }
1029
1030 let after = pooled_count_for_type::<f64>();
1031 assert!(after >= before.saturating_add(1));
1032 }
1033}