Skip to main content

tensor4all_treetn/treetn/restructure/
mod.rs

1//! Public scaffolding for multi-phase TreeTN restructuring.
2//!
3//! The approved B2a design is plan-first:
4//! 1. Build a pure restructure plan from the current and target site-index networks.
5//! 2. Execute that plan through split, move, and fuse phases.
6//!
7//! The initial implementation currently supports fuse-only, split-only,
8//! swap-only, a conservative path-based swap-then-fuse mixed path, and a
9//! conservative split-then-fuse mixed path.
10//!
11//! Unsupported patterns are reported explicitly. In particular, mixed cases
12//! that require both splitting a node into multiple cross-node fragments and a
13//! subsequent swap/move phase may still remain staged behind placeholder
14//! errors.
15
16use std::collections::{HashMap, HashSet};
17use std::hash::Hash;
18
19use crate::node_name_network::NodeNameNetwork;
20use anyhow::{bail, Context, Result};
21use petgraph::stable_graph::NodeIndex;
22use tensor4all_core::{IndexLike, TensorLike};
23
24use super::TreeTN;
25use crate::{RestructureOptions, SiteIndexNetwork};
26
27#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
28struct FragmentNode<CurrentV, TargetV> {
29    current: CurrentV,
30    split_rank: usize,
31    target: TargetV,
32}
33
34type SplitThenFuseTarget<CurrentV, TargetV, I> =
35    SiteIndexNetwork<FragmentNode<CurrentV, TargetV>, I>;
36
37#[derive(Debug, Clone)]
38enum RestructurePlanKind<CurrentV, TargetV, I>
39where
40    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
41    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
42    I: IndexLike,
43    I::Id: Eq + Hash,
44{
45    FuseOnly,
46    SplitOnly,
47    SwapOnly {
48        target_assignment: HashMap<I::Id, CurrentV>,
49    },
50    SwapThenFuse {
51        target_assignment: HashMap<I::Id, CurrentV>,
52    },
53    SplitThenFuse {
54        split_target: Box<SplitThenFuseTarget<CurrentV, TargetV, I>>,
55    },
56}
57
58#[derive(Debug, Clone)]
59struct RestructurePlan<CurrentV, TargetV, I>
60where
61    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
62    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
63    I: IndexLike,
64    I::Id: Eq + Hash,
65{
66    kind: RestructurePlanKind<CurrentV, TargetV, I>,
67}
68
69fn collect_site_targets<T, TargetV>(
70    target: &SiteIndexNetwork<TargetV, T::Index>,
71) -> Result<HashMap<<T::Index as IndexLike>::Id, TargetV>>
72where
73    T: TensorLike,
74    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
75    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
76{
77    let mut site_to_target = HashMap::new();
78    for target_node_name in target.node_names() {
79        let site_space = target.site_space(target_node_name).ok_or_else(|| {
80            anyhow::anyhow!(
81                "restructure_to: target node {:?} has no registered site space",
82                target_node_name
83            )
84        })?;
85        for site_idx in site_space {
86            let existing = site_to_target.insert(site_idx.id().clone(), target_node_name.clone());
87            if let Some(previous_target) = existing {
88                bail!(
89                    "restructure_to: site index {:?} appears in both target nodes {:?} and {:?}",
90                    site_idx.id(),
91                    previous_target,
92                    target_node_name
93                );
94            }
95        }
96    }
97    Ok(site_to_target)
98}
99
100fn collect_current_site_ids<T, CurrentV>(
101    current: &SiteIndexNetwork<CurrentV, T::Index>,
102) -> Result<HashSet<<T::Index as IndexLike>::Id>>
103where
104    T: TensorLike,
105    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
106    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
107{
108    let mut site_ids = HashSet::new();
109    for current_node_name in current.node_names() {
110        let site_space = current.site_space(current_node_name).ok_or_else(|| {
111            anyhow::anyhow!(
112                "restructure_to: current node {:?} has no registered site space",
113                current_node_name
114            )
115        })?;
116        for site_idx in site_space {
117            site_ids.insert(site_idx.id().clone());
118        }
119    }
120    Ok(site_ids)
121}
122
123fn current_nodes_map_uniquely_to_targets<T, CurrentV, TargetV>(
124    current: &SiteIndexNetwork<CurrentV, T::Index>,
125    site_to_target: &HashMap<<T::Index as IndexLike>::Id, TargetV>,
126) -> Result<bool>
127where
128    T: TensorLike,
129    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
130    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
131    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
132{
133    for current_node_name in current.node_names() {
134        let site_space = current.site_space(current_node_name).ok_or_else(|| {
135            anyhow::anyhow!(
136                "restructure_to: current node {:?} has no registered site space",
137                current_node_name
138            )
139        })?;
140        let target_names: HashSet<_> = site_space
141            .iter()
142            .map(|site_idx| {
143                site_to_target
144                    .get(site_idx.id())
145                    .cloned()
146                    .ok_or_else(|| {
147                        anyhow::anyhow!(
148                            "restructure_to: site index {:?} is present in the current network but missing from the target",
149                            site_idx.id()
150                        )
151                    })
152            })
153            .collect::<Result<_>>()?;
154        if target_names.len() > 1 {
155            return Ok(false);
156        }
157    }
158    Ok(true)
159}
160
161fn collect_site_currents<T, CurrentV>(
162    current: &SiteIndexNetwork<CurrentV, T::Index>,
163) -> Result<HashMap<<T::Index as IndexLike>::Id, CurrentV>>
164where
165    T: TensorLike,
166    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
167    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
168{
169    let mut site_to_current = HashMap::new();
170    for current_node_name in current.node_names() {
171        let site_space = current.site_space(current_node_name).ok_or_else(|| {
172            anyhow::anyhow!(
173                "restructure_to: current node {:?} has no registered site space",
174                current_node_name
175            )
176        })?;
177        for site_idx in site_space {
178            let existing = site_to_current.insert(site_idx.id().clone(), current_node_name.clone());
179            if let Some(previous_current) = existing {
180                bail!(
181                    "restructure_to: site index {:?} appears in both current nodes {:?} and {:?}",
182                    site_idx.id(),
183                    previous_current,
184                    current_node_name
185                );
186            }
187        }
188    }
189    Ok(site_to_current)
190}
191
192fn target_nodes_map_uniquely_to_currents<T, CurrentV, TargetV>(
193    target: &SiteIndexNetwork<TargetV, T::Index>,
194    site_to_current: &HashMap<<T::Index as IndexLike>::Id, CurrentV>,
195) -> Result<bool>
196where
197    T: TensorLike,
198    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
199    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
200    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
201{
202    for target_node_name in target.node_names() {
203        let site_space = target.site_space(target_node_name).ok_or_else(|| {
204            anyhow::anyhow!(
205                "restructure_to: target node {:?} has no registered site space",
206                target_node_name
207            )
208        })?;
209        let current_names: HashSet<_> = site_space
210            .iter()
211            .map(|site_idx| {
212                site_to_current
213                    .get(site_idx.id())
214                    .cloned()
215                    .ok_or_else(|| {
216                        anyhow::anyhow!(
217                            "restructure_to: site index {:?} is present in the target but missing from the current network",
218                            site_idx.id()
219                        )
220                    })
221            })
222            .collect::<Result<_>>()?;
223        if current_names.len() > 1 {
224            return Ok(false);
225        }
226    }
227    Ok(true)
228}
229
230fn target_nodes_span_connected_currents<T, CurrentV, TargetV>(
231    current: &SiteIndexNetwork<CurrentV, T::Index>,
232    target: &SiteIndexNetwork<TargetV, T::Index>,
233    site_to_current: &HashMap<<T::Index as IndexLike>::Id, CurrentV>,
234) -> Result<bool>
235where
236    T: TensorLike,
237    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
238    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
239    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
240{
241    for target_node_name in target.node_names() {
242        let site_space = target.site_space(target_node_name).ok_or_else(|| {
243            anyhow::anyhow!(
244                "restructure_to: target node {:?} has no registered site space",
245                target_node_name
246            )
247        })?;
248        let current_nodes: HashSet<_> = site_space
249            .iter()
250            .map(|site_idx| {
251                let current_name = site_to_current.get(site_idx.id()).ok_or_else(|| {
252                    anyhow::anyhow!(
253                        "restructure_to: site index {:?} is present in the target but missing from the current network",
254                        site_idx.id()
255                    )
256                })?;
257                current.node_index(current_name).ok_or_else(|| {
258                    anyhow::anyhow!(
259                        "restructure_to: current node {:?} is missing from the topology",
260                        current_name
261                    )
262                })
263            })
264            .collect::<Result<_>>()?;
265        if !current.is_connected_subset(&current_nodes) {
266            return Ok(false);
267        }
268    }
269
270    Ok(true)
271}
272
273fn collect_shared_targets<T, CurrentV, TargetV>(
274    target: &SiteIndexNetwork<TargetV, T::Index>,
275    site_to_current: &HashMap<<T::Index as IndexLike>::Id, CurrentV>,
276) -> Result<HashSet<TargetV>>
277where
278    T: TensorLike,
279    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
280    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
281    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
282{
283    let mut shared_targets = HashSet::new();
284    for target_node_name in target.node_names() {
285        let site_space = target.site_space(target_node_name).ok_or_else(|| {
286            anyhow::anyhow!(
287                "restructure_to: target node {:?} has no registered site space",
288                target_node_name
289            )
290        })?;
291        let current_names: HashSet<_> = site_space
292            .iter()
293            .map(|site_idx| {
294                site_to_current
295                    .get(site_idx.id())
296                    .cloned()
297                    .ok_or_else(|| {
298                        anyhow::anyhow!(
299                            "restructure_to: site index {:?} is present in the target but missing from the current network",
300                            site_idx.id()
301                        )
302                    })
303            })
304            .collect::<Result<_>>()?;
305        if current_names.len() > 1 {
306            shared_targets.insert(target_node_name.clone());
307        }
308    }
309    Ok(shared_targets)
310}
311
312fn build_split_then_fuse_target<T, CurrentV, TargetV>(
313    current: &SiteIndexNetwork<CurrentV, T::Index>,
314    target: &SiteIndexNetwork<TargetV, T::Index>,
315    site_to_target: &HashMap<<T::Index as IndexLike>::Id, TargetV>,
316    site_to_current: &HashMap<<T::Index as IndexLike>::Id, CurrentV>,
317) -> Result<Option<SplitThenFuseTarget<CurrentV, TargetV, T::Index>>>
318where
319    T: TensorLike,
320    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
321    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
322    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
323{
324    if !target_nodes_span_connected_currents::<T, CurrentV, TargetV>(
325        current,
326        target,
327        site_to_current,
328    )? {
329        return Ok(None);
330    }
331
332    let shared_targets = collect_shared_targets::<T, CurrentV, TargetV>(target, site_to_current)?;
333    let mut split_target = SiteIndexNetwork::with_capacity(current.node_count(), 0);
334    let mut current_node_names: Vec<_> = current.node_names().into_iter().cloned().collect();
335    current_node_names.sort();
336
337    for current_node_name in current_node_names {
338        let site_space = current.site_space(&current_node_name).ok_or_else(|| {
339            anyhow::anyhow!(
340                "restructure_to: current node {:?} has no registered site space",
341                current_node_name
342            )
343        })?;
344        let mut fragments: HashMap<TargetV, HashSet<T::Index>> = HashMap::new();
345        for site_idx in site_space {
346            let target_node_name = site_to_target.get(site_idx.id()).cloned().ok_or_else(|| {
347                anyhow::anyhow!(
348                    "restructure_to: site index {:?} is present in the current network but missing from the target",
349                    site_idx.id()
350                )
351            })?;
352            fragments
353                .entry(target_node_name)
354                .or_default()
355                .insert(site_idx.clone());
356        }
357
358        let shared_targets_here: Vec<_> = fragments
359            .keys()
360            .filter(|target_name| shared_targets.contains(*target_name))
361            .cloned()
362            .collect();
363        if shared_targets_here.len() > 1 {
364            return Ok(None);
365        }
366        let boundary_target = shared_targets_here.first().cloned();
367
368        let mut fragments: Vec<_> = fragments.into_iter().collect();
369        fragments.sort_by(|(left_name, _), (right_name, _)| {
370            let left_is_boundary = boundary_target.as_ref() == Some(left_name);
371            let right_is_boundary = boundary_target.as_ref() == Some(right_name);
372            left_is_boundary
373                .cmp(&right_is_boundary)
374                .then_with(|| left_name.cmp(right_name))
375        });
376
377        for (split_rank, (target_node_name, fragment_site_space)) in
378            fragments.into_iter().enumerate()
379        {
380            split_target
381                .add_node(
382                    FragmentNode {
383                        current: current_node_name.clone(),
384                        split_rank,
385                        target: target_node_name,
386                    },
387                    fragment_site_space,
388                )
389                .map_err(anyhow::Error::msg)?;
390        }
391    }
392
393    Ok(Some(split_target))
394}
395
396fn ordered_path_nodes<V>(topology: &NodeNameNetwork<V>) -> Option<Vec<V>>
397where
398    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
399{
400    let graph = topology.graph();
401    match topology.node_count() {
402        0 => return Some(Vec::new()),
403        1 => {
404            let node = graph.node_indices().next()?;
405            return Some(vec![topology.node_name(node)?.clone()]);
406        }
407        _ => {}
408    }
409
410    let mut leaves = Vec::new();
411    for node in graph.node_indices() {
412        let degree = graph.neighbors(node).count();
413        if degree > 2 {
414            return None;
415        }
416        if degree == 1 {
417            leaves.push(node);
418        }
419    }
420    if leaves.len() != 2 {
421        return None;
422    }
423
424    leaves.sort_by_key(|node| topology.node_name(*node).cloned());
425    let mut ordered = Vec::with_capacity(topology.node_count());
426    let mut previous = None;
427    let mut current = *leaves.first()?;
428
429    loop {
430        ordered.push(topology.node_name(current)?.clone());
431        let next = graph
432            .neighbors(current)
433            .find(|neighbor| Some(*neighbor) != previous);
434        let Some(next) = next else {
435            break;
436        };
437        previous = Some(current);
438        current = next;
439    }
440
441    if ordered.len() == topology.node_count() {
442        Some(ordered)
443    } else {
444        None
445    }
446}
447
448fn build_path_swap_then_fuse_assignment<T, CurrentV, TargetV>(
449    current: &SiteIndexNetwork<CurrentV, T::Index>,
450    target: &SiteIndexNetwork<TargetV, T::Index>,
451    site_to_target: &HashMap<<T::Index as IndexLike>::Id, TargetV>,
452) -> Result<Option<HashMap<<T::Index as IndexLike>::Id, CurrentV>>>
453where
454    T: TensorLike,
455    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
456    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
457    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
458{
459    let current_path = match ordered_path_nodes(current.topology()) {
460        Some(path) => path,
461        None => return Ok(None),
462    };
463    let target_path = match ordered_path_nodes(target.topology()) {
464        Some(path) => path,
465        None => return Ok(None),
466    };
467
468    let mut contributor_counts: HashMap<TargetV, usize> = HashMap::new();
469    for current_node_name in &current_path {
470        let site_space = current.site_space(current_node_name).ok_or_else(|| {
471            anyhow::anyhow!(
472                "restructure_to: current node {:?} has no registered site space",
473                current_node_name
474            )
475        })?;
476        let mut target_names: Vec<_> = site_space
477            .iter()
478            .map(|site_idx| {
479                site_to_target
480                    .get(site_idx.id())
481                    .cloned()
482                    .ok_or_else(|| {
483                        anyhow::anyhow!(
484                            "restructure_to: site index {:?} is present in the current network but missing from the target",
485                            site_idx.id()
486                        )
487                    })
488            })
489            .collect::<Result<_>>()?;
490        target_names.sort();
491        target_names.dedup();
492        if target_names.len() != 1 {
493            return Ok(None);
494        }
495        let target_name = target_names.into_iter().next().ok_or_else(|| {
496            anyhow::anyhow!(
497                "restructure_to: current node {:?} has no target mapping",
498                current_node_name
499            )
500        })?;
501        *contributor_counts.entry(target_name).or_default() += 1;
502    }
503
504    let total_contributors: usize = contributor_counts.values().sum();
505    if total_contributors != current_path.len() {
506        return Ok(None);
507    }
508
509    let mut contiguous_blocks: HashMap<TargetV, Vec<CurrentV>> = HashMap::new();
510    let mut cursor = 0usize;
511    for target_node_name in &target_path {
512        let block_len = *contributor_counts.get(target_node_name).unwrap_or(&0);
513        if block_len == 0 || cursor + block_len > current_path.len() {
514            return Ok(None);
515        }
516        contiguous_blocks.insert(
517            target_node_name.clone(),
518            current_path[cursor..cursor + block_len].to_vec(),
519        );
520        cursor += block_len;
521    }
522    if cursor != current_path.len() {
523        return Ok(None);
524    }
525
526    let mut target_assignment = HashMap::new();
527    for target_node_name in &target_path {
528        let block = contiguous_blocks.get(target_node_name).ok_or_else(|| {
529            anyhow::anyhow!(
530                "restructure_to: missing contiguous block for target {:?}",
531                target_node_name
532            )
533        })?;
534        let mut site_ids: Vec<_> = target
535            .site_space(target_node_name)
536            .ok_or_else(|| {
537                anyhow::anyhow!(
538                    "restructure_to: target node {:?} has no registered site space",
539                    target_node_name
540                )
541            })?
542            .iter()
543            .map(|site_idx| site_idx.id().clone())
544            .collect();
545        site_ids.sort();
546        if site_ids.len() < block.len() {
547            return Ok(None);
548        }
549
550        for (position, site_id) in site_ids.into_iter().enumerate() {
551            let block_index = position.min(block.len() - 1);
552            target_assignment.insert(site_id, block[block_index].clone());
553        }
554    }
555
556    Ok(Some(target_assignment))
557}
558
559fn tree_children<V>(
560    topology: &NodeNameNetwork<V>,
561    node: NodeIndex,
562    parent: Option<NodeIndex>,
563) -> Vec<NodeIndex>
564where
565    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
566{
567    topology
568        .graph()
569        .neighbors(node)
570        .filter(|neighbor| Some(*neighbor) != parent)
571        .collect()
572}
573
574fn rooted_signature<V>(
575    topology: &NodeNameNetwork<V>,
576    node: NodeIndex,
577    parent: Option<NodeIndex>,
578    cache: &mut HashMap<(usize, Option<usize>), String>,
579) -> String
580where
581    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
582{
583    let key = (node.index(), parent.map(|p| p.index()));
584    if let Some(signature) = cache.get(&key) {
585        return signature.clone();
586    }
587
588    let mut child_signatures: Vec<String> = tree_children(topology, node, parent)
589        .into_iter()
590        .map(|child| rooted_signature(topology, child, Some(node), cache))
591        .collect();
592    child_signatures.sort();
593
594    let signature = format!("({})", child_signatures.concat());
595    cache.insert(key, signature.clone());
596    signature
597}
598
599#[derive(Default)]
600struct IsomorphicMatchState {
601    current_cache: HashMap<(usize, Option<usize>), String>,
602    target_cache: HashMap<(usize, Option<usize>), String>,
603    mapping: HashMap<NodeIndex, NodeIndex>,
604}
605
606fn match_isomorphic_subtrees<CurrentV, TargetV>(
607    current_topology: &NodeNameNetwork<CurrentV>,
608    target_topology: &NodeNameNetwork<TargetV>,
609    current_node: NodeIndex,
610    current_parent: Option<NodeIndex>,
611    target_node: NodeIndex,
612    target_parent: Option<NodeIndex>,
613    state: &mut IsomorphicMatchState,
614) -> bool
615where
616    CurrentV: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
617    TargetV: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
618{
619    if let Some(existing) = state.mapping.insert(target_node, current_node) {
620        if existing != current_node {
621            return false;
622        }
623    }
624
625    let current_children = tree_children(current_topology, current_node, current_parent);
626    let target_children = tree_children(target_topology, target_node, target_parent);
627    if current_children.len() != target_children.len() {
628        return false;
629    }
630
631    let mut current_groups: HashMap<String, Vec<NodeIndex>> = HashMap::new();
632    for child in current_children {
633        let signature = rooted_signature(
634            current_topology,
635            child,
636            Some(current_node),
637            &mut state.current_cache,
638        );
639        current_groups.entry(signature).or_default().push(child);
640    }
641
642    let mut target_groups: HashMap<String, Vec<NodeIndex>> = HashMap::new();
643    for child in target_children {
644        let signature = rooted_signature(
645            target_topology,
646            child,
647            Some(target_node),
648            &mut state.target_cache,
649        );
650        target_groups.entry(signature).or_default().push(child);
651    }
652
653    if current_groups.len() != target_groups.len() {
654        return false;
655    }
656
657    let mut signatures: Vec<String> = current_groups.keys().cloned().collect();
658    signatures.sort();
659    for signature in signatures {
660        let mut current_bucket = match current_groups.remove(&signature) {
661            Some(bucket) => bucket,
662            None => return false,
663        };
664        let mut target_bucket = match target_groups.remove(&signature) {
665            Some(bucket) => bucket,
666            None => return false,
667        };
668        if current_bucket.len() != target_bucket.len() {
669            return false;
670        }
671
672        current_bucket.sort_by_key(|node| node.index());
673        target_bucket.sort_by_key(|node| node.index());
674
675        for (current_child, target_child) in current_bucket.into_iter().zip(target_bucket) {
676            if !match_isomorphic_subtrees(
677                current_topology,
678                target_topology,
679                current_child,
680                Some(current_node),
681                target_child,
682                Some(target_node),
683                state,
684            ) {
685                return false;
686            }
687        }
688    }
689
690    true
691}
692
693fn match_tree_topologies<CurrentV, TargetV>(
694    current_topology: &NodeNameNetwork<CurrentV>,
695    target_topology: &NodeNameNetwork<TargetV>,
696) -> Option<HashMap<NodeIndex, NodeIndex>>
697where
698    CurrentV: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
699    TargetV: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
700{
701    if current_topology.node_count() != target_topology.node_count() {
702        return None;
703    }
704    if current_topology.edge_count() != target_topology.edge_count() {
705        return None;
706    }
707
708    let mut current_roots: Vec<(String, NodeIndex)> = current_topology
709        .graph()
710        .node_indices()
711        .map(|node| {
712            (
713                rooted_signature(current_topology, node, None, &mut HashMap::new()),
714                node,
715            )
716        })
717        .collect();
718    current_roots.sort_by_key(|(signature, node)| (signature.clone(), node.index()));
719
720    let mut target_roots: Vec<(String, NodeIndex)> = target_topology
721        .graph()
722        .node_indices()
723        .map(|node| {
724            (
725                rooted_signature(target_topology, node, None, &mut HashMap::new()),
726                node,
727            )
728        })
729        .collect();
730    target_roots.sort_by_key(|(signature, node)| (signature.clone(), node.index()));
731
732    for (target_signature, target_root) in &target_roots {
733        for (current_signature, current_root) in &current_roots {
734            if current_signature != target_signature {
735                continue;
736            }
737
738            let mut state = IsomorphicMatchState::default();
739            if match_isomorphic_subtrees(
740                current_topology,
741                target_topology,
742                *current_root,
743                None,
744                *target_root,
745                None,
746                &mut state,
747            ) && state.mapping.len() == target_topology.node_count()
748            {
749                return Some(state.mapping);
750            }
751        }
752    }
753
754    None
755}
756
757fn build_swap_assignment<T, CurrentV, TargetV>(
758    current: &SiteIndexNetwork<CurrentV, T::Index>,
759    target: &SiteIndexNetwork<TargetV, T::Index>,
760) -> Result<Option<HashMap<<T::Index as IndexLike>::Id, CurrentV>>>
761where
762    T: TensorLike,
763    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
764    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
765    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
766{
767    let topology_mapping = match match_tree_topologies(current.topology(), target.topology()) {
768        Some(mapping) => mapping,
769        None => return Ok(None),
770    };
771
772    let mut assignment = HashMap::new();
773    for target_node in target.graph().node_indices() {
774        let target_name = target.node_name(target_node).ok_or_else(|| {
775            anyhow::anyhow!(
776                "restructure_to: target topology mapping referenced a missing node index {:?}",
777                target_node
778            )
779        })?;
780        let current_node = *topology_mapping.get(&target_node).ok_or_else(|| {
781            anyhow::anyhow!(
782                "restructure_to: target topology mapping did not assign a current node to {:?}",
783                target_name
784            )
785        })?;
786        let current_name = current.node_name(current_node).ok_or_else(|| {
787            anyhow::anyhow!(
788                "restructure_to: current topology mapping referenced a missing node index {:?}",
789                current_node
790            )
791        })?;
792        let site_space = target.site_space(target_name).ok_or_else(|| {
793            anyhow::anyhow!(
794                "restructure_to: target node {:?} has no registered site space",
795                target_name
796            )
797        })?;
798        for site_idx in site_space {
799            assignment.insert(site_idx.id().clone(), current_name.clone());
800        }
801    }
802
803    Ok(Some(assignment))
804}
805
806fn clone_tree<T, V>(tree: &TreeTN<T, V>) -> Result<TreeTN<T, V>>
807where
808    T: TensorLike,
809    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
810{
811    Ok(TreeTN {
812        graph: tree.graph.clone(),
813        canonical_region: tree.canonical_region.clone(),
814        canonical_form: tree.canonical_form,
815        site_index_network: tree.site_index_network.clone(),
816        link_index_network: tree.link_index_network.clone(),
817        ortho_towards: tree.ortho_towards.clone(),
818    })
819}
820
821fn apply_final_truncation<T, V>(
822    tree: TreeTN<T, V>,
823    options: &RestructureOptions,
824) -> Result<TreeTN<T, V>>
825where
826    T: TensorLike,
827    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
828    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
829{
830    let Some(final_truncation) = options.final_truncation else {
831        return Ok(tree);
832    };
833    let center = tree
834        .node_names()
835        .into_iter()
836        .min()
837        .ok_or_else(|| anyhow::anyhow!("restructure_to: cannot truncate an empty network"))?;
838    tree.truncate([center], final_truncation)
839        .context("restructure_to: final truncation")
840}
841
842fn build_plan<T, CurrentV, TargetV>(
843    current: &SiteIndexNetwork<CurrentV, T::Index>,
844    target: &SiteIndexNetwork<TargetV, T::Index>,
845) -> Result<RestructurePlan<CurrentV, TargetV, T::Index>>
846where
847    T: TensorLike,
848    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
849    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
850    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
851{
852    let site_to_target = collect_site_targets::<T, TargetV>(target)?;
853    let site_to_current = collect_site_currents::<T, CurrentV>(current)?;
854    let current_site_ids = collect_current_site_ids::<T, CurrentV>(current)?;
855    let target_site_ids: HashSet<_> = site_to_target.keys().cloned().collect();
856
857    if current_site_ids != target_site_ids {
858        bail!("restructure_to: current and target must contain the same site index ids");
859    }
860
861    if current_nodes_map_uniquely_to_targets::<T, CurrentV, TargetV>(current, &site_to_target)? {
862        if target_nodes_span_connected_currents::<T, CurrentV, TargetV>(
863            current,
864            target,
865            &site_to_current,
866        )? {
867            return Ok(RestructurePlan {
868                kind: RestructurePlanKind::FuseOnly,
869            });
870        }
871
872        if let Some(target_assignment) = build_path_swap_then_fuse_assignment::<T, CurrentV, TargetV>(
873            current,
874            target,
875            &site_to_target,
876        )? {
877            return Ok(RestructurePlan {
878                kind: RestructurePlanKind::SwapThenFuse { target_assignment },
879            });
880        }
881    }
882
883    if target_nodes_map_uniquely_to_currents::<T, CurrentV, TargetV>(target, &site_to_current)? {
884        return Ok(RestructurePlan {
885            kind: RestructurePlanKind::SplitOnly,
886        });
887    }
888
889    if let Some(target_assignment) = build_swap_assignment::<T, CurrentV, TargetV>(current, target)?
890    {
891        return Ok(RestructurePlan {
892            kind: RestructurePlanKind::SwapOnly { target_assignment },
893        });
894    }
895
896    if let Some(split_target) = build_split_then_fuse_target::<T, CurrentV, TargetV>(
897        current,
898        target,
899        &site_to_target,
900        &site_to_current,
901    )? {
902        return Ok(RestructurePlan {
903            kind: RestructurePlanKind::SplitThenFuse {
904                split_target: Box::new(split_target),
905            },
906        });
907    }
908
909    bail!(
910        "restructure_to: planner placeholder only; split/move/mixed restructure planning is not implemented yet"
911    )
912}
913
914fn execute_plan<T, CurrentV, TargetV>(
915    tree: &TreeTN<T, CurrentV>,
916    plan: RestructurePlan<CurrentV, TargetV, T::Index>,
917    target: &SiteIndexNetwork<TargetV, T::Index>,
918    options: &RestructureOptions,
919) -> Result<TreeTN<T, TargetV>>
920where
921    T: TensorLike,
922    CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
923    TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
924    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
925{
926    let result = match plan.kind {
927        RestructurePlanKind::FuseOnly => tree.fuse_to(target),
928        RestructurePlanKind::SplitOnly => tree.split_to(target, &options.split),
929        RestructurePlanKind::SwapOnly { target_assignment } => {
930            let mut working = clone_tree(tree)?;
931            working
932                .swap_site_indices(&target_assignment, &options.swap)
933                .context("restructure_to: swap phase")?;
934            Ok(working
935                .fuse_to(target)
936                .context("restructure_to: finalize after swap")?)
937        }
938        RestructurePlanKind::SwapThenFuse { target_assignment } => {
939            let mut working = clone_tree(tree)?;
940            working
941                .swap_site_indices(&target_assignment, &options.swap)
942                .context("restructure_to: swap phase")?;
943            Ok(working
944                .fuse_to(target)
945                .context("restructure_to: finalize after swap")?)
946        }
947        RestructurePlanKind::SplitThenFuse { split_target } => {
948            let split = tree
949                .split_to(split_target.as_ref(), &options.split)
950                .context("restructure_to: split phase")?;
951            split.fuse_to(target).context("restructure_to: fuse phase")
952        }
953    }?;
954
955    apply_final_truncation(result, options)
956}
957
958impl<T, V> TreeTN<T, V>
959where
960    T: TensorLike,
961    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
962{
963    /// Restructure this TreeTN to match a target site-index network.
964    ///
965    /// This is the plan-first public entry point for Issue #423 B2a.
966    ///
967    /// The current staged implementation already handles:
968    /// - fuse-only restructures, where each current node maps to exactly one
969    ///   target node;
970    /// - split-only restructures, where each target node maps to exactly one
971    ///   current node;
972    /// - swap-only restructures, where the current and target topologies are
973    ///   tree-isomorphic and only the site assignments differ;
974    /// - conservative path-based swap-then-fuse restructures, where the
975    ///   current nodes already map uniquely to target nodes but their target
976    ///   groups must be rearranged into contiguous path blocks before fusing;
977    /// - conservative mixed split-then-fuse restructures, where each current
978    ///   node has at most one cross-node target fragment that must retain the
979    ///   original external bonds.
980    ///
981    /// Unsupported patterns are reported explicitly. In particular, mixed
982    /// cases that require both split planning and a subsequent move/swap phase
983    /// may still remain intentionally staged behind placeholder errors while
984    /// the pure planner is expanded.
985    ///
986    /// Related types:
987    /// - [`RestructureOptions`] configures the split, transport, and optional
988    ///   final truncation phases.
989    /// - [`SiteIndexNetwork`] describes the desired final topology plus site
990    ///   grouping.
991    /// - [`TreeTN::split_to`](crate::treetn::TreeTN::split_to) and
992    ///   [`TreeTN::swap_site_indices`](crate::treetn::TreeTN::swap_site_indices)
993    ///   remain the lower-level building blocks that the executor will use.
994    ///
995    /// # Arguments
996    /// * `target` - Desired final topology and site grouping.
997    /// * `options` - Phase-specific options for split, transport, and optional
998    ///   final truncation.
999    ///
1000    /// # Returns
1001    /// A new `TreeTN` with the target node naming and target site-index
1002    /// network.
1003    ///
1004    /// # Errors
1005    /// Returns an error when the target is structurally incompatible with the
1006    /// current network, or when the requested restructure still needs the
1007    /// staged planner paths for mixed split/move/fuse execution.
1008    ///
1009    /// # Examples
1010    ///
1011    /// ```
1012    /// use std::collections::HashSet;
1013    ///
1014    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
1015    /// use tensor4all_treetn::{RestructureOptions, SiteIndexNetwork, TreeTN};
1016    ///
1017    /// # fn main() -> anyhow::Result<()> {
1018    /// let left = DynIndex::new_dyn(2);
1019    /// let right = DynIndex::new_dyn(2);
1020    /// let bond = DynIndex::new_dyn(1);
1021    /// let t0 = TensorDynLen::from_dense(vec![left.clone(), bond.clone()], vec![1.0, 0.0])?;
1022    /// let t1 = TensorDynLen::from_dense(vec![bond, right.clone()], vec![1.0, 0.0])?;
1023    /// let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1024    ///     vec![t0, t1],
1025    ///     vec!["A".to_string(), "B".to_string()],
1026    /// )?;
1027    ///
1028    /// let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1029    /// assert!(target
1030    ///     .add_node("AB".to_string(), HashSet::from([left.clone(), right.clone()]))
1031    ///     .is_ok());
1032    ///
1033    /// let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1034    ///
1035    /// assert_eq!(result.node_count(), 1);
1036    /// let dense = result.contract_to_tensor()?;
1037    /// let expected = treetn.contract_to_tensor()?;
1038    /// assert!((&dense - &expected).maxabs() < 1e-12);
1039    /// # Ok::<(), anyhow::Error>(())
1040    /// # }
1041    /// ```
1042    pub fn restructure_to<TargetV>(
1043        &self,
1044        target: &SiteIndexNetwork<TargetV, T::Index>,
1045        options: &RestructureOptions,
1046    ) -> Result<TreeTN<T, TargetV>>
1047    where
1048        TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
1049        <T::Index as IndexLike>::Id:
1050            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
1051    {
1052        let plan = build_plan::<T, V, TargetV>(self.site_index_network(), target)
1053            .context("restructure_to: build plan")?;
1054        execute_plan(self, plan, target, options).context("restructure_to: execute plan")
1055    }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060    use std::collections::HashSet;
1061
1062    use tensor4all_core::{DynIndex, IndexLike, TensorDynLen};
1063
1064    use super::*;
1065
1066    type FourSiteChainCase = (
1067        TreeTN<TensorDynLen, String>,
1068        DynIndex,
1069        DynIndex,
1070        DynIndex,
1071        DynIndex,
1072    );
1073
1074    fn two_node_chain() -> anyhow::Result<(TreeTN<TensorDynLen, String>, DynIndex, DynIndex)> {
1075        let left = DynIndex::new_dyn(2);
1076        let right = DynIndex::new_dyn(2);
1077        let bond = DynIndex::new_dyn(1);
1078        let t0 = TensorDynLen::from_dense(vec![left.clone(), bond.clone()], vec![1.0, 0.0])?;
1079        let t1 = TensorDynLen::from_dense(vec![bond, right.clone()], vec![1.0, 0.0])?;
1080        let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1081            vec![t0, t1],
1082            vec!["A".to_string(), "B".to_string()],
1083        )?;
1084        Ok((treetn, left, right))
1085    }
1086
1087    fn two_node_groups_of_two() -> anyhow::Result<FourSiteChainCase> {
1088        let x0 = DynIndex::new_dyn(2);
1089        let x1 = DynIndex::new_dyn(2);
1090        let y0 = DynIndex::new_dyn(2);
1091        let y1 = DynIndex::new_dyn(2);
1092        let bond = DynIndex::new_dyn(2);
1093        let left_tensor = TensorDynLen::from_dense(
1094            vec![x0.clone(), x1.clone(), bond.clone()],
1095            vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
1096        )?;
1097        let right_tensor = TensorDynLen::from_dense(
1098            vec![bond, y0.clone(), y1.clone()],
1099            vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
1100        )?;
1101        let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1102            vec![left_tensor, right_tensor],
1103            vec!["Left".to_string(), "Right".to_string()],
1104        )?;
1105        Ok((treetn, x0, x1, y0, y1))
1106    }
1107
1108    fn three_node_chain_for_swap() -> anyhow::Result<FourSiteChainCase> {
1109        let s0 = DynIndex::new_dyn(2);
1110        let s1 = DynIndex::new_dyn(2);
1111        let s2 = DynIndex::new_dyn(2);
1112        let s3 = DynIndex::new_dyn(2);
1113        let b01 = DynIndex::new_dyn(2);
1114        let b12 = DynIndex::new_dyn(2);
1115        let t0 = TensorDynLen::from_dense(
1116            vec![s0.clone(), s1.clone(), b01.clone()],
1117            vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0],
1118        )?;
1119        let t1 = TensorDynLen::from_dense(
1120            vec![b01.clone(), s2.clone(), b12.clone()],
1121            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
1122        )?;
1123        let t2 = TensorDynLen::from_dense(vec![b12, s3.clone()], vec![1.0, 2.0, 3.0, 4.0])?;
1124        let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1125            vec![t0, t1, t2],
1126            vec!["A".to_string(), "B".to_string(), "C".to_string()],
1127        )?;
1128        Ok((treetn, s0, s1, s2, s3))
1129    }
1130
1131    fn four_node_interleaved_chain() -> anyhow::Result<FourSiteChainCase> {
1132        let x0 = DynIndex::new_dyn(2);
1133        let x1 = DynIndex::new_dyn(2);
1134        let y0 = DynIndex::new_dyn(2);
1135        let y1 = DynIndex::new_dyn(2);
1136        let b01 = DynIndex::new_dyn(2);
1137        let b12 = DynIndex::new_dyn(2);
1138        let b23 = DynIndex::new_dyn(2);
1139        let t0 = TensorDynLen::from_dense(vec![x0.clone(), b01.clone()], vec![1.0, 0.0, 0.0, 1.0])?;
1140        let t1 = TensorDynLen::from_dense(
1141            vec![b01.clone(), x1.clone(), b12.clone()],
1142            vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
1143        )?;
1144        let t2 = TensorDynLen::from_dense(
1145            vec![b12.clone(), y0.clone(), b23.clone()],
1146            vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
1147        )?;
1148        let t3 = TensorDynLen::from_dense(vec![b23, y1.clone()], vec![1.0, 0.0, 0.0, 1.0])?;
1149        let treetn = TreeTN::<TensorDynLen, String>::from_tensors(
1150            vec![t0, t1, t2, t3],
1151            vec![
1152                "0".to_string(),
1153                "1".to_string(),
1154                "2".to_string(),
1155                "3".to_string(),
1156            ],
1157        )?;
1158        Ok((treetn, x0, x1, y0, y1))
1159    }
1160
1161    #[test]
1162    fn test_restructure_to_fuse_only_matches_target_structure() -> anyhow::Result<()> {
1163        let (treetn, left, right) = two_node_chain()?;
1164
1165        let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1166        target
1167            .add_node("AB".to_string(), HashSet::from([left, right]))
1168            .map_err(anyhow::Error::msg)?;
1169
1170        let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1171        let dense_expected = treetn.contract_to_tensor()?;
1172        let dense_actual = result.contract_to_tensor()?;
1173
1174        assert_eq!(result.node_count(), 1);
1175        assert_eq!(result.site_index_network().node_count(), 1);
1176        assert!((&dense_actual - &dense_expected).maxabs() < 1e-12);
1177
1178        Ok(())
1179    }
1180
1181    #[test]
1182    fn test_restructure_to_split_only_matches_target_structure() -> anyhow::Result<()> {
1183        let (treetn, left, right) = two_node_chain()?;
1184
1185        let mut fused_target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1186        fused_target
1187            .add_node(
1188                "AB".to_string(),
1189                HashSet::from([left.clone(), right.clone()]),
1190            )
1191            .map_err(anyhow::Error::msg)?;
1192        let fused = treetn.restructure_to(&fused_target, &RestructureOptions::default())?;
1193
1194        let mut split_target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1195        split_target
1196            .add_node("Left".to_string(), HashSet::from([left]))
1197            .map_err(anyhow::Error::msg)?;
1198        split_target
1199            .add_node("Right".to_string(), HashSet::from([right]))
1200            .map_err(anyhow::Error::msg)?;
1201        split_target
1202            .add_edge(&"Left".to_string(), &"Right".to_string())
1203            .map_err(anyhow::Error::msg)?;
1204
1205        let result = fused.restructure_to(&split_target, &RestructureOptions::default())?;
1206        let dense_expected = fused.contract_to_tensor()?;
1207        let dense_actual = result.contract_to_tensor()?;
1208
1209        assert_eq!(result.node_count(), 2);
1210        assert!((&dense_actual - &dense_expected).maxabs() < 1e-12);
1211
1212        Ok(())
1213    }
1214
1215    #[test]
1216    fn test_restructure_to_swap_only_matches_target_structure() -> anyhow::Result<()> {
1217        let (treetn, s0, s1, s2, s3) = three_node_chain_for_swap()?;
1218
1219        let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1220        target
1221            .add_node("X".to_string(), HashSet::from([s0.clone()]))
1222            .map_err(anyhow::Error::msg)?;
1223        target
1224            .add_node("Y".to_string(), HashSet::from([s1.clone(), s2.clone()]))
1225            .map_err(anyhow::Error::msg)?;
1226        target
1227            .add_node("Z".to_string(), HashSet::from([s3.clone()]))
1228            .map_err(anyhow::Error::msg)?;
1229        target
1230            .add_edge(&"X".to_string(), &"Y".to_string())
1231            .map_err(anyhow::Error::msg)?;
1232        target
1233            .add_edge(&"Y".to_string(), &"Z".to_string())
1234            .map_err(anyhow::Error::msg)?;
1235
1236        let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1237        let dense_expected = treetn.contract_to_tensor()?;
1238        let dense_actual = result.contract_to_tensor()?;
1239
1240        assert_eq!(
1241            result
1242                .site_index_network()
1243                .find_node_by_index_id(s0.id())
1244                .map(|name| name.as_str()),
1245            Some("X")
1246        );
1247        assert_eq!(
1248            result
1249                .site_index_network()
1250                .find_node_by_index_id(s1.id())
1251                .map(|name| name.as_str()),
1252            Some("Y")
1253        );
1254        assert_eq!(
1255            result
1256                .site_index_network()
1257                .find_node_by_index_id(s2.id())
1258                .map(|name| name.as_str()),
1259            Some("Y")
1260        );
1261        assert_eq!(
1262            result
1263                .site_index_network()
1264                .find_node_by_index_id(s3.id())
1265                .map(|name| name.as_str()),
1266            Some("Z")
1267        );
1268        assert!((&dense_actual - &dense_expected).maxabs() < 1e-10);
1269
1270        Ok(())
1271    }
1272
1273    #[test]
1274    fn test_restructure_to_split_then_fuse_mixed_case() -> anyhow::Result<()> {
1275        let (treetn, x0, x1, y0, y1) = two_node_groups_of_two()?;
1276
1277        let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1278        target
1279            .add_node("X".to_string(), HashSet::from([x0.clone()]))
1280            .map_err(anyhow::Error::msg)?;
1281        target
1282            .add_node("Y".to_string(), HashSet::from([x1.clone(), y0.clone()]))
1283            .map_err(anyhow::Error::msg)?;
1284        target
1285            .add_node("Z".to_string(), HashSet::from([y1.clone()]))
1286            .map_err(anyhow::Error::msg)?;
1287        target
1288            .add_edge(&"X".to_string(), &"Y".to_string())
1289            .map_err(anyhow::Error::msg)?;
1290        target
1291            .add_edge(&"Y".to_string(), &"Z".to_string())
1292            .map_err(anyhow::Error::msg)?;
1293
1294        let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1295        let dense_expected = treetn.contract_to_tensor()?;
1296        let dense_actual = result.contract_to_tensor()?;
1297
1298        assert_eq!(result.node_count(), 3);
1299        assert_eq!(result.edge_count(), 2);
1300        assert!(result
1301            .site_index_network()
1302            .share_equivalent_site_index_network(&target));
1303        assert_eq!(
1304            result
1305                .site_index_network()
1306                .find_node_by_index_id(x0.id())
1307                .map(|name| name.as_str()),
1308            Some("X")
1309        );
1310        assert_eq!(
1311            result
1312                .site_index_network()
1313                .find_node_by_index_id(x1.id())
1314                .map(|name| name.as_str()),
1315            Some("Y")
1316        );
1317        assert_eq!(
1318            result
1319                .site_index_network()
1320                .find_node_by_index_id(y0.id())
1321                .map(|name| name.as_str()),
1322            Some("Y")
1323        );
1324        assert_eq!(
1325            result
1326                .site_index_network()
1327                .find_node_by_index_id(y1.id())
1328                .map(|name| name.as_str()),
1329            Some("Z")
1330        );
1331        assert!((&dense_actual - &dense_expected).maxabs() < 1e-10);
1332
1333        Ok(())
1334    }
1335
1336    #[test]
1337    fn test_restructure_to_swap_then_fuse_mixed_case() -> anyhow::Result<()> {
1338        let (treetn, x0, x1, y0, y1) = four_node_interleaved_chain()?;
1339
1340        let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1341        target
1342            .add_node("X".to_string(), HashSet::from([x0.clone(), y0.clone()]))
1343            .map_err(anyhow::Error::msg)?;
1344        target
1345            .add_node("Y".to_string(), HashSet::from([x1.clone(), y1.clone()]))
1346            .map_err(anyhow::Error::msg)?;
1347        target
1348            .add_edge(&"X".to_string(), &"Y".to_string())
1349            .map_err(anyhow::Error::msg)?;
1350
1351        let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1352        let dense_expected = treetn.contract_to_tensor()?;
1353        let dense_actual = result.contract_to_tensor()?;
1354
1355        assert_eq!(result.node_count(), 2);
1356        assert_eq!(result.edge_count(), 1);
1357        assert!(result
1358            .site_index_network()
1359            .share_equivalent_site_index_network(&target));
1360        assert_eq!(
1361            result
1362                .site_index_network()
1363                .find_node_by_index_id(x0.id())
1364                .map(|name| name.as_str()),
1365            Some("X")
1366        );
1367        assert_eq!(
1368            result
1369                .site_index_network()
1370                .find_node_by_index_id(y0.id())
1371                .map(|name| name.as_str()),
1372            Some("X")
1373        );
1374        assert_eq!(
1375            result
1376                .site_index_network()
1377                .find_node_by_index_id(x1.id())
1378                .map(|name| name.as_str()),
1379            Some("Y")
1380        );
1381        assert_eq!(
1382            result
1383                .site_index_network()
1384                .find_node_by_index_id(y1.id())
1385                .map(|name| name.as_str()),
1386            Some("Y")
1387        );
1388        assert!((&dense_actual - &dense_expected).maxabs() < 1e-10);
1389
1390        Ok(())
1391    }
1392
1393    #[test]
1394    fn test_restructure_to_two_node_swap_only_cross_pairing() -> anyhow::Result<()> {
1395        let (treetn, x0, x1, y0, y1) = two_node_groups_of_two()?;
1396
1397        let mut target: SiteIndexNetwork<String, DynIndex> = SiteIndexNetwork::new();
1398        target
1399            .add_node("X".to_string(), HashSet::from([x0.clone(), y0.clone()]))
1400            .map_err(anyhow::Error::msg)?;
1401        target
1402            .add_node("Y".to_string(), HashSet::from([x1.clone(), y1.clone()]))
1403            .map_err(anyhow::Error::msg)?;
1404        target
1405            .add_edge(&"X".to_string(), &"Y".to_string())
1406            .map_err(anyhow::Error::msg)?;
1407
1408        let result = treetn.restructure_to(&target, &RestructureOptions::default())?;
1409        let dense_expected = treetn.contract_to_tensor()?;
1410        let dense_actual = result.contract_to_tensor()?;
1411
1412        assert_eq!(result.node_count(), 2);
1413        assert_eq!(result.edge_count(), 1);
1414        assert!(result
1415            .site_index_network()
1416            .share_equivalent_site_index_network(&target));
1417        assert_eq!(
1418            result
1419                .site_index_network()
1420                .find_node_by_index_id(x0.id())
1421                .map(|name| name.as_str()),
1422            Some("X")
1423        );
1424        assert_eq!(
1425            result
1426                .site_index_network()
1427                .find_node_by_index_id(y0.id())
1428                .map(|name| name.as_str()),
1429            Some("X")
1430        );
1431        assert_eq!(
1432            result
1433                .site_index_network()
1434                .find_node_by_index_id(x1.id())
1435                .map(|name| name.as_str()),
1436            Some("Y")
1437        );
1438        assert_eq!(
1439            result
1440                .site_index_network()
1441                .find_node_by_index_id(y1.id())
1442                .map(|name| name.as_str()),
1443            Some("Y")
1444        );
1445        assert!((&dense_actual - &dense_expected).maxabs() < 1e-10);
1446
1447        Ok(())
1448    }
1449}