Skip to main content

tenferro_tensor/cpu/
indexing.rs

1use std::ops::Add;
2
3use num_traits::Zero;
4
5use crate::config::{GatherConfig, PadConfig, ScatterConfig, SliceConfig};
6use crate::types::{dispatch_binary, dispatch_tensor, flat_to_multi, Tensor, TypedTensor};
7
8trait TensorAsTyped<T> {
9    fn as_typed(&self) -> Option<&TypedTensor<T>>;
10}
11
12impl TensorAsTyped<f32> for Tensor {
13    fn as_typed(&self) -> Option<&TypedTensor<f32>> {
14        match self {
15            Tensor::F32(tensor) => Some(tensor),
16            _ => None,
17        }
18    }
19}
20
21impl TensorAsTyped<f64> for Tensor {
22    fn as_typed(&self) -> Option<&TypedTensor<f64>> {
23        match self {
24            Tensor::F64(tensor) => Some(tensor),
25            _ => None,
26        }
27    }
28}
29
30impl TensorAsTyped<num_complex::Complex<f32>> for Tensor {
31    fn as_typed(&self) -> Option<&TypedTensor<num_complex::Complex<f32>>> {
32        match self {
33            Tensor::C32(tensor) => Some(tensor),
34            _ => None,
35        }
36    }
37}
38
39impl TensorAsTyped<num_complex::Complex<f64>> for Tensor {
40    fn as_typed(&self) -> Option<&TypedTensor<num_complex::Complex<f64>>> {
41        match self {
42            Tensor::C64(tensor) => Some(tensor),
43            _ => None,
44        }
45    }
46}
47
48pub fn gather(operand: &Tensor, start_indices: &Tensor, config: &GatherConfig) -> Tensor {
49    let start_indices = index_tensor(start_indices);
50    dispatch_tensor!(operand, t => typed_gather(t, &start_indices, config))
51}
52
53pub fn scatter(
54    operand: &Tensor,
55    scatter_indices: &Tensor,
56    updates: &Tensor,
57    config: &ScatterConfig,
58) -> Tensor {
59    let scatter_indices = index_tensor(scatter_indices);
60    dispatch_binary!(operand, updates, |op, upd| typed_scatter(
61        op,
62        &scatter_indices,
63        upd,
64        config
65    ))
66}
67
68pub fn slice(input: &Tensor, config: &SliceConfig) -> Tensor {
69    try_slice(input, config).expect("slice")
70}
71
72pub fn try_slice(input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor> {
73    match input {
74        Tensor::F32(tensor) => Ok(Tensor::F32(typed_slice(tensor, config)?)),
75        Tensor::F64(tensor) => Ok(Tensor::F64(typed_slice(tensor, config)?)),
76        Tensor::C32(tensor) => Ok(Tensor::C32(typed_slice(tensor, config)?)),
77        Tensor::C64(tensor) => Ok(Tensor::C64(typed_slice(tensor, config)?)),
78    }
79}
80
81pub fn dynamic_slice(input: &Tensor, starts: &Tensor, slice_sizes: &[usize]) -> Tensor {
82    let starts = index_tensor(starts);
83    dispatch_tensor!(input, t => typed_dynamic_slice(t, &starts, slice_sizes))
84}
85
86pub fn pad(input: &Tensor, config: &PadConfig) -> Tensor {
87    try_pad(input, config).expect("pad")
88}
89
90pub fn try_pad(input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
91    match input {
92        Tensor::F32(tensor) => Ok(Tensor::F32(typed_pad(tensor, config)?)),
93        Tensor::F64(tensor) => Ok(Tensor::F64(typed_pad(tensor, config)?)),
94        Tensor::C32(tensor) => Ok(Tensor::C32(typed_pad(tensor, config)?)),
95        Tensor::C64(tensor) => Ok(Tensor::C64(typed_pad(tensor, config)?)),
96    }
97}
98
99pub fn concatenate(inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor> {
100    try_concatenate(inputs, axis)
101}
102
103pub fn try_concatenate(inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor> {
104    let first = inputs
105        .first()
106        .copied()
107        .ok_or_else(|| crate::Error::InvalidConfig {
108            op: "concatenate",
109            message: "concatenate requires at least one input".into(),
110        })?;
111    match first {
112        Tensor::F32(t) => Ok(Tensor::F32(typed_concatenate_from_dyn_inputs(
113            t, inputs, axis,
114        )?)),
115        Tensor::F64(t) => Ok(Tensor::F64(typed_concatenate_from_dyn_inputs(
116            t, inputs, axis,
117        )?)),
118        Tensor::C32(t) => Ok(Tensor::C32(typed_concatenate_from_dyn_inputs(
119            t, inputs, axis,
120        )?)),
121        Tensor::C64(t) => Ok(Tensor::C64(typed_concatenate_from_dyn_inputs(
122            t, inputs, axis,
123        )?)),
124    }
125}
126
127pub fn reverse(input: &Tensor, axes: &[usize]) -> Tensor {
128    dispatch_tensor!(input, tensor => typed_reverse(tensor, axes))
129}
130
131#[allow(clippy::uninit_vec)]
132fn typed_tensor_uninit<T: Clone>(shape: Vec<usize>) -> TypedTensor<T> {
133    let n: usize = shape.iter().product();
134    let mut data = Vec::with_capacity(n);
135    // SAFETY: callers only use this helper for outputs that are fully written
136    // before any read.
137    unsafe { data.set_len(n) };
138    TypedTensor::from_vec(shape, data)
139}
140
141fn typed_slice<T: Copy + Clone>(
142    input: &TypedTensor<T>,
143    config: &SliceConfig,
144) -> crate::Result<TypedTensor<T>> {
145    let rank = input.shape.len();
146    if config.starts.len() != rank {
147        return Err(crate::Error::RankMismatch {
148            op: "slice",
149            expected: rank,
150            actual: config.starts.len(),
151        });
152    }
153    if config.limits.len() != rank {
154        return Err(crate::Error::RankMismatch {
155            op: "slice",
156            expected: rank,
157            actual: config.limits.len(),
158        });
159    }
160    if config.strides.len() != rank {
161        return Err(crate::Error::RankMismatch {
162            op: "slice",
163            expected: rank,
164            actual: config.strides.len(),
165        });
166    }
167
168    let out_shape: Vec<usize> = input
169        .shape
170        .iter()
171        .enumerate()
172        .map(|(axis, &dim)| {
173            let start = config.starts[axis];
174            let limit = config.limits[axis];
175            let stride = config.strides[axis];
176            if start > limit {
177                return Err(crate::Error::InvalidConfig {
178                    op: "slice",
179                    message: format!("start exceeds limit on axis {axis}"),
180                });
181            }
182            if limit > dim {
183                return Err(crate::Error::AxisOutOfBounds {
184                    op: "slice",
185                    axis,
186                    rank,
187                });
188            }
189            if stride == 0 {
190                return Err(crate::Error::InvalidConfig {
191                    op: "slice",
192                    message: format!("stride must be positive on axis {axis}"),
193                });
194            }
195            let span = limit - start;
196            Ok((span + stride - 1) / stride)
197        })
198        .collect::<crate::Result<Vec<_>>>()?;
199
200    let out_len: usize = out_shape.iter().product();
201    let mut out_data = Vec::with_capacity(out_len);
202    let mut out_idx = vec![0usize; rank];
203    let mut in_idx = vec![0usize; rank];
204
205    for flat in 0..out_len {
206        flat_to_multi(flat, &out_shape, &mut out_idx);
207        for axis in 0..rank {
208            in_idx[axis] = config.starts[axis] + out_idx[axis] * config.strides[axis];
209        }
210        out_data.push(*input.get(&in_idx));
211    }
212
213    Ok(TypedTensor::from_vec(out_shape, out_data))
214}
215
216fn typed_concatenate_from_dyn_inputs<T>(
217    _first: &TypedTensor<T>,
218    inputs: &[&Tensor],
219    axis: usize,
220) -> crate::Result<TypedTensor<T>>
221where
222    T: Copy + Clone,
223    Tensor: TensorAsTyped<T>,
224{
225    let first_dtype = inputs[0].dtype();
226    let typed_inputs = collect_typed_inputs(first_dtype, inputs)?;
227    typed_concatenate(&typed_inputs, axis)
228}
229
230fn collect_typed_inputs<'a, T>(
231    first_dtype: crate::DType,
232    inputs: &[&'a Tensor],
233) -> crate::Result<Vec<&'a TypedTensor<T>>>
234where
235    Tensor: TensorAsTyped<T>,
236{
237    inputs
238        .iter()
239        .map(|tensor| {
240            TensorAsTyped::<T>::as_typed(*tensor).ok_or_else(|| crate::Error::DTypeMismatch {
241                op: "concatenate",
242                lhs: first_dtype,
243                rhs: tensor.dtype(),
244            })
245        })
246        .collect()
247}
248
249fn typed_concatenate<T: Copy + Clone>(
250    inputs: &[&TypedTensor<T>],
251    axis: usize,
252) -> crate::Result<TypedTensor<T>> {
253    let first = inputs[0];
254    let rank = first.shape.len();
255    if axis >= rank {
256        return Err(crate::Error::AxisOutOfBounds {
257            op: "concatenate",
258            axis,
259            rank,
260        });
261    }
262
263    let mut out_shape = first.shape.clone();
264    let mut axis_extent = 0usize;
265    for input in inputs {
266        if input.shape.len() != rank {
267            return Err(crate::Error::RankMismatch {
268                op: "concatenate",
269                expected: rank,
270                actual: input.shape.len(),
271            });
272        }
273        for dim in 0..rank {
274            if dim == axis {
275                axis_extent += input.shape[dim];
276            } else if input.shape[dim] != first.shape[dim] {
277                return Err(crate::Error::ShapeMismatch {
278                    op: "concatenate",
279                    lhs: first.shape.clone(),
280                    rhs: input.shape.clone(),
281                });
282            }
283        }
284    }
285    out_shape[axis] = axis_extent;
286
287    let segment_ends: Vec<usize> = inputs
288        .iter()
289        .scan(0usize, |sum, input| {
290            *sum += input.shape[axis];
291            Some(*sum)
292        })
293        .collect();
294
295    let out_len: usize = out_shape.iter().product();
296    let mut out_data = Vec::with_capacity(out_len);
297    let mut out_idx = vec![0usize; rank];
298    let mut in_idx = vec![0usize; rank];
299
300    for flat in 0..out_len {
301        flat_to_multi(flat, &out_shape, &mut out_idx);
302        let concat_idx = out_idx[axis];
303        let input_pos = segment_ends
304            .iter()
305            .position(|&end| concat_idx < end)
306            .expect("concatenate: output index must map to an input");
307        let axis_base = if input_pos == 0 {
308            0
309        } else {
310            segment_ends[input_pos - 1]
311        };
312
313        in_idx.copy_from_slice(&out_idx);
314        in_idx[axis] -= axis_base;
315        out_data.push(*inputs[input_pos].get(&in_idx));
316    }
317
318    Ok(TypedTensor::from_vec(out_shape, out_data))
319}
320
321fn typed_reverse<T: Copy + Clone>(input: &TypedTensor<T>, axes: &[usize]) -> TypedTensor<T> {
322    let rank = input.shape.len();
323    let mut reverse_axis = vec![false; rank];
324    for &axis in axes {
325        assert!(axis < rank, "reverse: axis out of bounds");
326        reverse_axis[axis] = true;
327    }
328
329    let out_len = input.n_elements();
330    let mut out_data = Vec::with_capacity(out_len);
331    let mut out_idx = vec![0usize; rank];
332    let mut in_idx = vec![0usize; rank];
333
334    for flat in 0..out_len {
335        flat_to_multi(flat, &input.shape, &mut out_idx);
336        for axis in 0..rank {
337            in_idx[axis] = if reverse_axis[axis] {
338                input.shape[axis] - 1 - out_idx[axis]
339            } else {
340                out_idx[axis]
341            };
342        }
343        out_data.push(*input.get(&in_idx));
344    }
345
346    TypedTensor::from_vec(input.shape.clone(), out_data)
347}
348
349struct IndexTensor {
350    shape: Vec<usize>,
351    values: Vec<i64>,
352}
353
354fn index_tensor(tensor: &Tensor) -> IndexTensor {
355    match tensor {
356        Tensor::F32(t) => IndexTensor {
357            shape: t.shape.clone(),
358            values: t.host_data().iter().map(|&value| value as i64).collect(),
359        },
360        Tensor::F64(t) => IndexTensor {
361            shape: t.shape.clone(),
362            values: t.host_data().iter().map(|&value| value as i64).collect(),
363        },
364        Tensor::C32(_) | Tensor::C64(_) => panic!("complex index tensors are not supported"),
365    }
366}
367
368fn linear_offset(shape: &[usize], indices: &[usize]) -> usize {
369    let mut offset = 0usize;
370    let mut stride = 1usize;
371    for (axis, &index) in indices.iter().enumerate() {
372        offset += index * stride;
373        stride *= shape[axis];
374    }
375    offset
376}
377
378fn product(shape: &[usize]) -> usize {
379    shape.iter().product()
380}
381
382fn index_vector_size(shape: &[usize], index_vector_dim: usize) -> usize {
383    if index_vector_dim == shape.len() {
384        1
385    } else {
386        shape[index_vector_dim]
387    }
388}
389
390fn index_batch_shape(shape: &[usize], index_vector_dim: usize) -> Vec<usize> {
391    if index_vector_dim == shape.len() {
392        return shape.to_vec();
393    }
394    shape
395        .iter()
396        .enumerate()
397        .filter_map(|(axis, &dim)| (axis != index_vector_dim).then_some(dim))
398        .collect()
399}
400
401fn index_component(
402    indices: &IndexTensor,
403    batch_idx: &[usize],
404    index_vector_dim: usize,
405    component: usize,
406) -> i64 {
407    if index_vector_dim == indices.shape.len() {
408        assert_eq!(
409            component, 0,
410            "implicit index_vector_dim only supports scalar index vectors"
411        );
412        return indices.values[linear_offset(&indices.shape, batch_idx)];
413    }
414
415    let mut full_idx = vec![0usize; indices.shape.len()];
416    let mut batch_axis = 0usize;
417    for (axis, slot) in full_idx.iter_mut().enumerate() {
418        if axis == index_vector_dim {
419            *slot = component;
420        } else {
421            *slot = batch_idx[batch_axis];
422            batch_axis += 1;
423        }
424    }
425    indices.values[linear_offset(&indices.shape, &full_idx)]
426}
427
428fn clamp_window_start(start: i64, dim_size: usize, window_size: usize) -> usize {
429    assert!(
430        window_size <= dim_size,
431        "window size {window_size} exceeds dimension size {dim_size}"
432    );
433    let max_start = dim_size.saturating_sub(window_size) as i64;
434    start.clamp(0, max_start) as usize
435}
436
437fn operand_window_dims(rank: usize, collapsed_or_inserted: &[usize]) -> Vec<usize> {
438    (0..rank)
439        .filter(|dim| !collapsed_or_inserted.contains(dim))
440        .collect()
441}
442
443fn typed_gather<T: Copy + Clone + Zero>(
444    operand: &TypedTensor<T>,
445    start_indices: &IndexTensor,
446    config: &GatherConfig,
447) -> TypedTensor<T> {
448    assert_eq!(
449        config.slice_sizes.len(),
450        operand.shape.len(),
451        "gather: slice_sizes rank mismatch"
452    );
453
454    let index_size = index_vector_size(&start_indices.shape, config.index_vector_dim);
455    assert_eq!(
456        index_size,
457        config.start_index_map.len(),
458        "gather: start_index_map length mismatch"
459    );
460
461    let window_dims = operand_window_dims(operand.shape.len(), &config.collapsed_slice_dims);
462    assert_eq!(
463        config.offset_dims.len(),
464        window_dims.len(),
465        "gather: offset_dims length mismatch"
466    );
467
468    let batch_shape = index_batch_shape(&start_indices.shape, config.index_vector_dim);
469    let out_rank = batch_shape.len() + config.offset_dims.len();
470    let mut out_shape = vec![0usize; out_rank];
471    let mut out_axis_to_operand_dim = vec![None; out_rank];
472    for (offset_axis, &out_axis) in config.offset_dims.iter().enumerate() {
473        out_axis_to_operand_dim[out_axis] = Some(window_dims[offset_axis]);
474    }
475
476    let mut batch_axis = 0usize;
477    for out_axis in 0..out_rank {
478        if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
479            out_shape[out_axis] = config.slice_sizes[operand_dim];
480        } else {
481            out_shape[out_axis] = batch_shape[batch_axis];
482            batch_axis += 1;
483        }
484    }
485
486    let mut out = typed_tensor_uninit(out_shape.clone());
487    let mut out_idx = vec![0usize; out_rank];
488    let mut batch_idx = vec![0usize; batch_shape.len()];
489    let mut operand_idx = vec![0usize; operand.shape.len()];
490    let mut window_offsets = vec![0usize; operand.shape.len()];
491
492    for flat in 0..out.n_elements() {
493        flat_to_multi(flat, &out_shape, &mut out_idx);
494
495        batch_axis = 0;
496        window_offsets.fill(0);
497        for out_axis in 0..out_rank {
498            if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
499                window_offsets[operand_dim] = out_idx[out_axis];
500            } else {
501                batch_idx[batch_axis] = out_idx[out_axis];
502                batch_axis += 1;
503            }
504        }
505
506        operand_idx.fill(0);
507        for (component, &operand_dim) in config.start_index_map.iter().enumerate() {
508            let start = index_component(
509                start_indices,
510                &batch_idx,
511                config.index_vector_dim,
512                component,
513            );
514            operand_idx[operand_dim] = clamp_window_start(
515                start,
516                operand.shape[operand_dim],
517                config.slice_sizes[operand_dim],
518            );
519        }
520
521        for axis in 0..operand_idx.len() {
522            operand_idx[axis] += window_offsets[axis];
523        }
524
525        *out.get_mut(&out_idx) = *operand.get(&operand_idx);
526    }
527
528    out
529}
530
531fn typed_scatter<T>(
532    operand: &TypedTensor<T>,
533    scatter_indices: &IndexTensor,
534    updates: &TypedTensor<T>,
535    config: &ScatterConfig,
536) -> TypedTensor<T>
537where
538    T: Copy + Clone + Zero + Add<Output = T>,
539{
540    let index_size = index_vector_size(&scatter_indices.shape, config.index_vector_dim);
541    assert_eq!(
542        index_size,
543        config.scatter_dims_to_operand_dims.len(),
544        "scatter: scatter_dims_to_operand_dims length mismatch"
545    );
546
547    let batch_shape = index_batch_shape(&scatter_indices.shape, config.index_vector_dim);
548    let window_dims = operand_window_dims(operand.shape.len(), &config.inserted_window_dims);
549    assert_eq!(
550        config.update_window_dims.len(),
551        window_dims.len(),
552        "scatter: update_window_dims length mismatch"
553    );
554
555    let update_rank = updates.shape.len();
556    let mut is_update_window_dim = vec![false; update_rank];
557    for &axis in &config.update_window_dims {
558        is_update_window_dim[axis] = true;
559    }
560    assert_eq!(
561        update_rank - config.update_window_dims.len(),
562        batch_shape.len(),
563        "scatter: updates batch rank mismatch"
564    );
565
566    let mut window_shape = vec![1usize; operand.shape.len()];
567    let mut window_shape_updates = vec![0usize; config.update_window_dims.len()];
568    for (pos, &update_axis) in config.update_window_dims.iter().enumerate() {
569        let dim = updates.shape[update_axis];
570        window_shape_updates[pos] = dim;
571        window_shape[window_dims[pos]] = dim;
572    }
573
574    let batch_elems = product(&batch_shape).max(1);
575    let window_elems = product(&window_shape_updates).max(1);
576    let mut out = TypedTensor::zeros(operand.shape.clone());
577
578    let mut batch_idx = vec![0usize; batch_shape.len()];
579    let mut window_idx = vec![0usize; window_shape_updates.len()];
580    let mut update_idx = vec![0usize; updates.shape.len()];
581    let mut operand_base = vec![0usize; operand.shape.len()];
582    let mut operand_idx = vec![0usize; operand.shape.len()];
583
584    for batch_flat in 0..batch_elems {
585        if !batch_shape.is_empty() {
586            flat_to_multi(batch_flat, &batch_shape, &mut batch_idx);
587        }
588
589        let mut window_fits = true;
590        operand_base.fill(0);
591        for (component, &operand_dim) in config.scatter_dims_to_operand_dims.iter().enumerate() {
592            let start = index_component(
593                scatter_indices,
594                &batch_idx,
595                config.index_vector_dim,
596                component,
597            );
598            if start < 0 {
599                window_fits = false;
600                break;
601            }
602            operand_base[operand_dim] = start as usize;
603        }
604        if !window_fits {
605            continue;
606        }
607
608        for axis in 0..operand.shape.len() {
609            if operand_base[axis] + window_shape[axis] > operand.shape[axis] {
610                window_fits = false;
611                break;
612            }
613        }
614        if !window_fits {
615            continue;
616        }
617
618        for window_flat in 0..window_elems {
619            if !window_shape_updates.is_empty() {
620                flat_to_multi(window_flat, &window_shape_updates, &mut window_idx);
621            }
622
623            let mut batch_axis = 0usize;
624            let mut window_axis = 0usize;
625            for axis in 0..updates.shape.len() {
626                if is_update_window_dim[axis] {
627                    update_idx[axis] = window_idx[window_axis];
628                    window_axis += 1;
629                } else {
630                    update_idx[axis] = batch_idx[batch_axis];
631                    batch_axis += 1;
632                }
633            }
634
635            operand_idx.copy_from_slice(&operand_base);
636            for (window_axis, &operand_axis) in window_dims.iter().enumerate() {
637                operand_idx[operand_axis] += window_idx[window_axis];
638            }
639
640            let value = *updates.get(&update_idx);
641            let slot = out.get_mut(&operand_idx);
642            *slot = *slot + value;
643        }
644    }
645
646    out
647}
648
649fn typed_dynamic_slice<T: Copy + Clone + Zero>(
650    input: &TypedTensor<T>,
651    starts: &IndexTensor,
652    slice_sizes: &[usize],
653) -> TypedTensor<T> {
654    assert_eq!(
655        slice_sizes.len(),
656        input.shape.len(),
657        "dynamic_slice: slice_sizes rank mismatch"
658    );
659    assert_eq!(
660        starts.shape.len(),
661        1,
662        "dynamic_slice: starts must be a rank-1 tensor"
663    );
664    assert_eq!(
665        starts.values.len(),
666        input.shape.len(),
667        "dynamic_slice: starts length must match input rank"
668    );
669
670    let mut clamped_starts = vec![0usize; input.shape.len()];
671    for axis in 0..input.shape.len() {
672        clamped_starts[axis] =
673            clamp_window_start(starts.values[axis], input.shape[axis], slice_sizes[axis]);
674    }
675
676    let out_shape = slice_sizes.to_vec();
677    let mut out = typed_tensor_uninit(out_shape.clone());
678    let mut out_idx = vec![0usize; out_shape.len()];
679    let mut input_idx = vec![0usize; out_shape.len()];
680
681    for flat in 0..out.n_elements() {
682        flat_to_multi(flat, &out_shape, &mut out_idx);
683        for axis in 0..out_shape.len() {
684            input_idx[axis] = clamped_starts[axis] + out_idx[axis];
685        }
686        *out.get_mut(&out_idx) = *input.get(&input_idx);
687    }
688
689    out
690}
691
692fn typed_pad<T: Copy + Clone + Zero>(
693    input: &TypedTensor<T>,
694    config: &PadConfig,
695) -> crate::Result<TypedTensor<T>> {
696    let rank = input.shape.len();
697    if config.edge_padding_low.len() != rank {
698        return Err(crate::Error::RankMismatch {
699            op: "pad",
700            expected: rank,
701            actual: config.edge_padding_low.len(),
702        });
703    }
704    if config.edge_padding_high.len() != rank {
705        return Err(crate::Error::RankMismatch {
706            op: "pad",
707            expected: rank,
708            actual: config.edge_padding_high.len(),
709        });
710    }
711    if config.interior_padding.len() != rank {
712        return Err(crate::Error::RankMismatch {
713            op: "pad",
714            expected: rank,
715            actual: config.interior_padding.len(),
716        });
717    }
718
719    let mut out_shape = Vec::with_capacity(input.shape.len());
720    for axis in 0..input.shape.len() {
721        if config.interior_padding[axis] < 0 {
722            return Err(crate::Error::InvalidConfig {
723                op: "pad",
724                message: format!("interior padding must be non-negative on axis {axis}"),
725            });
726        }
727        let base = if input.shape[axis] == 0 {
728            0
729        } else {
730            (input.shape[axis] as i64 - 1) * (config.interior_padding[axis] + 1) + 1
731        };
732        let dim = config.edge_padding_low[axis] + config.edge_padding_high[axis] + base;
733        out_shape.push(
734            usize::try_from(dim).map_err(|_| crate::Error::InvalidConfig {
735                op: "pad",
736                message: format!("negative output dimension on axis {axis}"),
737            })?,
738        );
739    }
740
741    let mut out = TypedTensor::zeros(out_shape.clone());
742    let mut input_idx = vec![0usize; input.shape.len()];
743    let mut out_idx = vec![0usize; input.shape.len()];
744
745    for flat in 0..input.n_elements() {
746        flat_to_multi(flat, &input.shape, &mut input_idx);
747        let mut in_bounds = true;
748        for axis in 0..input.shape.len() {
749            let out_pos = config.edge_padding_low[axis]
750                + input_idx[axis] as i64 * (config.interior_padding[axis] + 1);
751            if !(0..out_shape[axis] as i64).contains(&out_pos) {
752                in_bounds = false;
753                break;
754            }
755            out_idx[axis] = out_pos as usize;
756        }
757        if in_bounds {
758            *out.get_mut(&out_idx) = *input.get(&input_idx);
759        }
760    }
761
762    Ok(out)
763}