Skip to main content

tenferro_tensor_core/
layout.rs

1use crate::{
2    checked_product, col_major_strides, validate_permutation, DynRank, Error, Result, ShapeVec,
3    SliceSpec, StrideVec, TensorRank,
4};
5use smallvec::SmallVec;
6use std::collections::HashSet;
7
8/// Maximum logical elements for exact mutable-overlap validation.
9///
10/// Larger layouts must pass the sufficient stride-span proof. This keeps the
11/// fallback bounded because it enumerates logical elements and stores visited
12/// physical offsets.
13const MUTABLE_NO_OVERLAP_EXACT_ELEMENT_LIMIT: usize = 4096;
14
15pub(crate) fn reachable_offset_range(
16    shape: &[usize],
17    strides: &[isize],
18    offset: isize,
19) -> Result<Option<(isize, isize)>> {
20    if shape.contains(&0) {
21        return Ok(None);
22    }
23
24    let mut min = offset;
25    let mut max = offset;
26    for (&extent, &stride) in shape.iter().zip(strides) {
27        let last = isize::try_from(extent.saturating_sub(1)).map_err(|_| Error::IntegerOverflow)?;
28        let delta = last.checked_mul(stride).ok_or(Error::IntegerOverflow)?;
29        if delta < 0 {
30            min = min.checked_add(delta).ok_or(Error::IntegerOverflow)?;
31        } else {
32            max = max.checked_add(delta).ok_or(Error::IntegerOverflow)?;
33        }
34    }
35    Ok(Some((min, max)))
36}
37
38pub(crate) fn validate_reachable_bounds(
39    shape: &[usize],
40    strides: &[isize],
41    offset: isize,
42    buffer_len: usize,
43) -> Result<()> {
44    if shape.len() != strides.len() {
45        return Err(Error::RankMismatch {
46            expected: shape.len(),
47            actual: strides.len(),
48        });
49    }
50
51    match reachable_offset_range(shape, strides, offset)? {
52        Some((min, max)) => {
53            if min < 0 {
54                return Err(Error::ViewOutOfBounds);
55            }
56            let max = usize::try_from(max).map_err(|_| Error::IntegerOverflow)?;
57            if max < buffer_len {
58                Ok(())
59            } else {
60                Err(Error::ViewOutOfBounds)
61            }
62        }
63        None => {
64            if offset < 0 {
65                return Err(Error::ViewOutOfBounds);
66            }
67            let offset = usize::try_from(offset).map_err(|_| Error::IntegerOverflow)?;
68            if offset <= buffer_len {
69                Ok(())
70            } else {
71                Err(Error::ViewOutOfBounds)
72            }
73        }
74    }
75}
76
77fn layout_from_vecs<R: TensorRank>(
78    shape: ShapeVec,
79    strides: StrideVec,
80    offset: isize,
81    buffer_len: usize,
82) -> Result<TensorLayout<R>> {
83    TensorLayout::from_parts(
84        R::shape_from_vec(shape)?,
85        R::strides_from_vec(strides)?,
86        offset,
87        buffer_len,
88    )
89}
90
91fn positive_ceil_div(numerator: isize, denominator: isize) -> Result<usize> {
92    debug_assert!(numerator >= 0);
93    debug_assert!(denominator > 0);
94    let extent = if numerator == 0 {
95        0
96    } else {
97        1 + (numerator - 1) / denominator
98    };
99    usize::try_from(extent).map_err(|_| Error::IntegerOverflow)
100}
101
102fn normalize_slice(slice: SliceSpec, axis_len: usize) -> Result<(isize, usize)> {
103    if slice.step == 0 {
104        return Err(Error::InvalidSliceStep { step: slice.step });
105    }
106
107    let axis_len = isize::try_from(axis_len).map_err(|_| Error::IntegerOverflow)?;
108    if slice.step > 0 {
109        let start = if slice.start < 0 {
110            slice
111                .start
112                .checked_add(axis_len)
113                .ok_or(Error::IntegerOverflow)?
114        } else {
115            slice.start
116        };
117        let end = if slice.end < 0 {
118            slice
119                .end
120                .checked_add(axis_len)
121                .ok_or(Error::IntegerOverflow)?
122        } else {
123            slice.end
124        };
125        if start < 0 || start > axis_len || end < 0 || end > axis_len {
126            return Err(Error::InvalidSliceBounds {
127                start: slice.start,
128                end: slice.end,
129                axis_len: usize::try_from(axis_len).map_err(|_| Error::IntegerOverflow)?,
130            });
131        }
132        if start >= end {
133            return Ok((start, 0));
134        }
135        return Ok((start, positive_ceil_div(end - start, slice.step)?));
136    }
137
138    let start = if slice.start < 0 {
139        slice
140            .start
141            .checked_add(axis_len)
142            .ok_or(Error::IntegerOverflow)?
143    } else {
144        slice.start
145    };
146    let end = slice.end;
147    if start < 0 || start >= axis_len || end < -1 || end >= axis_len {
148        return Err(Error::InvalidSliceBounds {
149            start: slice.start,
150            end: slice.end,
151            axis_len: usize::try_from(axis_len).map_err(|_| Error::IntegerOverflow)?,
152        });
153    }
154    if start <= end {
155        return Ok((start, 0));
156    }
157    let step = slice.step.checked_neg().ok_or(Error::IntegerOverflow)?;
158    Ok((start, positive_ceil_div(start - end, step)?))
159}
160
161/// Storage-neutral tensor layout metadata.
162///
163/// # Examples
164///
165/// ```rust
166/// use tenferro_tensor_core::{Rank, TensorLayout};
167///
168/// let layout = TensorLayout::<Rank<2>>::compact([2, 3])?;
169/// assert_eq!(layout.shape(), &[2, 3]);
170/// assert_eq!(layout.strides(), &[1, 2]);
171/// # Ok::<(), tenferro_tensor_core::Error>(())
172/// ```
173#[derive(Clone, Debug, PartialEq, Eq)]
174pub struct TensorLayout<R: TensorRank = DynRank> {
175    shape: R::Shape,
176    strides: R::Strides,
177    offset: isize,
178}
179
180impl<R: TensorRank> TensorLayout<R> {
181    /// Create a compact column-major layout with zero offset.
182    ///
183    /// # Examples
184    ///
185    /// ```rust
186    /// use tenferro_tensor_core::{Rank, TensorLayout};
187    ///
188    /// let layout = TensorLayout::<Rank<2>>::compact([2, 3])?;
189    /// assert_eq!(layout.strides(), &[1, 2]);
190    /// # Ok::<(), tenferro_tensor_core::Error>(())
191    /// ```
192    pub fn compact(shape: R::Shape) -> Result<Self> {
193        let strides = R::strides_from_vec(col_major_strides(shape.as_ref())?)?;
194        Ok(Self {
195            shape,
196            strides,
197            offset: 0,
198        })
199    }
200
201    /// Create a layout from shape, strides, element offset, and backing buffer length.
202    ///
203    /// # Examples
204    ///
205    /// ```rust
206    /// use tenferro_tensor_core::{DynRank, TensorLayout};
207    ///
208    /// let layout = TensorLayout::<DynRank>::from_parts(
209    ///     vec![2, 3].into(),
210    ///     vec![1, 2].into(),
211    ///     0,
212    ///     6,
213    /// )?;
214    /// assert!(layout.is_compact_col_major());
215    /// # Ok::<(), tenferro_tensor_core::Error>(())
216    /// ```
217    pub fn from_parts(
218        shape: R::Shape,
219        strides: R::Strides,
220        offset: isize,
221        buffer_len: usize,
222    ) -> Result<Self> {
223        validate_reachable_bounds(shape.as_ref(), strides.as_ref(), offset, buffer_len)?;
224        Ok(Self {
225            shape,
226            strides,
227            offset,
228        })
229    }
230
231    /// Return the layout shape.
232    ///
233    /// # Examples
234    ///
235    /// ```rust
236    /// use tenferro_tensor_core::{Rank, TensorLayout};
237    ///
238    /// let layout = TensorLayout::<Rank<1>>::compact([4])?;
239    /// assert_eq!(layout.shape(), &[4]);
240    /// # Ok::<(), tenferro_tensor_core::Error>(())
241    /// ```
242    pub fn shape(&self) -> &[usize] {
243        self.shape.as_ref()
244    }
245
246    /// Return the layout strides in element units.
247    ///
248    /// # Examples
249    ///
250    /// ```rust
251    /// use tenferro_tensor_core::{Rank, TensorLayout};
252    ///
253    /// let layout = TensorLayout::<Rank<2>>::compact([2, 3])?;
254    /// assert_eq!(layout.strides(), &[1, 2]);
255    /// # Ok::<(), tenferro_tensor_core::Error>(())
256    /// ```
257    pub fn strides(&self) -> &[isize] {
258        self.strides.as_ref()
259    }
260
261    /// Return the layout element offset.
262    ///
263    /// # Examples
264    ///
265    /// ```rust
266    /// use tenferro_tensor_core::{DynRank, TensorLayout};
267    ///
268    /// let layout = TensorLayout::<DynRank>::from_parts(vec![3].into(), vec![1].into(), 2, 5)?;
269    /// assert_eq!(layout.offset(), 2);
270    /// # Ok::<(), tenferro_tensor_core::Error>(())
271    /// ```
272    pub fn offset(&self) -> isize {
273        self.offset
274    }
275
276    /// Return whether the layout has compact column-major strides.
277    ///
278    /// # Examples
279    ///
280    /// ```rust
281    /// use tenferro_tensor_core::{Rank, TensorLayout};
282    ///
283    /// let layout = TensorLayout::<Rank<2>>::compact([2, 3])?;
284    /// assert!(layout.is_compact_col_major());
285    /// # Ok::<(), tenferro_tensor_core::Error>(())
286    /// ```
287    pub fn is_compact_col_major(&self) -> bool {
288        col_major_strides(self.shape())
289            .map(|strides| strides.as_slice() == self.strides())
290            .unwrap_or(false)
291    }
292
293    /// Validate that the layout can be used for mutable access without aliasing.
294    ///
295    /// Empty logical views are accepted. Non-empty layouts are accepted when a
296    /// conservative stride-span proof succeeds, or when exact enumeration of a
297    /// small bounded view proves that all logical elements map to distinct
298    /// physical offsets.
299    ///
300    /// # Examples
301    ///
302    /// ```rust
303    /// use tenferro_tensor_core::{DynRank, TensorLayout};
304    ///
305    /// let layout = TensorLayout::<DynRank>::from_parts(vec![3].into(), vec![-1].into(), 2, 3)?;
306    /// layout.validate_mutable_no_overlap()?;
307    /// # Ok::<(), tenferro_tensor_core::Error>(())
308    /// ```
309    pub fn validate_mutable_no_overlap(&self) -> Result<()> {
310        if self.shape().contains(&0) {
311            return Ok(());
312        }
313
314        for (&extent, &stride) in self.shape().iter().zip(self.strides()) {
315            if extent > 1 && stride == 0 {
316                return Err(Error::OverlappingMutableLayout);
317            }
318        }
319
320        let element_count = checked_product(self.shape())?;
321
322        let mut axes = self
323            .shape()
324            .iter()
325            .zip(self.strides())
326            .filter(|&(&extent, _)| extent > 1)
327            .map(|(&extent, &stride)| (extent, stride.unsigned_abs()))
328            .collect::<SmallVec<[(usize, usize); 8]>>();
329        axes.sort_by_key(|&(_, stride)| stride);
330
331        let mut span = 0usize;
332        for (extent, stride) in axes {
333            if stride <= span {
334                return self.validate_mutable_no_overlap_exact_or_reject(element_count);
335            }
336            span = span
337                .checked_add(
338                    (extent - 1)
339                        .checked_mul(stride)
340                        .ok_or(Error::IntegerOverflow)?,
341                )
342                .ok_or(Error::IntegerOverflow)?;
343        }
344
345        Ok(())
346    }
347
348    fn validate_mutable_no_overlap_exact_or_reject(&self, element_count: usize) -> Result<()> {
349        if element_count > MUTABLE_NO_OVERLAP_EXACT_ELEMENT_LIMIT {
350            return Err(Error::OverlappingMutableLayout);
351        }
352
353        let mut seen = HashSet::with_capacity(element_count);
354        let rank = self.shape().len();
355        let mut indices = vec![0usize; rank];
356
357        loop {
358            let mut physical_offset = self.offset;
359            for (&index, &stride) in indices.iter().zip(self.strides()) {
360                let index = isize::try_from(index).map_err(|_| Error::IntegerOverflow)?;
361                let delta = index.checked_mul(stride).ok_or(Error::IntegerOverflow)?;
362                physical_offset = physical_offset
363                    .checked_add(delta)
364                    .ok_or(Error::IntegerOverflow)?;
365            }
366
367            if !seen.insert(physical_offset) {
368                return Err(Error::OverlappingMutableLayout);
369            }
370
371            let mut axis = 0;
372            while axis < rank {
373                indices[axis] += 1;
374                if indices[axis] < self.shape()[axis] {
375                    break;
376                }
377                indices[axis] = 0;
378                axis += 1;
379            }
380            if axis == rank {
381                return Ok(());
382            }
383        }
384    }
385
386    /// Return a metadata-only axis permutation of this layout.
387    ///
388    /// # Examples
389    ///
390    /// ```rust
391    /// use tenferro_tensor_core::{Rank, TensorLayout};
392    ///
393    /// let layout = TensorLayout::<Rank<2>>::compact([2, 3])?;
394    /// let transposed = layout.transpose_view([1, 0])?;
395    /// assert_eq!(transposed.shape(), &[3, 2]);
396    /// assert_eq!(transposed.strides(), &[2, 1]);
397    /// # Ok::<(), tenferro_tensor_core::Error>(())
398    /// ```
399    pub fn transpose_view(&self, axes: impl AsRef<[usize]>) -> Result<Self> {
400        let axes = axes.as_ref();
401        validate_permutation(self.shape().len(), axes)?;
402        let shape = axes
403            .iter()
404            .map(|&axis| self.shape()[axis])
405            .collect::<ShapeVec>();
406        let strides = axes
407            .iter()
408            .map(|&axis| self.strides()[axis])
409            .collect::<StrideVec>();
410        Ok(Self {
411            shape: R::shape_from_vec(shape)?,
412            strides: R::strides_from_vec(strides)?,
413            offset: self.offset,
414        })
415    }
416
417    /// Return a metadata-only slice of this layout.
418    ///
419    /// # Examples
420    ///
421    /// ```rust
422    /// use tenferro_tensor_core::{Rank, SliceSpec, TensorLayout};
423    ///
424    /// let layout = TensorLayout::<Rank<1>>::compact([4])?;
425    /// let view = layout.slice_view([SliceSpec { start: 3, end: -1, step: -2 }], 4)?;
426    /// assert_eq!(view.shape(), &[2]);
427    /// assert_eq!(view.strides(), &[-2]);
428    /// # Ok::<(), tenferro_tensor_core::Error>(())
429    /// ```
430    pub fn slice_view(&self, spec: impl AsRef<[SliceSpec]>, buffer_len: usize) -> Result<Self> {
431        let spec = spec.as_ref();
432        if spec.len() != self.shape().len() {
433            return Err(Error::RankMismatch {
434                expected: self.shape().len(),
435                actual: spec.len(),
436            });
437        }
438
439        let mut shape = ShapeVec::new();
440        let mut strides = StrideVec::new();
441        let mut offset = self.offset;
442        for ((&axis_len, &stride), &slice) in self
443            .shape()
444            .iter()
445            .zip(self.strides().iter())
446            .zip(spec.iter())
447        {
448            let (start, extent) = normalize_slice(slice, axis_len)?;
449            let start_offset = start.checked_mul(stride).ok_or(Error::IntegerOverflow)?;
450            offset = offset
451                .checked_add(start_offset)
452                .ok_or(Error::IntegerOverflow)?;
453            shape.push(extent);
454            strides.push(
455                stride
456                    .checked_mul(slice.step)
457                    .ok_or(Error::IntegerOverflow)?,
458            );
459        }
460        layout_from_vecs(shape, strides, offset, buffer_len)
461    }
462
463    /// Return a metadata-only reshape of this compact column-major layout.
464    ///
465    /// # Examples
466    ///
467    /// ```rust
468    /// use tenferro_tensor_core::{Rank, TensorLayout};
469    ///
470    /// let layout = TensorLayout::<Rank<2>>::compact([2, 3])?;
471    /// let reshaped = layout.reshape_view_as::<Rank<1>>([6], 6)?;
472    /// assert_eq!(reshaped.shape(), &[6]);
473    /// assert_eq!(reshaped.strides(), &[1]);
474    /// # Ok::<(), tenferro_tensor_core::Error>(())
475    /// ```
476    pub fn reshape_view_as<R2: TensorRank>(
477        &self,
478        shape: R2::Shape,
479        buffer_len: usize,
480    ) -> Result<TensorLayout<R2>> {
481        if !self.is_compact_col_major() {
482            return Err(Error::NonContiguousViewAsSlice);
483        }
484        let from = checked_product(self.shape())?;
485        let to = checked_product(shape.as_ref())?;
486        if from != to {
487            return Err(Error::ReshapeElementCountMismatch { from, to });
488        }
489        let strides = R2::strides_from_vec(col_major_strides(shape.as_ref())?)?;
490        TensorLayout::from_parts(shape, strides, self.offset, buffer_len)
491    }
492
493    /// Return a metadata-only explicit broadcast of this layout into a target rank.
494    ///
495    /// # Examples
496    ///
497    /// ```rust
498    /// use tenferro_tensor_core::{Rank, TensorLayout};
499    ///
500    /// let layout = TensorLayout::<Rank<1>>::compact([3])?;
501    /// let broadcast = layout.broadcast_in_dim_view::<Rank<2>>([2, 3], [1], 3)?;
502    /// assert_eq!(broadcast.shape(), &[2, 3]);
503    /// assert_eq!(broadcast.strides(), &[0, 1]);
504    /// # Ok::<(), tenferro_tensor_core::Error>(())
505    /// ```
506    pub fn broadcast_in_dim_view<R2: TensorRank>(
507        &self,
508        shape: R2::Shape,
509        broadcast_dims: impl AsRef<[usize]>,
510        buffer_len: usize,
511    ) -> Result<TensorLayout<R2>> {
512        let broadcast_dims = broadcast_dims.as_ref();
513        if broadcast_dims.len() != self.shape().len() {
514            return Err(Error::RankMismatch {
515                expected: self.shape().len(),
516                actual: broadcast_dims.len(),
517            });
518        }
519
520        let output_rank = shape.as_ref().len();
521        let mut seen = vec![false; output_rank];
522        let mut strides = StrideVec::new();
523        strides.resize(output_rank, 0);
524        for (input_axis, &output_axis) in broadcast_dims.iter().enumerate() {
525            if output_axis >= output_rank {
526                return Err(Error::AxisOutOfBounds {
527                    axis: output_axis,
528                    rank: output_rank,
529                });
530            }
531            if seen[output_axis] {
532                return Err(Error::DuplicateAxis { axis: output_axis });
533            }
534            seen[output_axis] = true;
535
536            let input_extent = self.shape()[input_axis];
537            let output_extent = shape.as_ref()[output_axis];
538            if input_extent != output_extent && input_extent != 1 {
539                return Err(Error::ShapeDataLengthMismatch {
540                    expected: input_extent,
541                    actual: output_extent,
542                });
543            }
544            if input_extent == output_extent {
545                strides[output_axis] = self.strides()[input_axis];
546            }
547        }
548
549        TensorLayout::from_parts(
550            shape,
551            R2::strides_from_vec(strides)?,
552            self.offset,
553            buffer_len,
554        )
555    }
556}