tenferro_prims/families/
indexing.rs

1use tenferro_algebra::{Algebra, Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext};
7use crate::{RocmBackend, RocmContext};
8
9/// Reduction mode for scatter operations.
10///
11/// # Examples
12///
13/// ```
14/// use tenferro_prims::ScatterReduction;
15///
16/// let r = ScatterReduction::None;
17/// assert_eq!(r, ScatterReduction::None);
18/// ```
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum ScatterReduction {
21    /// Overwrite the destination value.
22    None,
23    /// Add the source value into the destination.
24    Add,
25}
26
27/// Descriptor for indexing-family planning.
28///
29/// # Examples
30///
31/// ```
32/// use tenferro_prims::{IndexingPrimsDescriptor, ScatterReduction};
33///
34/// let desc = IndexingPrimsDescriptor::IndexSelect { axis: 0 };
35/// assert!(matches!(desc, IndexingPrimsDescriptor::IndexSelect { .. }));
36///
37/// let desc = IndexingPrimsDescriptor::Scatter {
38///     axis: 1,
39///     reduction: ScatterReduction::Add,
40/// };
41/// assert!(matches!(desc, IndexingPrimsDescriptor::Scatter { .. }));
42/// ```
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub enum IndexingPrimsDescriptor {
45    /// Select slices along `axis` using a 1-D index tensor.
46    ///
47    /// Given input of shape `[d0, ..., d_{axis}, ..., d_{n-1}]` and a 1-D
48    /// index tensor of length `k`, produces output of shape
49    /// `[d0, ..., k, ..., d_{n-1}]`.
50    IndexSelect {
51        /// Axis along which to select slices.
52        axis: usize,
53    },
54    /// Gather elements along `axis` using an index tensor of the same rank.
55    ///
56    /// The index tensor shape must match the output shape.
57    Gather {
58        /// Axis along which to gather.
59        axis: usize,
60    },
61    /// Scatter source values into output positions along `axis`.
62    ///
63    /// Inverse of gather: places values from `source` into `output` at
64    /// positions determined by the index tensor.
65    Scatter {
66        /// Axis along which to scatter.
67        axis: usize,
68        /// Reduction mode: overwrite or accumulate.
69        reduction: ScatterReduction,
70    },
71    /// Put source values at index positions (flat or per-axis).
72    ///
73    /// When `accumulate` is true, adds into the output instead of overwriting.
74    IndexPut {
75        /// Whether to accumulate (add) instead of overwriting.
76        accumulate: bool,
77    },
78}
79
80/// Indexing execution protocol family.
81///
82/// Provides index-based selection, gathering, and scattering of tensor
83/// elements. Unlike the scalar and analytic families, indexing does not use
84/// `alpha`/`beta` scaling — operations are pure data movement.
85///
86/// The `indices` tensor uses `i64` elements representing positions along the
87/// target axis. Negative indices are not supported and will produce errors.
88///
89/// # Examples
90///
91/// ```ignore
92/// use tenferro_algebra::Standard;
93/// use tenferro_prims::{CpuBackend, CpuContext, IndexingPrimsDescriptor, TensorIndexingPrims};
94///
95/// let mut ctx = CpuContext::new(1);
96/// let desc = IndexingPrimsDescriptor::IndexSelect { axis: 0 };
97/// let _plan = <CpuBackend as TensorIndexingPrims<Standard<f64>>>::plan(
98///     &mut ctx,
99///     &desc,
100///     &[&[3, 2], &[2], &[2, 2]],
101/// )
102/// .unwrap();
103/// ```
104pub trait TensorIndexingPrims<Alg: Algebra>: Sized {
105    /// Backend-specific execution plan.
106    type Plan;
107    /// Backend execution context.
108    type Context;
109
110    /// Plan an indexing operation for the given input/index/output shapes.
111    ///
112    /// `shapes` contains: `[input_shape, index_shape, output_shape]`.
113    fn plan(
114        ctx: &mut Self::Context,
115        desc: &IndexingPrimsDescriptor,
116        shapes: &[&[usize]],
117    ) -> Result<Self::Plan>;
118
119    /// Execute a previously planned indexing operation.
120    ///
121    /// - For `IndexSelect` and `Gather`: reads from `inputs[0]` using
122    ///   `indices`, writes into `output`.
123    /// - For `Scatter`: reads from `inputs[0]` (source values) using
124    ///   `indices`, writes into `output`.
125    /// - For `IndexPut`: reads from `inputs[0]` (values to put) using
126    ///   `indices`, writes into `output`.
127    fn execute(
128        ctx: &mut Self::Context,
129        plan: &Self::Plan,
130        inputs: &[&Tensor<Alg::Scalar>],
131        indices: &Tensor<i64>,
132        output: &mut Tensor<Alg::Scalar>,
133    ) -> Result<()>;
134
135    /// Report whether the backend advertises support for the given descriptor.
136    fn has_indexing_support(desc: IndexingPrimsDescriptor) -> bool;
137}
138
139// ============================================================================
140// CUDA stub (when `cuda` feature is NOT enabled)
141// ============================================================================
142
143#[cfg(not(feature = "cuda"))]
144impl<S: Scalar> TensorIndexingPrims<Standard<S>> for CudaBackend {
145    type Plan = ();
146    type Context = CudaContext;
147
148    fn plan(
149        _ctx: &mut Self::Context,
150        desc: &IndexingPrimsDescriptor,
151        _shapes: &[&[usize]],
152    ) -> Result<Self::Plan> {
153        Err(Error::DeviceError(format!(
154            "indexing family descriptor {desc:?} is not implemented on CudaBackend"
155        )))
156    }
157
158    fn execute(
159        _ctx: &mut Self::Context,
160        _plan: &Self::Plan,
161        _inputs: &[&Tensor<S>],
162        _indices: &Tensor<i64>,
163        _output: &mut Tensor<S>,
164    ) -> Result<()> {
165        Err(Error::DeviceError(
166            "indexing family execution is not implemented on CudaBackend".into(),
167        ))
168    }
169
170    fn has_indexing_support(_desc: IndexingPrimsDescriptor) -> bool {
171        false
172    }
173}
174
175// ============================================================================
176// ROCm stub
177// ============================================================================
178
179impl<S: Scalar> TensorIndexingPrims<Standard<S>> for RocmBackend {
180    type Plan = ();
181    type Context = RocmContext;
182
183    fn plan(
184        _ctx: &mut Self::Context,
185        desc: &IndexingPrimsDescriptor,
186        _shapes: &[&[usize]],
187    ) -> Result<Self::Plan> {
188        Err(Error::DeviceError(format!(
189            "indexing family descriptor {desc:?} is not implemented on RocmBackend"
190        )))
191    }
192
193    fn execute(
194        _ctx: &mut Self::Context,
195        _plan: &Self::Plan,
196        _inputs: &[&Tensor<S>],
197        _indices: &Tensor<i64>,
198        _output: &mut Tensor<S>,
199    ) -> Result<()> {
200        Err(Error::DeviceError(
201            "indexing family execution is not implemented on RocmBackend".into(),
202        ))
203    }
204
205    fn has_indexing_support(_desc: IndexingPrimsDescriptor) -> bool {
206        false
207    }
208}