1#[cfg(feature = "parallel")]
4use crate::hptt::execute_permute_blocked_par;
5use crate::hptt::{build_permute_plan, execute_permute_blocked};
6use strided_view::{Result, StridedError, StridedView, StridedViewMut};
7
8fn total_len(dims: &[usize]) -> usize {
9 dims.iter().product()
10}
11
12fn is_both_contiguous(dims: &[usize], dst_strides: &[isize], src_strides: &[isize]) -> bool {
14 if dims.is_empty() {
15 return true;
16 }
17
18 let mut expected = 1isize;
20 let mut col_ok = true;
21 for (&d, (&ds, &ss)) in dims.iter().zip(dst_strides.iter().zip(src_strides.iter())) {
22 if d <= 1 {
23 continue;
24 }
25 if ds != expected || ss != expected {
26 col_ok = false;
27 break;
28 }
29 expected = expected.saturating_mul(d as isize);
30 }
31 if col_ok {
32 return true;
33 }
34
35 let mut expected = 1isize;
37 let mut row_ok = true;
38 for (&d, (&ds, &ss)) in dims
39 .iter()
40 .rev()
41 .zip(dst_strides.iter().rev().zip(src_strides.iter().rev()))
42 {
43 if d <= 1 {
44 continue;
45 }
46 if ds != expected || ss != expected {
47 row_ok = false;
48 break;
49 }
50 expected = expected.saturating_mul(d as isize);
51 }
52 row_ok
53}
54
55pub fn copy_into<T: Copy>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()> {
63 let dst_dims = dest.dims();
64 let src_dims = src.dims();
65 if dst_dims.len() != src_dims.len() {
66 return Err(StridedError::RankMismatch(dst_dims.len(), src_dims.len()));
67 }
68 if dst_dims != src_dims {
69 return Err(StridedError::ShapeMismatch(
70 dst_dims.to_vec(),
71 src_dims.to_vec(),
72 ));
73 }
74
75 let dst_ptr = dest.as_mut_ptr();
76 let src_ptr = src.ptr();
77 let dst_strides = dest.strides();
78 let src_strides = src.strides();
79
80 if is_both_contiguous(dst_dims, dst_strides, src_strides) {
82 let len = total_len(dst_dims);
83 unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
84 return Ok(());
85 }
86
87 let elem_size = std::mem::size_of::<T>();
89 let plan = build_permute_plan(dst_dims, src_strides, dst_strides, elem_size);
90 unsafe {
91 execute_permute_blocked(src_ptr, dst_ptr, &plan);
92 }
93 Ok(())
94}
95
96pub fn copy_into_col_major<T: Copy>(
102 dst: &mut StridedViewMut<T>,
103 src: &StridedView<T>,
104) -> Result<()> {
105 copy_into(dst, src)
106}
107
108#[cfg(feature = "parallel")]
113pub fn copy_into_par<T: Copy + Send + Sync>(
114 dest: &mut StridedViewMut<T>,
115 src: &StridedView<T>,
116) -> Result<()> {
117 let dst_dims = dest.dims();
118 let src_dims = src.dims();
119 if dst_dims.len() != src_dims.len() {
120 return Err(StridedError::RankMismatch(dst_dims.len(), src_dims.len()));
121 }
122 if dst_dims != src_dims {
123 return Err(StridedError::ShapeMismatch(
124 dst_dims.to_vec(),
125 src_dims.to_vec(),
126 ));
127 }
128
129 let dst_ptr = dest.as_mut_ptr();
130 let src_ptr = src.ptr();
131 let dst_strides = dest.strides();
132 let src_strides = src.strides();
133
134 if is_both_contiguous(dst_dims, dst_strides, src_strides) {
136 let len = total_len(dst_dims);
137 unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
138 return Ok(());
139 }
140
141 let elem_size = std::mem::size_of::<T>();
142 let plan = build_permute_plan(dst_dims, src_strides, dst_strides, elem_size);
143 unsafe {
144 execute_permute_blocked_par(src_ptr, dst_ptr, &plan);
145 }
146 Ok(())
147}
148
149#[cfg(feature = "parallel")]
151pub fn copy_into_col_major_par<T: Copy + Send + Sync>(
152 dst: &mut StridedViewMut<T>,
153 src: &StridedView<T>,
154) -> Result<()> {
155 copy_into_par(dst, src)
156}
157
158#[cfg(test)]
163fn try_fuse_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
164 match dims.len() {
165 0 => Some((1, 0)),
166 1 => Some((dims[0], strides[0])),
167 _ => {
168 if dims.len() != strides.len() {
169 return None;
170 }
171 for (&d, &s) in dims.iter().zip(strides.iter()) {
172 if d > 1 && s == 0 {
173 return None;
174 }
175 }
176 let mut base_idx: Option<usize> = None;
177 let mut base_abs = usize::MAX;
178 for (i, (&d, &s)) in dims.iter().zip(strides.iter()).enumerate() {
179 if d <= 1 {
180 continue;
181 }
182 let abs = s.unsigned_abs();
183 if abs < base_abs {
184 base_abs = abs;
185 base_idx = Some(i);
186 }
187 }
188
189 let Some(base) = base_idx else {
190 let stride = *strides
191 .iter()
192 .min_by_key(|s| s.unsigned_abs())
193 .unwrap_or(&0);
194 return Some((dims.iter().product(), stride));
195 };
196
197 let mut used = vec![false; dims.len()];
198 used[base] = true;
199 let mut expected_abs = base_abs.checked_mul(dims[base])?;
200
201 let non_singleton = dims.iter().filter(|&&d| d > 1).count();
202 for _ in 1..non_singleton {
203 let mut next = None;
204 for i in 0..dims.len() {
205 if used[i] || dims[i] <= 1 {
206 continue;
207 }
208 if strides[i].unsigned_abs() == expected_abs {
209 next = Some(i);
210 break;
211 }
212 }
213 let i = next?;
214 used[i] = true;
215 expected_abs = expected_abs.checked_mul(dims[i])?;
216 }
217
218 Some((dims.iter().product(), strides[base]))
219 }
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use strided_view::StridedArray;
227
228 #[test]
229 fn test_copy_into_contiguous() {
230 let src =
231 StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
232 let mut dst = StridedArray::<f64>::col_major(&[2, 3]);
233 copy_into(&mut dst.view_mut(), &src.view()).unwrap();
234 assert_eq!(dst.get(&[0, 0]), 0.0);
235 assert_eq!(dst.get(&[1, 2]), 12.0);
236 }
237
238 #[test]
239 fn test_copy_into_transposed() {
240 let src = StridedArray::<f64>::from_parts(
242 vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
243 &[3, 2],
244 &[2, 1], 0,
246 )
247 .unwrap();
248 let mut dst = StridedArray::<f64>::col_major(&[3, 2]);
249 copy_into(&mut dst.view_mut(), &src.view()).unwrap();
250 assert_eq!(dst.get(&[0, 0]), 0.0);
251 assert_eq!(dst.get(&[0, 1]), 1.0);
252 assert_eq!(dst.get(&[1, 0]), 2.0);
253 assert_eq!(dst.get(&[2, 1]), 5.0);
254 }
255
256 #[test]
257 fn test_copy_into_col_major_basic() {
258 let src =
259 StridedArray::<f64>::from_fn_col_major(&[4, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
260 let mut dst = StridedArray::<f64>::col_major(&[4, 3]);
261 copy_into_col_major(&mut dst.view_mut(), &src.view()).unwrap();
262 assert_eq!(dst.get(&[0, 0]), 0.0);
263 assert_eq!(dst.get(&[3, 2]), 32.0);
264 }
265
266 #[test]
267 fn test_copy_into_col_major_permuted_src() {
268 let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
270 let src = StridedArray::<f64>::from_parts(
271 data,
272 &[2, 3, 4],
273 &[12, 4, 1], 0,
275 )
276 .unwrap();
277 let mut dst = StridedArray::<f64>::col_major(&[2, 3, 4]);
278 copy_into_col_major(&mut dst.view_mut(), &src.view()).unwrap();
279 for i in 0..2 {
281 for j in 0..3 {
282 for k in 0..4 {
283 assert_eq!(
284 dst.get(&[i, j, k]),
285 src.get(&[i, j, k]),
286 "mismatch at [{},{},{}]",
287 i,
288 j,
289 k
290 );
291 }
292 }
293 }
294 }
295
296 #[test]
297 fn test_copy_into_grouped_high_rank_binary_permutation() {
298 let dims = vec![2usize; 12];
299 let total: usize = dims.iter().product();
300 let src = StridedArray::<u32>::from_fn_col_major(&dims, |idx| {
301 idx.iter()
302 .enumerate()
303 .map(|(axis, &coord)| (coord as u32) << axis)
304 .sum::<u32>()
305 });
306 let perm = vec![2, 4, 5, 6, 0, 3, 7, 1, 8, 9, 10, 11];
307 let src_view = src.view().permute(&perm).unwrap();
308 let mut dst = StridedArray::<u32>::col_major(&dims);
309
310 copy_into(&mut dst.view_mut(), &src_view).unwrap();
311
312 for flat in 0..total {
313 let mut rem = flat;
314 let mut src_offset = 0usize;
315 for &stride in src_view.strides() {
316 let coord = rem & 1;
317 rem >>= 1;
318 src_offset += coord * stride as usize;
319 }
320 assert_eq!(dst.data()[flat], src.data()[src_offset], "flat={flat}");
321 }
322 }
323
324 #[test]
325 fn test_try_fuse_group_empty() {
326 assert_eq!(try_fuse_group(&[], &[]), Some((1, 0)));
327 }
328
329 #[test]
330 fn test_try_fuse_group_single() {
331 assert_eq!(try_fuse_group(&[5], &[2]), Some((5, 2)));
332 }
333
334 #[test]
335 fn test_try_fuse_group_contiguous_row_major() {
336 assert_eq!(try_fuse_group(&[3, 4], &[4, 1]), Some((12, 1)));
337 }
338
339 #[test]
340 fn test_try_fuse_group_contiguous_col_major() {
341 assert_eq!(try_fuse_group(&[3, 4], &[1, 3]), Some((12, 1)));
342 }
343
344 #[test]
345 fn test_try_fuse_group_non_contiguous() {
346 assert_eq!(try_fuse_group(&[3, 4], &[8, 1]), None);
347 }
348
349 #[test]
350 fn test_copy_shape_mismatch() {
351 let src = StridedArray::<f64>::col_major(&[2, 3]);
352 let mut dst = StridedArray::<f64>::col_major(&[3, 2]);
353 let result = copy_into(&mut dst.view_mut(), &src.view());
354 assert!(result.is_err());
355 }
356}