Skip to main content

strided_perm/
copy.rs

1//! Copy/permutation operations on strided views.
2
3#[cfg(feature = "parallel")]
4use crate::hptt::execute_permute_blocked_par;
5use crate::hptt::{build_permute_plan, execute_permute_blocked};
6use crate::kernel::total_len;
7use strided_view::{Result, StridedError, StridedView, StridedViewMut};
8
9/// Check if all strides indicate contiguous column-major or row-major layout.
10fn is_both_contiguous(dims: &[usize], dst_strides: &[isize], src_strides: &[isize]) -> bool {
11    if dims.is_empty() {
12        return true;
13    }
14
15    // Check col-major for both
16    let mut expected = 1isize;
17    let mut col_ok = true;
18    for (&d, (&ds, &ss)) in dims.iter().zip(dst_strides.iter().zip(src_strides.iter())) {
19        if d <= 1 {
20            continue;
21        }
22        if ds != expected || ss != expected {
23            col_ok = false;
24            break;
25        }
26        expected = expected.saturating_mul(d as isize);
27    }
28    if col_ok {
29        return true;
30    }
31
32    // Check row-major for both
33    let mut expected = 1isize;
34    let mut row_ok = true;
35    for (&d, (&ds, &ss)) in dims
36        .iter()
37        .rev()
38        .zip(dst_strides.iter().rev().zip(src_strides.iter().rev()))
39    {
40        if d <= 1 {
41            continue;
42        }
43        if ds != expected || ss != expected {
44            row_ok = false;
45            break;
46        }
47        expected = expected.saturating_mul(d as isize);
48    }
49    row_ok
50}
51
52/// Copy elements from source to destination: `dest[i] = src[i]`.
53///
54/// Uses HPTT-inspired blocked permutation with bilateral dimension fusion,
55/// cache-aware blocking, and optimal loop ordering.
56///
57/// This is a simple copy without ElementOp support. For copies with
58/// element operations (conj, transpose, etc.), use `strided_kernel::copy_into`.
59pub fn copy_into<T: Copy>(dest: &mut StridedViewMut<T>, src: &StridedView<T>) -> Result<()> {
60    let dst_dims = dest.dims();
61    let src_dims = src.dims();
62    if dst_dims.len() != src_dims.len() {
63        return Err(StridedError::RankMismatch(dst_dims.len(), src_dims.len()));
64    }
65    if dst_dims != src_dims {
66        return Err(StridedError::ShapeMismatch(
67            dst_dims.to_vec(),
68            src_dims.to_vec(),
69        ));
70    }
71
72    let dst_ptr = dest.as_mut_ptr();
73    let src_ptr = src.ptr();
74    let dst_strides = dest.strides();
75    let src_strides = src.strides();
76
77    // Fast path: both contiguous
78    if is_both_contiguous(dst_dims, dst_strides, src_strides) {
79        let len = total_len(dst_dims);
80        unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
81        return Ok(());
82    }
83
84    // HPTT-inspired blocked permutation
85    let elem_size = std::mem::size_of::<T>();
86    let plan = build_permute_plan(dst_dims, src_strides, dst_strides, elem_size);
87    unsafe {
88        execute_permute_blocked(src_ptr, dst_ptr, &plan);
89    }
90    Ok(())
91}
92
93/// Copy elements to a col-major destination.
94///
95/// This now delegates to the same HPTT-inspired blocked permutation as
96/// `copy_into`. The HPTT planner automatically handles the col-major
97/// destination case optimally.
98pub fn copy_into_col_major<T: Copy>(
99    dst: &mut StridedViewMut<T>,
100    src: &StridedView<T>,
101) -> Result<()> {
102    copy_into(dst, src)
103}
104
105/// Copy elements from source to destination with Rayon parallelism.
106///
107/// Parallelizes the outermost block loop of the HPTT-inspired permutation.
108/// Falls back to single-threaded for small tensors.
109#[cfg(feature = "parallel")]
110pub fn copy_into_par<T: Copy + Send + Sync>(
111    dest: &mut StridedViewMut<T>,
112    src: &StridedView<T>,
113) -> Result<()> {
114    let dst_dims = dest.dims();
115    let src_dims = src.dims();
116    if dst_dims.len() != src_dims.len() {
117        return Err(StridedError::RankMismatch(dst_dims.len(), src_dims.len()));
118    }
119    if dst_dims != src_dims {
120        return Err(StridedError::ShapeMismatch(
121            dst_dims.to_vec(),
122            src_dims.to_vec(),
123        ));
124    }
125
126    let dst_ptr = dest.as_mut_ptr();
127    let src_ptr = src.ptr();
128    let dst_strides = dest.strides();
129    let src_strides = src.strides();
130
131    // Fast path: both contiguous (memcpy is already bandwidth-limited)
132    if is_both_contiguous(dst_dims, dst_strides, src_strides) {
133        let len = total_len(dst_dims);
134        unsafe { std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len) };
135        return Ok(());
136    }
137
138    let elem_size = std::mem::size_of::<T>();
139    let plan = build_permute_plan(dst_dims, src_strides, dst_strides, elem_size);
140    unsafe {
141        execute_permute_blocked_par(src_ptr, dst_ptr, &plan);
142    }
143    Ok(())
144}
145
146/// Copy elements to a col-major destination with Rayon parallelism.
147#[cfg(feature = "parallel")]
148pub fn copy_into_col_major_par<T: Copy + Send + Sync>(
149    dst: &mut StridedViewMut<T>,
150    src: &StridedView<T>,
151) -> Result<()> {
152    copy_into_par(dst, src)
153}
154
155/// Try to fuse a contiguous dimension group into a single (total_size, innermost_stride).
156///
157/// For the group to be fusable, consecutive dimensions must have contiguous strides.
158/// Returns `None` if strides are not contiguous within the group.
159pub fn try_fuse_group(dims: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
160    match dims.len() {
161        0 => Some((1, 0)),
162        1 => Some((dims[0], strides[0])),
163        _ => {
164            if dims.len() != strides.len() {
165                return None;
166            }
167            for (&d, &s) in dims.iter().zip(strides.iter()) {
168                if d > 1 && s == 0 {
169                    return None;
170                }
171            }
172            let mut base_idx: Option<usize> = None;
173            let mut base_abs = usize::MAX;
174            for (i, (&d, &s)) in dims.iter().zip(strides.iter()).enumerate() {
175                if d <= 1 {
176                    continue;
177                }
178                let abs = s.unsigned_abs();
179                if abs < base_abs {
180                    base_abs = abs;
181                    base_idx = Some(i);
182                }
183            }
184
185            let Some(base) = base_idx else {
186                let stride = *strides
187                    .iter()
188                    .min_by_key(|s| s.unsigned_abs())
189                    .unwrap_or(&0);
190                return Some((dims.iter().product(), stride));
191            };
192
193            let mut used = vec![false; dims.len()];
194            used[base] = true;
195            let mut expected_abs = base_abs.checked_mul(dims[base])?;
196
197            let non_singleton = dims.iter().filter(|&&d| d > 1).count();
198            for _ in 1..non_singleton {
199                let mut next = None;
200                for i in 0..dims.len() {
201                    if used[i] || dims[i] <= 1 {
202                        continue;
203                    }
204                    if strides[i].unsigned_abs() == expected_abs {
205                        next = Some(i);
206                        break;
207                    }
208                }
209                let i = next?;
210                used[i] = true;
211                expected_abs = expected_abs.checked_mul(dims[i])?;
212            }
213
214            Some((dims.iter().product(), strides[base]))
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use strided_view::StridedArray;
223
224    #[test]
225    fn test_copy_into_contiguous() {
226        let src =
227            StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
228        let mut dst = StridedArray::<f64>::col_major(&[2, 3]);
229        copy_into(&mut dst.view_mut(), &src.view()).unwrap();
230        assert_eq!(dst.get(&[0, 0]), 0.0);
231        assert_eq!(dst.get(&[1, 2]), 12.0);
232    }
233
234    #[test]
235    fn test_copy_into_transposed() {
236        // src is row-major [3,2], dst is col-major [3,2]
237        let src = StridedArray::<f64>::from_parts(
238            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
239            &[3, 2],
240            &[2, 1], // row-major
241            0,
242        )
243        .unwrap();
244        let mut dst = StridedArray::<f64>::col_major(&[3, 2]);
245        copy_into(&mut dst.view_mut(), &src.view()).unwrap();
246        assert_eq!(dst.get(&[0, 0]), 0.0);
247        assert_eq!(dst.get(&[0, 1]), 1.0);
248        assert_eq!(dst.get(&[1, 0]), 2.0);
249        assert_eq!(dst.get(&[2, 1]), 5.0);
250    }
251
252    #[test]
253    fn test_copy_into_col_major_basic() {
254        let src =
255            StridedArray::<f64>::from_fn_col_major(&[4, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
256        let mut dst = StridedArray::<f64>::col_major(&[4, 3]);
257        copy_into_col_major(&mut dst.view_mut(), &src.view()).unwrap();
258        assert_eq!(dst.get(&[0, 0]), 0.0);
259        assert_eq!(dst.get(&[3, 2]), 32.0);
260    }
261
262    #[test]
263    fn test_copy_into_col_major_permuted_src() {
264        // src has scattered strides, dst is col-major
265        let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
266        let src = StridedArray::<f64>::from_parts(
267            data,
268            &[2, 3, 4],
269            &[12, 4, 1], // row-major
270            0,
271        )
272        .unwrap();
273        let mut dst = StridedArray::<f64>::col_major(&[2, 3, 4]);
274        copy_into_col_major(&mut dst.view_mut(), &src.view()).unwrap();
275        // Verify element-by-element
276        for i in 0..2 {
277            for j in 0..3 {
278                for k in 0..4 {
279                    assert_eq!(
280                        dst.get(&[i, j, k]),
281                        src.get(&[i, j, k]),
282                        "mismatch at [{},{},{}]",
283                        i,
284                        j,
285                        k
286                    );
287                }
288            }
289        }
290    }
291
292    #[test]
293    fn test_try_fuse_group_empty() {
294        assert_eq!(try_fuse_group(&[], &[]), Some((1, 0)));
295    }
296
297    #[test]
298    fn test_try_fuse_group_single() {
299        assert_eq!(try_fuse_group(&[5], &[2]), Some((5, 2)));
300    }
301
302    #[test]
303    fn test_try_fuse_group_contiguous_row_major() {
304        assert_eq!(try_fuse_group(&[3, 4], &[4, 1]), Some((12, 1)));
305    }
306
307    #[test]
308    fn test_try_fuse_group_contiguous_col_major() {
309        assert_eq!(try_fuse_group(&[3, 4], &[1, 3]), Some((12, 1)));
310    }
311
312    #[test]
313    fn test_try_fuse_group_non_contiguous() {
314        assert_eq!(try_fuse_group(&[3, 4], &[8, 1]), None);
315    }
316
317    #[test]
318    fn test_copy_shape_mismatch() {
319        let src = StridedArray::<f64>::col_major(&[2, 3]);
320        let mut dst = StridedArray::<f64>::col_major(&[3, 2]);
321        let result = copy_into(&mut dst.view_mut(), &src.view());
322        assert!(result.is_err());
323    }
324}