tenferro_prims/cpu/
indexing.rs

1use tenferro_algebra::{Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::{MemoryOrder, Tensor};
4
5use crate::{
6    validate_execute_inputs, validate_shape_count, CpuBackend, CpuContext, IndexingPrimsDescriptor,
7    ScatterReduction, TensorIndexingPrims,
8};
9
10/// CPU execution plan for the indexing protocol family.
11///
12/// # Examples
13///
14/// ```ignore
15/// use tenferro_prims::CpuIndexingPlan;
16/// let _ = std::mem::size_of::<CpuIndexingPlan>();
17/// ```
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum CpuIndexingPlan {
20    /// Select slices along an axis by index.
21    IndexSelect {
22        /// Axis along which to select.
23        axis: usize,
24    },
25    /// Gather elements along an axis using an index tensor.
26    Gather {
27        /// Axis along which to gather.
28        axis: usize,
29    },
30    /// Scatter source values along an axis into the output.
31    Scatter {
32        /// Axis along which to scatter.
33        axis: usize,
34        /// Reduction mode.
35        reduction: ScatterReduction,
36    },
37    /// Put source values at index positions.
38    IndexPut {
39        /// Whether to accumulate (add) instead of overwriting.
40        accumulate: bool,
41    },
42}
43
44/// Execute `IndexSelect`: for each index i, copy the i-th slice along `axis`.
45fn execute_index_select<T: Scalar>(
46    source: &Tensor<T>,
47    indices: &Tensor<i64>,
48    output: &mut Tensor<T>,
49    axis: usize,
50) -> Result<()> {
51    let src = source.contiguous(MemoryOrder::ColumnMajor);
52    let src_data = src
53        .buffer()
54        .as_slice()
55        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
56    let idx = indices.contiguous(MemoryOrder::ColumnMajor);
57    let idx_data = idx
58        .buffer()
59        .as_slice()
60        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
61
62    let src_shape = src.dims();
63    let ndim = src_shape.len();
64    if axis >= ndim {
65        return Err(Error::InvalidArgument(format!(
66            "index_select axis {axis} >= ndim {ndim}"
67        )));
68    }
69
70    let num_indices = idx_data.len();
71    let axis_size = src_shape[axis];
72
73    // Compute the stride product for dimensions before and after axis
74    let pre_size: usize = src_shape[..axis].iter().product();
75    let post_size: usize = src_shape[axis + 1..].iter().product();
76
77    // Build output shape: replace src_shape[axis] with num_indices
78    let mut out_shape = src_shape.to_vec();
79    out_shape[axis] = num_indices;
80    let out_total: usize = out_shape.iter().product();
81
82    // Allocate output data
83    let mut out_data = vec![T::zero(); out_total];
84
85    // Column-major: axis `axis` has stride = product of dims before it
86    // For each combination of (pre, idx_pos, post):
87    //   out[pre + idx_pos * pre_size + post * pre_size * num_indices]
88    //     = src[pre + indices[idx_pos] * pre_size + post * pre_size * axis_size]
89    for post in 0..post_size {
90        for (idx_pos, &idx_val) in idx_data.iter().enumerate() {
91            let idx_usize = idx_val as usize;
92            if idx_usize >= axis_size {
93                return Err(Error::InvalidArgument(format!(
94                    "index_select: index {idx_val} out of bounds for axis {axis} with size {axis_size}"
95                )));
96            }
97            let src_offset = idx_usize * pre_size + post * pre_size * axis_size;
98            let out_offset = idx_pos * pre_size + post * pre_size * num_indices;
99            out_data[out_offset..out_offset + pre_size]
100                .copy_from_slice(&src_data[src_offset..src_offset + pre_size]);
101        }
102    }
103
104    *output = Tensor::from_slice(&out_data, &out_shape, MemoryOrder::ColumnMajor)
105        .map_err(|e| Error::InvalidArgument(format!("index_select output: {e}")))?;
106    Ok(())
107}
108
109/// Execute `Gather`: for each position in the output, use the index tensor to
110/// select an element along `axis`.
111fn execute_gather<T: Scalar>(
112    source: &Tensor<T>,
113    indices: &Tensor<i64>,
114    output: &mut Tensor<T>,
115    axis: usize,
116) -> Result<()> {
117    let src = source.contiguous(MemoryOrder::ColumnMajor);
118    let src_data = src
119        .buffer()
120        .as_slice()
121        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
122    let idx = indices.contiguous(MemoryOrder::ColumnMajor);
123    let idx_data = idx
124        .buffer()
125        .as_slice()
126        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
127
128    let src_shape = src.dims();
129    let ndim = src_shape.len();
130    if axis >= ndim {
131        return Err(Error::InvalidArgument(format!(
132            "gather axis {axis} >= ndim {ndim}"
133        )));
134    }
135
136    let out_shape = idx.dims().to_vec();
137    if out_shape.len() != ndim {
138        return Err(Error::InvalidArgument(format!(
139            "gather: index tensor rank {} != source rank {ndim}",
140            out_shape.len()
141        )));
142    }
143
144    let axis_size = src_shape[axis];
145    let total: usize = out_shape.iter().product();
146    let mut out_data = vec![T::zero(); total];
147
148    // Column-major multi-index iteration
149    let mut multi_idx = vec![0usize; ndim];
150    for flat in 0..total {
151        // The index tensor gives us the position along axis
152        let idx_val = idx_data[flat];
153        let idx_usize = idx_val as usize;
154        if idx_usize >= axis_size {
155            return Err(Error::InvalidArgument(format!(
156                "gather: index {idx_val} out of bounds for axis {axis} with size {axis_size}"
157            )));
158        }
159
160        // Compute source flat index: same multi-index but with axis replaced by idx_val
161        let mut src_flat = 0usize;
162        let mut stride = 1usize;
163        for d in 0..ndim {
164            let coord = if d == axis { idx_usize } else { multi_idx[d] };
165            src_flat += coord * stride;
166            stride *= src_shape[d];
167        }
168
169        out_data[flat] = src_data[src_flat];
170
171        // Increment multi-index (column-major order)
172        for d in 0..ndim {
173            multi_idx[d] += 1;
174            if multi_idx[d] < out_shape[d] {
175                break;
176            }
177            multi_idx[d] = 0;
178        }
179    }
180
181    *output = Tensor::from_slice(&out_data, &out_shape, MemoryOrder::ColumnMajor)
182        .map_err(|e| Error::InvalidArgument(format!("gather output: {e}")))?;
183    Ok(())
184}
185
186/// Execute `Scatter`: place source values into output at indexed positions
187/// along `axis`.
188fn execute_scatter<T: Scalar>(
189    source: &Tensor<T>,
190    indices: &Tensor<i64>,
191    output: &mut Tensor<T>,
192    axis: usize,
193    reduction: ScatterReduction,
194) -> Result<()> {
195    let src = source.contiguous(MemoryOrder::ColumnMajor);
196    let src_data = src
197        .buffer()
198        .as_slice()
199        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
200    let idx = indices.contiguous(MemoryOrder::ColumnMajor);
201    let idx_data = idx
202        .buffer()
203        .as_slice()
204        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
205
206    let out_shape = output.dims().to_vec();
207    let ndim = out_shape.len();
208    if axis >= ndim {
209        return Err(Error::InvalidArgument(format!(
210            "scatter axis {axis} >= ndim {ndim}"
211        )));
212    }
213
214    let src_shape = src.dims();
215    if src_shape.len() != ndim {
216        return Err(Error::InvalidArgument(format!(
217            "scatter: source rank {} != output rank {ndim}",
218            src_shape.len()
219        )));
220    }
221
222    let axis_size = out_shape[axis];
223    let total: usize = src_shape.iter().product();
224
225    // Make output contiguous so we can write into it
226    let out_contig = output.contiguous(MemoryOrder::ColumnMajor);
227    let mut out_data = out_contig
228        .buffer()
229        .as_slice()
230        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?
231        .to_vec();
232
233    // Column-major multi-index iteration over source shape
234    let mut multi_idx = vec![0usize; ndim];
235    for flat in 0..total {
236        let idx_val = idx_data[flat];
237        let idx_usize = idx_val as usize;
238        if idx_usize >= axis_size {
239            return Err(Error::InvalidArgument(format!(
240                "scatter: index {idx_val} out of bounds for axis {axis} with size {axis_size}"
241            )));
242        }
243
244        // Compute output flat index: same multi-index but with axis replaced by idx_val
245        let mut out_flat = 0usize;
246        let mut stride = 1usize;
247        for d in 0..ndim {
248            let coord = if d == axis { idx_usize } else { multi_idx[d] };
249            out_flat += coord * stride;
250            stride *= out_shape[d];
251        }
252
253        match reduction {
254            ScatterReduction::None => {
255                out_data[out_flat] = src_data[flat];
256            }
257            ScatterReduction::Add => {
258                out_data[out_flat] = out_data[out_flat] + src_data[flat];
259            }
260        }
261
262        // Increment multi-index (column-major order)
263        for d in 0..ndim {
264            multi_idx[d] += 1;
265            if multi_idx[d] < src_shape[d] {
266                break;
267            }
268            multi_idx[d] = 0;
269        }
270    }
271
272    *output = Tensor::from_slice(&out_data, &out_shape, MemoryOrder::ColumnMajor)
273        .map_err(|e| Error::InvalidArgument(format!("scatter output: {e}")))?;
274    Ok(())
275}
276
277/// Execute `IndexPut`: place values at the 1-D index positions.
278fn execute_index_put<T: Scalar>(
279    values: &Tensor<T>,
280    indices: &Tensor<i64>,
281    output: &mut Tensor<T>,
282    accumulate: bool,
283) -> Result<()> {
284    let vals = values.contiguous(MemoryOrder::ColumnMajor);
285    let vals_data = vals
286        .buffer()
287        .as_slice()
288        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
289    let idx = indices.contiguous(MemoryOrder::ColumnMajor);
290    let idx_data = idx
291        .buffer()
292        .as_slice()
293        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
294
295    let out_shape = output.dims().to_vec();
296    let out_total: usize = out_shape.iter().product();
297    let out_contig = output.contiguous(MemoryOrder::ColumnMajor);
298    let mut out_data = out_contig
299        .buffer()
300        .as_slice()
301        .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?
302        .to_vec();
303
304    if idx_data.len() != vals_data.len() {
305        return Err(Error::InvalidArgument(format!(
306            "index_put: indices length {} != values length {}",
307            idx_data.len(),
308            vals_data.len()
309        )));
310    }
311
312    for (i, &idx_val) in idx_data.iter().enumerate() {
313        let idx_usize = idx_val as usize;
314        if idx_usize >= out_total {
315            return Err(Error::InvalidArgument(format!(
316                "index_put: index {idx_val} out of bounds for output of size {out_total}"
317            )));
318        }
319        if accumulate {
320            out_data[idx_usize] = out_data[idx_usize] + vals_data[i];
321        } else {
322            out_data[idx_usize] = vals_data[i];
323        }
324    }
325
326    *output = Tensor::from_slice(&out_data, &out_shape, MemoryOrder::ColumnMajor)
327        .map_err(|e| Error::InvalidArgument(format!("index_put output: {e}")))?;
328    Ok(())
329}
330
331impl<S: Scalar + 'static> TensorIndexingPrims<Standard<S>> for CpuBackend {
332    type Plan = CpuIndexingPlan;
333    type Context = CpuContext;
334
335    fn plan(
336        _ctx: &mut Self::Context,
337        desc: &IndexingPrimsDescriptor,
338        shapes: &[&[usize]],
339    ) -> Result<Self::Plan> {
340        validate_shape_count(shapes, 3, "IndexingPrims")?;
341        match desc {
342            IndexingPrimsDescriptor::IndexSelect { axis } => {
343                Ok(CpuIndexingPlan::IndexSelect { axis: *axis })
344            }
345            IndexingPrimsDescriptor::Gather { axis } => Ok(CpuIndexingPlan::Gather { axis: *axis }),
346            IndexingPrimsDescriptor::Scatter { axis, reduction } => Ok(CpuIndexingPlan::Scatter {
347                axis: *axis,
348                reduction: *reduction,
349            }),
350            IndexingPrimsDescriptor::IndexPut { accumulate } => Ok(CpuIndexingPlan::IndexPut {
351                accumulate: *accumulate,
352            }),
353        }
354    }
355
356    fn execute(
357        _ctx: &mut Self::Context,
358        plan: &Self::Plan,
359        inputs: &[&Tensor<S>],
360        indices: &Tensor<i64>,
361        output: &mut Tensor<S>,
362    ) -> Result<()> {
363        validate_execute_inputs(inputs, 1, "IndexingPrims")?;
364        match plan {
365            CpuIndexingPlan::IndexSelect { axis } => {
366                execute_index_select(inputs[0], indices, output, *axis)
367            }
368            CpuIndexingPlan::Gather { axis } => execute_gather(inputs[0], indices, output, *axis),
369            CpuIndexingPlan::Scatter { axis, reduction } => {
370                execute_scatter(inputs[0], indices, output, *axis, *reduction)
371            }
372            CpuIndexingPlan::IndexPut { accumulate } => {
373                execute_index_put(inputs[0], indices, output, *accumulate)
374            }
375        }
376    }
377
378    fn has_indexing_support(_desc: IndexingPrimsDescriptor) -> bool {
379        true
380    }
381}