tenferro_prims/families/
sort.rs

1use tenferro_algebra::{Algebra, Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext};
7use crate::{RocmBackend, RocmContext};
8
9/// Descriptor for sort-family planning.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_prims::SortPrimsDescriptor;
15///
16/// let desc = SortPrimsDescriptor::Sort {
17///     axis: 0,
18///     descending: false,
19///     stable: true,
20/// };
21/// assert!(matches!(desc, SortPrimsDescriptor::Sort { .. }));
22///
23/// let desc = SortPrimsDescriptor::Topk {
24///     axis: 0,
25///     k: 3,
26///     largest: true,
27///     sorted: true,
28/// };
29/// assert!(matches!(desc, SortPrimsDescriptor::Topk { .. }));
30/// ```
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub enum SortPrimsDescriptor {
33    /// Sort elements along `axis`.
34    ///
35    /// Produces both sorted values and the corresponding permutation indices.
36    Sort {
37        /// Axis along which to sort.
38        axis: usize,
39        /// If true, sort in descending order.
40        descending: bool,
41        /// If true, use a stable sort (preserving relative order of equal elements).
42        stable: bool,
43    },
44    /// Compute the permutation that would sort elements along `axis`.
45    ///
46    /// Returns only the index permutation (values output is left unmodified).
47    Argsort {
48        /// Axis along which to compute the sort permutation.
49        axis: usize,
50        /// If true, sort in descending order.
51        descending: bool,
52        /// If true, use a stable sort.
53        stable: bool,
54    },
55    /// Select the top-k elements along `axis`.
56    ///
57    /// Returns the k largest (or smallest) values and their indices.
58    Topk {
59        /// Axis along which to select top-k.
60        axis: usize,
61        /// Number of elements to select.
62        k: usize,
63        /// If true, select the k largest; if false, select the k smallest.
64        largest: bool,
65        /// If true, the returned k elements are sorted.
66        sorted: bool,
67    },
68}
69
70/// Sort execution protocol family.
71///
72/// Provides sorting, argsort, and top-k operations on tensors. The `indices`
73/// output tensor uses `i64` elements representing positions along the target
74/// axis (same convention as [`crate::TensorIndexingPrims`]).
75///
76/// Unlike semiring and scalar families, sort does not use `alpha`/`beta`
77/// scaling -- operations are pure data rearrangement.
78///
79/// # Examples
80///
81/// ```ignore
82/// use tenferro_algebra::Standard;
83/// use tenferro_prims::{CpuBackend, CpuContext, SortPrimsDescriptor, TensorSortPrims};
84///
85/// let mut ctx = CpuContext::new(1);
86/// let desc = SortPrimsDescriptor::Sort {
87///     axis: 0,
88///     descending: false,
89///     stable: true,
90/// };
91/// let _plan = <CpuBackend as TensorSortPrims<Standard<f64>>>::plan(
92///     &mut ctx,
93///     &desc,
94///     &[&[5]],
95/// )
96/// .unwrap();
97/// ```
98pub trait TensorSortPrims<Alg: Algebra>: Sized
99where
100    Alg::Scalar: PartialOrd,
101{
102    /// Backend-specific execution plan.
103    type Plan;
104    /// Backend execution context.
105    type Context;
106
107    /// Plan a sort operation for the given input shape.
108    ///
109    /// `shapes` contains: `[input_shape]`.
110    fn plan(
111        ctx: &mut Self::Context,
112        desc: &SortPrimsDescriptor,
113        shapes: &[&[usize]],
114    ) -> Result<Self::Plan>;
115
116    /// Execute a previously planned sort operation.
117    ///
118    /// - For `Sort`: writes sorted values to `values_out` and permutation
119    ///   indices to `indices_out`.
120    /// - For `Argsort`: writes permutation indices to `indices_out`;
121    ///   `values_out` is left unmodified.
122    /// - For `Topk`: writes top-k values to `values_out` and their original
123    ///   indices to `indices_out`.
124    fn execute(
125        ctx: &mut Self::Context,
126        plan: &Self::Plan,
127        input: &Tensor<Alg::Scalar>,
128        values_out: &mut Tensor<Alg::Scalar>,
129        indices_out: &mut Tensor<i64>,
130    ) -> Result<()>;
131
132    /// Report whether the backend advertises support for the given descriptor.
133    fn has_sort_support(desc: &SortPrimsDescriptor) -> bool;
134}
135
136// ============================================================================
137// CUDA stub (when `cuda` feature is NOT enabled)
138// ============================================================================
139
140#[cfg(not(feature = "cuda"))]
141impl<S: Scalar + PartialOrd> TensorSortPrims<Standard<S>> for CudaBackend {
142    type Plan = ();
143    type Context = CudaContext;
144
145    fn plan(
146        _ctx: &mut Self::Context,
147        desc: &SortPrimsDescriptor,
148        _shapes: &[&[usize]],
149    ) -> Result<Self::Plan> {
150        Err(Error::DeviceError(format!(
151            "sort family descriptor {desc:?} is not implemented on CudaBackend"
152        )))
153    }
154
155    fn execute(
156        _ctx: &mut Self::Context,
157        _plan: &Self::Plan,
158        _input: &Tensor<S>,
159        _values_out: &mut Tensor<S>,
160        _indices_out: &mut Tensor<i64>,
161    ) -> Result<()> {
162        Err(Error::DeviceError(
163            "sort family execution is not implemented on CudaBackend".into(),
164        ))
165    }
166
167    fn has_sort_support(_desc: &SortPrimsDescriptor) -> bool {
168        false
169    }
170}
171
172// ============================================================================
173// ROCm stub
174// ============================================================================
175
176impl<S: Scalar + PartialOrd> TensorSortPrims<Standard<S>> for RocmBackend {
177    type Plan = ();
178    type Context = RocmContext;
179
180    fn plan(
181        _ctx: &mut Self::Context,
182        desc: &SortPrimsDescriptor,
183        _shapes: &[&[usize]],
184    ) -> Result<Self::Plan> {
185        Err(Error::DeviceError(format!(
186            "sort family descriptor {desc:?} is not implemented on RocmBackend"
187        )))
188    }
189
190    fn execute(
191        _ctx: &mut Self::Context,
192        _plan: &Self::Plan,
193        _input: &Tensor<S>,
194        _values_out: &mut Tensor<S>,
195        _indices_out: &mut Tensor<i64>,
196    ) -> Result<()> {
197        Err(Error::DeviceError(
198            "sort family execution is not implemented on RocmBackend".into(),
199        ))
200    }
201
202    fn has_sort_support(_desc: &SortPrimsDescriptor) -> bool {
203        false
204    }
205}