1use std::ops::Add;
2
3use num_traits::Zero;
4
5use crate::config::{GatherConfig, PadConfig, ScatterConfig, SliceConfig};
6use crate::types::{dispatch_binary, dispatch_tensor, flat_to_multi, Tensor, TypedTensor};
7
8trait TensorAsTyped<T> {
9 fn as_typed(&self) -> Option<&TypedTensor<T>>;
10}
11
12impl TensorAsTyped<f32> for Tensor {
13 fn as_typed(&self) -> Option<&TypedTensor<f32>> {
14 match self {
15 Tensor::F32(tensor) => Some(tensor),
16 _ => None,
17 }
18 }
19}
20
21impl TensorAsTyped<f64> for Tensor {
22 fn as_typed(&self) -> Option<&TypedTensor<f64>> {
23 match self {
24 Tensor::F64(tensor) => Some(tensor),
25 _ => None,
26 }
27 }
28}
29
30impl TensorAsTyped<num_complex::Complex<f32>> for Tensor {
31 fn as_typed(&self) -> Option<&TypedTensor<num_complex::Complex<f32>>> {
32 match self {
33 Tensor::C32(tensor) => Some(tensor),
34 _ => None,
35 }
36 }
37}
38
39impl TensorAsTyped<num_complex::Complex<f64>> for Tensor {
40 fn as_typed(&self) -> Option<&TypedTensor<num_complex::Complex<f64>>> {
41 match self {
42 Tensor::C64(tensor) => Some(tensor),
43 _ => None,
44 }
45 }
46}
47
48pub fn gather(operand: &Tensor, start_indices: &Tensor, config: &GatherConfig) -> Tensor {
49 let start_indices = index_tensor(start_indices);
50 dispatch_tensor!(operand, t => typed_gather(t, &start_indices, config))
51}
52
53pub fn scatter(
54 operand: &Tensor,
55 scatter_indices: &Tensor,
56 updates: &Tensor,
57 config: &ScatterConfig,
58) -> Tensor {
59 let scatter_indices = index_tensor(scatter_indices);
60 dispatch_binary!(operand, updates, |op, upd| typed_scatter(
61 op,
62 &scatter_indices,
63 upd,
64 config
65 ))
66}
67
68pub fn slice(input: &Tensor, config: &SliceConfig) -> Tensor {
69 try_slice(input, config).expect("slice")
70}
71
72pub fn try_slice(input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor> {
73 match input {
74 Tensor::F32(tensor) => Ok(Tensor::F32(typed_slice(tensor, config)?)),
75 Tensor::F64(tensor) => Ok(Tensor::F64(typed_slice(tensor, config)?)),
76 Tensor::C32(tensor) => Ok(Tensor::C32(typed_slice(tensor, config)?)),
77 Tensor::C64(tensor) => Ok(Tensor::C64(typed_slice(tensor, config)?)),
78 }
79}
80
81pub fn dynamic_slice(input: &Tensor, starts: &Tensor, slice_sizes: &[usize]) -> Tensor {
82 let starts = index_tensor(starts);
83 dispatch_tensor!(input, t => typed_dynamic_slice(t, &starts, slice_sizes))
84}
85
86pub fn pad(input: &Tensor, config: &PadConfig) -> Tensor {
87 try_pad(input, config).expect("pad")
88}
89
90pub fn try_pad(input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
91 match input {
92 Tensor::F32(tensor) => Ok(Tensor::F32(typed_pad(tensor, config)?)),
93 Tensor::F64(tensor) => Ok(Tensor::F64(typed_pad(tensor, config)?)),
94 Tensor::C32(tensor) => Ok(Tensor::C32(typed_pad(tensor, config)?)),
95 Tensor::C64(tensor) => Ok(Tensor::C64(typed_pad(tensor, config)?)),
96 }
97}
98
99pub fn concatenate(inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor> {
100 try_concatenate(inputs, axis)
101}
102
103pub fn try_concatenate(inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor> {
104 let first = inputs
105 .first()
106 .copied()
107 .ok_or_else(|| crate::Error::InvalidConfig {
108 op: "concatenate",
109 message: "concatenate requires at least one input".into(),
110 })?;
111 match first {
112 Tensor::F32(t) => Ok(Tensor::F32(typed_concatenate_from_dyn_inputs(
113 t, inputs, axis,
114 )?)),
115 Tensor::F64(t) => Ok(Tensor::F64(typed_concatenate_from_dyn_inputs(
116 t, inputs, axis,
117 )?)),
118 Tensor::C32(t) => Ok(Tensor::C32(typed_concatenate_from_dyn_inputs(
119 t, inputs, axis,
120 )?)),
121 Tensor::C64(t) => Ok(Tensor::C64(typed_concatenate_from_dyn_inputs(
122 t, inputs, axis,
123 )?)),
124 }
125}
126
127pub fn reverse(input: &Tensor, axes: &[usize]) -> Tensor {
128 dispatch_tensor!(input, tensor => typed_reverse(tensor, axes))
129}
130
131#[allow(clippy::uninit_vec)]
132fn typed_tensor_uninit<T: Clone>(shape: Vec<usize>) -> TypedTensor<T> {
133 let n: usize = shape.iter().product();
134 let mut data = Vec::with_capacity(n);
135 unsafe { data.set_len(n) };
138 TypedTensor::from_vec(shape, data)
139}
140
141fn typed_slice<T: Copy + Clone>(
142 input: &TypedTensor<T>,
143 config: &SliceConfig,
144) -> crate::Result<TypedTensor<T>> {
145 let rank = input.shape.len();
146 if config.starts.len() != rank {
147 return Err(crate::Error::RankMismatch {
148 op: "slice",
149 expected: rank,
150 actual: config.starts.len(),
151 });
152 }
153 if config.limits.len() != rank {
154 return Err(crate::Error::RankMismatch {
155 op: "slice",
156 expected: rank,
157 actual: config.limits.len(),
158 });
159 }
160 if config.strides.len() != rank {
161 return Err(crate::Error::RankMismatch {
162 op: "slice",
163 expected: rank,
164 actual: config.strides.len(),
165 });
166 }
167
168 let out_shape: Vec<usize> = input
169 .shape
170 .iter()
171 .enumerate()
172 .map(|(axis, &dim)| {
173 let start = config.starts[axis];
174 let limit = config.limits[axis];
175 let stride = config.strides[axis];
176 if start > limit {
177 return Err(crate::Error::InvalidConfig {
178 op: "slice",
179 message: format!("start exceeds limit on axis {axis}"),
180 });
181 }
182 if limit > dim {
183 return Err(crate::Error::AxisOutOfBounds {
184 op: "slice",
185 axis,
186 rank,
187 });
188 }
189 if stride == 0 {
190 return Err(crate::Error::InvalidConfig {
191 op: "slice",
192 message: format!("stride must be positive on axis {axis}"),
193 });
194 }
195 let span = limit - start;
196 Ok((span + stride - 1) / stride)
197 })
198 .collect::<crate::Result<Vec<_>>>()?;
199
200 let out_len: usize = out_shape.iter().product();
201 let mut out_data = Vec::with_capacity(out_len);
202 let mut out_idx = vec![0usize; rank];
203 let mut in_idx = vec![0usize; rank];
204
205 for flat in 0..out_len {
206 flat_to_multi(flat, &out_shape, &mut out_idx);
207 for axis in 0..rank {
208 in_idx[axis] = config.starts[axis] + out_idx[axis] * config.strides[axis];
209 }
210 out_data.push(*input.get(&in_idx));
211 }
212
213 Ok(TypedTensor::from_vec(out_shape, out_data))
214}
215
216fn typed_concatenate_from_dyn_inputs<T>(
217 _first: &TypedTensor<T>,
218 inputs: &[&Tensor],
219 axis: usize,
220) -> crate::Result<TypedTensor<T>>
221where
222 T: Copy + Clone,
223 Tensor: TensorAsTyped<T>,
224{
225 let first_dtype = inputs[0].dtype();
226 let typed_inputs = collect_typed_inputs(first_dtype, inputs)?;
227 typed_concatenate(&typed_inputs, axis)
228}
229
230fn collect_typed_inputs<'a, T>(
231 first_dtype: crate::DType,
232 inputs: &[&'a Tensor],
233) -> crate::Result<Vec<&'a TypedTensor<T>>>
234where
235 Tensor: TensorAsTyped<T>,
236{
237 inputs
238 .iter()
239 .map(|tensor| {
240 TensorAsTyped::<T>::as_typed(*tensor).ok_or_else(|| crate::Error::DTypeMismatch {
241 op: "concatenate",
242 lhs: first_dtype,
243 rhs: tensor.dtype(),
244 })
245 })
246 .collect()
247}
248
249fn typed_concatenate<T: Copy + Clone>(
250 inputs: &[&TypedTensor<T>],
251 axis: usize,
252) -> crate::Result<TypedTensor<T>> {
253 let first = inputs[0];
254 let rank = first.shape.len();
255 if axis >= rank {
256 return Err(crate::Error::AxisOutOfBounds {
257 op: "concatenate",
258 axis,
259 rank,
260 });
261 }
262
263 let mut out_shape = first.shape.clone();
264 let mut axis_extent = 0usize;
265 for input in inputs {
266 if input.shape.len() != rank {
267 return Err(crate::Error::RankMismatch {
268 op: "concatenate",
269 expected: rank,
270 actual: input.shape.len(),
271 });
272 }
273 for dim in 0..rank {
274 if dim == axis {
275 axis_extent += input.shape[dim];
276 } else if input.shape[dim] != first.shape[dim] {
277 return Err(crate::Error::ShapeMismatch {
278 op: "concatenate",
279 lhs: first.shape.clone(),
280 rhs: input.shape.clone(),
281 });
282 }
283 }
284 }
285 out_shape[axis] = axis_extent;
286
287 let segment_ends: Vec<usize> = inputs
288 .iter()
289 .scan(0usize, |sum, input| {
290 *sum += input.shape[axis];
291 Some(*sum)
292 })
293 .collect();
294
295 let out_len: usize = out_shape.iter().product();
296 let mut out_data = Vec::with_capacity(out_len);
297 let mut out_idx = vec![0usize; rank];
298 let mut in_idx = vec![0usize; rank];
299
300 for flat in 0..out_len {
301 flat_to_multi(flat, &out_shape, &mut out_idx);
302 let concat_idx = out_idx[axis];
303 let input_pos = segment_ends
304 .iter()
305 .position(|&end| concat_idx < end)
306 .expect("concatenate: output index must map to an input");
307 let axis_base = if input_pos == 0 {
308 0
309 } else {
310 segment_ends[input_pos - 1]
311 };
312
313 in_idx.copy_from_slice(&out_idx);
314 in_idx[axis] -= axis_base;
315 out_data.push(*inputs[input_pos].get(&in_idx));
316 }
317
318 Ok(TypedTensor::from_vec(out_shape, out_data))
319}
320
321fn typed_reverse<T: Copy + Clone>(input: &TypedTensor<T>, axes: &[usize]) -> TypedTensor<T> {
322 let rank = input.shape.len();
323 let mut reverse_axis = vec![false; rank];
324 for &axis in axes {
325 assert!(axis < rank, "reverse: axis out of bounds");
326 reverse_axis[axis] = true;
327 }
328
329 let out_len = input.n_elements();
330 let mut out_data = Vec::with_capacity(out_len);
331 let mut out_idx = vec![0usize; rank];
332 let mut in_idx = vec![0usize; rank];
333
334 for flat in 0..out_len {
335 flat_to_multi(flat, &input.shape, &mut out_idx);
336 for axis in 0..rank {
337 in_idx[axis] = if reverse_axis[axis] {
338 input.shape[axis] - 1 - out_idx[axis]
339 } else {
340 out_idx[axis]
341 };
342 }
343 out_data.push(*input.get(&in_idx));
344 }
345
346 TypedTensor::from_vec(input.shape.clone(), out_data)
347}
348
349struct IndexTensor {
350 shape: Vec<usize>,
351 values: Vec<i64>,
352}
353
354fn index_tensor(tensor: &Tensor) -> IndexTensor {
355 match tensor {
356 Tensor::F32(t) => IndexTensor {
357 shape: t.shape.clone(),
358 values: t.host_data().iter().map(|&value| value as i64).collect(),
359 },
360 Tensor::F64(t) => IndexTensor {
361 shape: t.shape.clone(),
362 values: t.host_data().iter().map(|&value| value as i64).collect(),
363 },
364 Tensor::C32(_) | Tensor::C64(_) => panic!("complex index tensors are not supported"),
365 }
366}
367
368fn linear_offset(shape: &[usize], indices: &[usize]) -> usize {
369 let mut offset = 0usize;
370 let mut stride = 1usize;
371 for (axis, &index) in indices.iter().enumerate() {
372 offset += index * stride;
373 stride *= shape[axis];
374 }
375 offset
376}
377
378fn product(shape: &[usize]) -> usize {
379 shape.iter().product()
380}
381
382fn index_vector_size(shape: &[usize], index_vector_dim: usize) -> usize {
383 if index_vector_dim == shape.len() {
384 1
385 } else {
386 shape[index_vector_dim]
387 }
388}
389
390fn index_batch_shape(shape: &[usize], index_vector_dim: usize) -> Vec<usize> {
391 if index_vector_dim == shape.len() {
392 return shape.to_vec();
393 }
394 shape
395 .iter()
396 .enumerate()
397 .filter_map(|(axis, &dim)| (axis != index_vector_dim).then_some(dim))
398 .collect()
399}
400
401fn index_component(
402 indices: &IndexTensor,
403 batch_idx: &[usize],
404 index_vector_dim: usize,
405 component: usize,
406) -> i64 {
407 if index_vector_dim == indices.shape.len() {
408 assert_eq!(
409 component, 0,
410 "implicit index_vector_dim only supports scalar index vectors"
411 );
412 return indices.values[linear_offset(&indices.shape, batch_idx)];
413 }
414
415 let mut full_idx = vec![0usize; indices.shape.len()];
416 let mut batch_axis = 0usize;
417 for (axis, slot) in full_idx.iter_mut().enumerate() {
418 if axis == index_vector_dim {
419 *slot = component;
420 } else {
421 *slot = batch_idx[batch_axis];
422 batch_axis += 1;
423 }
424 }
425 indices.values[linear_offset(&indices.shape, &full_idx)]
426}
427
428fn clamp_window_start(start: i64, dim_size: usize, window_size: usize) -> usize {
429 assert!(
430 window_size <= dim_size,
431 "window size {window_size} exceeds dimension size {dim_size}"
432 );
433 let max_start = dim_size.saturating_sub(window_size) as i64;
434 start.clamp(0, max_start) as usize
435}
436
437fn operand_window_dims(rank: usize, collapsed_or_inserted: &[usize]) -> Vec<usize> {
438 (0..rank)
439 .filter(|dim| !collapsed_or_inserted.contains(dim))
440 .collect()
441}
442
443fn typed_gather<T: Copy + Clone + Zero>(
444 operand: &TypedTensor<T>,
445 start_indices: &IndexTensor,
446 config: &GatherConfig,
447) -> TypedTensor<T> {
448 assert_eq!(
449 config.slice_sizes.len(),
450 operand.shape.len(),
451 "gather: slice_sizes rank mismatch"
452 );
453
454 let index_size = index_vector_size(&start_indices.shape, config.index_vector_dim);
455 assert_eq!(
456 index_size,
457 config.start_index_map.len(),
458 "gather: start_index_map length mismatch"
459 );
460
461 let window_dims = operand_window_dims(operand.shape.len(), &config.collapsed_slice_dims);
462 assert_eq!(
463 config.offset_dims.len(),
464 window_dims.len(),
465 "gather: offset_dims length mismatch"
466 );
467
468 let batch_shape = index_batch_shape(&start_indices.shape, config.index_vector_dim);
469 let out_rank = batch_shape.len() + config.offset_dims.len();
470 let mut out_shape = vec![0usize; out_rank];
471 let mut out_axis_to_operand_dim = vec![None; out_rank];
472 for (offset_axis, &out_axis) in config.offset_dims.iter().enumerate() {
473 out_axis_to_operand_dim[out_axis] = Some(window_dims[offset_axis]);
474 }
475
476 let mut batch_axis = 0usize;
477 for out_axis in 0..out_rank {
478 if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
479 out_shape[out_axis] = config.slice_sizes[operand_dim];
480 } else {
481 out_shape[out_axis] = batch_shape[batch_axis];
482 batch_axis += 1;
483 }
484 }
485
486 let mut out = typed_tensor_uninit(out_shape.clone());
487 let mut out_idx = vec![0usize; out_rank];
488 let mut batch_idx = vec![0usize; batch_shape.len()];
489 let mut operand_idx = vec![0usize; operand.shape.len()];
490 let mut window_offsets = vec![0usize; operand.shape.len()];
491
492 for flat in 0..out.n_elements() {
493 flat_to_multi(flat, &out_shape, &mut out_idx);
494
495 batch_axis = 0;
496 window_offsets.fill(0);
497 for out_axis in 0..out_rank {
498 if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
499 window_offsets[operand_dim] = out_idx[out_axis];
500 } else {
501 batch_idx[batch_axis] = out_idx[out_axis];
502 batch_axis += 1;
503 }
504 }
505
506 operand_idx.fill(0);
507 for (component, &operand_dim) in config.start_index_map.iter().enumerate() {
508 let start = index_component(
509 start_indices,
510 &batch_idx,
511 config.index_vector_dim,
512 component,
513 );
514 operand_idx[operand_dim] = clamp_window_start(
515 start,
516 operand.shape[operand_dim],
517 config.slice_sizes[operand_dim],
518 );
519 }
520
521 for axis in 0..operand_idx.len() {
522 operand_idx[axis] += window_offsets[axis];
523 }
524
525 *out.get_mut(&out_idx) = *operand.get(&operand_idx);
526 }
527
528 out
529}
530
531fn typed_scatter<T>(
532 operand: &TypedTensor<T>,
533 scatter_indices: &IndexTensor,
534 updates: &TypedTensor<T>,
535 config: &ScatterConfig,
536) -> TypedTensor<T>
537where
538 T: Copy + Clone + Zero + Add<Output = T>,
539{
540 let index_size = index_vector_size(&scatter_indices.shape, config.index_vector_dim);
541 assert_eq!(
542 index_size,
543 config.scatter_dims_to_operand_dims.len(),
544 "scatter: scatter_dims_to_operand_dims length mismatch"
545 );
546
547 let batch_shape = index_batch_shape(&scatter_indices.shape, config.index_vector_dim);
548 let window_dims = operand_window_dims(operand.shape.len(), &config.inserted_window_dims);
549 assert_eq!(
550 config.update_window_dims.len(),
551 window_dims.len(),
552 "scatter: update_window_dims length mismatch"
553 );
554
555 let update_rank = updates.shape.len();
556 let mut is_update_window_dim = vec![false; update_rank];
557 for &axis in &config.update_window_dims {
558 is_update_window_dim[axis] = true;
559 }
560 assert_eq!(
561 update_rank - config.update_window_dims.len(),
562 batch_shape.len(),
563 "scatter: updates batch rank mismatch"
564 );
565
566 let mut window_shape = vec![1usize; operand.shape.len()];
567 let mut window_shape_updates = vec![0usize; config.update_window_dims.len()];
568 for (pos, &update_axis) in config.update_window_dims.iter().enumerate() {
569 let dim = updates.shape[update_axis];
570 window_shape_updates[pos] = dim;
571 window_shape[window_dims[pos]] = dim;
572 }
573
574 let batch_elems = product(&batch_shape).max(1);
575 let window_elems = product(&window_shape_updates).max(1);
576 let mut out = TypedTensor::zeros(operand.shape.clone());
577
578 let mut batch_idx = vec![0usize; batch_shape.len()];
579 let mut window_idx = vec![0usize; window_shape_updates.len()];
580 let mut update_idx = vec![0usize; updates.shape.len()];
581 let mut operand_base = vec![0usize; operand.shape.len()];
582 let mut operand_idx = vec![0usize; operand.shape.len()];
583
584 for batch_flat in 0..batch_elems {
585 if !batch_shape.is_empty() {
586 flat_to_multi(batch_flat, &batch_shape, &mut batch_idx);
587 }
588
589 let mut window_fits = true;
590 operand_base.fill(0);
591 for (component, &operand_dim) in config.scatter_dims_to_operand_dims.iter().enumerate() {
592 let start = index_component(
593 scatter_indices,
594 &batch_idx,
595 config.index_vector_dim,
596 component,
597 );
598 if start < 0 {
599 window_fits = false;
600 break;
601 }
602 operand_base[operand_dim] = start as usize;
603 }
604 if !window_fits {
605 continue;
606 }
607
608 for axis in 0..operand.shape.len() {
609 if operand_base[axis] + window_shape[axis] > operand.shape[axis] {
610 window_fits = false;
611 break;
612 }
613 }
614 if !window_fits {
615 continue;
616 }
617
618 for window_flat in 0..window_elems {
619 if !window_shape_updates.is_empty() {
620 flat_to_multi(window_flat, &window_shape_updates, &mut window_idx);
621 }
622
623 let mut batch_axis = 0usize;
624 let mut window_axis = 0usize;
625 for axis in 0..updates.shape.len() {
626 if is_update_window_dim[axis] {
627 update_idx[axis] = window_idx[window_axis];
628 window_axis += 1;
629 } else {
630 update_idx[axis] = batch_idx[batch_axis];
631 batch_axis += 1;
632 }
633 }
634
635 operand_idx.copy_from_slice(&operand_base);
636 for (window_axis, &operand_axis) in window_dims.iter().enumerate() {
637 operand_idx[operand_axis] += window_idx[window_axis];
638 }
639
640 let value = *updates.get(&update_idx);
641 let slot = out.get_mut(&operand_idx);
642 *slot = *slot + value;
643 }
644 }
645
646 out
647}
648
649fn typed_dynamic_slice<T: Copy + Clone + Zero>(
650 input: &TypedTensor<T>,
651 starts: &IndexTensor,
652 slice_sizes: &[usize],
653) -> TypedTensor<T> {
654 assert_eq!(
655 slice_sizes.len(),
656 input.shape.len(),
657 "dynamic_slice: slice_sizes rank mismatch"
658 );
659 assert_eq!(
660 starts.shape.len(),
661 1,
662 "dynamic_slice: starts must be a rank-1 tensor"
663 );
664 assert_eq!(
665 starts.values.len(),
666 input.shape.len(),
667 "dynamic_slice: starts length must match input rank"
668 );
669
670 let mut clamped_starts = vec![0usize; input.shape.len()];
671 for axis in 0..input.shape.len() {
672 clamped_starts[axis] =
673 clamp_window_start(starts.values[axis], input.shape[axis], slice_sizes[axis]);
674 }
675
676 let out_shape = slice_sizes.to_vec();
677 let mut out = typed_tensor_uninit(out_shape.clone());
678 let mut out_idx = vec![0usize; out_shape.len()];
679 let mut input_idx = vec![0usize; out_shape.len()];
680
681 for flat in 0..out.n_elements() {
682 flat_to_multi(flat, &out_shape, &mut out_idx);
683 for axis in 0..out_shape.len() {
684 input_idx[axis] = clamped_starts[axis] + out_idx[axis];
685 }
686 *out.get_mut(&out_idx) = *input.get(&input_idx);
687 }
688
689 out
690}
691
692fn typed_pad<T: Copy + Clone + Zero>(
693 input: &TypedTensor<T>,
694 config: &PadConfig,
695) -> crate::Result<TypedTensor<T>> {
696 let rank = input.shape.len();
697 if config.edge_padding_low.len() != rank {
698 return Err(crate::Error::RankMismatch {
699 op: "pad",
700 expected: rank,
701 actual: config.edge_padding_low.len(),
702 });
703 }
704 if config.edge_padding_high.len() != rank {
705 return Err(crate::Error::RankMismatch {
706 op: "pad",
707 expected: rank,
708 actual: config.edge_padding_high.len(),
709 });
710 }
711 if config.interior_padding.len() != rank {
712 return Err(crate::Error::RankMismatch {
713 op: "pad",
714 expected: rank,
715 actual: config.interior_padding.len(),
716 });
717 }
718
719 let mut out_shape = Vec::with_capacity(input.shape.len());
720 for axis in 0..input.shape.len() {
721 if config.interior_padding[axis] < 0 {
722 return Err(crate::Error::InvalidConfig {
723 op: "pad",
724 message: format!("interior padding must be non-negative on axis {axis}"),
725 });
726 }
727 let base = if input.shape[axis] == 0 {
728 0
729 } else {
730 (input.shape[axis] as i64 - 1) * (config.interior_padding[axis] + 1) + 1
731 };
732 let dim = config.edge_padding_low[axis] + config.edge_padding_high[axis] + base;
733 out_shape.push(
734 usize::try_from(dim).map_err(|_| crate::Error::InvalidConfig {
735 op: "pad",
736 message: format!("negative output dimension on axis {axis}"),
737 })?,
738 );
739 }
740
741 let mut out = TypedTensor::zeros(out_shape.clone());
742 let mut input_idx = vec![0usize; input.shape.len()];
743 let mut out_idx = vec![0usize; input.shape.len()];
744
745 for flat in 0..input.n_elements() {
746 flat_to_multi(flat, &input.shape, &mut input_idx);
747 let mut in_bounds = true;
748 for axis in 0..input.shape.len() {
749 let out_pos = config.edge_padding_low[axis]
750 + input_idx[axis] as i64 * (config.interior_padding[axis] + 1);
751 if !(0..out_shape[axis] as i64).contains(&out_pos) {
752 in_bounds = false;
753 break;
754 }
755 out_idx[axis] = out_pos as usize;
756 }
757 if in_bounds {
758 *out.get_mut(&out_idx) = *input.get(&input_idx);
759 }
760 }
761
762 Ok(out)
763}