tenferro_prims/families/indexing.rs
1use tenferro_algebra::{Algebra, Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext};
7use crate::{RocmBackend, RocmContext};
8
9/// Reduction mode for scatter operations.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_prims::ScatterReduction;
15///
16/// let r = ScatterReduction::None;
17/// assert_eq!(r, ScatterReduction::None);
18/// ```
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum ScatterReduction {
21 /// Overwrite the destination value.
22 None,
23 /// Add the source value into the destination.
24 Add,
25}
26
27/// Descriptor for indexing-family planning.
28///
29/// # Examples
30///
31/// ```
32/// use tenferro_prims::{IndexingPrimsDescriptor, ScatterReduction};
33///
34/// let desc = IndexingPrimsDescriptor::IndexSelect { axis: 0 };
35/// assert!(matches!(desc, IndexingPrimsDescriptor::IndexSelect { .. }));
36///
37/// let desc = IndexingPrimsDescriptor::Scatter {
38/// axis: 1,
39/// reduction: ScatterReduction::Add,
40/// };
41/// assert!(matches!(desc, IndexingPrimsDescriptor::Scatter { .. }));
42/// ```
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub enum IndexingPrimsDescriptor {
45 /// Select slices along `axis` using a 1-D index tensor.
46 ///
47 /// Given input of shape `[d0, ..., d_{axis}, ..., d_{n-1}]` and a 1-D
48 /// index tensor of length `k`, produces output of shape
49 /// `[d0, ..., k, ..., d_{n-1}]`.
50 IndexSelect {
51 /// Axis along which to select slices.
52 axis: usize,
53 },
54 /// Gather elements along `axis` using an index tensor of the same rank.
55 ///
56 /// The index tensor shape must match the output shape.
57 Gather {
58 /// Axis along which to gather.
59 axis: usize,
60 },
61 /// Scatter source values into output positions along `axis`.
62 ///
63 /// Inverse of gather: places values from `source` into `output` at
64 /// positions determined by the index tensor.
65 Scatter {
66 /// Axis along which to scatter.
67 axis: usize,
68 /// Reduction mode: overwrite or accumulate.
69 reduction: ScatterReduction,
70 },
71 /// Put source values at index positions (flat or per-axis).
72 ///
73 /// When `accumulate` is true, adds into the output instead of overwriting.
74 IndexPut {
75 /// Whether to accumulate (add) instead of overwriting.
76 accumulate: bool,
77 },
78}
79
80/// Indexing execution protocol family.
81///
82/// Provides index-based selection, gathering, and scattering of tensor
83/// elements. Unlike the scalar and analytic families, indexing does not use
84/// `alpha`/`beta` scaling — operations are pure data movement.
85///
86/// The `indices` tensor uses `i64` elements representing positions along the
87/// target axis. Negative indices are not supported and will produce errors.
88///
89/// # Examples
90///
91/// ```ignore
92/// use tenferro_algebra::Standard;
93/// use tenferro_prims::{CpuBackend, CpuContext, IndexingPrimsDescriptor, TensorIndexingPrims};
94///
95/// let mut ctx = CpuContext::new(1);
96/// let desc = IndexingPrimsDescriptor::IndexSelect { axis: 0 };
97/// let _plan = <CpuBackend as TensorIndexingPrims<Standard<f64>>>::plan(
98/// &mut ctx,
99/// &desc,
100/// &[&[3, 2], &[2], &[2, 2]],
101/// )
102/// .unwrap();
103/// ```
104pub trait TensorIndexingPrims<Alg: Algebra>: Sized {
105 /// Backend-specific execution plan.
106 type Plan;
107 /// Backend execution context.
108 type Context;
109
110 /// Plan an indexing operation for the given input/index/output shapes.
111 ///
112 /// `shapes` contains: `[input_shape, index_shape, output_shape]`.
113 fn plan(
114 ctx: &mut Self::Context,
115 desc: &IndexingPrimsDescriptor,
116 shapes: &[&[usize]],
117 ) -> Result<Self::Plan>;
118
119 /// Execute a previously planned indexing operation.
120 ///
121 /// - For `IndexSelect` and `Gather`: reads from `inputs[0]` using
122 /// `indices`, writes into `output`.
123 /// - For `Scatter`: reads from `inputs[0]` (source values) using
124 /// `indices`, writes into `output`.
125 /// - For `IndexPut`: reads from `inputs[0]` (values to put) using
126 /// `indices`, writes into `output`.
127 fn execute(
128 ctx: &mut Self::Context,
129 plan: &Self::Plan,
130 inputs: &[&Tensor<Alg::Scalar>],
131 indices: &Tensor<i64>,
132 output: &mut Tensor<Alg::Scalar>,
133 ) -> Result<()>;
134
135 /// Report whether the backend advertises support for the given descriptor.
136 fn has_indexing_support(desc: IndexingPrimsDescriptor) -> bool;
137}
138
139// ============================================================================
140// CUDA stub (when `cuda` feature is NOT enabled)
141// ============================================================================
142
143#[cfg(not(feature = "cuda"))]
144impl<S: Scalar> TensorIndexingPrims<Standard<S>> for CudaBackend {
145 type Plan = ();
146 type Context = CudaContext;
147
148 fn plan(
149 _ctx: &mut Self::Context,
150 desc: &IndexingPrimsDescriptor,
151 _shapes: &[&[usize]],
152 ) -> Result<Self::Plan> {
153 Err(Error::DeviceError(format!(
154 "indexing family descriptor {desc:?} is not implemented on CudaBackend"
155 )))
156 }
157
158 fn execute(
159 _ctx: &mut Self::Context,
160 _plan: &Self::Plan,
161 _inputs: &[&Tensor<S>],
162 _indices: &Tensor<i64>,
163 _output: &mut Tensor<S>,
164 ) -> Result<()> {
165 Err(Error::DeviceError(
166 "indexing family execution is not implemented on CudaBackend".into(),
167 ))
168 }
169
170 fn has_indexing_support(_desc: IndexingPrimsDescriptor) -> bool {
171 false
172 }
173}
174
175// ============================================================================
176// ROCm stub
177// ============================================================================
178
179impl<S: Scalar> TensorIndexingPrims<Standard<S>> for RocmBackend {
180 type Plan = ();
181 type Context = RocmContext;
182
183 fn plan(
184 _ctx: &mut Self::Context,
185 desc: &IndexingPrimsDescriptor,
186 _shapes: &[&[usize]],
187 ) -> Result<Self::Plan> {
188 Err(Error::DeviceError(format!(
189 "indexing family descriptor {desc:?} is not implemented on RocmBackend"
190 )))
191 }
192
193 fn execute(
194 _ctx: &mut Self::Context,
195 _plan: &Self::Plan,
196 _inputs: &[&Tensor<S>],
197 _indices: &Tensor<i64>,
198 _output: &mut Tensor<S>,
199 ) -> Result<()> {
200 Err(Error::DeviceError(
201 "indexing family execution is not implemented on RocmBackend".into(),
202 ))
203 }
204
205 fn has_indexing_support(_desc: IndexingPrimsDescriptor) -> bool {
206 false
207 }
208}