Skip to main content

tensor4all_treetn/treetn/
swap.rs

1//! Site index swap: reorder which node holds which site index.
2//!
3//! Implements swapping site indices between adjacent nodes along the tree
4//! so that the network reaches a target assignment (index id -> node name).
5
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::hash::Hash;
8
9use anyhow::{Context, Result};
10use petgraph::stable_graph::NodeIndex;
11
12use tensor4all_core::{FactorizeOptions, FactorizeResult, IndexLike, TensorLike};
13
14use crate::node_name_network::NodeNameNetwork;
15
16use super::{localupdate::LocalUpdateSweepPlan, TreeTN};
17
18// ============================================================================
19// Factorize with trivial-bond handling
20// ============================================================================
21
22/// Factorize a tensor into left and right parts connected by a bond index.
23///
24/// Extends [`TensorLike::factorize`] to handle degenerate cases where all
25/// indices go to one side (empty `left_inds` or `left_inds == all_inds`).
26/// For these cases a dimension-1 trivial bond is created so that
27/// `contract(left, right)` recovers the input tensor exactly.
28///
29/// With `Canonical::Left` (the only mode used by swap):
30/// - **Normal case**: delegates to `TensorLike::factorize`.
31/// - **Empty `left_inds`**: `left = [1]` (dim-1 scalar isometry),
32///   `right = tensor ⊗ [1]` (acquires the trivial bond).
33/// - **Full `left_inds`**: `left = (tensor ⊗ [1]) / ‖tensor‖`,
34///   `right = [‖tensor‖]` (norm on the right side, left is isometric).
35pub(crate) fn factorize_or_trivial<T>(
36    tensor: &T,
37    left_inds: &[T::Index],
38    all_inds: &[T::Index],
39    factorize_options: &FactorizeOptions,
40) -> anyhow::Result<FactorizeResult<T>>
41where
42    T: TensorLike,
43    <T::Index as IndexLike>::Id: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
44{
45    if left_inds.is_empty() {
46        // All indices go to the right side.
47        let bond = <T::Index as IndexLike>::create_dummy_link_pair().0;
48        let left = T::onehot(&[(bond.clone(), 0)])
49            .map_err(|e| anyhow::anyhow!("factorize_or_trivial: left onehot: {}", e))?;
50        let right_bond = T::onehot(&[(bond.clone(), 0)])
51            .map_err(|e| anyhow::anyhow!("factorize_or_trivial: right onehot: {}", e))?;
52        let right = tensor
53            .outer_product(&right_bond)
54            .context("factorize_or_trivial: right outer_product")?;
55        return Ok(FactorizeResult {
56            left,
57            right,
58            bond_index: bond,
59            singular_values: None,
60            rank: 1,
61        });
62    }
63
64    if left_inds.len() == all_inds.len() {
65        // All indices go to the left side.
66        let bond = <T::Index as IndexLike>::create_dummy_link_pair().0;
67        let left_bond = T::onehot(&[(bond.clone(), 0)])
68            .map_err(|e| anyhow::anyhow!("factorize_or_trivial: left onehot: {}", e))?;
69        let mut left = tensor
70            .outer_product(&left_bond)
71            .context("factorize_or_trivial: left outer_product")?;
72        let mut right = T::onehot(&[(bond.clone(), 0)])
73            .map_err(|e| anyhow::anyhow!("factorize_or_trivial: right onehot: {}", e))?;
74        let left_norm = left.norm();
75        if left_norm > 0.0 {
76            left = left
77                .scale(tensor4all_core::AnyScalar::new_real(1.0 / left_norm))
78                .context("factorize_or_trivial: normalize left")?;
79            right = right
80                .scale(tensor4all_core::AnyScalar::new_real(left_norm))
81                .context("factorize_or_trivial: scale right")?;
82        }
83        return Ok(FactorizeResult {
84            left,
85            right,
86            bond_index: bond,
87            singular_values: None,
88            rank: 1,
89        });
90    }
91
92    // Normal case: delegate to TensorLike::factorize
93    tensor
94        .factorize(left_inds, factorize_options)
95        .map_err(|e| anyhow::anyhow!("factorize_or_trivial: factorize: {}", e))
96}
97
98// ============================================================================
99// SwapOptions
100// ============================================================================
101
102/// Options for site index swap (truncation during SVD).
103///
104/// When `max_rank` or `rtol` are set, the swap may introduce approximation error
105/// by truncating bond dimension. Default (both `None`) allows rank growth to preserve
106/// the tensor exactly.
107#[derive(Debug, Clone, Default)]
108pub struct SwapOptions {
109    /// Maximum bond dimension after each SVD (None = no limit).
110    pub max_rank: Option<usize>,
111    /// Relative tolerance for singular value truncation (None = no truncation).
112    pub rtol: Option<f64>,
113}
114
115// ============================================================================
116// ScheduledSwapStep
117// ============================================================================
118
119/// A single two-site update step in a pre-computed swap schedule.
120///
121/// Use this to inspect exactly which edge is updated, whether the canonical
122/// center must be transported first, and which site indices must end up on
123/// each side of the edge after the local factorization.
124///
125/// Related types:
126/// - [`SwapSchedule`] stores the full ordered sequence of these steps.
127/// - [`SwapOptions`] controls truncation only during execution, not schedule construction.
128///
129/// # Examples
130///
131/// ```
132/// use std::collections::HashSet;
133///
134/// use tensor4all_treetn::ScheduledSwapStep;
135///
136/// let step = ScheduledSwapStep {
137///     transport_path: vec!["L0".to_string(), "C".to_string()],
138///     node_a: "C".to_string(),
139///     node_b: "L1".to_string(),
140///     a_side_sites: HashSet::from(["s1".to_string()]),
141///     b_side_sites: HashSet::from(["s0".to_string()]),
142/// };
143///
144/// assert_eq!(step.transport_path, vec!["L0".to_string(), "C".to_string()]);
145/// assert!(step.a_side_sites.contains("s1"));
146/// assert!(step.b_side_sites.contains("s0"));
147/// ```
148#[derive(Debug, Clone, PartialEq, Eq)]
149pub struct ScheduledSwapStep<V, Id>
150where
151    Id: Eq + Hash,
152{
153    /// Path to transport the canonical center before the swap.
154    ///
155    /// Empty when the center is already at `node_a` or `node_b`.
156    /// Otherwise this is `[current_center, ..., node_a]`.
157    pub transport_path: Vec<V>,
158    /// The first node in the directed sweep edge.
159    pub node_a: V,
160    /// The second node in the directed sweep edge.
161    pub node_b: V,
162    /// Site index IDs that should live on `node_a`'s side after this step.
163    pub a_side_sites: HashSet<Id>,
164    /// Site index IDs that should live on `node_b`'s side after this step.
165    pub b_side_sites: HashSet<Id>,
166}
167
168// ============================================================================
169// SwapSchedule
170// ============================================================================
171
172/// Pre-computed swap schedule for `swap_site_indices`.
173///
174/// The schedule is derived purely from graph structure plus current and target
175/// site assignments. It contains no tensor data and can therefore be built,
176/// inspected, and unit-tested without performing any tensor contractions.
177///
178/// Related types:
179/// - [`ScheduledSwapStep`] is one local two-site update in this schedule.
180/// - [`SwapOptions`] affects execution of the schedule, but not its contents.
181///
182/// # Examples
183///
184/// ```
185/// use std::collections::HashMap;
186///
187/// use tensor4all_treetn::{NodeNameNetwork, SwapSchedule};
188///
189/// let mut topology = NodeNameNetwork::new();
190/// topology.add_node("A".to_string()).unwrap();
191/// topology.add_node("B".to_string()).unwrap();
192/// topology.add_edge(&"A".to_string(), &"B".to_string()).unwrap();
193///
194/// let current = HashMap::from([("s0".to_string(), "A".to_string())]);
195/// let target = HashMap::from([("s0".to_string(), "B".to_string())]);
196/// let root = "A".to_string();
197///
198/// let schedule = SwapSchedule::build(&topology, &current, &target, &root).unwrap();
199///
200/// assert_eq!(schedule.root, "A");
201/// assert_eq!(schedule.steps.len(), 1);
202/// assert_eq!(schedule.steps[0].node_a, "A");
203/// assert_eq!(schedule.steps[0].node_b, "B");
204/// ```
205#[derive(Debug, Clone)]
206pub struct SwapSchedule<V, Id>
207where
208    Id: Eq + Hash,
209{
210    /// Root used for the base Euler sweep and initial canonicalization.
211    pub root: V,
212    /// Fully expanded sequence of swap steps.
213    pub steps: Vec<ScheduledSwapStep<V, Id>>,
214}
215
216impl<V, Id> SwapSchedule<V, Id>
217where
218    V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
219    Id: Clone + Hash + Eq + std::fmt::Debug,
220{
221    /// Build a swap schedule from topology plus current and target assignments.
222    ///
223    /// The returned schedule is a pure graph computation. It simulates site
224    /// positions through repeated Euler-tour sweeps, emits only edges where at
225    /// least one targeted site index crosses, and records any required
226    /// canonical-center transport between non-adjacent emitted swap steps.
227    ///
228    /// # Arguments
229    /// * `topology` - Tree topology whose nodes are named by `V`.
230    /// * `current_assignment` - Current node for every site index ID in the network.
231    /// * `target_assignment` - Partial target map; indices not listed keep their current side.
232    /// * `root` - Sweep root and assumed initial canonical center.
233    ///
234    /// # Returns
235    /// A [`SwapSchedule`] containing the ordered local updates needed to realize `target_assignment`.
236    ///
237    /// # Errors
238    /// Returns an error if `root` is missing, an index ID in `target_assignment`
239    /// is unknown, a referenced node is missing from `topology`, no tree path
240    /// exists between required nodes, or the simulated sweeps fail to satisfy
241    /// the requested target assignment within the tree-diameter pass bound.
242    pub fn build(
243        topology: &NodeNameNetwork<V>,
244        current_assignment: &HashMap<Id, V>,
245        target_assignment: &HashMap<Id, V>,
246        root: &V,
247    ) -> Result<Self> {
248        if !topology.has_node(root) {
249            return Err(anyhow::anyhow!(
250                "SwapSchedule::build: root {:?} not in topology",
251                root
252            ));
253        }
254
255        for (index_id, current_node) in current_assignment {
256            if !topology.has_node(current_node) {
257                return Err(anyhow::anyhow!(
258                    "SwapSchedule::build: current node {:?} for index {:?} is not in the topology",
259                    current_node,
260                    index_id
261                ));
262            }
263        }
264
265        for (index_id, target_node) in target_assignment {
266            if !current_assignment.contains_key(index_id) {
267                return Err(anyhow::anyhow!(
268                    "SwapSchedule::build: target_assignment contains index id {:?} which is not in the network",
269                    index_id
270                ));
271            }
272            if !topology.has_node(target_node) {
273                return Err(anyhow::anyhow!(
274                    "SwapSchedule::build: target node {:?} for index {:?} is not in the topology",
275                    target_node,
276                    index_id
277                ));
278            }
279        }
280
281        let oracle = SubtreeOracle::new(topology, root)?;
282        let base_sweep = LocalUpdateSweepPlan::new(topology, root, 2)
283            .ok_or_else(|| anyhow::anyhow!("SwapSchedule::build: failed to build 2-site sweep"))?;
284        let max_passes = tree_diameter(topology)?;
285
286        let mut position = current_assignment.clone();
287        let mut center = root.clone();
288        let mut steps = Vec::new();
289
290        for _pass in 0..max_passes {
291            if positions_satisfy_targets(&position, target_assignment) {
292                break;
293            }
294
295            let mut any_moved_this_pass = false;
296
297            for sweep_step in base_sweep.iter() {
298                if sweep_step.nodes.len() != 2 {
299                    continue;
300                }
301
302                let node_a = sweep_step.nodes[0].clone();
303                let node_b = sweep_step.nodes[1].clone();
304
305                let mut a_side_sites = HashSet::new();
306                let mut b_side_sites = HashSet::new();
307                let mut any_crossing = false;
308                let mut any_site_on_edge = false;
309
310                for (index_id, current_node) in &position {
311                    if current_node != &node_a && current_node != &node_b {
312                        continue;
313                    }
314
315                    any_site_on_edge = true;
316
317                    if let Some(target_node) = target_assignment.get(index_id) {
318                        if oracle.is_target_on_a_side(&node_a, &node_b, target_node) {
319                            a_side_sites.insert(index_id.clone());
320                            if current_node == &node_b {
321                                any_crossing = true;
322                            }
323                        } else {
324                            b_side_sites.insert(index_id.clone());
325                            if current_node == &node_a {
326                                any_crossing = true;
327                            }
328                        }
329                    } else if current_node == &node_a {
330                        a_side_sites.insert(index_id.clone());
331                    } else {
332                        b_side_sites.insert(index_id.clone());
333                    }
334                }
335
336                if !any_site_on_edge || !any_crossing {
337                    continue;
338                }
339
340                let transport_path = if center == node_a || center == node_b {
341                    Vec::new()
342                } else {
343                    tree_path(topology, &center, &node_a)?
344                };
345
346                steps.push(ScheduledSwapStep {
347                    transport_path,
348                    node_a: node_a.clone(),
349                    node_b: node_b.clone(),
350                    a_side_sites: a_side_sites.clone(),
351                    b_side_sites: b_side_sites.clone(),
352                });
353
354                for index_id in &a_side_sites {
355                    position.insert(index_id.clone(), node_a.clone());
356                }
357                for index_id in &b_side_sites {
358                    position.insert(index_id.clone(), node_b.clone());
359                }
360
361                center = node_b;
362                any_moved_this_pass = true;
363            }
364
365            if !any_moved_this_pass {
366                break;
367            }
368        }
369
370        if !positions_satisfy_targets(&position, target_assignment) {
371            return Err(anyhow::anyhow!(
372                "SwapSchedule::build: did not converge within {} passes",
373                max_passes
374            ));
375        }
376
377        Ok(Self {
378            root: root.clone(),
379            steps,
380        })
381    }
382}
383
384fn positions_satisfy_targets<V, Id>(
385    position: &HashMap<Id, V>,
386    target_assignment: &HashMap<Id, V>,
387) -> bool
388where
389    V: Eq,
390    Id: Hash + Eq,
391{
392    target_assignment.iter().all(|(index_id, target_node)| {
393        position
394            .get(index_id)
395            .is_some_and(|node| node == target_node)
396    })
397}
398
399fn tree_path<V>(topology: &NodeNameNetwork<V>, from: &V, to: &V) -> Result<Vec<V>>
400where
401    V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
402{
403    let from_idx = topology
404        .node_index(from)
405        .ok_or_else(|| anyhow::anyhow!("tree_path: node {:?} not found", from))?;
406    let to_idx = topology
407        .node_index(to)
408        .ok_or_else(|| anyhow::anyhow!("tree_path: node {:?} not found", to))?;
409
410    topology
411        .path_between(from_idx, to_idx)
412        .ok_or_else(|| anyhow::anyhow!("tree_path: no path between {:?} and {:?}", from, to))?
413        .into_iter()
414        .map(|node_idx| {
415            topology
416                .node_name(node_idx)
417                .cloned()
418                .ok_or_else(|| anyhow::anyhow!("tree_path: node name not found"))
419        })
420        .collect()
421}
422
423fn tree_diameter<V>(topology: &NodeNameNetwork<V>) -> Result<usize>
424where
425    V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
426{
427    let mut node_indices = topology.graph().node_indices();
428    let Some(start) = node_indices.next() else {
429        return Ok(0);
430    };
431
432    let (farthest, _) = farthest_node(topology, start)?;
433    let (_, diameter) = farthest_node(topology, farthest)?;
434    Ok(diameter)
435}
436
437fn farthest_node<V>(topology: &NodeNameNetwork<V>, start: NodeIndex) -> Result<(NodeIndex, usize)>
438where
439    V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
440{
441    let graph = topology.graph();
442    let mut visited = HashSet::new();
443    let mut queue = VecDeque::from([(start, 0usize)]);
444    let mut farthest = (start, 0usize);
445
446    visited.insert(start);
447
448    while let Some((node, distance)) = queue.pop_front() {
449        if distance > farthest.1 {
450            farthest = (node, distance);
451        }
452
453        for neighbor in graph.neighbors(node) {
454            if visited.insert(neighbor) {
455                queue.push_back((neighbor, distance + 1));
456            }
457        }
458    }
459
460    if visited.len() != graph.node_count() {
461        return Err(anyhow::anyhow!(
462            "SwapSchedule::build: topology must be connected"
463        ));
464    }
465
466    Ok(farthest)
467}
468
469// ============================================================================
470// Helpers: current assignment
471// ============================================================================
472
473/// Build index id -> node name from a TreeTN (all site indices).
474pub(crate) fn current_site_assignment<T, V>(
475    treetn: &TreeTN<T, V>,
476) -> HashMap<<T::Index as IndexLike>::Id, V>
477where
478    T: TensorLike,
479    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
480{
481    let mut out: HashMap<<T::Index as IndexLike>::Id, V> = HashMap::new();
482    for node_name in treetn.node_names() {
483        if let Some(site_space) = treetn.site_space(&node_name) {
484            for idx in site_space {
485                out.insert(idx.id().to_owned(), node_name.clone());
486            }
487        }
488    }
489    out
490}
491
492// ============================================================================
493// SubtreeOracle
494// ============================================================================
495
496/// Pre-computed DFS timestamps enabling O(1) "which side of edge?" queries.
497///
498/// For an edge (A, B): `is_target_on_a_side(A, B, target)` returns true iff
499/// `target` is in the component containing A when the edge is removed.
500pub(crate) struct SubtreeOracle<V> {
501    in_time: HashMap<V, usize>,
502    out_time: HashMap<V, usize>,
503}
504
505impl<V> SubtreeOracle<V>
506where
507    V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
508{
509    /// Build from a tree topology rooted at `root`.
510    /// DFS entry/exit timestamps are computed iteratively.
511    pub(crate) fn new(topology: &NodeNameNetwork<V>, root: &V) -> Result<Self> {
512        let root_idx = topology
513            .node_index(root)
514            .ok_or_else(|| anyhow::anyhow!("SubtreeOracle: root {:?} not in topology", root))?;
515
516        let mut in_time: HashMap<V, usize> = HashMap::new();
517        let mut out_time: HashMap<V, usize> = HashMap::new();
518        let mut timer = 0usize;
519
520        // Stack: (node_idx, parent_idx, is_exit)
521        let mut stack: Vec<(NodeIndex, Option<NodeIndex>, bool)> = vec![(root_idx, None, false)];
522
523        while let Some((node_idx, parent_idx, is_exit)) = stack.pop() {
524            let name = topology
525                .node_name(node_idx)
526                .ok_or_else(|| anyhow::anyhow!("SubtreeOracle: node name not found"))?
527                .clone();
528            if is_exit {
529                out_time.insert(name, timer);
530                timer += 1;
531            } else {
532                in_time.insert(name, timer);
533                timer += 1;
534                stack.push((node_idx, parent_idx, true));
535                let graph = topology.graph();
536                for neighbor in graph.neighbors(node_idx) {
537                    if Some(neighbor) != parent_idx {
538                        stack.push((neighbor, Some(node_idx), false));
539                    }
540                }
541            }
542        }
543
544        Ok(Self { in_time, out_time })
545    }
546
547    /// Returns `true` iff `target` is on the A-side of edge (A, B).
548    ///
549    /// A-side = the connected component containing A after the (A,B) edge is removed.
550    pub(crate) fn is_target_on_a_side(&self, node_a: &V, node_b: &V, target: &V) -> bool {
551        if target == node_a {
552            return true;
553        }
554        if target == node_b {
555            return false;
556        }
557        let in_a = match self.in_time.get(node_a) {
558            Some(&t) => t,
559            None => return false,
560        };
561        let out_a = match self.out_time.get(node_a) {
562            Some(&t) => t,
563            None => return false,
564        };
565        let in_b = match self.in_time.get(node_b) {
566            Some(&t) => t,
567            None => return false,
568        };
569        let out_b = match self.out_time.get(node_b) {
570            Some(&t) => t,
571            None => return false,
572        };
573        let in_t = match self.in_time.get(target) {
574            Some(&t) => t,
575            None => return false,
576        };
577        let out_t = match self.out_time.get(target) {
578            Some(&t) => t,
579            None => return false,
580        };
581
582        if in_a <= in_b && out_b <= out_a {
583            !(in_b <= in_t && out_t <= out_b)
584        } else {
585            in_a <= in_t && out_t <= out_a
586        }
587    }
588}
589
590#[cfg(test)]
591mod tests;