1use std::ops::Range;
8
9use crate::error::{Result, TensorTrainError};
10use crate::tensortrain::TensorTrain;
11use crate::traits::{AbstractTensorTrain, TTScalar};
12use crate::types::{tensor3_zeros, Tensor3, Tensor3Ops};
13use tensor4all_tcicore::matrix::{mat_mul, ncols, nrows, transpose, zeros, Matrix};
14use tensor4all_tcicore::Scalar;
15use tensor4all_tcicore::{rrlu, RrLUOptions};
16
17fn qr_decomp<T: TTScalar + Scalar>(matrix: &Matrix<T>) -> (Matrix<T>, Matrix<T>) {
19 let options = RrLUOptions {
20 max_rank: ncols(matrix).min(nrows(matrix)),
21 rel_tol: 0.0, abs_tol: 0.0,
23 left_orthogonal: true,
24 };
25 let lu = rrlu(matrix, Some(options)).expect("rrlu failed in QR decomposition");
26 (lu.left(true), lu.right(true))
27}
28
29fn lq_decomp<T: TTScalar + Scalar>(matrix: &Matrix<T>) -> (Matrix<T>, Matrix<T>) {
31 let at = transpose(matrix);
32 let (qt, lt) = qr_decomp(&at);
33 (transpose(<), transpose(&qt))
34}
35
36fn tensor3_to_left_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
38 let left_dim = tensor.left_dim();
39 let site_dim = tensor.site_dim();
40 let right_dim = tensor.right_dim();
41 let rows = left_dim * site_dim;
42 let cols = right_dim;
43
44 let mut mat = zeros(rows, cols);
45 for l in 0..left_dim {
46 for s in 0..site_dim {
47 for r in 0..right_dim {
48 mat[[l * site_dim + s, r]] = *tensor.get3(l, s, r);
49 }
50 }
51 }
52 mat
53}
54
55fn tensor3_to_right_matrix<T: TTScalar + Scalar + Default>(tensor: &Tensor3<T>) -> Matrix<T> {
57 let left_dim = tensor.left_dim();
58 let site_dim = tensor.site_dim();
59 let right_dim = tensor.right_dim();
60 let rows = left_dim;
61 let cols = site_dim * right_dim;
62
63 let mut mat = zeros(rows, cols);
64 for l in 0..left_dim {
65 for s in 0..site_dim {
66 for r in 0..right_dim {
67 mat[[l, s * right_dim + r]] = *tensor.get3(l, s, r);
68 }
69 }
70 }
71 mat
72}
73
74#[derive(Debug, Clone)]
100pub struct SiteTensorTrain<T: TTScalar> {
101 tensors: Vec<Tensor3<T>>,
103 center: usize,
105 partition: Range<usize>,
107}
108
109impl<T: TTScalar + Scalar + Default> SiteTensorTrain<T> {
110 pub fn new(tensors: Vec<Tensor3<T>>, center: usize) -> Result<Self> {
112 let n = tensors.len();
113 if n == 0 {
114 return Err(TensorTrainError::Empty);
115 }
116 if center >= n {
117 return Err(TensorTrainError::InvalidOperation {
118 message: format!("Center {} is out of range for {} tensors", center, n),
119 });
120 }
121
122 for i in 0..n.saturating_sub(1) {
124 if tensors[i].right_dim() != tensors[i + 1].left_dim() {
125 return Err(TensorTrainError::DimensionMismatch { site: i });
126 }
127 }
128
129 let mut result = Self {
130 tensors,
131 center,
132 partition: 0..n,
133 };
134 result.canonicalize();
135 Ok(result)
136 }
137
138 pub fn from_tensor_train(tt: &TensorTrain<T>, center: usize) -> Result<Self> {
140 let tensors = tt.site_tensors().to_vec();
141 Self::new(tensors, center)
142 }
143
144 pub fn center(&self) -> usize {
146 self.center
147 }
148
149 pub fn partition(&self) -> &Range<usize> {
151 &self.partition
152 }
153
154 pub fn site_tensors_mut(&mut self) -> &mut [Tensor3<T>] {
156 &mut self.tensors
157 }
158
159 fn canonicalize(&mut self) {
161 let n = self.len();
162 if n <= 1 {
163 return;
164 }
165
166 for i in 0..self.center {
168 self.make_left_orthogonal(i);
169 }
170
171 for i in (self.center + 1..n).rev() {
173 self.make_right_orthogonal(i);
174 }
175 }
176
177 fn make_left_orthogonal(&mut self, i: usize) {
179 if i >= self.len() - 1 {
180 return;
181 }
182
183 let left_dim = self.tensors[i].left_dim();
184 let site_dim = self.tensors[i].site_dim();
185
186 let mat = tensor3_to_left_matrix(&self.tensors[i]);
188 let (q, r) = qr_decomp(&mat);
189
190 let new_bond_dim = ncols(&q);
191
192 let mut new_tensor = tensor3_zeros(left_dim, site_dim, new_bond_dim);
194 for l in 0..left_dim {
195 for s in 0..site_dim {
196 for b in 0..new_bond_dim {
197 let row = l * site_dim + s;
198 if row < nrows(&q) && b < ncols(&q) {
199 new_tensor.set3(l, s, b, q[[row, b]]);
200 }
201 }
202 }
203 }
204 self.tensors[i] = new_tensor;
205
206 let next_site_dim = self.tensors[i + 1].site_dim();
208 let next_right_dim = self.tensors[i + 1].right_dim();
209 let next_mat = tensor3_to_right_matrix(&self.tensors[i + 1]);
210
211 let contracted = mat_mul(&r, &next_mat);
213
214 let mut new_next_tensor = tensor3_zeros(new_bond_dim, next_site_dim, next_right_dim);
216 for l in 0..new_bond_dim {
217 for s in 0..next_site_dim {
218 for r_idx in 0..next_right_dim {
219 new_next_tensor.set3(l, s, r_idx, contracted[[l, s * next_right_dim + r_idx]]);
220 }
221 }
222 }
223 self.tensors[i + 1] = new_next_tensor;
224 }
225
226 fn make_right_orthogonal(&mut self, i: usize) {
228 if i == 0 {
229 return;
230 }
231
232 let site_dim = self.tensors[i].site_dim();
233 let right_dim = self.tensors[i].right_dim();
234
235 let mat = tensor3_to_right_matrix(&self.tensors[i]);
237 let (l_mat, q) = lq_decomp(&mat);
238
239 let new_bond_dim = nrows(&q);
240
241 let mut new_tensor = tensor3_zeros(new_bond_dim, site_dim, right_dim);
243 for l in 0..new_bond_dim {
244 for s in 0..site_dim {
245 for r in 0..right_dim {
246 new_tensor.set3(l, s, r, q[[l, s * right_dim + r]]);
247 }
248 }
249 }
250 self.tensors[i] = new_tensor;
251
252 let prev_left_dim = self.tensors[i - 1].left_dim();
254 let prev_site_dim = self.tensors[i - 1].site_dim();
255 let prev_mat = tensor3_to_left_matrix(&self.tensors[i - 1]);
256
257 let contracted = mat_mul(&prev_mat, &l_mat);
259
260 let mut new_prev_tensor = tensor3_zeros(prev_left_dim, prev_site_dim, new_bond_dim);
262 for l in 0..prev_left_dim {
263 for s in 0..prev_site_dim {
264 for r in 0..new_bond_dim {
265 new_prev_tensor.set3(l, s, r, contracted[[l * prev_site_dim + s, r]]);
266 }
267 }
268 }
269 self.tensors[i - 1] = new_prev_tensor;
270 }
271
272 pub fn move_center_right(&mut self) -> Result<()> {
274 if self.center >= self.len() - 1 {
275 return Err(TensorTrainError::InvalidOperation {
276 message: "Cannot move center right: already at rightmost position".to_string(),
277 });
278 }
279
280 self.make_left_orthogonal(self.center);
281 self.center += 1;
282 Ok(())
283 }
284
285 pub fn move_center_left(&mut self) -> Result<()> {
287 if self.center == 0 {
288 return Err(TensorTrainError::InvalidOperation {
289 message: "Cannot move center left: already at leftmost position".to_string(),
290 });
291 }
292
293 self.make_right_orthogonal(self.center);
294 self.center -= 1;
295 Ok(())
296 }
297
298 pub fn set_center(&mut self, new_center: usize) -> Result<()> {
300 if new_center >= self.len() {
301 return Err(TensorTrainError::InvalidOperation {
302 message: format!(
303 "New center {} is out of range for {} tensors",
304 new_center,
305 self.len()
306 ),
307 });
308 }
309
310 while self.center < new_center {
311 self.move_center_right()?;
312 }
313 while self.center > new_center {
314 self.move_center_left()?;
315 }
316 Ok(())
317 }
318
319 pub fn to_tensor_train(&self) -> TensorTrain<T> {
321 TensorTrain::from_tensors_unchecked(self.tensors.clone())
322 }
323
324 pub fn set_site_tensor(&mut self, i: usize, tensor: Tensor3<T>) {
328 self.tensors[i] = tensor;
329 }
330
331 pub fn set_two_site_tensors(
333 &mut self,
334 i: usize,
335 tensor1: Tensor3<T>,
336 tensor2: Tensor3<T>,
337 ) -> Result<()> {
338 if i >= self.len() - 1 {
339 return Err(TensorTrainError::InvalidOperation {
340 message: format!(
341 "Cannot set two-site tensors at site {} (max {})",
342 i,
343 self.len() - 2
344 ),
345 });
346 }
347
348 self.tensors[i] = tensor1;
349 self.tensors[i + 1] = tensor2;
350 Ok(())
351 }
352}
353
354impl<T: TTScalar + Scalar + Default> AbstractTensorTrain<T> for SiteTensorTrain<T> {
355 fn len(&self) -> usize {
356 self.tensors.len()
357 }
358
359 fn site_tensor(&self, i: usize) -> &Tensor3<T> {
360 &self.tensors[i]
361 }
362
363 fn site_tensors(&self) -> &[Tensor3<T>] {
364 &self.tensors
365 }
366}
367
368pub fn center_canonicalize<T: TTScalar + Scalar + Default>(
393 tensors: &mut [Tensor3<T>],
394 center: usize,
395) {
396 let n = tensors.len();
397 if n <= 1 || center >= n {
398 return;
399 }
400
401 for i in 0..center {
403 let left_dim = tensors[i].left_dim();
404 let site_dim = tensors[i].site_dim();
405
406 let mat = tensor3_to_left_matrix(&tensors[i]);
407 let (q, r) = qr_decomp(&mat);
408
409 let new_bond_dim = ncols(&q);
410
411 let mut new_tensor = tensor3_zeros(left_dim, site_dim, new_bond_dim);
413 for l in 0..left_dim {
414 for s in 0..site_dim {
415 for b in 0..new_bond_dim {
416 let row = l * site_dim + s;
417 if row < nrows(&q) && b < ncols(&q) {
418 new_tensor.set3(l, s, b, q[[row, b]]);
419 }
420 }
421 }
422 }
423 tensors[i] = new_tensor;
424
425 if i + 1 < n {
427 let next_site_dim = tensors[i + 1].site_dim();
428 let next_right_dim = tensors[i + 1].right_dim();
429 let next_mat = tensor3_to_right_matrix(&tensors[i + 1]);
430
431 let contracted = mat_mul(&r, &next_mat);
432
433 let mut new_next_tensor = tensor3_zeros(new_bond_dim, next_site_dim, next_right_dim);
434 for l in 0..new_bond_dim {
435 for s in 0..next_site_dim {
436 for r_idx in 0..next_right_dim {
437 new_next_tensor.set3(
438 l,
439 s,
440 r_idx,
441 contracted[[l, s * next_right_dim + r_idx]],
442 );
443 }
444 }
445 }
446 tensors[i + 1] = new_next_tensor;
447 }
448 }
449
450 for i in (center + 1..n).rev() {
452 let site_dim = tensors[i].site_dim();
453 let right_dim = tensors[i].right_dim();
454
455 let mat = tensor3_to_right_matrix(&tensors[i]);
456 let (l_mat, q) = lq_decomp(&mat);
457
458 let new_bond_dim = nrows(&q);
459
460 let mut new_tensor = tensor3_zeros(new_bond_dim, site_dim, right_dim);
462 for l in 0..new_bond_dim {
463 for s in 0..site_dim {
464 for r in 0..right_dim {
465 new_tensor.set3(l, s, r, q[[l, s * right_dim + r]]);
466 }
467 }
468 }
469 tensors[i] = new_tensor;
470
471 if i > 0 {
473 let prev_left_dim = tensors[i - 1].left_dim();
474 let prev_site_dim = tensors[i - 1].site_dim();
475 let prev_mat = tensor3_to_left_matrix(&tensors[i - 1]);
476
477 let contracted = mat_mul(&prev_mat, &l_mat);
478
479 let mut new_prev_tensor = tensor3_zeros(prev_left_dim, prev_site_dim, new_bond_dim);
480 for l in 0..prev_left_dim {
481 for s in 0..prev_site_dim {
482 for r in 0..new_bond_dim {
483 new_prev_tensor.set3(l, s, r, contracted[[l * prev_site_dim + s, r]]);
484 }
485 }
486 }
487 tensors[i - 1] = new_prev_tensor;
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests;