1#[cfg(feature = "parallel")]
4use crate::hptt::execute_permute_blocked_par;
5use crate::hptt::{build_permute_plan, execute_permute_blocked};
6use crate::kernel::total_len;
7use strided_view::{Result, StridedError, StridedView, StridedViewMut};
8
9fn is_both_contiguous(dims: &[usize], dst_strides: &[isize], src_strides: &[isize]) -> bool {
11 if dims.is_empty() {
12 return true;
13 }
14
15 let mut expected = 1isize;
17 let mut col_ok = true;
18 for (&d, (&ds, &ss)) in dims.iter().zip(dst_strides.iter().zip(src_strides.iter())) {
19 if d <= 1 {
20 continue;
21 }
22 if ds != expected || ss != expected {
23 col_ok = false;
24 break;
25 }
26 expected = expected.saturating_mul(d as isize);
27 }
28 if col_ok {
29 return true;
30 }
31
32 let mut expected = 1isize;
34 let mut row_ok = true;
35 for (&d, (&ds, &ss)) in dims
36 .iter()
37 .rev()
38 .zip(dst_strides.iter().rev().zip(src_strides.iter().rev()))
39 {
40 if d <= 1 {
41 continue;
42 }
43 if ds != expected || ss != expected {
44 row_ok = false;
45 break;
46 }
47 expected = expected.saturating_mul(d as isize);
48 }
49 row_ok
50}
51
52pub fn copy_into<T: Copy>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()> {
60 let dst_dims = dest.dims();
61 let src_dims = src.dims();
62 if dst_dims.len() != src_dims.len() {
63 return Err(StridedError::RankMismatch(dst_dims.len(), src_dims.len()));
64 }
65 if dst_dims != src_dims {
66 return Err(StridedError::ShapeMismatch(
67 dst_dims.to_vec(),
68 src_dims.to_vec(),
69 ));
70 }
71
72 let dst_ptr = dest.as_mut_ptr();
73 let src_ptr = src.ptr();
74 let dst_strides = dest.strides();
75 let src_strides = src.strides();
76
77 if is_both_contiguous(dst_dims, dst_strides, src_strides) {
79 let len = total_len(dst_dims);
80 unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
81 return Ok(());
82 }
83
84 let elem_size = std::mem::size_of::<T>();
86 let plan = build_permute_plan(dst_dims, src_strides, dst_strides, elem_size);
87 unsafe {
88 execute_permute_blocked(src_ptr, dst_ptr, &plan);
89 }
90 Ok(())
91}
92
93pub fn copy_into_col_major<T: Copy>(
99 dst: &mut StridedViewMut<T>,
100 src: &StridedView<T>,
101) -> Result<()> {
102 copy_into(dst, src)
103}
104
105#[cfg(feature = "parallel")]
110pub fn copy_into_par<T: Copy + Send + Sync>(
111 dest: &mut StridedViewMut<T>,
112 src: &StridedView<T>,
113) -> Result<()> {
114 let dst_dims = dest.dims();
115 let src_dims = src.dims();
116 if dst_dims.len() != src_dims.len() {
117 return Err(StridedError::RankMismatch(dst_dims.len(), src_dims.len()));
118 }
119 if dst_dims != src_dims {
120 return Err(StridedError::ShapeMismatch(
121 dst_dims.to_vec(),
122 src_dims.to_vec(),
123 ));
124 }
125
126 let dst_ptr = dest.as_mut_ptr();
127 let src_ptr = src.ptr();
128 let dst_strides = dest.strides();
129 let src_strides = src.strides();
130
131 if is_both_contiguous(dst_dims, dst_strides, src_strides) {
133 let len = total_len(dst_dims);
134 unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
135 return Ok(());
136 }
137
138 let elem_size = std::mem::size_of::<T>();
139 let plan = build_permute_plan(dst_dims, src_strides, dst_strides, elem_size);
140 unsafe {
141 execute_permute_blocked_par(src_ptr, dst_ptr, &plan);
142 }
143 Ok(())
144}
145
146#[cfg(feature = "parallel")]
148pub fn copy_into_col_major_par<T: Copy + Send + Sync>(
149 dst: &mut StridedViewMut<T>,
150 src: &StridedView<T>,
151) -> Result<()> {
152 copy_into_par(dst, src)
153}
154
155pub fn try_fuse_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
160 match dims.len() {
161 0 => Some((1, 0)),
162 1 => Some((dims[0], strides[0])),
163 _ => {
164 if dims.len() != strides.len() {
165 return None;
166 }
167 for (&d, &s) in dims.iter().zip(strides.iter()) {
168 if d > 1 && s == 0 {
169 return None;
170 }
171 }
172 let mut base_idx: Option<usize> = None;
173 let mut base_abs = usize::MAX;
174 for (i, (&d, &s)) in dims.iter().zip(strides.iter()).enumerate() {
175 if d <= 1 {
176 continue;
177 }
178 let abs = s.unsigned_abs();
179 if abs < base_abs {
180 base_abs = abs;
181 base_idx = Some(i);
182 }
183 }
184
185 let Some(base) = base_idx else {
186 let stride = *strides
187 .iter()
188 .min_by_key(|s| s.unsigned_abs())
189 .unwrap_or(&0);
190 return Some((dims.iter().product(), stride));
191 };
192
193 let mut used = vec![false; dims.len()];
194 used[base] = true;
195 let mut expected_abs = base_abs.checked_mul(dims[base])?;
196
197 let non_singleton = dims.iter().filter(|&&d| d > 1).count();
198 for _ in 1..non_singleton {
199 let mut next = None;
200 for i in 0..dims.len() {
201 if used[i] || dims[i] <= 1 {
202 continue;
203 }
204 if strides[i].unsigned_abs() == expected_abs {
205 next = Some(i);
206 break;
207 }
208 }
209 let i = next?;
210 used[i] = true;
211 expected_abs = expected_abs.checked_mul(dims[i])?;
212 }
213
214 Some((dims.iter().product(), strides[base]))
215 }
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use strided_view::StridedArray;
223
224 #[test]
225 fn test_copy_into_contiguous() {
226 let src =
227 StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
228 let mut dst = StridedArray::<f64>::col_major(&[2, 3]);
229 copy_into(&mut dst.view_mut(), &src.view()).unwrap();
230 assert_eq!(dst.get(&[0, 0]), 0.0);
231 assert_eq!(dst.get(&[1, 2]), 12.0);
232 }
233
234 #[test]
235 fn test_copy_into_transposed() {
236 let src = StridedArray::<f64>::from_parts(
238 vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
239 &[3, 2],
240 &[2, 1], 0,
242 )
243 .unwrap();
244 let mut dst = StridedArray::<f64>::col_major(&[3, 2]);
245 copy_into(&mut dst.view_mut(), &src.view()).unwrap();
246 assert_eq!(dst.get(&[0, 0]), 0.0);
247 assert_eq!(dst.get(&[0, 1]), 1.0);
248 assert_eq!(dst.get(&[1, 0]), 2.0);
249 assert_eq!(dst.get(&[2, 1]), 5.0);
250 }
251
252 #[test]
253 fn test_copy_into_col_major_basic() {
254 let src =
255 StridedArray::<f64>::from_fn_col_major(&[4, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
256 let mut dst = StridedArray::<f64>::col_major(&[4, 3]);
257 copy_into_col_major(&mut dst.view_mut(), &src.view()).unwrap();
258 assert_eq!(dst.get(&[0, 0]), 0.0);
259 assert_eq!(dst.get(&[3, 2]), 32.0);
260 }
261
262 #[test]
263 fn test_copy_into_col_major_permuted_src() {
264 let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
266 let src = StridedArray::<f64>::from_parts(
267 data,
268 &[2, 3, 4],
269 &[12, 4, 1], 0,
271 )
272 .unwrap();
273 let mut dst = StridedArray::<f64>::col_major(&[2, 3, 4]);
274 copy_into_col_major(&mut dst.view_mut(), &src.view()).unwrap();
275 for i in 0..2 {
277 for j in 0..3 {
278 for k in 0..4 {
279 assert_eq!(
280 dst.get(&[i, j, k]),
281 src.get(&[i, j, k]),
282 "mismatch at [{},{},{}]",
283 i,
284 j,
285 k
286 );
287 }
288 }
289 }
290 }
291
292 #[test]
293 fn test_try_fuse_group_empty() {
294 assert_eq!(try_fuse_group(&[], &[]), Some((1, 0)));
295 }
296
297 #[test]
298 fn test_try_fuse_group_single() {
299 assert_eq!(try_fuse_group(&[5], &[2]), Some((5, 2)));
300 }
301
302 #[test]
303 fn test_try_fuse_group_contiguous_row_major() {
304 assert_eq!(try_fuse_group(&[3, 4], &[4, 1]), Some((12, 1)));
305 }
306
307 #[test]
308 fn test_try_fuse_group_contiguous_col_major() {
309 assert_eq!(try_fuse_group(&[3, 4], &[1, 3]), Some((12, 1)));
310 }
311
312 #[test]
313 fn test_try_fuse_group_non_contiguous() {
314 assert_eq!(try_fuse_group(&[3, 4], &[8, 1]), None);
315 }
316
317 #[test]
318 fn test_copy_shape_mismatch() {
319 let src = StridedArray::<f64>::col_major(&[2, 3]);
320 let mut dst = StridedArray::<f64>::col_major(&[3, 2]);
321 let result = copy_into(&mut dst.view_mut(), &src.view());
322 assert!(result.is_err());
323 }
324}