tensor4all_treetn/
simplett_bridge.rs1use anyhow::{ensure, Result};
2use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorElement};
3use tensor4all_simplett::{AbstractTensorTrain, TTScalar, Tensor3Ops, TensorTrain};
4
5use crate::TreeTN;
6
7pub fn tensor_train_to_treetn<T>(
32 tt: &TensorTrain<T>,
33) -> Result<(TreeTN<TensorDynLen, usize>, Vec<DynIndex>)>
34where
35 T: TTScalar + TensorElement + Clone,
36{
37 tensor_train_to_treetn_with_names(tt, (0..tt.len()).collect())
38}
39
40pub fn tensor_train_to_treetn_with_names<T, V>(
62 tt: &TensorTrain<T>,
63 node_names: Vec<V>,
64) -> Result<(TreeTN<TensorDynLen, V>, Vec<DynIndex>)>
65where
66 T: TTScalar + TensorElement + Clone,
67 V: Clone + std::hash::Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
68{
69 tensor_train_to_treetn_impl(tt, node_names, None)
70}
71
72pub fn tensor_train_to_treetn_with_names_and_site_indices<T, V>(
100 tt: &TensorTrain<T>,
101 node_names: Vec<V>,
102 site_indices: Vec<DynIndex>,
103) -> Result<TreeTN<TensorDynLen, V>>
104where
105 T: TTScalar + TensorElement + Clone,
106 V: Clone + std::hash::Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
107{
108 let (treetn, _) = tensor_train_to_treetn_impl(tt, node_names, Some(site_indices))?;
109 Ok(treetn)
110}
111
112fn tensor_train_to_treetn_impl<T, V>(
113 tt: &TensorTrain<T>,
114 node_names: Vec<V>,
115 site_indices: Option<Vec<DynIndex>>,
116) -> Result<(TreeTN<TensorDynLen, V>, Vec<DynIndex>)>
117where
118 T: TTScalar + TensorElement + Clone,
119 V: Clone + std::hash::Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
120{
121 ensure!(
122 tt.len() == node_names.len(),
123 "tensor_train_to_treetn: node_names length {} must match tensor-train length {}",
124 node_names.len(),
125 tt.len()
126 );
127
128 if tt.is_empty() {
129 let site_indices = site_indices.unwrap_or_default();
130 ensure!(
131 site_indices.is_empty(),
132 "tensor_train_to_treetn: empty tensor train requires zero site indices"
133 );
134 return Ok((TreeTN::new(), Vec::new()));
135 }
136
137 let site_indices = match site_indices {
138 Some(indices) => {
139 ensure!(
140 indices.len() == tt.len(),
141 "tensor_train_to_treetn: site_indices length {} must match tensor-train length {}",
142 indices.len(),
143 tt.len()
144 );
145 for (site, index) in indices.iter().enumerate() {
146 ensure!(
147 index.dim() == tt.site_dim(site),
148 "tensor_train_to_treetn: site index {} has dim {} but tensor-train site {} has dim {}",
149 site,
150 index.dim(),
151 site,
152 tt.site_dim(site)
153 );
154 }
155 indices
156 }
157 None => tt.site_dims().into_iter().map(DynIndex::new_dyn).collect(),
158 };
159 let bond_indices: Vec<DynIndex> = tt
160 .link_dims()
161 .into_iter()
162 .map(DynIndex::new_bond)
163 .collect::<Result<_>>()?;
164 let nsites = tt.len();
165
166 let mut tensors = Vec::with_capacity(nsites);
167 for site in 0..nsites {
168 let site_tensor = tt.site_tensor(site);
169 let tensor = if nsites == 1 {
170 let data = single_site_data(site_tensor);
171 TensorDynLen::from_dense(vec![site_indices[site].clone()], data)?
172 } else if site == 0 {
173 let data = left_boundary_data(site_tensor);
174 TensorDynLen::from_dense(
175 vec![site_indices[site].clone(), bond_indices[site].clone()],
176 data,
177 )?
178 } else if site + 1 == nsites {
179 let data = right_boundary_data(site_tensor);
180 TensorDynLen::from_dense(
181 vec![bond_indices[site - 1].clone(), site_indices[site].clone()],
182 data,
183 )?
184 } else {
185 let data = middle_site_data(site_tensor);
186 TensorDynLen::from_dense(
187 vec![
188 bond_indices[site - 1].clone(),
189 site_indices[site].clone(),
190 bond_indices[site].clone(),
191 ],
192 data,
193 )?
194 };
195 tensors.push(tensor);
196 }
197
198 let treetn = TreeTN::from_tensors(tensors, node_names)?;
199 Ok((treetn, site_indices))
200}
201
202fn single_site_data<T>(tensor: &tensor4all_simplett::Tensor3<T>) -> Vec<T>
203where
204 T: TTScalar + Clone,
205{
206 let mut data = Vec::with_capacity(tensor.site_dim());
207 for s in 0..tensor.site_dim() {
208 data.push(*tensor.get3(0, s, 0));
209 }
210 data
211}
212
213fn left_boundary_data<T>(tensor: &tensor4all_simplett::Tensor3<T>) -> Vec<T>
214where
215 T: TTScalar + Clone,
216{
217 let mut data = Vec::with_capacity(tensor.site_dim() * tensor.right_dim());
218 for r in 0..tensor.right_dim() {
219 for s in 0..tensor.site_dim() {
220 data.push(*tensor.get3(0, s, r));
221 }
222 }
223 data
224}
225
226fn right_boundary_data<T>(tensor: &tensor4all_simplett::Tensor3<T>) -> Vec<T>
227where
228 T: TTScalar + Clone,
229{
230 let mut data = Vec::with_capacity(tensor.left_dim() * tensor.site_dim());
231 for s in 0..tensor.site_dim() {
232 for l in 0..tensor.left_dim() {
233 data.push(*tensor.get3(l, s, 0));
234 }
235 }
236 data
237}
238
239fn middle_site_data<T>(tensor: &tensor4all_simplett::Tensor3<T>) -> Vec<T>
240where
241 T: TTScalar + Clone,
242{
243 let mut data = Vec::with_capacity(tensor.left_dim() * tensor.site_dim() * tensor.right_dim());
244 for r in 0..tensor.right_dim() {
245 for s in 0..tensor.site_dim() {
246 for l in 0..tensor.left_dim() {
247 data.push(*tensor.get3(l, s, r));
248 }
249 }
250 }
251 data
252}