1use std::ops::Add;
2
3use num_traits::Zero;
4
5use super::indexing_alloc::pooled_uninit_tensor;
6use super::typed_host_data;
7use crate::buffer_pool::{BufferPool, PoolScalar};
8use tenferro_tensor::{GatherConfig, PadConfig, ScatterConfig, SliceConfig};
9use tenferro_tensor::{Tensor, TypedTensor};
10
11trait TensorAsTyped<T> {
18 fn as_typed(&self) -> Option<&TypedTensor<T>>;
19}
20
21macro_rules! impl_tensor_as_typed {
22 ($(($ty:ty, $variant:ident)),+ $(,)?) => {
23 $(
24 impl TensorAsTyped<$ty> for Tensor {
25 fn as_typed(&self) -> Option<&TypedTensor<$ty>> {
26 match self {
27 Tensor::$variant(tensor) => Some(tensor),
28 _ => None,
29 }
30 }
31 }
32 )+
33 };
34}
35
36impl_tensor_as_typed!(
37 (f32, F32),
38 (f64, F64),
39 (i32, I32),
40 (i64, I64),
41 (bool, Bool),
42 (num_complex::Complex<f32>, C32),
43 (num_complex::Complex<f64>, C64),
44);
45
46macro_rules! dispatch_tensor_unary_result {
47 ($input:expr, |$tensor:ident| $body:expr) => {
48 match $input {
49 Tensor::F32($tensor) => Ok(Tensor::F32($body?)),
50 Tensor::F64($tensor) => Ok(Tensor::F64($body?)),
51 Tensor::I32($tensor) => Ok(Tensor::I32($body?)),
52 Tensor::I64($tensor) => Ok(Tensor::I64($body?)),
53 Tensor::Bool($tensor) => Ok(Tensor::Bool($body?)),
54 Tensor::C32($tensor) => Ok(Tensor::C32($body?)),
55 Tensor::C64($tensor) => Ok(Tensor::C64($body?)),
56 }
57 };
58}
59
60macro_rules! dispatch_tensor_unary_with_bool_special_result {
61 ($input:expr, |$tensor:ident| $body:expr, bool |$bool_tensor:ident| $bool_body:expr) => {
62 match $input {
63 Tensor::F32($tensor) => Ok(Tensor::F32($body?)),
64 Tensor::F64($tensor) => Ok(Tensor::F64($body?)),
65 Tensor::I32($tensor) => Ok(Tensor::I32($body?)),
66 Tensor::I64($tensor) => Ok(Tensor::I64($body?)),
67 Tensor::Bool($bool_tensor) => Ok(Tensor::Bool($bool_body?)),
68 Tensor::C32($tensor) => Ok(Tensor::C32($body?)),
69 Tensor::C64($tensor) => Ok(Tensor::C64($body?)),
70 }
71 };
72}
73
74macro_rules! dispatch_same_dtype_result {
75 ($op:literal, $lhs:expr, $rhs:expr, |$lhs_t:ident, $rhs_t:ident| $body:expr) => {
76 match ($lhs, $rhs) {
77 (Tensor::F32($lhs_t), Tensor::F32($rhs_t)) => Ok(Tensor::F32($body?)),
78 (Tensor::F64($lhs_t), Tensor::F64($rhs_t)) => Ok(Tensor::F64($body?)),
79 (Tensor::I32($lhs_t), Tensor::I32($rhs_t)) => Ok(Tensor::I32($body?)),
80 (Tensor::I64($lhs_t), Tensor::I64($rhs_t)) => Ok(Tensor::I64($body?)),
81 (Tensor::Bool($lhs_t), Tensor::Bool($rhs_t)) => Ok(Tensor::Bool($body?)),
82 (Tensor::C32($lhs_t), Tensor::C32($rhs_t)) => Ok(Tensor::C32($body?)),
83 (Tensor::C64($lhs_t), Tensor::C64($rhs_t)) => Ok(Tensor::C64($body?)),
84 _ => Err(crate::Error::DTypeMismatch {
85 op: $op,
86 lhs: $lhs.dtype(),
87 rhs: $rhs.dtype(),
88 }),
89 }
90 };
91}
92
93macro_rules! dispatch_same_dtype_without_bool_result {
94 ($op:literal, $lhs:expr, $rhs:expr, $bool_message:literal, |$lhs_t:ident, $rhs_t:ident| $body:expr) => {
95 match ($lhs, $rhs) {
96 (Tensor::F32($lhs_t), Tensor::F32($rhs_t)) => Ok(Tensor::F32($body?)),
97 (Tensor::F64($lhs_t), Tensor::F64($rhs_t)) => Ok(Tensor::F64($body?)),
98 (Tensor::I32($lhs_t), Tensor::I32($rhs_t)) => Ok(Tensor::I32($body?)),
99 (Tensor::I64($lhs_t), Tensor::I64($rhs_t)) => Ok(Tensor::I64($body?)),
100 (Tensor::C32($lhs_t), Tensor::C32($rhs_t)) => Ok(Tensor::C32($body?)),
101 (Tensor::C64($lhs_t), Tensor::C64($rhs_t)) => Ok(Tensor::C64($body?)),
102 (Tensor::Bool(_), Tensor::Bool(_)) => {
103 Err(crate::Error::backend_failure($op, $bool_message))
104 }
105 _ => Err(crate::Error::DTypeMismatch {
106 op: $op,
107 lhs: $lhs.dtype(),
108 rhs: $rhs.dtype(),
109 }),
110 }
111 };
112}
113
114fn with_local_pool<T>(f: impl FnOnce(&mut BufferPool) -> T) -> T {
115 let mut buffers = BufferPool::new();
116 f(&mut buffers)
117}
118
119fn advance_col_major_index(index: &mut [usize], shape: &[usize]) {
120 debug_assert_eq!(index.len(), shape.len());
121 for axis in 0..index.len() {
122 if shape[axis] == 0 {
123 index[axis] = 0;
124 continue;
125 }
126 index[axis] += 1;
127 if index[axis] < shape[axis] {
128 break;
129 }
130 index[axis] = 0;
131 }
132}
133
134fn pooled_filled_tensor<T>(
135 buffers: &mut BufferPool,
136 shape: Vec<usize>,
137 fill: T,
138) -> crate::Result<TypedTensor<T>>
139where
140 T: Copy + Clone + PoolScalar,
141{
142 let mut out = pooled_uninit_tensor(buffers, shape)?;
144 out.host_data_mut()?.fill(fill);
145 Ok(out)
146}
147
148fn clone_host_tensor_from_pool<T>(
149 buffers: &mut BufferPool,
150 op: &'static str,
151 tensor: &TypedTensor<T>,
152) -> crate::Result<TypedTensor<T>>
153where
154 T: Copy + Clone + PoolScalar,
155{
156 let mut out = pooled_uninit_tensor(buffers, tensor.shape().to_vec())?;
158 out.host_data_mut()?
159 .copy_from_slice(typed_host_data(op, tensor)?);
160 Ok(out)
161}
162
163pub fn gather(
164 operand: &Tensor,
165 start_indices: &Tensor,
166 config: &GatherConfig,
167) -> crate::Result<Tensor> {
168 with_local_pool(|buffers| gather_with_pool(buffers, operand, start_indices, config))
169}
170
171pub(crate) fn gather_with_pool(
172 buffers: &mut BufferPool,
173 operand: &Tensor,
174 start_indices: &Tensor,
175 config: &GatherConfig,
176) -> crate::Result<Tensor> {
177 let start_indices = try_index_tensor(start_indices)?;
178 dispatch_tensor_unary_result!(operand, |t| typed_gather(
179 buffers,
180 t,
181 &start_indices,
182 config
183 ))
184}
185
186pub fn scatter(
187 operand: &Tensor,
188 scatter_indices: &Tensor,
189 updates: &Tensor,
190 config: &ScatterConfig,
191) -> crate::Result<Tensor> {
192 with_local_pool(|buffers| scatter_with_pool(buffers, operand, scatter_indices, updates, config))
193}
194
195pub(crate) fn scatter_with_pool(
196 buffers: &mut BufferPool,
197 operand: &Tensor,
198 scatter_indices: &Tensor,
199 updates: &Tensor,
200 config: &ScatterConfig,
201) -> crate::Result<Tensor> {
202 let scatter_indices = try_index_tensor(scatter_indices)?;
203 dispatch_same_dtype_without_bool_result!(
204 "scatter",
205 operand,
206 updates,
207 "Bool data tensors are not supported by additive scatter",
208 |op, upd| typed_scatter(buffers, op, &scatter_indices, upd, config)
209 )
210}
211
212pub(crate) fn try_slice_with_pool(
213 buffers: &mut BufferPool,
214 input: &Tensor,
215 config: &SliceConfig,
216) -> crate::Result<Tensor> {
217 dispatch_tensor_unary_result!(input, |t| typed_slice(buffers, t, config))
218}
219
220pub fn dynamic_slice(
221 input: &Tensor,
222 starts: &Tensor,
223 slice_sizes: &[usize],
224) -> crate::Result<Tensor> {
225 with_local_pool(|buffers| dynamic_slice_with_pool(buffers, input, starts, slice_sizes))
226}
227
228pub(crate) fn dynamic_slice_with_pool(
229 buffers: &mut BufferPool,
230 input: &Tensor,
231 starts: &Tensor,
232 slice_sizes: &[usize],
233) -> crate::Result<Tensor> {
234 let starts = try_index_tensor(starts)?;
235 dispatch_tensor_unary_result!(input, |t| typed_dynamic_slice(
236 buffers,
237 t,
238 &starts,
239 slice_sizes
240 ))
241}
242
243pub fn dynamic_update_slice(
263 operand: &Tensor,
264 update: &Tensor,
265 starts: &Tensor,
266) -> crate::Result<Tensor> {
267 with_local_pool(|buffers| dynamic_update_slice_with_pool(buffers, operand, update, starts))
268}
269
270pub(crate) fn dynamic_update_slice_with_pool(
271 buffers: &mut BufferPool,
272 operand: &Tensor,
273 update: &Tensor,
274 starts: &Tensor,
275) -> crate::Result<Tensor> {
276 let starts = try_index_tensor(starts)?;
277 dispatch_same_dtype_result!("dynamic_update_slice", operand, update, |op, upd| {
278 typed_dynamic_update_slice(buffers, op, upd, &starts)
279 })
280}
281
282pub fn pad(input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
283 try_pad(input, config)
284}
285
286fn try_pad(input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
287 with_local_pool(|buffers| try_pad_with_pool(buffers, input, config))
288}
289
290pub(crate) fn try_pad_with_pool(
291 buffers: &mut BufferPool,
292 input: &Tensor,
293 config: &PadConfig,
294) -> crate::Result<Tensor> {
295 dispatch_tensor_unary_with_bool_special_result!(
296 input,
297 |t| typed_pad(buffers, t, config),
298 bool | t | typed_pad_with_fill(buffers, t, config, false)
299 )
300}
301
302pub(crate) fn try_concatenate_with_pool(
303 buffers: &mut BufferPool,
304 inputs: &[&Tensor],
305 axis: usize,
306) -> crate::Result<Tensor> {
307 let first = inputs
308 .first()
309 .copied()
310 .ok_or_else(|| crate::Error::InvalidConfig {
311 op: "concatenate",
312 message: "concatenate requires at least one input".into(),
313 })?;
314 dispatch_tensor_unary_result!(first, |t| typed_concatenate_from_dyn_inputs(
315 buffers, t, inputs, axis
316 ))
317}
318
319pub(crate) fn reverse_with_pool(
320 buffers: &mut BufferPool,
321 input: &Tensor,
322 axes: &[usize],
323) -> crate::Result<Tensor> {
324 dispatch_tensor_unary_result!(input, |t| typed_reverse(buffers, t, axes))
325}
326
327fn typed_slice<T: Copy + Clone + PoolScalar>(
328 buffers: &mut BufferPool,
329 input: &TypedTensor<T>,
330 config: &SliceConfig,
331) -> crate::Result<TypedTensor<T>> {
332 let input_shape = input.shape();
333 let rank = input_shape.len();
334 if config.starts.len() != rank {
335 return Err(crate::Error::RankMismatch {
336 op: "slice",
337 expected: rank,
338 actual: config.starts.len(),
339 });
340 }
341 if config.limits.len() != rank {
342 return Err(crate::Error::RankMismatch {
343 op: "slice",
344 expected: rank,
345 actual: config.limits.len(),
346 });
347 }
348 if config.strides.len() != rank {
349 return Err(crate::Error::RankMismatch {
350 op: "slice",
351 expected: rank,
352 actual: config.strides.len(),
353 });
354 }
355
356 let out_shape: Vec<usize> = input
357 .shape()
358 .iter()
359 .enumerate()
360 .map(|(axis, &dim)| {
361 let start = config.starts[axis];
362 let limit = config.limits[axis];
363 let stride = config.strides[axis];
364 if start > limit {
365 return Err(crate::Error::InvalidConfig {
366 op: "slice",
367 message: format!("start exceeds limit on axis {axis}"),
368 });
369 }
370 if limit > dim {
371 return Err(crate::Error::AxisOutOfBounds {
372 op: "slice",
373 axis,
374 rank,
375 });
376 }
377 if stride == 0 {
378 return Err(crate::Error::InvalidConfig {
379 op: "slice",
380 message: format!("stride must be positive on axis {axis}"),
381 });
382 }
383 let span = limit - start;
384 Ok(span.div_ceil(stride))
385 })
386 .collect::<crate::Result<Vec<_>>>()?;
387
388 let mut out = pooled_uninit_tensor(buffers, out_shape.clone())?;
390 let mut out_idx = vec![0usize; rank];
391 let mut in_idx = vec![0usize; rank];
392
393 for out_value in out.host_data_mut()?.iter_mut() {
394 for axis in 0..rank {
395 in_idx[axis] = config.starts[axis] + out_idx[axis] * config.strides[axis];
396 }
397 *out_value = *input.get(&in_idx)?;
398 advance_col_major_index(&mut out_idx, &out_shape);
399 }
400
401 Ok(out)
402}
403
404fn typed_concatenate_from_dyn_inputs<T>(
405 buffers: &mut BufferPool,
406 _first: &TypedTensor<T>,
407 inputs: &[&Tensor],
408 axis: usize,
409) -> crate::Result<TypedTensor<T>>
410where
411 T: Copy + Clone + PoolScalar,
412 Tensor: TensorAsTyped<T>,
413{
414 let first_dtype = inputs[0].dtype();
415 let typed_inputs = collect_typed_inputs(first_dtype, inputs)?;
416 typed_concatenate(buffers, &typed_inputs, axis)
417}
418
419fn collect_typed_inputs<'a, T>(
420 first_dtype: crate::DType,
421 inputs: &[&'a Tensor],
422) -> crate::Result<Vec<&'a TypedTensor<T>>>
423where
424 Tensor: TensorAsTyped<T>,
425{
426 inputs
427 .iter()
428 .map(|tensor| {
429 TensorAsTyped::<T>::as_typed(*tensor).ok_or_else(|| crate::Error::DTypeMismatch {
430 op: "concatenate",
431 lhs: first_dtype,
432 rhs: tensor.dtype(),
433 })
434 })
435 .collect()
436}
437
438fn typed_concatenate<T: Copy + Clone + PoolScalar>(
439 buffers: &mut BufferPool,
440 inputs: &[&TypedTensor<T>],
441 axis: usize,
442) -> crate::Result<TypedTensor<T>> {
443 let first = inputs[0];
444 let first_shape = first.shape();
445 let rank = first_shape.len();
446 if axis >= rank {
447 return Err(crate::Error::AxisOutOfBounds {
448 op: "concatenate",
449 axis,
450 rank,
451 });
452 }
453
454 let mut out_shape = first_shape.to_vec();
455 let mut axis_extent = 0usize;
456 for input in inputs {
457 let input_shape = input.shape();
458 if input_shape.len() != rank {
459 return Err(crate::Error::RankMismatch {
460 op: "concatenate",
461 expected: rank,
462 actual: input_shape.len(),
463 });
464 }
465 for dim in 0..rank {
466 if dim == axis {
467 axis_extent = axis_extent.checked_add(input_shape[dim]).ok_or_else(|| {
468 crate::Error::InvalidConfig {
469 op: "concatenate",
470 message: "concatenate axis extent overflows usize".to_string(),
471 }
472 })?;
473 } else if input_shape[dim] != first_shape[dim] {
474 return Err(crate::Error::ShapeMismatch {
475 op: "concatenate",
476 lhs: first_shape.to_vec(),
477 rhs: input_shape.to_vec(),
478 });
479 }
480 }
481 }
482 out_shape[axis] = axis_extent;
483
484 let mut segment_ends = Vec::with_capacity(inputs.len());
485 let mut segment_end = 0usize;
486 for input in inputs {
487 segment_end = segment_end
488 .checked_add(input.shape()[axis])
489 .ok_or_else(|| crate::Error::InvalidConfig {
490 op: "concatenate",
491 message: "concatenate segment offset overflows usize".to_string(),
492 })?;
493 segment_ends.push(segment_end);
494 }
495
496 let mut out = pooled_uninit_tensor(buffers, out_shape.clone())?;
498 let mut out_idx = vec![0usize; rank];
499 let mut in_idx = vec![0usize; rank];
500
501 for out_value in out.host_data_mut()?.iter_mut() {
502 let concat_idx = out_idx[axis];
503 let input_pos = segment_ends.partition_point(|&end| concat_idx >= end);
504 if input_pos == segment_ends.len() {
505 return Err(crate::Error::InvalidConfig {
506 op: "concatenate",
507 message: "output index must map to an input".to_string(),
508 });
509 }
510 let axis_base = if input_pos == 0 {
511 0
512 } else {
513 segment_ends[input_pos - 1]
514 };
515
516 in_idx.copy_from_slice(&out_idx);
517 in_idx[axis] -= axis_base;
518 *out_value = *inputs[input_pos].get(&in_idx)?;
519 advance_col_major_index(&mut out_idx, &out_shape);
520 }
521
522 Ok(out)
523}
524
525fn typed_reverse<T: Copy + Clone + PoolScalar>(
526 buffers: &mut BufferPool,
527 input: &TypedTensor<T>,
528 axes: &[usize],
529) -> crate::Result<TypedTensor<T>> {
530 let input_shape = input.shape();
531 let rank = input_shape.len();
532 let mut reverse_axis = vec![false; rank];
533 for &axis in axes {
534 if axis >= rank {
535 return Err(crate::Error::AxisOutOfBounds {
536 op: "reverse",
537 axis,
538 rank,
539 });
540 }
541 reverse_axis[axis] = true;
542 }
543
544 let mut out = pooled_uninit_tensor(buffers, input_shape.to_vec())?;
546 let mut out_idx = vec![0usize; rank];
547 let mut in_idx = vec![0usize; rank];
548
549 for out_value in out.host_data_mut()?.iter_mut() {
550 for axis in 0..rank {
551 in_idx[axis] = if reverse_axis[axis] {
552 input_shape[axis] - 1 - out_idx[axis]
553 } else {
554 out_idx[axis]
555 };
556 }
557 *out_value = *input.get(&in_idx)?;
558 advance_col_major_index(&mut out_idx, input_shape);
559 }
560
561 Ok(out)
562}
563
564struct IndexTensor {
565 shape: Vec<usize>,
566 values: Vec<i64>,
567}
568
569const F32_MAX_EXACT_INT: f32 = 16_777_216.0;
571const F64_MAX_EXACT_INT: f64 = 9_007_199_254_740_992.0;
573
574fn f32_index_to_i64(value: f32) -> crate::Result<i64> {
575 if !value.is_finite() || value.fract() != 0.0 || value.abs() > F32_MAX_EXACT_INT {
576 return Err(crate::Error::InvalidConfig {
577 op: "index_tensor",
578 message: format!("index value {value} is not an exactly representable i64"),
579 });
580 }
581 Ok(value as i64)
582}
583
584fn f64_index_to_i64(value: f64) -> crate::Result<i64> {
585 if !value.is_finite() || value.fract() != 0.0 || value.abs() > F64_MAX_EXACT_INT {
586 return Err(crate::Error::InvalidConfig {
587 op: "index_tensor",
588 message: format!("index value {value} is not an exactly representable i64"),
589 });
590 }
591 Ok(value as i64)
592}
593
594fn try_index_tensor(tensor: &Tensor) -> crate::Result<IndexTensor> {
595 match tensor {
596 Tensor::I32(t) => Ok(IndexTensor {
597 shape: t.shape().to_vec(),
598 values: typed_host_data("index_tensor", t)?
599 .iter()
600 .map(|&value| value as i64)
601 .collect(),
602 }),
603 Tensor::I64(t) => Ok(IndexTensor {
604 shape: t.shape().to_vec(),
605 values: typed_host_data("index_tensor", t)?.to_vec(),
606 }),
607 Tensor::F32(t) => {
608 let values: crate::Result<Vec<i64>> = typed_host_data("index_tensor", t)?
609 .iter()
610 .map(|&value| f32_index_to_i64(value))
611 .collect();
612 Ok(IndexTensor {
613 shape: t.shape().to_vec(),
614 values: values?,
615 })
616 }
617 Tensor::F64(t) => {
618 let values: crate::Result<Vec<i64>> = typed_host_data("index_tensor", t)?
619 .iter()
620 .map(|&value| f64_index_to_i64(value))
621 .collect();
622 Ok(IndexTensor {
623 shape: t.shape().to_vec(),
624 values: values?,
625 })
626 }
627 Tensor::Bool(_) => Err(crate::Error::InvalidConfig {
628 op: "index_tensor",
629 message: "bool index tensors are not supported".into(),
630 }),
631 Tensor::C32(_) | Tensor::C64(_) => Err(crate::Error::InvalidConfig {
632 op: "index_tensor",
633 message: "complex index tensors are not supported".into(),
634 }),
635 }
636}
637
638fn checked_product(op: &'static str, role: &'static str, shape: &[usize]) -> crate::Result<usize> {
639 shape.iter().try_fold(1usize, |acc, &dim| {
640 acc.checked_mul(dim)
641 .ok_or_else(|| crate::Error::InvalidConfig {
642 op,
643 message: format!("{role} element count overflows usize"),
644 })
645 })
646}
647
648fn linear_offset(op: &'static str, shape: &[usize], indices: &[usize]) -> crate::Result<usize> {
649 let mut offset = 0usize;
650 let mut stride = 1usize;
651 for (axis, &index) in indices.iter().enumerate() {
652 let scaled = index
653 .checked_mul(stride)
654 .ok_or_else(|| crate::Error::InvalidConfig {
655 op,
656 message: format!("linear index component overflows usize on axis {axis}"),
657 })?;
658 offset = offset
659 .checked_add(scaled)
660 .ok_or_else(|| crate::Error::InvalidConfig {
661 op,
662 message: format!("linear offset overflows usize on axis {axis}"),
663 })?;
664 stride = stride
665 .checked_mul(shape[axis])
666 .ok_or_else(|| crate::Error::InvalidConfig {
667 op,
668 message: format!("linear stride overflows usize after axis {axis}"),
669 })?;
670 }
671 Ok(offset)
672}
673
674fn try_index_vector_size(
675 op: &'static str,
676 shape: &[usize],
677 index_vector_dim: usize,
678) -> crate::Result<usize> {
679 if index_vector_dim > shape.len() {
680 return Err(crate::Error::AxisOutOfBounds {
681 op,
682 axis: index_vector_dim,
683 rank: shape.len(),
684 });
685 }
686 Ok(if index_vector_dim == shape.len() {
687 1
688 } else {
689 shape[index_vector_dim]
690 })
691}
692
693fn try_index_batch_shape(
694 op: &'static str,
695 shape: &[usize],
696 index_vector_dim: usize,
697) -> crate::Result<Vec<usize>> {
698 if index_vector_dim > shape.len() {
699 return Err(crate::Error::AxisOutOfBounds {
700 op,
701 axis: index_vector_dim,
702 rank: shape.len(),
703 });
704 }
705 if index_vector_dim == shape.len() {
706 return Ok(shape.to_vec());
707 }
708 Ok(shape
709 .iter()
710 .enumerate()
711 .filter_map(|(axis, &dim)| (axis != index_vector_dim).then_some(dim))
712 .collect())
713}
714
715fn index_component(
716 op: &'static str,
717 indices: &IndexTensor,
718 batch_idx: &[usize],
719 index_vector_dim: usize,
720 component: usize,
721 index_scratch: &mut [usize],
722) -> crate::Result<i64> {
723 if index_vector_dim == indices.shape.len() {
724 if component != 0 {
725 return Err(crate::Error::InvalidConfig {
726 op,
727 message: "implicit index_vector_dim only supports scalar index vectors".into(),
728 });
729 }
730 return Ok(indices.values[linear_offset(op, &indices.shape, batch_idx)?]);
731 }
732
733 debug_assert_eq!(index_scratch.len(), indices.shape.len());
734 let mut batch_axis = 0usize;
735 for (axis, slot) in index_scratch.iter_mut().enumerate() {
736 if axis == index_vector_dim {
737 *slot = component;
738 } else {
739 *slot = batch_idx[batch_axis];
740 batch_axis += 1;
741 }
742 }
743 Ok(indices.values[linear_offset(op, &indices.shape, index_scratch)?])
744}
745
746fn clamp_window_start(
747 op: &'static str,
748 start: i64,
749 dim_size: usize,
750 window_size: usize,
751) -> crate::Result<usize> {
752 if window_size > dim_size {
753 return Err(crate::Error::InvalidConfig {
754 op,
755 message: format!("window size {window_size} exceeds dimension size {dim_size}"),
756 });
757 }
758 let max_start = dim_size.saturating_sub(window_size) as i64;
759 Ok(start.clamp(0, max_start) as usize)
760}
761
762fn operand_window_dims(rank: usize, collapsed_or_inserted: &[usize]) -> Vec<usize> {
763 (0..rank)
764 .filter(|dim| !collapsed_or_inserted.contains(dim))
765 .collect()
766}
767
768fn typed_gather<T: Copy + Clone + PoolScalar>(
769 buffers: &mut BufferPool,
770 operand: &TypedTensor<T>,
771 start_indices: &IndexTensor,
772 config: &GatherConfig,
773) -> crate::Result<TypedTensor<T>> {
774 let operand_shape = operand.shape();
775 let rank = operand_shape.len();
776 if config.slice_sizes.len() != rank {
777 return Err(crate::Error::RankMismatch {
778 op: "gather",
779 expected: rank,
780 actual: config.slice_sizes.len(),
781 });
782 }
783
784 for &dim in &config.collapsed_slice_dims {
785 if dim >= rank {
786 return Err(crate::Error::AxisOutOfBounds {
787 op: "gather",
788 axis: dim,
789 rank,
790 });
791 }
792 }
793 {
794 let mut seen = vec![false; rank];
795 for &dim in &config.collapsed_slice_dims {
796 if seen[dim] {
797 return Err(crate::Error::DuplicateAxis {
798 op: "gather",
799 axis: dim,
800 role: "collapsed_slice_dims",
801 });
802 }
803 seen[dim] = true;
804 }
805 }
806 for &dim in &config.collapsed_slice_dims {
807 if config.slice_sizes[dim] != 1 {
808 return Err(crate::Error::InvalidConfig {
809 op: "gather",
810 message: format!(
811 "collapsed slice dimension {dim} must have slice_size == 1, got {}",
812 config.slice_sizes[dim]
813 ),
814 });
815 }
816 }
817
818 let index_size =
819 try_index_vector_size("gather", &start_indices.shape, config.index_vector_dim)?;
820 if index_size != config.start_index_map.len() {
821 return Err(crate::Error::InvalidConfig {
822 op: "gather",
823 message: format!(
824 "start_index_map length {} does not match index vector size {}",
825 config.start_index_map.len(),
826 index_size
827 ),
828 });
829 }
830 for &operand_dim in &config.start_index_map {
831 if operand_dim >= rank {
832 return Err(crate::Error::AxisOutOfBounds {
833 op: "gather",
834 axis: operand_dim,
835 rank,
836 });
837 }
838 }
839 {
840 let mut seen = vec![false; rank];
841 for &operand_dim in &config.start_index_map {
842 if seen[operand_dim] {
843 return Err(crate::Error::DuplicateAxis {
844 op: "gather",
845 axis: operand_dim,
846 role: "start_index_map",
847 });
848 }
849 seen[operand_dim] = true;
850 }
851 }
852
853 let window_dims = operand_window_dims(rank, &config.collapsed_slice_dims);
854 if config.offset_dims.len() != window_dims.len() {
855 return Err(crate::Error::InvalidConfig {
856 op: "gather",
857 message: format!(
858 "offset_dims length {} does not match window dims count {}",
859 config.offset_dims.len(),
860 window_dims.len()
861 ),
862 });
863 }
864
865 let batch_shape =
866 try_index_batch_shape("gather", &start_indices.shape, config.index_vector_dim)?;
867 let out_rank = batch_shape.len() + config.offset_dims.len();
868 for &out_axis in &config.offset_dims {
869 if out_axis >= out_rank {
870 return Err(crate::Error::AxisOutOfBounds {
871 op: "gather",
872 axis: out_axis,
873 rank: out_rank,
874 });
875 }
876 }
877 {
878 let mut seen = vec![false; out_rank];
879 for &out_axis in &config.offset_dims {
880 if seen[out_axis] {
881 return Err(crate::Error::DuplicateAxis {
882 op: "gather",
883 axis: out_axis,
884 role: "offset_dims",
885 });
886 }
887 seen[out_axis] = true;
888 }
889 }
890
891 let mut out_axis_to_operand_dim = vec![None; out_rank];
892 for (offset_axis, &out_axis) in config.offset_dims.iter().enumerate() {
893 out_axis_to_operand_dim[out_axis] = Some(window_dims[offset_axis]);
894 }
895
896 let mut out_shape = vec![0usize; out_rank];
897 let mut batch_axis = 0usize;
898 for out_axis in 0..out_rank {
899 if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
900 out_shape[out_axis] = config.slice_sizes[operand_dim];
901 } else {
902 out_shape[out_axis] = batch_shape[batch_axis];
903 batch_axis += 1;
904 }
905 }
906
907 for (component, &operand_dim) in config.start_index_map.iter().enumerate() {
908 let _ = clamp_window_start(
909 "gather",
910 0,
911 operand_shape[operand_dim],
912 config.slice_sizes[operand_dim],
913 )?;
914 let _ = component;
915 }
916
917 let mut out = pooled_uninit_tensor(buffers, out_shape.clone())?;
919 let mut out_idx = vec![0usize; out_rank];
920 let mut batch_idx = vec![0usize; batch_shape.len()];
921 let mut operand_idx = vec![0usize; rank];
922 let mut window_offsets = vec![0usize; rank];
923 let mut index_scratch = vec![0usize; start_indices.shape.len()];
924
925 for out_value in out.host_data_mut()?.iter_mut() {
926 batch_axis = 0;
927 window_offsets.fill(0);
928 for out_axis in 0..out_rank {
929 if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
930 window_offsets[operand_dim] = out_idx[out_axis];
931 } else {
932 batch_idx[batch_axis] = out_idx[out_axis];
933 batch_axis += 1;
934 }
935 }
936
937 operand_idx.fill(0);
938 for (component, &operand_dim) in config.start_index_map.iter().enumerate() {
939 let start = index_component(
940 "gather",
941 start_indices,
942 &batch_idx,
943 config.index_vector_dim,
944 component,
945 &mut index_scratch,
946 )?;
947 operand_idx[operand_dim] = clamp_window_start(
948 "gather",
949 start,
950 operand_shape[operand_dim],
951 config.slice_sizes[operand_dim],
952 )?;
953 }
954
955 for axis in 0..operand_idx.len() {
956 operand_idx[axis] += window_offsets[axis];
957 }
958
959 *out_value = *operand.get(&operand_idx)?;
960 advance_col_major_index(&mut out_idx, &out_shape);
961 }
962
963 Ok(out)
964}
965
966fn typed_scatter<T>(
967 buffers: &mut BufferPool,
968 operand: &TypedTensor<T>,
969 scatter_indices: &IndexTensor,
970 updates: &TypedTensor<T>,
971 config: &ScatterConfig,
972) -> crate::Result<TypedTensor<T>>
973where
974 T: Copy + Clone + Zero + Add<Output = T> + PoolScalar,
975{
976 let operand_shape = operand.shape();
977 let updates_shape = updates.shape();
978 let op_rank = operand_shape.len();
979 for &dim in &config.inserted_window_dims {
980 if dim >= op_rank {
981 return Err(crate::Error::AxisOutOfBounds {
982 op: "scatter",
983 axis: dim,
984 rank: op_rank,
985 });
986 }
987 }
988 {
989 let mut seen = vec![false; op_rank];
990 for &dim in &config.inserted_window_dims {
991 if seen[dim] {
992 return Err(crate::Error::DuplicateAxis {
993 op: "scatter",
994 axis: dim,
995 role: "inserted_window_dims",
996 });
997 }
998 seen[dim] = true;
999 }
1000 }
1001
1002 let index_size =
1003 try_index_vector_size("scatter", &scatter_indices.shape, config.index_vector_dim)?;
1004 if index_size != config.scatter_dims_to_operand_dims.len() {
1005 return Err(crate::Error::InvalidConfig {
1006 op: "scatter",
1007 message: format!(
1008 "scatter_dims_to_operand_dims length {} does not match index vector size {}",
1009 config.scatter_dims_to_operand_dims.len(),
1010 index_size
1011 ),
1012 });
1013 }
1014 for &operand_dim in &config.scatter_dims_to_operand_dims {
1015 if operand_dim >= op_rank {
1016 return Err(crate::Error::AxisOutOfBounds {
1017 op: "scatter",
1018 axis: operand_dim,
1019 rank: op_rank,
1020 });
1021 }
1022 }
1023 {
1024 let mut seen = vec![false; op_rank];
1025 for &operand_dim in &config.scatter_dims_to_operand_dims {
1026 if seen[operand_dim] {
1027 return Err(crate::Error::DuplicateAxis {
1028 op: "scatter",
1029 axis: operand_dim,
1030 role: "scatter_dims_to_operand_dims",
1031 });
1032 }
1033 seen[operand_dim] = true;
1034 }
1035 }
1036
1037 let batch_shape =
1038 try_index_batch_shape("scatter", &scatter_indices.shape, config.index_vector_dim)?;
1039 let window_dims = operand_window_dims(op_rank, &config.inserted_window_dims);
1040 if config.update_window_dims.len() != window_dims.len() {
1041 return Err(crate::Error::InvalidConfig {
1042 op: "scatter",
1043 message: format!(
1044 "update_window_dims length {} does not match window dims count {}",
1045 config.update_window_dims.len(),
1046 window_dims.len()
1047 ),
1048 });
1049 }
1050
1051 let update_rank = updates_shape.len();
1052 let expected_batch_rank = update_rank
1053 .checked_sub(config.update_window_dims.len())
1054 .ok_or_else(|| crate::Error::InvalidConfig {
1055 op: "scatter",
1056 message: format!(
1057 "update_window_dims length {} exceeds update rank {}",
1058 config.update_window_dims.len(),
1059 update_rank
1060 ),
1061 })?;
1062 if expected_batch_rank != batch_shape.len() {
1063 return Err(crate::Error::InvalidConfig {
1064 op: "scatter",
1065 message: format!(
1066 "updates batch rank {} does not match index batch shape length {}",
1067 expected_batch_rank,
1068 batch_shape.len()
1069 ),
1070 });
1071 }
1072
1073 for &axis in &config.update_window_dims {
1074 if axis >= update_rank {
1075 return Err(crate::Error::AxisOutOfBounds {
1076 op: "scatter",
1077 axis,
1078 rank: update_rank,
1079 });
1080 }
1081 }
1082 {
1083 let mut seen = vec![false; update_rank];
1084 for &axis in &config.update_window_dims {
1085 if seen[axis] {
1086 return Err(crate::Error::DuplicateAxis {
1087 op: "scatter",
1088 axis,
1089 role: "update_window_dims",
1090 });
1091 }
1092 seen[axis] = true;
1093 }
1094 }
1095
1096 let mut is_update_window_dim = vec![false; update_rank];
1097 for &axis in &config.update_window_dims {
1098 is_update_window_dim[axis] = true;
1099 }
1100
1101 {
1102 let mut batch_axis = 0usize;
1103 for axis in 0..update_rank {
1104 if !is_update_window_dim[axis] {
1105 if updates_shape[axis] != batch_shape[batch_axis] {
1106 return Err(crate::Error::InvalidConfig {
1107 op: "scatter",
1108 message: format!(
1109 "updates batch dim {} extent {} does not match index batch dim {} extent {}",
1110 axis,
1111 updates_shape[axis],
1112 batch_axis,
1113 batch_shape[batch_axis]
1114 ),
1115 });
1116 }
1117 batch_axis += 1;
1118 }
1119 }
1120 }
1121
1122 let mut window_shape = vec![1usize; op_rank];
1123 let mut window_shape_updates = vec![0usize; config.update_window_dims.len()];
1124 for (pos, &update_axis) in config.update_window_dims.iter().enumerate() {
1125 let dim = updates_shape[update_axis];
1126 window_shape_updates[pos] = dim;
1127 window_shape[window_dims[pos]] = dim;
1128 }
1129
1130 let batch_elems = checked_product("scatter", "batch shape", &batch_shape)?;
1131 let window_elems = checked_product("scatter", "window update shape", &window_shape_updates)?;
1132 let mut out = clone_host_tensor_from_pool(buffers, "scatter", operand)?;
1133
1134 let mut batch_idx = vec![0usize; batch_shape.len()];
1135 let mut window_idx = vec![0usize; window_shape_updates.len()];
1136 let mut update_idx = vec![0usize; update_rank];
1137 let mut operand_base = vec![0usize; op_rank];
1138 let mut operand_idx = vec![0usize; op_rank];
1139 let mut index_scratch = vec![0usize; scatter_indices.shape.len()];
1140
1141 for _ in 0..batch_elems {
1142 let mut window_fits = true;
1143 operand_base.fill(0);
1144 for (component, &operand_dim) in config.scatter_dims_to_operand_dims.iter().enumerate() {
1145 let start = index_component(
1146 "scatter",
1147 scatter_indices,
1148 &batch_idx,
1149 config.index_vector_dim,
1150 component,
1151 &mut index_scratch,
1152 )?;
1153 if start < 0 {
1154 window_fits = false;
1155 break;
1156 }
1157 operand_base[operand_dim] = start as usize;
1158 }
1159 if !window_fits {
1160 advance_col_major_index(&mut batch_idx, &batch_shape);
1161 continue;
1162 }
1163
1164 for axis in 0..op_rank {
1165 if operand_base[axis] + window_shape[axis] > operand_shape[axis] {
1166 window_fits = false;
1167 break;
1168 }
1169 }
1170 if !window_fits {
1171 advance_col_major_index(&mut batch_idx, &batch_shape);
1172 continue;
1173 }
1174
1175 window_idx.fill(0);
1176 for _ in 0..window_elems {
1177 let mut batch_axis = 0usize;
1178 let mut window_axis = 0usize;
1179 for axis in 0..update_rank {
1180 if is_update_window_dim[axis] {
1181 update_idx[axis] = window_idx[window_axis];
1182 window_axis += 1;
1183 } else {
1184 update_idx[axis] = batch_idx[batch_axis];
1185 batch_axis += 1;
1186 }
1187 }
1188
1189 operand_idx.copy_from_slice(&operand_base);
1190 for (window_axis, &operand_axis) in window_dims.iter().enumerate() {
1191 operand_idx[operand_axis] += window_idx[window_axis];
1192 }
1193
1194 let value = *updates.get(&update_idx)?;
1195 let slot = out.get_mut(&operand_idx)?;
1196 *slot = *slot + value;
1197 advance_col_major_index(&mut window_idx, &window_shape_updates);
1198 }
1199 advance_col_major_index(&mut batch_idx, &batch_shape);
1200 }
1201
1202 Ok(out)
1203}
1204
1205fn typed_dynamic_slice<T: Copy + Clone + PoolScalar>(
1206 buffers: &mut BufferPool,
1207 input: &TypedTensor<T>,
1208 starts: &IndexTensor,
1209 slice_sizes: &[usize],
1210) -> crate::Result<TypedTensor<T>> {
1211 let input_shape = input.shape();
1212 if slice_sizes.len() != input_shape.len() {
1213 return Err(crate::Error::RankMismatch {
1214 op: "dynamic_slice",
1215 expected: input_shape.len(),
1216 actual: slice_sizes.len(),
1217 });
1218 }
1219 if starts.shape.len() != 1 {
1220 return Err(crate::Error::InvalidConfig {
1221 op: "dynamic_slice",
1222 message: "starts must be a rank-1 tensor".into(),
1223 });
1224 }
1225 if starts.values.len() != input_shape.len() {
1226 return Err(crate::Error::InvalidConfig {
1227 op: "dynamic_slice",
1228 message: format!(
1229 "starts length {} must match input rank {}",
1230 starts.values.len(),
1231 input_shape.len()
1232 ),
1233 });
1234 }
1235
1236 let mut clamped_starts = vec![0usize; input_shape.len()];
1237 for axis in 0..input_shape.len() {
1238 clamped_starts[axis] = clamp_window_start(
1239 "dynamic_slice",
1240 starts.values[axis],
1241 input_shape[axis],
1242 slice_sizes[axis],
1243 )?;
1244 }
1245
1246 let out_shape = slice_sizes.to_vec();
1247 let mut out = pooled_uninit_tensor(buffers, out_shape.clone())?;
1249 let mut out_idx = vec![0usize; out_shape.len()];
1250 let mut input_idx = vec![0usize; out_shape.len()];
1251
1252 for out_value in out.host_data_mut()?.iter_mut() {
1253 for axis in 0..out_shape.len() {
1254 input_idx[axis] = clamped_starts[axis] + out_idx[axis];
1255 }
1256 *out_value = *input.get(&input_idx)?;
1257 advance_col_major_index(&mut out_idx, &out_shape);
1258 }
1259
1260 Ok(out)
1261}
1262
1263fn typed_dynamic_update_slice<T: Copy + Clone + PoolScalar>(
1264 buffers: &mut BufferPool,
1265 operand: &TypedTensor<T>,
1266 update: &TypedTensor<T>,
1267 starts: &IndexTensor,
1268) -> crate::Result<TypedTensor<T>> {
1269 let operand_shape = operand.shape();
1270 let update_shape = update.shape();
1271 if update_shape.len() != operand_shape.len() {
1272 return Err(crate::Error::RankMismatch {
1273 op: "dynamic_update_slice",
1274 expected: operand_shape.len(),
1275 actual: update_shape.len(),
1276 });
1277 }
1278 if starts.shape.len() != 1 {
1279 return Err(crate::Error::InvalidConfig {
1280 op: "dynamic_update_slice",
1281 message: "starts must be a rank-1 tensor".into(),
1282 });
1283 }
1284 if starts.values.len() != operand_shape.len() {
1285 return Err(crate::Error::InvalidConfig {
1286 op: "dynamic_update_slice",
1287 message: format!(
1288 "starts length {} must match operand rank {}",
1289 starts.values.len(),
1290 operand_shape.len()
1291 ),
1292 });
1293 }
1294
1295 let mut clamped_starts = vec![0usize; operand_shape.len()];
1296 for axis in 0..operand_shape.len() {
1297 clamped_starts[axis] = clamp_window_start(
1298 "dynamic_update_slice",
1299 starts.values[axis],
1300 operand_shape[axis],
1301 update_shape[axis],
1302 )?;
1303 }
1304
1305 let mut out = clone_host_tensor_from_pool(buffers, "dynamic_update_slice", operand)?;
1306 let mut update_idx = vec![0usize; update_shape.len()];
1307 let mut operand_idx = vec![0usize; operand_shape.len()];
1308
1309 for update_value in update.as_slice()? {
1310 for axis in 0..update_shape.len() {
1311 operand_idx[axis] = clamped_starts[axis] + update_idx[axis];
1312 }
1313 *out.get_mut(&operand_idx)? = *update_value;
1314 advance_col_major_index(&mut update_idx, update_shape);
1315 }
1316
1317 Ok(out)
1318}
1319
1320fn typed_pad<T: Copy + Clone + Zero + PoolScalar>(
1321 buffers: &mut BufferPool,
1322 input: &TypedTensor<T>,
1323 config: &PadConfig,
1324) -> crate::Result<TypedTensor<T>> {
1325 typed_pad_with_fill(buffers, input, config, T::zero())
1326}
1327
1328fn typed_pad_with_fill<T: Copy + Clone + PoolScalar>(
1329 buffers: &mut BufferPool,
1330 input: &TypedTensor<T>,
1331 config: &PadConfig,
1332 fill: T,
1333) -> crate::Result<TypedTensor<T>> {
1334 let input_shape = input.shape();
1335 let rank = input_shape.len();
1336 if config.edge_padding_low.len() != rank {
1337 return Err(crate::Error::RankMismatch {
1338 op: "pad",
1339 expected: rank,
1340 actual: config.edge_padding_low.len(),
1341 });
1342 }
1343 if config.edge_padding_high.len() != rank {
1344 return Err(crate::Error::RankMismatch {
1345 op: "pad",
1346 expected: rank,
1347 actual: config.edge_padding_high.len(),
1348 });
1349 }
1350 if config.interior_padding.len() != rank {
1351 return Err(crate::Error::RankMismatch {
1352 op: "pad",
1353 expected: rank,
1354 actual: config.interior_padding.len(),
1355 });
1356 }
1357
1358 let mut out_shape = Vec::with_capacity(input_shape.len());
1359 for (axis, &input_extent) in input_shape.iter().enumerate() {
1360 if config.interior_padding[axis] < 0 {
1361 return Err(crate::Error::InvalidConfig {
1362 op: "pad",
1363 message: format!("interior padding must be non-negative on axis {axis}"),
1364 });
1365 }
1366 if config.edge_padding_low[axis] < 0 || config.edge_padding_high[axis] < 0 {
1367 return Err(crate::Error::InvalidConfig {
1368 op: "pad",
1369 message: format!("edge padding must be non-negative on axis {axis}"),
1370 });
1371 }
1372 let input_extent_i64 =
1373 i64::try_from(input_extent).map_err(|_| crate::Error::InvalidConfig {
1374 op: "pad",
1375 message: format!("input extent on axis {axis} does not fit in i64"),
1376 })?;
1377 let spacing = config.interior_padding[axis]
1378 .checked_add(1)
1379 .ok_or_else(|| crate::Error::InvalidConfig {
1380 op: "pad",
1381 message: format!("interior padding overflow on axis {axis}"),
1382 })?;
1383 let base = if input_extent == 0 {
1384 0
1385 } else {
1386 input_extent_i64
1387 .checked_sub(1)
1388 .and_then(|extent| extent.checked_mul(spacing))
1389 .and_then(|extent| extent.checked_add(1))
1390 .ok_or_else(|| crate::Error::InvalidConfig {
1391 op: "pad",
1392 message: format!("padded input extent overflow on axis {axis}"),
1393 })?
1394 };
1395 let dim = config.edge_padding_low[axis]
1396 .checked_add(config.edge_padding_high[axis])
1397 .and_then(|edge| edge.checked_add(base))
1398 .ok_or_else(|| crate::Error::InvalidConfig {
1399 op: "pad",
1400 message: format!("output dimension overflow on axis {axis}"),
1401 })?;
1402 out_shape.push(
1403 usize::try_from(dim).map_err(|_| crate::Error::InvalidConfig {
1404 op: "pad",
1405 message: format!("negative output dimension on axis {axis}"),
1406 })?,
1407 );
1408 }
1409
1410 let mut out = pooled_filled_tensor(buffers, out_shape.clone(), fill)?;
1411 let mut input_idx = vec![0usize; input_shape.len()];
1412 let mut out_idx = vec![0usize; input_shape.len()];
1413
1414 for input_value in input.as_slice()? {
1415 let mut in_bounds = true;
1416 for axis in 0..input_shape.len() {
1417 let out_pos = config.edge_padding_low[axis]
1418 + input_idx[axis] as i64 * (config.interior_padding[axis] + 1);
1419 if !(0..out_shape[axis] as i64).contains(&out_pos) {
1420 in_bounds = false;
1421 break;
1422 }
1423 out_idx[axis] = out_pos as usize;
1424 }
1425 if in_bounds {
1426 *out.get_mut(&out_idx)? = *input_value;
1427 }
1428 advance_col_major_index(&mut input_idx, input_shape);
1429 }
1430
1431 Ok(out)
1432}