ndarray_opteinsum/
convert.rs

1use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
2use strided_kernel::copy_into;
3use strided_view::{row_major_strides, StridedArray, StridedView, StridedViewMut};
4
5/// Compute the min and max element-offset reachable from index [0,0,...,0].
6///
7/// For non-negative strides the min is 0; for negative strides (reversed views)
8/// the min can be negative relative to `as_ptr()`.
9fn compute_offset_range(shape: &[usize], strides: &[isize]) -> (isize, isize) {
10    let mut min_off: isize = 0;
11    let mut max_off: isize = 0;
12    for (&d, &s) in shape.iter().zip(strides.iter()) {
13        if d == 0 {
14            continue;
15        }
16        let end = s * (d as isize - 1);
17        if end < 0 {
18            min_off += end;
19        } else {
20            max_off += end;
21        }
22    }
23    (min_off, max_off)
24}
25
26/// Wrap an ndarray `ArrayD<T>` as a `StridedView` (zero-copy).
27///
28/// Dims and strides are passed through directly (no reversal).
29pub fn array_to_strided_view<T>(arr: &ArrayD<T>) -> StridedView<'_, T> {
30    let shape = arr.shape();
31    let strides = arr.strides();
32    let (min_off, max_off) = compute_offset_range(shape, strides);
33    let base_ptr = unsafe { arr.as_ptr().offset(min_off) };
34    let data_len = (max_off - min_off + 1) as usize;
35    let data = unsafe { std::slice::from_raw_parts(base_ptr, data_len) };
36    let offset = -min_off;
37    StridedView::new(data, shape, strides, offset).expect("valid bounds")
38}
39
40/// Wrap an ndarray `ArrayViewD<T>` as a `StridedView` (zero-copy).
41pub fn view_to_strided_view<'a, T>(view: &ArrayViewD<'a, T>) -> StridedView<'a, T> {
42    let shape = view.shape();
43    let strides = view.strides();
44    let (min_off, max_off) = compute_offset_range(shape, strides);
45    let base_ptr = unsafe { view.as_ptr().offset(min_off) };
46    let data_len = (max_off - min_off + 1) as usize;
47    let data = unsafe { std::slice::from_raw_parts(base_ptr, data_len) };
48    let offset = -min_off;
49    StridedView::new(data, shape, strides, offset).expect("valid bounds")
50}
51
52/// Wrap an ndarray `ArrayViewMutD<T>` as a `StridedViewMut` (zero-copy).
53pub fn view_mut_to_strided_view_mut<'a, T>(
54    view: &'a mut ArrayViewMutD<'a, T>,
55) -> StridedViewMut<'a, T> {
56    let shape: Vec<usize> = view.shape().to_vec();
57    let strides: Vec<isize> = view.strides().to_vec();
58    let (min_off, max_off) = compute_offset_range(&shape, &strides);
59    let base_ptr = unsafe { view.as_mut_ptr().offset(min_off) };
60    let data_len = (max_off - min_off + 1) as usize;
61    let data = unsafe { std::slice::from_raw_parts_mut(base_ptr, data_len) };
62    let offset = -min_off;
63    StridedViewMut::new(data, &shape, &strides, offset).expect("valid bounds")
64}
65
66/// Convert a `StridedArray` result into an ndarray `ArrayD<T>`.
67///
68/// A copy is performed to materialize data into dense row-major order,
69/// since the `StridedArray` from einsum may have arbitrary strides.
70pub fn strided_array_to_ndarray<T>(arr: StridedArray<T>) -> ArrayD<T>
71where
72    T: Copy + strided_view::ElementOpApply + Send + Sync + num_traits::Zero + Default,
73{
74    let dims = arr.dims().to_vec();
75    let total: usize = dims.iter().product();
76    let mut buf = vec![T::default(); total];
77    let rm_strides = row_major_strides(&dims);
78    {
79        let mut dest = StridedViewMut::new(&mut buf, &dims, &rm_strides, 0).expect("valid dest");
80        copy_into(&mut dest, &arr.view()).expect("copy_into failed");
81    }
82    ArrayD::from_shape_vec(dims, buf).expect("shape matches")
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use ndarray::ArrayD;
89
90    #[test]
91    fn test_array_to_strided_view_2d() {
92        // Row-major 2x3: [[1,2,3],[4,5,6]]
93        let arr = ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
94        let sv = array_to_strided_view(&arr);
95
96        // Dims passed through directly (no reversal)
97        assert_eq!(sv.dims(), &[2, 3]);
98        // Row-major strides for [2,3]: [3, 1]
99        assert_eq!(sv.strides(), &[3, 1]);
100
101        // Verify element access
102        let ptr = sv.ptr();
103        unsafe {
104            assert_eq!(*ptr, 1.0); // [0,0]
105            assert_eq!(*ptr.offset(1), 2.0); // [0,1]
106            assert_eq!(*ptr.offset(3), 4.0); // [1,0]
107            assert_eq!(*ptr.offset(5), 6.0); // [1,2]
108        }
109    }
110
111    #[test]
112    fn test_strided_array_to_ndarray_roundtrip() {
113        // Column-major StridedArray [2, 3] with strides [1, 2]
114        let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
115        let arr = StridedArray::from_parts(data, &[2, 3], &[1, 2], 0).expect("valid strided array");
116
117        let nd = strided_array_to_ndarray(arr);
118        assert_eq!(nd.shape(), &[2, 3]);
119        // Row-major: [[1,2,3],[4,5,6]]
120        assert_eq!(nd[[0, 0]], 1.0);
121        assert_eq!(nd[[0, 1]], 2.0);
122        assert_eq!(nd[[0, 2]], 3.0);
123        assert_eq!(nd[[1, 0]], 4.0);
124        assert_eq!(nd[[1, 1]], 5.0);
125        assert_eq!(nd[[1, 2]], 6.0);
126    }
127
128    #[test]
129    fn test_negative_stride_view() {
130        // Create 3-element array [10, 20, 30], then reverse it
131        let arr = ArrayD::from_shape_vec(vec![3], vec![10.0, 20.0, 30.0]).unwrap();
132        use ndarray::s;
133        let reversed = arr.slice(s![..;-1]);
134        // reversed = [30, 20, 10], stride = -1
135
136        // Convert the dynamic view
137        let dyn_view: ArrayViewD<'_, f64> = reversed.into_dimensionality().unwrap();
138        let sv = view_to_strided_view(&dyn_view);
139
140        assert_eq!(sv.dims(), &[3]);
141        assert_eq!(sv.strides(), &[-1]);
142
143        // sv.ptr() points to element [0] = 30 (data[offset])
144        let ptr = sv.ptr();
145        unsafe {
146            assert_eq!(*ptr, 30.0); // [0]
147            assert_eq!(*ptr.offset(-1), 20.0); // [1]: offset + 1*(-1)
148            assert_eq!(*ptr.offset(-2), 10.0); // [2]: offset + 2*(-1)
149        }
150    }
151}