Skip to main content

tenferro_cpu/
indexing.rs

1use std::ops::Add;
2
3use num_traits::Zero;
4
5use super::indexing_alloc::pooled_uninit_tensor;
6use super::typed_host_data;
7use crate::buffer_pool::{BufferPool, PoolScalar};
8use tenferro_tensor::{GatherConfig, PadConfig, ScatterConfig, SliceConfig};
9use tenferro_tensor::{Tensor, TypedTensor};
10
11// Indexing-family kernels stay as dedicated sequential loops for now. Their
12// per-output gather/scatter/slice/pad/concatenate/reverse index transforms do
13// not currently map to a strided-kernel or backend-native parallel primitive.
14// Backend entrypoints still run these loops inside CpuContext::install, so a
15// future parallel implementation can use the same CPU threading policy.
16
17trait TensorAsTyped<T> {
18    fn as_typed(&self) -> Option<&TypedTensor<T>>;
19}
20
21macro_rules! impl_tensor_as_typed {
22    ($(($ty:ty, $variant:ident)),+ $(,)?) => {
23        $(
24            impl TensorAsTyped<$ty> for Tensor {
25                fn as_typed(&self) -> Option<&TypedTensor<$ty>> {
26                    match self {
27                        Tensor::$variant(tensor) => Some(tensor),
28                        _ => None,
29                    }
30                }
31            }
32        )+
33    };
34}
35
36impl_tensor_as_typed!(
37    (f32, F32),
38    (f64, F64),
39    (i32, I32),
40    (i64, I64),
41    (bool, Bool),
42    (num_complex::Complex<f32>, C32),
43    (num_complex::Complex<f64>, C64),
44);
45
46macro_rules! dispatch_tensor_unary_result {
47    ($input:expr, |$tensor:ident| $body:expr) => {
48        match $input {
49            Tensor::F32($tensor) => Ok(Tensor::F32($body?)),
50            Tensor::F64($tensor) => Ok(Tensor::F64($body?)),
51            Tensor::I32($tensor) => Ok(Tensor::I32($body?)),
52            Tensor::I64($tensor) => Ok(Tensor::I64($body?)),
53            Tensor::Bool($tensor) => Ok(Tensor::Bool($body?)),
54            Tensor::C32($tensor) => Ok(Tensor::C32($body?)),
55            Tensor::C64($tensor) => Ok(Tensor::C64($body?)),
56        }
57    };
58}
59
60macro_rules! dispatch_tensor_unary_with_bool_special_result {
61    ($input:expr, |$tensor:ident| $body:expr, bool |$bool_tensor:ident| $bool_body:expr) => {
62        match $input {
63            Tensor::F32($tensor) => Ok(Tensor::F32($body?)),
64            Tensor::F64($tensor) => Ok(Tensor::F64($body?)),
65            Tensor::I32($tensor) => Ok(Tensor::I32($body?)),
66            Tensor::I64($tensor) => Ok(Tensor::I64($body?)),
67            Tensor::Bool($bool_tensor) => Ok(Tensor::Bool($bool_body?)),
68            Tensor::C32($tensor) => Ok(Tensor::C32($body?)),
69            Tensor::C64($tensor) => Ok(Tensor::C64($body?)),
70        }
71    };
72}
73
74macro_rules! dispatch_same_dtype_result {
75    ($op:literal, $lhs:expr, $rhs:expr, |$lhs_t:ident, $rhs_t:ident| $body:expr) => {
76        match ($lhs, $rhs) {
77            (Tensor::F32($lhs_t), Tensor::F32($rhs_t)) => Ok(Tensor::F32($body?)),
78            (Tensor::F64($lhs_t), Tensor::F64($rhs_t)) => Ok(Tensor::F64($body?)),
79            (Tensor::I32($lhs_t), Tensor::I32($rhs_t)) => Ok(Tensor::I32($body?)),
80            (Tensor::I64($lhs_t), Tensor::I64($rhs_t)) => Ok(Tensor::I64($body?)),
81            (Tensor::Bool($lhs_t), Tensor::Bool($rhs_t)) => Ok(Tensor::Bool($body?)),
82            (Tensor::C32($lhs_t), Tensor::C32($rhs_t)) => Ok(Tensor::C32($body?)),
83            (Tensor::C64($lhs_t), Tensor::C64($rhs_t)) => Ok(Tensor::C64($body?)),
84            _ => Err(crate::Error::DTypeMismatch {
85                op: $op,
86                lhs: $lhs.dtype(),
87                rhs: $rhs.dtype(),
88            }),
89        }
90    };
91}
92
93macro_rules! dispatch_same_dtype_without_bool_result {
94    ($op:literal, $lhs:expr, $rhs:expr, $bool_message:literal, |$lhs_t:ident, $rhs_t:ident| $body:expr) => {
95        match ($lhs, $rhs) {
96            (Tensor::F32($lhs_t), Tensor::F32($rhs_t)) => Ok(Tensor::F32($body?)),
97            (Tensor::F64($lhs_t), Tensor::F64($rhs_t)) => Ok(Tensor::F64($body?)),
98            (Tensor::I32($lhs_t), Tensor::I32($rhs_t)) => Ok(Tensor::I32($body?)),
99            (Tensor::I64($lhs_t), Tensor::I64($rhs_t)) => Ok(Tensor::I64($body?)),
100            (Tensor::C32($lhs_t), Tensor::C32($rhs_t)) => Ok(Tensor::C32($body?)),
101            (Tensor::C64($lhs_t), Tensor::C64($rhs_t)) => Ok(Tensor::C64($body?)),
102            (Tensor::Bool(_), Tensor::Bool(_)) => {
103                Err(crate::Error::backend_failure($op, $bool_message))
104            }
105            _ => Err(crate::Error::DTypeMismatch {
106                op: $op,
107                lhs: $lhs.dtype(),
108                rhs: $rhs.dtype(),
109            }),
110        }
111    };
112}
113
114fn with_local_pool<T>(f: impl FnOnce(&mut BufferPool) -> T) -> T {
115    let mut buffers = BufferPool::new();
116    f(&mut buffers)
117}
118
119fn advance_col_major_index(index: &mut [usize], shape: &[usize]) {
120    debug_assert_eq!(index.len(), shape.len());
121    for axis in 0..index.len() {
122        if shape[axis] == 0 {
123            index[axis] = 0;
124            continue;
125        }
126        index[axis] += 1;
127        if index[axis] < shape[axis] {
128            break;
129        }
130        index[axis] = 0;
131    }
132}
133
134fn pooled_filled_tensor<T>(
135    buffers: &mut BufferPool,
136    shape: Vec<usize>,
137    fill: T,
138) -> crate::Result<TypedTensor<T>>
139where
140    T: Copy + Clone + PoolScalar,
141{
142    // SAFETY: the following fill writes every pooled output element.
143    let mut out = pooled_uninit_tensor(buffers, shape)?;
144    out.host_data_mut()?.fill(fill);
145    Ok(out)
146}
147
148fn clone_host_tensor_from_pool<T>(
149    buffers: &mut BufferPool,
150    op: &'static str,
151    tensor: &TypedTensor<T>,
152) -> crate::Result<TypedTensor<T>>
153where
154    T: Copy + Clone + PoolScalar,
155{
156    // SAFETY: copy_from_slice writes every pooled output element.
157    let mut out = pooled_uninit_tensor(buffers, tensor.shape().to_vec())?;
158    out.host_data_mut()?
159        .copy_from_slice(typed_host_data(op, tensor)?);
160    Ok(out)
161}
162
163pub fn gather(
164    operand: &Tensor,
165    start_indices: &Tensor,
166    config: &GatherConfig,
167) -> crate::Result<Tensor> {
168    with_local_pool(|buffers| gather_with_pool(buffers, operand, start_indices, config))
169}
170
171pub(crate) fn gather_with_pool(
172    buffers: &mut BufferPool,
173    operand: &Tensor,
174    start_indices: &Tensor,
175    config: &GatherConfig,
176) -> crate::Result<Tensor> {
177    let start_indices = try_index_tensor(start_indices)?;
178    dispatch_tensor_unary_result!(operand, |t| typed_gather(
179        buffers,
180        t,
181        &start_indices,
182        config
183    ))
184}
185
186pub fn scatter(
187    operand: &Tensor,
188    scatter_indices: &Tensor,
189    updates: &Tensor,
190    config: &ScatterConfig,
191) -> crate::Result<Tensor> {
192    with_local_pool(|buffers| scatter_with_pool(buffers, operand, scatter_indices, updates, config))
193}
194
195pub(crate) fn scatter_with_pool(
196    buffers: &mut BufferPool,
197    operand: &Tensor,
198    scatter_indices: &Tensor,
199    updates: &Tensor,
200    config: &ScatterConfig,
201) -> crate::Result<Tensor> {
202    let scatter_indices = try_index_tensor(scatter_indices)?;
203    dispatch_same_dtype_without_bool_result!(
204        "scatter",
205        operand,
206        updates,
207        "Bool data tensors are not supported by additive scatter",
208        |op, upd| typed_scatter(buffers, op, &scatter_indices, upd, config)
209    )
210}
211
212pub(crate) fn try_slice_with_pool(
213    buffers: &mut BufferPool,
214    input: &Tensor,
215    config: &SliceConfig,
216) -> crate::Result<Tensor> {
217    dispatch_tensor_unary_result!(input, |t| typed_slice(buffers, t, config))
218}
219
220pub fn dynamic_slice(
221    input: &Tensor,
222    starts: &Tensor,
223    slice_sizes: &[usize],
224) -> crate::Result<Tensor> {
225    with_local_pool(|buffers| dynamic_slice_with_pool(buffers, input, starts, slice_sizes))
226}
227
228pub(crate) fn dynamic_slice_with_pool(
229    buffers: &mut BufferPool,
230    input: &Tensor,
231    starts: &Tensor,
232    slice_sizes: &[usize],
233) -> crate::Result<Tensor> {
234    let starts = try_index_tensor(starts)?;
235    dispatch_tensor_unary_result!(input, |t| typed_dynamic_slice(
236        buffers,
237        t,
238        &starts,
239        slice_sizes
240    ))
241}
242
243/// Return `operand` with `update` written at dynamic `starts`.
244///
245/// Starts are clamped so the whole update window fits inside the operand,
246/// matching `dynamic_slice` behavior.
247///
248/// # Examples
249///
250/// ```
251/// use tenferro_cpu as cpu;
252/// use tenferro_tensor::{Tensor, TypedTensor};
253///
254/// let operand = Tensor::F64(TypedTensor::from_vec_col_major(vec![5], vec![0.0; 5])?);
255/// let update = Tensor::F64(TypedTensor::from_vec_col_major(vec![2], vec![3.0, 4.0])?);
256/// let starts = Tensor::I64(TypedTensor::from_vec_col_major(vec![1], vec![4])?);
257///
258/// let out = cpu::dynamic_update_slice(&operand, &update, &starts).unwrap();
259/// assert_eq!(out.as_slice::<f64>().unwrap(), &[0.0, 0.0, 0.0, 3.0, 4.0]);
260/// # Ok::<(), tenferro_tensor::Error>(())
261/// ```
262pub fn dynamic_update_slice(
263    operand: &Tensor,
264    update: &Tensor,
265    starts: &Tensor,
266) -> crate::Result<Tensor> {
267    with_local_pool(|buffers| dynamic_update_slice_with_pool(buffers, operand, update, starts))
268}
269
270pub(crate) fn dynamic_update_slice_with_pool(
271    buffers: &mut BufferPool,
272    operand: &Tensor,
273    update: &Tensor,
274    starts: &Tensor,
275) -> crate::Result<Tensor> {
276    let starts = try_index_tensor(starts)?;
277    dispatch_same_dtype_result!("dynamic_update_slice", operand, update, |op, upd| {
278        typed_dynamic_update_slice(buffers, op, upd, &starts)
279    })
280}
281
282pub fn pad(input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
283    try_pad(input, config)
284}
285
286fn try_pad(input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
287    with_local_pool(|buffers| try_pad_with_pool(buffers, input, config))
288}
289
290pub(crate) fn try_pad_with_pool(
291    buffers: &mut BufferPool,
292    input: &Tensor,
293    config: &PadConfig,
294) -> crate::Result<Tensor> {
295    dispatch_tensor_unary_with_bool_special_result!(
296        input,
297        |t| typed_pad(buffers, t, config),
298        bool | t | typed_pad_with_fill(buffers, t, config, false)
299    )
300}
301
302pub(crate) fn try_concatenate_with_pool(
303    buffers: &mut BufferPool,
304    inputs: &[&Tensor],
305    axis: usize,
306) -> crate::Result<Tensor> {
307    let first = inputs
308        .first()
309        .copied()
310        .ok_or_else(|| crate::Error::InvalidConfig {
311            op: "concatenate",
312            message: "concatenate requires at least one input".into(),
313        })?;
314    dispatch_tensor_unary_result!(first, |t| typed_concatenate_from_dyn_inputs(
315        buffers, t, inputs, axis
316    ))
317}
318
319pub(crate) fn reverse_with_pool(
320    buffers: &mut BufferPool,
321    input: &Tensor,
322    axes: &[usize],
323) -> crate::Result<Tensor> {
324    dispatch_tensor_unary_result!(input, |t| typed_reverse(buffers, t, axes))
325}
326
327fn typed_slice<T: Copy + Clone + PoolScalar>(
328    buffers: &mut BufferPool,
329    input: &TypedTensor<T>,
330    config: &SliceConfig,
331) -> crate::Result<TypedTensor<T>> {
332    let input_shape = input.shape();
333    let rank = input_shape.len();
334    if config.starts.len() != rank {
335        return Err(crate::Error::RankMismatch {
336            op: "slice",
337            expected: rank,
338            actual: config.starts.len(),
339        });
340    }
341    if config.limits.len() != rank {
342        return Err(crate::Error::RankMismatch {
343            op: "slice",
344            expected: rank,
345            actual: config.limits.len(),
346        });
347    }
348    if config.strides.len() != rank {
349        return Err(crate::Error::RankMismatch {
350            op: "slice",
351            expected: rank,
352            actual: config.strides.len(),
353        });
354    }
355
356    let out_shape: Vec<usize> = input
357        .shape()
358        .iter()
359        .enumerate()
360        .map(|(axis, &dim)| {
361            let start = config.starts[axis];
362            let limit = config.limits[axis];
363            let stride = config.strides[axis];
364            if start > limit {
365                return Err(crate::Error::InvalidConfig {
366                    op: "slice",
367                    message: format!("start exceeds limit on axis {axis}"),
368                });
369            }
370            if limit > dim {
371                return Err(crate::Error::AxisOutOfBounds {
372                    op: "slice",
373                    axis,
374                    rank,
375                });
376            }
377            if stride == 0 {
378                return Err(crate::Error::InvalidConfig {
379                    op: "slice",
380                    message: format!("stride must be positive on axis {axis}"),
381                });
382            }
383            let span = limit - start;
384            Ok(span.div_ceil(stride))
385        })
386        .collect::<crate::Result<Vec<_>>>()?;
387
388    // SAFETY: the slice loop below assigns every output coordinate exactly once.
389    let mut out = pooled_uninit_tensor(buffers, out_shape.clone())?;
390    let mut out_idx = vec![0usize; rank];
391    let mut in_idx = vec![0usize; rank];
392
393    for out_value in out.host_data_mut()?.iter_mut() {
394        for axis in 0..rank {
395            in_idx[axis] = config.starts[axis] + out_idx[axis] * config.strides[axis];
396        }
397        *out_value = *input.get(&in_idx)?;
398        advance_col_major_index(&mut out_idx, &out_shape);
399    }
400
401    Ok(out)
402}
403
404fn typed_concatenate_from_dyn_inputs<T>(
405    buffers: &mut BufferPool,
406    _first: &TypedTensor<T>,
407    inputs: &[&Tensor],
408    axis: usize,
409) -> crate::Result<TypedTensor<T>>
410where
411    T: Copy + Clone + PoolScalar,
412    Tensor: TensorAsTyped<T>,
413{
414    let first_dtype = inputs[0].dtype();
415    let typed_inputs = collect_typed_inputs(first_dtype, inputs)?;
416    typed_concatenate(buffers, &typed_inputs, axis)
417}
418
419fn collect_typed_inputs<'a, T>(
420    first_dtype: crate::DType,
421    inputs: &[&'a Tensor],
422) -> crate::Result<Vec<&'a TypedTensor<T>>>
423where
424    Tensor: TensorAsTyped<T>,
425{
426    inputs
427        .iter()
428        .map(|tensor| {
429            TensorAsTyped::<T>::as_typed(*tensor).ok_or_else(|| crate::Error::DTypeMismatch {
430                op: "concatenate",
431                lhs: first_dtype,
432                rhs: tensor.dtype(),
433            })
434        })
435        .collect()
436}
437
438fn typed_concatenate<T: Copy + Clone + PoolScalar>(
439    buffers: &mut BufferPool,
440    inputs: &[&TypedTensor<T>],
441    axis: usize,
442) -> crate::Result<TypedTensor<T>> {
443    let first = inputs[0];
444    let first_shape = first.shape();
445    let rank = first_shape.len();
446    if axis >= rank {
447        return Err(crate::Error::AxisOutOfBounds {
448            op: "concatenate",
449            axis,
450            rank,
451        });
452    }
453
454    let mut out_shape = first_shape.to_vec();
455    let mut axis_extent = 0usize;
456    for input in inputs {
457        let input_shape = input.shape();
458        if input_shape.len() != rank {
459            return Err(crate::Error::RankMismatch {
460                op: "concatenate",
461                expected: rank,
462                actual: input_shape.len(),
463            });
464        }
465        for dim in 0..rank {
466            if dim == axis {
467                axis_extent = axis_extent.checked_add(input_shape[dim]).ok_or_else(|| {
468                    crate::Error::InvalidConfig {
469                        op: "concatenate",
470                        message: "concatenate axis extent overflows usize".to_string(),
471                    }
472                })?;
473            } else if input_shape[dim] != first_shape[dim] {
474                return Err(crate::Error::ShapeMismatch {
475                    op: "concatenate",
476                    lhs: first_shape.to_vec(),
477                    rhs: input_shape.to_vec(),
478                });
479            }
480        }
481    }
482    out_shape[axis] = axis_extent;
483
484    let mut segment_ends = Vec::with_capacity(inputs.len());
485    let mut segment_end = 0usize;
486    for input in inputs {
487        segment_end = segment_end
488            .checked_add(input.shape()[axis])
489            .ok_or_else(|| crate::Error::InvalidConfig {
490                op: "concatenate",
491                message: "concatenate segment offset overflows usize".to_string(),
492            })?;
493        segment_ends.push(segment_end);
494    }
495
496    // SAFETY: the concatenate loop below assigns every output coordinate exactly once.
497    let mut out = pooled_uninit_tensor(buffers, out_shape.clone())?;
498    let mut out_idx = vec![0usize; rank];
499    let mut in_idx = vec![0usize; rank];
500
501    for out_value in out.host_data_mut()?.iter_mut() {
502        let concat_idx = out_idx[axis];
503        let input_pos = segment_ends.partition_point(|&end| concat_idx >= end);
504        if input_pos == segment_ends.len() {
505            return Err(crate::Error::InvalidConfig {
506                op: "concatenate",
507                message: "output index must map to an input".to_string(),
508            });
509        }
510        let axis_base = if input_pos == 0 {
511            0
512        } else {
513            segment_ends[input_pos - 1]
514        };
515
516        in_idx.copy_from_slice(&out_idx);
517        in_idx[axis] -= axis_base;
518        *out_value = *inputs[input_pos].get(&in_idx)?;
519        advance_col_major_index(&mut out_idx, &out_shape);
520    }
521
522    Ok(out)
523}
524
525fn typed_reverse<T: Copy + Clone + PoolScalar>(
526    buffers: &mut BufferPool,
527    input: &TypedTensor<T>,
528    axes: &[usize],
529) -> crate::Result<TypedTensor<T>> {
530    let input_shape = input.shape();
531    let rank = input_shape.len();
532    let mut reverse_axis = vec![false; rank];
533    for &axis in axes {
534        if axis >= rank {
535            return Err(crate::Error::AxisOutOfBounds {
536                op: "reverse",
537                axis,
538                rank,
539            });
540        }
541        reverse_axis[axis] = true;
542    }
543
544    // SAFETY: the reverse loop below assigns every output coordinate exactly once.
545    let mut out = pooled_uninit_tensor(buffers, input_shape.to_vec())?;
546    let mut out_idx = vec![0usize; rank];
547    let mut in_idx = vec![0usize; rank];
548
549    for out_value in out.host_data_mut()?.iter_mut() {
550        for axis in 0..rank {
551            in_idx[axis] = if reverse_axis[axis] {
552                input_shape[axis] - 1 - out_idx[axis]
553            } else {
554                out_idx[axis]
555            };
556        }
557        *out_value = *input.get(&in_idx)?;
558        advance_col_major_index(&mut out_idx, input_shape);
559    }
560
561    Ok(out)
562}
563
564struct IndexTensor {
565    shape: Vec<usize>,
566    values: Vec<i64>,
567}
568
569/// Maximum exact integer representable by f32 (2^24).
570const F32_MAX_EXACT_INT: f32 = 16_777_216.0;
571/// Maximum exact integer representable by f64 (2^53).
572const F64_MAX_EXACT_INT: f64 = 9_007_199_254_740_992.0;
573
574fn f32_index_to_i64(value: f32) -> crate::Result<i64> {
575    if !value.is_finite() || value.fract() != 0.0 || value.abs() > F32_MAX_EXACT_INT {
576        return Err(crate::Error::InvalidConfig {
577            op: "index_tensor",
578            message: format!("index value {value} is not an exactly representable i64"),
579        });
580    }
581    Ok(value as i64)
582}
583
584fn f64_index_to_i64(value: f64) -> crate::Result<i64> {
585    if !value.is_finite() || value.fract() != 0.0 || value.abs() > F64_MAX_EXACT_INT {
586        return Err(crate::Error::InvalidConfig {
587            op: "index_tensor",
588            message: format!("index value {value} is not an exactly representable i64"),
589        });
590    }
591    Ok(value as i64)
592}
593
594fn try_index_tensor(tensor: &Tensor) -> crate::Result<IndexTensor> {
595    match tensor {
596        Tensor::I32(t) => Ok(IndexTensor {
597            shape: t.shape().to_vec(),
598            values: typed_host_data("index_tensor", t)?
599                .iter()
600                .map(|&value| value as i64)
601                .collect(),
602        }),
603        Tensor::I64(t) => Ok(IndexTensor {
604            shape: t.shape().to_vec(),
605            values: typed_host_data("index_tensor", t)?.to_vec(),
606        }),
607        Tensor::F32(t) => {
608            let values: crate::Result<Vec<i64>> = typed_host_data("index_tensor", t)?
609                .iter()
610                .map(|&value| f32_index_to_i64(value))
611                .collect();
612            Ok(IndexTensor {
613                shape: t.shape().to_vec(),
614                values: values?,
615            })
616        }
617        Tensor::F64(t) => {
618            let values: crate::Result<Vec<i64>> = typed_host_data("index_tensor", t)?
619                .iter()
620                .map(|&value| f64_index_to_i64(value))
621                .collect();
622            Ok(IndexTensor {
623                shape: t.shape().to_vec(),
624                values: values?,
625            })
626        }
627        Tensor::Bool(_) => Err(crate::Error::InvalidConfig {
628            op: "index_tensor",
629            message: "bool index tensors are not supported".into(),
630        }),
631        Tensor::C32(_) | Tensor::C64(_) => Err(crate::Error::InvalidConfig {
632            op: "index_tensor",
633            message: "complex index tensors are not supported".into(),
634        }),
635    }
636}
637
638fn checked_product(op: &'static str, role: &'static str, shape: &[usize]) -> crate::Result<usize> {
639    shape.iter().try_fold(1usize, |acc, &dim| {
640        acc.checked_mul(dim)
641            .ok_or_else(|| crate::Error::InvalidConfig {
642                op,
643                message: format!("{role} element count overflows usize"),
644            })
645    })
646}
647
648fn linear_offset(op: &'static str, shape: &[usize], indices: &[usize]) -> crate::Result<usize> {
649    let mut offset = 0usize;
650    let mut stride = 1usize;
651    for (axis, &index) in indices.iter().enumerate() {
652        let scaled = index
653            .checked_mul(stride)
654            .ok_or_else(|| crate::Error::InvalidConfig {
655                op,
656                message: format!("linear index component overflows usize on axis {axis}"),
657            })?;
658        offset = offset
659            .checked_add(scaled)
660            .ok_or_else(|| crate::Error::InvalidConfig {
661                op,
662                message: format!("linear offset overflows usize on axis {axis}"),
663            })?;
664        stride = stride
665            .checked_mul(shape[axis])
666            .ok_or_else(|| crate::Error::InvalidConfig {
667                op,
668                message: format!("linear stride overflows usize after axis {axis}"),
669            })?;
670    }
671    Ok(offset)
672}
673
674fn try_index_vector_size(
675    op: &'static str,
676    shape: &[usize],
677    index_vector_dim: usize,
678) -> crate::Result<usize> {
679    if index_vector_dim > shape.len() {
680        return Err(crate::Error::AxisOutOfBounds {
681            op,
682            axis: index_vector_dim,
683            rank: shape.len(),
684        });
685    }
686    Ok(if index_vector_dim == shape.len() {
687        1
688    } else {
689        shape[index_vector_dim]
690    })
691}
692
693fn try_index_batch_shape(
694    op: &'static str,
695    shape: &[usize],
696    index_vector_dim: usize,
697) -> crate::Result<Vec<usize>> {
698    if index_vector_dim > shape.len() {
699        return Err(crate::Error::AxisOutOfBounds {
700            op,
701            axis: index_vector_dim,
702            rank: shape.len(),
703        });
704    }
705    if index_vector_dim == shape.len() {
706        return Ok(shape.to_vec());
707    }
708    Ok(shape
709        .iter()
710        .enumerate()
711        .filter_map(|(axis, &dim)| (axis != index_vector_dim).then_some(dim))
712        .collect())
713}
714
715fn index_component(
716    op: &'static str,
717    indices: &IndexTensor,
718    batch_idx: &[usize],
719    index_vector_dim: usize,
720    component: usize,
721    index_scratch: &mut [usize],
722) -> crate::Result<i64> {
723    if index_vector_dim == indices.shape.len() {
724        if component != 0 {
725            return Err(crate::Error::InvalidConfig {
726                op,
727                message: "implicit index_vector_dim only supports scalar index vectors".into(),
728            });
729        }
730        return Ok(indices.values[linear_offset(op, &indices.shape, batch_idx)?]);
731    }
732
733    debug_assert_eq!(index_scratch.len(), indices.shape.len());
734    let mut batch_axis = 0usize;
735    for (axis, slot) in index_scratch.iter_mut().enumerate() {
736        if axis == index_vector_dim {
737            *slot = component;
738        } else {
739            *slot = batch_idx[batch_axis];
740            batch_axis += 1;
741        }
742    }
743    Ok(indices.values[linear_offset(op, &indices.shape, index_scratch)?])
744}
745
746fn clamp_window_start(
747    op: &'static str,
748    start: i64,
749    dim_size: usize,
750    window_size: usize,
751) -> crate::Result<usize> {
752    if window_size > dim_size {
753        return Err(crate::Error::InvalidConfig {
754            op,
755            message: format!("window size {window_size} exceeds dimension size {dim_size}"),
756        });
757    }
758    let max_start = dim_size.saturating_sub(window_size) as i64;
759    Ok(start.clamp(0, max_start) as usize)
760}
761
762fn operand_window_dims(rank: usize, collapsed_or_inserted: &[usize]) -> Vec<usize> {
763    (0..rank)
764        .filter(|dim| !collapsed_or_inserted.contains(dim))
765        .collect()
766}
767
768fn typed_gather<T: Copy + Clone + PoolScalar>(
769    buffers: &mut BufferPool,
770    operand: &TypedTensor<T>,
771    start_indices: &IndexTensor,
772    config: &GatherConfig,
773) -> crate::Result<TypedTensor<T>> {
774    let operand_shape = operand.shape();
775    let rank = operand_shape.len();
776    if config.slice_sizes.len() != rank {
777        return Err(crate::Error::RankMismatch {
778            op: "gather",
779            expected: rank,
780            actual: config.slice_sizes.len(),
781        });
782    }
783
784    for &dim in &config.collapsed_slice_dims {
785        if dim >= rank {
786            return Err(crate::Error::AxisOutOfBounds {
787                op: "gather",
788                axis: dim,
789                rank,
790            });
791        }
792    }
793    {
794        let mut seen = vec![false; rank];
795        for &dim in &config.collapsed_slice_dims {
796            if seen[dim] {
797                return Err(crate::Error::DuplicateAxis {
798                    op: "gather",
799                    axis: dim,
800                    role: "collapsed_slice_dims",
801                });
802            }
803            seen[dim] = true;
804        }
805    }
806    for &dim in &config.collapsed_slice_dims {
807        if config.slice_sizes[dim] != 1 {
808            return Err(crate::Error::InvalidConfig {
809                op: "gather",
810                message: format!(
811                    "collapsed slice dimension {dim} must have slice_size == 1, got {}",
812                    config.slice_sizes[dim]
813                ),
814            });
815        }
816    }
817
818    let index_size =
819        try_index_vector_size("gather", &start_indices.shape, config.index_vector_dim)?;
820    if index_size != config.start_index_map.len() {
821        return Err(crate::Error::InvalidConfig {
822            op: "gather",
823            message: format!(
824                "start_index_map length {} does not match index vector size {}",
825                config.start_index_map.len(),
826                index_size
827            ),
828        });
829    }
830    for &operand_dim in &config.start_index_map {
831        if operand_dim >= rank {
832            return Err(crate::Error::AxisOutOfBounds {
833                op: "gather",
834                axis: operand_dim,
835                rank,
836            });
837        }
838    }
839    {
840        let mut seen = vec![false; rank];
841        for &operand_dim in &config.start_index_map {
842            if seen[operand_dim] {
843                return Err(crate::Error::DuplicateAxis {
844                    op: "gather",
845                    axis: operand_dim,
846                    role: "start_index_map",
847                });
848            }
849            seen[operand_dim] = true;
850        }
851    }
852
853    let window_dims = operand_window_dims(rank, &config.collapsed_slice_dims);
854    if config.offset_dims.len() != window_dims.len() {
855        return Err(crate::Error::InvalidConfig {
856            op: "gather",
857            message: format!(
858                "offset_dims length {} does not match window dims count {}",
859                config.offset_dims.len(),
860                window_dims.len()
861            ),
862        });
863    }
864
865    let batch_shape =
866        try_index_batch_shape("gather", &start_indices.shape, config.index_vector_dim)?;
867    let out_rank = batch_shape.len() + config.offset_dims.len();
868    for &out_axis in &config.offset_dims {
869        if out_axis >= out_rank {
870            return Err(crate::Error::AxisOutOfBounds {
871                op: "gather",
872                axis: out_axis,
873                rank: out_rank,
874            });
875        }
876    }
877    {
878        let mut seen = vec![false; out_rank];
879        for &out_axis in &config.offset_dims {
880            if seen[out_axis] {
881                return Err(crate::Error::DuplicateAxis {
882                    op: "gather",
883                    axis: out_axis,
884                    role: "offset_dims",
885                });
886            }
887            seen[out_axis] = true;
888        }
889    }
890
891    let mut out_axis_to_operand_dim = vec![None; out_rank];
892    for (offset_axis, &out_axis) in config.offset_dims.iter().enumerate() {
893        out_axis_to_operand_dim[out_axis] = Some(window_dims[offset_axis]);
894    }
895
896    let mut out_shape = vec![0usize; out_rank];
897    let mut batch_axis = 0usize;
898    for out_axis in 0..out_rank {
899        if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
900            out_shape[out_axis] = config.slice_sizes[operand_dim];
901        } else {
902            out_shape[out_axis] = batch_shape[batch_axis];
903            batch_axis += 1;
904        }
905    }
906
907    for (component, &operand_dim) in config.start_index_map.iter().enumerate() {
908        let _ = clamp_window_start(
909            "gather",
910            0,
911            operand_shape[operand_dim],
912            config.slice_sizes[operand_dim],
913        )?;
914        let _ = component;
915    }
916
917    // SAFETY: the gather loop below assigns every output coordinate exactly once.
918    let mut out = pooled_uninit_tensor(buffers, out_shape.clone())?;
919    let mut out_idx = vec![0usize; out_rank];
920    let mut batch_idx = vec![0usize; batch_shape.len()];
921    let mut operand_idx = vec![0usize; rank];
922    let mut window_offsets = vec![0usize; rank];
923    let mut index_scratch = vec![0usize; start_indices.shape.len()];
924
925    for out_value in out.host_data_mut()?.iter_mut() {
926        batch_axis = 0;
927        window_offsets.fill(0);
928        for out_axis in 0..out_rank {
929            if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
930                window_offsets[operand_dim] = out_idx[out_axis];
931            } else {
932                batch_idx[batch_axis] = out_idx[out_axis];
933                batch_axis += 1;
934            }
935        }
936
937        operand_idx.fill(0);
938        for (component, &operand_dim) in config.start_index_map.iter().enumerate() {
939            let start = index_component(
940                "gather",
941                start_indices,
942                &batch_idx,
943                config.index_vector_dim,
944                component,
945                &mut index_scratch,
946            )?;
947            operand_idx[operand_dim] = clamp_window_start(
948                "gather",
949                start,
950                operand_shape[operand_dim],
951                config.slice_sizes[operand_dim],
952            )?;
953        }
954
955        for axis in 0..operand_idx.len() {
956            operand_idx[axis] += window_offsets[axis];
957        }
958
959        *out_value = *operand.get(&operand_idx)?;
960        advance_col_major_index(&mut out_idx, &out_shape);
961    }
962
963    Ok(out)
964}
965
966fn typed_scatter<T>(
967    buffers: &mut BufferPool,
968    operand: &TypedTensor<T>,
969    scatter_indices: &IndexTensor,
970    updates: &TypedTensor<T>,
971    config: &ScatterConfig,
972) -> crate::Result<TypedTensor<T>>
973where
974    T: Copy + Clone + Zero + Add<Output = T> + PoolScalar,
975{
976    let operand_shape = operand.shape();
977    let updates_shape = updates.shape();
978    let op_rank = operand_shape.len();
979    for &dim in &config.inserted_window_dims {
980        if dim >= op_rank {
981            return Err(crate::Error::AxisOutOfBounds {
982                op: "scatter",
983                axis: dim,
984                rank: op_rank,
985            });
986        }
987    }
988    {
989        let mut seen = vec![false; op_rank];
990        for &dim in &config.inserted_window_dims {
991            if seen[dim] {
992                return Err(crate::Error::DuplicateAxis {
993                    op: "scatter",
994                    axis: dim,
995                    role: "inserted_window_dims",
996                });
997            }
998            seen[dim] = true;
999        }
1000    }
1001
1002    let index_size =
1003        try_index_vector_size("scatter", &scatter_indices.shape, config.index_vector_dim)?;
1004    if index_size != config.scatter_dims_to_operand_dims.len() {
1005        return Err(crate::Error::InvalidConfig {
1006            op: "scatter",
1007            message: format!(
1008                "scatter_dims_to_operand_dims length {} does not match index vector size {}",
1009                config.scatter_dims_to_operand_dims.len(),
1010                index_size
1011            ),
1012        });
1013    }
1014    for &operand_dim in &config.scatter_dims_to_operand_dims {
1015        if operand_dim >= op_rank {
1016            return Err(crate::Error::AxisOutOfBounds {
1017                op: "scatter",
1018                axis: operand_dim,
1019                rank: op_rank,
1020            });
1021        }
1022    }
1023    {
1024        let mut seen = vec![false; op_rank];
1025        for &operand_dim in &config.scatter_dims_to_operand_dims {
1026            if seen[operand_dim] {
1027                return Err(crate::Error::DuplicateAxis {
1028                    op: "scatter",
1029                    axis: operand_dim,
1030                    role: "scatter_dims_to_operand_dims",
1031                });
1032            }
1033            seen[operand_dim] = true;
1034        }
1035    }
1036
1037    let batch_shape =
1038        try_index_batch_shape("scatter", &scatter_indices.shape, config.index_vector_dim)?;
1039    let window_dims = operand_window_dims(op_rank, &config.inserted_window_dims);
1040    if config.update_window_dims.len() != window_dims.len() {
1041        return Err(crate::Error::InvalidConfig {
1042            op: "scatter",
1043            message: format!(
1044                "update_window_dims length {} does not match window dims count {}",
1045                config.update_window_dims.len(),
1046                window_dims.len()
1047            ),
1048        });
1049    }
1050
1051    let update_rank = updates_shape.len();
1052    let expected_batch_rank = update_rank
1053        .checked_sub(config.update_window_dims.len())
1054        .ok_or_else(|| crate::Error::InvalidConfig {
1055            op: "scatter",
1056            message: format!(
1057                "update_window_dims length {} exceeds update rank {}",
1058                config.update_window_dims.len(),
1059                update_rank
1060            ),
1061        })?;
1062    if expected_batch_rank != batch_shape.len() {
1063        return Err(crate::Error::InvalidConfig {
1064            op: "scatter",
1065            message: format!(
1066                "updates batch rank {} does not match index batch shape length {}",
1067                expected_batch_rank,
1068                batch_shape.len()
1069            ),
1070        });
1071    }
1072
1073    for &axis in &config.update_window_dims {
1074        if axis >= update_rank {
1075            return Err(crate::Error::AxisOutOfBounds {
1076                op: "scatter",
1077                axis,
1078                rank: update_rank,
1079            });
1080        }
1081    }
1082    {
1083        let mut seen = vec![false; update_rank];
1084        for &axis in &config.update_window_dims {
1085            if seen[axis] {
1086                return Err(crate::Error::DuplicateAxis {
1087                    op: "scatter",
1088                    axis,
1089                    role: "update_window_dims",
1090                });
1091            }
1092            seen[axis] = true;
1093        }
1094    }
1095
1096    let mut is_update_window_dim = vec![false; update_rank];
1097    for &axis in &config.update_window_dims {
1098        is_update_window_dim[axis] = true;
1099    }
1100
1101    {
1102        let mut batch_axis = 0usize;
1103        for axis in 0..update_rank {
1104            if !is_update_window_dim[axis] {
1105                if updates_shape[axis] != batch_shape[batch_axis] {
1106                    return Err(crate::Error::InvalidConfig {
1107                        op: "scatter",
1108                        message: format!(
1109                            "updates batch dim {} extent {} does not match index batch dim {} extent {}",
1110                            axis,
1111                            updates_shape[axis],
1112                            batch_axis,
1113                            batch_shape[batch_axis]
1114                        ),
1115                    });
1116                }
1117                batch_axis += 1;
1118            }
1119        }
1120    }
1121
1122    let mut window_shape = vec![1usize; op_rank];
1123    let mut window_shape_updates = vec![0usize; config.update_window_dims.len()];
1124    for (pos, &update_axis) in config.update_window_dims.iter().enumerate() {
1125        let dim = updates_shape[update_axis];
1126        window_shape_updates[pos] = dim;
1127        window_shape[window_dims[pos]] = dim;
1128    }
1129
1130    let batch_elems = checked_product("scatter", "batch shape", &batch_shape)?;
1131    let window_elems = checked_product("scatter", "window update shape", &window_shape_updates)?;
1132    let mut out = clone_host_tensor_from_pool(buffers, "scatter", operand)?;
1133
1134    let mut batch_idx = vec![0usize; batch_shape.len()];
1135    let mut window_idx = vec![0usize; window_shape_updates.len()];
1136    let mut update_idx = vec![0usize; update_rank];
1137    let mut operand_base = vec![0usize; op_rank];
1138    let mut operand_idx = vec![0usize; op_rank];
1139    let mut index_scratch = vec![0usize; scatter_indices.shape.len()];
1140
1141    for _ in 0..batch_elems {
1142        let mut window_fits = true;
1143        operand_base.fill(0);
1144        for (component, &operand_dim) in config.scatter_dims_to_operand_dims.iter().enumerate() {
1145            let start = index_component(
1146                "scatter",
1147                scatter_indices,
1148                &batch_idx,
1149                config.index_vector_dim,
1150                component,
1151                &mut index_scratch,
1152            )?;
1153            if start < 0 {
1154                window_fits = false;
1155                break;
1156            }
1157            operand_base[operand_dim] = start as usize;
1158        }
1159        if !window_fits {
1160            advance_col_major_index(&mut batch_idx, &batch_shape);
1161            continue;
1162        }
1163
1164        for axis in 0..op_rank {
1165            if operand_base[axis] + window_shape[axis] > operand_shape[axis] {
1166                window_fits = false;
1167                break;
1168            }
1169        }
1170        if !window_fits {
1171            advance_col_major_index(&mut batch_idx, &batch_shape);
1172            continue;
1173        }
1174
1175        window_idx.fill(0);
1176        for _ in 0..window_elems {
1177            let mut batch_axis = 0usize;
1178            let mut window_axis = 0usize;
1179            for axis in 0..update_rank {
1180                if is_update_window_dim[axis] {
1181                    update_idx[axis] = window_idx[window_axis];
1182                    window_axis += 1;
1183                } else {
1184                    update_idx[axis] = batch_idx[batch_axis];
1185                    batch_axis += 1;
1186                }
1187            }
1188
1189            operand_idx.copy_from_slice(&operand_base);
1190            for (window_axis, &operand_axis) in window_dims.iter().enumerate() {
1191                operand_idx[operand_axis] += window_idx[window_axis];
1192            }
1193
1194            let value = *updates.get(&update_idx)?;
1195            let slot = out.get_mut(&operand_idx)?;
1196            *slot = *slot + value;
1197            advance_col_major_index(&mut window_idx, &window_shape_updates);
1198        }
1199        advance_col_major_index(&mut batch_idx, &batch_shape);
1200    }
1201
1202    Ok(out)
1203}
1204
1205fn typed_dynamic_slice<T: Copy + Clone + PoolScalar>(
1206    buffers: &mut BufferPool,
1207    input: &TypedTensor<T>,
1208    starts: &IndexTensor,
1209    slice_sizes: &[usize],
1210) -> crate::Result<TypedTensor<T>> {
1211    let input_shape = input.shape();
1212    if slice_sizes.len() != input_shape.len() {
1213        return Err(crate::Error::RankMismatch {
1214            op: "dynamic_slice",
1215            expected: input_shape.len(),
1216            actual: slice_sizes.len(),
1217        });
1218    }
1219    if starts.shape.len() != 1 {
1220        return Err(crate::Error::InvalidConfig {
1221            op: "dynamic_slice",
1222            message: "starts must be a rank-1 tensor".into(),
1223        });
1224    }
1225    if starts.values.len() != input_shape.len() {
1226        return Err(crate::Error::InvalidConfig {
1227            op: "dynamic_slice",
1228            message: format!(
1229                "starts length {} must match input rank {}",
1230                starts.values.len(),
1231                input_shape.len()
1232            ),
1233        });
1234    }
1235
1236    let mut clamped_starts = vec![0usize; input_shape.len()];
1237    for axis in 0..input_shape.len() {
1238        clamped_starts[axis] = clamp_window_start(
1239            "dynamic_slice",
1240            starts.values[axis],
1241            input_shape[axis],
1242            slice_sizes[axis],
1243        )?;
1244    }
1245
1246    let out_shape = slice_sizes.to_vec();
1247    // SAFETY: the dynamic-slice loop below assigns every output coordinate exactly once.
1248    let mut out = pooled_uninit_tensor(buffers, out_shape.clone())?;
1249    let mut out_idx = vec![0usize; out_shape.len()];
1250    let mut input_idx = vec![0usize; out_shape.len()];
1251
1252    for out_value in out.host_data_mut()?.iter_mut() {
1253        for axis in 0..out_shape.len() {
1254            input_idx[axis] = clamped_starts[axis] + out_idx[axis];
1255        }
1256        *out_value = *input.get(&input_idx)?;
1257        advance_col_major_index(&mut out_idx, &out_shape);
1258    }
1259
1260    Ok(out)
1261}
1262
1263fn typed_dynamic_update_slice<T: Copy + Clone + PoolScalar>(
1264    buffers: &mut BufferPool,
1265    operand: &TypedTensor<T>,
1266    update: &TypedTensor<T>,
1267    starts: &IndexTensor,
1268) -> crate::Result<TypedTensor<T>> {
1269    let operand_shape = operand.shape();
1270    let update_shape = update.shape();
1271    if update_shape.len() != operand_shape.len() {
1272        return Err(crate::Error::RankMismatch {
1273            op: "dynamic_update_slice",
1274            expected: operand_shape.len(),
1275            actual: update_shape.len(),
1276        });
1277    }
1278    if starts.shape.len() != 1 {
1279        return Err(crate::Error::InvalidConfig {
1280            op: "dynamic_update_slice",
1281            message: "starts must be a rank-1 tensor".into(),
1282        });
1283    }
1284    if starts.values.len() != operand_shape.len() {
1285        return Err(crate::Error::InvalidConfig {
1286            op: "dynamic_update_slice",
1287            message: format!(
1288                "starts length {} must match operand rank {}",
1289                starts.values.len(),
1290                operand_shape.len()
1291            ),
1292        });
1293    }
1294
1295    let mut clamped_starts = vec![0usize; operand_shape.len()];
1296    for axis in 0..operand_shape.len() {
1297        clamped_starts[axis] = clamp_window_start(
1298            "dynamic_update_slice",
1299            starts.values[axis],
1300            operand_shape[axis],
1301            update_shape[axis],
1302        )?;
1303    }
1304
1305    let mut out = clone_host_tensor_from_pool(buffers, "dynamic_update_slice", operand)?;
1306    let mut update_idx = vec![0usize; update_shape.len()];
1307    let mut operand_idx = vec![0usize; operand_shape.len()];
1308
1309    for update_value in update.as_slice()? {
1310        for axis in 0..update_shape.len() {
1311            operand_idx[axis] = clamped_starts[axis] + update_idx[axis];
1312        }
1313        *out.get_mut(&operand_idx)? = *update_value;
1314        advance_col_major_index(&mut update_idx, update_shape);
1315    }
1316
1317    Ok(out)
1318}
1319
1320fn typed_pad<T: Copy + Clone + Zero + PoolScalar>(
1321    buffers: &mut BufferPool,
1322    input: &TypedTensor<T>,
1323    config: &PadConfig,
1324) -> crate::Result<TypedTensor<T>> {
1325    typed_pad_with_fill(buffers, input, config, T::zero())
1326}
1327
1328fn typed_pad_with_fill<T: Copy + Clone + PoolScalar>(
1329    buffers: &mut BufferPool,
1330    input: &TypedTensor<T>,
1331    config: &PadConfig,
1332    fill: T,
1333) -> crate::Result<TypedTensor<T>> {
1334    let input_shape = input.shape();
1335    let rank = input_shape.len();
1336    if config.edge_padding_low.len() != rank {
1337        return Err(crate::Error::RankMismatch {
1338            op: "pad",
1339            expected: rank,
1340            actual: config.edge_padding_low.len(),
1341        });
1342    }
1343    if config.edge_padding_high.len() != rank {
1344        return Err(crate::Error::RankMismatch {
1345            op: "pad",
1346            expected: rank,
1347            actual: config.edge_padding_high.len(),
1348        });
1349    }
1350    if config.interior_padding.len() != rank {
1351        return Err(crate::Error::RankMismatch {
1352            op: "pad",
1353            expected: rank,
1354            actual: config.interior_padding.len(),
1355        });
1356    }
1357
1358    let mut out_shape = Vec::with_capacity(input_shape.len());
1359    for (axis, &input_extent) in input_shape.iter().enumerate() {
1360        if config.interior_padding[axis] < 0 {
1361            return Err(crate::Error::InvalidConfig {
1362                op: "pad",
1363                message: format!("interior padding must be non-negative on axis {axis}"),
1364            });
1365        }
1366        if config.edge_padding_low[axis] < 0 || config.edge_padding_high[axis] < 0 {
1367            return Err(crate::Error::InvalidConfig {
1368                op: "pad",
1369                message: format!("edge padding must be non-negative on axis {axis}"),
1370            });
1371        }
1372        let input_extent_i64 =
1373            i64::try_from(input_extent).map_err(|_| crate::Error::InvalidConfig {
1374                op: "pad",
1375                message: format!("input extent on axis {axis} does not fit in i64"),
1376            })?;
1377        let spacing = config.interior_padding[axis]
1378            .checked_add(1)
1379            .ok_or_else(|| crate::Error::InvalidConfig {
1380                op: "pad",
1381                message: format!("interior padding overflow on axis {axis}"),
1382            })?;
1383        let base = if input_extent == 0 {
1384            0
1385        } else {
1386            input_extent_i64
1387                .checked_sub(1)
1388                .and_then(|extent| extent.checked_mul(spacing))
1389                .and_then(|extent| extent.checked_add(1))
1390                .ok_or_else(|| crate::Error::InvalidConfig {
1391                    op: "pad",
1392                    message: format!("padded input extent overflow on axis {axis}"),
1393                })?
1394        };
1395        let dim = config.edge_padding_low[axis]
1396            .checked_add(config.edge_padding_high[axis])
1397            .and_then(|edge| edge.checked_add(base))
1398            .ok_or_else(|| crate::Error::InvalidConfig {
1399                op: "pad",
1400                message: format!("output dimension overflow on axis {axis}"),
1401            })?;
1402        out_shape.push(
1403            usize::try_from(dim).map_err(|_| crate::Error::InvalidConfig {
1404                op: "pad",
1405                message: format!("negative output dimension on axis {axis}"),
1406            })?,
1407        );
1408    }
1409
1410    let mut out = pooled_filled_tensor(buffers, out_shape.clone(), fill)?;
1411    let mut input_idx = vec![0usize; input_shape.len()];
1412    let mut out_idx = vec![0usize; input_shape.len()];
1413
1414    for input_value in input.as_slice()? {
1415        let mut in_bounds = true;
1416        for axis in 0..input_shape.len() {
1417            let out_pos = config.edge_padding_low[axis]
1418                + input_idx[axis] as i64 * (config.interior_padding[axis] + 1);
1419            if !(0..out_shape[axis] as i64).contains(&out_pos) {
1420                in_bounds = false;
1421                break;
1422            }
1423            out_idx[axis] = out_pos as usize;
1424        }
1425        if in_bounds {
1426            *out.get_mut(&out_idx)? = *input_value;
1427        }
1428        advance_col_major_index(&mut input_idx, input_shape);
1429    }
1430
1431    Ok(out)
1432}