1use std::mem::ManuallyDrop;
2
3use mdarray::{Array, DynRank, Shape, View, ViewMut};
4use strided_kernel::copy_into;
5use strided_view::{StridedArray, StridedView, StridedViewMut, row_major_strides};
6
7pub fn reverse_notation(notation: &str) -> String {
14 let s: String = notation.chars().filter(|c| !c.is_whitespace()).collect();
15
16 let arrow_pos = s.find("->").expect("missing '->' in einsum notation");
17 let lhs = &s[..arrow_pos];
18 let rhs = &s[arrow_pos + 2..];
19
20 let reversed_lhs = reverse_lhs(lhs);
21 let reversed_rhs: String = rhs.chars().rev().collect();
22
23 format!("{}->{}", reversed_lhs, reversed_rhs)
24}
25
26fn reverse_lhs(s: &str) -> String {
28 let mut result = String::with_capacity(s.len());
29 let mut label_buf = String::new();
30
31 for c in s.chars() {
32 match c {
33 '(' | ')' | ',' => {
34 if !label_buf.is_empty() {
36 result.extend(label_buf.chars().rev());
37 label_buf.clear();
38 }
39 result.push(c);
40 }
41 _ => {
42 label_buf.push(c);
43 }
44 }
45 }
46 if !label_buf.is_empty() {
48 result.extend(label_buf.chars().rev());
49 }
50 result
51}
52
53pub fn array_to_strided_view<T>(arr: &Array<T, DynRank>) -> StridedView<'_, T> {
57 let dims: Vec<usize> = arr.dims().iter().rev().copied().collect();
58 let strides = row_major_strides(arr.dims());
59 let reversed_strides: Vec<isize> = strides.iter().rev().copied().collect();
60 let len = arr.len();
61 let data = unsafe { std::slice::from_raw_parts(arr.as_ptr(), len) };
62 StridedView::new(data, &dims, &reversed_strides, 0).expect("valid bounds")
63}
64
65pub fn view_to_strided_view<'a, T>(view: &'a View<'a, T, DynRank>) -> StridedView<'a, T> {
69 let dims: Vec<usize> = view.dims().iter().rev().copied().collect();
70 let strides = row_major_strides(view.dims());
71 let reversed_strides: Vec<isize> = strides.iter().rev().copied().collect();
72 let len = view.len();
73 let data = unsafe { std::slice::from_raw_parts(view.as_ptr(), len) };
74 StridedView::new(data, &dims, &reversed_strides, 0).expect("valid bounds")
75}
76
77pub fn view_mut_to_strided_view_mut<'a, T>(
81 view: &'a mut ViewMut<'a, T, DynRank>,
82) -> StridedViewMut<'a, T> {
83 let dims: Vec<usize> = view.dims().iter().rev().copied().collect();
84 let strides = row_major_strides(view.dims());
85 let reversed_strides: Vec<isize> = strides.iter().rev().copied().collect();
86 let len = view.len();
87 let data = unsafe { std::slice::from_raw_parts_mut(view.as_mut_ptr(), len) };
88 StridedViewMut::new(data, &dims, &reversed_strides, 0).expect("valid bounds")
89}
90
91pub fn strided_array_to_mdarray<T>(arr: StridedArray<T>) -> Array<T, DynRank>
97where
98 T: Copy + strided_view::ElementOpApply + Send + Sync + num_traits::Zero + Default,
99{
100 let src_view = arr.view();
101 let reversed_dims: Vec<usize> = arr.dims().iter().rev().copied().collect();
102 let reversed_strides: Vec<isize> = arr.strides().iter().rev().copied().collect();
103
104 let total: usize = reversed_dims.iter().product();
107 let mut buf = vec![T::default(); total];
108 let rm_strides = row_major_strides(&reversed_dims);
109 {
110 let mut dest =
111 StridedViewMut::new(&mut buf, &reversed_dims, &rm_strides, 0).expect("valid dest");
112 let src_reversed: StridedView<'_, T> =
114 StridedView::new(src_view.data(), &reversed_dims, &reversed_strides, 0)
115 .expect("valid source");
116 copy_into(&mut dest, &src_reversed).expect("copy_into failed");
117 }
118
119 let shape: DynRank = Shape::from_dims(&reversed_dims);
120 let mut buf = ManuallyDrop::new(buf);
121 let capacity = buf.capacity();
122 let ptr = buf.as_mut_ptr();
123 unsafe { Array::from_raw_parts(ptr, shape, capacity) }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_reverse_flat() {
132 assert_eq!(reverse_notation("ij,jk->ik"), "ji,kj->ki");
133 }
134
135 #[test]
136 fn test_reverse_nested() {
137 assert_eq!(reverse_notation("(ij,jk),kl->il"), "(ji,kj),lk->li");
138 }
139
140 #[test]
141 fn test_reverse_deep_nested() {
142 assert_eq!(
143 reverse_notation("((ij,jk),(kl,lm))->im"),
144 "((ji,kj),(lk,ml))->mi"
145 );
146 }
147
148 #[test]
149 fn test_reverse_trace() {
150 assert_eq!(reverse_notation("ii->"), "ii->");
151 }
152
153 #[test]
154 fn test_reverse_single_operand() {
155 assert_eq!(reverse_notation("ijk->kji"), "kji->ijk");
156 }
157
158 #[test]
159 fn test_reverse_scalar_output() {
160 assert_eq!(reverse_notation("ij,ji->"), "ji,ij->");
161 }
162
163 #[test]
164 fn test_array_to_strided_view_2d() {
165 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
168 let shape: DynRank = Shape::from_dims(&[2, 3]);
169 let arr: Array<f64, DynRank> = unsafe {
170 let mut data = ManuallyDrop::new(data);
171 Array::from_raw_parts(data.as_mut_ptr(), shape, data.capacity())
172 };
173
174 let view = array_to_strided_view(&arr);
175 assert_eq!(view.dims(), &[3, 2]);
177 assert_eq!(view.strides(), &[1, 3]);
179
180 let ptr = view.ptr();
183 unsafe {
184 assert_eq!(*ptr, 1.0);
186 assert_eq!(*ptr.offset(1), 2.0);
188 assert_eq!(*ptr.offset(3), 4.0);
190 assert_eq!(*ptr.offset(5), 6.0);
192 }
193 }
194
195 #[test]
196 fn test_strided_array_to_mdarray_roundtrip() {
197 let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
201 let arr = StridedArray::from_parts(data, &[3, 2], &[1, 3], 0).expect("valid strided array");
202
203 let md = strided_array_to_mdarray(arr);
204 assert_eq!(md.dims(), &[2, 3]);
206 }
211}