1use std::marker::PhantomData;
11use std::ops::{Index, IndexMut};
12use std::sync::Arc;
13
14use crate::element_op::{ComposableElementOp, ElementOp, ElementOpApply, Identity};
15use crate::{Result, StridedError};
16
17fn validate_bounds(len: usize, dims: &[usize], strides: &[isize], offset: isize) -> Result<()> {
23 if dims.len() != strides.len() {
24 return Err(StridedError::StrideLengthMismatch);
25 }
26 if dims.iter().any(|&d| d == 0) {
28 return Ok(());
29 }
30 let mut min_offset = offset;
32 let mut max_offset = offset;
33 for (&dim, &stride) in dims.iter().zip(strides.iter()) {
34 if dim > 1 {
35 let end = stride
36 .checked_mul(dim as isize - 1)
37 .ok_or(StridedError::OffsetOverflow)?;
38 if end >= 0 {
39 max_offset = max_offset
40 .checked_add(end)
41 .ok_or(StridedError::OffsetOverflow)?;
42 } else {
43 min_offset = min_offset
44 .checked_add(end)
45 .ok_or(StridedError::OffsetOverflow)?;
46 }
47 }
48 }
49 if min_offset < 0 || max_offset < 0 {
50 return Err(StridedError::OffsetOverflow);
51 }
52 if max_offset as usize >= len {
53 return Err(StridedError::OffsetOverflow);
54 }
55 Ok(())
56}
57
58pub fn col_major_strides(dims: &[usize]) -> Vec<isize> {
60 let rank = dims.len();
61 if rank == 0 {
62 return vec![];
63 }
64 let mut strides = vec![1isize; rank];
65 for i in 1..rank {
66 strides[i] = strides[i - 1] * dims[i - 1] as isize;
67 }
68 strides
69}
70
71pub fn row_major_strides(dims: &[usize]) -> Vec<isize> {
73 let rank = dims.len();
74 if rank == 0 {
75 return vec![];
76 }
77 let mut strides = vec![1isize; rank];
78 for i in (0..rank - 1).rev() {
79 strides[i] = strides[i + 1] * dims[i + 1] as isize;
80 }
81 strides
82}
83
84pub struct StridedView<'a, T, Op = Identity> {
100 ptr: *const T,
101 data: &'a [T],
102 dims: Arc<[usize]>,
103 strides: Arc<[isize]>,
104 offset: isize,
105 _op: PhantomData<Op>,
106}
107
108unsafe impl<T: Send, Op: Send> Send for StridedView<'_, T, Op> {}
109unsafe impl<T: Sync, Op: Sync> Sync for StridedView<'_, T, Op> {}
110
111impl<T, Op> Clone for StridedView<'_, T, Op> {
112 fn clone(&self) -> Self {
113 Self {
114 ptr: self.ptr,
115 data: self.data,
116 dims: self.dims.clone(),
117 strides: self.strides.clone(),
118 offset: self.offset,
119 _op: PhantomData,
120 }
121 }
122}
123
124impl<T: std::fmt::Debug, Op> std::fmt::Debug for StridedView<'_, T, Op> {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("StridedView")
127 .field("dims", &self.dims)
128 .field("strides", &self.strides)
129 .field("offset", &self.offset)
130 .finish()
131 }
132}
133
134impl<'a, T, Op> StridedView<'a, T, Op> {
135 pub fn new(data: &'a [T], dims: &[usize], strides: &[isize], offset: isize) -> Result<Self> {
137 validate_bounds(data.len(), dims, strides, offset)?;
138 let ptr = unsafe { data.as_ptr().offset(offset) };
139 Ok(Self {
140 ptr,
141 data,
142 dims: Arc::from(dims),
143 strides: Arc::from(strides),
144 offset,
145 _op: PhantomData,
146 })
147 }
148
149 pub unsafe fn new_unchecked(
154 data: &'a [T],
155 dims: &[usize],
156 strides: &[isize],
157 offset: isize,
158 ) -> Self {
159 let ptr = data.as_ptr().offset(offset);
160 Self {
161 ptr,
162 data,
163 dims: Arc::from(dims),
164 strides: Arc::from(strides),
165 offset,
166 _op: PhantomData,
167 }
168 }
169
170 #[inline]
172 pub fn dims(&self) -> &[usize] {
173 &self.dims
174 }
175
176 #[inline]
178 pub fn strides(&self) -> &[isize] {
179 &self.strides
180 }
181
182 #[inline]
184 pub fn offset(&self) -> isize {
185 self.offset
186 }
187
188 #[inline]
190 pub fn ndim(&self) -> usize {
191 self.dims.len()
192 }
193
194 #[inline]
196 pub fn len(&self) -> usize {
197 self.dims.iter().product()
198 }
199
200 #[inline]
202 pub fn is_empty(&self) -> bool {
203 self.dims.iter().any(|&d| d == 0)
204 }
205
206 #[inline]
208 pub fn data(&self) -> &'a [T] {
209 self.data
210 }
211
212 #[inline]
214 pub fn ptr(&self) -> *const T {
215 self.ptr
216 }
217
218 pub fn permute(&self, perm: &[usize]) -> Result<StridedView<'a, T, Op>> {
220 let rank = self.dims.len();
221 if perm.len() != rank {
222 return Err(StridedError::RankMismatch(perm.len(), rank));
223 }
224 let mut seen = vec![false; rank];
225 for &p in perm {
226 if p >= rank {
227 return Err(StridedError::InvalidAxis { axis: p, rank });
228 }
229 if seen[p] {
230 return Err(StridedError::InvalidAxis { axis: p, rank });
231 }
232 seen[p] = true;
233 }
234 let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
235 let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
236 Ok(StridedView {
237 ptr: self.ptr,
238 data: self.data,
239 dims: Arc::from(new_dims),
240 strides: Arc::from(new_strides),
241 offset: self.offset,
242 _op: PhantomData,
243 })
244 }
245
246 pub fn diagonal_view(&self, axis_pairs: &[(usize, usize)]) -> Result<StridedView<'a, T, Op>> {
257 let ndim = self.ndim();
258 let mut dims: Vec<usize> = self.dims().to_vec();
259 let mut strides: Vec<isize> = self.strides().to_vec();
260
261 let mut axes_to_remove = Vec::new();
262 for &(a, b) in axis_pairs {
263 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
264 if lo >= ndim || hi >= ndim {
265 return Err(StridedError::InvalidAxis {
266 axis: hi,
267 rank: ndim,
268 });
269 }
270 if lo == hi {
271 return Err(StridedError::InvalidAxis {
272 axis: lo,
273 rank: ndim,
274 });
275 }
276 strides[lo] += strides[hi];
277 dims[lo] = dims[lo].min(dims[hi]);
278 axes_to_remove.push(hi);
279 }
280
281 axes_to_remove.sort_unstable();
282 axes_to_remove.dedup();
283 for &ax in axes_to_remove.iter().rev() {
284 dims.remove(ax);
285 strides.remove(ax);
286 }
287
288 unsafe {
289 Ok(StridedView::new_unchecked(
290 self.data(),
291 &dims,
292 &strides,
293 self.offset(),
294 ))
295 }
296 }
297
298 pub fn broadcast(&self, target_dims: &[usize]) -> Result<StridedView<'a, T, Op>> {
302 if self.dims.len() != target_dims.len() {
303 return Err(StridedError::RankMismatch(
304 self.dims.len(),
305 target_dims.len(),
306 ));
307 }
308 let mut new_strides = Vec::with_capacity(self.dims.len());
309 for i in 0..self.dims.len() {
310 if self.dims[i] == target_dims[i] {
311 new_strides.push(self.strides[i]);
312 } else if self.dims[i] == 1 {
313 new_strides.push(0);
314 } else {
315 return Err(StridedError::ShapeMismatch(
316 self.dims.to_vec(),
317 target_dims.to_vec(),
318 ));
319 }
320 }
321 Ok(StridedView {
322 ptr: self.ptr,
323 data: self.data,
324 dims: Arc::from(target_dims),
325 strides: Arc::from(new_strides),
326 offset: self.offset,
327 _op: PhantomData,
328 })
329 }
330}
331
332impl<'a, T: Copy + ElementOpApply, Op: ComposableElementOp<T>> StridedView<'a, T, Op> {
334 pub fn transpose_2d(&self) -> Result<StridedView<'a, T, Op::ComposeTranspose>> {
338 if self.dims.len() != 2 {
339 return Err(StridedError::RankMismatch(self.dims.len(), 2));
340 }
341 Ok(StridedView {
342 ptr: self.ptr,
343 data: self.data,
344 dims: Arc::new([self.dims[1], self.dims[0]]),
345 strides: Arc::new([self.strides[1], self.strides[0]]),
346 offset: self.offset,
347 _op: PhantomData,
348 })
349 }
350
351 pub fn adjoint_2d(&self) -> Result<StridedView<'a, T, Op::ComposeAdjoint>> {
355 if self.dims.len() != 2 {
356 return Err(StridedError::RankMismatch(self.dims.len(), 2));
357 }
358 Ok(StridedView {
359 ptr: self.ptr,
360 data: self.data,
361 dims: Arc::new([self.dims[1], self.dims[0]]),
362 strides: Arc::new([self.strides[1], self.strides[0]]),
363 offset: self.offset,
364 _op: PhantomData,
365 })
366 }
367
368 pub fn conj(&self) -> StridedView<'a, T, Op::ComposeConj> {
372 StridedView {
373 ptr: self.ptr,
374 data: self.data,
375 dims: self.dims.clone(),
376 strides: self.strides.clone(),
377 offset: self.offset,
378 _op: PhantomData,
379 }
380 }
381}
382
383impl<'a, T: Copy, Op: ElementOp<T>> StridedView<'a, T, Op> {
385 pub fn get(&self, indices: &[usize]) -> T {
387 assert_eq!(indices.len(), self.dims.len(), "wrong number of indices");
388 let mut idx = 0isize;
389 for (i, &index) in indices.iter().enumerate() {
390 assert!(
391 index < self.dims[i],
392 "index {} out of bounds for dim {}",
393 index,
394 self.dims[i]
395 );
396 idx += index as isize * self.strides[i];
397 }
398 Op::apply(unsafe { *self.ptr.offset(idx) })
399 }
400
401 #[inline]
406 pub unsafe fn get_unchecked(&self, indices: &[usize]) -> T {
407 let mut idx = 0isize;
408 for (i, &index) in indices.iter().enumerate() {
409 idx += index as isize * self.strides[i];
410 }
411 Op::apply(*self.ptr.offset(idx))
412 }
413}
414
415pub struct StridedViewMut<'a, T> {
424 ptr: *mut T,
425 data: &'a mut [T],
426 dims: Arc<[usize]>,
427 strides: Arc<[isize]>,
428 offset: isize,
429}
430
431unsafe impl<T: Send> Send for StridedViewMut<'_, T> {}
432
433impl<T: std::fmt::Debug> std::fmt::Debug for StridedViewMut<'_, T> {
434 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435 f.debug_struct("StridedViewMut")
436 .field("dims", &self.dims)
437 .field("strides", &self.strides)
438 .field("offset", &self.offset)
439 .finish()
440 }
441}
442
443impl<'a, T> StridedViewMut<'a, T> {
444 pub fn new(
446 data: &'a mut [T],
447 dims: &[usize],
448 strides: &[isize],
449 offset: isize,
450 ) -> Result<Self> {
451 validate_bounds(data.len(), dims, strides, offset)?;
452 let ptr = unsafe { data.as_mut_ptr().offset(offset) };
453 Ok(Self {
454 ptr,
455 data,
456 dims: Arc::from(dims),
457 strides: Arc::from(strides),
458 offset,
459 })
460 }
461
462 pub unsafe fn new_unchecked(
467 data: &'a mut [T],
468 dims: &[usize],
469 strides: &[isize],
470 offset: isize,
471 ) -> Self {
472 let ptr = data.as_mut_ptr().offset(offset);
473 Self {
474 ptr,
475 data,
476 dims: Arc::from(dims),
477 strides: Arc::from(strides),
478 offset,
479 }
480 }
481
482 #[inline]
484 pub fn dims(&self) -> &[usize] {
485 &self.dims
486 }
487
488 #[inline]
490 pub fn strides(&self) -> &[isize] {
491 &self.strides
492 }
493
494 #[inline]
496 pub fn offset(&self) -> isize {
497 self.offset
498 }
499
500 #[inline]
502 pub fn ndim(&self) -> usize {
503 self.dims.len()
504 }
505
506 #[inline]
508 pub fn len(&self) -> usize {
509 self.dims.iter().product()
510 }
511
512 #[inline]
514 pub fn is_empty(&self) -> bool {
515 self.dims.iter().any(|&d| d == 0)
516 }
517
518 #[inline]
520 pub fn ptr(&self) -> *const T {
521 self.ptr as *const T
522 }
523
524 #[inline]
526 pub fn as_mut_ptr(&self) -> *mut T {
527 self.ptr
528 }
529
530 pub fn permute(self, perm: &[usize]) -> Result<StridedViewMut<'a, T>> {
535 let rank = self.dims.len();
536 if perm.len() != rank {
537 return Err(StridedError::RankMismatch(perm.len(), rank));
538 }
539 let mut seen = vec![false; rank];
540 for &p in perm {
541 if p >= rank {
542 return Err(StridedError::InvalidAxis { axis: p, rank });
543 }
544 if seen[p] {
545 return Err(StridedError::InvalidAxis { axis: p, rank });
546 }
547 seen[p] = true;
548 }
549 let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
550 let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
551 Ok(StridedViewMut {
552 ptr: self.ptr,
553 data: self.data,
554 dims: Arc::from(new_dims),
555 strides: Arc::from(new_strides),
556 offset: self.offset,
557 })
558 }
559
560 pub fn as_view(&self) -> StridedView<'_, T, Identity> {
562 StridedView {
563 ptr: self.ptr as *const T,
564 data: unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.data.len()) },
565 dims: self.dims.clone(),
566 strides: self.strides.clone(),
567 offset: self.offset,
568 _op: PhantomData,
569 }
570 }
571}
572
573impl<'a, T: Copy> StridedViewMut<'a, T> {
574 pub fn get(&self, indices: &[usize]) -> T {
576 assert_eq!(indices.len(), self.dims.len());
577 let mut idx = 0isize;
578 for (i, &index) in indices.iter().enumerate() {
579 assert!(index < self.dims[i]);
580 idx += index as isize * self.strides[i];
581 }
582 unsafe { *self.ptr.offset(idx) }
583 }
584
585 pub fn set(&mut self, indices: &[usize], value: T) {
587 assert_eq!(indices.len(), self.dims.len());
588 let mut idx = 0isize;
589 for (i, &index) in indices.iter().enumerate() {
590 assert!(index < self.dims[i]);
591 idx += index as isize * self.strides[i];
592 }
593 unsafe {
594 *self.ptr.offset(idx) = value;
595 }
596 }
597}
598
599pub struct StridedArray<T> {
607 data: Vec<T>,
608 dims: Arc<[usize]>,
609 strides: Arc<[isize]>,
610 offset: isize,
611}
612
613impl<T: std::fmt::Debug> std::fmt::Debug for StridedArray<T> {
614 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615 f.debug_struct("StridedArray")
616 .field("dims", &self.dims)
617 .field("strides", &self.strides)
618 .field("offset", &self.offset)
619 .finish()
620 }
621}
622
623impl<T: Clone> Clone for StridedArray<T> {
624 fn clone(&self) -> Self {
625 Self {
626 data: self.data.clone(),
627 dims: self.dims.clone(),
628 strides: self.strides.clone(),
629 offset: self.offset,
630 }
631 }
632}
633
634impl<T: Clone + Default> StridedArray<T> {
635 pub fn col_major(dims: &[usize]) -> Self {
637 let total: usize = dims.iter().product();
638 let data = vec![T::default(); total];
639 let strides = col_major_strides(dims);
640 Self {
641 data,
642 dims: Arc::from(dims),
643 strides: Arc::from(strides),
644 offset: 0,
645 }
646 }
647
648 pub fn row_major(dims: &[usize]) -> Self {
650 let total: usize = dims.iter().product();
651 let data = vec![T::default(); total];
652 let strides = row_major_strides(dims);
653 Self {
654 data,
655 dims: Arc::from(dims),
656 strides: Arc::from(strides),
657 offset: 0,
658 }
659 }
660
661 pub fn from_fn_col_major(dims: &[usize], mut f: impl FnMut(&[usize]) -> T) -> Self {
665 let total: usize = dims.iter().product();
666 let strides = col_major_strides(dims);
667 let rank = dims.len();
668 let mut data = Vec::with_capacity(total);
669 let mut idx = vec![0usize; rank];
670 for _ in 0..total {
671 data.push(f(&idx));
672 for d in 0..rank {
673 idx[d] += 1;
674 if idx[d] < dims[d] {
675 break;
676 }
677 idx[d] = 0;
678 }
679 }
680 Self {
681 data,
682 dims: Arc::from(dims),
683 strides: Arc::from(strides),
684 offset: 0,
685 }
686 }
687
688 pub fn from_fn_row_major(dims: &[usize], mut f: impl FnMut(&[usize]) -> T) -> Self {
692 let total: usize = dims.iter().product();
693 let strides = row_major_strides(dims);
694 let rank = dims.len();
695 let mut data = Vec::with_capacity(total);
696 let mut idx = vec![0usize; rank];
697 for _ in 0..total {
698 data.push(f(&idx));
699 for d in (0..rank).rev() {
700 idx[d] += 1;
701 if idx[d] < dims[d] {
702 break;
703 }
704 idx[d] = 0;
705 }
706 }
707 Self {
708 data,
709 dims: Arc::from(dims),
710 strides: Arc::from(strides),
711 offset: 0,
712 }
713 }
714}
715
716impl<T> StridedArray<T> {
717 pub fn from_parts(
719 data: Vec<T>,
720 dims: &[usize],
721 strides: &[isize],
722 offset: isize,
723 ) -> Result<Self> {
724 validate_bounds(data.len(), dims, strides, offset)?;
725 Ok(Self {
726 data,
727 dims: Arc::from(dims),
728 strides: Arc::from(strides),
729 offset,
730 })
731 }
732
733 #[inline]
735 pub fn dims(&self) -> &[usize] {
736 &self.dims
737 }
738
739 #[inline]
741 pub fn strides(&self) -> &[isize] {
742 &self.strides
743 }
744
745 #[inline]
747 pub fn ndim(&self) -> usize {
748 self.dims.len()
749 }
750
751 #[inline]
753 pub fn len(&self) -> usize {
754 self.dims.iter().product()
755 }
756
757 #[inline]
759 pub fn is_empty(&self) -> bool {
760 self.dims.iter().any(|&d| d == 0)
761 }
762
763 #[inline]
765 pub fn data(&self) -> &[T] {
766 &self.data
767 }
768
769 #[inline]
771 pub fn data_mut(&mut self) -> &mut [T] {
772 &mut self.data
773 }
774
775 pub fn view(&self) -> StridedView<'_, T> {
777 let ptr = unsafe { self.data.as_ptr().offset(self.offset) };
778 StridedView {
779 ptr,
780 data: &self.data,
781 dims: self.dims.clone(),
782 strides: self.strides.clone(),
783 offset: self.offset,
784 _op: PhantomData,
785 }
786 }
787
788 pub fn view_mut(&mut self) -> StridedViewMut<'_, T> {
790 let ptr = unsafe { self.data.as_mut_ptr().offset(self.offset) };
791 StridedViewMut {
792 ptr,
793 data: &mut self.data,
794 dims: self.dims.clone(),
795 strides: self.strides.clone(),
796 offset: self.offset,
797 }
798 }
799
800 pub fn permuted(self, perm: &[usize]) -> Result<Self> {
805 let rank = self.dims.len();
806 if perm.len() != rank {
807 return Err(StridedError::RankMismatch(perm.len(), rank));
808 }
809 let mut seen = vec![false; rank];
810 for &p in perm {
811 if p >= rank {
812 return Err(StridedError::InvalidAxis { axis: p, rank });
813 }
814 if seen[p] {
815 return Err(StridedError::InvalidAxis { axis: p, rank });
816 }
817 seen[p] = true;
818 }
819 let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
820 let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
821 Ok(Self {
822 data: self.data,
823 dims: Arc::from(new_dims),
824 strides: Arc::from(new_strides),
825 offset: self.offset,
826 })
827 }
828
829 pub fn into_data(self) -> Vec<T> {
831 self.data
832 }
833
834 pub fn iter(&self) -> std::slice::Iter<'_, T> {
836 self.data.iter()
837 }
838
839 pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
841 self.data.iter_mut()
842 }
843}
844
845impl<T: Default> StridedArray<T> {
846 pub fn col_major_from_buffer(mut buf: Vec<T>, dims: &[usize]) -> Self {
851 let total: usize = dims.iter().product();
852 if buf.len() >= total {
853 buf.truncate(total);
854 } else {
855 buf.resize_with(total, T::default);
856 }
857 for v in buf.iter_mut() {
859 *v = T::default();
860 }
861 let strides = col_major_strides(dims);
862 Self {
863 data: buf,
864 dims: Arc::from(dims),
865 strides: Arc::from(strides),
866 offset: 0,
867 }
868 }
869}
870
871impl<T: Copy> StridedArray<T> {
872 pub unsafe fn col_major_uninit(dims: &[usize]) -> Self {
877 let total: usize = dims.iter().product();
878 let mut data = Vec::with_capacity(total);
879 data.set_len(total);
880 let strides = col_major_strides(dims);
881 Self {
882 data,
883 dims: Arc::from(dims),
884 strides: Arc::from(strides),
885 offset: 0,
886 }
887 }
888
889 pub unsafe fn col_major_from_buffer_uninit(mut buf: Vec<T>, dims: &[usize]) -> Self {
894 let total: usize = dims.iter().product();
895 if buf.capacity() < total {
896 buf.reserve(total - buf.len());
897 }
898 buf.set_len(total);
899 let strides = col_major_strides(dims);
900 Self {
901 data: buf,
902 dims: Arc::from(dims),
903 strides: Arc::from(strides),
904 offset: 0,
905 }
906 }
907}
908
909impl<T: Copy> StridedArray<T> {
910 pub fn get(&self, indices: &[usize]) -> T {
912 self.view().get(indices)
913 }
914
915 pub fn set(&mut self, indices: &[usize], value: T) {
917 assert_eq!(indices.len(), self.dims.len());
918 let mut idx = self.offset;
919 for (i, &index) in indices.iter().enumerate() {
920 assert!(index < self.dims[i]);
921 idx += index as isize * self.strides[i];
922 }
923 self.data[idx as usize] = value;
924 }
925}
926
927impl<T: Copy> Index<&[usize]> for StridedArray<T> {
928 type Output = T;
929
930 fn index(&self, indices: &[usize]) -> &T {
931 let mut idx = self.offset;
932 for (i, &index) in indices.iter().enumerate() {
933 assert!(index < self.dims[i]);
934 idx += index as isize * self.strides[i];
935 }
936 &self.data[idx as usize]
937 }
938}
939
940impl<T: Copy> IndexMut<&[usize]> for StridedArray<T> {
941 fn index_mut(&mut self, indices: &[usize]) -> &mut T {
942 let mut idx = self.offset;
943 for (i, &index) in indices.iter().enumerate() {
944 assert!(index < self.dims[i]);
945 idx += index as isize * self.strides[i];
946 }
947 &mut self.data[idx as usize]
948 }
949}
950
951#[cfg(test)]
956mod tests {
957 use super::*;
958 use num_complex::Complex64;
959
960 #[test]
961 fn test_col_major_strides() {
962 assert_eq!(col_major_strides(&[3, 4]), vec![1, 3]);
963 assert_eq!(col_major_strides(&[2, 3, 4]), vec![1, 2, 6]);
964 }
965
966 #[test]
967 fn test_row_major_strides() {
968 assert_eq!(row_major_strides(&[3, 4]), vec![4, 1]);
969 assert_eq!(row_major_strides(&[2, 3, 4]), vec![12, 4, 1]);
970 }
971
972 #[test]
973 fn test_strided_view_new() {
974 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
975 let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
976 assert_eq!(view.ndim(), 2);
977 assert_eq!(view.dims(), &[2, 3]);
978 assert_eq!(view.strides(), &[3, 1]);
979 assert_eq!(view.len(), 6);
980 }
981
982 #[test]
983 fn test_strided_view_get() {
984 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
985 let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
986 assert_eq!(view.get(&[0, 0]), 1.0);
987 assert_eq!(view.get(&[0, 1]), 2.0);
988 assert_eq!(view.get(&[0, 2]), 3.0);
989 assert_eq!(view.get(&[1, 0]), 4.0);
990 assert_eq!(view.get(&[1, 2]), 6.0);
991 }
992
993 #[test]
994 fn test_strided_view_col_major() {
995 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
997 let view = StridedView::<f64>::new(&data, &[2, 3], &[1, 2], 0).unwrap();
998 assert_eq!(view.get(&[0, 0]), 1.0); assert_eq!(view.get(&[1, 0]), 2.0); assert_eq!(view.get(&[0, 1]), 3.0); assert_eq!(view.get(&[1, 1]), 4.0); assert_eq!(view.get(&[0, 2]), 5.0); assert_eq!(view.get(&[1, 2]), 6.0); }
1005
1006 #[test]
1007 fn test_strided_view_permute() {
1008 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1009 let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1010 let perm = view.permute(&[1, 0]).unwrap();
1011 assert_eq!(perm.dims(), &[3, 2]);
1012 assert_eq!(perm.strides(), &[1, 3]);
1013 assert_eq!(perm.get(&[0, 0]), 1.0);
1014 assert_eq!(perm.get(&[1, 0]), 2.0);
1015 assert_eq!(perm.get(&[0, 1]), 4.0);
1016 }
1017
1018 #[test]
1019 fn test_strided_view_transpose_2d() {
1020 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1021 let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1022 let t = view.transpose_2d().unwrap();
1023 assert_eq!(t.dims(), &[3, 2]);
1024 assert_eq!(t.get(&[0, 0]), 1.0);
1025 assert_eq!(t.get(&[1, 0]), 2.0);
1026 assert_eq!(t.get(&[0, 1]), 4.0);
1027 }
1028
1029 #[test]
1030 fn test_strided_view_conj() {
1031 let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1032 let view = StridedView::<Complex64>::new(&data, &[2], &[1], 0).unwrap();
1033 let c = view.conj();
1034 assert_eq!(c.get(&[0]), Complex64::new(1.0, -2.0));
1035 assert_eq!(c.get(&[1]), Complex64::new(3.0, -4.0));
1036 }
1037
1038 #[test]
1039 fn test_strided_view_adjoint_2d() {
1040 let data = vec![
1041 Complex64::new(1.0, 2.0),
1042 Complex64::new(3.0, 4.0),
1043 Complex64::new(5.0, 6.0),
1044 Complex64::new(7.0, 8.0),
1045 ];
1046 let view = StridedView::<Complex64>::new(&data, &[2, 2], &[2, 1], 0).unwrap();
1048 let adj = view.adjoint_2d().unwrap();
1049 assert_eq!(adj.dims(), &[2, 2]);
1050 assert_eq!(adj.get(&[0, 0]), Complex64::new(1.0, -2.0));
1052 assert_eq!(adj.get(&[1, 0]), Complex64::new(3.0, -4.0));
1053 assert_eq!(adj.get(&[0, 1]), Complex64::new(5.0, -6.0));
1054 }
1055
1056 #[test]
1057 fn test_strided_view_broadcast() {
1058 let data = vec![1.0, 2.0, 3.0];
1059 let view = StridedView::<f64>::new(&data, &[1, 3], &[3, 1], 0).unwrap();
1060 let broad = view.broadcast(&[4, 3]).unwrap();
1061 assert_eq!(broad.dims(), &[4, 3]);
1062 for i in 0..4 {
1063 assert_eq!(broad.get(&[i, 0]), 1.0);
1064 assert_eq!(broad.get(&[i, 1]), 2.0);
1065 assert_eq!(broad.get(&[i, 2]), 3.0);
1066 }
1067 }
1068
1069 #[test]
1070 fn test_strided_view_mut() {
1071 let mut data = vec![0.0; 6];
1072 {
1073 let mut view = StridedViewMut::<f64>::new(&mut data, &[2, 3], &[3, 1], 0).unwrap();
1074 view.set(&[0, 0], 1.0);
1075 view.set(&[1, 2], 6.0);
1076 }
1077 assert_eq!(data[0], 1.0);
1078 assert_eq!(data[5], 6.0);
1079 }
1080
1081 #[test]
1082 fn test_strided_view_mut_as_view() {
1083 let mut data = vec![1.0, 2.0, 3.0];
1084 let vm = StridedViewMut::<f64>::new(&mut data, &[3], &[1], 0).unwrap();
1085 let v = vm.as_view();
1086 assert_eq!(v.get(&[0]), 1.0);
1087 assert_eq!(v.get(&[2]), 3.0);
1088 }
1089
1090 #[test]
1091 fn test_strided_tensor_col_major() {
1092 let t = StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
1093 assert_eq!(t.dims(), &[2, 3]);
1094 assert_eq!(t.strides(), &[1, 2]); assert_eq!(t.get(&[0, 0]), 0.0);
1096 assert_eq!(t.get(&[1, 0]), 3.0);
1097 assert_eq!(t.get(&[0, 1]), 1.0);
1098 assert_eq!(t.get(&[1, 2]), 5.0);
1099 }
1100
1101 #[test]
1102 fn test_strided_tensor_row_major() {
1103 let t = StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
1104 assert_eq!(t.dims(), &[2, 3]);
1105 assert_eq!(t.strides(), &[3, 1]); assert_eq!(t.get(&[0, 0]), 0.0);
1107 assert_eq!(t.get(&[0, 1]), 1.0);
1108 assert_eq!(t.get(&[1, 0]), 3.0);
1109 assert_eq!(t.get(&[1, 2]), 5.0);
1110 }
1111
1112 #[test]
1113 fn test_strided_tensor_view() {
1114 let t =
1115 StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
1116 let v = t.view();
1117 assert_eq!(v.get(&[0, 0]), 0.0);
1118 assert_eq!(v.get(&[1, 0]), 10.0);
1119 assert_eq!(v.get(&[0, 2]), 2.0);
1120 }
1121
1122 #[test]
1123 fn test_strided_tensor_view_mut() {
1124 let mut t = StridedArray::<f64>::col_major(&[2, 3]);
1125 {
1126 let mut vm = t.view_mut();
1127 vm.set(&[1, 2], 42.0);
1128 }
1129 assert_eq!(t.get(&[1, 2]), 42.0);
1130 }
1131
1132 #[test]
1133 fn test_strided_tensor_index() {
1134 let t =
1135 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 10 + idx[1]) as f64);
1136 assert_eq!(t[&[0usize, 0] as &[usize]], 0.0);
1137 assert_eq!(t[&[2usize, 3] as &[usize]], 23.0);
1138 }
1139
1140 #[test]
1141 fn test_strided_tensor_index_mut() {
1142 let mut t = StridedArray::<f64>::row_major(&[2, 3]);
1143 t[&[1usize, 2] as &[usize]] = 99.0;
1144 assert_eq!(t.get(&[1, 2]), 99.0);
1145 }
1146
1147 #[test]
1148 fn test_validate_bounds_ok() {
1149 assert!(validate_bounds(6, &[2, 3], &[3, 1], 0).is_ok());
1150 assert!(validate_bounds(6, &[2, 3], &[1, 2], 0).is_ok());
1151 }
1152
1153 #[test]
1154 fn test_validate_bounds_out_of_range() {
1155 assert!(validate_bounds(5, &[2, 3], &[3, 1], 0).is_err());
1156 }
1157
1158 #[test]
1159 fn test_validate_bounds_empty() {
1160 assert!(validate_bounds(0, &[0, 3], &[3, 1], 0).is_ok());
1161 }
1162
1163 #[test]
1164 fn test_validate_bounds_with_offset() {
1165 assert!(validate_bounds(7, &[2, 3], &[3, 1], 1).is_ok());
1166 assert!(validate_bounds(6, &[2, 3], &[3, 1], 1).is_err());
1167 }
1168
1169 #[test]
1170 fn test_strided_tensor_3d() {
1171 let t = StridedArray::<f64>::from_fn_col_major(&[2, 3, 4], |idx| {
1172 (idx[0] * 100 + idx[1] * 10 + idx[2]) as f64
1173 });
1174 assert_eq!(t.ndim(), 3);
1175 assert_eq!(t.strides(), &[1, 2, 6]); assert_eq!(t.get(&[0, 0, 0]), 0.0);
1177 assert_eq!(t.get(&[1, 0, 0]), 100.0);
1178 assert_eq!(t.get(&[0, 1, 0]), 10.0);
1179 assert_eq!(t.get(&[0, 0, 1]), 1.0);
1180 assert_eq!(t.get(&[1, 2, 3]), 123.0);
1181 }
1182
1183 #[test]
1184 fn test_diagonal_view_2d() {
1185 let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
1188 let view = StridedView::<f64>::new(&data, &[3, 3], &[3, 1], 0).unwrap();
1189 let diag = view.diagonal_view(&[(0, 1)]).unwrap();
1190 assert_eq!(diag.dims(), &[3]);
1191 assert_eq!(diag.strides(), &[4]);
1192 assert_eq!(diag.get(&[0]), 0.0); assert_eq!(diag.get(&[1]), 4.0); assert_eq!(diag.get(&[2]), 8.0); }
1196
1197 #[test]
1198 fn test_diagonal_view_3d_adjacent() {
1199 let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
1202 let view = StridedView::<f64>::new(&data, &[2, 2, 3], &[6, 3, 1], 0).unwrap();
1203 let diag = view.diagonal_view(&[(0, 1)]).unwrap();
1204 assert_eq!(diag.dims(), &[2, 3]);
1205 assert_eq!(diag.strides(), &[9, 1]);
1206 assert_eq!(diag.get(&[0, 0]), 0.0);
1207 assert_eq!(diag.get(&[0, 2]), 2.0);
1208 assert_eq!(diag.get(&[1, 0]), 9.0);
1209 }
1210
1211 #[test]
1212 fn test_diagonal_view_3d_non_adjacent() {
1213 let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
1216 let view = StridedView::<f64>::new(&data, &[2, 3, 2], &[6, 2, 1], 0).unwrap();
1217 let diag = view.diagonal_view(&[(0, 2)]).unwrap();
1218 assert_eq!(diag.dims(), &[2, 3]);
1219 assert_eq!(diag.strides(), &[7, 2]);
1220 assert_eq!(diag.get(&[0, 0]), 0.0);
1221 assert_eq!(diag.get(&[0, 1]), 2.0);
1222 assert_eq!(diag.get(&[1, 0]), 7.0);
1223 assert_eq!(diag.get(&[1, 2]), 11.0);
1224 }
1225
1226 #[test]
1227 fn test_diagonal_view_two_pairs() {
1228 let data: Vec<f64> = (0..36).map(|x| x as f64).collect();
1230 let view = StridedView::<f64>::new(&data, &[2, 3, 2, 3], &[18, 6, 3, 1], 0).unwrap();
1231 let diag = view.diagonal_view(&[(0, 2), (1, 3)]).unwrap();
1232 assert_eq!(diag.dims(), &[2, 3]);
1233 assert_eq!(diag.strides(), &[21, 7]);
1234 assert_eq!(diag.get(&[0, 0]), 0.0);
1235 assert_eq!(diag.get(&[1, 1]), 28.0);
1236 }
1237
1238 #[test]
1239 fn test_custom_type_without_element_op_apply() {
1240 #[derive(Debug, Clone, Copy, PartialEq)]
1243 struct MyScalar(f64);
1244
1245 impl Default for MyScalar {
1246 fn default() -> Self {
1247 MyScalar(0.0)
1248 }
1249 }
1250
1251 let data = vec![MyScalar(1.0), MyScalar(2.0), MyScalar(3.0), MyScalar(4.0)];
1253 let arr = StridedArray::from_parts(data, &[2, 2], &[2, 1], 0).unwrap();
1254
1255 assert_eq!(arr.get(&[0, 0]), MyScalar(1.0));
1257 assert_eq!(arr.get(&[1, 0]), MyScalar(3.0));
1258
1259 let view: StridedView<MyScalar> = arr.view();
1261 assert_eq!(view.get(&[0, 1]), MyScalar(2.0));
1262
1263 let perm = view.permute(&[1, 0]).unwrap();
1265 assert_eq!(perm.get(&[0, 0]), MyScalar(1.0));
1266 assert_eq!(perm.get(&[0, 1]), MyScalar(3.0)); }
1268}