tensor4all_treetn/random.rs
1//! Random tensor network generation.
2//!
3//! Provides utilities for creating random tensor networks, useful for testing.
4//!
5//! Note: Currently only supports `DynId` indices (the default dynamic index type).
6
7use crate::site_index_network::SiteIndexNetwork;
8use crate::treetn::TreeTN;
9use rand::Rng;
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::hash::Hash;
13use tensor4all_core::index::{DynId, Index, TagSet};
14use tensor4all_core::{RandomScalar, TensorDynLen};
15
16/// Specification for link (bond) dimensions.
17///
18/// Used when creating random tensor networks to specify the dimension of each bond.
19#[derive(Debug, Clone)]
20pub enum LinkSpace<V> {
21 /// All links have the same dimension.
22 Uniform(usize),
23 /// Each edge has its own dimension.
24 /// The map uses ordered pairs `(min(a, b), max(a, b))` as keys for consistency.
25 PerEdge(HashMap<(V, V), usize>),
26}
27
28impl<V> LinkSpace<V> {
29 /// Create a uniform link space where all bonds have the same dimension.
30 pub fn uniform(dim: usize) -> Self {
31 Self::Uniform(dim)
32 }
33
34 /// Create a per-edge link space from a map of edge dimensions.
35 pub fn per_edge(dims: HashMap<(V, V), usize>) -> Self {
36 Self::PerEdge(dims)
37 }
38}
39
40impl<V: Clone + Ord + Hash> LinkSpace<V> {
41 /// Get the dimension for an edge between two nodes.
42 ///
43 /// For `PerEdge`, the key is normalized to `(min(a, b), max(a, b))`.
44 pub fn get(&self, a: &V, b: &V) -> Option<usize> {
45 match self {
46 LinkSpace::Uniform(dim) => Some(*dim),
47 LinkSpace::PerEdge(map) => {
48 let key = if a < b {
49 (a.clone(), b.clone())
50 } else {
51 (b.clone(), a.clone())
52 };
53 map.get(&key).copied()
54 }
55 }
56 }
57}
58
59/// Type alias for the default index type used in random generation.
60pub type DefaultIndex = Index<DynId, TagSet>;
61
62/// Create a random TreeTN from a site index network (generic over scalar type).
63///
64/// Generates random tensors at each node with:
65/// - Site indices from the `site_network`
66/// - Link indices created according to `link_space`
67///
68/// # Type Parameters
69/// * `T` - Scalar type (e.g. `f64` or `Complex64`)
70/// * `R` - RNG type
71/// * `V` - Node name type
72///
73/// # Arguments
74/// * `rng` - Random number generator for tensor data
75/// * `site_network` - Network topology and site (physical) indices
76/// * `link_space` - Specification for bond dimensions
77///
78/// # Example
79/// ```
80/// use tensor4all_treetn::{SiteIndexNetwork, random_treetn, LinkSpace};
81/// use tensor4all_core::index::{Index, DynId, TagSet};
82/// use rand::SeedableRng;
83/// use rand_chacha::ChaCha8Rng;
84/// use std::collections::HashSet;
85///
86/// // Create a simple 2-node network
87/// let mut site_network = SiteIndexNetwork::<String, Index<DynId, TagSet>>::new();
88/// let i = Index::new_dyn(2);
89/// let j = Index::new_dyn(3);
90/// site_network.add_node("A".to_string(), HashSet::from([i.clone()])).unwrap();
91/// site_network.add_node("B".to_string(), HashSet::from([j.clone()])).unwrap();
92/// site_network.add_edge(&"A".to_string(), &"B".to_string()).unwrap();
93///
94/// let mut rng = ChaCha8Rng::seed_from_u64(42);
95/// let treetn = random_treetn::<f64, _, _>(&mut rng, &site_network, LinkSpace::uniform(4));
96///
97/// assert_eq!(treetn.node_count(), 2);
98/// ```
99pub fn random_treetn<T, R, V>(
100 rng: &mut R,
101 site_network: &SiteIndexNetwork<V, DefaultIndex>,
102 link_space: LinkSpace<V>,
103) -> TreeTN<TensorDynLen, V>
104where
105 T: RandomScalar,
106 R: Rng,
107 V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
108{
109 // Step 1: Create link indices for each edge
110 // Key: (smaller_name, larger_name), Value: link index
111 let mut link_indices: HashMap<(V, V), DefaultIndex> = HashMap::new();
112
113 // Get all edges from the site network topology
114 for (a, b) in site_network.edges() {
115 let key = if a < b {
116 (a.clone(), b.clone())
117 } else {
118 (b.clone(), a.clone())
119 };
120 let key_clone = (key.0.clone(), key.1.clone());
121
122 link_indices.entry(key).or_insert_with(|| {
123 let dim = link_space
124 .get(&key_clone.0, &key_clone.1)
125 .expect("LinkSpace must provide dimension for all edges");
126 Index::new_dyn(dim)
127 });
128 }
129
130 // Step 2: For each node, collect all indices and create random tensor
131 let mut tensors = Vec::new();
132 let mut node_names = Vec::new();
133
134 for node_name in site_network.node_names() {
135 let node_name = node_name.clone();
136
137 // Collect site indices
138 let site_inds = site_network
139 .site_space(&node_name)
140 .cloned()
141 .unwrap_or_default();
142
143 // Collect link indices from edges connected to this node
144 let mut all_indices: Vec<DefaultIndex> = site_inds.into_iter().collect();
145
146 for neighbor in site_network.neighbors(&node_name) {
147 let key = if node_name < neighbor {
148 (node_name.clone(), neighbor.clone())
149 } else {
150 (neighbor.clone(), node_name.clone())
151 };
152
153 if let Some(link_idx) = link_indices.get(&key) {
154 all_indices.push(link_idx.clone());
155 }
156 }
157
158 // Create random tensor
159 let tensor = TensorDynLen::random::<T, R>(rng, all_indices);
160
161 tensors.push(tensor);
162 node_names.push(node_name);
163 }
164
165 // Step 3: Create TreeTN from tensors
166 TreeTN::from_tensors(tensors, node_names).expect("Failed to create TreeTN from random tensors")
167}
168
169#[cfg(test)]
170mod tests;