1use crate::{
2 checked_product, col_major_strides, validate_permutation, DynRank, Error, Result, ShapeVec,
3 SliceSpec, StrideVec, TensorRank,
4};
5use smallvec::SmallVec;
6use std::collections::HashSet;
7
8const MUTABLE_NO_OVERLAP_EXACT_ELEMENT_LIMIT: usize = 4096;
14
15pub(crate) fn reachable_offset_range(
16 shape: &[usize],
17 strides: &[isize],
18 offset: isize,
19) -> Result<Option<(isize, isize)>> {
20 if shape.contains(&0) {
21 return Ok(None);
22 }
23
24 let mut min = offset;
25 let mut max = offset;
26 for (&extent, &stride) in shape.iter().zip(strides) {
27 let last = isize::try_from(extent.saturating_sub(1)).map_err(|_| Error::IntegerOverflow)?;
28 let delta = last.checked_mul(stride).ok_or(Error::IntegerOverflow)?;
29 if delta < 0 {
30 min = min.checked_add(delta).ok_or(Error::IntegerOverflow)?;
31 } else {
32 max = max.checked_add(delta).ok_or(Error::IntegerOverflow)?;
33 }
34 }
35 Ok(Some((min, max)))
36}
37
38pub(crate) fn validate_reachable_bounds(
39 shape: &[usize],
40 strides: &[isize],
41 offset: isize,
42 buffer_len: usize,
43) -> Result<()> {
44 if shape.len() != strides.len() {
45 return Err(Error::RankMismatch {
46 expected: shape.len(),
47 actual: strides.len(),
48 });
49 }
50
51 match reachable_offset_range(shape, strides, offset)? {
52 Some((min, max)) => {
53 if min < 0 {
54 return Err(Error::ViewOutOfBounds);
55 }
56 let max = usize::try_from(max).map_err(|_| Error::IntegerOverflow)?;
57 if max < buffer_len {
58 Ok(())
59 } else {
60 Err(Error::ViewOutOfBounds)
61 }
62 }
63 None => {
64 if offset < 0 {
65 return Err(Error::ViewOutOfBounds);
66 }
67 let offset = usize::try_from(offset).map_err(|_| Error::IntegerOverflow)?;
68 if offset <= buffer_len {
69 Ok(())
70 } else {
71 Err(Error::ViewOutOfBounds)
72 }
73 }
74 }
75}
76
77fn layout_from_vecs<R: TensorRank>(
78 shape: ShapeVec,
79 strides: StrideVec,
80 offset: isize,
81 buffer_len: usize,
82) -> Result<TensorLayout<R>> {
83 TensorLayout::from_parts(
84 R::shape_from_vec(shape)?,
85 R::strides_from_vec(strides)?,
86 offset,
87 buffer_len,
88 )
89}
90
91fn positive_ceil_div(numerator: isize, denominator: isize) -> Result<usize> {
92 debug_assert!(numerator >= 0);
93 debug_assert!(denominator > 0);
94 let extent = if numerator == 0 {
95 0
96 } else {
97 1 + (numerator - 1) / denominator
98 };
99 usize::try_from(extent).map_err(|_| Error::IntegerOverflow)
100}
101
102fn normalize_slice(slice: SliceSpec, axis_len: usize) -> Result<(isize, usize)> {
103 if slice.step == 0 {
104 return Err(Error::InvalidSliceStep { step: slice.step });
105 }
106
107 let axis_len = isize::try_from(axis_len).map_err(|_| Error::IntegerOverflow)?;
108 if slice.step > 0 {
109 let start = if slice.start < 0 {
110 slice
111 .start
112 .checked_add(axis_len)
113 .ok_or(Error::IntegerOverflow)?
114 } else {
115 slice.start
116 };
117 let end = if slice.end < 0 {
118 slice
119 .end
120 .checked_add(axis_len)
121 .ok_or(Error::IntegerOverflow)?
122 } else {
123 slice.end
124 };
125 if start < 0 || start > axis_len || end < 0 || end > axis_len {
126 return Err(Error::InvalidSliceBounds {
127 start: slice.start,
128 end: slice.end,
129 axis_len: usize::try_from(axis_len).map_err(|_| Error::IntegerOverflow)?,
130 });
131 }
132 if start >= end {
133 return Ok((start, 0));
134 }
135 return Ok((start, positive_ceil_div(end - start, slice.step)?));
136 }
137
138 let start = if slice.start < 0 {
139 slice
140 .start
141 .checked_add(axis_len)
142 .ok_or(Error::IntegerOverflow)?
143 } else {
144 slice.start
145 };
146 let end = slice.end;
147 if start < 0 || start >= axis_len || end < -1 || end >= axis_len {
148 return Err(Error::InvalidSliceBounds {
149 start: slice.start,
150 end: slice.end,
151 axis_len: usize::try_from(axis_len).map_err(|_| Error::IntegerOverflow)?,
152 });
153 }
154 if start <= end {
155 return Ok((start, 0));
156 }
157 let step = slice.step.checked_neg().ok_or(Error::IntegerOverflow)?;
158 Ok((start, positive_ceil_div(start - end, step)?))
159}
160
161#[derive(Clone, Debug, PartialEq, Eq)]
174pub struct TensorLayout<R: TensorRank = DynRank> {
175 shape: R::Shape,
176 strides: R::Strides,
177 offset: isize,
178}
179
180impl<R: TensorRank> TensorLayout<R> {
181 pub fn compact(shape: R::Shape) -> Result<Self> {
193 let strides = R::strides_from_vec(col_major_strides(shape.as_ref())?)?;
194 Ok(Self {
195 shape,
196 strides,
197 offset: 0,
198 })
199 }
200
201 pub fn from_parts(
218 shape: R::Shape,
219 strides: R::Strides,
220 offset: isize,
221 buffer_len: usize,
222 ) -> Result<Self> {
223 validate_reachable_bounds(shape.as_ref(), strides.as_ref(), offset, buffer_len)?;
224 Ok(Self {
225 shape,
226 strides,
227 offset,
228 })
229 }
230
231 pub fn shape(&self) -> &[usize] {
243 self.shape.as_ref()
244 }
245
246 pub fn strides(&self) -> &[isize] {
258 self.strides.as_ref()
259 }
260
261 pub fn offset(&self) -> isize {
273 self.offset
274 }
275
276 pub fn is_compact_col_major(&self) -> bool {
288 col_major_strides(self.shape())
289 .map(|strides| strides.as_slice() == self.strides())
290 .unwrap_or(false)
291 }
292
293 pub fn validate_mutable_no_overlap(&self) -> Result<()> {
310 if self.shape().contains(&0) {
311 return Ok(());
312 }
313
314 for (&extent, &stride) in self.shape().iter().zip(self.strides()) {
315 if extent > 1 && stride == 0 {
316 return Err(Error::OverlappingMutableLayout);
317 }
318 }
319
320 let element_count = checked_product(self.shape())?;
321
322 let mut axes = self
323 .shape()
324 .iter()
325 .zip(self.strides())
326 .filter(|&(&extent, _)| extent > 1)
327 .map(|(&extent, &stride)| (extent, stride.unsigned_abs()))
328 .collect::<SmallVec<[(usize, usize); 8]>>();
329 axes.sort_by_key(|&(_, stride)| stride);
330
331 let mut span = 0usize;
332 for (extent, stride) in axes {
333 if stride <= span {
334 return self.validate_mutable_no_overlap_exact_or_reject(element_count);
335 }
336 span = span
337 .checked_add(
338 (extent - 1)
339 .checked_mul(stride)
340 .ok_or(Error::IntegerOverflow)?,
341 )
342 .ok_or(Error::IntegerOverflow)?;
343 }
344
345 Ok(())
346 }
347
348 fn validate_mutable_no_overlap_exact_or_reject(&self, element_count: usize) -> Result<()> {
349 if element_count > MUTABLE_NO_OVERLAP_EXACT_ELEMENT_LIMIT {
350 return Err(Error::OverlappingMutableLayout);
351 }
352
353 let mut seen = HashSet::with_capacity(element_count);
354 let rank = self.shape().len();
355 let mut indices = vec![0usize; rank];
356
357 loop {
358 let mut physical_offset = self.offset;
359 for (&index, &stride) in indices.iter().zip(self.strides()) {
360 let index = isize::try_from(index).map_err(|_| Error::IntegerOverflow)?;
361 let delta = index.checked_mul(stride).ok_or(Error::IntegerOverflow)?;
362 physical_offset = physical_offset
363 .checked_add(delta)
364 .ok_or(Error::IntegerOverflow)?;
365 }
366
367 if !seen.insert(physical_offset) {
368 return Err(Error::OverlappingMutableLayout);
369 }
370
371 let mut axis = 0;
372 while axis < rank {
373 indices[axis] += 1;
374 if indices[axis] < self.shape()[axis] {
375 break;
376 }
377 indices[axis] = 0;
378 axis += 1;
379 }
380 if axis == rank {
381 return Ok(());
382 }
383 }
384 }
385
386 pub fn transpose_view(&self, axes: impl AsRef<[usize]>) -> Result<Self> {
400 let axes = axes.as_ref();
401 validate_permutation(self.shape().len(), axes)?;
402 let shape = axes
403 .iter()
404 .map(|&axis| self.shape()[axis])
405 .collect::<ShapeVec>();
406 let strides = axes
407 .iter()
408 .map(|&axis| self.strides()[axis])
409 .collect::<StrideVec>();
410 Ok(Self {
411 shape: R::shape_from_vec(shape)?,
412 strides: R::strides_from_vec(strides)?,
413 offset: self.offset,
414 })
415 }
416
417 pub fn slice_view(&self, spec: impl AsRef<[SliceSpec]>, buffer_len: usize) -> Result<Self> {
431 let spec = spec.as_ref();
432 if spec.len() != self.shape().len() {
433 return Err(Error::RankMismatch {
434 expected: self.shape().len(),
435 actual: spec.len(),
436 });
437 }
438
439 let mut shape = ShapeVec::new();
440 let mut strides = StrideVec::new();
441 let mut offset = self.offset;
442 for ((&axis_len, &stride), &slice) in self
443 .shape()
444 .iter()
445 .zip(self.strides().iter())
446 .zip(spec.iter())
447 {
448 let (start, extent) = normalize_slice(slice, axis_len)?;
449 let start_offset = start.checked_mul(stride).ok_or(Error::IntegerOverflow)?;
450 offset = offset
451 .checked_add(start_offset)
452 .ok_or(Error::IntegerOverflow)?;
453 shape.push(extent);
454 strides.push(
455 stride
456 .checked_mul(slice.step)
457 .ok_or(Error::IntegerOverflow)?,
458 );
459 }
460 layout_from_vecs(shape, strides, offset, buffer_len)
461 }
462
463 pub fn reshape_view_as<R2: TensorRank>(
477 &self,
478 shape: R2::Shape,
479 buffer_len: usize,
480 ) -> Result<TensorLayout<R2>> {
481 if !self.is_compact_col_major() {
482 return Err(Error::NonContiguousViewAsSlice);
483 }
484 let from = checked_product(self.shape())?;
485 let to = checked_product(shape.as_ref())?;
486 if from != to {
487 return Err(Error::ReshapeElementCountMismatch { from, to });
488 }
489 let strides = R2::strides_from_vec(col_major_strides(shape.as_ref())?)?;
490 TensorLayout::from_parts(shape, strides, self.offset, buffer_len)
491 }
492
493 pub fn broadcast_in_dim_view<R2: TensorRank>(
507 &self,
508 shape: R2::Shape,
509 broadcast_dims: impl AsRef<[usize]>,
510 buffer_len: usize,
511 ) -> Result<TensorLayout<R2>> {
512 let broadcast_dims = broadcast_dims.as_ref();
513 if broadcast_dims.len() != self.shape().len() {
514 return Err(Error::RankMismatch {
515 expected: self.shape().len(),
516 actual: broadcast_dims.len(),
517 });
518 }
519
520 let output_rank = shape.as_ref().len();
521 let mut seen = vec![false; output_rank];
522 let mut strides = StrideVec::new();
523 strides.resize(output_rank, 0);
524 for (input_axis, &output_axis) in broadcast_dims.iter().enumerate() {
525 if output_axis >= output_rank {
526 return Err(Error::AxisOutOfBounds {
527 axis: output_axis,
528 rank: output_rank,
529 });
530 }
531 if seen[output_axis] {
532 return Err(Error::DuplicateAxis { axis: output_axis });
533 }
534 seen[output_axis] = true;
535
536 let input_extent = self.shape()[input_axis];
537 let output_extent = shape.as_ref()[output_axis];
538 if input_extent != output_extent && input_extent != 1 {
539 return Err(Error::ShapeDataLengthMismatch {
540 expected: input_extent,
541 actual: output_extent,
542 });
543 }
544 if input_extent == output_extent {
545 strides[output_axis] = self.strides()[input_axis];
546 }
547 }
548
549 TensorLayout::from_parts(
550 shape,
551 R2::strides_from_vec(strides)?,
552 self.offset,
553 buffer_len,
554 )
555 }
556}