tenferro_tensor_core/rank.rs
1use crate::{Error, Result, ShapeVec, StrideVec};
2use std::fmt::Debug;
3
4/// Rank contract for tensor metadata shapes and strides.
5///
6/// # Examples
7///
8/// ```rust
9/// use tenferro_tensor_core::{Rank, TensorRank};
10///
11/// let shape = <Rank<2> as TensorRank>::shape_from_vec(vec![2, 3].into())?;
12/// assert_eq!(shape.as_ref(), &[2, 3]);
13/// # Ok::<(), tenferro_tensor_core::Error>(())
14/// ```
15pub trait TensorRank: private::Sealed + Clone + Copy + Debug + Eq + Send + Sync + 'static {
16 /// Static rank when known at compile time.
17 ///
18 /// # Examples
19 ///
20 /// ```rust
21 /// use tenferro_tensor_core::{DynRank, Rank, TensorRank};
22 ///
23 /// assert_eq!(DynRank::RANK, None);
24 /// assert_eq!(Rank::<2>::RANK, Some(2));
25 /// ```
26 const RANK: Option<usize>;
27
28 /// Shape representation for this rank.
29 ///
30 /// # Examples
31 ///
32 /// ```rust
33 /// use tenferro_tensor_core::{DynRank, TensorRank};
34 ///
35 /// let shape: <DynRank as TensorRank>::Shape = vec![2, 3].into();
36 /// assert_eq!(shape.as_ref(), &[2, 3]);
37 /// ```
38 type Shape: Clone + Debug + PartialEq + Eq + AsRef<[usize]>;
39
40 /// Stride representation for this rank.
41 ///
42 /// # Examples
43 ///
44 /// ```rust
45 /// use tenferro_tensor_core::{DynRank, TensorRank};
46 ///
47 /// let strides: <DynRank as TensorRank>::Strides = vec![1, 2].into();
48 /// assert_eq!(strides.as_ref(), &[1, 2]);
49 /// ```
50 type Strides: Clone + Debug + PartialEq + Eq + AsRef<[isize]>;
51
52 /// Convert a dynamic shape vector into this rank's shape representation.
53 ///
54 /// # Examples
55 ///
56 /// ```rust
57 /// use tenferro_tensor_core::{Rank, TensorRank};
58 ///
59 /// let shape = <Rank<1> as TensorRank>::shape_from_vec(vec![4].into())?;
60 /// assert_eq!(shape.as_ref(), &[4]);
61 /// # Ok::<(), tenferro_tensor_core::Error>(())
62 /// ```
63 fn shape_from_vec(shape: ShapeVec) -> Result<Self::Shape>;
64
65 /// Convert this rank's shape representation into a dynamic shape vector.
66 ///
67 /// # Examples
68 ///
69 /// ```rust
70 /// use tenferro_tensor_core::{Rank, TensorRank};
71 ///
72 /// let shape = <Rank<2> as TensorRank>::shape_from_vec(vec![2, 3].into())?;
73 /// assert_eq!(<Rank<2> as TensorRank>::shape_into_vec(shape).as_slice(), &[2, 3]);
74 /// # Ok::<(), tenferro_tensor_core::Error>(())
75 /// ```
76 fn shape_into_vec(shape: Self::Shape) -> ShapeVec;
77
78 /// Convert a dynamic stride vector into this rank's stride representation.
79 ///
80 /// # Examples
81 ///
82 /// ```rust
83 /// use tenferro_tensor_core::{Rank, TensorRank};
84 ///
85 /// let strides = <Rank<2> as TensorRank>::strides_from_vec(vec![1, 2].into())?;
86 /// assert_eq!(strides.as_ref(), &[1, 2]);
87 /// # Ok::<(), tenferro_tensor_core::Error>(())
88 /// ```
89 fn strides_from_vec(strides: StrideVec) -> Result<Self::Strides>;
90
91 /// Convert this rank's stride representation into a dynamic stride vector.
92 ///
93 /// # Examples
94 ///
95 /// ```rust
96 /// use tenferro_tensor_core::{Rank, TensorRank};
97 ///
98 /// let strides = <Rank<2> as TensorRank>::strides_from_vec(vec![1, 2].into())?;
99 /// assert_eq!(<Rank<2> as TensorRank>::strides_into_vec(strides).as_slice(), &[1, 2]);
100 /// # Ok::<(), tenferro_tensor_core::Error>(())
101 /// ```
102 fn strides_into_vec(strides: Self::Strides) -> StrideVec;
103}
104
105/// Dynamic tensor rank marker.
106///
107/// # Examples
108///
109/// ```rust
110/// use tenferro_tensor_core::{DynRank, TensorRank};
111///
112/// assert_eq!(DynRank::RANK, None);
113/// ```
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
115pub struct DynRank;
116
117/// Static tensor rank marker.
118///
119/// # Examples
120///
121/// ```rust
122/// use tenferro_tensor_core::{Rank, TensorRank};
123///
124/// assert_eq!(Rank::<3>::RANK, Some(3));
125/// ```
126#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
127pub struct Rank<const N: usize>;
128
129impl TensorRank for DynRank {
130 const RANK: Option<usize> = None;
131
132 type Shape = ShapeVec;
133 type Strides = StrideVec;
134
135 fn shape_from_vec(shape: ShapeVec) -> Result<Self::Shape> {
136 Ok(shape)
137 }
138
139 fn shape_into_vec(shape: Self::Shape) -> ShapeVec {
140 shape
141 }
142
143 fn strides_from_vec(strides: StrideVec) -> Result<Self::Strides> {
144 Ok(strides)
145 }
146
147 fn strides_into_vec(strides: Self::Strides) -> StrideVec {
148 strides
149 }
150}
151
152impl<const N: usize> TensorRank for Rank<N> {
153 const RANK: Option<usize> = Some(N);
154
155 type Shape = [usize; N];
156 type Strides = [isize; N];
157
158 fn shape_from_vec(shape: ShapeVec) -> Result<Self::Shape> {
159 let actual = shape.len();
160 shape
161 .into_vec()
162 .try_into()
163 .map_err(|_| Error::RankMismatch {
164 expected: N,
165 actual,
166 })
167 }
168
169 fn shape_into_vec(shape: Self::Shape) -> ShapeVec {
170 ShapeVec::from_iter(shape)
171 }
172
173 fn strides_from_vec(strides: StrideVec) -> Result<Self::Strides> {
174 let actual = strides.len();
175 strides
176 .into_vec()
177 .try_into()
178 .map_err(|_| Error::RankMismatch {
179 expected: N,
180 actual,
181 })
182 }
183
184 fn strides_into_vec(strides: Self::Strides) -> StrideVec {
185 StrideVec::from_iter(strides)
186 }
187}
188
189mod private {
190 pub trait Sealed {}
191
192 impl Sealed for super::DynRank {}
193 impl<const N: usize> Sealed for super::Rank<N> {}
194}