1use std::collections::{HashMap, HashSet, VecDeque};
7use std::hash::Hash;
8
9use anyhow::{Context, Result};
10use petgraph::stable_graph::NodeIndex;
11
12use tensor4all_core::{FactorizeOptions, FactorizeResult, IndexLike, TensorLike};
13
14use crate::node_name_network::NodeNameNetwork;
15
16use super::{localupdate::LocalUpdateSweepPlan, TreeTN};
17
18pub(crate) fn factorize_or_trivial<T>(
36 tensor: &T,
37 left_inds: &[T::Index],
38 all_inds: &[T::Index],
39 factorize_options: &FactorizeOptions,
40) -> anyhow::Result<FactorizeResult<T>>
41where
42 T: TensorLike,
43 <T::Index as IndexLike>::Id: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
44{
45 if left_inds.is_empty() {
46 let bond = <T::Index as IndexLike>::create_dummy_link_pair().0;
48 let left = T::onehot(&[(bond.clone(), 0)])
49 .map_err(|e| anyhow::anyhow!("factorize_or_trivial: left onehot: {}", e))?;
50 let right_bond = T::onehot(&[(bond.clone(), 0)])
51 .map_err(|e| anyhow::anyhow!("factorize_or_trivial: right onehot: {}", e))?;
52 let right = tensor
53 .outer_product(&right_bond)
54 .context("factorize_or_trivial: right outer_product")?;
55 return Ok(FactorizeResult {
56 left,
57 right,
58 bond_index: bond,
59 singular_values: None,
60 rank: 1,
61 });
62 }
63
64 if left_inds.len() == all_inds.len() {
65 let bond = <T::Index as IndexLike>::create_dummy_link_pair().0;
67 let left_bond = T::onehot(&[(bond.clone(), 0)])
68 .map_err(|e| anyhow::anyhow!("factorize_or_trivial: left onehot: {}", e))?;
69 let mut left = tensor
70 .outer_product(&left_bond)
71 .context("factorize_or_trivial: left outer_product")?;
72 let mut right = T::onehot(&[(bond.clone(), 0)])
73 .map_err(|e| anyhow::anyhow!("factorize_or_trivial: right onehot: {}", e))?;
74 let left_norm = left.norm();
75 if left_norm > 0.0 {
76 left = left
77 .scale(tensor4all_core::AnyScalar::new_real(1.0 / left_norm))
78 .context("factorize_or_trivial: normalize left")?;
79 right = right
80 .scale(tensor4all_core::AnyScalar::new_real(left_norm))
81 .context("factorize_or_trivial: scale right")?;
82 }
83 return Ok(FactorizeResult {
84 left,
85 right,
86 bond_index: bond,
87 singular_values: None,
88 rank: 1,
89 });
90 }
91
92 tensor
94 .factorize(left_inds, factorize_options)
95 .map_err(|e| anyhow::anyhow!("factorize_or_trivial: factorize: {}", e))
96}
97
98#[derive(Debug, Clone, Default)]
108pub struct SwapOptions {
109 pub max_rank: Option<usize>,
111 pub rtol: Option<f64>,
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
149pub struct ScheduledSwapStep<V, Id>
150where
151 Id: Eq + Hash,
152{
153 pub transport_path: Vec<V>,
158 pub node_a: V,
160 pub node_b: V,
162 pub a_side_sites: HashSet<Id>,
164 pub b_side_sites: HashSet<Id>,
166}
167
168#[derive(Debug, Clone)]
206pub struct SwapSchedule<V, Id>
207where
208 Id: Eq + Hash,
209{
210 pub root: V,
212 pub steps: Vec<ScheduledSwapStep<V, Id>>,
214}
215
216impl<V, Id> SwapSchedule<V, Id>
217where
218 V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
219 Id: Clone + Hash + Eq + std::fmt::Debug,
220{
221 pub fn build(
243 topology: &NodeNameNetwork<V>,
244 current_assignment: &HashMap<Id, V>,
245 target_assignment: &HashMap<Id, V>,
246 root: &V,
247 ) -> Result<Self> {
248 if !topology.has_node(root) {
249 return Err(anyhow::anyhow!(
250 "SwapSchedule::build: root {:?} not in topology",
251 root
252 ));
253 }
254
255 for (index_id, current_node) in current_assignment {
256 if !topology.has_node(current_node) {
257 return Err(anyhow::anyhow!(
258 "SwapSchedule::build: current node {:?} for index {:?} is not in the topology",
259 current_node,
260 index_id
261 ));
262 }
263 }
264
265 for (index_id, target_node) in target_assignment {
266 if !current_assignment.contains_key(index_id) {
267 return Err(anyhow::anyhow!(
268 "SwapSchedule::build: target_assignment contains index id {:?} which is not in the network",
269 index_id
270 ));
271 }
272 if !topology.has_node(target_node) {
273 return Err(anyhow::anyhow!(
274 "SwapSchedule::build: target node {:?} for index {:?} is not in the topology",
275 target_node,
276 index_id
277 ));
278 }
279 }
280
281 let oracle = SubtreeOracle::new(topology, root)?;
282 let base_sweep = LocalUpdateSweepPlan::new(topology, root, 2)
283 .ok_or_else(|| anyhow::anyhow!("SwapSchedule::build: failed to build 2-site sweep"))?;
284 let max_passes = tree_diameter(topology)?;
285
286 let mut position = current_assignment.clone();
287 let mut center = root.clone();
288 let mut steps = Vec::new();
289
290 for _pass in 0..max_passes {
291 if positions_satisfy_targets(&position, target_assignment) {
292 break;
293 }
294
295 let mut any_moved_this_pass = false;
296
297 for sweep_step in base_sweep.iter() {
298 if sweep_step.nodes.len() != 2 {
299 continue;
300 }
301
302 let node_a = sweep_step.nodes[0].clone();
303 let node_b = sweep_step.nodes[1].clone();
304
305 let mut a_side_sites = HashSet::new();
306 let mut b_side_sites = HashSet::new();
307 let mut any_crossing = false;
308 let mut any_site_on_edge = false;
309
310 for (index_id, current_node) in &position {
311 if current_node != &node_a && current_node != &node_b {
312 continue;
313 }
314
315 any_site_on_edge = true;
316
317 if let Some(target_node) = target_assignment.get(index_id) {
318 if oracle.is_target_on_a_side(&node_a, &node_b, target_node) {
319 a_side_sites.insert(index_id.clone());
320 if current_node == &node_b {
321 any_crossing = true;
322 }
323 } else {
324 b_side_sites.insert(index_id.clone());
325 if current_node == &node_a {
326 any_crossing = true;
327 }
328 }
329 } else if current_node == &node_a {
330 a_side_sites.insert(index_id.clone());
331 } else {
332 b_side_sites.insert(index_id.clone());
333 }
334 }
335
336 if !any_site_on_edge || !any_crossing {
337 continue;
338 }
339
340 let transport_path = if center == node_a || center == node_b {
341 Vec::new()
342 } else {
343 tree_path(topology, ¢er, &node_a)?
344 };
345
346 steps.push(ScheduledSwapStep {
347 transport_path,
348 node_a: node_a.clone(),
349 node_b: node_b.clone(),
350 a_side_sites: a_side_sites.clone(),
351 b_side_sites: b_side_sites.clone(),
352 });
353
354 for index_id in &a_side_sites {
355 position.insert(index_id.clone(), node_a.clone());
356 }
357 for index_id in &b_side_sites {
358 position.insert(index_id.clone(), node_b.clone());
359 }
360
361 center = node_b;
362 any_moved_this_pass = true;
363 }
364
365 if !any_moved_this_pass {
366 break;
367 }
368 }
369
370 if !positions_satisfy_targets(&position, target_assignment) {
371 return Err(anyhow::anyhow!(
372 "SwapSchedule::build: did not converge within {} passes",
373 max_passes
374 ));
375 }
376
377 Ok(Self {
378 root: root.clone(),
379 steps,
380 })
381 }
382}
383
384fn positions_satisfy_targets<V, Id>(
385 position: &HashMap<Id, V>,
386 target_assignment: &HashMap<Id, V>,
387) -> bool
388where
389 V: Eq,
390 Id: Hash + Eq,
391{
392 target_assignment.iter().all(|(index_id, target_node)| {
393 position
394 .get(index_id)
395 .is_some_and(|node| node == target_node)
396 })
397}
398
399fn tree_path<V>(topology: &NodeNameNetwork<V>, from: &V, to: &V) -> Result<Vec<V>>
400where
401 V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
402{
403 let from_idx = topology
404 .node_index(from)
405 .ok_or_else(|| anyhow::anyhow!("tree_path: node {:?} not found", from))?;
406 let to_idx = topology
407 .node_index(to)
408 .ok_or_else(|| anyhow::anyhow!("tree_path: node {:?} not found", to))?;
409
410 topology
411 .path_between(from_idx, to_idx)
412 .ok_or_else(|| anyhow::anyhow!("tree_path: no path between {:?} and {:?}", from, to))?
413 .into_iter()
414 .map(|node_idx| {
415 topology
416 .node_name(node_idx)
417 .cloned()
418 .ok_or_else(|| anyhow::anyhow!("tree_path: node name not found"))
419 })
420 .collect()
421}
422
423fn tree_diameter<V>(topology: &NodeNameNetwork<V>) -> Result<usize>
424where
425 V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
426{
427 let mut node_indices = topology.graph().node_indices();
428 let Some(start) = node_indices.next() else {
429 return Ok(0);
430 };
431
432 let (farthest, _) = farthest_node(topology, start)?;
433 let (_, diameter) = farthest_node(topology, farthest)?;
434 Ok(diameter)
435}
436
437fn farthest_node<V>(topology: &NodeNameNetwork<V>, start: NodeIndex) -> Result<(NodeIndex, usize)>
438where
439 V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
440{
441 let graph = topology.graph();
442 let mut visited = HashSet::new();
443 let mut queue = VecDeque::from([(start, 0usize)]);
444 let mut farthest = (start, 0usize);
445
446 visited.insert(start);
447
448 while let Some((node, distance)) = queue.pop_front() {
449 if distance > farthest.1 {
450 farthest = (node, distance);
451 }
452
453 for neighbor in graph.neighbors(node) {
454 if visited.insert(neighbor) {
455 queue.push_back((neighbor, distance + 1));
456 }
457 }
458 }
459
460 if visited.len() != graph.node_count() {
461 return Err(anyhow::anyhow!(
462 "SwapSchedule::build: topology must be connected"
463 ));
464 }
465
466 Ok(farthest)
467}
468
469pub(crate) fn current_site_assignment<T, V>(
475 treetn: &TreeTN<T, V>,
476) -> HashMap<<T::Index as IndexLike>::Id, V>
477where
478 T: TensorLike,
479 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
480{
481 let mut out: HashMap<<T::Index as IndexLike>::Id, V> = HashMap::new();
482 for node_name in treetn.node_names() {
483 if let Some(site_space) = treetn.site_space(&node_name) {
484 for idx in site_space {
485 out.insert(idx.id().to_owned(), node_name.clone());
486 }
487 }
488 }
489 out
490}
491
492pub(crate) struct SubtreeOracle<V> {
501 in_time: HashMap<V, usize>,
502 out_time: HashMap<V, usize>,
503}
504
505impl<V> SubtreeOracle<V>
506where
507 V: Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
508{
509 pub(crate) fn new(topology: &NodeNameNetwork<V>, root: &V) -> Result<Self> {
512 let root_idx = topology
513 .node_index(root)
514 .ok_or_else(|| anyhow::anyhow!("SubtreeOracle: root {:?} not in topology", root))?;
515
516 let mut in_time: HashMap<V, usize> = HashMap::new();
517 let mut out_time: HashMap<V, usize> = HashMap::new();
518 let mut timer = 0usize;
519
520 let mut stack: Vec<(NodeIndex, Option<NodeIndex>, bool)> = vec![(root_idx, None, false)];
522
523 while let Some((node_idx, parent_idx, is_exit)) = stack.pop() {
524 let name = topology
525 .node_name(node_idx)
526 .ok_or_else(|| anyhow::anyhow!("SubtreeOracle: node name not found"))?
527 .clone();
528 if is_exit {
529 out_time.insert(name, timer);
530 timer += 1;
531 } else {
532 in_time.insert(name, timer);
533 timer += 1;
534 stack.push((node_idx, parent_idx, true));
535 let graph = topology.graph();
536 for neighbor in graph.neighbors(node_idx) {
537 if Some(neighbor) != parent_idx {
538 stack.push((neighbor, Some(node_idx), false));
539 }
540 }
541 }
542 }
543
544 Ok(Self { in_time, out_time })
545 }
546
547 pub(crate) fn is_target_on_a_side(&self, node_a: &V, node_b: &V, target: &V) -> bool {
551 if target == node_a {
552 return true;
553 }
554 if target == node_b {
555 return false;
556 }
557 let in_a = match self.in_time.get(node_a) {
558 Some(&t) => t,
559 None => return false,
560 };
561 let out_a = match self.out_time.get(node_a) {
562 Some(&t) => t,
563 None => return false,
564 };
565 let in_b = match self.in_time.get(node_b) {
566 Some(&t) => t,
567 None => return false,
568 };
569 let out_b = match self.out_time.get(node_b) {
570 Some(&t) => t,
571 None => return false,
572 };
573 let in_t = match self.in_time.get(target) {
574 Some(&t) => t,
575 None => return false,
576 };
577 let out_t = match self.out_time.get(target) {
578 Some(&t) => t,
579 None => return false,
580 };
581
582 if in_a <= in_b && out_b <= out_a {
583 !(in_b <= in_t && out_t <= out_b)
584 } else {
585 in_a <= in_t && out_t <= out_a
586 }
587 }
588}
589
590#[cfg(test)]
591mod tests;