1use tenferro_algebra::{Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::{MemoryOrder, Tensor};
4
5use crate::{validate_shape_count, CpuBackend, CpuContext, SortPrimsDescriptor, TensorSortPrims};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum CpuSortPlan {
17 Sort {
19 axis: usize,
21 descending: bool,
23 stable: bool,
25 },
26 Argsort {
28 axis: usize,
30 descending: bool,
32 stable: bool,
34 },
35 Topk {
37 axis: usize,
39 k: usize,
41 largest: bool,
43 sorted: bool,
45 },
46}
47
48fn sort_slice<T: Copy + PartialOrd>(
50 slice: &[T],
51 descending: bool,
52 stable: bool,
53) -> (Vec<T>, Vec<i64>) {
54 let mut indexed: Vec<(T, i64)> = slice.iter().copied().zip(0i64..).collect();
55 let cmp = |a: &(T, i64), b: &(T, i64)| {
56 let ord = a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal);
57 if descending {
58 ord.reverse()
59 } else {
60 ord
61 }
62 };
63 if stable {
64 indexed.sort_by(cmp);
65 } else {
66 indexed.sort_unstable_by(cmp);
67 }
68 let values: Vec<T> = indexed.iter().map(|(v, _)| *v).collect();
69 let indices: Vec<i64> = indexed.iter().map(|(_, i)| *i).collect();
70 (values, indices)
71}
72
73fn execute_sort_along_axis<T: Scalar + PartialOrd>(
75 input: &Tensor<T>,
76 values_out: &mut Tensor<T>,
77 indices_out: &mut Tensor<i64>,
78 axis: usize,
79 descending: bool,
80 stable: bool,
81 write_values: bool,
82) -> Result<()> {
83 let src = input.contiguous(MemoryOrder::ColumnMajor);
84 let src_data = src
85 .buffer()
86 .as_slice()
87 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
88
89 let shape = src.dims();
90 let ndim = shape.len();
91 if axis >= ndim {
92 return Err(Error::InvalidArgument(format!(
93 "sort axis {axis} >= ndim {ndim}"
94 )));
95 }
96
97 let axis_size = shape[axis];
98 let pre_size: usize = shape[..axis].iter().product();
99 let post_size: usize = shape[axis + 1..].iter().product();
100
101 let total: usize = shape.iter().product();
102 let mut val_data = vec![T::zero(); total];
103 let mut idx_data = vec![0i64; total];
104
105 let mut slice_buf = vec![T::zero(); axis_size];
107
108 for post in 0..post_size {
109 for pre in 0..pre_size {
110 for (a, slot) in slice_buf.iter_mut().enumerate().take(axis_size) {
112 let flat = pre + a * pre_size + post * pre_size * axis_size;
113 *slot = src_data[flat];
114 }
115
116 let (sorted_vals, sorted_idx) = sort_slice(&slice_buf, descending, stable);
117
118 for a in 0..axis_size {
120 let flat = pre + a * pre_size + post * pre_size * axis_size;
121 if write_values {
122 val_data[flat] = sorted_vals[a];
123 }
124 idx_data[flat] = sorted_idx[a];
125 }
126 }
127 }
128
129 if write_values {
130 *values_out = Tensor::from_slice(&val_data, shape, MemoryOrder::ColumnMajor)
131 .map_err(|e| Error::InvalidArgument(format!("sort values output: {e}")))?;
132 }
133 *indices_out = Tensor::from_slice(&idx_data, shape, MemoryOrder::ColumnMajor)
134 .map_err(|e| Error::InvalidArgument(format!("sort indices output: {e}")))?;
135
136 Ok(())
137}
138
139fn execute_topk<T: Scalar + PartialOrd>(
141 input: &Tensor<T>,
142 values_out: &mut Tensor<T>,
143 indices_out: &mut Tensor<i64>,
144 axis: usize,
145 k: usize,
146 largest: bool,
147 sorted: bool,
148) -> Result<()> {
149 let src = input.contiguous(MemoryOrder::ColumnMajor);
150 let src_data = src
151 .buffer()
152 .as_slice()
153 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
154
155 let shape = src.dims();
156 let ndim = shape.len();
157 if axis >= ndim {
158 return Err(Error::InvalidArgument(format!(
159 "topk axis {axis} >= ndim {ndim}"
160 )));
161 }
162
163 let axis_size = shape[axis];
164 if k > axis_size {
165 return Err(Error::InvalidArgument(format!(
166 "topk k={k} > axis size {axis_size}"
167 )));
168 }
169
170 let pre_size: usize = shape[..axis].iter().product();
171 let post_size: usize = shape[axis + 1..].iter().product();
172
173 let mut out_shape = shape.to_vec();
175 out_shape[axis] = k;
176 let out_total: usize = out_shape.iter().product();
177
178 let mut val_data = vec![T::zero(); out_total];
179 let mut idx_data = vec![0i64; out_total];
180
181 let mut slice_buf = vec![T::zero(); axis_size];
182
183 for post in 0..post_size {
184 for pre in 0..pre_size {
185 for (a, slot) in slice_buf.iter_mut().enumerate().take(axis_size) {
187 let flat = pre + a * pre_size + post * pre_size * axis_size;
188 *slot = src_data[flat];
189 }
190
191 let (sorted_vals, sorted_idx) = sort_slice(&slice_buf, largest, sorted);
193
194 for a in 0..k {
196 let out_flat = pre + a * pre_size + post * pre_size * k;
197 val_data[out_flat] = sorted_vals[a];
198 idx_data[out_flat] = sorted_idx[a];
199 }
200 }
201 }
202
203 *values_out = Tensor::from_slice(&val_data, &out_shape, MemoryOrder::ColumnMajor)
204 .map_err(|e| Error::InvalidArgument(format!("topk values output: {e}")))?;
205 *indices_out = Tensor::from_slice(&idx_data, &out_shape, MemoryOrder::ColumnMajor)
206 .map_err(|e| Error::InvalidArgument(format!("topk indices output: {e}")))?;
207
208 Ok(())
209}
210
211impl<S: Scalar + PartialOrd + 'static> TensorSortPrims<Standard<S>> for CpuBackend {
212 type Plan = CpuSortPlan;
213 type Context = CpuContext;
214
215 fn plan(
216 _ctx: &mut Self::Context,
217 desc: &SortPrimsDescriptor,
218 shapes: &[&[usize]],
219 ) -> Result<Self::Plan> {
220 validate_shape_count(shapes, 1, "SortPrims")?;
221 match desc {
222 SortPrimsDescriptor::Sort {
223 axis,
224 descending,
225 stable,
226 } => Ok(CpuSortPlan::Sort {
227 axis: *axis,
228 descending: *descending,
229 stable: *stable,
230 }),
231 SortPrimsDescriptor::Argsort {
232 axis,
233 descending,
234 stable,
235 } => Ok(CpuSortPlan::Argsort {
236 axis: *axis,
237 descending: *descending,
238 stable: *stable,
239 }),
240 SortPrimsDescriptor::Topk {
241 axis,
242 k,
243 largest,
244 sorted,
245 } => Ok(CpuSortPlan::Topk {
246 axis: *axis,
247 k: *k,
248 largest: *largest,
249 sorted: *sorted,
250 }),
251 }
252 }
253
254 fn execute(
255 _ctx: &mut Self::Context,
256 plan: &Self::Plan,
257 input: &Tensor<S>,
258 values_out: &mut Tensor<S>,
259 indices_out: &mut Tensor<i64>,
260 ) -> Result<()> {
261 match plan {
262 CpuSortPlan::Sort {
263 axis,
264 descending,
265 stable,
266 } => execute_sort_along_axis(
267 input,
268 values_out,
269 indices_out,
270 *axis,
271 *descending,
272 *stable,
273 true,
274 ),
275 CpuSortPlan::Argsort {
276 axis,
277 descending,
278 stable,
279 } => execute_sort_along_axis(
280 input,
281 values_out,
282 indices_out,
283 *axis,
284 *descending,
285 *stable,
286 false,
287 ),
288 CpuSortPlan::Topk {
289 axis,
290 k,
291 largest,
292 sorted,
293 } => execute_topk(input, values_out, indices_out, *axis, *k, *largest, *sorted),
294 }
295 }
296
297 fn has_sort_support(_desc: &SortPrimsDescriptor) -> bool {
298 true
299 }
300}