1use std::collections::{HashMap, HashSet};
17use std::hash::Hash;
18
19use crate::node_name_network::NodeNameNetwork;
20use anyhow::{bail, Context, Result};
21use petgraph::stable_graph::NodeIndex;
22use tensor4all_core::{IndexLike, TensorLike};
23
24use super::TreeTN;
25use crate::{RestructureOptions, SiteIndexNetwork};
26
27#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
28struct FragmentNode<CurrentV, TargetV> {
29 current: CurrentV,
30 split_rank: usize,
31 target: TargetV,
32}
33
34type SplitThenFuseTarget<CurrentV, TargetV, I> =
35 SiteIndexNetwork<FragmentNode<CurrentV, TargetV>, I>;
36
37#[derive(Debug, Clone)]
38enum RestructurePlanKind<CurrentV, TargetV, I>
39where
40 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
41 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
42 I: IndexLike,
43 I::Id: Eq + Hash,
44{
45 FuseOnly,
46 SplitOnly,
47 SwapOnly {
48 target_assignment: HashMap<I::Id, CurrentV>,
49 },
50 SwapThenFuse {
51 target_assignment: HashMap<I::Id, CurrentV>,
52 },
53 SplitThenFuse {
54 split_target: Box<SplitThenFuseTarget<CurrentV, TargetV, I>>,
55 },
56}
57
58#[derive(Debug, Clone)]
59struct RestructurePlan<CurrentV, TargetV, I>
60where
61 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
62 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
63 I: IndexLike,
64 I::Id: Eq + Hash,
65{
66 kind: RestructurePlanKind<CurrentV, TargetV, I>,
67}
68
69fn collect_site_targets<T, TargetV>(
70 target: &SiteIndexNetwork<TargetV, T::Index>,
71) -> Result<HashMap<<T::Index as IndexLike>::Id, TargetV>>
72where
73 T: TensorLike,
74 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
75 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
76{
77 let mut site_to_target = HashMap::new();
78 for target_node_name in target.node_names() {
79 let site_space = target.site_space(target_node_name).ok_or_else(|| {
80 anyhow::anyhow!(
81 "restructure_to: target node {:?} has no registered site space",
82 target_node_name
83 )
84 })?;
85 for site_idx in site_space {
86 let existing = site_to_target.insert(site_idx.id().clone(), target_node_name.clone());
87 if let Some(previous_target) = existing {
88 bail!(
89 "restructure_to: site index {:?} appears in both target nodes {:?} and {:?}",
90 site_idx.id(),
91 previous_target,
92 target_node_name
93 );
94 }
95 }
96 }
97 Ok(site_to_target)
98}
99
100fn collect_current_site_ids<T, CurrentV>(
101 current: &SiteIndexNetwork<CurrentV, T::Index>,
102) -> Result<HashSet<<T::Index as IndexLike>::Id>>
103where
104 T: TensorLike,
105 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
106 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
107{
108 let mut site_ids = HashSet::new();
109 for current_node_name in current.node_names() {
110 let site_space = current.site_space(current_node_name).ok_or_else(|| {
111 anyhow::anyhow!(
112 "restructure_to: current node {:?} has no registered site space",
113 current_node_name
114 )
115 })?;
116 for site_idx in site_space {
117 site_ids.insert(site_idx.id().clone());
118 }
119 }
120 Ok(site_ids)
121}
122
123fn current_nodes_map_uniquely_to_targets<T, CurrentV, TargetV>(
124 current: &SiteIndexNetwork<CurrentV, T::Index>,
125 site_to_target: &HashMap<<T::Index as IndexLike>::Id, TargetV>,
126) -> Result<bool>
127where
128 T: TensorLike,
129 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
130 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
131 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
132{
133 for current_node_name in current.node_names() {
134 let site_space = current.site_space(current_node_name).ok_or_else(|| {
135 anyhow::anyhow!(
136 "restructure_to: current node {:?} has no registered site space",
137 current_node_name
138 )
139 })?;
140 let target_names: HashSet<_> = site_space
141 .iter()
142 .map(|site_idx| {
143 site_to_target
144 .get(site_idx.id())
145 .cloned()
146 .ok_or_else(|| {
147 anyhow::anyhow!(
148 "restructure_to: site index {:?} is present in the current network but missing from the target",
149 site_idx.id()
150 )
151 })
152 })
153 .collect::<Result<_>>()?;
154 if target_names.len() > 1 {
155 return Ok(false);
156 }
157 }
158 Ok(true)
159}
160
161fn collect_site_currents<T, CurrentV>(
162 current: &SiteIndexNetwork<CurrentV, T::Index>,
163) -> Result<HashMap<<T::Index as IndexLike>::Id, CurrentV>>
164where
165 T: TensorLike,
166 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
167 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
168{
169 let mut site_to_current = HashMap::new();
170 for current_node_name in current.node_names() {
171 let site_space = current.site_space(current_node_name).ok_or_else(|| {
172 anyhow::anyhow!(
173 "restructure_to: current node {:?} has no registered site space",
174 current_node_name
175 )
176 })?;
177 for site_idx in site_space {
178 let existing = site_to_current.insert(site_idx.id().clone(), current_node_name.clone());
179 if let Some(previous_current) = existing {
180 bail!(
181 "restructure_to: site index {:?} appears in both current nodes {:?} and {:?}",
182 site_idx.id(),
183 previous_current,
184 current_node_name
185 );
186 }
187 }
188 }
189 Ok(site_to_current)
190}
191
192fn target_nodes_map_uniquely_to_currents<T, CurrentV, TargetV>(
193 target: &SiteIndexNetwork<TargetV, T::Index>,
194 site_to_current: &HashMap<<T::Index as IndexLike>::Id, CurrentV>,
195) -> Result<bool>
196where
197 T: TensorLike,
198 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
199 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
200 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
201{
202 for target_node_name in target.node_names() {
203 let site_space = target.site_space(target_node_name).ok_or_else(|| {
204 anyhow::anyhow!(
205 "restructure_to: target node {:?} has no registered site space",
206 target_node_name
207 )
208 })?;
209 let current_names: HashSet<_> = site_space
210 .iter()
211 .map(|site_idx| {
212 site_to_current
213 .get(site_idx.id())
214 .cloned()
215 .ok_or_else(|| {
216 anyhow::anyhow!(
217 "restructure_to: site index {:?} is present in the target but missing from the current network",
218 site_idx.id()
219 )
220 })
221 })
222 .collect::<Result<_>>()?;
223 if current_names.len() > 1 {
224 return Ok(false);
225 }
226 }
227 Ok(true)
228}
229
230fn target_nodes_span_connected_currents<T, CurrentV, TargetV>(
231 current: &SiteIndexNetwork<CurrentV, T::Index>,
232 target: &SiteIndexNetwork<TargetV, T::Index>,
233 site_to_current: &HashMap<<T::Index as IndexLike>::Id, CurrentV>,
234) -> Result<bool>
235where
236 T: TensorLike,
237 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
238 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
239 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
240{
241 for target_node_name in target.node_names() {
242 let site_space = target.site_space(target_node_name).ok_or_else(|| {
243 anyhow::anyhow!(
244 "restructure_to: target node {:?} has no registered site space",
245 target_node_name
246 )
247 })?;
248 let current_nodes: HashSet<_> = site_space
249 .iter()
250 .map(|site_idx| {
251 let current_name = site_to_current.get(site_idx.id()).ok_or_else(|| {
252 anyhow::anyhow!(
253 "restructure_to: site index {:?} is present in the target but missing from the current network",
254 site_idx.id()
255 )
256 })?;
257 current.node_index(current_name).ok_or_else(|| {
258 anyhow::anyhow!(
259 "restructure_to: current node {:?} is missing from the topology",
260 current_name
261 )
262 })
263 })
264 .collect::<Result<_>>()?;
265 if !current.is_connected_subset(¤t_nodes) {
266 return Ok(false);
267 }
268 }
269
270 Ok(true)
271}
272
273fn collect_shared_targets<T, CurrentV, TargetV>(
274 target: &SiteIndexNetwork<TargetV, T::Index>,
275 site_to_current: &HashMap<<T::Index as IndexLike>::Id, CurrentV>,
276) -> Result<HashSet<TargetV>>
277where
278 T: TensorLike,
279 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
280 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
281 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
282{
283 let mut shared_targets = HashSet::new();
284 for target_node_name in target.node_names() {
285 let site_space = target.site_space(target_node_name).ok_or_else(|| {
286 anyhow::anyhow!(
287 "restructure_to: target node {:?} has no registered site space",
288 target_node_name
289 )
290 })?;
291 let current_names: HashSet<_> = site_space
292 .iter()
293 .map(|site_idx| {
294 site_to_current
295 .get(site_idx.id())
296 .cloned()
297 .ok_or_else(|| {
298 anyhow::anyhow!(
299 "restructure_to: site index {:?} is present in the target but missing from the current network",
300 site_idx.id()
301 )
302 })
303 })
304 .collect::<Result<_>>()?;
305 if current_names.len() > 1 {
306 shared_targets.insert(target_node_name.clone());
307 }
308 }
309 Ok(shared_targets)
310}
311
312fn build_split_then_fuse_target<T, CurrentV, TargetV>(
313 current: &SiteIndexNetwork<CurrentV, T::Index>,
314 target: &SiteIndexNetwork<TargetV, T::Index>,
315 site_to_target: &HashMap<<T::Index as IndexLike>::Id, TargetV>,
316 site_to_current: &HashMap<<T::Index as IndexLike>::Id, CurrentV>,
317) -> Result<Option<SplitThenFuseTarget<CurrentV, TargetV, T::Index>>>
318where
319 T: TensorLike,
320 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
321 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
322 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
323{
324 if !target_nodes_span_connected_currents::<T, CurrentV, TargetV>(
325 current,
326 target,
327 site_to_current,
328 )? {
329 return Ok(None);
330 }
331
332 let shared_targets = collect_shared_targets::<T, CurrentV, TargetV>(target, site_to_current)?;
333 let mut split_target = SiteIndexNetwork::with_capacity(current.node_count(), 0);
334 let mut current_node_names: Vec<_> = current.node_names().into_iter().cloned().collect();
335 current_node_names.sort();
336
337 for current_node_name in current_node_names {
338 let site_space = current.site_space(¤t_node_name).ok_or_else(|| {
339 anyhow::anyhow!(
340 "restructure_to: current node {:?} has no registered site space",
341 current_node_name
342 )
343 })?;
344 let mut fragments: HashMap<TargetV, HashSet<T::Index>> = HashMap::new();
345 for site_idx in site_space {
346 let target_node_name = site_to_target.get(site_idx.id()).cloned().ok_or_else(|| {
347 anyhow::anyhow!(
348 "restructure_to: site index {:?} is present in the current network but missing from the target",
349 site_idx.id()
350 )
351 })?;
352 fragments
353 .entry(target_node_name)
354 .or_default()
355 .insert(site_idx.clone());
356 }
357
358 let shared_targets_here: Vec<_> = fragments
359 .keys()
360 .filter(|target_name| shared_targets.contains(*target_name))
361 .cloned()
362 .collect();
363 if shared_targets_here.len() > 1 {
364 return Ok(None);
365 }
366 let boundary_target = shared_targets_here.first().cloned();
367
368 let mut fragments: Vec<_> = fragments.into_iter().collect();
369 fragments.sort_by(|(left_name, _), (right_name, _)| {
370 let left_is_boundary = boundary_target.as_ref() == Some(left_name);
371 let right_is_boundary = boundary_target.as_ref() == Some(right_name);
372 left_is_boundary
373 .cmp(&right_is_boundary)
374 .then_with(|| left_name.cmp(right_name))
375 });
376
377 for (split_rank, (target_node_name, fragment_site_space)) in
378 fragments.into_iter().enumerate()
379 {
380 split_target
381 .add_node(
382 FragmentNode {
383 current: current_node_name.clone(),
384 split_rank,
385 target: target_node_name,
386 },
387 fragment_site_space,
388 )
389 .map_err(anyhow::Error::msg)?;
390 }
391 }
392
393 Ok(Some(split_target))
394}
395
396fn ordered_path_nodes<V>(topology: &NodeNameNetwork<V>) -> Option<Vec<V>>
397where
398 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
399{
400 let graph = topology.graph();
401 match topology.node_count() {
402 0 => return Some(Vec::new()),
403 1 => {
404 let node = graph.node_indices().next()?;
405 return Some(vec![topology.node_name(node)?.clone()]);
406 }
407 _ => {}
408 }
409
410 let mut leaves = Vec::new();
411 for node in graph.node_indices() {
412 let degree = graph.neighbors(node).count();
413 if degree > 2 {
414 return None;
415 }
416 if degree == 1 {
417 leaves.push(node);
418 }
419 }
420 if leaves.len() != 2 {
421 return None;
422 }
423
424 leaves.sort_by_key(|node| topology.node_name(*node).cloned());
425 let mut ordered = Vec::with_capacity(topology.node_count());
426 let mut previous = None;
427 let mut current = *leaves.first()?;
428
429 loop {
430 ordered.push(topology.node_name(current)?.clone());
431 let next = graph
432 .neighbors(current)
433 .find(|neighbor| Some(*neighbor) != previous);
434 let Some(next) = next else {
435 break;
436 };
437 previous = Some(current);
438 current = next;
439 }
440
441 if ordered.len() == topology.node_count() {
442 Some(ordered)
443 } else {
444 None
445 }
446}
447
448fn build_path_swap_then_fuse_assignment<T, CurrentV, TargetV>(
449 current: &SiteIndexNetwork<CurrentV, T::Index>,
450 target: &SiteIndexNetwork<TargetV, T::Index>,
451 site_to_target: &HashMap<<T::Index as IndexLike>::Id, TargetV>,
452) -> Result<Option<HashMap<<T::Index as IndexLike>::Id, CurrentV>>>
453where
454 T: TensorLike,
455 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
456 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
457 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
458{
459 let current_path = match ordered_path_nodes(current.topology()) {
460 Some(path) => path,
461 None => return Ok(None),
462 };
463 let target_path = match ordered_path_nodes(target.topology()) {
464 Some(path) => path,
465 None => return Ok(None),
466 };
467
468 let mut contributor_counts: HashMap<TargetV, usize> = HashMap::new();
469 for current_node_name in ¤t_path {
470 let site_space = current.site_space(current_node_name).ok_or_else(|| {
471 anyhow::anyhow!(
472 "restructure_to: current node {:?} has no registered site space",
473 current_node_name
474 )
475 })?;
476 let mut target_names: Vec<_> = site_space
477 .iter()
478 .map(|site_idx| {
479 site_to_target
480 .get(site_idx.id())
481 .cloned()
482 .ok_or_else(|| {
483 anyhow::anyhow!(
484 "restructure_to: site index {:?} is present in the current network but missing from the target",
485 site_idx.id()
486 )
487 })
488 })
489 .collect::<Result<_>>()?;
490 target_names.sort();
491 target_names.dedup();
492 if target_names.len() != 1 {
493 return Ok(None);
494 }
495 let target_name = target_names.into_iter().next().ok_or_else(|| {
496 anyhow::anyhow!(
497 "restructure_to: current node {:?} has no target mapping",
498 current_node_name
499 )
500 })?;
501 *contributor_counts.entry(target_name).or_default() += 1;
502 }
503
504 let total_contributors: usize = contributor_counts.values().sum();
505 if total_contributors != current_path.len() {
506 return Ok(None);
507 }
508
509 let mut contiguous_blocks: HashMap<TargetV, Vec<CurrentV>> = HashMap::new();
510 let mut cursor = 0usize;
511 for target_node_name in &target_path {
512 let block_len = *contributor_counts.get(target_node_name).unwrap_or(&0);
513 if block_len == 0 || cursor + block_len > current_path.len() {
514 return Ok(None);
515 }
516 contiguous_blocks.insert(
517 target_node_name.clone(),
518 current_path[cursor..cursor + block_len].to_vec(),
519 );
520 cursor += block_len;
521 }
522 if cursor != current_path.len() {
523 return Ok(None);
524 }
525
526 let mut target_assignment = HashMap::new();
527 for target_node_name in &target_path {
528 let block = contiguous_blocks.get(target_node_name).ok_or_else(|| {
529 anyhow::anyhow!(
530 "restructure_to: missing contiguous block for target {:?}",
531 target_node_name
532 )
533 })?;
534 let mut site_ids: Vec<_> = target
535 .site_space(target_node_name)
536 .ok_or_else(|| {
537 anyhow::anyhow!(
538 "restructure_to: target node {:?} has no registered site space",
539 target_node_name
540 )
541 })?
542 .iter()
543 .map(|site_idx| site_idx.id().clone())
544 .collect();
545 site_ids.sort();
546 if site_ids.len() < block.len() {
547 return Ok(None);
548 }
549
550 for (position, site_id) in site_ids.into_iter().enumerate() {
551 let block_index = position.min(block.len() - 1);
552 target_assignment.insert(site_id, block[block_index].clone());
553 }
554 }
555
556 Ok(Some(target_assignment))
557}
558
559fn tree_children<V>(
560 topology: &NodeNameNetwork<V>,
561 node: NodeIndex,
562 parent: Option<NodeIndex>,
563) -> Vec<NodeIndex>
564where
565 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
566{
567 topology
568 .graph()
569 .neighbors(node)
570 .filter(|neighbor| Some(*neighbor) != parent)
571 .collect()
572}
573
574fn rooted_signature<V>(
575 topology: &NodeNameNetwork<V>,
576 node: NodeIndex,
577 parent: Option<NodeIndex>,
578 cache: &mut HashMap<(usize, Option<usize>), String>,
579) -> String
580where
581 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
582{
583 let key = (node.index(), parent.map(|p| p.index()));
584 if let Some(signature) = cache.get(&key) {
585 return signature.clone();
586 }
587
588 let mut child_signatures: Vec<String> = tree_children(topology, node, parent)
589 .into_iter()
590 .map(|child| rooted_signature(topology, child, Some(node), cache))
591 .collect();
592 child_signatures.sort();
593
594 let signature = format!("({})", child_signatures.concat());
595 cache.insert(key, signature.clone());
596 signature
597}
598
599#[derive(Default)]
600struct IsomorphicMatchState {
601 current_cache: HashMap<(usize, Option<usize>), String>,
602 target_cache: HashMap<(usize, Option<usize>), String>,
603 mapping: HashMap<NodeIndex, NodeIndex>,
604}
605
606fn match_isomorphic_subtrees<CurrentV, TargetV>(
607 current_topology: &NodeNameNetwork<CurrentV>,
608 target_topology: &NodeNameNetwork<TargetV>,
609 current_node: NodeIndex,
610 current_parent: Option<NodeIndex>,
611 target_node: NodeIndex,
612 target_parent: Option<NodeIndex>,
613 state: &mut IsomorphicMatchState,
614) -> bool
615where
616 CurrentV: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
617 TargetV: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
618{
619 if let Some(existing) = state.mapping.insert(target_node, current_node) {
620 if existing != current_node {
621 return false;
622 }
623 }
624
625 let current_children = tree_children(current_topology, current_node, current_parent);
626 let target_children = tree_children(target_topology, target_node, target_parent);
627 if current_children.len() != target_children.len() {
628 return false;
629 }
630
631 let mut current_groups: HashMap<String, Vec<NodeIndex>> = HashMap::new();
632 for child in current_children {
633 let signature = rooted_signature(
634 current_topology,
635 child,
636 Some(current_node),
637 &mut state.current_cache,
638 );
639 current_groups.entry(signature).or_default().push(child);
640 }
641
642 let mut target_groups: HashMap<String, Vec<NodeIndex>> = HashMap::new();
643 for child in target_children {
644 let signature = rooted_signature(
645 target_topology,
646 child,
647 Some(target_node),
648 &mut state.target_cache,
649 );
650 target_groups.entry(signature).or_default().push(child);
651 }
652
653 if current_groups.len() != target_groups.len() {
654 return false;
655 }
656
657 let mut signatures: Vec<String> = current_groups.keys().cloned().collect();
658 signatures.sort();
659 for signature in signatures {
660 let mut current_bucket = match current_groups.remove(&signature) {
661 Some(bucket) => bucket,
662 None => return false,
663 };
664 let mut target_bucket = match target_groups.remove(&signature) {
665 Some(bucket) => bucket,
666 None => return false,
667 };
668 if current_bucket.len() != target_bucket.len() {
669 return false;
670 }
671
672 current_bucket.sort_by_key(|node| node.index());
673 target_bucket.sort_by_key(|node| node.index());
674
675 for (current_child, target_child) in current_bucket.into_iter().zip(target_bucket) {
676 if !match_isomorphic_subtrees(
677 current_topology,
678 target_topology,
679 current_child,
680 Some(current_node),
681 target_child,
682 Some(target_node),
683 state,
684 ) {
685 return false;
686 }
687 }
688 }
689
690 true
691}
692
693fn match_tree_topologies<CurrentV, TargetV>(
694 current_topology: &NodeNameNetwork<CurrentV>,
695 target_topology: &NodeNameNetwork<TargetV>,
696) -> Option<HashMap<NodeIndex, NodeIndex>>
697where
698 CurrentV: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
699 TargetV: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
700{
701 if current_topology.node_count() != target_topology.node_count() {
702 return None;
703 }
704 if current_topology.edge_count() != target_topology.edge_count() {
705 return None;
706 }
707
708 let mut current_roots: Vec<(String, NodeIndex)> = current_topology
709 .graph()
710 .node_indices()
711 .map(|node| {
712 (
713 rooted_signature(current_topology, node, None, &mut HashMap::new()),
714 node,
715 )
716 })
717 .collect();
718 current_roots.sort_by_key(|(signature, node)| (signature.clone(), node.index()));
719
720 let mut target_roots: Vec<(String, NodeIndex)> = target_topology
721 .graph()
722 .node_indices()
723 .map(|node| {
724 (
725 rooted_signature(target_topology, node, None, &mut HashMap::new()),
726 node,
727 )
728 })
729 .collect();
730 target_roots.sort_by_key(|(signature, node)| (signature.clone(), node.index()));
731
732 for (target_signature, target_root) in &target_roots {
733 for (current_signature, current_root) in ¤t_roots {
734 if current_signature != target_signature {
735 continue;
736 }
737
738 let mut state = IsomorphicMatchState::default();
739 if match_isomorphic_subtrees(
740 current_topology,
741 target_topology,
742 *current_root,
743 None,
744 *target_root,
745 None,
746 &mut state,
747 ) && state.mapping.len() == target_topology.node_count()
748 {
749 return Some(state.mapping);
750 }
751 }
752 }
753
754 None
755}
756
757fn build_swap_assignment<T, CurrentV, TargetV>(
758 current: &SiteIndexNetwork<CurrentV, T::Index>,
759 target: &SiteIndexNetwork<TargetV, T::Index>,
760) -> Result<Option<HashMap<<T::Index as IndexLike>::Id, CurrentV>>>
761where
762 T: TensorLike,
763 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
764 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
765 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
766{
767 let topology_mapping = match match_tree_topologies(current.topology(), target.topology()) {
768 Some(mapping) => mapping,
769 None => return Ok(None),
770 };
771
772 let mut assignment = HashMap::new();
773 for target_node in target.graph().node_indices() {
774 let target_name = target.node_name(target_node).ok_or_else(|| {
775 anyhow::anyhow!(
776 "restructure_to: target topology mapping referenced a missing node index {:?}",
777 target_node
778 )
779 })?;
780 let current_node = *topology_mapping.get(&target_node).ok_or_else(|| {
781 anyhow::anyhow!(
782 "restructure_to: target topology mapping did not assign a current node to {:?}",
783 target_name
784 )
785 })?;
786 let current_name = current.node_name(current_node).ok_or_else(|| {
787 anyhow::anyhow!(
788 "restructure_to: current topology mapping referenced a missing node index {:?}",
789 current_node
790 )
791 })?;
792 let site_space = target.site_space(target_name).ok_or_else(|| {
793 anyhow::anyhow!(
794 "restructure_to: target node {:?} has no registered site space",
795 target_name
796 )
797 })?;
798 for site_idx in site_space {
799 assignment.insert(site_idx.id().clone(), current_name.clone());
800 }
801 }
802
803 Ok(Some(assignment))
804}
805
806fn clone_tree<T, V>(tree: &TreeTN<T, V>) -> Result<TreeTN<T, V>>
807where
808 T: TensorLike,
809 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
810{
811 Ok(TreeTN {
812 graph: tree.graph.clone(),
813 canonical_region: tree.canonical_region.clone(),
814 canonical_form: tree.canonical_form,
815 site_index_network: tree.site_index_network.clone(),
816 link_index_network: tree.link_index_network.clone(),
817 ortho_towards: tree.ortho_towards.clone(),
818 })
819}
820
821fn apply_final_truncation<T, V>(
822 tree: TreeTN<T, V>,
823 options: &RestructureOptions,
824) -> Result<TreeTN<T, V>>
825where
826 T: TensorLike,
827 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
828 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
829{
830 let Some(final_truncation) = options.final_truncation else {
831 return Ok(tree);
832 };
833 let center = tree
834 .node_names()
835 .into_iter()
836 .min()
837 .ok_or_else(|| anyhow::anyhow!("restructure_to: cannot truncate an empty network"))?;
838 tree.truncate([center], final_truncation)
839 .context("restructure_to: final truncation")
840}
841
842fn build_plan<T, CurrentV, TargetV>(
843 current: &SiteIndexNetwork<CurrentV, T::Index>,
844 target: &SiteIndexNetwork<TargetV, T::Index>,
845) -> Result<RestructurePlan<CurrentV, TargetV, T::Index>>
846where
847 T: TensorLike,
848 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
849 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
850 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
851{
852 let site_to_target = collect_site_targets::<T, TargetV>(target)?;
853 let site_to_current = collect_site_currents::<T, CurrentV>(current)?;
854 let current_site_ids = collect_current_site_ids::<T, CurrentV>(current)?;
855 let target_site_ids: HashSet<_> = site_to_target.keys().cloned().collect();
856
857 if current_site_ids != target_site_ids {
858 bail!("restructure_to: current and target must contain the same site index ids");
859 }
860
861 if current_nodes_map_uniquely_to_targets::<T, CurrentV, TargetV>(current, &site_to_target)? {
862 if target_nodes_span_connected_currents::<T, CurrentV, TargetV>(
863 current,
864 target,
865 &site_to_current,
866 )? {
867 return Ok(RestructurePlan {
868 kind: RestructurePlanKind::FuseOnly,
869 });
870 }
871
872 if let Some(target_assignment) = build_path_swap_then_fuse_assignment::<T, CurrentV, TargetV>(
873 current,
874 target,
875 &site_to_target,
876 )? {
877 return Ok(RestructurePlan {
878 kind: RestructurePlanKind::SwapThenFuse { target_assignment },
879 });
880 }
881 }
882
883 if target_nodes_map_uniquely_to_currents::<T, CurrentV, TargetV>(target, &site_to_current)? {
884 return Ok(RestructurePlan {
885 kind: RestructurePlanKind::SplitOnly,
886 });
887 }
888
889 if let Some(target_assignment) = build_swap_assignment::<T, CurrentV, TargetV>(current, target)?
890 {
891 return Ok(RestructurePlan {
892 kind: RestructurePlanKind::SwapOnly { target_assignment },
893 });
894 }
895
896 if let Some(split_target) = build_split_then_fuse_target::<T, CurrentV, TargetV>(
897 current,
898 target,
899 &site_to_target,
900 &site_to_current,
901 )? {
902 return Ok(RestructurePlan {
903 kind: RestructurePlanKind::SplitThenFuse {
904 split_target: Box::new(split_target),
905 },
906 });
907 }
908
909 bail!(
910 "restructure_to: planner placeholder only; split/move/mixed restructure planning is not implemented yet"
911 )
912}
913
914fn execute_plan<T, CurrentV, TargetV>(
915 tree: &TreeTN<T, CurrentV>,
916 plan: RestructurePlan<CurrentV, TargetV, T::Index>,
917 target: &SiteIndexNetwork<TargetV, T::Index>,
918 options: &RestructureOptions,
919) -> Result<TreeTN<T, TargetV>>
920where
921 T: TensorLike,
922 CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
923 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
924 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
925{
926 let result = match plan.kind {
927 RestructurePlanKind::FuseOnly => tree.fuse_to(target),
928 RestructurePlanKind::SplitOnly => tree.split_to(target, &options.split),
929 RestructurePlanKind::SwapOnly { target_assignment } => {
930 let mut working = clone_tree(tree)?;
931 working
932 .swap_site_indices(&target_assignment, &options.swap)
933 .context("restructure_to: swap phase")?;
934 Ok(working
935 .fuse_to(target)
936 .context("restructure_to: finalize after swap")?)
937 }
938 RestructurePlanKind::SwapThenFuse { target_assignment } => {
939 let mut working = clone_tree(tree)?;
940 working
941 .swap_site_indices(&target_assignment, &options.swap)
942 .context("restructure_to: swap phase")?;
943 Ok(working
944 .fuse_to(target)
945 .context("restructure_to: finalize after swap")?)
946 }
947 RestructurePlanKind::SplitThenFuse { split_target } => {
948 let split = tree
949 .split_to(split_target.as_ref(), &options.split)
950 .context("restructure_to: split phase")?;
951 split.fuse_to(target).context("restructure_to: fuse phase")
952 }
953 }?;
954
955 apply_final_truncation(result, options)
956}
957
958impl<T, V> TreeTN<T, V>
959where
960 T: TensorLike,
961 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
962{
963 pub fn restructure_to<TargetV>(
1043 &self,
1044 target: &SiteIndexNetwork<TargetV, T::Index>,
1045 options: &RestructureOptions,
1046 ) -> Result<TreeTN<T, TargetV>>
1047 where
1048 TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
1049 <T::Index as IndexLike>::Id:
1050 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
1051 {
1052 let plan = build_plan::<T, V, TargetV>(self.site_index_network(), target)
1053 .context("restructure_to: build plan")?;
1054 execute_plan(self, plan, target, options).context("restructure_to: execute plan")
1055 }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060 use std::collections::HashSet;
1061
1062 use tensor4all_core::{DynIndex, IndexLike, TensorDynLen};
1063
1064 use super::*;
1065
1066 type FourSiteChainCase = (
1067 TreeTN<TensorDynLen, String>,
1068 DynIndex,
1069 DynIndex,
1070 DynIndex,
1071 DynIndex,
1072 );
1073
1074 fn two_node_chain() -> anyhow::Result<(TreeTN<TensorDynLen, String>, DynIndex, DynIndex)> {
1075 let left = DynIndex::new_dyn(2);
1076 let right = DynIndex::new_dyn(2);
1077 let bond = DynIndex::new_dyn(1);
1078 let t0 = TensorDynLen::from_dense(vec![left.clone(), bond.clone()], vec![1.0, 0.0])?;
1079 let t1 = TensorDynLen::from_dense(vec![bond, right.clone()], vec![1.0, 0.0])?;
1080 let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1081 vec![t0, t1],
1082 vec!["A".to_string(), "B".to_string()],
1083 )?;
1084 Ok((treetn, left, right))
1085 }
1086
1087 fn two_node_groups_of_two() -> anyhow::Result<FourSiteChainCase> {
1088 let x0 = DynIndex::new_dyn(2);
1089 let x1 = DynIndex::new_dyn(2);
1090 let y0 = DynIndex::new_dyn(2);
1091 let y1 = DynIndex::new_dyn(2);
1092 let bond = DynIndex::new_dyn(2);
1093 let left_tensor = TensorDynLen::from_dense(
1094 vec![x0.clone(), x1.clone(), bond.clone()],
1095 vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
1096 )?;
1097 let right_tensor = TensorDynLen::from_dense(
1098 vec![bond, y0.clone(), y1.clone()],
1099 vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
1100 )?;
1101 let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1102 vec![left_tensor, right_tensor],
1103 vec!["Left".to_string(), "Right".to_string()],
1104 )?;
1105 Ok((treetn, x0, x1, y0, y1))
1106 }
1107
1108 fn three_node_chain_for_swap() -> anyhow::Result<FourSiteChainCase> {
1109 let s0 = DynIndex::new_dyn(2);
1110 let s1 = DynIndex::new_dyn(2);
1111 let s2 = DynIndex::new_dyn(2);
1112 let s3 = DynIndex::new_dyn(2);
1113 let b01 = DynIndex::new_dyn(2);
1114 let b12 = DynIndex::new_dyn(2);
1115 let t0 = TensorDynLen::from_dense(
1116 vec![s0.clone(), s1.clone(), b01.clone()],
1117 vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0],
1118 )?;
1119 let t1 = TensorDynLen::from_dense(
1120 vec![b01.clone(), s2.clone(), b12.clone()],
1121 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
1122 )?;
1123 let t2 = TensorDynLen::from_dense(vec![b12, s3.clone()], vec![1.0, 2.0, 3.0, 4.0])?;
1124 let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1125 vec![t0, t1, t2],
1126 vec!["A".to_string(), "B".to_string(), "C".to_string()],
1127 )?;
1128 Ok((treetn, s0, s1, s2, s3))
1129 }
1130
1131 fn four_node_interleaved_chain() -> anyhow::Result<FourSiteChainCase> {
1132 let x0 = DynIndex::new_dyn(2);
1133 let x1 = DynIndex::new_dyn(2);
1134 let y0 = DynIndex::new_dyn(2);
1135 let y1 = DynIndex::new_dyn(2);
1136 let b01 = DynIndex::new_dyn(2);
1137 let b12 = DynIndex::new_dyn(2);
1138 let b23 = DynIndex::new_dyn(2);
1139 let t0 = TensorDynLen::from_dense(vec![x0.clone(), b01.clone()], vec![1.0, 0.0, 0.0, 1.0])?;
1140 let t1 = TensorDynLen::from_dense(
1141 vec![b01.clone(), x1.clone(), b12.clone()],
1142 vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
1143 )?;
1144 let t2 = TensorDynLen::from_dense(
1145 vec![b12.clone(), y0.clone(), b23.clone()],
1146 vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
1147 )?;
1148 let t3 = TensorDynLen::from_dense(vec![b23, y1.clone()], vec![1.0, 0.0, 0.0, 1.0])?;
1149 let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1150 vec![t0, t1, t2, t3],
1151 vec![
1152 "0".to_string(),
1153 "1".to_string(),
1154 "2".to_string(),
1155 "3".to_string(),
1156 ],
1157 )?;
1158 Ok((treetn, x0, x1, y0, y1))
1159 }
1160
1161 #[test]
1162 fn test_restructure_to_fuse_only_matches_target_structure() -> anyhow::Result<()> {
1163 let (treetn, left, right) = two_node_chain()?;
1164
1165 let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1166 target
1167 .add_node("AB".to_string(), HashSet::from([left, right]))
1168 .map_err(anyhow::Error::msg)?;
1169
1170 let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1171 let dense_expected = treetn.contract_to_tensor()?;
1172 let dense_actual = result.contract_to_tensor()?;
1173
1174 assert_eq!(result.node_count(), 1);
1175 assert_eq!(result.site_index_network().node_count(), 1);
1176 assert!((&dense_actual - &dense_expected).maxabs() < 1e-12);
1177
1178 Ok(())
1179 }
1180
1181 #[test]
1182 fn test_restructure_to_split_only_matches_target_structure() -> anyhow::Result<()> {
1183 let (treetn, left, right) = two_node_chain()?;
1184
1185 let mut fused_target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1186 fused_target
1187 .add_node(
1188 "AB".to_string(),
1189 HashSet::from([left.clone(), right.clone()]),
1190 )
1191 .map_err(anyhow::Error::msg)?;
1192 let fused = treetn.restructure_to(&fused_target, &RestructureOptions::default())?;
1193
1194 let mut split_target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1195 split_target
1196 .add_node("Left".to_string(), HashSet::from([left]))
1197 .map_err(anyhow::Error::msg)?;
1198 split_target
1199 .add_node("Right".to_string(), HashSet::from([right]))
1200 .map_err(anyhow::Error::msg)?;
1201 split_target
1202 .add_edge(&"Left".to_string(), &"Right".to_string())
1203 .map_err(anyhow::Error::msg)?;
1204
1205 let result = fused.restructure_to(&split_target, &RestructureOptions::default())?;
1206 let dense_expected = fused.contract_to_tensor()?;
1207 let dense_actual = result.contract_to_tensor()?;
1208
1209 assert_eq!(result.node_count(), 2);
1210 assert!((&dense_actual - &dense_expected).maxabs() < 1e-12);
1211
1212 Ok(())
1213 }
1214
1215 #[test]
1216 fn test_restructure_to_swap_only_matches_target_structure() -> anyhow::Result<()> {
1217 let (treetn, s0, s1, s2, s3) = three_node_chain_for_swap()?;
1218
1219 let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1220 target
1221 .add_node("X".to_string(), HashSet::from([s0.clone()]))
1222 .map_err(anyhow::Error::msg)?;
1223 target
1224 .add_node("Y".to_string(), HashSet::from([s1.clone(), s2.clone()]))
1225 .map_err(anyhow::Error::msg)?;
1226 target
1227 .add_node("Z".to_string(), HashSet::from([s3.clone()]))
1228 .map_err(anyhow::Error::msg)?;
1229 target
1230 .add_edge(&"X".to_string(), &"Y".to_string())
1231 .map_err(anyhow::Error::msg)?;
1232 target
1233 .add_edge(&"Y".to_string(), &"Z".to_string())
1234 .map_err(anyhow::Error::msg)?;
1235
1236 let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1237 let dense_expected = treetn.contract_to_tensor()?;
1238 let dense_actual = result.contract_to_tensor()?;
1239
1240 assert_eq!(
1241 result
1242 .site_index_network()
1243 .find_node_by_index_id(s0.id())
1244 .map(|name| name.as_str()),
1245 Some("X")
1246 );
1247 assert_eq!(
1248 result
1249 .site_index_network()
1250 .find_node_by_index_id(s1.id())
1251 .map(|name| name.as_str()),
1252 Some("Y")
1253 );
1254 assert_eq!(
1255 result
1256 .site_index_network()
1257 .find_node_by_index_id(s2.id())
1258 .map(|name| name.as_str()),
1259 Some("Y")
1260 );
1261 assert_eq!(
1262 result
1263 .site_index_network()
1264 .find_node_by_index_id(s3.id())
1265 .map(|name| name.as_str()),
1266 Some("Z")
1267 );
1268 assert!((&dense_actual - &dense_expected).maxabs() < 1e-10);
1269
1270 Ok(())
1271 }
1272
1273 #[test]
1274 fn test_restructure_to_split_then_fuse_mixed_case() -> anyhow::Result<()> {
1275 let (treetn, x0, x1, y0, y1) = two_node_groups_of_two()?;
1276
1277 let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1278 target
1279 .add_node("X".to_string(), HashSet::from([x0.clone()]))
1280 .map_err(anyhow::Error::msg)?;
1281 target
1282 .add_node("Y".to_string(), HashSet::from([x1.clone(), y0.clone()]))
1283 .map_err(anyhow::Error::msg)?;
1284 target
1285 .add_node("Z".to_string(), HashSet::from([y1.clone()]))
1286 .map_err(anyhow::Error::msg)?;
1287 target
1288 .add_edge(&"X".to_string(), &"Y".to_string())
1289 .map_err(anyhow::Error::msg)?;
1290 target
1291 .add_edge(&"Y".to_string(), &"Z".to_string())
1292 .map_err(anyhow::Error::msg)?;
1293
1294 let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1295 let dense_expected = treetn.contract_to_tensor()?;
1296 let dense_actual = result.contract_to_tensor()?;
1297
1298 assert_eq!(result.node_count(), 3);
1299 assert_eq!(result.edge_count(), 2);
1300 assert!(result
1301 .site_index_network()
1302 .share_equivalent_site_index_network(&target));
1303 assert_eq!(
1304 result
1305 .site_index_network()
1306 .find_node_by_index_id(x0.id())
1307 .map(|name| name.as_str()),
1308 Some("X")
1309 );
1310 assert_eq!(
1311 result
1312 .site_index_network()
1313 .find_node_by_index_id(x1.id())
1314 .map(|name| name.as_str()),
1315 Some("Y")
1316 );
1317 assert_eq!(
1318 result
1319 .site_index_network()
1320 .find_node_by_index_id(y0.id())
1321 .map(|name| name.as_str()),
1322 Some("Y")
1323 );
1324 assert_eq!(
1325 result
1326 .site_index_network()
1327 .find_node_by_index_id(y1.id())
1328 .map(|name| name.as_str()),
1329 Some("Z")
1330 );
1331 assert!((&dense_actual - &dense_expected).maxabs() < 1e-10);
1332
1333 Ok(())
1334 }
1335
1336 #[test]
1337 fn test_restructure_to_swap_then_fuse_mixed_case() -> anyhow::Result<()> {
1338 let (treetn, x0, x1, y0, y1) = four_node_interleaved_chain()?;
1339
1340 let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1341 target
1342 .add_node("X".to_string(), HashSet::from([x0.clone(), y0.clone()]))
1343 .map_err(anyhow::Error::msg)?;
1344 target
1345 .add_node("Y".to_string(), HashSet::from([x1.clone(), y1.clone()]))
1346 .map_err(anyhow::Error::msg)?;
1347 target
1348 .add_edge(&"X".to_string(), &"Y".to_string())
1349 .map_err(anyhow::Error::msg)?;
1350
1351 let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1352 let dense_expected = treetn.contract_to_tensor()?;
1353 let dense_actual = result.contract_to_tensor()?;
1354
1355 assert_eq!(result.node_count(), 2);
1356 assert_eq!(result.edge_count(), 1);
1357 assert!(result
1358 .site_index_network()
1359 .share_equivalent_site_index_network(&target));
1360 assert_eq!(
1361 result
1362 .site_index_network()
1363 .find_node_by_index_id(x0.id())
1364 .map(|name| name.as_str()),
1365 Some("X")
1366 );
1367 assert_eq!(
1368 result
1369 .site_index_network()
1370 .find_node_by_index_id(y0.id())
1371 .map(|name| name.as_str()),
1372 Some("X")
1373 );
1374 assert_eq!(
1375 result
1376 .site_index_network()
1377 .find_node_by_index_id(x1.id())
1378 .map(|name| name.as_str()),
1379 Some("Y")
1380 );
1381 assert_eq!(
1382 result
1383 .site_index_network()
1384 .find_node_by_index_id(y1.id())
1385 .map(|name| name.as_str()),
1386 Some("Y")
1387 );
1388 assert!((&dense_actual - &dense_expected).maxabs() < 1e-10);
1389
1390 Ok(())
1391 }
1392
1393 #[test]
1394 fn test_restructure_to_two_node_swap_only_cross_pairing() -> anyhow::Result<()> {
1395 let (treetn, x0, x1, y0, y1) = two_node_groups_of_two()?;
1396
1397 let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1398 target
1399 .add_node("X".to_string(), HashSet::from([x0.clone(), y0.clone()]))
1400 .map_err(anyhow::Error::msg)?;
1401 target
1402 .add_node("Y".to_string(), HashSet::from([x1.clone(), y1.clone()]))
1403 .map_err(anyhow::Error::msg)?;
1404 target
1405 .add_edge(&"X".to_string(), &"Y".to_string())
1406 .map_err(anyhow::Error::msg)?;
1407
1408 let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1409 let dense_expected = treetn.contract_to_tensor()?;
1410 let dense_actual = result.contract_to_tensor()?;
1411
1412 assert_eq!(result.node_count(), 2);
1413 assert_eq!(result.edge_count(), 1);
1414 assert!(result
1415 .site_index_network()
1416 .share_equivalent_site_index_network(&target));
1417 assert_eq!(
1418 result
1419 .site_index_network()
1420 .find_node_by_index_id(x0.id())
1421 .map(|name| name.as_str()),
1422 Some("X")
1423 );
1424 assert_eq!(
1425 result
1426 .site_index_network()
1427 .find_node_by_index_id(y0.id())
1428 .map(|name| name.as_str()),
1429 Some("X")
1430 );
1431 assert_eq!(
1432 result
1433 .site_index_network()
1434 .find_node_by_index_id(x1.id())
1435 .map(|name| name.as_str()),
1436 Some("Y")
1437 );
1438 assert_eq!(
1439 result
1440 .site_index_network()
1441 .find_node_by_index_id(y1.id())
1442 .map(|name| name.as_str()),
1443 Some("Y")
1444 );
1445 assert!((&dense_actual - &dense_expected).maxabs() < 1e-10);
1446
1447 Ok(())
1448 }
1449}