1use crate::ScalarBase;
8use std::any::{Any, TypeId};
9use std::cell::RefCell;
10use std::collections::HashMap;
11use strided_perm::try_fuse_group;
12use strided_view::{StridedArray, StridedView, StridedViewMut};
13
14pub struct ContiguousOperand<T: Copy + 'static> {
16 ptr: *const T,
17 row_stride: isize,
18 col_stride: isize,
19 batch_strides: Vec<isize>,
20 conj: bool,
21 pub(crate) _buf: Option<StridedArray<T>>,
23 buf_is_pooled: bool,
24}
25
26pub struct ContiguousOperandMut<T: Copy + 'static> {
28 ptr: *mut T,
29 row_stride: isize,
30 col_stride: isize,
31 batch_strides: Vec<isize>,
32 needs_writeback: bool,
35 pub(crate) _buf: Option<StridedArray<T>>,
37 buf_is_pooled: bool,
38}
39
40thread_local! {
41 static BUFFER_POOL: RefCell<HashMap<TypeId, Box<dyn Any>>> = RefCell::new(HashMap::new());
42}
43
44const MAX_POOL_PER_TYPE: usize = 16;
45const MAX_POOLED_BYTES: usize = 64 * 1024 * 1024;
46
47fn take_pooled_vec_uninit<T: Copy + 'static>(len: usize) -> Vec<T> {
48 BUFFER_POOL.with(|pool| {
49 let mut pool = pool.borrow_mut();
50 let entry = pool
51 .entry(TypeId::of::<T>())
52 .or_insert_with(|| Box::new(Vec::<Vec<T>>::new()));
53 let vecs = entry
54 .downcast_mut::<Vec<Vec<T>>>()
55 .expect("buffer pool type mismatch");
56
57 let mut best_idx = None;
58 let mut best_cap = usize::MAX;
59 for (idx, v) in vecs.iter().enumerate() {
60 let cap = v.capacity();
61 if cap >= len && cap < best_cap {
62 best_idx = Some(idx);
63 best_cap = cap;
64 }
65 }
66
67 let mut data = best_idx
68 .map(|idx| vecs.swap_remove(idx))
69 .unwrap_or_else(|| Vec::with_capacity(len));
70 if data.capacity() < len {
71 data.reserve(len - data.capacity());
72 }
73 unsafe { data.set_len(len) };
74 data
75 })
76}
77
78fn return_pooled_vec<T: Copy + 'static>(mut data: Vec<T>) {
79 let bytes = data.capacity().saturating_mul(std::mem::size_of::<T>());
80 if bytes == 0 || bytes > MAX_POOLED_BYTES {
81 return;
82 }
83 data.clear();
84 BUFFER_POOL.with(|pool| {
85 let mut pool = pool.borrow_mut();
86 let entry = pool
87 .entry(TypeId::of::<T>())
88 .or_insert_with(|| Box::new(Vec::<Vec<T>>::new()));
89 let vecs = entry
90 .downcast_mut::<Vec<Vec<T>>>()
91 .expect("buffer pool type mismatch");
92 if vecs.len() >= MAX_POOL_PER_TYPE {
93 if let Some((min_idx, min_cap)) = vecs
94 .iter()
95 .enumerate()
96 .map(|(i, v)| (i, v.capacity()))
97 .min_by_key(|(_, cap)| *cap)
98 {
99 if min_cap < data.capacity() {
100 vecs.swap_remove(min_idx);
101 vecs.push(data);
102 }
103 }
104 } else {
105 vecs.push(data);
106 }
107 });
108}
109
110fn alloc_col_major_uninit_with_pool<T: Copy + 'static>(dims: &[usize]) -> (StridedArray<T>, bool) {
111 let total: usize = dims.iter().product::<usize>().max(1);
112 let bytes = total.saturating_mul(std::mem::size_of::<T>());
113 if bytes == 0 || bytes > MAX_POOLED_BYTES {
114 return (alloc_col_major_uninit(dims), false);
115 }
116 let data = take_pooled_vec_uninit::<T>(total);
117 let arr = unsafe { StridedArray::col_major_from_buffer_uninit(data, dims) };
118 (arr, true)
119}
120
121fn alloc_maybe_pooled<T: Copy + 'static>(
123 dims: &[usize],
124 use_pool: bool,
125) -> (StridedArray<T>, bool) {
126 if use_pool {
127 alloc_col_major_uninit_with_pool(dims)
128 } else {
129 (alloc_col_major_uninit(dims), false)
130 }
131}
132
133#[cfg(test)]
134fn pooled_count_for_type<T: 'static>() -> usize {
135 BUFFER_POOL.with(|pool| {
136 let mut pool = pool.borrow_mut();
137 let Some(entry) = pool.get_mut(&TypeId::of::<T>()) else {
138 return 0;
139 };
140 entry
141 .downcast_mut::<Vec<Vec<T>>>()
142 .map_or(0, |vecs| vecs.len())
143 })
144}
145
146impl<T: Copy + 'static> ContiguousOperand<T> {
147 #[inline]
149 pub fn ptr(&self) -> *const T {
150 self.ptr
151 }
152
153 #[inline]
155 pub fn row_stride(&self) -> isize {
156 self.row_stride
157 }
158
159 #[inline]
161 pub fn col_stride(&self) -> isize {
162 self.col_stride
163 }
164
165 #[inline]
167 pub fn batch_strides(&self) -> &[isize] {
168 &self.batch_strides
169 }
170
171 #[inline]
173 pub fn conj(&self) -> bool {
174 self.conj
175 }
176
177 #[cfg(test)]
179 #[inline]
180 pub(crate) fn has_buf(&self) -> bool {
181 self._buf.is_some()
182 }
183}
184
185impl<T: Copy + 'static> ContiguousOperandMut<T> {
186 #[inline]
188 pub fn ptr(&self) -> *mut T {
189 self.ptr
190 }
191
192 #[inline]
194 pub fn row_stride(&self) -> isize {
195 self.row_stride
196 }
197
198 #[inline]
200 pub fn col_stride(&self) -> isize {
201 self.col_stride
202 }
203
204 #[inline]
206 pub fn batch_strides(&self) -> &[isize] {
207 &self.batch_strides
208 }
209
210 #[cfg(test)]
212 #[inline]
213 pub(crate) fn has_buf(&self) -> bool {
214 self._buf.is_some()
215 }
216
217 #[cfg(test)]
220 #[inline]
221 pub(crate) fn needs_writeback(&self) -> bool {
222 self.needs_writeback
223 }
224}
225
226impl<T: Copy + Send + Sync> ContiguousOperandMut<T> {
227 pub fn finalize_into(self, dest: &mut StridedViewMut<T>) -> crate::Result<()> {
232 if self.needs_writeback {
233 if let Some(ref buf) = self._buf {
234 strided_perm::copy_into(dest, &buf.view())?;
235 }
236 }
237 Ok(())
238 }
239}
240
241impl<T: Copy + 'static> Drop for ContiguousOperand<T> {
242 fn drop(&mut self) {
243 if self.buf_is_pooled {
244 if let Some(arr) = self._buf.take() {
245 return_pooled_vec(arr.into_data());
246 }
247 }
248 }
249}
250
251impl<T: Copy + 'static> Drop for ContiguousOperandMut<T> {
252 fn drop(&mut self) {
253 if self.buf_is_pooled {
254 if let Some(arr) = self._buf.take() {
255 return_pooled_vec(arr.into_data());
256 }
257 }
258 }
259}
260
261struct ContiguityCheck {
263 fused_g1: Option<(usize, isize)>,
264 fused_g2: Option<(usize, isize)>,
265 needs_copy: bool,
266}
267
268fn check_contiguity(
273 group1_dims: &[usize],
274 group1_strides: &[isize],
275 group2_dims: &[usize],
276 group2_strides: &[isize],
277 requires_unit_stride: bool,
278) -> ContiguityCheck {
279 let fused_g1 = try_fuse_group(group1_dims, group1_strides);
280 let fused_g2 = try_fuse_group(group2_dims, group2_strides);
281
282 let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
283
284 if requires_unit_stride && !needs_copy {
285 let (_, rs) = fused_g1.unwrap();
286 let (_, cs) = fused_g2.unwrap();
287 if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
288 needs_copy = true;
289 }
290 }
291
292 ContiguityCheck {
293 fused_g1,
294 fused_g2,
295 needs_copy,
296 }
297}
298
299fn col_major_layout(
303 buf: &StridedArray<impl Copy>,
304 n_group1: usize,
305 n_inner: usize,
306) -> (isize, isize, Vec<isize>) {
307 let m: usize = buf.dims()[..n_group1].iter().product::<usize>().max(1);
308 let row_stride = if m == 0 { 0 } else { 1isize };
309 let col_stride = m as isize;
310 let batch_strides = buf.strides()[n_inner..].to_vec();
311 (row_stride, col_stride, batch_strides)
312}
313
314pub(crate) fn alloc_col_major_uninit<T: Copy>(dims: &[usize]) -> StridedArray<T> {
320 let total: usize = dims.iter().product::<usize>().max(1);
321 let mut data = Vec::with_capacity(total);
326 unsafe { data.set_len(total) };
327
328 let mut strides = vec![0isize; dims.len()];
331 if !dims.is_empty() {
332 strides[0] = 1;
333 for i in 1..dims.len() {
334 strides[i] = strides[i - 1] * dims[i - 1] as isize;
335 }
336 }
337
338 let arr = StridedArray::from_parts(data, dims, &strides, 0).expect("col-major allocation");
339 arr
340}
341
342pub fn prepare_input_view<T: ScalarBase + 'static>(
353 view: &StridedView<T>,
354 n_group1: usize,
355 n_group2: usize,
356 conj: bool,
357 requires_unit_stride: bool,
358 use_pool: bool,
359 materialize_conj_fn: Option<fn(T) -> T>,
360) -> crate::Result<ContiguousOperand<T>> {
361 let dims = view.dims();
362 let strides = view.strides();
363 let n_inner = n_group1 + n_group2;
364
365 if let Some(conj_fn) = materialize_conj_fn {
368 if conj {
369 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
370 strided_kernel::map_into(&mut buf.view_mut(), view, conj_fn)?;
371 let ptr = buf.view().ptr();
372 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
373 return Ok(ContiguousOperand {
374 ptr,
375 row_stride,
376 col_stride,
377 batch_strides,
378 conj: false,
379 _buf: Some(buf),
380 buf_is_pooled,
381 });
382 }
383 }
384
385 let check = check_contiguity(
386 &dims[..n_group1],
387 &strides[..n_group1],
388 &dims[n_group1..n_inner],
389 &strides[n_group1..n_inner],
390 requires_unit_stride,
391 );
392
393 if check.needs_copy {
394 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(dims, use_pool);
395 strided_kernel::copy_into_col_major(&mut buf.view_mut(), view)?;
396 let ptr = buf.view().ptr();
397 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
398 Ok(ContiguousOperand {
399 ptr,
400 row_stride,
401 col_stride,
402 batch_strides,
403 conj,
404 _buf: Some(buf),
405 buf_is_pooled,
406 })
407 } else {
408 let (_, rs) = check.fused_g1.unwrap();
409 let (_, cs) = check.fused_g2.unwrap();
410 Ok(ContiguousOperand {
411 ptr: view.ptr(),
412 row_stride: rs,
413 col_stride: cs,
414 batch_strides: strides[n_inner..].to_vec(),
415 conj,
416 _buf: None,
417 buf_is_pooled: false,
418 })
419 }
420}
421
422pub fn prepare_input_owned<T: ScalarBase + 'static>(
430 arr: StridedArray<T>,
431 n_group1: usize,
432 n_group2: usize,
433 conj: bool,
434 requires_unit_stride: bool,
435 use_pool: bool,
436 materialize_conj_fn: Option<fn(T) -> T>,
437) -> crate::Result<ContiguousOperand<T>> {
438 let dims = arr.dims().to_vec();
439 let strides = arr.strides().to_vec();
440 let n_inner = n_group1 + n_group2;
441
442 if let Some(conj_fn) = materialize_conj_fn {
445 if conj {
446 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
447 strided_kernel::map_into(&mut buf.view_mut(), &arr.view(), conj_fn)?;
448 let ptr = buf.view().ptr();
449 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
450 return Ok(ContiguousOperand {
451 ptr,
452 row_stride,
453 col_stride,
454 batch_strides,
455 conj: false,
456 _buf: Some(buf),
457 buf_is_pooled,
458 });
459 }
460 }
461
462 let check = check_contiguity(
463 &dims[..n_group1],
464 &strides[..n_group1],
465 &dims[n_group1..n_inner],
466 &strides[n_group1..n_inner],
467 requires_unit_stride,
468 );
469
470 if check.needs_copy {
471 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
472 strided_kernel::copy_into_col_major(&mut buf.view_mut(), &arr.view())?;
473 let ptr = buf.view().ptr();
474 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
475 Ok(ContiguousOperand {
476 ptr,
477 row_stride,
478 col_stride,
479 batch_strides,
480 conj,
481 _buf: Some(buf),
482 buf_is_pooled,
483 })
484 } else {
485 let (_, rs) = check.fused_g1.unwrap();
486 let (_, cs) = check.fused_g2.unwrap();
487 let ptr = arr.view().ptr();
488 Ok(ContiguousOperand {
489 ptr,
490 row_stride: rs,
491 col_stride: cs,
492 batch_strides: strides[n_inner..].to_vec(),
493 conj,
494 _buf: Some(arr),
495 buf_is_pooled: false,
496 })
497 }
498}
499
500pub fn prepare_output_view<T: ScalarBase + 'static>(
516 view: &mut StridedViewMut<T>,
517 n_group1: usize,
518 n_group2: usize,
519 beta: T,
520 requires_unit_stride: bool,
521 use_pool: bool,
522) -> crate::Result<ContiguousOperandMut<T>> {
523 let dims = view.dims().to_vec();
524 let strides = view.strides().to_vec();
525 let n_inner = n_group1 + n_group2;
526
527 let check = check_contiguity(
528 &dims[..n_group1],
529 &strides[..n_group1],
530 &dims[n_group1..n_inner],
531 &strides[n_group1..n_inner],
532 requires_unit_stride,
533 );
534
535 if check.needs_copy {
536 let (mut buf, buf_is_pooled) = alloc_maybe_pooled(&dims, use_pool);
537 if beta != T::zero() {
538 strided_kernel::copy_into_col_major(&mut buf.view_mut(), &view.as_view())?;
539 }
540 let ptr = buf.view_mut().as_mut_ptr();
541 let (row_stride, col_stride, batch_strides) = col_major_layout(&buf, n_group1, n_inner);
542 Ok(ContiguousOperandMut {
543 ptr,
544 row_stride,
545 col_stride,
546 batch_strides,
547 needs_writeback: true,
548 _buf: Some(buf),
549 buf_is_pooled,
550 })
551 } else {
552 let (_, rs) = check.fused_g1.unwrap();
553 let (_, cs) = check.fused_g2.unwrap();
554 Ok(ContiguousOperandMut {
555 ptr: view.as_mut_ptr(),
556 row_stride: rs,
557 col_stride: cs,
558 batch_strides: strides[n_inner..].to_vec(),
559 needs_writeback: false,
560 _buf: None,
561 buf_is_pooled: false,
562 })
563 }
564}
565
566#[cfg(test)]
567mod tests_generic_backend {
568 use super::*;
569 use crate::backend::{Backend, NaiveBackend};
570
571 #[test]
572 fn test_input_for_backend_contiguous() {
573 let a = StridedArray::<f64>::col_major(&[2, 3]);
574 let view = a.view();
575 let op = prepare_input_view(
576 &view,
577 1,
578 1,
579 false,
580 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
581 false,
582 None,
583 )
584 .unwrap();
585 assert!(op._buf.is_none());
586 assert_eq!(op.row_stride(), 1);
587 assert_eq!(op.col_stride(), 2);
588 assert!(!op.conj());
589 }
590
591 #[test]
592 fn test_input_for_backend_non_contiguous() {
593 let data = vec![0.0f64; 100];
594 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
595 let view = a.view();
596 let op = prepare_input_view(
597 &view,
598 2,
599 1,
600 false,
601 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
602 false,
603 None,
604 )
605 .unwrap();
606 assert!(op._buf.is_some());
607 assert_eq!(op.row_stride(), 1);
608 assert_eq!(op.col_stride(), 6);
609 }
610
611 #[test]
612 fn test_output_for_backend_contiguous() {
613 let mut c = StridedArray::<f64>::col_major(&[2, 3]);
614 let mut view = c.view_mut();
615 let op = prepare_output_view(
616 &mut view,
617 1,
618 1,
619 0.0,
620 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
621 false,
622 )
623 .unwrap();
624 assert!(!op.needs_writeback);
625 assert!(op._buf.is_none());
626 assert_eq!(op.row_stride(), 1);
627 assert_eq!(op.col_stride(), 2);
628 }
629
630 #[test]
631 fn test_output_for_backend_non_contiguous_beta_zero() {
632 let data = vec![0.0f64; 100];
633 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
634 let mut view = c.view_mut();
635 let op = prepare_output_view(
636 &mut view,
637 2,
638 1,
639 0.0,
640 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
641 false,
642 )
643 .unwrap();
644 assert!(op.needs_writeback);
645 assert!(op._buf.is_some());
646 assert_eq!(op.row_stride(), 1);
647 assert_eq!(op.col_stride(), 6);
648 }
649
650 #[test]
651 fn test_output_for_backend_non_contiguous_beta_nonzero_and_finalize() {
652 let mut data = vec![0.0f64; 30];
653 data[0] = 10.0;
654 data[1] = 20.0;
655 data[10] = 40.0;
656 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
657 let mut view = c.view_mut();
658 let op = prepare_output_view(
659 &mut view,
660 2,
661 1,
662 1.0,
663 <NaiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE,
664 false,
665 )
666 .unwrap();
667 assert!(op.needs_writeback);
668 let buf = op._buf.as_ref().unwrap();
669 assert_eq!(buf.get(&[0, 0, 0]), 10.0);
670 assert_eq!(buf.get(&[0, 1, 0]), 20.0);
671 assert_eq!(buf.get(&[1, 0, 0]), 40.0);
672 op.finalize_into(&mut view).unwrap();
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679 use crate::backend::{ActiveBackend, Backend};
680
681 const UNIT_STRIDE: bool = <ActiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE;
683
684 #[test]
685 fn test_borrowed_contiguous_no_copy() {
686 let a = StridedArray::<f64>::col_major(&[2, 3]);
687 let view = a.view();
688
689 let op = prepare_input_view(&view, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
690
691 assert!(!op.has_buf());
692 assert_eq!(op.row_stride(), 1);
693 assert_eq!(op.col_stride(), 2);
694 assert!(!op.conj());
695 }
696
697 #[test]
698 fn test_borrowed_non_contiguous_copies() {
699 let data = vec![0.0f64; 100];
700 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
701 let view = a.view();
702
703 let op = prepare_input_view(&view, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
704
705 assert!(op.has_buf());
706 assert_eq!(op.row_stride(), 1);
707 assert_eq!(op.col_stride(), 6);
708 }
709
710 #[test]
711 fn test_owned_contiguous_no_copy() {
712 let a = StridedArray::<f64>::col_major(&[2, 3]);
713
714 let op = prepare_input_owned(a, 1, 1, false, UNIT_STRIDE, true, None).unwrap();
715
716 assert!(op.has_buf());
717 assert_eq!(op.row_stride(), 1);
718 assert_eq!(op.col_stride(), 2);
719 }
720
721 #[test]
722 fn test_owned_non_contiguous_copies() {
723 let data = vec![0.0f64; 100];
724 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
725
726 let op = prepare_input_owned(a, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
727
728 assert!(op.has_buf());
729 assert_eq!(op.row_stride(), 1);
730 assert_eq!(op.col_stride(), 6);
731 }
732
733 #[test]
734 fn test_output_view_contiguous() {
735 let mut c = StridedArray::<f64>::col_major(&[2, 3]);
736 let mut view = c.view_mut();
737
738 let op = prepare_output_view(&mut view, 1, 1, 0.0, UNIT_STRIDE, true).unwrap();
739
740 assert!(!op.needs_writeback());
741 assert!(!op.has_buf());
742 assert_eq!(op.row_stride(), 1);
743 assert_eq!(op.col_stride(), 2);
744 }
745
746 #[test]
747 fn test_output_view_non_contiguous_beta_zero() {
748 let data = vec![0.0f64; 100];
749 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
750 let mut view = c.view_mut();
751
752 let op = prepare_output_view(&mut view, 2, 1, 0.0, UNIT_STRIDE, true).unwrap();
753
754 assert!(op.needs_writeback());
755 assert!(op.has_buf());
756 assert_eq!(op.row_stride(), 1);
757 assert_eq!(op.col_stride(), 6);
758 }
759
760 #[test]
761 fn test_output_view_non_contiguous_beta_nonzero_and_finalize() {
762 let mut data = vec![0.0f64; 30];
763 data[0] = 10.0;
764 data[1] = 20.0;
765 data[2] = 30.0;
766 data[10] = 40.0;
767 data[11] = 50.0;
768 data[12] = 60.0;
769 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
770
771 assert_eq!(c.get(&[0, 0, 0]), 10.0);
772 assert_eq!(c.get(&[1, 1, 0]), 50.0);
773
774 let mut view = c.view_mut();
775
776 let mut op = prepare_output_view(&mut view, 2, 1, 1.0, UNIT_STRIDE, true).unwrap();
777
778 assert!(op.needs_writeback());
779 assert!(op.has_buf());
780
781 let buf = op._buf.as_ref().unwrap();
782 assert_eq!(buf.get(&[0, 0, 0]), 10.0);
783 assert_eq!(buf.get(&[1, 1, 0]), 50.0);
784
785 {
786 let result_data = vec![100.0f64; 6];
787 let result =
788 StridedArray::<f64>::from_parts(result_data, &[2, 3, 1], &[3, 1, 1], 0).unwrap();
789 strided_kernel::copy_into(&mut op._buf.as_mut().unwrap().view_mut(), &result.view())
790 .unwrap();
791 op.ptr = op._buf.as_mut().unwrap().view_mut().as_mut_ptr();
792 }
793
794 op.finalize_into(&mut view).unwrap();
795
796 assert_eq!(c.get(&[0, 0, 0]), 100.0);
797 assert_eq!(c.get(&[0, 1, 0]), 100.0);
798 assert_eq!(c.get(&[0, 2, 0]), 100.0);
799 assert_eq!(c.get(&[1, 0, 0]), 100.0);
800 assert_eq!(c.get(&[1, 1, 0]), 100.0);
801 assert_eq!(c.get(&[1, 2, 0]), 100.0);
802 }
803
804 #[test]
805 fn test_prepare_input_view_temp_buffer_is_recycled() {
806 let before = pooled_count_for_type::<f64>();
807 let data = vec![0.0f64; 100];
808 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
809 let view = a.view();
810
811 {
812 let op = prepare_input_view(&view, 2, 1, false, UNIT_STRIDE, true, None).unwrap();
813 assert!(op.has_buf());
814 }
815
816 let after = pooled_count_for_type::<f64>();
817 assert!(after >= before.saturating_add(1));
818 }
819}