1use tenferro_algebra::{Scalar, Standard};
2use tenferro_device::{Error, Result};
3use tenferro_tensor::{MemoryOrder, Tensor};
4
5use crate::{
6 validate_execute_inputs, validate_shape_count, CpuBackend, CpuContext, IndexingPrimsDescriptor,
7 ScatterReduction, TensorIndexingPrims,
8};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum CpuIndexingPlan {
20 IndexSelect {
22 axis: usize,
24 },
25 Gather {
27 axis: usize,
29 },
30 Scatter {
32 axis: usize,
34 reduction: ScatterReduction,
36 },
37 IndexPut {
39 accumulate: bool,
41 },
42}
43
44fn execute_index_select<T: Scalar>(
46 source: &Tensor<T>,
47 indices: &Tensor<i64>,
48 output: &mut Tensor<T>,
49 axis: usize,
50) -> Result<()> {
51 let src = source.contiguous(MemoryOrder::ColumnMajor);
52 let src_data = src
53 .buffer()
54 .as_slice()
55 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
56 let idx = indices.contiguous(MemoryOrder::ColumnMajor);
57 let idx_data = idx
58 .buffer()
59 .as_slice()
60 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
61
62 let src_shape = src.dims();
63 let ndim = src_shape.len();
64 if axis >= ndim {
65 return Err(Error::InvalidArgument(format!(
66 "index_select axis {axis} >= ndim {ndim}"
67 )));
68 }
69
70 let num_indices = idx_data.len();
71 let axis_size = src_shape[axis];
72
73 let pre_size: usize = src_shape[..axis].iter().product();
75 let post_size: usize = src_shape[axis + 1..].iter().product();
76
77 let mut out_shape = src_shape.to_vec();
79 out_shape[axis] = num_indices;
80 let out_total: usize = out_shape.iter().product();
81
82 let mut out_data = vec![T::zero(); out_total];
84
85 for post in 0..post_size {
90 for (idx_pos, &idx_val) in idx_data.iter().enumerate() {
91 let idx_usize = idx_val as usize;
92 if idx_usize >= axis_size {
93 return Err(Error::InvalidArgument(format!(
94 "index_select: index {idx_val} out of bounds for axis {axis} with size {axis_size}"
95 )));
96 }
97 let src_offset = idx_usize * pre_size + post * pre_size * axis_size;
98 let out_offset = idx_pos * pre_size + post * pre_size * num_indices;
99 out_data[out_offset..out_offset + pre_size]
100 .copy_from_slice(&src_data[src_offset..src_offset + pre_size]);
101 }
102 }
103
104 *output = Tensor::from_slice(&out_data, &out_shape, MemoryOrder::ColumnMajor)
105 .map_err(|e| Error::InvalidArgument(format!("index_select output: {e}")))?;
106 Ok(())
107}
108
109fn execute_gather<T: Scalar>(
112 source: &Tensor<T>,
113 indices: &Tensor<i64>,
114 output: &mut Tensor<T>,
115 axis: usize,
116) -> Result<()> {
117 let src = source.contiguous(MemoryOrder::ColumnMajor);
118 let src_data = src
119 .buffer()
120 .as_slice()
121 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
122 let idx = indices.contiguous(MemoryOrder::ColumnMajor);
123 let idx_data = idx
124 .buffer()
125 .as_slice()
126 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
127
128 let src_shape = src.dims();
129 let ndim = src_shape.len();
130 if axis >= ndim {
131 return Err(Error::InvalidArgument(format!(
132 "gather axis {axis} >= ndim {ndim}"
133 )));
134 }
135
136 let out_shape = idx.dims().to_vec();
137 if out_shape.len() != ndim {
138 return Err(Error::InvalidArgument(format!(
139 "gather: index tensor rank {} != source rank {ndim}",
140 out_shape.len()
141 )));
142 }
143
144 let axis_size = src_shape[axis];
145 let total: usize = out_shape.iter().product();
146 let mut out_data = vec![T::zero(); total];
147
148 let mut multi_idx = vec![0usize; ndim];
150 for flat in 0..total {
151 let idx_val = idx_data[flat];
153 let idx_usize = idx_val as usize;
154 if idx_usize >= axis_size {
155 return Err(Error::InvalidArgument(format!(
156 "gather: index {idx_val} out of bounds for axis {axis} with size {axis_size}"
157 )));
158 }
159
160 let mut src_flat = 0usize;
162 let mut stride = 1usize;
163 for d in 0..ndim {
164 let coord = if d == axis { idx_usize } else { multi_idx[d] };
165 src_flat += coord * stride;
166 stride *= src_shape[d];
167 }
168
169 out_data[flat] = src_data[src_flat];
170
171 for d in 0..ndim {
173 multi_idx[d] += 1;
174 if multi_idx[d] < out_shape[d] {
175 break;
176 }
177 multi_idx[d] = 0;
178 }
179 }
180
181 *output = Tensor::from_slice(&out_data, &out_shape, MemoryOrder::ColumnMajor)
182 .map_err(|e| Error::InvalidArgument(format!("gather output: {e}")))?;
183 Ok(())
184}
185
186fn execute_scatter<T: Scalar>(
189 source: &Tensor<T>,
190 indices: &Tensor<i64>,
191 output: &mut Tensor<T>,
192 axis: usize,
193 reduction: ScatterReduction,
194) -> Result<()> {
195 let src = source.contiguous(MemoryOrder::ColumnMajor);
196 let src_data = src
197 .buffer()
198 .as_slice()
199 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
200 let idx = indices.contiguous(MemoryOrder::ColumnMajor);
201 let idx_data = idx
202 .buffer()
203 .as_slice()
204 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
205
206 let out_shape = output.dims().to_vec();
207 let ndim = out_shape.len();
208 if axis >= ndim {
209 return Err(Error::InvalidArgument(format!(
210 "scatter axis {axis} >= ndim {ndim}"
211 )));
212 }
213
214 let src_shape = src.dims();
215 if src_shape.len() != ndim {
216 return Err(Error::InvalidArgument(format!(
217 "scatter: source rank {} != output rank {ndim}",
218 src_shape.len()
219 )));
220 }
221
222 let axis_size = out_shape[axis];
223 let total: usize = src_shape.iter().product();
224
225 let out_contig = output.contiguous(MemoryOrder::ColumnMajor);
227 let mut out_data = out_contig
228 .buffer()
229 .as_slice()
230 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?
231 .to_vec();
232
233 let mut multi_idx = vec![0usize; ndim];
235 for flat in 0..total {
236 let idx_val = idx_data[flat];
237 let idx_usize = idx_val as usize;
238 if idx_usize >= axis_size {
239 return Err(Error::InvalidArgument(format!(
240 "scatter: index {idx_val} out of bounds for axis {axis} with size {axis_size}"
241 )));
242 }
243
244 let mut out_flat = 0usize;
246 let mut stride = 1usize;
247 for d in 0..ndim {
248 let coord = if d == axis { idx_usize } else { multi_idx[d] };
249 out_flat += coord * stride;
250 stride *= out_shape[d];
251 }
252
253 match reduction {
254 ScatterReduction::None => {
255 out_data[out_flat] = src_data[flat];
256 }
257 ScatterReduction::Add => {
258 out_data[out_flat] = out_data[out_flat] + src_data[flat];
259 }
260 }
261
262 for d in 0..ndim {
264 multi_idx[d] += 1;
265 if multi_idx[d] < src_shape[d] {
266 break;
267 }
268 multi_idx[d] = 0;
269 }
270 }
271
272 *output = Tensor::from_slice(&out_data, &out_shape, MemoryOrder::ColumnMajor)
273 .map_err(|e| Error::InvalidArgument(format!("scatter output: {e}")))?;
274 Ok(())
275}
276
277fn execute_index_put<T: Scalar>(
279 values: &Tensor<T>,
280 indices: &Tensor<i64>,
281 output: &mut Tensor<T>,
282 accumulate: bool,
283) -> Result<()> {
284 let vals = values.contiguous(MemoryOrder::ColumnMajor);
285 let vals_data = vals
286 .buffer()
287 .as_slice()
288 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
289 let idx = indices.contiguous(MemoryOrder::ColumnMajor);
290 let idx_data = idx
291 .buffer()
292 .as_slice()
293 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?;
294
295 let out_shape = output.dims().to_vec();
296 let out_total: usize = out_shape.iter().product();
297 let out_contig = output.contiguous(MemoryOrder::ColumnMajor);
298 let mut out_data = out_contig
299 .buffer()
300 .as_slice()
301 .ok_or_else(|| Error::DeviceError("GPU tensor passed to CPU backend".into()))?
302 .to_vec();
303
304 if idx_data.len() != vals_data.len() {
305 return Err(Error::InvalidArgument(format!(
306 "index_put: indices length {} != values length {}",
307 idx_data.len(),
308 vals_data.len()
309 )));
310 }
311
312 for (i, &idx_val) in idx_data.iter().enumerate() {
313 let idx_usize = idx_val as usize;
314 if idx_usize >= out_total {
315 return Err(Error::InvalidArgument(format!(
316 "index_put: index {idx_val} out of bounds for output of size {out_total}"
317 )));
318 }
319 if accumulate {
320 out_data[idx_usize] = out_data[idx_usize] + vals_data[i];
321 } else {
322 out_data[idx_usize] = vals_data[i];
323 }
324 }
325
326 *output = Tensor::from_slice(&out_data, &out_shape, MemoryOrder::ColumnMajor)
327 .map_err(|e| Error::InvalidArgument(format!("index_put output: {e}")))?;
328 Ok(())
329}
330
331impl<S: Scalar + 'static> TensorIndexingPrims<Standard<S>> for CpuBackend {
332 type Plan = CpuIndexingPlan;
333 type Context = CpuContext;
334
335 fn plan(
336 _ctx: &mut Self::Context,
337 desc: &IndexingPrimsDescriptor,
338 shapes: &[&[usize]],
339 ) -> Result<Self::Plan> {
340 validate_shape_count(shapes, 3, "IndexingPrims")?;
341 match desc {
342 IndexingPrimsDescriptor::IndexSelect { axis } => {
343 Ok(CpuIndexingPlan::IndexSelect { axis: *axis })
344 }
345 IndexingPrimsDescriptor::Gather { axis } => Ok(CpuIndexingPlan::Gather { axis: *axis }),
346 IndexingPrimsDescriptor::Scatter { axis, reduction } => Ok(CpuIndexingPlan::Scatter {
347 axis: *axis,
348 reduction: *reduction,
349 }),
350 IndexingPrimsDescriptor::IndexPut { accumulate } => Ok(CpuIndexingPlan::IndexPut {
351 accumulate: *accumulate,
352 }),
353 }
354 }
355
356 fn execute(
357 _ctx: &mut Self::Context,
358 plan: &Self::Plan,
359 inputs: &[&Tensor<S>],
360 indices: &Tensor<i64>,
361 output: &mut Tensor<S>,
362 ) -> Result<()> {
363 validate_execute_inputs(inputs, 1, "IndexingPrims")?;
364 match plan {
365 CpuIndexingPlan::IndexSelect { axis } => {
366 execute_index_select(inputs[0], indices, output, *axis)
367 }
368 CpuIndexingPlan::Gather { axis } => execute_gather(inputs[0], indices, output, *axis),
369 CpuIndexingPlan::Scatter { axis, reduction } => {
370 execute_scatter(inputs[0], indices, output, *axis, *reduction)
371 }
372 CpuIndexingPlan::IndexPut { accumulate } => {
373 execute_index_put(inputs[0], indices, output, *accumulate)
374 }
375 }
376 }
377
378 fn has_indexing_support(_desc: IndexingPrimsDescriptor) -> bool {
379 true
380 }
381}