tenferro_prims/families/sort.rs
1use tenferro_algebra::{Algebra, Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext};
7use crate::{RocmBackend, RocmContext};
8
9/// Descriptor for sort-family planning.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_prims::SortPrimsDescriptor;
15///
16/// let desc = SortPrimsDescriptor::Sort {
17/// axis: 0,
18/// descending: false,
19/// stable: true,
20/// };
21/// assert!(matches!(desc, SortPrimsDescriptor::Sort { .. }));
22///
23/// let desc = SortPrimsDescriptor::Topk {
24/// axis: 0,
25/// k: 3,
26/// largest: true,
27/// sorted: true,
28/// };
29/// assert!(matches!(desc, SortPrimsDescriptor::Topk { .. }));
30/// ```
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub enum SortPrimsDescriptor {
33 /// Sort elements along `axis`.
34 ///
35 /// Produces both sorted values and the corresponding permutation indices.
36 Sort {
37 /// Axis along which to sort.
38 axis: usize,
39 /// If true, sort in descending order.
40 descending: bool,
41 /// If true, use a stable sort (preserving relative order of equal elements).
42 stable: bool,
43 },
44 /// Compute the permutation that would sort elements along `axis`.
45 ///
46 /// Returns only the index permutation (values output is left unmodified).
47 Argsort {
48 /// Axis along which to compute the sort permutation.
49 axis: usize,
50 /// If true, sort in descending order.
51 descending: bool,
52 /// If true, use a stable sort.
53 stable: bool,
54 },
55 /// Select the top-k elements along `axis`.
56 ///
57 /// Returns the k largest (or smallest) values and their indices.
58 Topk {
59 /// Axis along which to select top-k.
60 axis: usize,
61 /// Number of elements to select.
62 k: usize,
63 /// If true, select the k largest; if false, select the k smallest.
64 largest: bool,
65 /// If true, the returned k elements are sorted.
66 sorted: bool,
67 },
68}
69
70/// Sort execution protocol family.
71///
72/// Provides sorting, argsort, and top-k operations on tensors. The `indices`
73/// output tensor uses `i64` elements representing positions along the target
74/// axis (same convention as [`crate::TensorIndexingPrims`]).
75///
76/// Unlike semiring and scalar families, sort does not use `alpha`/`beta`
77/// scaling -- operations are pure data rearrangement.
78///
79/// # Examples
80///
81/// ```ignore
82/// use tenferro_algebra::Standard;
83/// use tenferro_prims::{CpuBackend, CpuContext, SortPrimsDescriptor, TensorSortPrims};
84///
85/// let mut ctx = CpuContext::new(1);
86/// let desc = SortPrimsDescriptor::Sort {
87/// axis: 0,
88/// descending: false,
89/// stable: true,
90/// };
91/// let _plan = <CpuBackend as TensorSortPrims<Standard<f64>>>::plan(
92/// &mut ctx,
93/// &desc,
94/// &[&[5]],
95/// )
96/// .unwrap();
97/// ```
98pub trait TensorSortPrims<Alg: Algebra>: Sized
99where
100 Alg::Scalar: PartialOrd,
101{
102 /// Backend-specific execution plan.
103 type Plan;
104 /// Backend execution context.
105 type Context;
106
107 /// Plan a sort operation for the given input shape.
108 ///
109 /// `shapes` contains: `[input_shape]`.
110 fn plan(
111 ctx: &mut Self::Context,
112 desc: &SortPrimsDescriptor,
113 shapes: &[&[usize]],
114 ) -> Result<Self::Plan>;
115
116 /// Execute a previously planned sort operation.
117 ///
118 /// - For `Sort`: writes sorted values to `values_out` and permutation
119 /// indices to `indices_out`.
120 /// - For `Argsort`: writes permutation indices to `indices_out`;
121 /// `values_out` is left unmodified.
122 /// - For `Topk`: writes top-k values to `values_out` and their original
123 /// indices to `indices_out`.
124 fn execute(
125 ctx: &mut Self::Context,
126 plan: &Self::Plan,
127 input: &Tensor<Alg::Scalar>,
128 values_out: &mut Tensor<Alg::Scalar>,
129 indices_out: &mut Tensor<i64>,
130 ) -> Result<()>;
131
132 /// Report whether the backend advertises support for the given descriptor.
133 fn has_sort_support(desc: &SortPrimsDescriptor) -> bool;
134}
135
136// ============================================================================
137// CUDA stub (when `cuda` feature is NOT enabled)
138// ============================================================================
139
140#[cfg(not(feature = "cuda"))]
141impl<S: Scalar + PartialOrd> TensorSortPrims<Standard<S>> for CudaBackend {
142 type Plan = ();
143 type Context = CudaContext;
144
145 fn plan(
146 _ctx: &mut Self::Context,
147 desc: &SortPrimsDescriptor,
148 _shapes: &[&[usize]],
149 ) -> Result<Self::Plan> {
150 Err(Error::DeviceError(format!(
151 "sort family descriptor {desc:?} is not implemented on CudaBackend"
152 )))
153 }
154
155 fn execute(
156 _ctx: &mut Self::Context,
157 _plan: &Self::Plan,
158 _input: &Tensor<S>,
159 _values_out: &mut Tensor<S>,
160 _indices_out: &mut Tensor<i64>,
161 ) -> Result<()> {
162 Err(Error::DeviceError(
163 "sort family execution is not implemented on CudaBackend".into(),
164 ))
165 }
166
167 fn has_sort_support(_desc: &SortPrimsDescriptor) -> bool {
168 false
169 }
170}
171
172// ============================================================================
173// ROCm stub
174// ============================================================================
175
176impl<S: Scalar + PartialOrd> TensorSortPrims<Standard<S>> for RocmBackend {
177 type Plan = ();
178 type Context = RocmContext;
179
180 fn plan(
181 _ctx: &mut Self::Context,
182 desc: &SortPrimsDescriptor,
183 _shapes: &[&[usize]],
184 ) -> Result<Self::Plan> {
185 Err(Error::DeviceError(format!(
186 "sort family descriptor {desc:?} is not implemented on RocmBackend"
187 )))
188 }
189
190 fn execute(
191 _ctx: &mut Self::Context,
192 _plan: &Self::Plan,
193 _input: &Tensor<S>,
194 _values_out: &mut Tensor<S>,
195 _indices_out: &mut Tensor<i64>,
196 ) -> Result<()> {
197 Err(Error::DeviceError(
198 "sort family execution is not implemented on RocmBackend".into(),
199 ))
200 }
201
202 fn has_sort_support(_desc: &SortPrimsDescriptor) -> bool {
203 false
204 }
205}