Skip to main content

tensor4all_itensorlike/
tensortrain.rs

1//! Main Tensor Train type as a wrapper around TreeTN.
2//!
3//! This module provides the `TensorTrain` type, which represents a Tensor Train
4//! (also known as MPS) with orthogonality tracking, inspired by ITensorMPS.jl.
5//!
6//! Internally, TensorTrain is implemented as a thin wrapper around
7//! `TreeTN<TensorDynLen, usize>` where node names are site indices (0, 1, 2, ...).
8
9use num_complex::Complex64;
10use std::ops::Range;
11use tensor4all_core::{common_inds, hascommoninds, DynIndex, IndexLike};
12use tensor4all_core::{
13    AllowedPairs, AnyScalar, CommonScalar, DirectSumResult, FactorizeError, FactorizeOptions,
14    FactorizeResult, TensorDynLen, TensorElement, TensorIndex, TensorLike,
15};
16use tensor4all_treetn::{CanonicalizationOptions, TreeTN, TruncationOptions};
17
18use crate::error::{Result, TensorTrainError};
19use crate::options::{validate_svd_truncation_options, CanonicalForm, TruncateOptions};
20
21/// Tensor Train with orthogonality tracking.
22///
23/// This type represents a tensor train as a sequence of tensors with tracked
24/// orthogonality limits. It is inspired by ITensorMPS.jl but uses
25/// 0-indexed sites (Rust convention).
26///
27/// Unlike traditional MPS which assumes one physical index per site, this
28/// implementation allows each site to have multiple site indices.
29///
30/// # Orthogonality Tracking
31///
32/// The tensor train tracks orthogonality using `ortho_region` from the underlying TreeTN:
33/// - When `ortho_region` is empty, no orthogonality is assumed
34/// - When `ortho_region` contains a single site, that site is the orthogonality center
35///
36/// # Implementation
37///
38/// Internally wraps `TreeTN<TensorDynLen, usize>` where node names are site indices.
39/// This allows reuse of TreeTN's canonicalization and contraction algorithms.
40///
41/// # Examples
42///
43/// Build a 2-site tensor train and query its properties:
44///
45/// ```
46/// use tensor4all_itensorlike::TensorTrain;
47/// use tensor4all_core::{DynIndex, TensorDynLen, Index};
48/// use tensor4all_core::DynId;
49///
50/// // Site indices and link index
51/// let s0 = Index::new_with_size(DynId(0), 2);
52/// let link = Index::new_with_size(DynId(1), 3);
53/// let s1 = Index::new_with_size(DynId(2), 2);
54///
55/// let t0 = TensorDynLen::from_dense(
56///     vec![s0.clone(), link.clone()],
57///     (0..6).map(|i| i as f64).collect(),
58/// ).unwrap();
59/// let t1 = TensorDynLen::from_dense(
60///     vec![link.clone(), s1.clone()],
61///     (0..6).map(|i| i as f64).collect(),
62/// ).unwrap();
63///
64/// let tt = TensorTrain::new(vec![t0, t1]).unwrap();
65/// assert_eq!(tt.len(), 2);
66/// assert_eq!(tt.maxbonddim(), 3);
67/// assert!(!tt.is_empty());
68/// ```
69#[derive(Debug, Clone)]
70pub struct TensorTrain {
71    /// The underlying TreeTN with linear chain topology.
72    /// Node names are usize (0, 1, 2, ...) representing site indices.
73    pub(crate) treetn: TreeTN<TensorDynLen, usize>,
74    /// The canonical form used (if known).
75    canonical_form: Option<CanonicalForm>,
76}
77
78#[derive(Debug)]
79struct PackedSiteTensor<T> {
80    left_dim: usize,
81    physical_dim: usize,
82    right_dim: usize,
83    data: Vec<T>,
84}
85
86impl<T: Copy> PackedSiteTensor<T> {
87    fn get(&self, left: usize, physical: usize, right: usize) -> T {
88        debug_assert!(left < self.left_dim);
89        debug_assert!(physical < self.physical_dim);
90        debug_assert!(right < self.right_dim);
91
92        let idx = left + self.left_dim * (physical + self.physical_dim * right);
93        self.data[idx]
94    }
95}
96
97trait NormAccumScalar: CommonScalar {
98    fn into_nonnegative_real(self) -> f64;
99}
100
101impl NormAccumScalar for f64 {
102    fn into_nonnegative_real(self) -> f64 {
103        self.max(0.0)
104    }
105}
106
107impl NormAccumScalar for Complex64 {
108    fn into_nonnegative_real(self) -> f64 {
109        self.re.max(0.0)
110    }
111}
112
113impl TensorTrain {
114    /// Create a new tensor train from a vector of tensors.
115    ///
116    /// The tensor train is created with no assumed orthogonality.
117    ///
118    /// # Arguments
119    ///
120    /// * `tensors` - Vector of tensors representing the tensor train
121    ///
122    /// # Returns
123    ///
124    /// A new tensor train with no orthogonality.
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if the tensors have inconsistent bond dimensions
129    /// (i.e., the link indices between adjacent tensors don't match).
130    pub fn new(tensors: Vec<TensorDynLen>) -> Result<Self> {
131        if tensors.is_empty() {
132            // Create an empty TreeTN
133            let treetn = TreeTN::<TensorDynLen, usize>::new();
134            return Ok(Self {
135                treetn,
136                canonical_form: None,
137            });
138        }
139
140        // Validate that adjacent tensors share exactly one common index (the link)
141        for i in 0..tensors.len() - 1 {
142            let left = &tensors[i];
143            let right = &tensors[i + 1];
144
145            let common = common_inds(left.indices(), right.indices());
146            if common.is_empty() {
147                return Err(TensorTrainError::InvalidStructure {
148                    message: format!(
149                        "No common index between tensors at sites {} and {}",
150                        i,
151                        i + 1
152                    ),
153                });
154            }
155            if common.len() > 1 {
156                return Err(TensorTrainError::InvalidStructure {
157                    message: format!(
158                        "Multiple common indices ({}) between tensors at sites {} and {}",
159                        common.len(),
160                        i,
161                        i + 1
162                    ),
163                });
164            }
165        }
166
167        // Create node names: 0, 1, 2, ..., n-1
168        let node_names: Vec<usize> = (0..tensors.len()).collect();
169
170        // Create TreeTN with from_tensors (auto-connects by shared index IDs)
171        let treetn =
172            TreeTN::<TensorDynLen, usize>::from_tensors(tensors, node_names).map_err(|e| {
173                TensorTrainError::InvalidStructure {
174                    message: format!("Failed to create TreeTN: {}", e),
175                }
176            })?;
177
178        let mut tt = Self {
179            treetn,
180            canonical_form: None,
181        };
182        tt.normalize_site_tensor_orders()?;
183        Ok(tt)
184    }
185
186    /// Create a new tensor train with specified orthogonality center.
187    ///
188    /// This is useful when constructing a tensor train that is already in canonical form.
189    ///
190    /// # Arguments
191    ///
192    /// * `tensors` - Vector of tensors representing the tensor train
193    /// * `llim` - Left orthogonality limit (for compatibility; only used to compute center)
194    /// * `rlim` - Right orthogonality limit (for compatibility; only used to compute center)
195    /// * `canonical_form` - The method used for canonicalization (if any)
196    pub fn with_ortho(
197        tensors: Vec<TensorDynLen>,
198        llim: i32,
199        rlim: i32,
200        canonical_form: Option<CanonicalForm>,
201    ) -> Result<Self> {
202        let mut tt = Self::new(tensors)?;
203
204        // Convert llim/rlim to ortho center
205        // When llim + 2 == rlim, ortho center is at llim + 1
206        if llim + 2 == rlim && llim >= -1 && (llim + 1) < tt.len() as i32 {
207            let center = (llim + 1) as usize;
208            tt.treetn.set_canonical_region(vec![center]).map_err(|e| {
209                TensorTrainError::InvalidStructure {
210                    message: format!("Failed to set ortho region: {}", e),
211                }
212            })?;
213        }
214
215        tt.canonical_form = canonical_form;
216        Ok(tt)
217    }
218
219    /// Create a TensorTrain from an existing TreeTN and canonical form.
220    ///
221    /// This is a crate-internal constructor used by `contract` and `linsolve`.
222    pub(crate) fn from_inner(
223        treetn: TreeTN<TensorDynLen, usize>,
224        canonical_form: Option<CanonicalForm>,
225    ) -> Result<Self> {
226        let mut node_names = treetn.node_names();
227        node_names.sort_unstable();
228        let mut tt = Self {
229            treetn,
230            canonical_form,
231        };
232        for (site, old_name) in node_names.into_iter().enumerate() {
233            if old_name != site {
234                tt.treetn.rename_node(&old_name, site).map_err(|e| {
235                    TensorTrainError::InvalidStructure {
236                        message: format!("Failed to renumber TensorTrain sites: {}", e),
237                    }
238                })?;
239            }
240        }
241        if tt.has_simple_linear_links() {
242            tt.normalize_site_tensor_orders()?;
243        }
244        Ok(tt)
245    }
246
247    /// Get a reference to the underlying TreeTN.
248    ///
249    /// This is a crate-internal accessor used by `contract` and `linsolve`.
250    pub(crate) fn as_treetn(&self) -> &TreeTN<TensorDynLen, usize> {
251        &self.treetn
252    }
253
254    /// Number of sites (tensors) in the tensor train.
255    #[inline]
256    pub fn len(&self) -> usize {
257        self.treetn.node_count()
258    }
259
260    /// Check if the tensor train is empty.
261    #[inline]
262    pub fn is_empty(&self) -> bool {
263        self.treetn.node_count() == 0
264    }
265
266    /// Left orthogonality limit.
267    ///
268    /// Sites `0..llim` are guaranteed to be left-orthogonal.
269    /// Returns -1 if no sites are left-orthogonal.
270    #[inline]
271    pub fn llim(&self) -> i32 {
272        match self.orthocenter() {
273            Some(center) => center as i32 - 1,
274            None => -1,
275        }
276    }
277
278    /// Right orthogonality limit.
279    ///
280    /// Sites `rlim..len()` are guaranteed to be right-orthogonal.
281    /// Returns `len() + 1` if no sites are right-orthogonal.
282    #[inline]
283    pub fn rlim(&self) -> i32 {
284        match self.orthocenter() {
285            Some(center) => center as i32 + 1,
286            None => self.len() as i32 + 1,
287        }
288    }
289
290    /// Set the left orthogonality limit.
291    #[inline]
292    pub fn set_llim(&mut self, llim: i32) {
293        // Convert to ortho center if possible
294        let rlim = self.rlim();
295        if llim + 2 == rlim && llim >= -1 && (llim + 1) < self.len() as i32 {
296            let center = (llim + 1) as usize;
297            let _ = self.treetn.set_canonical_region(vec![center]);
298        } else {
299            // Clear ortho region if not a single center
300            let _ = self.treetn.set_canonical_region(Vec::<usize>::new());
301        }
302    }
303
304    /// Set the right orthogonality limit.
305    #[inline]
306    pub fn set_rlim(&mut self, rlim: i32) {
307        // Convert to ortho center if possible
308        let llim = self.llim();
309        if llim + 2 == rlim && llim >= -1 && (llim + 1) < self.len() as i32 {
310            let center = (llim + 1) as usize;
311            let _ = self.treetn.set_canonical_region(vec![center]);
312        } else {
313            // Clear ortho region if not a single center
314            let _ = self.treetn.set_canonical_region(Vec::<usize>::new());
315        }
316    }
317
318    /// Get the orthogonality center range.
319    ///
320    /// Returns the range of sites that may not be orthogonal.
321    /// If the tensor train is fully left-orthogonal, returns an empty range at the end.
322    /// If the tensor train is fully right-orthogonal, returns an empty range at the beginning.
323    pub fn ortho_lims(&self) -> Range<usize> {
324        let llim = self.llim();
325        let rlim = self.rlim();
326        let start = (llim + 1).max(0) as usize;
327        let end = rlim.max(0) as usize;
328        start..end.min(self.len())
329    }
330
331    /// Check if the tensor train has a single orthogonality center.
332    ///
333    /// Returns true if there is exactly one site that is not guaranteed to be orthogonal.
334    #[inline]
335    pub fn isortho(&self) -> bool {
336        self.treetn.canonical_region().len() == 1
337    }
338
339    /// Get the orthogonality center (0-indexed).
340    ///
341    /// Returns `Some(site)` if the tensor train has a single orthogonality center,
342    /// `None` otherwise.
343    pub fn orthocenter(&self) -> Option<usize> {
344        let region = self.treetn.canonical_region();
345        if region.len() == 1 {
346            // Node name IS the site index since V = usize
347            Some(*region.iter().next().unwrap())
348        } else {
349            None
350        }
351    }
352
353    /// Get the canonicalization method used.
354    #[inline]
355    pub fn canonical_form(&self) -> Option<CanonicalForm> {
356        self.canonical_form
357    }
358
359    /// Set the canonicalization method.
360    #[inline]
361    pub fn set_canonical_form(&mut self, method: Option<CanonicalForm>) {
362        self.canonical_form = method;
363    }
364
365    /// Get a reference to the tensor at the given site.
366    ///
367    /// # Panics
368    ///
369    /// Panics if `site >= len()`.
370    #[inline]
371    pub fn tensor(&self, site: usize) -> &TensorDynLen {
372        let node_idx = self.treetn.node_index(&site).expect("Site out of bounds");
373        self.treetn.tensor(node_idx).expect("Tensor not found")
374    }
375
376    /// Get a reference to the tensor at the given site.
377    ///
378    /// Returns `Err` if `site >= len()`.
379    pub fn tensor_checked(&self, site: usize) -> Result<&TensorDynLen> {
380        if site >= self.len() {
381            return Err(TensorTrainError::SiteOutOfBounds {
382                site,
383                length: self.len(),
384            });
385        }
386        let node_idx =
387            self.treetn
388                .node_index(&site)
389                .ok_or_else(|| TensorTrainError::SiteOutOfBounds {
390                    site,
391                    length: self.len(),
392                })?;
393        self.treetn
394            .tensor(node_idx)
395            .ok_or_else(|| TensorTrainError::SiteOutOfBounds {
396                site,
397                length: self.len(),
398            })
399    }
400
401    /// Get a mutable reference to the tensor at the given site.
402    ///
403    /// # Panics
404    ///
405    /// Panics if `site >= len()`.
406    #[inline]
407    pub fn tensor_mut(&mut self, site: usize) -> &mut TensorDynLen {
408        let node_idx = self.treetn.node_index(&site).expect("Site out of bounds");
409        self.treetn.tensor_mut(node_idx).expect("Tensor not found")
410    }
411
412    /// Get a reference to all tensors.
413    #[inline]
414    pub fn tensors(&self) -> Vec<&TensorDynLen> {
415        (0..self.len())
416            .filter_map(|site| {
417                let node_idx = self.treetn.node_index(&site)?;
418                self.treetn.tensor(node_idx)
419            })
420            .collect()
421    }
422
423    /// Get a mutable reference to all tensors.
424    #[inline]
425    pub fn tensors_mut(&mut self) -> Vec<&mut TensorDynLen> {
426        let node_indices: Vec<_> = (0..self.len())
427            .map(|site| self.treetn.node_index(&site).expect("Site out of bounds"))
428            .collect();
429        let mut tensor_ptrs = Vec::with_capacity(node_indices.len());
430        for node_idx in node_indices {
431            let tensor = self.treetn.tensor_mut(node_idx).expect("Tensor not found");
432            tensor_ptrs.push(tensor as *mut TensorDynLen);
433        }
434
435        // SAFETY: TensorTrain site names are unique, so each site resolves to a
436        // distinct TreeTN node. We collect at most one pointer per node and do
437        // not mutate the network structure before converting those pointers back
438        // into mutable references.
439        unsafe { tensor_ptrs.into_iter().map(|tensor| &mut *tensor).collect() }
440    }
441
442    /// Get the link index between sites `i` and `i+1`.
443    ///
444    /// Returns `None` if `i >= len() - 1` or if no common index exists.
445    pub fn linkind(&self, i: usize) -> Option<DynIndex> {
446        if i >= self.len().saturating_sub(1) {
447            return None;
448        }
449
450        let left_node = self.treetn.node_index(&i)?;
451        let right_node = self.treetn.node_index(&(i + 1))?;
452        let left = self.treetn.tensor(left_node)?;
453        let right = self.treetn.tensor(right_node)?;
454        let common = common_inds(left.indices(), right.indices());
455        common.into_iter().next()
456    }
457
458    /// Get all link indices.
459    ///
460    /// Returns a vector of length `len() - 1` containing the link indices.
461    pub fn linkinds(&self) -> Vec<DynIndex> {
462        (0..self.len().saturating_sub(1))
463            .filter_map(|i| self.linkind(i))
464            .collect()
465    }
466
467    /// Create a copy with all link indices replaced by new unique IDs.
468    ///
469    /// This is useful for computing inner products where two tensor trains
470    /// share link indices. By simulating (replacing) the link indices in one
471    /// of the tensor trains, they can be contracted over site indices only.
472    pub fn sim_linkinds(&self) -> Self {
473        if self.len() <= 1 {
474            return self.clone();
475        }
476
477        // Build replacement pairs: (old_link, new_link) for each link index
478        let old_links = self.linkinds();
479        let new_links: Vec<_> = old_links.iter().map(|idx| idx.sim()).collect();
480        let replacements: Vec<_> = old_links
481            .iter()
482            .cloned()
483            .zip(new_links.iter().cloned())
484            .collect();
485
486        // Replace link indices in each tensor and rebuild
487        let mut new_tensors = Vec::with_capacity(self.len());
488        for site in 0..self.len() {
489            let tensor = self.tensor(site);
490            let mut new_tensor = tensor.clone();
491            for (old_idx, new_idx) in &replacements {
492                new_tensor = new_tensor.replaceind(old_idx, new_idx);
493            }
494            new_tensors.push(new_tensor);
495        }
496
497        Self::new(new_tensors).expect("sim_linkinds: failed to create new tensor train")
498    }
499
500    fn normalize_site_tensor_orders(&mut self) -> Result<()> {
501        for site in 0..self.len() {
502            self.normalize_site_tensor_order(site)?;
503        }
504        Ok(())
505    }
506
507    fn has_simple_linear_links(&self) -> bool {
508        if self.len() <= 1 {
509            return true;
510        }
511
512        (0..self.len() - 1).all(|site| {
513            let left = self.tensor(site);
514            let right = self.tensor(site + 1);
515            common_inds(left.indices(), right.indices()).len() <= 1
516        })
517    }
518
519    fn can_normalize_site_tensor_order(&self, site: usize) -> bool {
520        let left_ok = if site > 0 {
521            common_inds(self.tensor(site - 1).indices(), self.tensor(site).indices()).len() <= 1
522        } else {
523            true
524        };
525        let right_ok = if site + 1 < self.len() {
526            common_inds(self.tensor(site).indices(), self.tensor(site + 1).indices()).len() <= 1
527        } else {
528            true
529        };
530        left_ok && right_ok
531    }
532
533    fn normalize_site_tensor_order(&mut self, site: usize) -> Result<()> {
534        if !self.can_normalize_site_tensor_order(site) {
535            return Ok(());
536        }
537
538        let tensor = self.tensor_checked(site)?.clone();
539        let current = tensor.indices().to_vec();
540        let left = if site > 0 {
541            self.linkind(site - 1)
542        } else {
543            None
544        };
545        let right = if site + 1 < self.len() {
546            self.linkind(site)
547        } else {
548            None
549        };
550
551        let mut desired = Vec::with_capacity(current.len());
552        if let Some(ref left_link) = left {
553            desired.push(left_link.clone());
554        }
555        desired.extend(
556            current
557                .iter()
558                .filter(|idx| Some(*idx) != left.as_ref() && Some(*idx) != right.as_ref())
559                .cloned(),
560        );
561        if let Some(ref right_link) = right {
562            desired.push(right_link.clone());
563        }
564
565        if desired == current {
566            return Ok(());
567        }
568
569        let normalized =
570            tensor
571                .permuteinds(&desired)
572                .map_err(|e| TensorTrainError::InvalidStructure {
573                    message: format!(
574                        "Failed to normalize site tensor index order at site {}: {}",
575                        site, e
576                    ),
577                })?;
578        self.set_tensor_raw(site, normalized)
579    }
580
581    /// Get the site indices (non-link indices) for all sites.
582    ///
583    /// For each site, returns a vector of indices that are not shared with
584    /// adjacent tensors (i.e., the "physical" or "site" indices).
585    pub fn siteinds(&self) -> Vec<Vec<DynIndex>> {
586        if self.is_empty() {
587            return Vec::new();
588        }
589
590        let mut result = Vec::with_capacity(self.len());
591
592        for i in 0..self.len() {
593            let tensor = self.tensor(i);
594            let mut site_inds: Vec<DynIndex> = tensor.indices().to_vec();
595
596            // Remove link to left neighbor
597            if i > 0 {
598                if let Some(link) = self.linkind(i - 1) {
599                    site_inds.retain(|idx| idx != &link);
600                }
601            }
602
603            // Remove link to right neighbor
604            if i < self.len() - 1 {
605                if let Some(link) = self.linkind(i) {
606                    site_inds.retain(|idx| idx != &link);
607                }
608            }
609
610            result.push(site_inds);
611        }
612
613        result
614    }
615
616    /// Get the bond dimension at link `i` (between sites `i` and `i+1`).
617    ///
618    /// Returns `None` if `i >= len() - 1`.
619    pub fn bond_dim(&self, i: usize) -> Option<usize> {
620        self.linkind(i).map(|idx| idx.size())
621    }
622
623    /// Get all bond dimensions.
624    ///
625    /// Returns a vector of length `len() - 1`.
626    pub fn bond_dims(&self) -> Vec<usize> {
627        self.linkinds().iter().map(|idx| idx.size()).collect()
628    }
629
630    /// Get the maximum bond dimension.
631    pub fn maxbonddim(&self) -> usize {
632        self.bond_dims().into_iter().max().unwrap_or(1)
633    }
634
635    /// Check if two adjacent tensors share an index.
636    pub fn haslink(&self, i: usize) -> bool {
637        if i >= self.len().saturating_sub(1) {
638            return false;
639        }
640        let left_node = self.treetn.node_index(&i);
641        let right_node = self.treetn.node_index(&(i + 1));
642        match (left_node, right_node) {
643            (Some(l), Some(r)) => {
644                let left = self.treetn.tensor(l);
645                let right = self.treetn.tensor(r);
646                match (left, right) {
647                    (Some(l), Some(r)) => hascommoninds(l.indices(), r.indices()),
648                    _ => false,
649                }
650            }
651            _ => false,
652        }
653    }
654
655    /// Replace the tensor at the given site.
656    ///
657    /// This invalidates orthogonality tracking.
658    fn set_tensor_raw(&mut self, site: usize, tensor: TensorDynLen) -> Result<()> {
659        let node_idx = self.treetn.node_index(&site).expect("Site out of bounds");
660        self.treetn.replace_tensor(node_idx, tensor).map_err(|e| {
661            TensorTrainError::InvalidStructure {
662                message: format!("Failed to replace tensor at site {}: {}", site, e),
663            }
664        })?;
665        Ok(())
666    }
667
668    /// Replace the tensor at the given site.
669    ///
670    /// This invalidates orthogonality tracking.
671    pub fn set_tensor(&mut self, site: usize, tensor: TensorDynLen) {
672        self.set_tensor_raw(site, tensor)
673            .and_then(|()| self.normalize_site_tensor_order(site))
674            .unwrap_or_else(|e| panic!("TensorTrain::set_tensor failed: {}", e));
675        // Invalidate orthogonality
676        let _ = self.treetn.set_canonical_region(Vec::<usize>::new());
677    }
678
679    /// Orthogonalize the tensor train to have orthogonality center at the given site.
680    ///
681    /// This function performs a series of factorizations to make the tensor train
682    /// canonical with orthogonality center at `site`.
683    ///
684    /// # Arguments
685    ///
686    /// * `site` - The target site for the orthogonality center (0-indexed)
687    ///
688    /// # Errors
689    ///
690    /// Returns an error if the factorization fails or if the site is out of bounds.
691    ///
692    /// # Examples
693    ///
694    /// ```
695    /// use tensor4all_itensorlike::TensorTrain;
696    /// use tensor4all_core::{DynIndex, TensorDynLen, Index, DynId};
697    ///
698    /// let s0 = Index::new_with_size(DynId(0), 2);
699    /// let link = Index::new_with_size(DynId(1), 3);
700    /// let s1 = Index::new_with_size(DynId(2), 2);
701    ///
702    /// let t0 = TensorDynLen::from_dense(
703    ///     vec![s0.clone(), link.clone()],
704    ///     (0..6).map(|i| i as f64).collect(),
705    /// ).unwrap();
706    /// let t1 = TensorDynLen::from_dense(
707    ///     vec![link.clone(), s1.clone()],
708    ///     (0..6).map(|i| i as f64).collect(),
709    /// ).unwrap();
710    ///
711    /// let mut tt = TensorTrain::new(vec![t0, t1]).unwrap();
712    /// assert!(!tt.isortho());
713    ///
714    /// // Orthogonalize to site 0
715    /// tt.orthogonalize(0).unwrap();
716    /// assert!(tt.isortho());
717    /// assert_eq!(tt.orthocenter(), Some(0));
718    /// ```
719    pub fn orthogonalize(&mut self, site: usize) -> Result<()> {
720        self.orthogonalize_with(site, CanonicalForm::Unitary)
721    }
722
723    /// Orthogonalize with a specified canonical form.
724    ///
725    /// # Arguments
726    ///
727    /// * `site` - The target site for the orthogonality center (0-indexed)
728    /// * `form` - The canonical form to use:
729    ///   - `Unitary`: Uses QR decomposition, each tensor is isometric
730    ///   - `LU`: Uses LU decomposition, one factor has unit diagonal
731    ///   - `CI`: Uses Cross Interpolation
732    pub fn orthogonalize_with(&mut self, site: usize, form: CanonicalForm) -> Result<()> {
733        if self.is_empty() {
734            return Err(TensorTrainError::Empty);
735        }
736        if site >= self.len() {
737            return Err(TensorTrainError::SiteOutOfBounds {
738                site,
739                length: self.len(),
740            });
741        }
742
743        // Use TreeTN's canonicalize (accepts node names and CanonicalizationOptions)
744        // Since V = usize, node names are site indices
745        let options = CanonicalizationOptions::forced().with_form(form);
746        self.treetn = std::mem::take(&mut self.treetn)
747            .canonicalize(vec![site], options)
748            .map_err(|e| TensorTrainError::InvalidStructure {
749                message: format!("Canonicalize failed: {}", e),
750            })?;
751
752        self.canonical_form = Some(form);
753        Ok(())
754    }
755
756    /// Truncate the tensor train bond dimensions.
757    ///
758    /// This delegates to the TreeTN's truncate_mut method, which performs a
759    /// two-site sweep with Euler tour traversal for optimal truncation.
760    ///
761    /// Note: The `site_range` option in `TruncateOptions` is currently ignored
762    /// as the underlying TreeTN truncation operates on the full network.
763    ///
764    /// # Examples
765    ///
766    /// ```
767    /// use tensor4all_itensorlike::{TensorTrain, TruncateOptions};
768    /// use tensor4all_core::{DynIndex, TensorDynLen, Index, DynId};
769    ///
770    /// // Build a 3-site tensor train with bond dimension 4
771    /// let s0 = Index::new_with_size(DynId(0), 2);
772    /// let l01 = Index::new_with_size(DynId(1), 4);
773    /// let s1 = Index::new_with_size(DynId(2), 2);
774    /// let l12 = Index::new_with_size(DynId(3), 4);
775    /// let s2 = Index::new_with_size(DynId(4), 2);
776    ///
777    /// let t0 = TensorDynLen::from_dense(
778    ///     vec![s0.clone(), l01.clone()],
779    ///     (0..8).map(|i| i as f64).collect(),
780    /// ).unwrap();
781    /// let t1 = TensorDynLen::from_dense(
782    ///     vec![l01.clone(), s1.clone(), l12.clone()],
783    ///     (0..32).map(|i| i as f64).collect(),
784    /// ).unwrap();
785    /// let t2 = TensorDynLen::from_dense(
786    ///     vec![l12.clone(), s2.clone()],
787    ///     (0..8).map(|i| i as f64).collect(),
788    /// ).unwrap();
789    ///
790    /// let mut tt = TensorTrain::new(vec![t0, t1, t2]).unwrap();
791    /// assert_eq!(tt.maxbonddim(), 4);
792    ///
793    /// // Truncate bond dimension to at most 2
794    /// let opts = TruncateOptions::svd().with_max_rank(2);
795    /// tt.truncate(&opts).unwrap();
796    /// assert!(tt.maxbonddim() <= 2);
797    /// ```
798    pub fn truncate(&mut self, options: &TruncateOptions) -> Result<()> {
799        if self.len() <= 1 {
800            return Ok(());
801        }
802
803        validate_svd_truncation_options(options.max_rank(), options.svd_policy())?;
804
805        // Convert TruncateOptions to TruncationOptions
806        let mut treetn_options = TruncationOptions::new();
807        if let Some(policy) = options.svd_policy() {
808            treetn_options = treetn_options.with_svd_policy(policy);
809        }
810        if let Some(max_rank) = options.max_rank() {
811            treetn_options = treetn_options.with_max_rank(max_rank);
812        }
813
814        // Truncate towards the last site (rightmost) as the canonical center
815        // This matches ITensor convention where truncation sweeps left-to-right
816        let center = self.len() - 1;
817
818        self.treetn
819            .truncate_mut([center], treetn_options)
820            .map_err(|e| TensorTrainError::InvalidStructure {
821                message: format!("TreeTN truncation failed: {}", e),
822            })?;
823
824        self.canonical_form = Some(CanonicalForm::Unitary);
825
826        Ok(())
827    }
828
829    /// Compute the inner product (dot product) of two tensor trains.
830    ///
831    /// Computes `<self | other>` = sum over all indices of `conj(self) * other`.
832    ///
833    /// Both tensor trains must have the same site indices (same IDs).
834    /// Link indices may differ between the two tensor trains.
835    ///
836    /// # Examples
837    ///
838    /// ```
839    /// use tensor4all_itensorlike::TensorTrain;
840    /// use tensor4all_core::{DynIndex, TensorDynLen, Index, DynId, AnyScalar};
841    ///
842    /// // Single-site tensor train with values [1.0, 0.0]
843    /// let s0 = Index::new_with_size(DynId(0), 2);
844    /// let t = TensorDynLen::from_dense(
845    ///     vec![s0.clone()],
846    ///     vec![1.0_f64, 0.0],
847    /// ).unwrap();
848    ///
849    /// let tt = TensorTrain::new(vec![t]).unwrap();
850    ///
851    /// // <tt | tt> = 1.0^2 + 0.0^2 = 1.0
852    /// let result = tt.inner(&tt);
853    /// assert!((result.real() - 1.0).abs() < 1e-10);
854    /// ```
855    pub fn inner(&self, other: &Self) -> AnyScalar {
856        assert_eq!(
857            self.len(),
858            other.len(),
859            "Tensor trains must have the same length for inner product"
860        );
861
862        if self.is_empty() {
863            return AnyScalar::new_real(0.0);
864        }
865
866        // Sequential bra-ket contraction along the chain: O(N·D²·d).
867        // TreeTN::inner() uses contract_naive which is O(d^N) and OOMs for large N.
868        let other_sim = other.treetn.sim_internal_inds();
869
870        let n = self.len();
871        let node_idx = |ttn: &TreeTN<TensorDynLen, usize>, site: usize| {
872            ttn.node_index(&site).expect("node not found")
873        };
874
875        // Start with leftmost tensors - contract over site indices only
876        let mut env = {
877            let a0_conj = self.tensor(0).conj();
878            let b0 = other_sim
879                .tensor(node_idx(&other_sim, 0))
880                .expect("tensor not found")
881                .clone();
882            a0_conj.contract(&b0)
883        };
884
885        // Sweep through remaining sites
886        for i in 1..n {
887            let ai_conj = self.tensor(i).conj();
888            let bi = other_sim
889                .tensor(node_idx(&other_sim, i))
890                .expect("tensor not found");
891
892            // Contract: env * conj(A_i) (over self's link index)
893            env = env.contract(&ai_conj);
894            // Contract: result * B_i (over other's link index and site indices)
895            env = env.contract(bi);
896        }
897
898        // Result should be a scalar (0-dimensional tensor)
899        env.only()
900    }
901
902    /// Compute the squared norm of the tensor train.
903    ///
904    /// Returns `<self | self>` = ||self||^2.
905    ///
906    /// # Note
907    /// For linear tensor trains with one site index per site, this uses a
908    /// specialized chain contraction instead of the generic inner-product path.
909    /// Due to numerical errors, the final scalar can be very slightly negative,
910    /// so the returned value is clamped to be non-negative.
911    pub fn norm_squared(&self) -> f64 {
912        self.norm_squared_fast_path()
913            .unwrap_or_else(|| self.inner(self).real().max(0.0))
914    }
915
916    /// Compute the norm of the tensor train.
917    ///
918    /// Returns ||self|| = sqrt(<self | self>).
919    pub fn norm(&self) -> f64 {
920        self.norm_squared().sqrt()
921    }
922
923    fn norm_squared_fast_path(&self) -> Option<f64> {
924        if self.is_empty() {
925            return Some(0.0);
926        }
927        if !self.has_simple_linear_links() {
928            return None;
929        }
930        if self
931            .siteinds()
932            .iter()
933            .any(|site_indices| site_indices.len() != 1)
934        {
935            return None;
936        }
937
938        let mut normalized = self.clone();
939        normalized.normalize_site_tensor_orders().ok()?;
940
941        if let Some(sites) = Self::pack_normalized_sites::<f64>(&normalized) {
942            return Some(Self::norm_squared_from_packed_sites(&sites));
943        }
944        if let Some(sites) = Self::pack_normalized_sites::<Complex64>(&normalized) {
945            return Some(Self::norm_squared_from_packed_sites(&sites));
946        }
947
948        None
949    }
950
951    fn pack_normalized_sites<T: TensorElement>(tt: &Self) -> Option<Vec<PackedSiteTensor<T>>> {
952        let mut sites = Vec::with_capacity(tt.len());
953
954        for site in 0..tt.len() {
955            let tensor = tt.tensor(site);
956            let left_dim = if site == 0 {
957                1
958            } else {
959                tt.linkind(site - 1)?.size()
960            };
961            let right_dim = if site + 1 == tt.len() {
962                1
963            } else {
964                tt.linkind(site)?.size()
965            };
966            let total_size: usize = tensor.dims().iter().product();
967            let boundary_size = left_dim.checked_mul(right_dim)?;
968            if boundary_size == 0 || !total_size.is_multiple_of(boundary_size) {
969                return None;
970            }
971
972            sites.push(PackedSiteTensor {
973                left_dim,
974                physical_dim: total_size / boundary_size,
975                right_dim,
976                data: tensor.to_vec::<T>().ok()?,
977            });
978        }
979
980        Some(sites)
981    }
982
983    fn norm_squared_from_packed_sites<T: NormAccumScalar>(sites: &[PackedSiteTensor<T>]) -> f64 {
984        if sites.is_empty() {
985            return 0.0;
986        }
987
988        let first = &sites[0];
989        let mut current = vec![T::zero(); first.right_dim * first.right_dim];
990
991        for physical in 0..first.physical_dim {
992            for right in 0..first.right_dim {
993                let value = first.get(0, physical, right);
994                for right_conj in 0..first.right_dim {
995                    let idx = right * first.right_dim + right_conj;
996                    current[idx] = current[idx] + value * first.get(0, physical, right_conj).conj();
997                }
998            }
999        }
1000
1001        for site in &sites[1..] {
1002            let mut next = vec![T::zero(); site.right_dim * site.right_dim];
1003
1004            for left in 0..site.left_dim {
1005                for left_conj in 0..site.left_dim {
1006                    let env = current[left * site.left_dim + left_conj];
1007                    for physical in 0..site.physical_dim {
1008                        for right in 0..site.right_dim {
1009                            let value = site.get(left, physical, right);
1010                            for right_conj in 0..site.right_dim {
1011                                let idx = right * site.right_dim + right_conj;
1012                                next[idx] = next[idx]
1013                                    + env
1014                                        * value
1015                                        * site.get(left_conj, physical, right_conj).conj();
1016                            }
1017                        }
1018                    }
1019                }
1020            }
1021
1022            current = next;
1023        }
1024
1025        current[0].into_nonnegative_real()
1026    }
1027
1028    /// Convert the tensor train to a single dense tensor.
1029    ///
1030    /// This contracts all tensors in the train along their link indices,
1031    /// producing a single tensor with only site indices.
1032    ///
1033    /// # Warning
1034    /// This operation can be very expensive for large tensor trains,
1035    /// as the result size grows exponentially with the number of sites.
1036    ///
1037    /// # Returns
1038    /// A single tensor containing all site indices, or an error if the
1039    /// tensor train is empty.
1040    ///
1041    /// # Example
1042    /// ```
1043    /// use tensor4all_core::{DynIndex, TensorDynLen};
1044    /// use tensor4all_itensorlike::TensorTrain;
1045    ///
1046    /// # fn main() -> anyhow::Result<()> {
1047    /// let s0 = DynIndex::new_dyn(2);
1048    /// let link = DynIndex::new_dyn(1);
1049    /// let s1 = DynIndex::new_dyn(2);
1050    /// let t0 = TensorDynLen::from_dense(vec![s0.clone(), link.clone()], vec![1.0, 2.0])?;
1051    /// let t1 = TensorDynLen::from_dense(vec![link.clone(), s1.clone()], vec![3.0, 4.0])?;
1052    ///
1053    /// let tt = TensorTrain::new(vec![t0, t1])?;
1054    /// let dense = tt.to_dense()?;
1055    ///
1056    /// assert_eq!(dense.dims(), vec![2, 2]);
1057    /// assert_eq!(dense.to_vec::<f64>()?, vec![3.0, 6.0, 4.0, 8.0]);
1058    /// # Ok(())
1059    /// # }
1060    /// ```
1061    pub fn to_dense(&self) -> Result<TensorDynLen> {
1062        if self.is_empty() {
1063            return Err(TensorTrainError::InvalidStructure {
1064                message: "Cannot convert empty tensor train to dense".to_string(),
1065            });
1066        }
1067
1068        self.treetn
1069            .contract_to_tensor()
1070            .map_err(|e| TensorTrainError::InvalidStructure {
1071                message: format!("Failed to contract to dense: {}", e),
1072            })
1073    }
1074
1075    /// Add two tensor trains using direct-sum construction.
1076    ///
1077    /// This creates a new tensor train where each tensor is the direct sum of the
1078    /// corresponding tensors from self and other, with bond dimensions merged.
1079    /// The result has bond dimensions equal to the sum of the input bond dimensions.
1080    ///
1081    /// # Arguments
1082    /// * `other` - The other tensor train to add
1083    ///
1084    /// # Returns
1085    /// A new tensor train representing the sum.
1086    ///
1087    /// # Errors
1088    /// Returns an error if the tensor trains have incompatible structures.
1089    pub fn add(&self, other: &Self) -> Result<Self> {
1090        if self.is_empty() && other.is_empty() {
1091            return Ok(Self::default());
1092        }
1093
1094        if self.is_empty() {
1095            return Ok(other.clone());
1096        }
1097
1098        if other.is_empty() {
1099            return Ok(self.clone());
1100        }
1101
1102        if self.len() != other.len() {
1103            return Err(TensorTrainError::InvalidStructure {
1104                message: format!(
1105                    "Tensor trains must have the same length for addition: {} vs {}",
1106                    self.len(),
1107                    other.len()
1108                ),
1109            });
1110        }
1111
1112        let result_inner =
1113            self.treetn
1114                .add(&other.treetn)
1115                .map_err(|e| TensorTrainError::InvalidStructure {
1116                    message: format!("TT addition failed: {}", e),
1117                })?;
1118
1119        Self::from_inner(result_inner, None)
1120    }
1121
1122    /// Scale the tensor train by a scalar.
1123    ///
1124    /// Only one tensor (the first non-empty site) is scaled to avoid scalar^n scaling.
1125    /// This is correct because the tensor train represents a product of tensors,
1126    /// so scaling one factor scales the entire product.
1127    ///
1128    /// # Arguments
1129    /// * `scalar` - The scalar to multiply by
1130    ///
1131    /// # Returns
1132    /// A new tensor train scaled by the given scalar.
1133    ///
1134    /// # Example
1135    /// ```
1136    /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen};
1137    /// use tensor4all_itensorlike::TensorTrain;
1138    ///
1139    /// # fn main() -> anyhow::Result<()> {
1140    /// let s0 = DynIndex::new_dyn(2);
1141    /// let tt = TensorTrain::new(vec![TensorDynLen::from_dense(
1142    ///     vec![s0.clone()],
1143    ///     vec![1.0, 2.0],
1144    /// )?])?;
1145    ///
1146    /// let scaled = tt.scale(AnyScalar::new_real(2.0))?;
1147    /// assert_eq!(scaled.to_dense()?.to_vec::<f64>()?, vec![2.0, 4.0]);
1148    /// # Ok(())
1149    /// # }
1150    /// ```
1151    pub fn scale(&self, scalar: AnyScalar) -> Result<Self> {
1152        if self.is_empty() {
1153            return Ok(self.clone());
1154        }
1155
1156        let mut tensors = Vec::with_capacity(self.len());
1157        for site in 0..self.len() {
1158            let tensor = self.tensor(site);
1159            if site == 0 {
1160                // Scale only the first tensor
1161                let scaled =
1162                    tensor
1163                        .scale(scalar.clone())
1164                        .map_err(|e| TensorTrainError::OperationError {
1165                            message: format!("Failed to scale tensor at site 0: {}", e),
1166                        })?;
1167                tensors.push(scaled);
1168            } else {
1169                tensors.push(tensor.clone());
1170            }
1171        }
1172
1173        Self::new(tensors)
1174    }
1175
1176    /// Compute a linear combination: `a * self + b * other`.
1177    ///
1178    /// This is equivalent to `self.scale(a)?.add(&other.scale(b)?)`.
1179    ///
1180    /// # Arguments
1181    /// * `a` - Scalar coefficient for self
1182    /// * `other` - The other tensor train
1183    /// * `b` - Scalar coefficient for other
1184    ///
1185    /// # Returns
1186    /// A new tensor train representing `a * self + b * other`.
1187    ///
1188    /// # Note
1189    /// The bond dimension of the result is the sum of the bond dimensions
1190    /// of the two input tensor trains (before any truncation).
1191    pub fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
1192        let scaled_self = self.scale(a)?;
1193        let scaled_other = other.scale(b)?;
1194        scaled_self.add(&scaled_other)
1195    }
1196}
1197
1198// Implement Default for TensorTrain to allow std::mem::take
1199impl Default for TensorTrain {
1200    fn default() -> Self {
1201        Self::new(vec![]).expect("Failed to create empty TensorTrain")
1202    }
1203}
1204
1205// ============================================================================
1206// TensorIndex implementation for TensorTrain
1207// ============================================================================
1208
1209impl TensorIndex for TensorTrain {
1210    type Index = DynIndex;
1211
1212    fn external_indices(&self) -> Vec<Self::Index> {
1213        // Delegate to the internal TreeTN's TensorIndex implementation
1214        self.treetn.external_indices()
1215    }
1216
1217    fn num_external_indices(&self) -> usize {
1218        self.treetn.num_external_indices()
1219    }
1220
1221    fn replaceind(&self, old: &Self::Index, new: &Self::Index) -> anyhow::Result<Self> {
1222        // Delegate to the internal TreeTN's replaceind
1223        // After replacement, canonical form may be invalid, so set to None
1224        let treetn = self.treetn.replaceind(old, new)?;
1225        Self::from_inner(treetn, None).map_err(|e| anyhow::anyhow!("{}", e))
1226    }
1227
1228    fn replaceinds(&self, old: &[Self::Index], new: &[Self::Index]) -> anyhow::Result<Self> {
1229        let treetn = self.treetn.replaceinds(old, new)?;
1230        Self::from_inner(treetn, None).map_err(|e| anyhow::anyhow!("{}", e))
1231    }
1232}
1233
1234// ============================================================================
1235// TensorLike implementation for TensorTrain
1236// ============================================================================
1237
1238impl TensorLike for TensorTrain {
1239    // ========================================================================
1240    // GMRES-required methods (fully supported)
1241    // ========================================================================
1242
1243    fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> anyhow::Result<Self> {
1244        TensorTrain::axpby(self, a, other, b).map_err(|e| anyhow::anyhow!("{}", e))
1245    }
1246
1247    fn scale(&self, scalar: AnyScalar) -> anyhow::Result<Self> {
1248        TensorTrain::scale(self, scalar).map_err(|e| anyhow::anyhow!("{}", e))
1249    }
1250
1251    fn inner_product(&self, other: &Self) -> anyhow::Result<AnyScalar> {
1252        Ok(self.inner(other))
1253    }
1254
1255    fn norm_squared(&self) -> f64 {
1256        TensorTrain::norm_squared(self)
1257    }
1258
1259    fn maxabs(&self) -> f64 {
1260        self.to_dense().map(|t| t.maxabs()).unwrap_or(0.0)
1261    }
1262
1263    fn conj(&self) -> Self {
1264        // Clone and conjugate each site tensor
1265        // Note: conj() cannot return Result, so we ensure this never fails
1266        let mut result = self.clone();
1267        for site in 0..result.len() {
1268            let t = result.tensor(site).conj();
1269            result.set_tensor(site, t);
1270        }
1271        result
1272    }
1273
1274    // ========================================================================
1275    // Methods not supported by TensorTrain
1276    // ========================================================================
1277
1278    fn factorize(
1279        &self,
1280        _left_inds: &[Self::Index],
1281        _options: &FactorizeOptions,
1282    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
1283        Err(FactorizeError::UnsupportedStorage(
1284            "TensorTrain does not support factorize; use orthogonalize() instead",
1285        ))
1286    }
1287
1288    fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> anyhow::Result<Self> {
1289        anyhow::bail!("TensorTrain does not support TensorLike::contract; use TensorTrain::contract() method instead")
1290    }
1291
1292    fn contract_connected(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> anyhow::Result<Self> {
1293        anyhow::bail!("TensorTrain does not support TensorLike::contract_connected; use TensorTrain::contract() method instead")
1294    }
1295
1296    fn direct_sum(
1297        &self,
1298        _other: &Self,
1299        _pairs: &[(Self::Index, Self::Index)],
1300    ) -> anyhow::Result<DirectSumResult<Self>> {
1301        anyhow::bail!("TensorTrain does not support direct_sum; use add() instead")
1302    }
1303
1304    fn outer_product(&self, _other: &Self) -> anyhow::Result<Self> {
1305        anyhow::bail!("TensorTrain does not support outer_product")
1306    }
1307
1308    fn permuteinds(&self, _new_order: &[Self::Index]) -> anyhow::Result<Self> {
1309        anyhow::bail!("TensorTrain does not support permuteinds")
1310    }
1311
1312    fn diagonal(input: &Self::Index, output: &Self::Index) -> anyhow::Result<Self> {
1313        // Create a single-site TensorTrain with an identity tensor
1314        let delta = TensorDynLen::diagonal(input, output)?;
1315        Self::new(vec![delta]).map_err(|e| anyhow::anyhow!("{}", e))
1316    }
1317
1318    fn scalar_one() -> anyhow::Result<Self> {
1319        // Empty tensor train represents scalar 1
1320        Self::new(vec![]).map_err(|e| anyhow::anyhow!("{}", e))
1321    }
1322
1323    fn ones(indices: &[Self::Index]) -> anyhow::Result<Self> {
1324        let t = TensorDynLen::ones(indices)?;
1325        Self::new(vec![t]).map_err(|e| anyhow::anyhow!("{}", e))
1326    }
1327
1328    fn onehot(index_vals: &[(Self::Index, usize)]) -> anyhow::Result<Self> {
1329        let t = TensorDynLen::onehot(index_vals)?;
1330        Self::new(vec![t]).map_err(|e| anyhow::anyhow!("{}", e))
1331    }
1332}
1333
1334#[cfg(test)]
1335mod tests;