Skip to main content

tenferro_tensor_core/
rank.rs

1use crate::{Error, Result, ShapeVec, StrideVec};
2use std::fmt::Debug;
3
4/// Rank contract for tensor metadata shapes and strides.
5///
6/// # Examples
7///
8/// ```rust
9/// use tenferro_tensor_core::{Rank, TensorRank};
10///
11/// let shape = <Rank<2> as TensorRank>::shape_from_vec(vec![2, 3].into())?;
12/// assert_eq!(shape.as_ref(), &[2, 3]);
13/// # Ok::<(), tenferro_tensor_core::Error>(())
14/// ```
15pub trait TensorRank: private::Sealed + Clone + Copy + Debug + Eq + Send + Sync + 'static {
16    /// Static rank when known at compile time.
17    ///
18    /// # Examples
19    ///
20    /// ```rust
21    /// use tenferro_tensor_core::{DynRank, Rank, TensorRank};
22    ///
23    /// assert_eq!(DynRank::RANK, None);
24    /// assert_eq!(Rank::<2>::RANK, Some(2));
25    /// ```
26    const RANK: Option<usize>;
27
28    /// Shape representation for this rank.
29    ///
30    /// # Examples
31    ///
32    /// ```rust
33    /// use tenferro_tensor_core::{DynRank, TensorRank};
34    ///
35    /// let shape: <DynRank as TensorRank>::Shape = vec![2, 3].into();
36    /// assert_eq!(shape.as_ref(), &[2, 3]);
37    /// ```
38    type Shape: Clone + Debug + PartialEq + Eq + AsRef<[usize]>;
39
40    /// Stride representation for this rank.
41    ///
42    /// # Examples
43    ///
44    /// ```rust
45    /// use tenferro_tensor_core::{DynRank, TensorRank};
46    ///
47    /// let strides: <DynRank as TensorRank>::Strides = vec![1, 2].into();
48    /// assert_eq!(strides.as_ref(), &[1, 2]);
49    /// ```
50    type Strides: Clone + Debug + PartialEq + Eq + AsRef<[isize]>;
51
52    /// Convert a dynamic shape vector into this rank's shape representation.
53    ///
54    /// # Examples
55    ///
56    /// ```rust
57    /// use tenferro_tensor_core::{Rank, TensorRank};
58    ///
59    /// let shape = <Rank<1> as TensorRank>::shape_from_vec(vec![4].into())?;
60    /// assert_eq!(shape.as_ref(), &[4]);
61    /// # Ok::<(), tenferro_tensor_core::Error>(())
62    /// ```
63    fn shape_from_vec(shape: ShapeVec) -> Result<Self::Shape>;
64
65    /// Convert this rank's shape representation into a dynamic shape vector.
66    ///
67    /// # Examples
68    ///
69    /// ```rust
70    /// use tenferro_tensor_core::{Rank, TensorRank};
71    ///
72    /// let shape = <Rank<2> as TensorRank>::shape_from_vec(vec![2, 3].into())?;
73    /// assert_eq!(<Rank<2> as TensorRank>::shape_into_vec(shape).as_slice(), &[2, 3]);
74    /// # Ok::<(), tenferro_tensor_core::Error>(())
75    /// ```
76    fn shape_into_vec(shape: Self::Shape) -> ShapeVec;
77
78    /// Convert a dynamic stride vector into this rank's stride representation.
79    ///
80    /// # Examples
81    ///
82    /// ```rust
83    /// use tenferro_tensor_core::{Rank, TensorRank};
84    ///
85    /// let strides = <Rank<2> as TensorRank>::strides_from_vec(vec![1, 2].into())?;
86    /// assert_eq!(strides.as_ref(), &[1, 2]);
87    /// # Ok::<(), tenferro_tensor_core::Error>(())
88    /// ```
89    fn strides_from_vec(strides: StrideVec) -> Result<Self::Strides>;
90
91    /// Convert this rank's stride representation into a dynamic stride vector.
92    ///
93    /// # Examples
94    ///
95    /// ```rust
96    /// use tenferro_tensor_core::{Rank, TensorRank};
97    ///
98    /// let strides = <Rank<2> as TensorRank>::strides_from_vec(vec![1, 2].into())?;
99    /// assert_eq!(<Rank<2> as TensorRank>::strides_into_vec(strides).as_slice(), &[1, 2]);
100    /// # Ok::<(), tenferro_tensor_core::Error>(())
101    /// ```
102    fn strides_into_vec(strides: Self::Strides) -> StrideVec;
103}
104
105/// Dynamic tensor rank marker.
106///
107/// # Examples
108///
109/// ```rust
110/// use tenferro_tensor_core::{DynRank, TensorRank};
111///
112/// assert_eq!(DynRank::RANK, None);
113/// ```
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
115pub struct DynRank;
116
117/// Static tensor rank marker.
118///
119/// # Examples
120///
121/// ```rust
122/// use tenferro_tensor_core::{Rank, TensorRank};
123///
124/// assert_eq!(Rank::<3>::RANK, Some(3));
125/// ```
126#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
127pub struct Rank<const N: usize>;
128
129impl TensorRank for DynRank {
130    const RANK: Option<usize> = None;
131
132    type Shape = ShapeVec;
133    type Strides = StrideVec;
134
135    fn shape_from_vec(shape: ShapeVec) -> Result<Self::Shape> {
136        Ok(shape)
137    }
138
139    fn shape_into_vec(shape: Self::Shape) -> ShapeVec {
140        shape
141    }
142
143    fn strides_from_vec(strides: StrideVec) -> Result<Self::Strides> {
144        Ok(strides)
145    }
146
147    fn strides_into_vec(strides: Self::Strides) -> StrideVec {
148        strides
149    }
150}
151
152impl<const N: usize> TensorRank for Rank<N> {
153    const RANK: Option<usize> = Some(N);
154
155    type Shape = [usize; N];
156    type Strides = [isize; N];
157
158    fn shape_from_vec(shape: ShapeVec) -> Result<Self::Shape> {
159        let actual = shape.len();
160        shape
161            .into_vec()
162            .try_into()
163            .map_err(|_| Error::RankMismatch {
164                expected: N,
165                actual,
166            })
167    }
168
169    fn shape_into_vec(shape: Self::Shape) -> ShapeVec {
170        ShapeVec::from_iter(shape)
171    }
172
173    fn strides_from_vec(strides: StrideVec) -> Result<Self::Strides> {
174        let actual = strides.len();
175        strides
176            .into_vec()
177            .try_into()
178            .map_err(|_| Error::RankMismatch {
179                expected: N,
180                actual,
181            })
182    }
183
184    fn strides_into_vec(strides: Self::Strides) -> StrideVec {
185        StrideVec::from_iter(strides)
186    }
187}
188
189mod private {
190    pub trait Sealed {}
191
192    impl Sealed for super::DynRank {}
193    impl<const N: usize> Sealed for super::Rank<N> {}
194}