mdarray_opteinsum/
convert.rs

1use std::mem::ManuallyDrop;
2
3use mdarray::{Array, DynRank, Shape, View, ViewMut};
4use strided_kernel::copy_into;
5use strided_view::{StridedArray, StridedView, StridedViewMut, row_major_strides};
6
7/// Reverse the index labels in an einsum notation string.
8///
9/// Each operand's labels and the output labels are reversed to convert
10/// between row-major (mdarray) and column-major (strided) conventions.
11///
12/// Example: `"(ij,jk),kl->il"` becomes `"(ji,kj),lk->li"`
13pub fn reverse_notation(notation: &str) -> String {
14    let s: String = notation.chars().filter(|c| !c.is_whitespace()).collect();
15
16    let arrow_pos = s.find("->").expect("missing '->' in einsum notation");
17    let lhs = &s[..arrow_pos];
18    let rhs = &s[arrow_pos + 2..];
19
20    let reversed_lhs = reverse_lhs(lhs);
21    let reversed_rhs: String = rhs.chars().rev().collect();
22
23    format!("{}->{}", reversed_lhs, reversed_rhs)
24}
25
26/// Reverse index labels in the LHS of einsum notation, preserving structure.
27fn reverse_lhs(s: &str) -> String {
28    let mut result = String::with_capacity(s.len());
29    let mut label_buf = String::new();
30
31    for c in s.chars() {
32        match c {
33            '(' | ')' | ',' => {
34                // Flush accumulated labels (reversed)
35                if !label_buf.is_empty() {
36                    result.extend(label_buf.chars().rev());
37                    label_buf.clear();
38                }
39                result.push(c);
40            }
41            _ => {
42                label_buf.push(c);
43            }
44        }
45    }
46    // Flush trailing labels
47    if !label_buf.is_empty() {
48        result.extend(label_buf.chars().rev());
49    }
50    result
51}
52
53/// Wrap an mdarray `Array<T, DynRank>` as a `StridedView` with reversed dims.
54///
55/// Zero-copy: the view borrows the array's data.
56pub fn array_to_strided_view<T>(arr: &Array<T, DynRank>) -> StridedView<'_, T> {
57    let dims: Vec<usize> = arr.dims().iter().rev().copied().collect();
58    let strides = row_major_strides(arr.dims());
59    let reversed_strides: Vec<isize> = strides.iter().rev().copied().collect();
60    let len = arr.len();
61    let data = unsafe { std::slice::from_raw_parts(arr.as_ptr(), len) };
62    StridedView::new(data, &dims, &reversed_strides, 0).expect("valid bounds")
63}
64
65/// Wrap an mdarray `View<T, DynRank>` as a `StridedView` with reversed dims.
66///
67/// Zero-copy: the view borrows the same data.
68pub fn view_to_strided_view<'a, T>(view: &'a View<'a, T, DynRank>) -> StridedView<'a, T> {
69    let dims: Vec<usize> = view.dims().iter().rev().copied().collect();
70    let strides = row_major_strides(view.dims());
71    let reversed_strides: Vec<isize> = strides.iter().rev().copied().collect();
72    let len = view.len();
73    let data = unsafe { std::slice::from_raw_parts(view.as_ptr(), len) };
74    StridedView::new(data, &dims, &reversed_strides, 0).expect("valid bounds")
75}
76
77/// Wrap an mdarray `ViewMut<T, DynRank>` as a `StridedViewMut` with reversed dims.
78///
79/// Zero-copy: the mutable view borrows the same data.
80pub fn view_mut_to_strided_view_mut<'a, T>(
81    view: &'a mut ViewMut<'a, T, DynRank>,
82) -> StridedViewMut<'a, T> {
83    let dims: Vec<usize> = view.dims().iter().rev().copied().collect();
84    let strides = row_major_strides(view.dims());
85    let reversed_strides: Vec<isize> = strides.iter().rev().copied().collect();
86    let len = view.len();
87    let data = unsafe { std::slice::from_raw_parts_mut(view.as_mut_ptr(), len) };
88    StridedViewMut::new(data, &dims, &reversed_strides, 0).expect("valid bounds")
89}
90
91/// Convert a `StridedArray` result into an mdarray `Array<T, DynRank>`.
92///
93/// The result's dims are reversed to match mdarray's row-major convention.
94/// A copy is performed to materialize the data into dense row-major order,
95/// since mdarray only supports dense (implicit-stride) layout.
96pub fn strided_array_to_mdarray<T>(arr: StridedArray<T>) -> Array<T, DynRank>
97where
98    T: Copy + strided_view::ElementOpApply + Send + Sync + num_traits::Zero + Default,
99{
100    let src_view = arr.view();
101    let reversed_dims: Vec<usize> = arr.dims().iter().rev().copied().collect();
102    let reversed_strides: Vec<isize> = arr.strides().iter().rev().copied().collect();
103
104    // Create a dest StridedViewMut with row-major layout (= reversed strides)
105    // and copy from the source view, letting copy_into handle arbitrary strides.
106    let total: usize = reversed_dims.iter().product();
107    let mut buf = vec![T::default(); total];
108    let rm_strides = row_major_strides(&reversed_dims);
109    {
110        let mut dest =
111            StridedViewMut::new(&mut buf, &reversed_dims, &rm_strides, 0).expect("valid dest");
112        // Wrap source data with reversed dims/strides so copy_into can match shapes.
113        let src_reversed: StridedView<'_, T> =
114            StridedView::new(src_view.data(), &reversed_dims, &reversed_strides, 0)
115                .expect("valid source");
116        copy_into(&mut dest, &src_reversed).expect("copy_into failed");
117    }
118
119    let shape: DynRank = Shape::from_dims(&reversed_dims);
120    let mut buf = ManuallyDrop::new(buf);
121    let capacity = buf.capacity();
122    let ptr = buf.as_mut_ptr();
123    unsafe { Array::from_raw_parts(ptr, shape, capacity) }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_reverse_flat() {
132        assert_eq!(reverse_notation("ij,jk->ik"), "ji,kj->ki");
133    }
134
135    #[test]
136    fn test_reverse_nested() {
137        assert_eq!(reverse_notation("(ij,jk),kl->il"), "(ji,kj),lk->li");
138    }
139
140    #[test]
141    fn test_reverse_deep_nested() {
142        assert_eq!(
143            reverse_notation("((ij,jk),(kl,lm))->im"),
144            "((ji,kj),(lk,ml))->mi"
145        );
146    }
147
148    #[test]
149    fn test_reverse_trace() {
150        assert_eq!(reverse_notation("ii->"), "ii->");
151    }
152
153    #[test]
154    fn test_reverse_single_operand() {
155        assert_eq!(reverse_notation("ijk->kji"), "kji->ijk");
156    }
157
158    #[test]
159    fn test_reverse_scalar_output() {
160        assert_eq!(reverse_notation("ij,ji->"), "ji,ij->");
161    }
162
163    #[test]
164    fn test_array_to_strided_view_2d() {
165        // Create a row-major 2x3 mdarray: [[1,2,3],[4,5,6]]
166        // Memory: [1,2,3,4,5,6]
167        let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
168        let shape: DynRank = Shape::from_dims(&[2, 3]);
169        let arr: Array<f64, DynRank> = unsafe {
170            let mut data = ManuallyDrop::new(data);
171            Array::from_raw_parts(data.as_mut_ptr(), shape, data.capacity())
172        };
173
174        let view = array_to_strided_view(&arr);
175        // Reversed dims: [3, 2]
176        assert_eq!(view.dims(), &[3, 2]);
177        // Row-major strides for [2,3] are [3,1], reversed: [1,3]
178        assert_eq!(view.strides(), &[1, 3]);
179
180        // Verify element access via pointer arithmetic:
181        // view[col, row] maps to arr[row, col] (transposed indexing)
182        let ptr = view.ptr();
183        unsafe {
184            // view[0,0] = arr[0,0] = 1.0
185            assert_eq!(*ptr, 1.0);
186            // view[1,0] = arr[0,1] = 2.0 (stride[0]=1)
187            assert_eq!(*ptr.offset(1), 2.0);
188            // view[0,1] = arr[1,0] = 4.0 (stride[1]=3)
189            assert_eq!(*ptr.offset(3), 4.0);
190            // view[2,1] = arr[1,2] = 6.0 (2*1 + 1*3 = 5)
191            assert_eq!(*ptr.offset(5), 6.0);
192        }
193    }
194
195    #[test]
196    fn test_strided_array_to_mdarray_roundtrip() {
197        // Create a col-major StridedArray [3, 2] (result from einsum)
198        // col-major strides: [1, 3]
199        // Memory: [a00, a10, a20, a01, a11, a21]
200        let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
201        let arr = StridedArray::from_parts(data, &[3, 2], &[1, 3], 0).expect("valid strided array");
202
203        let md = strided_array_to_mdarray(arr);
204        // Reversed dims: [2, 3]
205        assert_eq!(md.dims(), &[2, 3]);
206        // Row-major mdarray [2,3] with memory [a00, a10, a20, a01, a11, a21]
207        // md[0,0] = memory[0] = 1.0, md[0,1] = memory[1] = 4.0, ...
208        // This is correct because col-major [3,2] data = row-major [2,3] data
209        // (same memory layout when dims are reversed)
210    }
211}