Skip to main content

strided_perm/
copy.rs

1//! Copy/permutation operations on strided views.
2
3#[cfg(feature = "parallel")]
4use crate::hptt::execute_permute_blocked_par;
5use crate::hptt::{build_permute_plan, execute_permute_blocked};
6use strided_view::{Result, StridedError, StridedView, StridedViewMut};
7
8fn total_len(dims: &[usize]) -> usize {
9    dims.iter().product()
10}
11
12/// Check if all strides indicate contiguous column-major or row-major layout.
13fn is_both_contiguous(dims: &[usize], dst_strides: &[isize], src_strides: &[isize]) -> bool {
14    if dims.is_empty() {
15        return true;
16    }
17
18    // Check col-major for both
19    let mut expected = 1isize;
20    let mut col_ok = true;
21    for (&d, (&ds, &ss)) in dims.iter().zip(dst_strides.iter().zip(src_strides.iter())) {
22        if d <= 1 {
23            continue;
24        }
25        if ds != expected || ss != expected {
26            col_ok = false;
27            break;
28        }
29        expected = expected.saturating_mul(d as isize);
30    }
31    if col_ok {
32        return true;
33    }
34
35    // Check row-major for both
36    let mut expected = 1isize;
37    let mut row_ok = true;
38    for (&d, (&ds, &ss)) in dims
39        .iter()
40        .rev()
41        .zip(dst_strides.iter().rev().zip(src_strides.iter().rev()))
42    {
43        if d <= 1 {
44            continue;
45        }
46        if ds != expected || ss != expected {
47            row_ok = false;
48            break;
49        }
50        expected = expected.saturating_mul(d as isize);
51    }
52    row_ok
53}
54
55/// Copy elements from source to destination: `dest[i] = src[i]`.
56///
57/// Uses HPTT-inspired blocked permutation with bilateral dimension fusion,
58/// cache-aware blocking, and optimal loop ordering.
59///
60/// This is a simple copy without ElementOp support. For copies with
61/// element operations (conj, transpose, etc.), use `strided_kernel::copy_into`.
62pub fn copy_into<T: Copy>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()> {
63    let dst_dims = dest.dims();
64    let src_dims = src.dims();
65    if dst_dims.len() != src_dims.len() {
66        return Err(StridedError::RankMismatch(dst_dims.len(), src_dims.len()));
67    }
68    if dst_dims != src_dims {
69        return Err(StridedError::ShapeMismatch(
70            dst_dims.to_vec(),
71            src_dims.to_vec(),
72        ));
73    }
74
75    let dst_ptr = dest.as_mut_ptr();
76    let src_ptr = src.ptr();
77    let dst_strides = dest.strides();
78    let src_strides = src.strides();
79
80    // Fast path: both contiguous
81    if is_both_contiguous(dst_dims, dst_strides, src_strides) {
82        let len = total_len(dst_dims);
83        unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
84        return Ok(());
85    }
86
87    // HPTT-inspired blocked permutation
88    let elem_size = std::mem::size_of::<T>();
89    let plan = build_permute_plan(dst_dims, src_strides, dst_strides, elem_size);
90    unsafe {
91        execute_permute_blocked(src_ptr, dst_ptr, &plan);
92    }
93    Ok(())
94}
95
96/// Copy elements to a col-major destination.
97///
98/// This now delegates to the same HPTT-inspired blocked permutation as
99/// `copy_into`. The HPTT planner automatically handles the col-major
100/// destination case optimally.
101pub fn copy_into_col_major<T: Copy>(
102    dst: &mut StridedViewMut<T>,
103    src: &StridedView<T>,
104) -> Result<()> {
105    copy_into(dst, src)
106}
107
108/// Copy elements from source to destination with Rayon parallelism.
109///
110/// Parallelizes the outermost block loop of the HPTT-inspired permutation.
111/// Falls back to single-threaded for small tensors.
112#[cfg(feature = "parallel")]
113pub fn copy_into_par<T: Copy + Send + Sync>(
114    dest: &mut StridedViewMut<T>,
115    src: &StridedView<T>,
116) -> Result<()> {
117    let dst_dims = dest.dims();
118    let src_dims = src.dims();
119    if dst_dims.len() != src_dims.len() {
120        return Err(StridedError::RankMismatch(dst_dims.len(), src_dims.len()));
121    }
122    if dst_dims != src_dims {
123        return Err(StridedError::ShapeMismatch(
124            dst_dims.to_vec(),
125            src_dims.to_vec(),
126        ));
127    }
128
129    let dst_ptr = dest.as_mut_ptr();
130    let src_ptr = src.ptr();
131    let dst_strides = dest.strides();
132    let src_strides = src.strides();
133
134    // Fast path: both contiguous (memcpy is already bandwidth-limited)
135    if is_both_contiguous(dst_dims, dst_strides, src_strides) {
136        let len = total_len(dst_dims);
137        unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
138        return Ok(());
139    }
140
141    let elem_size = std::mem::size_of::<T>();
142    let plan = build_permute_plan(dst_dims, src_strides, dst_strides, elem_size);
143    unsafe {
144        execute_permute_blocked_par(src_ptr, dst_ptr, &plan);
145    }
146    Ok(())
147}
148
149/// Copy elements to a col-major destination with Rayon parallelism.
150#[cfg(feature = "parallel")]
151pub fn copy_into_col_major_par<T: Copy + Send + Sync>(
152    dst: &mut StridedViewMut<T>,
153    src: &StridedView<T>,
154) -> Result<()> {
155    copy_into_par(dst, src)
156}
157
158/// Try to fuse a contiguous dimension group into a single (total_size, innermost_stride).
159///
160/// For the group to be fusable, consecutive dimensions must have contiguous strides.
161/// Returns `None` if strides are not contiguous within the group.
162#[cfg(test)]
163fn try_fuse_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
164    match dims.len() {
165        0 => Some((1, 0)),
166        1 => Some((dims[0], strides[0])),
167        _ => {
168            if dims.len() != strides.len() {
169                return None;
170            }
171            for (&d, &s) in dims.iter().zip(strides.iter()) {
172                if d > 1 && s == 0 {
173                    return None;
174                }
175            }
176            let mut base_idx: Option<usize> = None;
177            let mut base_abs = usize::MAX;
178            for (i, (&d, &s)) in dims.iter().zip(strides.iter()).enumerate() {
179                if d <= 1 {
180                    continue;
181                }
182                let abs = s.unsigned_abs();
183                if abs < base_abs {
184                    base_abs = abs;
185                    base_idx = Some(i);
186                }
187            }
188
189            let Some(base) = base_idx else {
190                let stride = *strides
191                    .iter()
192                    .min_by_key(|s| s.unsigned_abs())
193                    .unwrap_or(&0);
194                return Some((dims.iter().product(), stride));
195            };
196
197            let mut used = vec![false; dims.len()];
198            used[base] = true;
199            let mut expected_abs = base_abs.checked_mul(dims[base])?;
200
201            let non_singleton = dims.iter().filter(|&&d| d > 1).count();
202            for _ in 1..non_singleton {
203                let mut next = None;
204                for i in 0..dims.len() {
205                    if used[i] || dims[i] <= 1 {
206                        continue;
207                    }
208                    if strides[i].unsigned_abs() == expected_abs {
209                        next = Some(i);
210                        break;
211                    }
212                }
213                let i = next?;
214                used[i] = true;
215                expected_abs = expected_abs.checked_mul(dims[i])?;
216            }
217
218            Some((dims.iter().product(), strides[base]))
219        }
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use strided_view::StridedArray;
227
228    #[test]
229    fn test_copy_into_contiguous() {
230        let src =
231            StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
232        let mut dst = StridedArray::<f64>::col_major(&[2, 3]);
233        copy_into(&mut dst.view_mut(), &src.view()).unwrap();
234        assert_eq!(dst.get(&[0, 0]), 0.0);
235        assert_eq!(dst.get(&[1, 2]), 12.0);
236    }
237
238    #[test]
239    fn test_copy_into_transposed() {
240        // src is row-major [3,2], dst is col-major [3,2]
241        let src = StridedArray::<f64>::from_parts(
242            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
243            &[3, 2],
244            &[2, 1], // row-major
245            0,
246        )
247        .unwrap();
248        let mut dst = StridedArray::<f64>::col_major(&[3, 2]);
249        copy_into(&mut dst.view_mut(), &src.view()).unwrap();
250        assert_eq!(dst.get(&[0, 0]), 0.0);
251        assert_eq!(dst.get(&[0, 1]), 1.0);
252        assert_eq!(dst.get(&[1, 0]), 2.0);
253        assert_eq!(dst.get(&[2, 1]), 5.0);
254    }
255
256    #[test]
257    fn test_copy_into_col_major_basic() {
258        let src =
259            StridedArray::<f64>::from_fn_col_major(&[4, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
260        let mut dst = StridedArray::<f64>::col_major(&[4, 3]);
261        copy_into_col_major(&mut dst.view_mut(), &src.view()).unwrap();
262        assert_eq!(dst.get(&[0, 0]), 0.0);
263        assert_eq!(dst.get(&[3, 2]), 32.0);
264    }
265
266    #[test]
267    fn test_copy_into_col_major_permuted_src() {
268        // src has scattered strides, dst is col-major
269        let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
270        let src = StridedArray::<f64>::from_parts(
271            data,
272            &[2, 3, 4],
273            &[12, 4, 1], // row-major
274            0,
275        )
276        .unwrap();
277        let mut dst = StridedArray::<f64>::col_major(&[2, 3, 4]);
278        copy_into_col_major(&mut dst.view_mut(), &src.view()).unwrap();
279        // Verify element-by-element
280        for i in 0..2 {
281            for j in 0..3 {
282                for k in 0..4 {
283                    assert_eq!(
284                        dst.get(&[i, j, k]),
285                        src.get(&[i, j, k]),
286                        "mismatch at [{},{},{}]",
287                        i,
288                        j,
289                        k
290                    );
291                }
292            }
293        }
294    }
295
296    #[test]
297    fn test_copy_into_grouped_high_rank_binary_permutation() {
298        let dims = vec![2usize; 12];
299        let total: usize = dims.iter().product();
300        let src = StridedArray::<u32>::from_fn_col_major(&dims, |idx| {
301            idx.iter()
302                .enumerate()
303                .map(|(axis, &coord)| (coord as u32) << axis)
304                .sum::<u32>()
305        });
306        let perm = vec![2, 4, 5, 6, 0, 3, 7, 1, 8, 9, 10, 11];
307        let src_view = src.view().permute(&perm).unwrap();
308        let mut dst = StridedArray::<u32>::col_major(&dims);
309
310        copy_into(&mut dst.view_mut(), &src_view).unwrap();
311
312        for flat in 0..total {
313            let mut rem = flat;
314            let mut src_offset = 0usize;
315            for &stride in src_view.strides() {
316                let coord = rem & 1;
317                rem >>= 1;
318                src_offset += coord * stride as usize;
319            }
320            assert_eq!(dst.data()[flat], src.data()[src_offset], "flat={flat}");
321        }
322    }
323
324    #[test]
325    fn test_try_fuse_group_empty() {
326        assert_eq!(try_fuse_group(&[], &[]), Some((1, 0)));
327    }
328
329    #[test]
330    fn test_try_fuse_group_single() {
331        assert_eq!(try_fuse_group(&[5], &[2]), Some((5, 2)));
332    }
333
334    #[test]
335    fn test_try_fuse_group_contiguous_row_major() {
336        assert_eq!(try_fuse_group(&[3, 4], &[4, 1]), Some((12, 1)));
337    }
338
339    #[test]
340    fn test_try_fuse_group_contiguous_col_major() {
341        assert_eq!(try_fuse_group(&[3, 4], &[1, 3]), Some((12, 1)));
342    }
343
344    #[test]
345    fn test_try_fuse_group_non_contiguous() {
346        assert_eq!(try_fuse_group(&[3, 4], &[8, 1]), None);
347    }
348
349    #[test]
350    fn test_copy_shape_mismatch() {
351        let src = StridedArray::<f64>::col_major(&[2, 3]);
352        let mut dst = StridedArray::<f64>::col_major(&[3, 2]);
353        let result = copy_into(&mut dst.view_mut(), &src.view());
354        assert!(result.is_err());
355    }
356}