tensor4all_simplett/
types.rs1use crate::tensor::Tensor;
4pub use crate::tensor::Tensor3;
5use tenferro_tensor::{TensorScalar, TypedTensor as TfTensor};
6
7pub type LocalIndex = usize;
9
10pub type MultiIndex = Vec<LocalIndex>;
12
13pub trait Tensor3Ops<T: Clone + Default> {
34 fn left_dim(&self) -> usize;
36
37 fn site_dim(&self) -> usize;
39
40 fn right_dim(&self) -> usize;
42
43 fn get3(&self, l: usize, s: usize, r: usize) -> &T;
45
46 fn get3_mut(&mut self, l: usize, s: usize, r: usize) -> &mut T;
48
49 fn set3(&mut self, l: usize, s: usize, r: usize, value: T);
51
52 fn slice_site(&self, s: usize) -> Vec<T>;
55
56 fn as_left_matrix(&self) -> (Vec<T>, usize, usize);
58
59 fn as_right_matrix(&self) -> (Vec<T>, usize, usize);
61}
62
63impl<T: Clone + Default + TensorScalar> Tensor3Ops<T> for Tensor3<T> {
64 fn left_dim(&self) -> usize {
65 self.dim(0)
66 }
67
68 fn site_dim(&self) -> usize {
69 self.dim(1)
70 }
71
72 fn right_dim(&self) -> usize {
73 self.dim(2)
74 }
75
76 fn get3(&self, l: usize, s: usize, r: usize) -> &T {
77 &self[[l, s, r]]
78 }
79
80 fn get3_mut(&mut self, l: usize, s: usize, r: usize) -> &mut T {
81 &mut self[[l, s, r]]
82 }
83
84 fn set3(&mut self, l: usize, s: usize, r: usize, value: T) {
85 self[[l, s, r]] = value;
86 }
87
88 fn slice_site(&self, s: usize) -> Vec<T> {
89 let left_dim = self.left_dim();
90 let right_dim = self.right_dim();
91 let mut result = Vec::with_capacity(left_dim * right_dim);
92 for l in 0..left_dim {
93 for r in 0..right_dim {
94 result.push(self[[l, s, r]]);
95 }
96 }
97 result
98 }
99
100 fn as_left_matrix(&self) -> (Vec<T>, usize, usize) {
101 let left_dim = self.left_dim();
102 let site_dim = self.site_dim();
103 let right_dim = self.right_dim();
104 let rows = left_dim * site_dim;
105 let cols = right_dim;
106 let mut result = Vec::with_capacity(rows * cols);
107 for l in 0..left_dim {
108 for s in 0..site_dim {
109 for r in 0..right_dim {
110 result.push(self[[l, s, r]]);
111 }
112 }
113 }
114 (result, rows, cols)
115 }
116
117 fn as_right_matrix(&self) -> (Vec<T>, usize, usize) {
118 let left_dim = self.left_dim();
119 let site_dim = self.site_dim();
120 let right_dim = self.right_dim();
121 let rows = left_dim;
122 let cols = site_dim * right_dim;
123 let mut result = Vec::with_capacity(rows * cols);
124 for l in 0..left_dim {
125 for s in 0..site_dim {
126 for r in 0..right_dim {
127 result.push(self[[l, s, r]]);
128 }
129 }
130 }
131 (result, rows, cols)
132 }
133}
134
135pub fn tensor3_zeros<T: Clone + Default + TensorScalar>(
149 left_dim: usize,
150 site_dim: usize,
151 right_dim: usize,
152) -> Tensor3<T> {
153 Tensor::from_elem([left_dim, site_dim, right_dim], T::default())
154}
155
156pub fn tensor3_from_data<T: TensorScalar>(
173 data: Vec<T>,
174 left_dim: usize,
175 site_dim: usize,
176 right_dim: usize,
177) -> Tensor3<T> {
178 assert_eq!(data.len(), left_dim * site_dim * right_dim);
179 let dims = [left_dim, site_dim, right_dim];
180 let inner = TfTensor::from_vec(dims.to_vec(), data);
181 Tensor::from_tenferro(inner)
182}
183
184#[cfg(test)]
185mod tests;