1use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
2use strided_kernel::copy_into;
3use strided_view::{row_major_strides, StridedArray, StridedView, StridedViewMut};
4
5fn compute_offset_range(shape: &[usize], strides: &[isize]) -> (isize, isize) {
10 let mut min_off: isize = 0;
11 let mut max_off: isize = 0;
12 for (&d, &s) in shape.iter().zip(strides.iter()) {
13 if d == 0 {
14 continue;
15 }
16 let end = s * (d as isize - 1);
17 if end < 0 {
18 min_off += end;
19 } else {
20 max_off += end;
21 }
22 }
23 (min_off, max_off)
24}
25
26pub fn array_to_strided_view<T>(arr: &ArrayD<T>) -> StridedView<'_, T> {
30 let shape = arr.shape();
31 let strides = arr.strides();
32 let (min_off, max_off) = compute_offset_range(shape, strides);
33 let base_ptr = unsafe { arr.as_ptr().offset(min_off) };
34 let data_len = (max_off - min_off + 1) as usize;
35 let data = unsafe { std::slice::from_raw_parts(base_ptr, data_len) };
36 let offset = -min_off;
37 StridedView::new(data, shape, strides, offset).expect("valid bounds")
38}
39
40pub fn view_to_strided_view<'a, T>(view: &ArrayViewD<'a, T>) -> StridedView<'a, T> {
42 let shape = view.shape();
43 let strides = view.strides();
44 let (min_off, max_off) = compute_offset_range(shape, strides);
45 let base_ptr = unsafe { view.as_ptr().offset(min_off) };
46 let data_len = (max_off - min_off + 1) as usize;
47 let data = unsafe { std::slice::from_raw_parts(base_ptr, data_len) };
48 let offset = -min_off;
49 StridedView::new(data, shape, strides, offset).expect("valid bounds")
50}
51
52pub fn view_mut_to_strided_view_mut<'a, T>(
54 view: &'a mut ArrayViewMutD<'a, T>,
55) -> StridedViewMut<'a, T> {
56 let shape: Vec<usize> = view.shape().to_vec();
57 let strides: Vec<isize> = view.strides().to_vec();
58 let (min_off, max_off) = compute_offset_range(&shape, &strides);
59 let base_ptr = unsafe { view.as_mut_ptr().offset(min_off) };
60 let data_len = (max_off - min_off + 1) as usize;
61 let data = unsafe { std::slice::from_raw_parts_mut(base_ptr, data_len) };
62 let offset = -min_off;
63 StridedViewMut::new(data, &shape, &strides, offset).expect("valid bounds")
64}
65
66pub fn strided_array_to_ndarray<T>(arr: StridedArray<T>) -> ArrayD<T>
71where
72 T: Copy + strided_view::ElementOpApply + Send + Sync + num_traits::Zero + Default,
73{
74 let dims = arr.dims().to_vec();
75 let total: usize = dims.iter().product();
76 let mut buf = vec![T::default(); total];
77 let rm_strides = row_major_strides(&dims);
78 {
79 let mut dest = StridedViewMut::new(&mut buf, &dims, &rm_strides, 0).expect("valid dest");
80 copy_into(&mut dest, &arr.view()).expect("copy_into failed");
81 }
82 ArrayD::from_shape_vec(dims, buf).expect("shape matches")
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use ndarray::ArrayD;
89
90 #[test]
91 fn test_array_to_strided_view_2d() {
92 let arr = ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
94 let sv = array_to_strided_view(&arr);
95
96 assert_eq!(sv.dims(), &[2, 3]);
98 assert_eq!(sv.strides(), &[3, 1]);
100
101 let ptr = sv.ptr();
103 unsafe {
104 assert_eq!(*ptr, 1.0); assert_eq!(*ptr.offset(1), 2.0); assert_eq!(*ptr.offset(3), 4.0); assert_eq!(*ptr.offset(5), 6.0); }
109 }
110
111 #[test]
112 fn test_strided_array_to_ndarray_roundtrip() {
113 let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
115 let arr = StridedArray::from_parts(data, &[2, 3], &[1, 2], 0).expect("valid strided array");
116
117 let nd = strided_array_to_ndarray(arr);
118 assert_eq!(nd.shape(), &[2, 3]);
119 assert_eq!(nd[[0, 0]], 1.0);
121 assert_eq!(nd[[0, 1]], 2.0);
122 assert_eq!(nd[[0, 2]], 3.0);
123 assert_eq!(nd[[1, 0]], 4.0);
124 assert_eq!(nd[[1, 1]], 5.0);
125 assert_eq!(nd[[1, 2]], 6.0);
126 }
127
128 #[test]
129 fn test_negative_stride_view() {
130 let arr = ArrayD::from_shape_vec(vec![3], vec![10.0, 20.0, 30.0]).unwrap();
132 use ndarray::s;
133 let reversed = arr.slice(s![..;-1]);
134 let dyn_view: ArrayViewD<'_, f64> = reversed.into_dimensionality().unwrap();
138 let sv = view_to_strided_view(&dyn_view);
139
140 assert_eq!(sv.dims(), &[3]);
141 assert_eq!(sv.strides(), &[-1]);
142
143 let ptr = sv.ptr();
145 unsafe {
146 assert_eq!(*ptr, 30.0); assert_eq!(*ptr.offset(-1), 20.0); assert_eq!(*ptr.offset(-2), 10.0); }
150 }
151}