tenferro_prims/cpu/
sort.rs

1use tenferro_algebra::{Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::{MemoryOrder, Tensor};
4
5use crate::{validate_shape_count, CpuBackend, CpuContext, SortPrimsDescriptor, TensorSortPrims};
6
7/// CPU execution plan for the sort protocol family.
8///
9/// # Examples
10///
11/// ```ignore
12/// use tenferro_prims::CpuSortPlan;
13/// let _ = std::mem::size_of::<CpuSortPlan>();
14/// ```
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum CpuSortPlan {
17    /// Sort elements along an axis.
18    Sort {
19        /// Axis along which to sort.
20        axis: usize,
21        /// If true, sort in descending order.
22        descending: bool,
23        /// If true, use a stable sort.
24        stable: bool,
25    },
26    /// Compute the sort permutation along an axis.
27    Argsort {
28        /// Axis along which to compute the permutation.
29        axis: usize,
30        /// If true, sort in descending order.
31        descending: bool,
32        /// If true, use a stable sort.
33        stable: bool,
34    },
35    /// Select top-k elements along an axis.
36    Topk {
37        /// Axis along which to select top-k.
38        axis: usize,
39        /// Number of elements to select.
40        k: usize,
41        /// If true, select the k largest.
42        largest: bool,
43        /// If true, the returned k elements are sorted.
44        sorted: bool,
45    },
46}
47
48/// Sort one axis-slice, returning (sorted_values, original_indices).
49fn sort_slice<T: Copy + PartialOrd>(
50    slice: &[T],
51    descending: bool,
52    stable: bool,
53) -> (Vec<T>, Vec<i64>) {
54    let mut indexed: Vec<(T, i64)> = slice.iter().copied().zip(0i64..).collect();
55    let cmp = |a: &(T, i64), b: &(T, i64)| {
56        let ord = a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal);
57        if descending {
58            ord.reverse()
59        } else {
60            ord
61        }
62    };
63    if stable {
64        indexed.sort_by(cmp);
65    } else {
66        indexed.sort_unstable_by(cmp);
67    }
68    let values: Vec<T> = indexed.iter().map(|(v, _)| *v).collect();
69    let indices: Vec<i64> = indexed.iter().map(|(_, i)| *i).collect();
70    (values, indices)
71}
72
73/// Execute sort or argsort along the given axis.
74fn execute_sort_along_axis<T: Scalar + PartialOrd>(
75    input: &Tensor<T>,
76    values_out: &mut Tensor<T>,
77    indices_out: &mut Tensor<i64>,
78    axis: usize,
79    descending: bool,
80    stable: bool,
81    write_values: bool,
82) -> Result<()> {
83    let src = input.contiguous(MemoryOrder::ColumnMajor);
84    let src_data = src
85        .buffer()
86        .as_slice()
87        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
88
89    let shape = src.dims();
90    let ndim = shape.len();
91    if axis >= ndim {
92        return Err(Error::InvalidArgument(format!(
93            "sort axis {axis} >= ndim {ndim}"
94        )));
95    }
96
97    let axis_size = shape[axis];
98    let pre_size: usize = shape[..axis].iter().product();
99    let post_size: usize = shape[axis + 1..].iter().product();
100
101    let total: usize = shape.iter().product();
102    let mut val_data = vec![T::zero(); total];
103    let mut idx_data = vec![0i64; total];
104
105    // Temporary buffer for one axis slice
106    let mut slice_buf = vec![T::zero(); axis_size];
107
108    for post in 0..post_size {
109        for pre in 0..pre_size {
110            // Extract axis slice from column-major data
111            for (a, slot) in slice_buf.iter_mut().enumerate().take(axis_size) {
112                let flat = pre + a * pre_size + post * pre_size * axis_size;
113                *slot = src_data[flat];
114            }
115
116            let (sorted_vals, sorted_idx) = sort_slice(&slice_buf, descending, stable);
117
118            // Write back
119            for a in 0..axis_size {
120                let flat = pre + a * pre_size + post * pre_size * axis_size;
121                if write_values {
122                    val_data[flat] = sorted_vals[a];
123                }
124                idx_data[flat] = sorted_idx[a];
125            }
126        }
127    }
128
129    if write_values {
130        *values_out = Tensor::from_slice(&val_data, shape, MemoryOrder::ColumnMajor)
131            .map_err(|e| Error::InvalidArgument(format!("sort values output: {e}")))?;
132    }
133    *indices_out = Tensor::from_slice(&idx_data, shape, MemoryOrder::ColumnMajor)
134        .map_err(|e| Error::InvalidArgument(format!("sort indices output: {e}")))?;
135
136    Ok(())
137}
138
139/// Execute topk along the given axis.
140fn execute_topk<T: Scalar + PartialOrd>(
141    input: &Tensor<T>,
142    values_out: &mut Tensor<T>,
143    indices_out: &mut Tensor<i64>,
144    axis: usize,
145    k: usize,
146    largest: bool,
147    sorted: bool,
148) -> Result<()> {
149    let src = input.contiguous(MemoryOrder::ColumnMajor);
150    let src_data = src
151        .buffer()
152        .as_slice()
153        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
154
155    let shape = src.dims();
156    let ndim = shape.len();
157    if axis >= ndim {
158        return Err(Error::InvalidArgument(format!(
159            "topk axis {axis} >= ndim {ndim}"
160        )));
161    }
162
163    let axis_size = shape[axis];
164    if k > axis_size {
165        return Err(Error::InvalidArgument(format!(
166            "topk k={k} > axis size {axis_size}"
167        )));
168    }
169
170    let pre_size: usize = shape[..axis].iter().product();
171    let post_size: usize = shape[axis + 1..].iter().product();
172
173    // Output shape: same as input but with axis_size replaced by k
174    let mut out_shape = shape.to_vec();
175    out_shape[axis] = k;
176    let out_total: usize = out_shape.iter().product();
177
178    let mut val_data = vec![T::zero(); out_total];
179    let mut idx_data = vec![0i64; out_total];
180
181    let mut slice_buf = vec![T::zero(); axis_size];
182
183    for post in 0..post_size {
184        for pre in 0..pre_size {
185            // Extract axis slice
186            for (a, slot) in slice_buf.iter_mut().enumerate().take(axis_size) {
187                let flat = pre + a * pre_size + post * pre_size * axis_size;
188                *slot = src_data[flat];
189            }
190
191            // Sort descending if largest, ascending if smallest
192            let (sorted_vals, sorted_idx) = sort_slice(&slice_buf, largest, sorted);
193
194            // Take first k
195            for a in 0..k {
196                let out_flat = pre + a * pre_size + post * pre_size * k;
197                val_data[out_flat] = sorted_vals[a];
198                idx_data[out_flat] = sorted_idx[a];
199            }
200        }
201    }
202
203    *values_out = Tensor::from_slice(&val_data, &out_shape, MemoryOrder::ColumnMajor)
204        .map_err(|e| Error::InvalidArgument(format!("topk values output: {e}")))?;
205    *indices_out = Tensor::from_slice(&idx_data, &out_shape, MemoryOrder::ColumnMajor)
206        .map_err(|e| Error::InvalidArgument(format!("topk indices output: {e}")))?;
207
208    Ok(())
209}
210
211impl<S: Scalar + PartialOrd + 'static> TensorSortPrims<Standard<S>> for CpuBackend {
212    type Plan = CpuSortPlan;
213    type Context = CpuContext;
214
215    fn plan(
216        _ctx: &mut Self::Context,
217        desc: &SortPrimsDescriptor,
218        shapes: &[&[usize]],
219    ) -> Result<Self::Plan> {
220        validate_shape_count(shapes, 1, "SortPrims")?;
221        match desc {
222            SortPrimsDescriptor::Sort {
223                axis,
224                descending,
225                stable,
226            } => Ok(CpuSortPlan::Sort {
227                axis: *axis,
228                descending: *descending,
229                stable: *stable,
230            }),
231            SortPrimsDescriptor::Argsort {
232                axis,
233                descending,
234                stable,
235            } => Ok(CpuSortPlan::Argsort {
236                axis: *axis,
237                descending: *descending,
238                stable: *stable,
239            }),
240            SortPrimsDescriptor::Topk {
241                axis,
242                k,
243                largest,
244                sorted,
245            } => Ok(CpuSortPlan::Topk {
246                axis: *axis,
247                k: *k,
248                largest: *largest,
249                sorted: *sorted,
250            }),
251        }
252    }
253
254    fn execute(
255        _ctx: &mut Self::Context,
256        plan: &Self::Plan,
257        input: &Tensor<S>,
258        values_out: &mut Tensor<S>,
259        indices_out: &mut Tensor<i64>,
260    ) -> Result<()> {
261        match plan {
262            CpuSortPlan::Sort {
263                axis,
264                descending,
265                stable,
266            } => execute_sort_along_axis(
267                input,
268                values_out,
269                indices_out,
270                *axis,
271                *descending,
272                *stable,
273                true,
274            ),
275            CpuSortPlan::Argsort {
276                axis,
277                descending,
278                stable,
279            } => execute_sort_along_axis(
280                input,
281                values_out,
282                indices_out,
283                *axis,
284                *descending,
285                *stable,
286                false,
287            ),
288            CpuSortPlan::Topk {
289                axis,
290                k,
291                largest,
292                sorted,
293            } => execute_topk(input, values_out, indices_out, *axis, *k, *largest, *sorted),
294        }
295    }
296
297    fn has_sort_support(_desc: &SortPrimsDescriptor) -> bool {
298        true
299    }
300}