1use std::marker::PhantomData;
7use std::ops::{Index, IndexMut};
8
9use tenferro_tensor::{TensorScalar, TypedTensor as TfTensor};
10
11use crate::einsum_helper::{tensor_to_row_major_vec, typed_tensor_from_row_major_slice};
12
13#[derive(Debug)]
15pub struct Tensor<T: TensorScalar, const N: usize>(TfTensor<T>);
16
17pub struct TensorIter<'a, T: TensorScalar, const N: usize> {
19 tensor: &'a TfTensor<T>,
20 dims: [usize; N],
21 next: usize,
22 len: usize,
23}
24
25pub struct TensorIterMut<'a, T: TensorScalar, const N: usize> {
27 tensor: *mut TfTensor<T>,
28 dims: [usize; N],
29 next: usize,
30 len: usize,
31 _marker: PhantomData<&'a mut TfTensor<T>>,
32}
33
34pub type Tensor2<T> = Tensor<T, 2>;
36
37pub type Tensor3<T> = Tensor<T, 3>;
39
40pub type Tensor4<T> = Tensor<T, 4>;
42
43fn row_major_index_from_linear<const N: usize>(mut linear: usize, dims: &[usize; N]) -> [usize; N] {
44 let mut idx = [0usize; N];
45 for axis in (0..N).rev() {
46 let dim = dims[axis];
47 if dim == 0 {
48 return idx;
49 }
50 idx[axis] = linear % dim;
51 linear /= dim;
52 }
53 idx
54}
55
56fn row_major_data_to_tensor<T: TensorScalar, const N: usize>(
57 dims: [usize; N],
58 data: Vec<T>,
59) -> Tensor<T, N> {
60 let inner = typed_tensor_from_row_major_slice(&data, &dims);
61 Tensor::from_tenferro(inner)
62}
63
64impl<T: TensorScalar, const N: usize> Clone for Tensor<T, N> {
65 fn clone(&self) -> Self {
66 Self(self.0.clone())
67 }
68}
69
70impl<T: TensorScalar + PartialEq, const N: usize> PartialEq for Tensor<T, N> {
71 fn eq(&self, other: &Self) -> bool {
72 self.dims() == other.dims() && self.iter().eq(other.iter())
73 }
74}
75
76impl<T: TensorScalar + Eq, const N: usize> Eq for Tensor<T, N> {}
77
78impl<'a, T: TensorScalar, const N: usize> Iterator for TensorIter<'a, T, N> {
79 type Item = &'a T;
80
81 fn next(&mut self) -> Option<Self::Item> {
82 if self.next == self.len {
83 return None;
84 }
85
86 let idx = row_major_index_from_linear(self.next, &self.dims);
87 self.next += 1;
88 Some(self.tensor.get(&idx[..]))
89 }
90
91 fn size_hint(&self) -> (usize, Option<usize>) {
92 let remaining = self.len - self.next;
93 (remaining, Some(remaining))
94 }
95}
96
97impl<'a, T: TensorScalar, const N: usize> ExactSizeIterator for TensorIter<'a, T, N> {}
98
99impl<'a, T: TensorScalar, const N: usize> Iterator for TensorIterMut<'a, T, N> {
100 type Item = &'a mut T;
101
102 fn next(&mut self) -> Option<Self::Item> {
103 if self.next == self.len {
104 return None;
105 }
106
107 let idx = row_major_index_from_linear(self.next, &self.dims);
108 self.next += 1;
109
110 let tensor = unsafe { &mut *self.tensor };
113 let elem = tensor.get_mut(&idx[..]);
114 let ptr = elem as *mut T;
115 Some(unsafe { &mut *ptr })
116 }
117
118 fn size_hint(&self) -> (usize, Option<usize>) {
119 let remaining = self.len - self.next;
120 (remaining, Some(remaining))
121 }
122}
123
124impl<'a, T: TensorScalar, const N: usize> ExactSizeIterator for TensorIterMut<'a, T, N> {}
125
126impl<T: TensorScalar, const N: usize> Tensor<T, N> {
127 pub fn len(&self) -> usize {
129 self.0.n_elements()
130 }
131
132 pub fn is_empty(&self) -> bool {
134 self.len() == 0
135 }
136
137 pub fn dim(&self, axis: usize) -> usize {
139 self.dims()[axis]
140 }
141
142 pub fn dims(&self) -> &[usize; N] {
144 self.0.shape.as_slice().try_into().unwrap_or_else(|_| {
145 panic!(
146 "tensor rank mismatch: expected rank {N}, got {}",
147 self.0.shape.len()
148 )
149 })
150 }
151
152 pub fn to_row_major_vec(&self) -> Vec<T> {
154 tensor_to_row_major_vec(&self.0)
155 }
156
157 pub fn iter(&self) -> TensorIter<'_, T, N> {
159 TensorIter {
160 tensor: &self.0,
161 dims: *self.dims(),
162 next: 0,
163 len: self.len(),
164 }
165 }
166
167 pub fn iter_mut(&mut self) -> TensorIterMut<'_, T, N> {
169 TensorIterMut {
170 tensor: &mut self.0,
171 dims: *self.dims(),
172 next: 0,
173 len: self.len(),
174 _marker: PhantomData,
175 }
176 }
177
178 pub fn as_inner(&self) -> &TfTensor<T> {
180 &self.0
181 }
182
183 pub fn as_inner_mut(&mut self) -> &mut TfTensor<T> {
185 &mut self.0
186 }
187
188 pub fn into_inner(self) -> TfTensor<T> {
190 self.0
191 }
192
193 pub fn from_tenferro(tensor: TfTensor<T>) -> Self {
195 if tensor.shape.len() != N {
196 panic!(
197 "tensor rank mismatch: expected rank {N}, got {}",
198 tensor.shape.len()
199 );
200 }
201 Self(tensor)
202 }
203
204 pub fn from_fn(dims: [usize; N], mut f: impl FnMut([usize; N]) -> T) -> Self {
206 let total: usize = dims.iter().product();
207 let mut data = Vec::with_capacity(total);
208
209 for linear in 0..total {
210 data.push(f(row_major_index_from_linear(linear, &dims)));
211 }
212
213 row_major_data_to_tensor(dims, data)
214 }
215}
216
217impl<T: TensorScalar, const N: usize> Tensor<T, N> {
218 pub fn from_elem(dims: [usize; N], value: T) -> Self {
220 let total: usize = dims.iter().product();
221 row_major_data_to_tensor(dims, vec![value; total])
222 }
223}
224
225impl<T: TensorScalar, const N: usize> Index<[usize; N]> for Tensor<T, N> {
226 type Output = T;
227
228 fn index(&self, idx: [usize; N]) -> &T {
229 self.0.get(&idx[..])
230 }
231}
232
233impl<T: TensorScalar, const N: usize> IndexMut<[usize; N]> for Tensor<T, N> {
234 fn index_mut(&mut self, idx: [usize; N]) -> &mut T {
235 self.0.get_mut(&idx[..])
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 use crate::einsum_helper::{tensor_to_row_major_vec, typed_tensor_from_row_major_slice};
244
245 #[test]
246 fn test_tensor2_from_elem() {
247 let t: Tensor2<f64> = Tensor2::from_elem([3, 4], 0.0);
248 assert_eq!(t.len(), 12);
249 assert_eq!(t.dim(0), 3);
250 assert_eq!(t.dim(1), 4);
251 assert_eq!(t[[0, 0]], 0.0);
252 }
253
254 #[test]
255 fn test_tensor2_indexing() {
256 let mut t: Tensor2<f64> = Tensor2::from_elem([2, 3], 0.0);
257 t[[0, 0]] = 1.0;
258 t[[0, 1]] = 2.0;
259 t[[0, 2]] = 3.0;
260 t[[1, 0]] = 4.0;
261 t[[1, 1]] = 5.0;
262 t[[1, 2]] = 6.0;
263 assert_eq!(t[[0, 0]], 1.0);
264 assert_eq!(t[[1, 2]], 6.0);
265 assert_eq!(t.to_row_major_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
266 }
267
268 #[test]
269 fn test_tensor3_from_fn() {
270 let t: Tensor3<f64> =
271 Tensor3::from_fn([2, 3, 4], |[i, j, k]| (i * 100 + j * 10 + k) as f64);
272 assert_eq!(t[[0, 0, 0]], 0.0);
273 assert_eq!(t[[1, 2, 3]], 123.0);
274 assert_eq!(t[[0, 1, 2]], 12.0);
275 assert_eq!(t.dim(0), 2);
276 assert_eq!(t.dim(1), 3);
277 assert_eq!(t.dim(2), 4);
278 assert_eq!(t.len(), 24);
279 }
280
281 #[test]
282 fn test_tensor4_from_fn() {
283 let t: Tensor4<f64> = Tensor4::from_fn([2, 3, 4, 5], |[i, j, k, l]| {
284 (i * 1000 + j * 100 + k * 10 + l) as f64
285 });
286 assert_eq!(t[[1, 2, 3, 4]], 1234.0);
287 assert_eq!(t.dim(0), 2);
288 assert_eq!(t.dim(1), 3);
289 assert_eq!(t.dim(2), 4);
290 assert_eq!(t.dim(3), 5);
291 assert_eq!(t.len(), 120);
292 }
293
294 #[test]
295 fn test_iter() {
296 let t: Tensor2<f64> = Tensor2::from_fn([2, 3], |[i, j]| (i * 3 + j) as f64);
297 let collected: Vec<f64> = t.iter().copied().collect();
298 assert_eq!(collected, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
299 }
300
301 #[test]
302 fn test_dims() {
303 let t: Tensor3<f64> = Tensor3::from_elem([2, 3, 4], 1.0);
304 assert_eq!(t.dims(), &[2, 3, 4]);
305 }
306
307 #[test]
308 fn test_from_tenferro_roundtrip() {
309 let inner = typed_tensor_from_row_major_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
310 let tensor = Tensor2::from_tenferro(inner);
311
312 assert_eq!(tensor.dims(), &[2, 3]);
313 assert_eq!(
314 tensor.to_row_major_vec(),
315 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
316 );
317
318 let inner_again = tensor.clone().into_inner();
319 assert_eq!(
320 tensor_to_row_major_vec(&inner_again),
321 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
322 );
323 }
324
325 #[test]
326 #[should_panic(expected = "rank")]
327 fn test_from_tenferro_rank_mismatch_panics() {
328 let inner = typed_tensor_from_row_major_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
329 let _ = Tensor3::from_tenferro(inner);
330 }
331
332 #[test]
333 fn test_clone_allows_index_mut() {
334 let original: Tensor2<f64> = Tensor2::from_fn([2, 2], |[i, j]| (i * 2 + j) as f64);
335 let mut cloned = original.clone();
336 cloned[[1, 0]] = 99.0;
337
338 assert_eq!(original[[1, 0]], 2.0);
339 assert_eq!(cloned[[1, 0]], 99.0);
340 }
341}