Skip to main content

tensor4all_simplett/
tensor.rs

1//! Fixed-rank tensor type backed by `tenferro_tensor::TypedTensor<T>`.
2//!
3//! This wrapper preserves a compile-time rank while delegating storage and
4//! indexing to tenferro's typed dense tensor implementation.
5
6use std::marker::PhantomData;
7use std::ops::{Index, IndexMut};
8
9use tenferro_tensor::{TensorScalar, TypedTensor as TfTensor};
10
11use crate::einsum_helper::{tensor_to_row_major_vec, typed_tensor_from_row_major_slice};
12
13/// Rank-N tensor backed by `tenferro_tensor::TypedTensor<T>`.
14#[derive(Debug)]
15pub struct Tensor<T: TensorScalar, const N: usize>(TfTensor<T>);
16
17/// Iterator over tensor elements in row-major order.
18pub struct TensorIter<'a, T: TensorScalar, const N: usize> {
19    tensor: &'a TfTensor<T>,
20    dims: [usize; N],
21    next: usize,
22    len: usize,
23}
24
25/// Mutable iterator over tensor elements in row-major order.
26pub struct TensorIterMut<'a, T: TensorScalar, const N: usize> {
27    tensor: *mut TfTensor<T>,
28    dims: [usize; N],
29    next: usize,
30    len: usize,
31    _marker: PhantomData<&'a mut TfTensor<T>>,
32}
33
34/// 2D tensor (matrix).
35pub type Tensor2<T> = Tensor<T, 2>;
36
37/// 3D tensor.
38pub type Tensor3<T> = Tensor<T, 3>;
39
40/// 4D tensor.
41pub type Tensor4<T> = Tensor<T, 4>;
42
43fn row_major_index_from_linear<const N: usize>(mut linear: usize, dims: &[usize; N]) -> [usize; N] {
44    let mut idx = [0usize; N];
45    for axis in (0..N).rev() {
46        let dim = dims[axis];
47        if dim == 0 {
48            return idx;
49        }
50        idx[axis] = linear % dim;
51        linear /= dim;
52    }
53    idx
54}
55
56fn row_major_data_to_tensor<T: TensorScalar, const N: usize>(
57    dims: [usize; N],
58    data: Vec<T>,
59) -> Tensor<T, N> {
60    let inner = typed_tensor_from_row_major_slice(&data, &dims);
61    Tensor::from_tenferro(inner)
62}
63
64impl<T: TensorScalar, const N: usize> Clone for Tensor<T, N> {
65    fn clone(&self) -> Self {
66        Self(self.0.clone())
67    }
68}
69
70impl<T: TensorScalar + PartialEq, const N: usize> PartialEq for Tensor<T, N> {
71    fn eq(&self, other: &Self) -> bool {
72        self.dims() == other.dims() && self.iter().eq(other.iter())
73    }
74}
75
76impl<T: TensorScalar + Eq, const N: usize> Eq for Tensor<T, N> {}
77
78impl<'a, T: TensorScalar, const N: usize> Iterator for TensorIter<'a, T, N> {
79    type Item = &'a T;
80
81    fn next(&mut self) -> Option<Self::Item> {
82        if self.next == self.len {
83            return None;
84        }
85
86        let idx = row_major_index_from_linear(self.next, &self.dims);
87        self.next += 1;
88        Some(self.tensor.get(&idx[..]))
89    }
90
91    fn size_hint(&self) -> (usize, Option<usize>) {
92        let remaining = self.len - self.next;
93        (remaining, Some(remaining))
94    }
95}
96
97impl<'a, T: TensorScalar, const N: usize> ExactSizeIterator for TensorIter<'a, T, N> {}
98
99impl<'a, T: TensorScalar, const N: usize> Iterator for TensorIterMut<'a, T, N> {
100    type Item = &'a mut T;
101
102    fn next(&mut self) -> Option<Self::Item> {
103        if self.next == self.len {
104            return None;
105        }
106
107        let idx = row_major_index_from_linear(self.next, &self.dims);
108        self.next += 1;
109
110        // Safety: each logical index is visited at most once, so returned
111        // mutable references never alias each other.
112        let tensor = unsafe { &mut *self.tensor };
113        let elem = tensor.get_mut(&idx[..]);
114        let ptr = elem as *mut T;
115        Some(unsafe { &mut *ptr })
116    }
117
118    fn size_hint(&self) -> (usize, Option<usize>) {
119        let remaining = self.len - self.next;
120        (remaining, Some(remaining))
121    }
122}
123
124impl<'a, T: TensorScalar, const N: usize> ExactSizeIterator for TensorIterMut<'a, T, N> {}
125
126impl<T: TensorScalar, const N: usize> Tensor<T, N> {
127    /// Total number of elements.
128    pub fn len(&self) -> usize {
129        self.0.n_elements()
130    }
131
132    /// Whether the tensor is empty (zero elements).
133    pub fn is_empty(&self) -> bool {
134        self.len() == 0
135    }
136
137    /// Dimension along `axis`.
138    pub fn dim(&self, axis: usize) -> usize {
139        self.dims()[axis]
140    }
141
142    /// All dimensions.
143    pub fn dims(&self) -> &[usize; N] {
144        self.0.shape.as_slice().try_into().unwrap_or_else(|_| {
145            panic!(
146                "tensor rank mismatch: expected rank {N}, got {}",
147                self.0.shape.len()
148            )
149        })
150    }
151
152    /// Export the tensor as a row-major flat vector.
153    pub fn to_row_major_vec(&self) -> Vec<T> {
154        tensor_to_row_major_vec(&self.0)
155    }
156
157    /// Iterate over all elements in row-major order.
158    pub fn iter(&self) -> TensorIter<'_, T, N> {
159        TensorIter {
160            tensor: &self.0,
161            dims: *self.dims(),
162            next: 0,
163            len: self.len(),
164        }
165    }
166
167    /// Iterate mutably over all elements in row-major order.
168    pub fn iter_mut(&mut self) -> TensorIterMut<'_, T, N> {
169        TensorIterMut {
170            tensor: &mut self.0,
171            dims: *self.dims(),
172            next: 0,
173            len: self.len(),
174            _marker: PhantomData,
175        }
176    }
177
178    /// Borrow the wrapped tenferro tensor.
179    pub fn as_inner(&self) -> &TfTensor<T> {
180        &self.0
181    }
182
183    /// Mutably borrow the wrapped tenferro tensor.
184    pub fn as_inner_mut(&mut self) -> &mut TfTensor<T> {
185        &mut self.0
186    }
187
188    /// Consume this wrapper and return the inner tenferro tensor.
189    pub fn into_inner(self) -> TfTensor<T> {
190        self.0
191    }
192
193    /// Wrap an existing tenferro tensor, panicking on rank mismatch.
194    pub fn from_tenferro(tensor: TfTensor<T>) -> Self {
195        if tensor.shape.len() != N {
196            panic!(
197                "tensor rank mismatch: expected rank {N}, got {}",
198                tensor.shape.len()
199            );
200        }
201        Self(tensor)
202    }
203
204    /// Create a tensor by applying `f` to each multi-index (row-major order).
205    pub fn from_fn(dims: [usize; N], mut f: impl FnMut([usize; N]) -> T) -> Self {
206        let total: usize = dims.iter().product();
207        let mut data = Vec::with_capacity(total);
208
209        for linear in 0..total {
210            data.push(f(row_major_index_from_linear(linear, &dims)));
211        }
212
213        row_major_data_to_tensor(dims, data)
214    }
215}
216
217impl<T: TensorScalar, const N: usize> Tensor<T, N> {
218    /// Create a tensor filled with `value`.
219    pub fn from_elem(dims: [usize; N], value: T) -> Self {
220        let total: usize = dims.iter().product();
221        row_major_data_to_tensor(dims, vec![value; total])
222    }
223}
224
225impl<T: TensorScalar, const N: usize> Index<[usize; N]> for Tensor<T, N> {
226    type Output = T;
227
228    fn index(&self, idx: [usize; N]) -> &T {
229        self.0.get(&idx[..])
230    }
231}
232
233impl<T: TensorScalar, const N: usize> IndexMut<[usize; N]> for Tensor<T, N> {
234    fn index_mut(&mut self, idx: [usize; N]) -> &mut T {
235        self.0.get_mut(&idx[..])
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    use crate::einsum_helper::{tensor_to_row_major_vec, typed_tensor_from_row_major_slice};
244
245    #[test]
246    fn test_tensor2_from_elem() {
247        let t: Tensor2<f64> = Tensor2::from_elem([3, 4], 0.0);
248        assert_eq!(t.len(), 12);
249        assert_eq!(t.dim(0), 3);
250        assert_eq!(t.dim(1), 4);
251        assert_eq!(t[[0, 0]], 0.0);
252    }
253
254    #[test]
255    fn test_tensor2_indexing() {
256        let mut t: Tensor2<f64> = Tensor2::from_elem([2, 3], 0.0);
257        t[[0, 0]] = 1.0;
258        t[[0, 1]] = 2.0;
259        t[[0, 2]] = 3.0;
260        t[[1, 0]] = 4.0;
261        t[[1, 1]] = 5.0;
262        t[[1, 2]] = 6.0;
263        assert_eq!(t[[0, 0]], 1.0);
264        assert_eq!(t[[1, 2]], 6.0);
265        assert_eq!(t.to_row_major_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
266    }
267
268    #[test]
269    fn test_tensor3_from_fn() {
270        let t: Tensor3<f64> =
271            Tensor3::from_fn([2, 3, 4], |[i, j, k]| (i * 100 + j * 10 + k) as f64);
272        assert_eq!(t[[0, 0, 0]], 0.0);
273        assert_eq!(t[[1, 2, 3]], 123.0);
274        assert_eq!(t[[0, 1, 2]], 12.0);
275        assert_eq!(t.dim(0), 2);
276        assert_eq!(t.dim(1), 3);
277        assert_eq!(t.dim(2), 4);
278        assert_eq!(t.len(), 24);
279    }
280
281    #[test]
282    fn test_tensor4_from_fn() {
283        let t: Tensor4<f64> = Tensor4::from_fn([2, 3, 4, 5], |[i, j, k, l]| {
284            (i * 1000 + j * 100 + k * 10 + l) as f64
285        });
286        assert_eq!(t[[1, 2, 3, 4]], 1234.0);
287        assert_eq!(t.dim(0), 2);
288        assert_eq!(t.dim(1), 3);
289        assert_eq!(t.dim(2), 4);
290        assert_eq!(t.dim(3), 5);
291        assert_eq!(t.len(), 120);
292    }
293
294    #[test]
295    fn test_iter() {
296        let t: Tensor2<f64> = Tensor2::from_fn([2, 3], |[i, j]| (i * 3 + j) as f64);
297        let collected: Vec<f64> = t.iter().copied().collect();
298        assert_eq!(collected, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
299    }
300
301    #[test]
302    fn test_dims() {
303        let t: Tensor3<f64> = Tensor3::from_elem([2, 3, 4], 1.0);
304        assert_eq!(t.dims(), &[2, 3, 4]);
305    }
306
307    #[test]
308    fn test_from_tenferro_roundtrip() {
309        let inner = typed_tensor_from_row_major_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
310        let tensor = Tensor2::from_tenferro(inner);
311
312        assert_eq!(tensor.dims(), &[2, 3]);
313        assert_eq!(
314            tensor.to_row_major_vec(),
315            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
316        );
317
318        let inner_again = tensor.clone().into_inner();
319        assert_eq!(
320            tensor_to_row_major_vec(&inner_again),
321            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
322        );
323    }
324
325    #[test]
326    #[should_panic(expected = "rank")]
327    fn test_from_tenferro_rank_mismatch_panics() {
328        let inner = typed_tensor_from_row_major_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
329        let _ = Tensor3::from_tenferro(inner);
330    }
331
332    #[test]
333    fn test_clone_allows_index_mut() {
334        let original: Tensor2<f64> = Tensor2::from_fn([2, 2], |[i, j]| (i * 2 + j) as f64);
335        let mut cloned = original.clone();
336        cloned[[1, 0]] = 99.0;
337
338        assert_eq!(original[[1, 0]], 2.0);
339        assert_eq!(cloned[[1, 0]], 99.0);
340    }
341}