1use std::collections::{HashMap, HashSet};
7
8use bnum::types::{U1024, U256, U512};
9
10use crate::einsum_helper::EinsumScalar;
11use crate::einsum_helper::{matrix_times_col_vector, row_vector_times_matrix};
12use crate::error::{Result, TensorTrainError};
13use crate::traits::{AbstractTensorTrain, TTScalar};
14use crate::types::{LocalIndex, MultiIndex, Tensor3, Tensor3Ops};
15
16fn compute_total_bits(local_dims: &[usize]) -> u32 {
18 local_dims
19 .iter()
20 .map(|&d| if d <= 1 { 0 } else { (d as u64).ilog2() + 1 })
21 .sum()
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26enum IndexKey {
27 U64(u64),
28 U128(u128),
29 U256(U256),
30 U512(U512),
31 U1024(U1024),
32}
33
34enum FlatIndexer {
36 U64 { coeffs: Vec<u64> },
37 U128 { coeffs: Vec<u128> },
38 U256 { coeffs: Vec<U256> },
39 U512 { coeffs: Vec<U512> },
40 U1024 { coeffs: Vec<U1024> },
41}
42
43macro_rules! compute_coeffs_primitive {
45 ($local_dims:expr, $T:ty) => {{
46 let mut coeffs = Vec::with_capacity($local_dims.len());
47 let mut prod: $T = 1;
48 for &d in $local_dims {
49 coeffs.push(prod);
50 prod = prod.saturating_mul(d as $T);
51 }
52 coeffs
53 }};
54}
55
56macro_rules! compute_coeffs_bnum {
58 ($local_dims:expr, $T:ty) => {{
59 let mut coeffs = Vec::with_capacity($local_dims.len());
60 let mut prod = <$T>::ONE;
61 for &d in $local_dims {
62 coeffs.push(prod);
63 prod = prod.saturating_mul(<$T>::from(d as u64));
64 }
65 coeffs
66 }};
67}
68
69macro_rules! flat_index_primitive {
71 ($idx:expr, $coeffs:expr, $T:ty, $Key:ident) => {{
72 let key: $T = $idx.iter().zip($coeffs).map(|(&i, &c)| c * i as $T).sum();
73 IndexKey::$Key(key)
74 }};
75}
76
77macro_rules! flat_index_bnum {
79 ($idx:expr, $coeffs:expr, $T:ty, $Key:ident) => {{
80 let key = $idx
81 .iter()
82 .zip($coeffs)
83 .map(|(&i, &c)| c * <$T>::from(i as u64))
84 .fold(<$T>::ZERO, |a, b| a + b);
85 IndexKey::$Key(key)
86 }};
87}
88
89impl FlatIndexer {
90 fn new(local_dims: &[usize]) -> Self {
92 let total_bits = compute_total_bits(local_dims);
93
94 if total_bits <= 64 {
95 Self::U64 {
96 coeffs: compute_coeffs_primitive!(local_dims, u64),
97 }
98 } else if total_bits <= 128 {
99 Self::U128 {
100 coeffs: compute_coeffs_primitive!(local_dims, u128),
101 }
102 } else if total_bits <= 256 {
103 Self::U256 {
104 coeffs: compute_coeffs_bnum!(local_dims, U256),
105 }
106 } else if total_bits <= 512 {
107 Self::U512 {
108 coeffs: compute_coeffs_bnum!(local_dims, U512),
109 }
110 } else {
111 Self::U1024 {
112 coeffs: compute_coeffs_bnum!(local_dims, U1024),
113 }
114 }
115 }
116
117 fn flat_index(&self, idx: &[usize]) -> IndexKey {
119 match self {
120 Self::U64 { coeffs } => flat_index_primitive!(idx, coeffs, u64, U64),
121 Self::U128 { coeffs } => flat_index_primitive!(idx, coeffs, u128, U128),
122 Self::U256 { coeffs } => flat_index_bnum!(idx, coeffs, U256, U256),
123 Self::U512 { coeffs } => flat_index_bnum!(idx, coeffs, U512, U512),
124 Self::U1024 { coeffs } => flat_index_bnum!(idx, coeffs, U1024, U1024),
125 }
126 }
127}
128
129struct IndexMapper {
131 left_indexer: FlatIndexer,
132 right_indexer: FlatIndexer,
133 left_key_to_id: HashMap<IndexKey, usize>,
134 right_key_to_id: HashMap<IndexKey, usize>,
135 idx_to_left: Vec<usize>,
136 idx_to_right: Vec<usize>,
137 left_first_idx: Vec<usize>,
138 right_first_idx: Vec<usize>,
139}
140
141impl IndexMapper {
142 fn new(left_dims: &[usize], right_dims: &[usize], capacity: usize) -> Self {
143 Self {
144 left_indexer: FlatIndexer::new(left_dims),
145 right_indexer: FlatIndexer::new(right_dims),
146 left_key_to_id: HashMap::new(),
147 right_key_to_id: HashMap::new(),
148 idx_to_left: Vec::with_capacity(capacity),
149 idx_to_right: Vec::with_capacity(capacity),
150 left_first_idx: Vec::new(),
151 right_first_idx: Vec::new(),
152 }
153 }
154
155 fn add_index(&mut self, i: usize, left_part: &[usize], right_part: &[usize]) {
156 let left_key = self.left_indexer.flat_index(left_part);
157 let right_key = self.right_indexer.flat_index(right_part);
158
159 let left_id = match self.left_key_to_id.get(&left_key) {
160 Some(&id) => id,
161 None => {
162 let id = self.left_key_to_id.len();
163 self.left_key_to_id.insert(left_key, id);
164 self.left_first_idx.push(i);
165 id
166 }
167 };
168
169 let right_id = match self.right_key_to_id.get(&right_key) {
170 Some(&id) => id,
171 None => {
172 let id = self.right_key_to_id.len();
173 self.right_key_to_id.insert(right_key, id);
174 self.right_first_idx.push(i);
175 id
176 }
177 };
178
179 self.idx_to_left.push(left_id);
180 self.idx_to_right.push(right_id);
181 }
182}
183
184struct UniqueCounter {
186 indexer: FlatIndexer,
187 keys: HashSet<IndexKey>,
188}
189
190impl UniqueCounter {
191 fn new(local_dims: &[usize], capacity: usize) -> Self {
192 Self {
193 indexer: FlatIndexer::new(local_dims),
194 keys: HashSet::with_capacity(capacity),
195 }
196 }
197
198 fn insert(&mut self, idx: &[usize]) {
199 let key = self.indexer.flat_index(idx);
200 self.keys.insert(key);
201 }
202
203 fn len(&self) -> usize {
204 self.keys.len()
205 }
206}
207
208#[derive(Debug, Clone)]
233pub struct TTCache<T: TTScalar> {
234 tensors: Vec<Tensor3<T>>,
236 cache_left: Vec<HashMap<MultiIndex, Vec<T>>>,
238 cache_right: Vec<HashMap<MultiIndex, Vec<T>>>,
240 site_dims: Vec<Vec<usize>>,
242}
243
244impl<T: TTScalar + EinsumScalar> TTCache<T> {
245 pub fn new<TT: AbstractTensorTrain<T>>(tt: &TT) -> Self {
247 let n = tt.len();
248 let tensors: Vec<Tensor3<T>> = tt.site_tensors().to_vec();
249 let site_dims: Vec<Vec<usize>> = tensors.iter().map(|t| vec![t.site_dim()]).collect();
250
251 Self {
252 tensors,
253 cache_left: (0..n).map(|_| HashMap::new()).collect(),
254 cache_right: (0..n).map(|_| HashMap::new()).collect(),
255 site_dims,
256 }
257 }
258
259 pub fn with_site_dims<TT: AbstractTensorTrain<T>>(
263 tt: &TT,
264 site_dims: Vec<Vec<usize>>,
265 ) -> Result<Self> {
266 let n = tt.len();
267 if site_dims.len() != n {
268 return Err(TensorTrainError::InvalidOperation {
269 message: format!(
270 "site_dims length {} doesn't match tensor train length {}",
271 site_dims.len(),
272 n
273 ),
274 });
275 }
276
277 for (i, (tensor, dims)) in tt.site_tensors().iter().zip(site_dims.iter()).enumerate() {
279 let expected: usize = dims.iter().product();
280 if expected != tensor.site_dim() {
281 return Err(TensorTrainError::InvalidOperation {
282 message: format!(
283 "site_dims product {} doesn't match tensor site dim {} at site {}",
284 expected,
285 tensor.site_dim(),
286 i
287 ),
288 });
289 }
290 }
291
292 let tensors: Vec<Tensor3<T>> = tt.site_tensors().to_vec();
293
294 Ok(Self {
295 tensors,
296 cache_left: (0..n).map(|_| HashMap::new()).collect(),
297 cache_right: (0..n).map(|_| HashMap::new()).collect(),
298 site_dims,
299 })
300 }
301
302 pub fn len(&self) -> usize {
304 self.tensors.len()
305 }
306
307 pub fn is_empty(&self) -> bool {
309 self.tensors.is_empty()
310 }
311
312 pub fn site_dims(&self) -> &[Vec<usize>] {
314 &self.site_dims
315 }
316
317 pub fn link_dims(&self) -> Vec<usize> {
319 if self.len() <= 1 {
320 return Vec::new();
321 }
322 (1..self.len())
323 .map(|i| self.tensors[i].left_dim())
324 .collect()
325 }
326
327 pub fn link_dim(&self, i: usize) -> usize {
329 self.tensors[i + 1].left_dim()
330 }
331
332 pub fn clear_cache(&mut self) {
334 for cache in &mut self.cache_left {
335 cache.clear();
336 }
337 for cache in &mut self.cache_right {
338 cache.clear();
339 }
340 }
341
342 fn multi_to_flat(&self, site: usize, indices: &[LocalIndex]) -> LocalIndex {
344 let dims = &self.site_dims[site];
345 let mut flat = 0;
346 let mut stride = 1;
347 for (i, &idx) in indices.iter().rev().enumerate() {
348 flat += idx * stride;
349 stride *= dims[dims.len() - 1 - i];
350 }
351 flat
352 }
353
354 pub fn evaluate_left(&mut self, indices: &[LocalIndex]) -> Vec<T> {
358 let ell = indices.len();
359 if ell == 0 {
360 return vec![T::one()];
361 }
362
363 let key: MultiIndex = indices.to_vec();
365 if let Some(cached) = self.cache_left[ell - 1].get(&key) {
366 return cached.clone();
367 }
368
369 let result = if ell == 1 {
371 let flat_idx = self.multi_to_flat(0, &indices[0..1]);
373 let tensor = &self.tensors[0];
374 tensor.slice_site(flat_idx)
375 } else {
376 let left = self.evaluate_left(&indices[0..ell - 1]);
378 let flat_idx = self.multi_to_flat(ell - 1, &indices[ell - 1..ell]);
379 let tensor = &self.tensors[ell - 1];
380 let slice = tensor.slice_site(flat_idx);
381 row_vector_times_matrix(&left, &slice, tensor.left_dim(), tensor.right_dim())
382 };
383
384 self.cache_left[ell - 1].insert(key, result.clone());
386 result
387 }
388
389 pub fn evaluate_right(&mut self, indices: &[LocalIndex]) -> Vec<T> {
394 let n = self.len();
395 let ell = indices.len();
396 if ell == 0 {
397 return vec![T::one()];
398 }
399
400 let start = n - ell;
401
402 let key: MultiIndex = indices.to_vec();
404 if let Some(cached) = self.cache_right[start].get(&key) {
405 return cached.clone();
406 }
407
408 let result = if ell == 1 {
410 let flat_idx = self.multi_to_flat(n - 1, &indices[0..1]);
412 let tensor = &self.tensors[n - 1];
413 tensor.slice_site(flat_idx)
414 } else {
415 let right = self.evaluate_right(&indices[1..]);
417 let flat_idx = self.multi_to_flat(start, &indices[0..1]);
418 let tensor = &self.tensors[start];
419 let slice = tensor.slice_site(flat_idx);
420 matrix_times_col_vector(&slice, tensor.left_dim(), tensor.right_dim(), &right)
421 };
422
423 self.cache_right[start].insert(key, result.clone());
425 result
426 }
427
428 pub fn evaluate(&mut self, indices: &[LocalIndex]) -> Result<T> {
430 let n = self.len();
431 if indices.len() != n {
432 return Err(TensorTrainError::IndexLengthMismatch {
433 expected: n,
434 got: indices.len(),
435 });
436 }
437
438 if n == 0 {
439 return Err(TensorTrainError::Empty);
440 }
441
442 let mid = n / 2;
444 let left = self.evaluate_left(&indices[0..mid]);
445 let right = self.evaluate_right(&indices[mid..]);
446
447 if left.len() != right.len() {
449 return Err(TensorTrainError::InvalidOperation {
450 message: format!(
451 "Left/right dimension mismatch: {} vs {}",
452 left.len(),
453 right.len()
454 ),
455 });
456 }
457
458 let mut result = T::zero();
459 for i in 0..left.len() {
460 result = result + left[i] * right[i];
461 }
462
463 Ok(result)
464 }
465
466 pub fn evaluate_many(
478 &mut self,
479 indices: &[MultiIndex],
480 split: Option<usize>,
481 ) -> Result<Vec<T>> {
482 if indices.is_empty() {
483 return Ok(Vec::new());
484 }
485
486 let n = self.len();
487 if n == 0 {
488 return Err(TensorTrainError::Empty);
489 }
490
491 let split = match split {
493 Some(s) => s,
494 None => self.find_split_heuristic(indices),
495 };
496
497 if split == 0 || split > n {
498 return Err(TensorTrainError::InvalidOperation {
499 message: format!("Invalid split position: {} (n_sites={})", split, n),
500 });
501 }
502
503 let local_dims: Vec<usize> = self.site_dims.iter().map(|d| d.iter().product()).collect();
505
506 let mut mapper =
508 IndexMapper::new(&local_dims[..split], &local_dims[split..], indices.len());
509
510 for (i, idx) in indices.iter().enumerate() {
511 mapper.add_index(i, &idx[..split], &idx[split..]);
512 }
513
514 let unique_left: Vec<MultiIndex> = mapper
516 .left_first_idx
517 .iter()
518 .map(|&i| indices[i][..split].to_vec())
519 .collect();
520
521 let unique_right: Vec<MultiIndex> = mapper
522 .right_first_idx
523 .iter()
524 .map(|&i| indices[i][split..].to_vec())
525 .collect();
526
527 let left_envs: Vec<Vec<T>> = unique_left.iter().map(|l| self.evaluate_left(l)).collect();
529
530 let right_envs: Vec<Vec<T>> = unique_right
532 .iter()
533 .map(|r| self.evaluate_right(r))
534 .collect();
535
536 let results: Vec<T> = mapper
538 .idx_to_left
539 .iter()
540 .zip(&mapper.idx_to_right)
541 .map(|(&il, &ir)| {
542 let left_env = &left_envs[il];
543 let right_env = &right_envs[ir];
544 left_env
546 .iter()
547 .zip(right_env.iter())
548 .fold(T::zero(), |acc, (&l, &r)| acc + l * r)
549 })
550 .collect();
551
552 Ok(results)
553 }
554
555 fn find_split_heuristic(&self, indices: &[MultiIndex]) -> usize {
560 let n = self.len();
561 if n <= 1 {
562 return n.max(1);
563 }
564
565 let local_dims: Vec<usize> = self.site_dims.iter().map(|d| d.iter().product()).collect();
566
567 let compute_cost = |split: usize| -> usize {
569 if split == 0 || split >= n {
570 return usize::MAX;
571 }
572
573 let mut left_counter = UniqueCounter::new(&local_dims[..split], indices.len());
574 let mut right_counter = UniqueCounter::new(&local_dims[split..], indices.len());
575
576 for idx in indices {
577 left_counter.insert(&idx[..split]);
578 right_counter.insert(&idx[split..]);
579 }
580
581 left_counter.len() + right_counter.len()
582 };
583
584 let candidates = [n / 4, n / 2, 3 * n / 4];
586 let costs: Vec<(usize, usize)> = candidates
587 .iter()
588 .filter(|&&p| p >= 1 && p < n)
589 .map(|&p| (p, compute_cost(p)))
590 .collect();
591
592 costs
593 .into_iter()
594 .min_by_key(|&(_, c)| c)
595 .map(|(p, _)| p)
596 .unwrap_or(n / 2)
597 }
598}
599
600#[cfg(test)]
601mod tests;