1use std::marker::PhantomData;
11use std::ops::{Index, IndexMut};
12use std::sync::Arc;
13
14use crate::element_op::{ComposableElementOp, ElementOp, ElementOpApply, Identity};
15use crate::{Result, StridedError};
16
17pub(crate) fn validate_bounds(
23 len: usize,
24 dims: &[usize],
25 strides: &[isize],
26 offset: isize,
27) -> Result<()> {
28 if dims.len() != strides.len() {
29 return Err(StridedError::StrideLengthMismatch);
30 }
31 if dims.iter().any(|&d| d == 0) {
33 return Ok(());
34 }
35 let mut min_offset = offset;
37 let mut max_offset = offset;
38 for (&dim, &stride) in dims.iter().zip(strides.iter()) {
39 if dim > 1 {
40 let end = stride
41 .checked_mul(dim as isize - 1)
42 .ok_or(StridedError::OffsetOverflow)?;
43 if end >= 0 {
44 max_offset = max_offset
45 .checked_add(end)
46 .ok_or(StridedError::OffsetOverflow)?;
47 } else {
48 min_offset = min_offset
49 .checked_add(end)
50 .ok_or(StridedError::OffsetOverflow)?;
51 }
52 }
53 }
54 if min_offset < 0 || max_offset < 0 {
55 return Err(StridedError::OffsetOverflow);
56 }
57 if max_offset as usize >= len {
58 return Err(StridedError::OffsetOverflow);
59 }
60 Ok(())
61}
62
63pub fn col_major_strides(dims: &[usize]) -> Vec<isize> {
65 let rank = dims.len();
66 if rank == 0 {
67 return vec![];
68 }
69 let mut strides = vec![1isize; rank];
70 for i in 1..rank {
71 strides[i] = strides[i - 1] * dims[i - 1] as isize;
72 }
73 strides
74}
75
76pub fn row_major_strides(dims: &[usize]) -> Vec<isize> {
78 let rank = dims.len();
79 if rank == 0 {
80 return vec![];
81 }
82 let mut strides = vec![1isize; rank];
83 for i in (0..rank - 1).rev() {
84 strides[i] = strides[i + 1] * dims[i + 1] as isize;
85 }
86 strides
87}
88
89pub struct StridedView<'a, T, Op = Identity> {
105 ptr: *const T,
106 data: &'a [T],
107 dims: Arc<[usize]>,
108 strides: Arc<[isize]>,
109 offset: isize,
110 _op: PhantomData<Op>,
111}
112
113unsafe impl<T: Send, Op: Send> Send for StridedView<'_, T, Op> {}
114unsafe impl<T: Sync, Op: Sync> Sync for StridedView<'_, T, Op> {}
115
116impl<T, Op> Clone for StridedView<'_, T, Op> {
117 fn clone(&self) -> Self {
118 Self {
119 ptr: self.ptr,
120 data: self.data,
121 dims: self.dims.clone(),
122 strides: self.strides.clone(),
123 offset: self.offset,
124 _op: PhantomData,
125 }
126 }
127}
128
129impl<T: std::fmt::Debug, Op> std::fmt::Debug for StridedView<'_, T, Op> {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("StridedView")
132 .field("dims", &self.dims)
133 .field("strides", &self.strides)
134 .field("offset", &self.offset)
135 .finish()
136 }
137}
138
139impl<'a, T, Op> StridedView<'a, T, Op> {
140 pub fn new(data: &'a [T], dims: &[usize], strides: &[isize], offset: isize) -> Result<Self> {
142 validate_bounds(data.len(), dims, strides, offset)?;
143 let ptr = unsafe { data.as_ptr().offset(offset) };
144 Ok(Self {
145 ptr,
146 data,
147 dims: Arc::from(dims),
148 strides: Arc::from(strides),
149 offset,
150 _op: PhantomData,
151 })
152 }
153
154 pub unsafe fn new_unchecked(
159 data: &'a [T],
160 dims: &[usize],
161 strides: &[isize],
162 offset: isize,
163 ) -> Self {
164 let ptr = data.as_ptr().offset(offset);
165 Self {
166 ptr,
167 data,
168 dims: Arc::from(dims),
169 strides: Arc::from(strides),
170 offset,
171 _op: PhantomData,
172 }
173 }
174
175 #[inline]
177 pub fn dims(&self) -> &[usize] {
178 &self.dims
179 }
180
181 #[inline]
183 pub fn strides(&self) -> &[isize] {
184 &self.strides
185 }
186
187 #[inline]
189 pub fn offset(&self) -> isize {
190 self.offset
191 }
192
193 #[inline]
195 pub fn ndim(&self) -> usize {
196 self.dims.len()
197 }
198
199 #[inline]
201 pub fn len(&self) -> usize {
202 self.dims.iter().product()
203 }
204
205 #[inline]
207 pub fn is_empty(&self) -> bool {
208 self.dims.iter().any(|&d| d == 0)
209 }
210
211 #[inline]
213 pub fn data(&self) -> &'a [T] {
214 self.data
215 }
216
217 #[inline]
219 pub fn ptr(&self) -> *const T {
220 self.ptr
221 }
222
223 pub fn permute(&self, perm: &[usize]) -> Result<StridedView<'a, T, Op>> {
225 let rank = self.dims.len();
226 if perm.len() != rank {
227 return Err(StridedError::RankMismatch(perm.len(), rank));
228 }
229 let mut seen = vec![false; rank];
230 for &p in perm {
231 if p >= rank {
232 return Err(StridedError::InvalidAxis { axis: p, rank });
233 }
234 if seen[p] {
235 return Err(StridedError::InvalidAxis { axis: p, rank });
236 }
237 seen[p] = true;
238 }
239 let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
240 let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
241 Ok(StridedView {
242 ptr: self.ptr,
243 data: self.data,
244 dims: Arc::from(new_dims),
245 strides: Arc::from(new_strides),
246 offset: self.offset,
247 _op: PhantomData,
248 })
249 }
250
251 pub fn diagonal_view(&self, axis_pairs: &[(usize, usize)]) -> Result<StridedView<'a, T, Op>> {
262 let ndim = self.ndim();
263 let mut dims: Vec<usize> = self.dims().to_vec();
264 let mut strides: Vec<isize> = self.strides().to_vec();
265
266 let mut axes_to_remove = Vec::new();
267 for &(a, b) in axis_pairs {
268 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
269 if lo >= ndim || hi >= ndim {
270 return Err(StridedError::InvalidAxis {
271 axis: hi,
272 rank: ndim,
273 });
274 }
275 if lo == hi {
276 return Err(StridedError::InvalidAxis {
277 axis: lo,
278 rank: ndim,
279 });
280 }
281 strides[lo] += strides[hi];
282 dims[lo] = dims[lo].min(dims[hi]);
283 axes_to_remove.push(hi);
284 }
285
286 axes_to_remove.sort_unstable();
287 axes_to_remove.dedup();
288 for &ax in axes_to_remove.iter().rev() {
289 dims.remove(ax);
290 strides.remove(ax);
291 }
292
293 unsafe {
294 Ok(StridedView::new_unchecked(
295 self.data(),
296 &dims,
297 &strides,
298 self.offset(),
299 ))
300 }
301 }
302
303 pub fn broadcast(&self, target_dims: &[usize]) -> Result<StridedView<'a, T, Op>> {
307 if self.dims.len() != target_dims.len() {
308 return Err(StridedError::RankMismatch(
309 self.dims.len(),
310 target_dims.len(),
311 ));
312 }
313 let mut new_strides = Vec::with_capacity(self.dims.len());
314 for i in 0..self.dims.len() {
315 if self.dims[i] == target_dims[i] {
316 new_strides.push(self.strides[i]);
317 } else if self.dims[i] == 1 {
318 new_strides.push(0);
319 } else {
320 return Err(StridedError::ShapeMismatch(
321 self.dims.to_vec(),
322 target_dims.to_vec(),
323 ));
324 }
325 }
326 Ok(StridedView {
327 ptr: self.ptr,
328 data: self.data,
329 dims: Arc::from(target_dims),
330 strides: Arc::from(new_strides),
331 offset: self.offset,
332 _op: PhantomData,
333 })
334 }
335}
336
337impl<'a, T: Copy + ElementOpApply, Op: ComposableElementOp<T>> StridedView<'a, T, Op> {
339 pub fn transpose_2d(&self) -> Result<StridedView<'a, T, Op::ComposeTranspose>> {
343 if self.dims.len() != 2 {
344 return Err(StridedError::RankMismatch(self.dims.len(), 2));
345 }
346 Ok(StridedView {
347 ptr: self.ptr,
348 data: self.data,
349 dims: Arc::new([self.dims[1], self.dims[0]]),
350 strides: Arc::new([self.strides[1], self.strides[0]]),
351 offset: self.offset,
352 _op: PhantomData,
353 })
354 }
355
356 pub fn adjoint_2d(&self) -> Result<StridedView<'a, T, Op::ComposeAdjoint>> {
360 if self.dims.len() != 2 {
361 return Err(StridedError::RankMismatch(self.dims.len(), 2));
362 }
363 Ok(StridedView {
364 ptr: self.ptr,
365 data: self.data,
366 dims: Arc::new([self.dims[1], self.dims[0]]),
367 strides: Arc::new([self.strides[1], self.strides[0]]),
368 offset: self.offset,
369 _op: PhantomData,
370 })
371 }
372
373 pub fn conj(&self) -> StridedView<'a, T, Op::ComposeConj> {
377 StridedView {
378 ptr: self.ptr,
379 data: self.data,
380 dims: self.dims.clone(),
381 strides: self.strides.clone(),
382 offset: self.offset,
383 _op: PhantomData,
384 }
385 }
386}
387
388impl<'a, T: Copy, Op: ElementOp<T>> StridedView<'a, T, Op> {
390 pub fn get(&self, indices: &[usize]) -> T {
392 assert_eq!(indices.len(), self.dims.len(), "wrong number of indices");
393 let mut idx = 0isize;
394 for (i, &index) in indices.iter().enumerate() {
395 assert!(
396 index < self.dims[i],
397 "index {} out of bounds for dim {}",
398 index,
399 self.dims[i]
400 );
401 idx += index as isize * self.strides[i];
402 }
403 Op::apply(unsafe { *self.ptr.offset(idx) })
404 }
405
406 #[inline]
411 pub unsafe fn get_unchecked(&self, indices: &[usize]) -> T {
412 let mut idx = 0isize;
413 for (i, &index) in indices.iter().enumerate() {
414 idx += index as isize * self.strides[i];
415 }
416 Op::apply(*self.ptr.offset(idx))
417 }
418}
419
420pub struct StridedViewMut<'a, T> {
429 ptr: *mut T,
430 data: &'a mut [T],
431 dims: Arc<[usize]>,
432 strides: Arc<[isize]>,
433 offset: isize,
434}
435
436unsafe impl<T: Send> Send for StridedViewMut<'_, T> {}
437
438impl<T: std::fmt::Debug> std::fmt::Debug for StridedViewMut<'_, T> {
439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440 f.debug_struct("StridedViewMut")
441 .field("dims", &self.dims)
442 .field("strides", &self.strides)
443 .field("offset", &self.offset)
444 .finish()
445 }
446}
447
448impl<'a, T> StridedViewMut<'a, T> {
449 pub fn new(
451 data: &'a mut [T],
452 dims: &[usize],
453 strides: &[isize],
454 offset: isize,
455 ) -> Result<Self> {
456 validate_bounds(data.len(), dims, strides, offset)?;
457 let ptr = unsafe { data.as_mut_ptr().offset(offset) };
458 Ok(Self {
459 ptr,
460 data,
461 dims: Arc::from(dims),
462 strides: Arc::from(strides),
463 offset,
464 })
465 }
466
467 pub unsafe fn new_unchecked(
472 data: &'a mut [T],
473 dims: &[usize],
474 strides: &[isize],
475 offset: isize,
476 ) -> Self {
477 let ptr = data.as_mut_ptr().offset(offset);
478 Self {
479 ptr,
480 data,
481 dims: Arc::from(dims),
482 strides: Arc::from(strides),
483 offset,
484 }
485 }
486
487 #[inline]
489 pub fn dims(&self) -> &[usize] {
490 &self.dims
491 }
492
493 #[inline]
495 pub fn strides(&self) -> &[isize] {
496 &self.strides
497 }
498
499 #[inline]
501 pub fn offset(&self) -> isize {
502 self.offset
503 }
504
505 #[inline]
507 pub fn ndim(&self) -> usize {
508 self.dims.len()
509 }
510
511 #[inline]
513 pub fn len(&self) -> usize {
514 self.dims.iter().product()
515 }
516
517 #[inline]
519 pub fn is_empty(&self) -> bool {
520 self.dims.iter().any(|&d| d == 0)
521 }
522
523 #[inline]
525 pub fn ptr(&self) -> *const T {
526 self.ptr as *const T
527 }
528
529 #[inline]
531 pub fn as_mut_ptr(&self) -> *mut T {
532 self.ptr
533 }
534
535 #[inline]
537 pub fn data(&self) -> &[T] {
538 &self.data
539 }
540
541 #[inline]
543 pub fn data_mut(&mut self) -> &mut [T] {
544 &mut self.data
545 }
546
547 pub fn permute(self, perm: &[usize]) -> Result<StridedViewMut<'a, T>> {
552 let rank = self.dims.len();
553 if perm.len() != rank {
554 return Err(StridedError::RankMismatch(perm.len(), rank));
555 }
556 let mut seen = vec![false; rank];
557 for &p in perm {
558 if p >= rank {
559 return Err(StridedError::InvalidAxis { axis: p, rank });
560 }
561 if seen[p] {
562 return Err(StridedError::InvalidAxis { axis: p, rank });
563 }
564 seen[p] = true;
565 }
566 let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
567 let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
568 Ok(StridedViewMut {
569 ptr: self.ptr,
570 data: self.data,
571 dims: Arc::from(new_dims),
572 strides: Arc::from(new_strides),
573 offset: self.offset,
574 })
575 }
576
577 pub fn as_view(&self) -> StridedView<'_, T, Identity> {
579 StridedView {
580 ptr: self.ptr as *const T,
581 data: unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.data.len()) },
582 dims: self.dims.clone(),
583 strides: self.strides.clone(),
584 offset: self.offset,
585 _op: PhantomData,
586 }
587 }
588}
589
590impl<'a, T: Copy> StridedViewMut<'a, T> {
591 pub fn get(&self, indices: &[usize]) -> T {
593 assert_eq!(indices.len(), self.dims.len());
594 let mut idx = 0isize;
595 for (i, &index) in indices.iter().enumerate() {
596 assert!(index < self.dims[i]);
597 idx += index as isize * self.strides[i];
598 }
599 unsafe { *self.ptr.offset(idx) }
600 }
601
602 pub fn set(&mut self, indices: &[usize], value: T) {
604 assert_eq!(indices.len(), self.dims.len());
605 let mut idx = 0isize;
606 for (i, &index) in indices.iter().enumerate() {
607 assert!(index < self.dims[i]);
608 idx += index as isize * self.strides[i];
609 }
610 unsafe {
611 *self.ptr.offset(idx) = value;
612 }
613 }
614}
615
616pub struct StridedArray<T> {
624 data: Vec<T>,
625 dims: Arc<[usize]>,
626 strides: Arc<[isize]>,
627 offset: isize,
628}
629
630impl<T: std::fmt::Debug> std::fmt::Debug for StridedArray<T> {
631 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632 f.debug_struct("StridedArray")
633 .field("dims", &self.dims)
634 .field("strides", &self.strides)
635 .field("offset", &self.offset)
636 .finish()
637 }
638}
639
640impl<T: Clone> Clone for StridedArray<T> {
641 fn clone(&self) -> Self {
642 Self {
643 data: self.data.clone(),
644 dims: self.dims.clone(),
645 strides: self.strides.clone(),
646 offset: self.offset,
647 }
648 }
649}
650
651impl<T: Clone + Default> StridedArray<T> {
652 pub fn col_major(dims: &[usize]) -> Self {
654 let total: usize = dims.iter().product();
655 let data = vec![T::default(); total];
656 let strides = col_major_strides(dims);
657 Self {
658 data,
659 dims: Arc::from(dims),
660 strides: Arc::from(strides),
661 offset: 0,
662 }
663 }
664
665 pub fn row_major(dims: &[usize]) -> Self {
667 let total: usize = dims.iter().product();
668 let data = vec![T::default(); total];
669 let strides = row_major_strides(dims);
670 Self {
671 data,
672 dims: Arc::from(dims),
673 strides: Arc::from(strides),
674 offset: 0,
675 }
676 }
677
678 pub fn from_fn_col_major(dims: &[usize], mut f: impl FnMut(&[usize]) -> T) -> Self {
682 let total: usize = dims.iter().product();
683 let strides = col_major_strides(dims);
684 let rank = dims.len();
685 let mut data = Vec::with_capacity(total);
686 let mut idx = vec![0usize; rank];
687 for _ in 0..total {
688 data.push(f(&idx));
689 for d in 0..rank {
690 idx[d] += 1;
691 if idx[d] < dims[d] {
692 break;
693 }
694 idx[d] = 0;
695 }
696 }
697 Self {
698 data,
699 dims: Arc::from(dims),
700 strides: Arc::from(strides),
701 offset: 0,
702 }
703 }
704
705 pub fn from_fn_row_major(dims: &[usize], mut f: impl FnMut(&[usize]) -> T) -> Self {
709 let total: usize = dims.iter().product();
710 let strides = row_major_strides(dims);
711 let rank = dims.len();
712 let mut data = Vec::with_capacity(total);
713 let mut idx = vec![0usize; rank];
714 for _ in 0..total {
715 data.push(f(&idx));
716 for d in (0..rank).rev() {
717 idx[d] += 1;
718 if idx[d] < dims[d] {
719 break;
720 }
721 idx[d] = 0;
722 }
723 }
724 Self {
725 data,
726 dims: Arc::from(dims),
727 strides: Arc::from(strides),
728 offset: 0,
729 }
730 }
731}
732
733impl<T> StridedArray<T> {
734 pub fn from_parts(
736 data: Vec<T>,
737 dims: &[usize],
738 strides: &[isize],
739 offset: isize,
740 ) -> Result<Self> {
741 validate_bounds(data.len(), dims, strides, offset)?;
742 Ok(Self {
743 data,
744 dims: Arc::from(dims),
745 strides: Arc::from(strides),
746 offset,
747 })
748 }
749
750 #[inline]
752 pub fn dims(&self) -> &[usize] {
753 &self.dims
754 }
755
756 #[inline]
758 pub fn strides(&self) -> &[isize] {
759 &self.strides
760 }
761
762 #[inline]
764 pub fn ndim(&self) -> usize {
765 self.dims.len()
766 }
767
768 #[inline]
770 pub fn len(&self) -> usize {
771 self.dims.iter().product()
772 }
773
774 #[inline]
776 pub fn is_empty(&self) -> bool {
777 self.dims.iter().any(|&d| d == 0)
778 }
779
780 #[inline]
782 pub fn data(&self) -> &[T] {
783 &self.data
784 }
785
786 #[inline]
788 pub fn data_mut(&mut self) -> &mut [T] {
789 &mut self.data
790 }
791
792 pub fn view(&self) -> StridedView<'_, T> {
794 let ptr = unsafe { self.data.as_ptr().offset(self.offset) };
795 StridedView {
796 ptr,
797 data: &self.data,
798 dims: self.dims.clone(),
799 strides: self.strides.clone(),
800 offset: self.offset,
801 _op: PhantomData,
802 }
803 }
804
805 pub fn view_mut(&mut self) -> StridedViewMut<'_, T> {
807 let ptr = unsafe { self.data.as_mut_ptr().offset(self.offset) };
808 StridedViewMut {
809 ptr,
810 data: &mut self.data,
811 dims: self.dims.clone(),
812 strides: self.strides.clone(),
813 offset: self.offset,
814 }
815 }
816
817 pub fn permuted(self, perm: &[usize]) -> Result<Self> {
822 let rank = self.dims.len();
823 if perm.len() != rank {
824 return Err(StridedError::RankMismatch(perm.len(), rank));
825 }
826 let mut seen = vec![false; rank];
827 for &p in perm {
828 if p >= rank {
829 return Err(StridedError::InvalidAxis { axis: p, rank });
830 }
831 if seen[p] {
832 return Err(StridedError::InvalidAxis { axis: p, rank });
833 }
834 seen[p] = true;
835 }
836 let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
837 let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
838 Ok(Self {
839 data: self.data,
840 dims: Arc::from(new_dims),
841 strides: Arc::from(new_strides),
842 offset: self.offset,
843 })
844 }
845
846 pub fn into_data(self) -> Vec<T> {
848 self.data
849 }
850
851 pub fn iter(&self) -> std::slice::Iter<'_, T> {
853 self.data.iter()
854 }
855
856 pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
858 self.data.iter_mut()
859 }
860}
861
862impl<T: Default> StridedArray<T> {
863 pub fn col_major_from_buffer(mut buf: Vec<T>, dims: &[usize]) -> Self {
868 let total: usize = dims.iter().product();
869 if buf.len() >= total {
870 buf.truncate(total);
871 } else {
872 buf.resize_with(total, T::default);
873 }
874 for v in buf.iter_mut() {
876 *v = T::default();
877 }
878 let strides = col_major_strides(dims);
879 Self {
880 data: buf,
881 dims: Arc::from(dims),
882 strides: Arc::from(strides),
883 offset: 0,
884 }
885 }
886}
887
888impl<T: Copy> StridedArray<T> {
889 pub unsafe fn col_major_uninit(dims: &[usize]) -> Self {
894 let total: usize = dims.iter().product();
895 let mut data = Vec::with_capacity(total);
896 data.set_len(total);
897 let strides = col_major_strides(dims);
898 Self {
899 data,
900 dims: Arc::from(dims),
901 strides: Arc::from(strides),
902 offset: 0,
903 }
904 }
905
906 pub unsafe fn col_major_from_buffer_uninit(mut buf: Vec<T>, dims: &[usize]) -> Self {
911 let total: usize = dims.iter().product();
912 if buf.capacity() < total {
913 buf.reserve(total - buf.len());
914 }
915 buf.set_len(total);
916 let strides = col_major_strides(dims);
917 Self {
918 data: buf,
919 dims: Arc::from(dims),
920 strides: Arc::from(strides),
921 offset: 0,
922 }
923 }
924}
925
926impl<T: Copy> StridedArray<T> {
927 pub fn get(&self, indices: &[usize]) -> T {
929 self.view().get(indices)
930 }
931
932 pub fn set(&mut self, indices: &[usize], value: T) {
934 assert_eq!(indices.len(), self.dims.len());
935 let mut idx = self.offset;
936 for (i, &index) in indices.iter().enumerate() {
937 assert!(index < self.dims[i]);
938 idx += index as isize * self.strides[i];
939 }
940 self.data[idx as usize] = value;
941 }
942}
943
944impl<T: Copy> Index<&[usize]> for StridedArray<T> {
945 type Output = T;
946
947 fn index(&self, indices: &[usize]) -> &T {
948 let mut idx = self.offset;
949 for (i, &index) in indices.iter().enumerate() {
950 assert!(index < self.dims[i]);
951 idx += index as isize * self.strides[i];
952 }
953 &self.data[idx as usize]
954 }
955}
956
957impl<T: Copy> IndexMut<&[usize]> for StridedArray<T> {
958 fn index_mut(&mut self, indices: &[usize]) -> &mut T {
959 let mut idx = self.offset;
960 for (i, &index) in indices.iter().enumerate() {
961 assert!(index < self.dims[i]);
962 idx += index as isize * self.strides[i];
963 }
964 &mut self.data[idx as usize]
965 }
966}
967
968#[cfg(test)]
973mod tests {
974 use super::*;
975 use num_complex::Complex64;
976
977 #[test]
978 fn test_col_major_strides() {
979 assert_eq!(col_major_strides(&[3, 4]), vec![1, 3]);
980 assert_eq!(col_major_strides(&[2, 3, 4]), vec![1, 2, 6]);
981 }
982
983 #[test]
984 fn test_row_major_strides() {
985 assert_eq!(row_major_strides(&[3, 4]), vec![4, 1]);
986 assert_eq!(row_major_strides(&[2, 3, 4]), vec![12, 4, 1]);
987 }
988
989 #[test]
990 fn test_strided_view_new() {
991 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
992 let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
993 assert_eq!(view.ndim(), 2);
994 assert_eq!(view.dims(), &[2, 3]);
995 assert_eq!(view.strides(), &[3, 1]);
996 assert_eq!(view.len(), 6);
997 }
998
999 #[test]
1000 fn test_strided_view_get() {
1001 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1002 let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1003 assert_eq!(view.get(&[0, 0]), 1.0);
1004 assert_eq!(view.get(&[0, 1]), 2.0);
1005 assert_eq!(view.get(&[0, 2]), 3.0);
1006 assert_eq!(view.get(&[1, 0]), 4.0);
1007 assert_eq!(view.get(&[1, 2]), 6.0);
1008 }
1009
1010 #[test]
1011 fn test_strided_view_col_major() {
1012 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1014 let view = StridedView::<f64>::new(&data, &[2, 3], &[1, 2], 0).unwrap();
1015 assert_eq!(view.get(&[0, 0]), 1.0); assert_eq!(view.get(&[1, 0]), 2.0); assert_eq!(view.get(&[0, 1]), 3.0); assert_eq!(view.get(&[1, 1]), 4.0); assert_eq!(view.get(&[0, 2]), 5.0); assert_eq!(view.get(&[1, 2]), 6.0); }
1022
1023 #[test]
1024 fn test_strided_view_permute() {
1025 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1026 let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1027 let perm = view.permute(&[1, 0]).unwrap();
1028 assert_eq!(perm.dims(), &[3, 2]);
1029 assert_eq!(perm.strides(), &[1, 3]);
1030 assert_eq!(perm.get(&[0, 0]), 1.0);
1031 assert_eq!(perm.get(&[1, 0]), 2.0);
1032 assert_eq!(perm.get(&[0, 1]), 4.0);
1033 }
1034
1035 #[test]
1036 fn test_strided_view_transpose_2d() {
1037 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1038 let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1039 let t = view.transpose_2d().unwrap();
1040 assert_eq!(t.dims(), &[3, 2]);
1041 assert_eq!(t.get(&[0, 0]), 1.0);
1042 assert_eq!(t.get(&[1, 0]), 2.0);
1043 assert_eq!(t.get(&[0, 1]), 4.0);
1044 }
1045
1046 #[test]
1047 fn test_strided_view_conj() {
1048 let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1049 let view = StridedView::<Complex64>::new(&data, &[2], &[1], 0).unwrap();
1050 let c = view.conj();
1051 assert_eq!(c.get(&[0]), Complex64::new(1.0, -2.0));
1052 assert_eq!(c.get(&[1]), Complex64::new(3.0, -4.0));
1053 }
1054
1055 #[test]
1056 fn test_strided_view_adjoint_2d() {
1057 let data = vec![
1058 Complex64::new(1.0, 2.0),
1059 Complex64::new(3.0, 4.0),
1060 Complex64::new(5.0, 6.0),
1061 Complex64::new(7.0, 8.0),
1062 ];
1063 let view = StridedView::<Complex64>::new(&data, &[2, 2], &[2, 1], 0).unwrap();
1065 let adj = view.adjoint_2d().unwrap();
1066 assert_eq!(adj.dims(), &[2, 2]);
1067 assert_eq!(adj.get(&[0, 0]), Complex64::new(1.0, -2.0));
1069 assert_eq!(adj.get(&[1, 0]), Complex64::new(3.0, -4.0));
1070 assert_eq!(adj.get(&[0, 1]), Complex64::new(5.0, -6.0));
1071 }
1072
1073 #[test]
1074 fn test_strided_view_broadcast() {
1075 let data = vec![1.0, 2.0, 3.0];
1076 let view = StridedView::<f64>::new(&data, &[1, 3], &[3, 1], 0).unwrap();
1077 let broad = view.broadcast(&[4, 3]).unwrap();
1078 assert_eq!(broad.dims(), &[4, 3]);
1079 for i in 0..4 {
1080 assert_eq!(broad.get(&[i, 0]), 1.0);
1081 assert_eq!(broad.get(&[i, 1]), 2.0);
1082 assert_eq!(broad.get(&[i, 2]), 3.0);
1083 }
1084 }
1085
1086 #[test]
1087 fn test_strided_view_mut() {
1088 let mut data = vec![0.0; 6];
1089 {
1090 let mut view = StridedViewMut::<f64>::new(&mut data, &[2, 3], &[3, 1], 0).unwrap();
1091 view.set(&[0, 0], 1.0);
1092 view.set(&[1, 2], 6.0);
1093 }
1094 assert_eq!(data[0], 1.0);
1095 assert_eq!(data[5], 6.0);
1096 }
1097
1098 #[test]
1099 fn test_strided_view_mut_as_view() {
1100 let mut data = vec![1.0, 2.0, 3.0];
1101 let vm = StridedViewMut::<f64>::new(&mut data, &[3], &[1], 0).unwrap();
1102 let v = vm.as_view();
1103 assert_eq!(v.get(&[0]), 1.0);
1104 assert_eq!(v.get(&[2]), 3.0);
1105 }
1106
1107 #[test]
1108 fn test_strided_tensor_col_major() {
1109 let t = StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
1110 assert_eq!(t.dims(), &[2, 3]);
1111 assert_eq!(t.strides(), &[1, 2]); assert_eq!(t.get(&[0, 0]), 0.0);
1113 assert_eq!(t.get(&[1, 0]), 3.0);
1114 assert_eq!(t.get(&[0, 1]), 1.0);
1115 assert_eq!(t.get(&[1, 2]), 5.0);
1116 }
1117
1118 #[test]
1119 fn test_strided_tensor_row_major() {
1120 let t = StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
1121 assert_eq!(t.dims(), &[2, 3]);
1122 assert_eq!(t.strides(), &[3, 1]); assert_eq!(t.get(&[0, 0]), 0.0);
1124 assert_eq!(t.get(&[0, 1]), 1.0);
1125 assert_eq!(t.get(&[1, 0]), 3.0);
1126 assert_eq!(t.get(&[1, 2]), 5.0);
1127 }
1128
1129 #[test]
1130 fn test_strided_tensor_view() {
1131 let t =
1132 StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
1133 let v = t.view();
1134 assert_eq!(v.get(&[0, 0]), 0.0);
1135 assert_eq!(v.get(&[1, 0]), 10.0);
1136 assert_eq!(v.get(&[0, 2]), 2.0);
1137 }
1138
1139 #[test]
1140 fn test_strided_tensor_view_mut() {
1141 let mut t = StridedArray::<f64>::col_major(&[2, 3]);
1142 {
1143 let mut vm = t.view_mut();
1144 vm.set(&[1, 2], 42.0);
1145 }
1146 assert_eq!(t.get(&[1, 2]), 42.0);
1147 }
1148
1149 #[test]
1150 fn test_strided_tensor_index() {
1151 let t =
1152 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 10 + idx[1]) as f64);
1153 assert_eq!(t[&[0usize, 0] as &[usize]], 0.0);
1154 assert_eq!(t[&[2usize, 3] as &[usize]], 23.0);
1155 }
1156
1157 #[test]
1158 fn test_strided_tensor_index_mut() {
1159 let mut t = StridedArray::<f64>::row_major(&[2, 3]);
1160 t[&[1usize, 2] as &[usize]] = 99.0;
1161 assert_eq!(t.get(&[1, 2]), 99.0);
1162 }
1163
1164 #[test]
1165 fn test_validate_bounds_ok() {
1166 assert!(validate_bounds(6, &[2, 3], &[3, 1], 0).is_ok());
1167 assert!(validate_bounds(6, &[2, 3], &[1, 2], 0).is_ok());
1168 }
1169
1170 #[test]
1171 fn test_validate_bounds_out_of_range() {
1172 assert!(validate_bounds(5, &[2, 3], &[3, 1], 0).is_err());
1173 }
1174
1175 #[test]
1176 fn test_validate_bounds_empty() {
1177 assert!(validate_bounds(0, &[0, 3], &[3, 1], 0).is_ok());
1178 }
1179
1180 #[test]
1181 fn test_validate_bounds_with_offset() {
1182 assert!(validate_bounds(7, &[2, 3], &[3, 1], 1).is_ok());
1183 assert!(validate_bounds(6, &[2, 3], &[3, 1], 1).is_err());
1184 }
1185
1186 #[test]
1187 fn test_strided_tensor_3d() {
1188 let t = StridedArray::<f64>::from_fn_col_major(&[2, 3, 4], |idx| {
1189 (idx[0] * 100 + idx[1] * 10 + idx[2]) as f64
1190 });
1191 assert_eq!(t.ndim(), 3);
1192 assert_eq!(t.strides(), &[1, 2, 6]); assert_eq!(t.get(&[0, 0, 0]), 0.0);
1194 assert_eq!(t.get(&[1, 0, 0]), 100.0);
1195 assert_eq!(t.get(&[0, 1, 0]), 10.0);
1196 assert_eq!(t.get(&[0, 0, 1]), 1.0);
1197 assert_eq!(t.get(&[1, 2, 3]), 123.0);
1198 }
1199
1200 #[test]
1201 fn test_diagonal_view_2d() {
1202 let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
1205 let view = StridedView::<f64>::new(&data, &[3, 3], &[3, 1], 0).unwrap();
1206 let diag = view.diagonal_view(&[(0, 1)]).unwrap();
1207 assert_eq!(diag.dims(), &[3]);
1208 assert_eq!(diag.strides(), &[4]);
1209 assert_eq!(diag.get(&[0]), 0.0); assert_eq!(diag.get(&[1]), 4.0); assert_eq!(diag.get(&[2]), 8.0); }
1213
1214 #[test]
1215 fn test_diagonal_view_3d_adjacent() {
1216 let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
1219 let view = StridedView::<f64>::new(&data, &[2, 2, 3], &[6, 3, 1], 0).unwrap();
1220 let diag = view.diagonal_view(&[(0, 1)]).unwrap();
1221 assert_eq!(diag.dims(), &[2, 3]);
1222 assert_eq!(diag.strides(), &[9, 1]);
1223 assert_eq!(diag.get(&[0, 0]), 0.0);
1224 assert_eq!(diag.get(&[0, 2]), 2.0);
1225 assert_eq!(diag.get(&[1, 0]), 9.0);
1226 }
1227
1228 #[test]
1229 fn test_diagonal_view_3d_non_adjacent() {
1230 let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
1233 let view = StridedView::<f64>::new(&data, &[2, 3, 2], &[6, 2, 1], 0).unwrap();
1234 let diag = view.diagonal_view(&[(0, 2)]).unwrap();
1235 assert_eq!(diag.dims(), &[2, 3]);
1236 assert_eq!(diag.strides(), &[7, 2]);
1237 assert_eq!(diag.get(&[0, 0]), 0.0);
1238 assert_eq!(diag.get(&[0, 1]), 2.0);
1239 assert_eq!(diag.get(&[1, 0]), 7.0);
1240 assert_eq!(diag.get(&[1, 2]), 11.0);
1241 }
1242
1243 #[test]
1244 fn test_diagonal_view_two_pairs() {
1245 let data: Vec<f64> = (0..36).map(|x| x as f64).collect();
1247 let view = StridedView::<f64>::new(&data, &[2, 3, 2, 3], &[18, 6, 3, 1], 0).unwrap();
1248 let diag = view.diagonal_view(&[(0, 2), (1, 3)]).unwrap();
1249 assert_eq!(diag.dims(), &[2, 3]);
1250 assert_eq!(diag.strides(), &[21, 7]);
1251 assert_eq!(diag.get(&[0, 0]), 0.0);
1252 assert_eq!(diag.get(&[1, 1]), 28.0);
1253 }
1254
1255 #[test]
1256 fn test_custom_type_without_element_op_apply() {
1257 #[derive(Debug, Clone, Copy, PartialEq)]
1260 struct MyScalar(f64);
1261
1262 impl Default for MyScalar {
1263 fn default() -> Self {
1264 MyScalar(0.0)
1265 }
1266 }
1267
1268 let data = vec![MyScalar(1.0), MyScalar(2.0), MyScalar(3.0), MyScalar(4.0)];
1270 let arr = StridedArray::from_parts(data, &[2, 2], &[2, 1], 0).unwrap();
1271
1272 assert_eq!(arr.get(&[0, 0]), MyScalar(1.0));
1274 assert_eq!(arr.get(&[1, 0]), MyScalar(3.0));
1275
1276 let view: StridedView<MyScalar> = arr.view();
1278 assert_eq!(view.get(&[0, 1]), MyScalar(2.0));
1279
1280 let perm = view.permute(&[1, 0]).unwrap();
1282 assert_eq!(perm.get(&[0, 0]), MyScalar(1.0));
1283 assert_eq!(perm.get(&[0, 1]), MyScalar(3.0)); }
1285}