Skip to main content

tensor4all_treetn/
simplett_bridge.rs

1use anyhow::{ensure, Result};
2use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorElement};
3use tensor4all_simplett::{AbstractTensorTrain, TTScalar, Tensor3Ops, TensorTrain};
4
5use crate::TreeTN;
6
7/// Convert a linear-chain simple tensor train into a `TreeTN` with node names `0..n-1`.
8///
9/// The returned site indices are ordered by tensor-train site position, which is
10/// convenient for downstream state/layout bookkeeping.
11///
12/// # Examples
13///
14/// ```
15/// use tensor4all_core::{IndexLike, TensorIndex, TensorLike};
16/// use tensor4all_simplett::{tensor3_from_data, AbstractTensorTrain, TensorTrain};
17/// use tensor4all_treetn::tensor_train_to_treetn;
18///
19/// let tt = TensorTrain::new(vec![
20///     tensor3_from_data(vec![1.0_f64, 2.0], 1, 2, 1),
21/// ]).unwrap();
22///
23/// let (treetn, site_indices) = tensor_train_to_treetn(&tt).unwrap();
24/// let dense = treetn.contract_to_tensor().unwrap();
25///
26/// assert_eq!(treetn.node_names(), vec![0]);
27/// assert_eq!(site_indices.len(), 1);
28/// assert_eq!(dense.external_indices()[0].id(), site_indices[0].id());
29/// assert_eq!(dense.dims(), vec![2]);
30/// ```
31pub fn tensor_train_to_treetn<T>(
32    tt: &TensorTrain<T>,
33) -> Result<(TreeTN<TensorDynLen, usize>, Vec<DynIndex>)>
34where
35    T: TTScalar + TensorElement + Clone,
36{
37    tensor_train_to_treetn_with_names(tt, (0..tt.len()).collect())
38}
39
40/// Convert a linear-chain simple tensor train into a `TreeTN` with explicit node names.
41///
42/// The returned site indices are ordered by tensor-train site position, not by
43/// sorted node-name order.
44///
45/// # Examples
46///
47/// ```
48/// use tensor4all_simplett::{tensor3_from_data, TensorTrain};
49/// use tensor4all_treetn::tensor_train_to_treetn_with_names;
50///
51/// let tt = TensorTrain::new(vec![
52///     tensor3_from_data(vec![1.0_f64, 2.0], 1, 2, 1),
53/// ]).unwrap();
54///
55/// let (treetn, site_indices) =
56///     tensor_train_to_treetn_with_names(&tt, vec!["site0".to_string()]).unwrap();
57///
58/// assert_eq!(treetn.node_names(), vec!["site0".to_string()]);
59/// assert_eq!(site_indices.len(), 1);
60/// ```
61pub fn tensor_train_to_treetn_with_names<T, V>(
62    tt: &TensorTrain<T>,
63    node_names: Vec<V>,
64) -> Result<(TreeTN<TensorDynLen, V>, Vec<DynIndex>)>
65where
66    T: TTScalar + TensorElement + Clone,
67    V: Clone + std::hash::Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
68{
69    tensor_train_to_treetn_impl(tt, node_names, None)
70}
71
72/// Convert a linear-chain simple tensor train into a `TreeTN` with explicit node names
73/// and caller-provided site indices.
74///
75/// This is useful when downstream code must preserve external site-index identities
76/// across a conversion boundary while still allowing internal bond indices to be
77/// created fresh.
78///
79/// # Examples
80///
81/// ```
82/// use tensor4all_core::DynIndex;
83/// use tensor4all_simplett::{tensor3_from_data, TensorTrain};
84/// use tensor4all_treetn::tensor_train_to_treetn_with_names_and_site_indices;
85///
86/// let tt = TensorTrain::new(vec![
87///     tensor3_from_data(vec![1.0_f64, 2.0], 1, 2, 1),
88/// ]).unwrap();
89/// let site = DynIndex::new_dyn(2);
90///
91/// let treetn = tensor_train_to_treetn_with_names_and_site_indices(
92///     &tt,
93///     vec!["site0".to_string()],
94///     vec![site],
95/// ).unwrap();
96///
97/// assert_eq!(treetn.node_names(), vec!["site0".to_string()]);
98/// ```
99pub fn tensor_train_to_treetn_with_names_and_site_indices<T, V>(
100    tt: &TensorTrain<T>,
101    node_names: Vec<V>,
102    site_indices: Vec<DynIndex>,
103) -> Result<TreeTN<TensorDynLen, V>>
104where
105    T: TTScalar + TensorElement + Clone,
106    V: Clone + std::hash::Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
107{
108    let (treetn, _) = tensor_train_to_treetn_impl(tt, node_names, Some(site_indices))?;
109    Ok(treetn)
110}
111
112fn tensor_train_to_treetn_impl<T, V>(
113    tt: &TensorTrain<T>,
114    node_names: Vec<V>,
115    site_indices: Option<Vec<DynIndex>>,
116) -> Result<(TreeTN<TensorDynLen, V>, Vec<DynIndex>)>
117where
118    T: TTScalar + TensorElement + Clone,
119    V: Clone + std::hash::Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
120{
121    ensure!(
122        tt.len() == node_names.len(),
123        "tensor_train_to_treetn: node_names length {} must match tensor-train length {}",
124        node_names.len(),
125        tt.len()
126    );
127
128    if tt.is_empty() {
129        let site_indices = site_indices.unwrap_or_default();
130        ensure!(
131            site_indices.is_empty(),
132            "tensor_train_to_treetn: empty tensor train requires zero site indices"
133        );
134        return Ok((TreeTN::new(), Vec::new()));
135    }
136
137    let site_indices = match site_indices {
138        Some(indices) => {
139            ensure!(
140                indices.len() == tt.len(),
141                "tensor_train_to_treetn: site_indices length {} must match tensor-train length {}",
142                indices.len(),
143                tt.len()
144            );
145            for (site, index) in indices.iter().enumerate() {
146                ensure!(
147                    index.dim() == tt.site_dim(site),
148                    "tensor_train_to_treetn: site index {} has dim {} but tensor-train site {} has dim {}",
149                    site,
150                    index.dim(),
151                    site,
152                    tt.site_dim(site)
153                );
154            }
155            indices
156        }
157        None => tt.site_dims().into_iter().map(DynIndex::new_dyn).collect(),
158    };
159    let bond_indices: Vec<DynIndex> = tt
160        .link_dims()
161        .into_iter()
162        .map(DynIndex::new_bond)
163        .collect::<Result<_>>()?;
164    let nsites = tt.len();
165
166    let mut tensors = Vec::with_capacity(nsites);
167    for site in 0..nsites {
168        let site_tensor = tt.site_tensor(site);
169        let tensor = if nsites == 1 {
170            let data = single_site_data(site_tensor);
171            TensorDynLen::from_dense(vec![site_indices[site].clone()], data)?
172        } else if site == 0 {
173            let data = left_boundary_data(site_tensor);
174            TensorDynLen::from_dense(
175                vec![site_indices[site].clone(), bond_indices[site].clone()],
176                data,
177            )?
178        } else if site + 1 == nsites {
179            let data = right_boundary_data(site_tensor);
180            TensorDynLen::from_dense(
181                vec![bond_indices[site - 1].clone(), site_indices[site].clone()],
182                data,
183            )?
184        } else {
185            let data = middle_site_data(site_tensor);
186            TensorDynLen::from_dense(
187                vec![
188                    bond_indices[site - 1].clone(),
189                    site_indices[site].clone(),
190                    bond_indices[site].clone(),
191                ],
192                data,
193            )?
194        };
195        tensors.push(tensor);
196    }
197
198    let treetn = TreeTN::from_tensors(tensors, node_names)?;
199    Ok((treetn, site_indices))
200}
201
202fn single_site_data<T>(tensor: &tensor4all_simplett::Tensor3<T>) -> Vec<T>
203where
204    T: TTScalar + Clone,
205{
206    let mut data = Vec::with_capacity(tensor.site_dim());
207    for s in 0..tensor.site_dim() {
208        data.push(*tensor.get3(0, s, 0));
209    }
210    data
211}
212
213fn left_boundary_data<T>(tensor: &tensor4all_simplett::Tensor3<T>) -> Vec<T>
214where
215    T: TTScalar + Clone,
216{
217    let mut data = Vec::with_capacity(tensor.site_dim() * tensor.right_dim());
218    for r in 0..tensor.right_dim() {
219        for s in 0..tensor.site_dim() {
220            data.push(*tensor.get3(0, s, r));
221        }
222    }
223    data
224}
225
226fn right_boundary_data<T>(tensor: &tensor4all_simplett::Tensor3<T>) -> Vec<T>
227where
228    T: TTScalar + Clone,
229{
230    let mut data = Vec::with_capacity(tensor.left_dim() * tensor.site_dim());
231    for s in 0..tensor.site_dim() {
232        for l in 0..tensor.left_dim() {
233            data.push(*tensor.get3(l, s, 0));
234        }
235    }
236    data
237}
238
239fn middle_site_data<T>(tensor: &tensor4all_simplett::Tensor3<T>) -> Vec<T>
240where
241    T: TTScalar + Clone,
242{
243    let mut data = Vec::with_capacity(tensor.left_dim() * tensor.site_dim() * tensor.right_dim());
244    for r in 0..tensor.right_dim() {
245        for s in 0..tensor.site_dim() {
246            for l in 0..tensor.left_dim() {
247                data.push(*tensor.get3(l, s, r));
248            }
249        }
250    }
251    data
252}