tensor4all_itensorlike/tensortrain.rs
1//! Main Tensor Train type as a wrapper around TreeTN.
2//!
3//! This module provides the `TensorTrain` type, which represents a Tensor Train
4//! (also known as MPS) with orthogonality tracking, inspired by ITensorMPS.jl.
5//!
6//! Internally, TensorTrain is implemented as a thin wrapper around
7//! `TreeTN<TensorDynLen, usize>` where node names are site indices (0, 1, 2, ...).
8
9use num_complex::Complex64;
10use std::ops::Range;
11use tensor4all_core::{common_inds, hascommoninds, DynIndex, IndexLike};
12use tensor4all_core::{
13 AllowedPairs, AnyScalar, CommonScalar, DirectSumResult, FactorizeError, FactorizeOptions,
14 FactorizeResult, TensorDynLen, TensorElement, TensorIndex, TensorLike,
15};
16use tensor4all_treetn::{CanonicalizationOptions, TreeTN, TruncationOptions};
17
18use crate::error::{Result, TensorTrainError};
19use crate::options::{validate_svd_truncation_options, CanonicalForm, TruncateOptions};
20
21/// Tensor Train with orthogonality tracking.
22///
23/// This type represents a tensor train as a sequence of tensors with tracked
24/// orthogonality limits. It is inspired by ITensorMPS.jl but uses
25/// 0-indexed sites (Rust convention).
26///
27/// Unlike traditional MPS which assumes one physical index per site, this
28/// implementation allows each site to have multiple site indices.
29///
30/// # Orthogonality Tracking
31///
32/// The tensor train tracks orthogonality using `ortho_region` from the underlying TreeTN:
33/// - When `ortho_region` is empty, no orthogonality is assumed
34/// - When `ortho_region` contains a single site, that site is the orthogonality center
35///
36/// # Implementation
37///
38/// Internally wraps `TreeTN<TensorDynLen, usize>` where node names are site indices.
39/// This allows reuse of TreeTN's canonicalization and contraction algorithms.
40///
41/// # Examples
42///
43/// Build a 2-site tensor train and query its properties:
44///
45/// ```
46/// use tensor4all_itensorlike::TensorTrain;
47/// use tensor4all_core::{DynIndex, TensorDynLen, Index};
48/// use tensor4all_core::DynId;
49///
50/// // Site indices and link index
51/// let s0 = Index::new_with_size(DynId(0), 2);
52/// let link = Index::new_with_size(DynId(1), 3);
53/// let s1 = Index::new_with_size(DynId(2), 2);
54///
55/// let t0 = TensorDynLen::from_dense(
56/// vec![s0.clone(), link.clone()],
57/// (0..6).map(|i| i as f64).collect(),
58/// ).unwrap();
59/// let t1 = TensorDynLen::from_dense(
60/// vec![link.clone(), s1.clone()],
61/// (0..6).map(|i| i as f64).collect(),
62/// ).unwrap();
63///
64/// let tt = TensorTrain::new(vec![t0, t1]).unwrap();
65/// assert_eq!(tt.len(), 2);
66/// assert_eq!(tt.maxbonddim(), 3);
67/// assert!(!tt.is_empty());
68/// ```
69#[derive(Debug, Clone)]
70pub struct TensorTrain {
71 /// The underlying TreeTN with linear chain topology.
72 /// Node names are usize (0, 1, 2, ...) representing site indices.
73 pub(crate) treetn: TreeTN<TensorDynLen, usize>,
74 /// The canonical form used (if known).
75 canonical_form: Option<CanonicalForm>,
76}
77
78#[derive(Debug)]
79struct PackedSiteTensor<T> {
80 left_dim: usize,
81 physical_dim: usize,
82 right_dim: usize,
83 data: Vec<T>,
84}
85
86impl<T: Copy> PackedSiteTensor<T> {
87 fn get(&self, left: usize, physical: usize, right: usize) -> T {
88 debug_assert!(left < self.left_dim);
89 debug_assert!(physical < self.physical_dim);
90 debug_assert!(right < self.right_dim);
91
92 let idx = left + self.left_dim * (physical + self.physical_dim * right);
93 self.data[idx]
94 }
95}
96
97trait NormAccumScalar: CommonScalar {
98 fn into_nonnegative_real(self) -> f64;
99}
100
101impl NormAccumScalar for f64 {
102 fn into_nonnegative_real(self) -> f64 {
103 self.max(0.0)
104 }
105}
106
107impl NormAccumScalar for Complex64 {
108 fn into_nonnegative_real(self) -> f64 {
109 self.re.max(0.0)
110 }
111}
112
113impl TensorTrain {
114 /// Create a new tensor train from a vector of tensors.
115 ///
116 /// The tensor train is created with no assumed orthogonality.
117 ///
118 /// # Arguments
119 ///
120 /// * `tensors` - Vector of tensors representing the tensor train
121 ///
122 /// # Returns
123 ///
124 /// A new tensor train with no orthogonality.
125 ///
126 /// # Errors
127 ///
128 /// Returns an error if the tensors have inconsistent bond dimensions
129 /// (i.e., the link indices between adjacent tensors don't match).
130 pub fn new(tensors: Vec<TensorDynLen>) -> Result<Self> {
131 if tensors.is_empty() {
132 // Create an empty TreeTN
133 let treetn = TreeTN::<TensorDynLen, usize>::new();
134 return Ok(Self {
135 treetn,
136 canonical_form: None,
137 });
138 }
139
140 // Validate that adjacent tensors share exactly one common index (the link)
141 for i in 0..tensors.len() - 1 {
142 let left = &tensors[i];
143 let right = &tensors[i + 1];
144
145 let common = common_inds(left.indices(), right.indices());
146 if common.is_empty() {
147 return Err(TensorTrainError::InvalidStructure {
148 message: format!(
149 "No common index between tensors at sites {} and {}",
150 i,
151 i + 1
152 ),
153 });
154 }
155 if common.len() > 1 {
156 return Err(TensorTrainError::InvalidStructure {
157 message: format!(
158 "Multiple common indices ({}) between tensors at sites {} and {}",
159 common.len(),
160 i,
161 i + 1
162 ),
163 });
164 }
165 }
166
167 // Create node names: 0, 1, 2, ..., n-1
168 let node_names: Vec<usize> = (0..tensors.len()).collect();
169
170 // Create TreeTN with from_tensors (auto-connects by shared index IDs)
171 let treetn =
172 TreeTN::<TensorDynLen, usize>::from_tensors(tensors, node_names).map_err(|e| {
173 TensorTrainError::InvalidStructure {
174 message: format!("Failed to create TreeTN: {}", e),
175 }
176 })?;
177
178 let mut tt = Self {
179 treetn,
180 canonical_form: None,
181 };
182 tt.normalize_site_tensor_orders()?;
183 Ok(tt)
184 }
185
186 /// Create a new tensor train with specified orthogonality center.
187 ///
188 /// This is useful when constructing a tensor train that is already in canonical form.
189 ///
190 /// # Arguments
191 ///
192 /// * `tensors` - Vector of tensors representing the tensor train
193 /// * `llim` - Left orthogonality limit (for compatibility; only used to compute center)
194 /// * `rlim` - Right orthogonality limit (for compatibility; only used to compute center)
195 /// * `canonical_form` - The method used for canonicalization (if any)
196 pub fn with_ortho(
197 tensors: Vec<TensorDynLen>,
198 llim: i32,
199 rlim: i32,
200 canonical_form: Option<CanonicalForm>,
201 ) -> Result<Self> {
202 let mut tt = Self::new(tensors)?;
203
204 // Convert llim/rlim to ortho center
205 // When llim + 2 == rlim, ortho center is at llim + 1
206 if llim + 2 == rlim && llim >= -1 && (llim + 1) < tt.len() as i32 {
207 let center = (llim + 1) as usize;
208 tt.treetn.set_canonical_region(vec![center]).map_err(|e| {
209 TensorTrainError::InvalidStructure {
210 message: format!("Failed to set ortho region: {}", e),
211 }
212 })?;
213 }
214
215 tt.canonical_form = canonical_form;
216 Ok(tt)
217 }
218
219 /// Create a TensorTrain from an existing TreeTN and canonical form.
220 ///
221 /// This is a crate-internal constructor used by `contract` and `linsolve`.
222 pub(crate) fn from_inner(
223 treetn: TreeTN<TensorDynLen, usize>,
224 canonical_form: Option<CanonicalForm>,
225 ) -> Result<Self> {
226 let mut node_names = treetn.node_names();
227 node_names.sort_unstable();
228 let mut tt = Self {
229 treetn,
230 canonical_form,
231 };
232 for (site, old_name) in node_names.into_iter().enumerate() {
233 if old_name != site {
234 tt.treetn.rename_node(&old_name, site).map_err(|e| {
235 TensorTrainError::InvalidStructure {
236 message: format!("Failed to renumber TensorTrain sites: {}", e),
237 }
238 })?;
239 }
240 }
241 if tt.has_simple_linear_links() {
242 tt.normalize_site_tensor_orders()?;
243 }
244 Ok(tt)
245 }
246
247 /// Get a reference to the underlying TreeTN.
248 ///
249 /// This is a crate-internal accessor used by `contract` and `linsolve`.
250 pub(crate) fn as_treetn(&self) -> &TreeTN<TensorDynLen, usize> {
251 &self.treetn
252 }
253
254 /// Number of sites (tensors) in the tensor train.
255 #[inline]
256 pub fn len(&self) -> usize {
257 self.treetn.node_count()
258 }
259
260 /// Check if the tensor train is empty.
261 #[inline]
262 pub fn is_empty(&self) -> bool {
263 self.treetn.node_count() == 0
264 }
265
266 /// Left orthogonality limit.
267 ///
268 /// Sites `0..llim` are guaranteed to be left-orthogonal.
269 /// Returns -1 if no sites are left-orthogonal.
270 #[inline]
271 pub fn llim(&self) -> i32 {
272 match self.orthocenter() {
273 Some(center) => center as i32 - 1,
274 None => -1,
275 }
276 }
277
278 /// Right orthogonality limit.
279 ///
280 /// Sites `rlim..len()` are guaranteed to be right-orthogonal.
281 /// Returns `len() + 1` if no sites are right-orthogonal.
282 #[inline]
283 pub fn rlim(&self) -> i32 {
284 match self.orthocenter() {
285 Some(center) => center as i32 + 1,
286 None => self.len() as i32 + 1,
287 }
288 }
289
290 /// Set the left orthogonality limit.
291 #[inline]
292 pub fn set_llim(&mut self, llim: i32) {
293 // Convert to ortho center if possible
294 let rlim = self.rlim();
295 if llim + 2 == rlim && llim >= -1 && (llim + 1) < self.len() as i32 {
296 let center = (llim + 1) as usize;
297 let _ = self.treetn.set_canonical_region(vec![center]);
298 } else {
299 // Clear ortho region if not a single center
300 let _ = self.treetn.set_canonical_region(Vec::<usize>::new());
301 }
302 }
303
304 /// Set the right orthogonality limit.
305 #[inline]
306 pub fn set_rlim(&mut self, rlim: i32) {
307 // Convert to ortho center if possible
308 let llim = self.llim();
309 if llim + 2 == rlim && llim >= -1 && (llim + 1) < self.len() as i32 {
310 let center = (llim + 1) as usize;
311 let _ = self.treetn.set_canonical_region(vec![center]);
312 } else {
313 // Clear ortho region if not a single center
314 let _ = self.treetn.set_canonical_region(Vec::<usize>::new());
315 }
316 }
317
318 /// Get the orthogonality center range.
319 ///
320 /// Returns the range of sites that may not be orthogonal.
321 /// If the tensor train is fully left-orthogonal, returns an empty range at the end.
322 /// If the tensor train is fully right-orthogonal, returns an empty range at the beginning.
323 pub fn ortho_lims(&self) -> Range<usize> {
324 let llim = self.llim();
325 let rlim = self.rlim();
326 let start = (llim + 1).max(0) as usize;
327 let end = rlim.max(0) as usize;
328 start..end.min(self.len())
329 }
330
331 /// Check if the tensor train has a single orthogonality center.
332 ///
333 /// Returns true if there is exactly one site that is not guaranteed to be orthogonal.
334 #[inline]
335 pub fn isortho(&self) -> bool {
336 self.treetn.canonical_region().len() == 1
337 }
338
339 /// Get the orthogonality center (0-indexed).
340 ///
341 /// Returns `Some(site)` if the tensor train has a single orthogonality center,
342 /// `None` otherwise.
343 pub fn orthocenter(&self) -> Option<usize> {
344 let region = self.treetn.canonical_region();
345 if region.len() == 1 {
346 // Node name IS the site index since V = usize
347 Some(*region.iter().next().unwrap())
348 } else {
349 None
350 }
351 }
352
353 /// Get the canonicalization method used.
354 #[inline]
355 pub fn canonical_form(&self) -> Option<CanonicalForm> {
356 self.canonical_form
357 }
358
359 /// Set the canonicalization method.
360 #[inline]
361 pub fn set_canonical_form(&mut self, method: Option<CanonicalForm>) {
362 self.canonical_form = method;
363 }
364
365 /// Get a reference to the tensor at the given site.
366 ///
367 /// # Panics
368 ///
369 /// Panics if `site >= len()`.
370 #[inline]
371 pub fn tensor(&self, site: usize) -> &TensorDynLen {
372 let node_idx = self.treetn.node_index(&site).expect("Site out of bounds");
373 self.treetn.tensor(node_idx).expect("Tensor not found")
374 }
375
376 /// Get a reference to the tensor at the given site.
377 ///
378 /// Returns `Err` if `site >= len()`.
379 pub fn tensor_checked(&self, site: usize) -> Result<&TensorDynLen> {
380 if site >= self.len() {
381 return Err(TensorTrainError::SiteOutOfBounds {
382 site,
383 length: self.len(),
384 });
385 }
386 let node_idx =
387 self.treetn
388 .node_index(&site)
389 .ok_or_else(|| TensorTrainError::SiteOutOfBounds {
390 site,
391 length: self.len(),
392 })?;
393 self.treetn
394 .tensor(node_idx)
395 .ok_or_else(|| TensorTrainError::SiteOutOfBounds {
396 site,
397 length: self.len(),
398 })
399 }
400
401 /// Get a mutable reference to the tensor at the given site.
402 ///
403 /// # Panics
404 ///
405 /// Panics if `site >= len()`.
406 #[inline]
407 pub fn tensor_mut(&mut self, site: usize) -> &mut TensorDynLen {
408 let node_idx = self.treetn.node_index(&site).expect("Site out of bounds");
409 self.treetn.tensor_mut(node_idx).expect("Tensor not found")
410 }
411
412 /// Get a reference to all tensors.
413 #[inline]
414 pub fn tensors(&self) -> Vec<&TensorDynLen> {
415 (0..self.len())
416 .filter_map(|site| {
417 let node_idx = self.treetn.node_index(&site)?;
418 self.treetn.tensor(node_idx)
419 })
420 .collect()
421 }
422
423 /// Get a mutable reference to all tensors.
424 #[inline]
425 pub fn tensors_mut(&mut self) -> Vec<&mut TensorDynLen> {
426 let node_indices: Vec<_> = (0..self.len())
427 .map(|site| self.treetn.node_index(&site).expect("Site out of bounds"))
428 .collect();
429 let mut tensor_ptrs = Vec::with_capacity(node_indices.len());
430 for node_idx in node_indices {
431 let tensor = self.treetn.tensor_mut(node_idx).expect("Tensor not found");
432 tensor_ptrs.push(tensor as *mut TensorDynLen);
433 }
434
435 // SAFETY: TensorTrain site names are unique, so each site resolves to a
436 // distinct TreeTN node. We collect at most one pointer per node and do
437 // not mutate the network structure before converting those pointers back
438 // into mutable references.
439 unsafe { tensor_ptrs.into_iter().map(|tensor| &mut *tensor).collect() }
440 }
441
442 /// Get the link index between sites `i` and `i+1`.
443 ///
444 /// Returns `None` if `i >= len() - 1` or if no common index exists.
445 pub fn linkind(&self, i: usize) -> Option<DynIndex> {
446 if i >= self.len().saturating_sub(1) {
447 return None;
448 }
449
450 let left_node = self.treetn.node_index(&i)?;
451 let right_node = self.treetn.node_index(&(i + 1))?;
452 let left = self.treetn.tensor(left_node)?;
453 let right = self.treetn.tensor(right_node)?;
454 let common = common_inds(left.indices(), right.indices());
455 common.into_iter().next()
456 }
457
458 /// Get all link indices.
459 ///
460 /// Returns a vector of length `len() - 1` containing the link indices.
461 pub fn linkinds(&self) -> Vec<DynIndex> {
462 (0..self.len().saturating_sub(1))
463 .filter_map(|i| self.linkind(i))
464 .collect()
465 }
466
467 /// Create a copy with all link indices replaced by new unique IDs.
468 ///
469 /// This is useful for computing inner products where two tensor trains
470 /// share link indices. By simulating (replacing) the link indices in one
471 /// of the tensor trains, they can be contracted over site indices only.
472 pub fn sim_linkinds(&self) -> Self {
473 if self.len() <= 1 {
474 return self.clone();
475 }
476
477 // Build replacement pairs: (old_link, new_link) for each link index
478 let old_links = self.linkinds();
479 let new_links: Vec<_> = old_links.iter().map(|idx| idx.sim()).collect();
480 let replacements: Vec<_> = old_links
481 .iter()
482 .cloned()
483 .zip(new_links.iter().cloned())
484 .collect();
485
486 // Replace link indices in each tensor and rebuild
487 let mut new_tensors = Vec::with_capacity(self.len());
488 for site in 0..self.len() {
489 let tensor = self.tensor(site);
490 let mut new_tensor = tensor.clone();
491 for (old_idx, new_idx) in &replacements {
492 new_tensor = new_tensor.replaceind(old_idx, new_idx);
493 }
494 new_tensors.push(new_tensor);
495 }
496
497 Self::new(new_tensors).expect("sim_linkinds: failed to create new tensor train")
498 }
499
500 fn normalize_site_tensor_orders(&mut self) -> Result<()> {
501 for site in 0..self.len() {
502 self.normalize_site_tensor_order(site)?;
503 }
504 Ok(())
505 }
506
507 fn has_simple_linear_links(&self) -> bool {
508 if self.len() <= 1 {
509 return true;
510 }
511
512 (0..self.len() - 1).all(|site| {
513 let left = self.tensor(site);
514 let right = self.tensor(site + 1);
515 common_inds(left.indices(), right.indices()).len() <= 1
516 })
517 }
518
519 fn can_normalize_site_tensor_order(&self, site: usize) -> bool {
520 let left_ok = if site > 0 {
521 common_inds(self.tensor(site - 1).indices(), self.tensor(site).indices()).len() <= 1
522 } else {
523 true
524 };
525 let right_ok = if site + 1 < self.len() {
526 common_inds(self.tensor(site).indices(), self.tensor(site + 1).indices()).len() <= 1
527 } else {
528 true
529 };
530 left_ok && right_ok
531 }
532
533 fn normalize_site_tensor_order(&mut self, site: usize) -> Result<()> {
534 if !self.can_normalize_site_tensor_order(site) {
535 return Ok(());
536 }
537
538 let tensor = self.tensor_checked(site)?.clone();
539 let current = tensor.indices().to_vec();
540 let left = if site > 0 {
541 self.linkind(site - 1)
542 } else {
543 None
544 };
545 let right = if site + 1 < self.len() {
546 self.linkind(site)
547 } else {
548 None
549 };
550
551 let mut desired = Vec::with_capacity(current.len());
552 if let Some(ref left_link) = left {
553 desired.push(left_link.clone());
554 }
555 desired.extend(
556 current
557 .iter()
558 .filter(|idx| Some(*idx) != left.as_ref() && Some(*idx) != right.as_ref())
559 .cloned(),
560 );
561 if let Some(ref right_link) = right {
562 desired.push(right_link.clone());
563 }
564
565 if desired == current {
566 return Ok(());
567 }
568
569 let normalized =
570 tensor
571 .permuteinds(&desired)
572 .map_err(|e| TensorTrainError::InvalidStructure {
573 message: format!(
574 "Failed to normalize site tensor index order at site {}: {}",
575 site, e
576 ),
577 })?;
578 self.set_tensor_raw(site, normalized)
579 }
580
581 /// Get the site indices (non-link indices) for all sites.
582 ///
583 /// For each site, returns a vector of indices that are not shared with
584 /// adjacent tensors (i.e., the "physical" or "site" indices).
585 pub fn siteinds(&self) -> Vec<Vec<DynIndex>> {
586 if self.is_empty() {
587 return Vec::new();
588 }
589
590 let mut result = Vec::with_capacity(self.len());
591
592 for i in 0..self.len() {
593 let tensor = self.tensor(i);
594 let mut site_inds: Vec<DynIndex> = tensor.indices().to_vec();
595
596 // Remove link to left neighbor
597 if i > 0 {
598 if let Some(link) = self.linkind(i - 1) {
599 site_inds.retain(|idx| idx != &link);
600 }
601 }
602
603 // Remove link to right neighbor
604 if i < self.len() - 1 {
605 if let Some(link) = self.linkind(i) {
606 site_inds.retain(|idx| idx != &link);
607 }
608 }
609
610 result.push(site_inds);
611 }
612
613 result
614 }
615
616 /// Get the bond dimension at link `i` (between sites `i` and `i+1`).
617 ///
618 /// Returns `None` if `i >= len() - 1`.
619 pub fn bond_dim(&self, i: usize) -> Option<usize> {
620 self.linkind(i).map(|idx| idx.size())
621 }
622
623 /// Get all bond dimensions.
624 ///
625 /// Returns a vector of length `len() - 1`.
626 pub fn bond_dims(&self) -> Vec<usize> {
627 self.linkinds().iter().map(|idx| idx.size()).collect()
628 }
629
630 /// Get the maximum bond dimension.
631 pub fn maxbonddim(&self) -> usize {
632 self.bond_dims().into_iter().max().unwrap_or(1)
633 }
634
635 /// Check if two adjacent tensors share an index.
636 pub fn haslink(&self, i: usize) -> bool {
637 if i >= self.len().saturating_sub(1) {
638 return false;
639 }
640 let left_node = self.treetn.node_index(&i);
641 let right_node = self.treetn.node_index(&(i + 1));
642 match (left_node, right_node) {
643 (Some(l), Some(r)) => {
644 let left = self.treetn.tensor(l);
645 let right = self.treetn.tensor(r);
646 match (left, right) {
647 (Some(l), Some(r)) => hascommoninds(l.indices(), r.indices()),
648 _ => false,
649 }
650 }
651 _ => false,
652 }
653 }
654
655 /// Replace the tensor at the given site.
656 ///
657 /// This invalidates orthogonality tracking.
658 fn set_tensor_raw(&mut self, site: usize, tensor: TensorDynLen) -> Result<()> {
659 let node_idx = self.treetn.node_index(&site).expect("Site out of bounds");
660 self.treetn.replace_tensor(node_idx, tensor).map_err(|e| {
661 TensorTrainError::InvalidStructure {
662 message: format!("Failed to replace tensor at site {}: {}", site, e),
663 }
664 })?;
665 Ok(())
666 }
667
668 /// Replace the tensor at the given site.
669 ///
670 /// This invalidates orthogonality tracking.
671 pub fn set_tensor(&mut self, site: usize, tensor: TensorDynLen) {
672 self.set_tensor_raw(site, tensor)
673 .and_then(|()| self.normalize_site_tensor_order(site))
674 .unwrap_or_else(|e| panic!("TensorTrain::set_tensor failed: {}", e));
675 // Invalidate orthogonality
676 let _ = self.treetn.set_canonical_region(Vec::<usize>::new());
677 }
678
679 /// Orthogonalize the tensor train to have orthogonality center at the given site.
680 ///
681 /// This function performs a series of factorizations to make the tensor train
682 /// canonical with orthogonality center at `site`.
683 ///
684 /// # Arguments
685 ///
686 /// * `site` - The target site for the orthogonality center (0-indexed)
687 ///
688 /// # Errors
689 ///
690 /// Returns an error if the factorization fails or if the site is out of bounds.
691 ///
692 /// # Examples
693 ///
694 /// ```
695 /// use tensor4all_itensorlike::TensorTrain;
696 /// use tensor4all_core::{DynIndex, TensorDynLen, Index, DynId};
697 ///
698 /// let s0 = Index::new_with_size(DynId(0), 2);
699 /// let link = Index::new_with_size(DynId(1), 3);
700 /// let s1 = Index::new_with_size(DynId(2), 2);
701 ///
702 /// let t0 = TensorDynLen::from_dense(
703 /// vec![s0.clone(), link.clone()],
704 /// (0..6).map(|i| i as f64).collect(),
705 /// ).unwrap();
706 /// let t1 = TensorDynLen::from_dense(
707 /// vec![link.clone(), s1.clone()],
708 /// (0..6).map(|i| i as f64).collect(),
709 /// ).unwrap();
710 ///
711 /// let mut tt = TensorTrain::new(vec![t0, t1]).unwrap();
712 /// assert!(!tt.isortho());
713 ///
714 /// // Orthogonalize to site 0
715 /// tt.orthogonalize(0).unwrap();
716 /// assert!(tt.isortho());
717 /// assert_eq!(tt.orthocenter(), Some(0));
718 /// ```
719 pub fn orthogonalize(&mut self, site: usize) -> Result<()> {
720 self.orthogonalize_with(site, CanonicalForm::Unitary)
721 }
722
723 /// Orthogonalize with a specified canonical form.
724 ///
725 /// # Arguments
726 ///
727 /// * `site` - The target site for the orthogonality center (0-indexed)
728 /// * `form` - The canonical form to use:
729 /// - `Unitary`: Uses QR decomposition, each tensor is isometric
730 /// - `LU`: Uses LU decomposition, one factor has unit diagonal
731 /// - `CI`: Uses Cross Interpolation
732 pub fn orthogonalize_with(&mut self, site: usize, form: CanonicalForm) -> Result<()> {
733 if self.is_empty() {
734 return Err(TensorTrainError::Empty);
735 }
736 if site >= self.len() {
737 return Err(TensorTrainError::SiteOutOfBounds {
738 site,
739 length: self.len(),
740 });
741 }
742
743 // Use TreeTN's canonicalize (accepts node names and CanonicalizationOptions)
744 // Since V = usize, node names are site indices
745 let options = CanonicalizationOptions::forced().with_form(form);
746 self.treetn = std::mem::take(&mut self.treetn)
747 .canonicalize(vec![site], options)
748 .map_err(|e| TensorTrainError::InvalidStructure {
749 message: format!("Canonicalize failed: {}", e),
750 })?;
751
752 self.canonical_form = Some(form);
753 Ok(())
754 }
755
756 /// Truncate the tensor train bond dimensions.
757 ///
758 /// This delegates to the TreeTN's truncate_mut method, which performs a
759 /// two-site sweep with Euler tour traversal for optimal truncation.
760 ///
761 /// Note: The `site_range` option in `TruncateOptions` is currently ignored
762 /// as the underlying TreeTN truncation operates on the full network.
763 ///
764 /// # Examples
765 ///
766 /// ```
767 /// use tensor4all_itensorlike::{TensorTrain, TruncateOptions};
768 /// use tensor4all_core::{DynIndex, TensorDynLen, Index, DynId};
769 ///
770 /// // Build a 3-site tensor train with bond dimension 4
771 /// let s0 = Index::new_with_size(DynId(0), 2);
772 /// let l01 = Index::new_with_size(DynId(1), 4);
773 /// let s1 = Index::new_with_size(DynId(2), 2);
774 /// let l12 = Index::new_with_size(DynId(3), 4);
775 /// let s2 = Index::new_with_size(DynId(4), 2);
776 ///
777 /// let t0 = TensorDynLen::from_dense(
778 /// vec![s0.clone(), l01.clone()],
779 /// (0..8).map(|i| i as f64).collect(),
780 /// ).unwrap();
781 /// let t1 = TensorDynLen::from_dense(
782 /// vec![l01.clone(), s1.clone(), l12.clone()],
783 /// (0..32).map(|i| i as f64).collect(),
784 /// ).unwrap();
785 /// let t2 = TensorDynLen::from_dense(
786 /// vec![l12.clone(), s2.clone()],
787 /// (0..8).map(|i| i as f64).collect(),
788 /// ).unwrap();
789 ///
790 /// let mut tt = TensorTrain::new(vec![t0, t1, t2]).unwrap();
791 /// assert_eq!(tt.maxbonddim(), 4);
792 ///
793 /// // Truncate bond dimension to at most 2
794 /// let opts = TruncateOptions::svd().with_max_rank(2);
795 /// tt.truncate(&opts).unwrap();
796 /// assert!(tt.maxbonddim() <= 2);
797 /// ```
798 pub fn truncate(&mut self, options: &TruncateOptions) -> Result<()> {
799 if self.len() <= 1 {
800 return Ok(());
801 }
802
803 validate_svd_truncation_options(options.max_rank(), options.svd_policy())?;
804
805 // Convert TruncateOptions to TruncationOptions
806 let mut treetn_options = TruncationOptions::new();
807 if let Some(policy) = options.svd_policy() {
808 treetn_options = treetn_options.with_svd_policy(policy);
809 }
810 if let Some(max_rank) = options.max_rank() {
811 treetn_options = treetn_options.with_max_rank(max_rank);
812 }
813
814 // Truncate towards the last site (rightmost) as the canonical center
815 // This matches ITensor convention where truncation sweeps left-to-right
816 let center = self.len() - 1;
817
818 self.treetn
819 .truncate_mut([center], treetn_options)
820 .map_err(|e| TensorTrainError::InvalidStructure {
821 message: format!("TreeTN truncation failed: {}", e),
822 })?;
823
824 self.canonical_form = Some(CanonicalForm::Unitary);
825
826 Ok(())
827 }
828
829 /// Compute the inner product (dot product) of two tensor trains.
830 ///
831 /// Computes `<self | other>` = sum over all indices of `conj(self) * other`.
832 ///
833 /// Both tensor trains must have the same site indices (same IDs).
834 /// Link indices may differ between the two tensor trains.
835 ///
836 /// # Examples
837 ///
838 /// ```
839 /// use tensor4all_itensorlike::TensorTrain;
840 /// use tensor4all_core::{DynIndex, TensorDynLen, Index, DynId, AnyScalar};
841 ///
842 /// // Single-site tensor train with values [1.0, 0.0]
843 /// let s0 = Index::new_with_size(DynId(0), 2);
844 /// let t = TensorDynLen::from_dense(
845 /// vec![s0.clone()],
846 /// vec![1.0_f64, 0.0],
847 /// ).unwrap();
848 ///
849 /// let tt = TensorTrain::new(vec![t]).unwrap();
850 ///
851 /// // <tt | tt> = 1.0^2 + 0.0^2 = 1.0
852 /// let result = tt.inner(&tt);
853 /// assert!((result.real() - 1.0).abs() < 1e-10);
854 /// ```
855 pub fn inner(&self, other: &Self) -> AnyScalar {
856 assert_eq!(
857 self.len(),
858 other.len(),
859 "Tensor trains must have the same length for inner product"
860 );
861
862 if self.is_empty() {
863 return AnyScalar::new_real(0.0);
864 }
865
866 // Sequential bra-ket contraction along the chain: O(N·D²·d).
867 // TreeTN::inner() uses contract_naive which is O(d^N) and OOMs for large N.
868 let other_sim = other.treetn.sim_internal_inds();
869
870 let n = self.len();
871 let node_idx = |ttn: &TreeTN<TensorDynLen, usize>, site: usize| {
872 ttn.node_index(&site).expect("node not found")
873 };
874
875 // Start with leftmost tensors - contract over site indices only
876 let mut env = {
877 let a0_conj = self.tensor(0).conj();
878 let b0 = other_sim
879 .tensor(node_idx(&other_sim, 0))
880 .expect("tensor not found")
881 .clone();
882 a0_conj.contract(&b0)
883 };
884
885 // Sweep through remaining sites
886 for i in 1..n {
887 let ai_conj = self.tensor(i).conj();
888 let bi = other_sim
889 .tensor(node_idx(&other_sim, i))
890 .expect("tensor not found");
891
892 // Contract: env * conj(A_i) (over self's link index)
893 env = env.contract(&ai_conj);
894 // Contract: result * B_i (over other's link index and site indices)
895 env = env.contract(bi);
896 }
897
898 // Result should be a scalar (0-dimensional tensor)
899 env.only()
900 }
901
902 /// Compute the squared norm of the tensor train.
903 ///
904 /// Returns `<self | self>` = ||self||^2.
905 ///
906 /// # Note
907 /// For linear tensor trains with one site index per site, this uses a
908 /// specialized chain contraction instead of the generic inner-product path.
909 /// Due to numerical errors, the final scalar can be very slightly negative,
910 /// so the returned value is clamped to be non-negative.
911 pub fn norm_squared(&self) -> f64 {
912 self.norm_squared_fast_path()
913 .unwrap_or_else(|| self.inner(self).real().max(0.0))
914 }
915
916 /// Compute the norm of the tensor train.
917 ///
918 /// Returns ||self|| = sqrt(<self | self>).
919 pub fn norm(&self) -> f64 {
920 self.norm_squared().sqrt()
921 }
922
923 fn norm_squared_fast_path(&self) -> Option<f64> {
924 if self.is_empty() {
925 return Some(0.0);
926 }
927 if !self.has_simple_linear_links() {
928 return None;
929 }
930 if self
931 .siteinds()
932 .iter()
933 .any(|site_indices| site_indices.len() != 1)
934 {
935 return None;
936 }
937
938 let mut normalized = self.clone();
939 normalized.normalize_site_tensor_orders().ok()?;
940
941 if let Some(sites) = Self::pack_normalized_sites::<f64>(&normalized) {
942 return Some(Self::norm_squared_from_packed_sites(&sites));
943 }
944 if let Some(sites) = Self::pack_normalized_sites::<Complex64>(&normalized) {
945 return Some(Self::norm_squared_from_packed_sites(&sites));
946 }
947
948 None
949 }
950
951 fn pack_normalized_sites<T: TensorElement>(tt: &Self) -> Option<Vec<PackedSiteTensor<T>>> {
952 let mut sites = Vec::with_capacity(tt.len());
953
954 for site in 0..tt.len() {
955 let tensor = tt.tensor(site);
956 let left_dim = if site == 0 {
957 1
958 } else {
959 tt.linkind(site - 1)?.size()
960 };
961 let right_dim = if site + 1 == tt.len() {
962 1
963 } else {
964 tt.linkind(site)?.size()
965 };
966 let total_size: usize = tensor.dims().iter().product();
967 let boundary_size = left_dim.checked_mul(right_dim)?;
968 if boundary_size == 0 || !total_size.is_multiple_of(boundary_size) {
969 return None;
970 }
971
972 sites.push(PackedSiteTensor {
973 left_dim,
974 physical_dim: total_size / boundary_size,
975 right_dim,
976 data: tensor.to_vec::<T>().ok()?,
977 });
978 }
979
980 Some(sites)
981 }
982
983 fn norm_squared_from_packed_sites<T: NormAccumScalar>(sites: &[PackedSiteTensor<T>]) -> f64 {
984 if sites.is_empty() {
985 return 0.0;
986 }
987
988 let first = &sites[0];
989 let mut current = vec![T::zero(); first.right_dim * first.right_dim];
990
991 for physical in 0..first.physical_dim {
992 for right in 0..first.right_dim {
993 let value = first.get(0, physical, right);
994 for right_conj in 0..first.right_dim {
995 let idx = right * first.right_dim + right_conj;
996 current[idx] = current[idx] + value * first.get(0, physical, right_conj).conj();
997 }
998 }
999 }
1000
1001 for site in &sites[1..] {
1002 let mut next = vec![T::zero(); site.right_dim * site.right_dim];
1003
1004 for left in 0..site.left_dim {
1005 for left_conj in 0..site.left_dim {
1006 let env = current[left * site.left_dim + left_conj];
1007 for physical in 0..site.physical_dim {
1008 for right in 0..site.right_dim {
1009 let value = site.get(left, physical, right);
1010 for right_conj in 0..site.right_dim {
1011 let idx = right * site.right_dim + right_conj;
1012 next[idx] = next[idx]
1013 + env
1014 * value
1015 * site.get(left_conj, physical, right_conj).conj();
1016 }
1017 }
1018 }
1019 }
1020 }
1021
1022 current = next;
1023 }
1024
1025 current[0].into_nonnegative_real()
1026 }
1027
1028 /// Convert the tensor train to a single dense tensor.
1029 ///
1030 /// This contracts all tensors in the train along their link indices,
1031 /// producing a single tensor with only site indices.
1032 ///
1033 /// # Warning
1034 /// This operation can be very expensive for large tensor trains,
1035 /// as the result size grows exponentially with the number of sites.
1036 ///
1037 /// # Returns
1038 /// A single tensor containing all site indices, or an error if the
1039 /// tensor train is empty.
1040 ///
1041 /// # Example
1042 /// ```
1043 /// use tensor4all_core::{DynIndex, TensorDynLen};
1044 /// use tensor4all_itensorlike::TensorTrain;
1045 ///
1046 /// # fn main() -> anyhow::Result<()> {
1047 /// let s0 = DynIndex::new_dyn(2);
1048 /// let link = DynIndex::new_dyn(1);
1049 /// let s1 = DynIndex::new_dyn(2);
1050 /// let t0 = TensorDynLen::from_dense(vec![s0.clone(), link.clone()], vec![1.0, 2.0])?;
1051 /// let t1 = TensorDynLen::from_dense(vec![link.clone(), s1.clone()], vec![3.0, 4.0])?;
1052 ///
1053 /// let tt = TensorTrain::new(vec![t0, t1])?;
1054 /// let dense = tt.to_dense()?;
1055 ///
1056 /// assert_eq!(dense.dims(), vec![2, 2]);
1057 /// assert_eq!(dense.to_vec::<f64>()?, vec![3.0, 6.0, 4.0, 8.0]);
1058 /// # Ok(())
1059 /// # }
1060 /// ```
1061 pub fn to_dense(&self) -> Result<TensorDynLen> {
1062 if self.is_empty() {
1063 return Err(TensorTrainError::InvalidStructure {
1064 message: "Cannot convert empty tensor train to dense".to_string(),
1065 });
1066 }
1067
1068 self.treetn
1069 .contract_to_tensor()
1070 .map_err(|e| TensorTrainError::InvalidStructure {
1071 message: format!("Failed to contract to dense: {}", e),
1072 })
1073 }
1074
1075 /// Add two tensor trains using direct-sum construction.
1076 ///
1077 /// This creates a new tensor train where each tensor is the direct sum of the
1078 /// corresponding tensors from self and other, with bond dimensions merged.
1079 /// The result has bond dimensions equal to the sum of the input bond dimensions.
1080 ///
1081 /// # Arguments
1082 /// * `other` - The other tensor train to add
1083 ///
1084 /// # Returns
1085 /// A new tensor train representing the sum.
1086 ///
1087 /// # Errors
1088 /// Returns an error if the tensor trains have incompatible structures.
1089 pub fn add(&self, other: &Self) -> Result<Self> {
1090 if self.is_empty() && other.is_empty() {
1091 return Ok(Self::default());
1092 }
1093
1094 if self.is_empty() {
1095 return Ok(other.clone());
1096 }
1097
1098 if other.is_empty() {
1099 return Ok(self.clone());
1100 }
1101
1102 if self.len() != other.len() {
1103 return Err(TensorTrainError::InvalidStructure {
1104 message: format!(
1105 "Tensor trains must have the same length for addition: {} vs {}",
1106 self.len(),
1107 other.len()
1108 ),
1109 });
1110 }
1111
1112 let result_inner =
1113 self.treetn
1114 .add(&other.treetn)
1115 .map_err(|e| TensorTrainError::InvalidStructure {
1116 message: format!("TT addition failed: {}", e),
1117 })?;
1118
1119 Self::from_inner(result_inner, None)
1120 }
1121
1122 /// Scale the tensor train by a scalar.
1123 ///
1124 /// Only one tensor (the first non-empty site) is scaled to avoid scalar^n scaling.
1125 /// This is correct because the tensor train represents a product of tensors,
1126 /// so scaling one factor scales the entire product.
1127 ///
1128 /// # Arguments
1129 /// * `scalar` - The scalar to multiply by
1130 ///
1131 /// # Returns
1132 /// A new tensor train scaled by the given scalar.
1133 ///
1134 /// # Example
1135 /// ```
1136 /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen};
1137 /// use tensor4all_itensorlike::TensorTrain;
1138 ///
1139 /// # fn main() -> anyhow::Result<()> {
1140 /// let s0 = DynIndex::new_dyn(2);
1141 /// let tt = TensorTrain::new(vec![TensorDynLen::from_dense(
1142 /// vec![s0.clone()],
1143 /// vec![1.0, 2.0],
1144 /// )?])?;
1145 ///
1146 /// let scaled = tt.scale(AnyScalar::new_real(2.0))?;
1147 /// assert_eq!(scaled.to_dense()?.to_vec::<f64>()?, vec![2.0, 4.0]);
1148 /// # Ok(())
1149 /// # }
1150 /// ```
1151 pub fn scale(&self, scalar: AnyScalar) -> Result<Self> {
1152 if self.is_empty() {
1153 return Ok(self.clone());
1154 }
1155
1156 let mut tensors = Vec::with_capacity(self.len());
1157 for site in 0..self.len() {
1158 let tensor = self.tensor(site);
1159 if site == 0 {
1160 // Scale only the first tensor
1161 let scaled =
1162 tensor
1163 .scale(scalar.clone())
1164 .map_err(|e| TensorTrainError::OperationError {
1165 message: format!("Failed to scale tensor at site 0: {}", e),
1166 })?;
1167 tensors.push(scaled);
1168 } else {
1169 tensors.push(tensor.clone());
1170 }
1171 }
1172
1173 Self::new(tensors)
1174 }
1175
1176 /// Compute a linear combination: `a * self + b * other`.
1177 ///
1178 /// This is equivalent to `self.scale(a)?.add(&other.scale(b)?)`.
1179 ///
1180 /// # Arguments
1181 /// * `a` - Scalar coefficient for self
1182 /// * `other` - The other tensor train
1183 /// * `b` - Scalar coefficient for other
1184 ///
1185 /// # Returns
1186 /// A new tensor train representing `a * self + b * other`.
1187 ///
1188 /// # Note
1189 /// The bond dimension of the result is the sum of the bond dimensions
1190 /// of the two input tensor trains (before any truncation).
1191 pub fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
1192 let scaled_self = self.scale(a)?;
1193 let scaled_other = other.scale(b)?;
1194 scaled_self.add(&scaled_other)
1195 }
1196}
1197
1198// Implement Default for TensorTrain to allow std::mem::take
1199impl Default for TensorTrain {
1200 fn default() -> Self {
1201 Self::new(vec![]).expect("Failed to create empty TensorTrain")
1202 }
1203}
1204
1205// ============================================================================
1206// TensorIndex implementation for TensorTrain
1207// ============================================================================
1208
1209impl TensorIndex for TensorTrain {
1210 type Index = DynIndex;
1211
1212 fn external_indices(&self) -> Vec<Self::Index> {
1213 // Delegate to the internal TreeTN's TensorIndex implementation
1214 self.treetn.external_indices()
1215 }
1216
1217 fn num_external_indices(&self) -> usize {
1218 self.treetn.num_external_indices()
1219 }
1220
1221 fn replaceind(&self, old: &Self::Index, new: &Self::Index) -> anyhow::Result<Self> {
1222 // Delegate to the internal TreeTN's replaceind
1223 // After replacement, canonical form may be invalid, so set to None
1224 let treetn = self.treetn.replaceind(old, new)?;
1225 Self::from_inner(treetn, None).map_err(|e| anyhow::anyhow!("{}", e))
1226 }
1227
1228 fn replaceinds(&self, old: &[Self::Index], new: &[Self::Index]) -> anyhow::Result<Self> {
1229 let treetn = self.treetn.replaceinds(old, new)?;
1230 Self::from_inner(treetn, None).map_err(|e| anyhow::anyhow!("{}", e))
1231 }
1232}
1233
1234// ============================================================================
1235// TensorLike implementation for TensorTrain
1236// ============================================================================
1237
1238impl TensorLike for TensorTrain {
1239 // ========================================================================
1240 // GMRES-required methods (fully supported)
1241 // ========================================================================
1242
1243 fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> anyhow::Result<Self> {
1244 TensorTrain::axpby(self, a, other, b).map_err(|e| anyhow::anyhow!("{}", e))
1245 }
1246
1247 fn scale(&self, scalar: AnyScalar) -> anyhow::Result<Self> {
1248 TensorTrain::scale(self, scalar).map_err(|e| anyhow::anyhow!("{}", e))
1249 }
1250
1251 fn inner_product(&self, other: &Self) -> anyhow::Result<AnyScalar> {
1252 Ok(self.inner(other))
1253 }
1254
1255 fn norm_squared(&self) -> f64 {
1256 TensorTrain::norm_squared(self)
1257 }
1258
1259 fn maxabs(&self) -> f64 {
1260 self.to_dense().map(|t| t.maxabs()).unwrap_or(0.0)
1261 }
1262
1263 fn conj(&self) -> Self {
1264 // Clone and conjugate each site tensor
1265 // Note: conj() cannot return Result, so we ensure this never fails
1266 let mut result = self.clone();
1267 for site in 0..result.len() {
1268 let t = result.tensor(site).conj();
1269 result.set_tensor(site, t);
1270 }
1271 result
1272 }
1273
1274 // ========================================================================
1275 // Methods not supported by TensorTrain
1276 // ========================================================================
1277
1278 fn factorize(
1279 &self,
1280 _left_inds: &[Self::Index],
1281 _options: &FactorizeOptions,
1282 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
1283 Err(FactorizeError::UnsupportedStorage(
1284 "TensorTrain does not support factorize; use orthogonalize() instead",
1285 ))
1286 }
1287
1288 fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> anyhow::Result<Self> {
1289 anyhow::bail!("TensorTrain does not support TensorLike::contract; use TensorTrain::contract() method instead")
1290 }
1291
1292 fn contract_connected(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> anyhow::Result<Self> {
1293 anyhow::bail!("TensorTrain does not support TensorLike::contract_connected; use TensorTrain::contract() method instead")
1294 }
1295
1296 fn direct_sum(
1297 &self,
1298 _other: &Self,
1299 _pairs: &[(Self::Index, Self::Index)],
1300 ) -> anyhow::Result<DirectSumResult<Self>> {
1301 anyhow::bail!("TensorTrain does not support direct_sum; use add() instead")
1302 }
1303
1304 fn outer_product(&self, _other: &Self) -> anyhow::Result<Self> {
1305 anyhow::bail!("TensorTrain does not support outer_product")
1306 }
1307
1308 fn permuteinds(&self, _new_order: &[Self::Index]) -> anyhow::Result<Self> {
1309 anyhow::bail!("TensorTrain does not support permuteinds")
1310 }
1311
1312 fn diagonal(input: &Self::Index, output: &Self::Index) -> anyhow::Result<Self> {
1313 // Create a single-site TensorTrain with an identity tensor
1314 let delta = TensorDynLen::diagonal(input, output)?;
1315 Self::new(vec![delta]).map_err(|e| anyhow::anyhow!("{}", e))
1316 }
1317
1318 fn scalar_one() -> anyhow::Result<Self> {
1319 // Empty tensor train represents scalar 1
1320 Self::new(vec![]).map_err(|e| anyhow::anyhow!("{}", e))
1321 }
1322
1323 fn ones(indices: &[Self::Index]) -> anyhow::Result<Self> {
1324 let t = TensorDynLen::ones(indices)?;
1325 Self::new(vec![t]).map_err(|e| anyhow::anyhow!("{}", e))
1326 }
1327
1328 fn onehot(index_vals: &[(Self::Index, usize)]) -> anyhow::Result<Self> {
1329 let t = TensorDynLen::onehot(index_vals)?;
1330 Self::new(vec![t]).map_err(|e| anyhow::anyhow!("{}", e))
1331 }
1332}
1333
1334#[cfg(test)]
1335mod tests;