tensor4all_treetn/treetn/
transform.rs1use std::collections::{HashMap, HashSet};
8use std::hash::Hash;
9
10use anyhow::{Context, Result};
11use petgraph::stable_graph::NodeIndex;
12
13use tensor4all_core::{
14 AllowedPairs, Canonical, FactorizeAlg, FactorizeOptions, IndexLike, TensorLike,
15};
16
17use super::TreeTN;
18use crate::options::SplitOptions;
19use crate::site_index_network::SiteIndexNetwork;
20
21impl<T, V> TreeTN<T, V>
22where
23 T: TensorLike,
24 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
25 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
26{
27 pub fn fuse_to<TargetV>(
60 &self,
61 target: &SiteIndexNetwork<TargetV, T::Index>,
62 ) -> Result<TreeTN<T, TargetV>>
63 where
64 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
65 {
66 let mut site_to_current_node: HashMap<<T::Index as IndexLike>::Id, V> = HashMap::new();
68 for current_node_name in self.node_names() {
69 if let Some(site_space) = self.site_space(¤t_node_name) {
70 for site_idx in site_space {
71 site_to_current_node.insert(site_idx.id().clone(), current_node_name.clone());
72 }
73 }
74 }
75
76 let mut target_to_current: HashMap<TargetV, HashSet<V>> = HashMap::new();
79
80 for target_node_name in target.node_names() {
81 let target_site_space = target.site_space(target_node_name).ok_or_else(|| {
82 anyhow::anyhow!("Target node {:?} has no site space", target_node_name)
83 })?;
84
85 let mut current_nodes_for_target: HashSet<V> = HashSet::new();
86 for target_site_idx in target_site_space {
87 if let Some(current_node) = site_to_current_node.get(target_site_idx.id()) {
88 current_nodes_for_target.insert(current_node.clone());
89 }
90 }
91
92 if current_nodes_for_target.is_empty() {
93 return Err(anyhow::anyhow!(
94 "Target node {:?} has site indices not found in current TreeTN",
95 target_node_name
96 ))
97 .context("fuse_to: incompatible target structure");
98 }
99
100 target_to_current.insert(target_node_name.clone(), current_nodes_for_target);
101 }
102
103 let mut current_to_target: HashMap<V, TargetV> = HashMap::new();
105 for (target_name, current_nodes) in &target_to_current {
106 for current_node in current_nodes {
107 if let Some(existing_target) = current_to_target.get(current_node) {
108 return Err(anyhow::anyhow!(
109 "Current node {:?} maps to multiple target nodes: {:?} and {:?}",
110 current_node,
111 existing_target,
112 target_name
113 ))
114 .context("fuse_to: ambiguous mapping");
115 }
116 current_to_target.insert(current_node.clone(), target_name.clone());
117 }
118 }
119
120 for current_name in self.node_names() {
122 if !current_to_target.contains_key(¤t_name) {
123 return Err(anyhow::anyhow!(
124 "Current node {:?} has no corresponding target node",
125 current_name
126 ))
127 .context("fuse_to: missing target for current node");
128 }
129 }
130
131 let mut result_tensors: HashMap<TargetV, T> = HashMap::new();
133
134 for (target_name, current_nodes) in &target_to_current {
135 let contracted = self.contract_node_group(current_nodes).with_context(|| {
136 format!(
137 "fuse_to: failed to contract nodes for target {:?}",
138 target_name
139 )
140 })?;
141 result_tensors.insert(target_name.clone(), contracted);
142 }
143
144 let mut target_names: Vec<TargetV> = target.node_names().into_iter().cloned().collect();
147 target_names.sort();
148
149 let tensors: Vec<T> = target_names
150 .iter()
151 .map(|name| result_tensors.remove(name).unwrap())
152 .collect();
153
154 let result = TreeTN::<T, TargetV>::from_tensors(tensors, target_names)
155 .context("fuse_to: failed to build result TreeTN")?;
156
157 Ok(result)
158 }
159
160 fn contract_node_group(&self, nodes: &HashSet<V>) -> Result<T>
166 where
167 V: Ord,
168 {
169 if nodes.is_empty() {
170 return Err(anyhow::anyhow!("Cannot contract empty node group"));
171 }
172
173 let node_indices: HashSet<NodeIndex> = nodes
175 .iter()
176 .filter_map(|name| self.graph.node_index(name))
177 .collect();
178
179 if node_indices.len() != nodes.len() {
180 return Err(anyhow::anyhow!(
181 "Some nodes not found in graph: expected {} nodes, found {}",
182 nodes.len(),
183 node_indices.len()
184 ));
185 }
186
187 if nodes.len() == 1 {
189 let node_name = nodes.iter().next().unwrap();
190 let node_idx = self.graph.node_index(node_name).unwrap();
191 return self
192 .tensor(node_idx)
193 .cloned()
194 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name));
195 }
196
197 if !self.site_index_network.is_connected_subset(&node_indices) {
199 return Err(anyhow::anyhow!(
200 "Nodes to contract do not form a connected subtree"
201 ));
202 }
203
204 let root_name = nodes.iter().min().unwrap();
206 let root_idx = self.graph.node_index(root_name).unwrap();
207
208 let edges = self
210 .site_index_network
211 .edges_to_canonicalize(None, root_idx);
212
213 let internal_edges: Vec<(NodeIndex, NodeIndex)> = edges
215 .iter()
216 .filter(|(from, to)| node_indices.contains(from) && node_indices.contains(to))
217 .cloned()
218 .collect();
219
220 let mut tensors: HashMap<NodeIndex, T> = node_indices
222 .iter()
223 .filter_map(|&idx| self.tensor(idx).cloned().map(|t| (idx, t)))
224 .collect();
225
226 for (from, to) in internal_edges {
228 let from_tensor = tensors
229 .remove(&from)
230 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", from))?;
231 let to_tensor = tensors
232 .remove(&to)
233 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", to))?;
234
235 let contracted = T::contract(&[&to_tensor, &from_tensor], AllowedPairs::All)
238 .map_err(|e| anyhow::anyhow!("Failed to contract tensors: {}", e))?;
239
240 tensors.insert(to, contracted);
241 }
242
243 tensors
245 .remove(&root_idx)
246 .ok_or_else(|| anyhow::anyhow!("Contraction produced no result at root"))
247 }
248
249 pub fn split_to<TargetV>(
287 &self,
288 target: &SiteIndexNetwork<TargetV, T::Index>,
289 options: &SplitOptions,
290 ) -> Result<TreeTN<T, TargetV>>
291 where
292 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
293 {
294 let mut site_to_target: HashMap<<T::Index as IndexLike>::Id, TargetV> = HashMap::new();
296 for target_node_name in target.node_names() {
297 if let Some(site_space) = target.site_space(target_node_name) {
298 for site_idx in site_space {
299 site_to_target.insert(site_idx.id().clone(), target_node_name.clone());
300 }
301 }
302 }
303
304 let mut current_to_targets: HashMap<V, HashSet<TargetV>> = HashMap::new();
307 for current_node_name in self.node_names() {
308 if let Some(site_space) = self.site_space(¤t_node_name) {
309 let mut targets_for_node: HashSet<TargetV> = HashSet::new();
310 for site_idx in site_space {
311 if let Some(target_name) = site_to_target.get(site_idx.id()) {
312 targets_for_node.insert(target_name.clone());
313 } else {
314 return Err(anyhow::anyhow!(
315 "Site index {:?} in current node {:?} has no corresponding target node",
316 site_idx.id(),
317 current_node_name
318 ))
319 .context("split_to: incompatible target structure");
320 }
321 }
322 current_to_targets.insert(current_node_name.clone(), targets_for_node);
323 }
324 }
325
326 let mut result_tensors: Vec<(TargetV, T)> = Vec::new();
329
330 for current_node_name in self.node_names() {
331 let node_idx = self
332 .node_index(¤t_node_name)
333 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", current_node_name))?;
334 let tensor = self
335 .tensor(node_idx)
336 .ok_or_else(|| anyhow::anyhow!("Tensor not found for {:?}", current_node_name))?;
337
338 let targets_for_node = current_to_targets.get(¤t_node_name).ok_or_else(|| {
339 anyhow::anyhow!("No target mapping for node {:?}", current_node_name)
340 })?;
341
342 if targets_for_node.len() == 1 {
343 let target_name = targets_for_node.iter().next().unwrap().clone();
345 result_tensors.push((target_name, tensor.clone()));
346 } else {
347 let split_tensors = self
349 .split_tensor_for_targets(tensor, &site_to_target)
350 .with_context(|| {
351 format!("split_to: failed to split node {:?}", current_node_name)
352 })?;
353 result_tensors.extend(split_tensors);
354 }
355 }
356
357 result_tensors.sort_by(|(a, _), (b, _)| a.cmp(b));
360
361 let names: Vec<TargetV> = result_tensors
362 .iter()
363 .map(|(name, _)| name.clone())
364 .collect();
365 let tensors: Vec<T> = result_tensors.into_iter().map(|(_, t)| t).collect();
366
367 let result = TreeTN::<T, TargetV>::from_tensors(tensors, names)
368 .context("split_to: failed to build result TreeTN")?;
369
370 if options.final_sweep {
372 let center = result.node_names().into_iter().min().ok_or_else(|| {
374 anyhow::anyhow!("split_to: no nodes in result for truncation sweep")
375 })?;
376
377 let truncation_options = crate::TruncationOptions {
378 form: options.form,
379 truncation: options.truncation,
380 };
381
382 return result
383 .truncate([center], truncation_options)
384 .context("split_to: truncation sweep failed");
385 }
386
387 Ok(result)
388 }
389
390 fn split_tensor_for_targets<TargetV>(
397 &self,
398 tensor: &T,
399 site_to_target: &HashMap<<T::Index as IndexLike>::Id, TargetV>,
400 ) -> Result<Vec<(TargetV, T)>>
401 where
402 TargetV: Clone + Hash + Eq + Ord + std::fmt::Debug,
403 {
404 let mut partition: HashMap<TargetV, HashSet<<T::Index as IndexLike>::Id>> = HashMap::new();
406 for idx in tensor.external_indices() {
407 if let Some(target_name) = site_to_target.get(idx.id()) {
408 partition
409 .entry(target_name.clone())
410 .or_default()
411 .insert(idx.id().clone());
412 }
413 }
415
416 let mut target_names: Vec<TargetV> = partition.keys().cloned().collect();
418 target_names.sort();
419
420 if target_names.len() <= 1 {
421 let target_name = target_names
423 .first()
424 .cloned()
425 .ok_or_else(|| anyhow::anyhow!("No site indices found in tensor"))?;
426 return Ok(vec![(target_name, tensor.clone())]);
427 }
428
429 let mut remaining_tensor = tensor.clone();
431 let mut result: Vec<(TargetV, T)> = Vec::new();
432
433 for target_name in target_names.iter().take(target_names.len() - 1) {
435 let site_ids_for_target = partition.get(target_name).unwrap();
436
437 let left_inds: Vec<_> = remaining_tensor
439 .external_indices()
440 .iter()
441 .filter(|idx| site_ids_for_target.contains(idx.id()))
442 .cloned()
443 .collect();
444
445 if left_inds.is_empty() {
446 continue;
447 }
448
449 let factorize_options = FactorizeOptions {
451 alg: FactorizeAlg::QR,
452 canonical: Canonical::Left,
453 max_rank: None,
454 svd_policy: None,
455 qr_rtol: None,
456 };
457
458 let factorize_result = remaining_tensor
459 .factorize(&left_inds, &factorize_options)
460 .map_err(|e| anyhow::anyhow!("Factorization failed: {:?}", e))?;
461
462 result.push((target_name.clone(), factorize_result.left));
464
465 remaining_tensor = factorize_result.right;
467 }
468
469 let last_target = target_names.last().unwrap().clone();
471 result.push((last_target, remaining_tensor));
472
473 Ok(result)
474 }
475}
476
477#[cfg(test)]
479mod tests;