Skip to main content

tensor4all_simplett/
types.rs

1//! Core types for tensor train operations
2
3use crate::tensor::Tensor;
4pub use crate::tensor::Tensor3;
5use tenferro_tensor::{TensorScalar, TypedTensor as TfTensor};
6
7/// Local index type (index within a single tensor site)
8pub type LocalIndex = usize;
9
10/// Multi-index type (indices across all sites)
11pub type MultiIndex = Vec<LocalIndex>;
12
13/// Convenience accessors for rank-3 core tensors with shape
14/// `(left_bond, site_dim, right_bond)`.
15///
16/// These methods give named access to dimensions and elements using the
17/// tensor train convention where axis 0 is the left bond, axis 1 is the
18/// physical (site) index, and axis 2 is the right bond.
19///
20/// # Examples
21///
22/// ```
23/// use tensor4all_simplett::{Tensor3Ops, tensor3_zeros};
24///
25/// let mut t = tensor3_zeros::<f64>(2, 3, 4);
26/// assert_eq!(t.left_dim(), 2);
27/// assert_eq!(t.site_dim(), 3);
28/// assert_eq!(t.right_dim(), 4);
29///
30/// t.set3(1, 2, 3, 42.0);
31/// assert_eq!(*t.get3(1, 2, 3), 42.0);
32/// ```
33pub trait Tensor3Ops<T: Clone + Default> {
34    /// Left (bond) dimension (axis 0).
35    fn left_dim(&self) -> usize;
36
37    /// Physical (site) dimension (axis 1).
38    fn site_dim(&self) -> usize;
39
40    /// Right (bond) dimension (axis 2).
41    fn right_dim(&self) -> usize;
42
43    /// Borrow the element at `(left, site, right)`.
44    fn get3(&self, l: usize, s: usize, r: usize) -> &T;
45
46    /// Mutably borrow the element at `(left, site, right)`.
47    fn get3_mut(&mut self, l: usize, s: usize, r: usize) -> &mut T;
48
49    /// Set the element at `(left, site, right)` to `value`.
50    fn set3(&mut self, l: usize, s: usize, r: usize, value: T);
51
52    /// Extract the `(left_dim, right_dim)` matrix for a fixed site index `s`
53    /// as a flat row-major vector.
54    fn slice_site(&self, s: usize) -> Vec<T>;
55
56    /// Reshape to a `(left_dim * site_dim, right_dim)` matrix.
57    fn as_left_matrix(&self) -> (Vec<T>, usize, usize);
58
59    /// Reshape to a `(left_dim, site_dim * right_dim)` matrix.
60    fn as_right_matrix(&self) -> (Vec<T>, usize, usize);
61}
62
63impl<T: Clone + Default + TensorScalar> Tensor3Ops<T> for Tensor3<T> {
64    fn left_dim(&self) -> usize {
65        self.dim(0)
66    }
67
68    fn site_dim(&self) -> usize {
69        self.dim(1)
70    }
71
72    fn right_dim(&self) -> usize {
73        self.dim(2)
74    }
75
76    fn get3(&self, l: usize, s: usize, r: usize) -> &T {
77        &self[[l, s, r]]
78    }
79
80    fn get3_mut(&mut self, l: usize, s: usize, r: usize) -> &mut T {
81        &mut self[[l, s, r]]
82    }
83
84    fn set3(&mut self, l: usize, s: usize, r: usize, value: T) {
85        self[[l, s, r]] = value;
86    }
87
88    fn slice_site(&self, s: usize) -> Vec<T> {
89        let left_dim = self.left_dim();
90        let right_dim = self.right_dim();
91        let mut result = Vec::with_capacity(left_dim * right_dim);
92        for l in 0..left_dim {
93            for r in 0..right_dim {
94                result.push(self[[l, s, r]]);
95            }
96        }
97        result
98    }
99
100    fn as_left_matrix(&self) -> (Vec<T>, usize, usize) {
101        let left_dim = self.left_dim();
102        let site_dim = self.site_dim();
103        let right_dim = self.right_dim();
104        let rows = left_dim * site_dim;
105        let cols = right_dim;
106        let mut result = Vec::with_capacity(rows * cols);
107        for l in 0..left_dim {
108            for s in 0..site_dim {
109                for r in 0..right_dim {
110                    result.push(self[[l, s, r]]);
111                }
112            }
113        }
114        (result, rows, cols)
115    }
116
117    fn as_right_matrix(&self) -> (Vec<T>, usize, usize) {
118        let left_dim = self.left_dim();
119        let site_dim = self.site_dim();
120        let right_dim = self.right_dim();
121        let rows = left_dim;
122        let cols = site_dim * right_dim;
123        let mut result = Vec::with_capacity(rows * cols);
124        for l in 0..left_dim {
125            for s in 0..site_dim {
126                for r in 0..right_dim {
127                    result.push(self[[l, s, r]]);
128                }
129            }
130        }
131        (result, rows, cols)
132    }
133}
134
135/// Create a zero-filled rank-3 tensor with shape `(left_dim, site_dim, right_dim)`.
136///
137/// # Examples
138///
139/// ```
140/// use tensor4all_simplett::{tensor3_zeros, Tensor3Ops};
141///
142/// let t = tensor3_zeros::<f64>(2, 3, 4);
143/// assert_eq!(t.left_dim(), 2);
144/// assert_eq!(t.site_dim(), 3);
145/// assert_eq!(t.right_dim(), 4);
146/// assert_eq!(*t.get3(0, 0, 0), 0.0);
147/// ```
148pub fn tensor3_zeros<T: Clone + Default + TensorScalar>(
149    left_dim: usize,
150    site_dim: usize,
151    right_dim: usize,
152) -> Tensor3<T> {
153    Tensor::from_elem([left_dim, site_dim, right_dim], T::default())
154}
155
156/// Create a rank-3 tensor from flat data in **column-major** order.
157///
158/// # Panics
159///
160/// Panics if `data.len() != left_dim * site_dim * right_dim`.
161///
162/// # Examples
163///
164/// ```
165/// use tensor4all_simplett::{tensor3_from_data, Tensor3Ops};
166///
167/// // 1 x 2 x 1 tensor, column-major data: [10.0, 20.0]
168/// let t = tensor3_from_data(vec![10.0, 20.0], 1, 2, 1);
169/// assert_eq!(*t.get3(0, 0, 0), 10.0);
170/// assert_eq!(*t.get3(0, 1, 0), 20.0);
171/// ```
172pub fn tensor3_from_data<T: TensorScalar>(
173    data: Vec<T>,
174    left_dim: usize,
175    site_dim: usize,
176    right_dim: usize,
177) -> Tensor3<T> {
178    assert_eq!(data.len(), left_dim * site_dim * right_dim);
179    let dims = [left_dim, site_dim, right_dim];
180    let inner = TfTensor::from_vec(dims.to_vec(), data);
181    Tensor::from_tenferro(inner)
182}
183
184#[cfg(test)]
185mod tests;