1use crate::backend::{ActiveBackend, BackendConfig};
8use crate::util::try_fuse_group;
9use crate::{Scalar, ScalarBase};
10use strided_view::{StridedArray, StridedView, StridedViewMut};
11
12pub struct ContiguousOperand<T> {
14 ptr: *const T,
15 row_stride: isize,
16 col_stride: isize,
17 batch_strides: Vec<isize>,
18 conj: bool,
19 pub(crate) _buf: Option<StridedArray<T>>,
21}
22
23pub struct ContiguousOperandMut<T> {
25 ptr: *mut T,
26 row_stride: isize,
27 col_stride: isize,
28 batch_strides: Vec<isize>,
29 needs_writeback: bool,
32 pub(crate) _buf: Option<StridedArray<T>>,
34}
35
36impl<T> ContiguousOperand<T> {
37 #[inline]
39 pub fn ptr(&self) -> *const T {
40 self.ptr
41 }
42
43 #[inline]
45 pub fn row_stride(&self) -> isize {
46 self.row_stride
47 }
48
49 #[inline]
51 pub fn col_stride(&self) -> isize {
52 self.col_stride
53 }
54
55 #[inline]
57 pub fn batch_strides(&self) -> &[isize] {
58 &self.batch_strides
59 }
60
61 #[inline]
63 pub fn conj(&self) -> bool {
64 self.conj
65 }
66
67 #[cfg(test)]
69 #[inline]
70 pub(crate) fn has_buf(&self) -> bool {
71 self._buf.is_some()
72 }
73}
74
75impl<T> ContiguousOperandMut<T> {
76 #[inline]
78 pub fn ptr(&self) -> *mut T {
79 self.ptr
80 }
81
82 #[inline]
84 pub fn row_stride(&self) -> isize {
85 self.row_stride
86 }
87
88 #[inline]
90 pub fn col_stride(&self) -> isize {
91 self.col_stride
92 }
93
94 #[inline]
96 pub fn batch_strides(&self) -> &[isize] {
97 &self.batch_strides
98 }
99
100 #[cfg(test)]
102 #[inline]
103 pub(crate) fn has_buf(&self) -> bool {
104 self._buf.is_some()
105 }
106
107 #[cfg(test)]
110 #[inline]
111 pub(crate) fn needs_writeback(&self) -> bool {
112 self.needs_writeback
113 }
114}
115
116impl<T: Copy + Send + Sync> ContiguousOperandMut<T> {
117 pub fn finalize_into(self, dest: &mut StridedViewMut<T>) -> crate::Result<()> {
122 if self.needs_writeback {
123 if let Some(ref buf) = self._buf {
124 strided_kernel::copy_into(dest, &buf.view())?;
125 }
126 }
127 Ok(())
128 }
129}
130
131pub(crate) fn alloc_col_major_uninit<T: Copy>(dims: &[usize]) -> StridedArray<T> {
137 let total: usize = dims.iter().product::<usize>().max(1);
138 let mut data = Vec::with_capacity(total);
143 unsafe { data.set_len(total) };
144
145 let mut strides = vec![0isize; dims.len()];
148 if !dims.is_empty() {
149 strides[0] = 1;
150 for i in 1..dims.len() {
151 strides[i] = strides[i - 1] * dims[i - 1] as isize;
152 }
153 }
154
155 let arr = StridedArray::from_parts(data, dims, &strides, 0).expect("col-major allocation");
156 arr
157}
158
159pub fn prepare_input_view<T: Scalar>(
165 view: &StridedView<T>,
166 _n_batch: usize,
167 n_group1: usize,
168 n_group2: usize,
169 conj: bool,
170) -> crate::Result<ContiguousOperand<T>> {
171 let dims = view.dims();
172 let strides = view.strides();
173 let n_inner = n_group1 + n_group2;
174
175 let group1_dims = &dims[..n_group1];
177 let group1_strides = &strides[..n_group1];
178 let group2_dims = &dims[n_group1..n_inner];
179 let group2_strides = &strides[n_group1..n_inner];
180
181 if ActiveBackend::MATERIALIZES_CONJ && conj {
184 use strided_view::Conj as ConjOp;
185 use strided_view::ElementOp;
186
187 let m: usize = group1_dims.iter().product::<usize>().max(1);
188 let mut buf = alloc_col_major_uninit(dims);
189 strided_kernel::map_into(&mut buf.view_mut(), view, |x| ConjOp::apply(x))?;
190 let ptr = buf.view().ptr();
191 let batch_strides = buf.strides()[n_inner..].to_vec();
192 let row_stride = if m == 0 { 0 } else { 1isize };
193 let col_stride = m as isize;
194 return Ok(ContiguousOperand {
195 ptr,
196 row_stride,
197 col_stride,
198 batch_strides,
199 conj: false,
200 _buf: Some(buf),
201 });
202 }
203
204 let fused_g1 = try_fuse_group(group1_dims, group1_strides);
205 let fused_g2 = try_fuse_group(group2_dims, group2_strides);
206
207 let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
208
209 if ActiveBackend::REQUIRES_UNIT_STRIDE && !needs_copy {
213 let (_, rs) = fused_g1.unwrap();
214 let (_, cs) = fused_g2.unwrap();
215 if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
216 needs_copy = true;
217 }
218 }
219
220 if needs_copy {
221 let m: usize = group1_dims.iter().product::<usize>().max(1);
222 let mut buf = alloc_col_major_uninit(dims);
223 strided_kernel::copy_into(&mut buf.view_mut(), view)?;
224 let ptr = buf.view().ptr();
225 let batch_strides = buf.strides()[n_inner..].to_vec();
226 let row_stride = if m == 0 { 0 } else { 1isize };
227 let col_stride = m as isize;
228 Ok(ContiguousOperand {
229 ptr,
230 row_stride,
231 col_stride,
232 batch_strides,
233 conj,
234 _buf: Some(buf),
235 })
236 } else {
237 let (_, rs) = fused_g1.unwrap();
238 let (_, cs) = fused_g2.unwrap();
239 let batch_strides = strides[n_inner..].to_vec();
240 Ok(ContiguousOperand {
241 ptr: view.ptr(),
242 row_stride: rs,
243 col_stride: cs,
244 batch_strides,
245 conj,
246 _buf: None,
247 })
248 }
249}
250
251pub fn prepare_input_owned<T: Scalar>(
257 arr: StridedArray<T>,
258 _n_batch: usize,
259 n_group1: usize,
260 n_group2: usize,
261 conj: bool,
262) -> crate::Result<ContiguousOperand<T>> {
263 let dims = arr.dims().to_vec();
264 let strides = arr.strides().to_vec();
265 let n_inner = n_group1 + n_group2;
266
267 let group1_dims = &dims[..n_group1];
269 let group1_strides = &strides[..n_group1];
270 let group2_dims = &dims[n_group1..n_inner];
271 let group2_strides = &strides[n_group1..n_inner];
272
273 if ActiveBackend::MATERIALIZES_CONJ && conj {
276 use strided_view::Conj as ConjOp;
277 use strided_view::ElementOp;
278
279 let m: usize = group1_dims.iter().product::<usize>().max(1);
280 let mut buf = alloc_col_major_uninit(&dims);
281 strided_kernel::map_into(&mut buf.view_mut(), &arr.view(), |x| ConjOp::apply(x))?;
282 let ptr = buf.view().ptr();
283 let batch_strides = buf.strides()[n_inner..].to_vec();
284 let row_stride = if m == 0 { 0 } else { 1isize };
285 let col_stride = m as isize;
286 return Ok(ContiguousOperand {
287 ptr,
288 row_stride,
289 col_stride,
290 batch_strides,
291 conj: false,
292 _buf: Some(buf),
293 });
294 }
295
296 let fused_g1 = try_fuse_group(group1_dims, group1_strides);
297 let fused_g2 = try_fuse_group(group2_dims, group2_strides);
298
299 let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
300
301 if ActiveBackend::REQUIRES_UNIT_STRIDE && !needs_copy {
303 let (_, rs) = fused_g1.unwrap();
304 let (_, cs) = fused_g2.unwrap();
305 if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
306 needs_copy = true;
307 }
308 }
309
310 if needs_copy {
311 let m: usize = group1_dims.iter().product::<usize>().max(1);
312 let mut buf = alloc_col_major_uninit(&dims);
313 strided_kernel::copy_into(&mut buf.view_mut(), &arr.view())?;
314 let ptr = buf.view().ptr();
315 let batch_strides = buf.strides()[n_inner..].to_vec();
316 let row_stride = if m == 0 { 0 } else { 1isize };
317 let col_stride = m as isize;
318 Ok(ContiguousOperand {
319 ptr,
320 row_stride,
321 col_stride,
322 batch_strides,
323 conj,
324 _buf: Some(buf),
325 })
326 } else {
327 let (_, rs) = fused_g1.unwrap();
328 let (_, cs) = fused_g2.unwrap();
329 let batch_strides = strides[n_inner..].to_vec();
330 let ptr = arr.view().ptr();
331 Ok(ContiguousOperand {
332 ptr,
333 row_stride: rs,
334 col_stride: cs,
335 batch_strides,
336 conj,
337 _buf: Some(arr),
338 })
339 }
340}
341
342pub fn prepare_output_view<T: Scalar>(
358 view: &mut StridedViewMut<T>,
359 _n_batch: usize,
360 n_group1: usize,
361 n_group2: usize,
362 beta: T,
363) -> crate::Result<ContiguousOperandMut<T>> {
364 let dims = view.dims().to_vec();
365 let strides = view.strides().to_vec();
366 let n_inner = n_group1 + n_group2;
367
368 let group1_dims = &dims[..n_group1];
369 let group1_strides = &strides[..n_group1];
370 let group2_dims = &dims[n_group1..n_inner];
371 let group2_strides = &strides[n_group1..n_inner];
372
373 let fused_g1 = try_fuse_group(group1_dims, group1_strides);
374 let fused_g2 = try_fuse_group(group2_dims, group2_strides);
375
376 let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
377
378 if ActiveBackend::REQUIRES_UNIT_STRIDE && !needs_copy {
380 let (_, rs) = fused_g1.unwrap();
381 let (_, cs) = fused_g2.unwrap();
382 if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
383 needs_copy = true;
384 }
385 }
386
387 if needs_copy {
388 let m: usize = group1_dims.iter().product::<usize>().max(1);
389 let mut buf = alloc_col_major_uninit(&dims);
390 if beta != T::zero() {
391 strided_kernel::copy_into(&mut buf.view_mut(), &view.as_view())?;
393 }
394 let ptr = buf.view_mut().as_mut_ptr();
395 let batch_strides = buf.strides()[n_inner..].to_vec();
396 let row_stride = if m == 0 { 0 } else { 1isize };
397 let col_stride = m as isize;
398 Ok(ContiguousOperandMut {
399 ptr,
400 row_stride,
401 col_stride,
402 batch_strides,
403 needs_writeback: true,
404 _buf: Some(buf),
405 })
406 } else {
407 let (_, rs) = fused_g1.unwrap();
408 let (_, cs) = fused_g2.unwrap();
409 let batch_strides = strides[n_inner..].to_vec();
410 Ok(ContiguousOperandMut {
411 ptr: view.as_mut_ptr(),
412 row_stride: rs,
413 col_stride: cs,
414 batch_strides,
415 needs_writeback: false,
416 _buf: None,
417 })
418 }
419}
420
421#[allow(dead_code)]
434pub fn prepare_output_owned<T: Scalar>(
435 arr: &mut StridedArray<T>,
436 _n_batch: usize,
437 n_group1: usize,
438 n_group2: usize,
439 beta: T,
440) -> crate::Result<ContiguousOperandMut<T>> {
441 let dims = arr.dims().to_vec();
442 let strides = arr.strides().to_vec();
443 let n_inner = n_group1 + n_group2;
444
445 let group1_dims = &dims[..n_group1];
446 let group1_strides = &strides[..n_group1];
447 let group2_dims = &dims[n_group1..n_inner];
448 let group2_strides = &strides[n_group1..n_inner];
449
450 let fused_g1 = try_fuse_group(group1_dims, group1_strides);
451 let fused_g2 = try_fuse_group(group2_dims, group2_strides);
452
453 let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
454
455 if ActiveBackend::REQUIRES_UNIT_STRIDE && !needs_copy {
457 let (_, rs) = fused_g1.unwrap();
458 let (_, cs) = fused_g2.unwrap();
459 if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
460 needs_copy = true;
461 }
462 }
463
464 if needs_copy {
465 let m: usize = group1_dims.iter().product::<usize>().max(1);
466 let mut buf = alloc_col_major_uninit(&dims);
467 if beta != T::zero() {
468 strided_kernel::copy_into(&mut buf.view_mut(), &arr.view())?;
469 }
470 let ptr = buf.view_mut().as_mut_ptr();
471 let batch_strides = buf.strides()[n_inner..].to_vec();
472 let row_stride = if m == 0 { 0 } else { 1isize };
473 let col_stride = m as isize;
474 Ok(ContiguousOperandMut {
475 ptr,
476 row_stride,
477 col_stride,
478 batch_strides,
479 needs_writeback: false,
480 _buf: Some(buf),
481 })
482 } else {
483 let (_, rs) = fused_g1.unwrap();
484 let (_, cs) = fused_g2.unwrap();
485 let batch_strides = strides[n_inner..].to_vec();
486 Ok(ContiguousOperandMut {
487 ptr: arr.view_mut().as_mut_ptr(),
488 row_stride: rs,
489 col_stride: cs,
490 batch_strides,
491 needs_writeback: false,
492 _buf: None,
493 })
494 }
495}
496
497pub fn prepare_input_view_for_backend<T: ScalarBase, B: BackendConfig>(
504 view: &StridedView<T>,
505 _n_batch: usize,
506 n_group1: usize,
507 n_group2: usize,
508) -> crate::Result<ContiguousOperand<T>> {
509 let dims = view.dims();
510 let strides = view.strides();
511 let n_inner = n_group1 + n_group2;
512
513 let group1_dims = &dims[..n_group1];
514 let group1_strides = &strides[..n_group1];
515 let group2_dims = &dims[n_group1..n_inner];
516 let group2_strides = &strides[n_group1..n_inner];
517
518 let fused_g1 = try_fuse_group(group1_dims, group1_strides);
519 let fused_g2 = try_fuse_group(group2_dims, group2_strides);
520
521 let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
522
523 if B::REQUIRES_UNIT_STRIDE && !needs_copy {
524 let (_, rs) = fused_g1.unwrap();
525 let (_, cs) = fused_g2.unwrap();
526 if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
527 needs_copy = true;
528 }
529 }
530
531 if needs_copy {
532 let m: usize = group1_dims.iter().product::<usize>().max(1);
533 let mut buf = alloc_col_major_uninit(dims);
534 strided_kernel::copy_into(&mut buf.view_mut(), view)?;
535 let ptr = buf.view().ptr();
536 let batch_strides = buf.strides()[n_inner..].to_vec();
537 let row_stride = if m == 0 { 0 } else { 1isize };
538 let col_stride = m as isize;
539 Ok(ContiguousOperand {
540 ptr,
541 row_stride,
542 col_stride,
543 batch_strides,
544 conj: false,
545 _buf: Some(buf),
546 })
547 } else {
548 let (_, rs) = fused_g1.unwrap();
549 let (_, cs) = fused_g2.unwrap();
550 let batch_strides = strides[n_inner..].to_vec();
551 Ok(ContiguousOperand {
552 ptr: view.ptr(),
553 row_stride: rs,
554 col_stride: cs,
555 batch_strides,
556 conj: false,
557 _buf: None,
558 })
559 }
560}
561
562pub fn prepare_output_view_for_backend<T: ScalarBase, B: BackendConfig>(
567 view: &mut StridedViewMut<T>,
568 _n_batch: usize,
569 n_group1: usize,
570 n_group2: usize,
571 beta: T,
572) -> crate::Result<ContiguousOperandMut<T>> {
573 let dims = view.dims().to_vec();
574 let strides = view.strides().to_vec();
575 let n_inner = n_group1 + n_group2;
576
577 let group1_dims = &dims[..n_group1];
578 let group1_strides = &strides[..n_group1];
579 let group2_dims = &dims[n_group1..n_inner];
580 let group2_strides = &strides[n_group1..n_inner];
581
582 let fused_g1 = try_fuse_group(group1_dims, group1_strides);
583 let fused_g2 = try_fuse_group(group2_dims, group2_strides);
584
585 let mut needs_copy = fused_g1.is_none() || fused_g2.is_none();
586
587 if B::REQUIRES_UNIT_STRIDE && !needs_copy {
588 let (_, rs) = fused_g1.unwrap();
589 let (_, cs) = fused_g2.unwrap();
590 if rs != 0 && rs != 1 && cs != 0 && cs != 1 {
591 needs_copy = true;
592 }
593 }
594
595 if needs_copy {
596 let m: usize = group1_dims.iter().product::<usize>().max(1);
597 let mut buf = alloc_col_major_uninit(&dims);
598 if beta != T::zero() {
599 strided_kernel::copy_into(&mut buf.view_mut(), &view.as_view())?;
600 }
601 let ptr = buf.view_mut().as_mut_ptr();
602 let batch_strides = buf.strides()[n_inner..].to_vec();
603 let row_stride = if m == 0 { 0 } else { 1isize };
604 let col_stride = m as isize;
605 Ok(ContiguousOperandMut {
606 ptr,
607 row_stride,
608 col_stride,
609 batch_strides,
610 needs_writeback: true,
611 _buf: Some(buf),
612 })
613 } else {
614 let (_, rs) = fused_g1.unwrap();
615 let (_, cs) = fused_g2.unwrap();
616 let batch_strides = strides[n_inner..].to_vec();
617 Ok(ContiguousOperandMut {
618 ptr: view.as_mut_ptr(),
619 row_stride: rs,
620 col_stride: cs,
621 batch_strides,
622 needs_writeback: false,
623 _buf: None,
624 })
625 }
626}
627
628#[cfg(test)]
629mod tests_generic_backend {
630 use super::*;
631 use crate::backend::NaiveBackend;
632
633 #[test]
634 fn test_input_for_backend_contiguous() {
635 let a = StridedArray::<f64>::col_major(&[2, 3]);
636 let view = a.view();
637 let op = prepare_input_view_for_backend::<f64, NaiveBackend>(&view, 0, 1, 1).unwrap();
638 assert!(op._buf.is_none());
640 assert_eq!(op.row_stride(), 1);
641 assert_eq!(op.col_stride(), 2);
642 assert!(!op.conj());
643 }
644
645 #[test]
646 fn test_input_for_backend_non_contiguous() {
647 let data = vec![0.0f64; 100];
648 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
649 let view = a.view();
650 let op = prepare_input_view_for_backend::<f64, NaiveBackend>(&view, 0, 2, 1).unwrap();
651 assert!(op._buf.is_some());
652 assert_eq!(op.row_stride(), 1);
653 assert_eq!(op.col_stride(), 6);
654 }
655
656 #[test]
657 fn test_output_for_backend_contiguous() {
658 let mut c = StridedArray::<f64>::col_major(&[2, 3]);
659 let mut view = c.view_mut();
660 let op =
661 prepare_output_view_for_backend::<f64, NaiveBackend>(&mut view, 0, 1, 1, 0.0).unwrap();
662 assert!(!op.needs_writeback);
663 assert!(op._buf.is_none());
664 assert_eq!(op.row_stride(), 1);
665 assert_eq!(op.col_stride(), 2);
666 }
667
668 #[test]
669 fn test_output_for_backend_non_contiguous_beta_zero() {
670 let data = vec![0.0f64; 100];
671 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
672 let mut view = c.view_mut();
673 let op =
674 prepare_output_view_for_backend::<f64, NaiveBackend>(&mut view, 0, 2, 1, 0.0).unwrap();
675 assert!(op.needs_writeback);
676 assert!(op._buf.is_some());
677 assert_eq!(op.row_stride(), 1);
678 assert_eq!(op.col_stride(), 6);
679 }
680
681 #[test]
682 fn test_output_for_backend_non_contiguous_beta_nonzero_and_finalize() {
683 let mut data = vec![0.0f64; 30];
684 data[0] = 10.0;
685 data[1] = 20.0;
686 data[10] = 40.0;
687 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
688 let mut view = c.view_mut();
689 let op =
690 prepare_output_view_for_backend::<f64, NaiveBackend>(&mut view, 0, 2, 1, 1.0).unwrap();
691 assert!(op.needs_writeback);
692 let buf = op._buf.as_ref().unwrap();
694 assert_eq!(buf.get(&[0, 0, 0]), 10.0);
695 assert_eq!(buf.get(&[0, 1, 0]), 20.0);
696 assert_eq!(buf.get(&[1, 0, 0]), 40.0);
697 op.finalize_into(&mut view).unwrap();
699 }
700}
701
702#[cfg(test)]
703#[cfg(any(
704 all(feature = "faer", not(any(feature = "blas", feature = "blas-inject"))),
705 all(
706 not(feature = "faer"),
707 any(
708 all(feature = "blas", not(feature = "blas-inject")),
709 all(feature = "blas-inject", not(feature = "blas"))
710 )
711 )
712))]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn test_borrowed_contiguous_no_copy() {
718 let a = StridedArray::<f64>::col_major(&[2, 3]);
723 let view = a.view();
724
725 let op = prepare_input_view(&view, 0, 1, 1, false).unwrap();
726
727 assert!(!op.has_buf());
728 assert_eq!(op.row_stride(), 1);
729 assert_eq!(op.col_stride(), 2);
730 assert!(!op.conj());
731 }
732
733 #[test]
734 fn test_borrowed_non_contiguous_copies() {
735 let data = vec![0.0f64; 100];
739 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
740 let view = a.view();
741
742 let op = prepare_input_view(&view, 0, 2, 1, false).unwrap();
743
744 assert!(op.has_buf());
745 assert_eq!(op.row_stride(), 1);
747 assert_eq!(op.col_stride(), 6);
748 }
749
750 #[test]
751 fn test_owned_contiguous_no_copy() {
752 let a = StridedArray::<f64>::col_major(&[2, 3]);
755
756 let op = prepare_input_owned(a, 0, 1, 1, false).unwrap();
757
758 assert!(op.has_buf());
760 assert_eq!(op.row_stride(), 1);
761 assert_eq!(op.col_stride(), 2);
762 }
763
764 #[test]
765 fn test_owned_non_contiguous_copies() {
766 let data = vec![0.0f64; 100];
769 let a = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
770
771 let op = prepare_input_owned(a, 0, 2, 1, false).unwrap();
772
773 assert!(op.has_buf());
774 assert_eq!(op.row_stride(), 1);
776 assert_eq!(op.col_stride(), 6);
777 }
778
779 #[test]
782 fn test_output_view_contiguous() {
783 let mut c = StridedArray::<f64>::col_major(&[2, 3]);
786 let mut view = c.view_mut();
787
788 let op = prepare_output_view(&mut view, 0, 1, 1, 0.0).unwrap();
789
790 assert!(!op.needs_writeback());
791 assert!(!op.has_buf());
792 assert_eq!(op.row_stride(), 1);
793 assert_eq!(op.col_stride(), 2);
794 }
795
796 #[test]
797 fn test_output_view_non_contiguous_beta_zero() {
798 let data = vec![0.0f64; 100];
802 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
803 let mut view = c.view_mut();
804
805 let op = prepare_output_view(&mut view, 0, 2, 1, 0.0).unwrap();
806
807 assert!(op.needs_writeback());
808 assert!(op.has_buf());
809 assert_eq!(op.row_stride(), 1);
811 assert_eq!(op.col_stride(), 6);
812 }
813
814 #[test]
815 fn test_output_view_non_contiguous_beta_nonzero_and_finalize() {
816 let mut data = vec![0.0f64; 30];
822 data[0] = 10.0; data[1] = 20.0; data[2] = 30.0; data[10] = 40.0; data[11] = 50.0; data[12] = 60.0; let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 1], &[10, 1, 1], 0).unwrap();
829
830 assert_eq!(c.get(&[0, 0, 0]), 10.0);
832 assert_eq!(c.get(&[1, 1, 0]), 50.0);
833
834 let mut view = c.view_mut();
835
836 let mut op = prepare_output_view(&mut view, 0, 2, 1, 1.0).unwrap();
838
839 assert!(op.needs_writeback());
840 assert!(op.has_buf());
841
842 let buf = op._buf.as_ref().unwrap();
845 assert_eq!(buf.get(&[0, 0, 0]), 10.0);
846 assert_eq!(buf.get(&[1, 1, 0]), 50.0);
847
848 {
850 let result_data = vec![100.0f64; 6];
851 let result =
852 StridedArray::<f64>::from_parts(result_data, &[2, 3, 1], &[3, 1, 1], 0).unwrap();
853 strided_kernel::copy_into(&mut op._buf.as_mut().unwrap().view_mut(), &result.view())
854 .unwrap();
855 op.ptr = op._buf.as_mut().unwrap().view_mut().as_mut_ptr();
857 }
858
859 op.finalize_into(&mut view).unwrap();
861
862 assert_eq!(c.get(&[0, 0, 0]), 100.0);
864 assert_eq!(c.get(&[0, 1, 0]), 100.0);
865 assert_eq!(c.get(&[0, 2, 0]), 100.0);
866 assert_eq!(c.get(&[1, 0, 0]), 100.0);
867 assert_eq!(c.get(&[1, 1, 0]), 100.0);
868 assert_eq!(c.get(&[1, 2, 0]), 100.0);
869 }
870
871 #[test]
872 fn test_output_owned_no_writeback() {
873 let data = vec![0.0f64; 100];
876 let mut c = StridedArray::<f64>::from_parts(data, &[2, 3, 4], &[20, 4, 1], 0).unwrap();
877
878 let op = prepare_output_owned(&mut c, 0, 2, 1, 0.0).unwrap();
879
880 assert!(!op.needs_writeback());
882 assert!(op.has_buf());
883 assert_eq!(op.row_stride(), 1);
884 assert_eq!(op.col_stride(), 6);
885 }
886}