Skip to main content

tensor4all_treetn/
random.rs

1//! Random tensor network generation.
2//!
3//! Provides utilities for creating random tensor networks, useful for testing.
4//!
5//! Note: Currently only supports `DynId` indices (the default dynamic index type).
6
7use crate::site_index_network::SiteIndexNetwork;
8use crate::treetn::TreeTN;
9use rand::Rng;
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::hash::Hash;
13use tensor4all_core::index::{DynId, Index, TagSet};
14use tensor4all_core::{RandomScalar, TensorDynLen};
15
16/// Specification for link (bond) dimensions.
17///
18/// Used when creating random tensor networks to specify the dimension of each bond.
19#[derive(Debug, Clone)]
20pub enum LinkSpace<V> {
21    /// All links have the same dimension.
22    Uniform(usize),
23    /// Each edge has its own dimension.
24    /// The map uses ordered pairs `(min(a, b), max(a, b))` as keys for consistency.
25    PerEdge(HashMap<(V, V), usize>),
26}
27
28impl<V> LinkSpace<V> {
29    /// Create a uniform link space where all bonds have the same dimension.
30    pub fn uniform(dim: usize) -> Self {
31        Self::Uniform(dim)
32    }
33
34    /// Create a per-edge link space from a map of edge dimensions.
35    pub fn per_edge(dims: HashMap<(V, V), usize>) -> Self {
36        Self::PerEdge(dims)
37    }
38}
39
40impl<V: Clone + Ord + Hash> LinkSpace<V> {
41    /// Get the dimension for an edge between two nodes.
42    ///
43    /// For `PerEdge`, the key is normalized to `(min(a, b), max(a, b))`.
44    pub fn get(&self, a: &V, b: &V) -> Option<usize> {
45        match self {
46            LinkSpace::Uniform(dim) => Some(*dim),
47            LinkSpace::PerEdge(map) => {
48                let key = if a < b {
49                    (a.clone(), b.clone())
50                } else {
51                    (b.clone(), a.clone())
52                };
53                map.get(&key).copied()
54            }
55        }
56    }
57}
58
59/// Type alias for the default index type used in random generation.
60pub type DefaultIndex = Index<DynId, TagSet>;
61
62/// Create a random TreeTN from a site index network (generic over scalar type).
63///
64/// Generates random tensors at each node with:
65/// - Site indices from the `site_network`
66/// - Link indices created according to `link_space`
67///
68/// # Type Parameters
69/// * `T` - Scalar type (e.g. `f64` or `Complex64`)
70/// * `R` - RNG type
71/// * `V` - Node name type
72///
73/// # Arguments
74/// * `rng` - Random number generator for tensor data
75/// * `site_network` - Network topology and site (physical) indices
76/// * `link_space` - Specification for bond dimensions
77///
78/// # Example
79/// ```
80/// use tensor4all_treetn::{SiteIndexNetwork, random_treetn, LinkSpace};
81/// use tensor4all_core::index::{Index, DynId, TagSet};
82/// use rand::SeedableRng;
83/// use rand_chacha::ChaCha8Rng;
84/// use std::collections::HashSet;
85///
86/// // Create a simple 2-node network
87/// let mut site_network = SiteIndexNetwork::<String, Index<DynId, TagSet>>::new();
88/// let i = Index::new_dyn(2);
89/// let j = Index::new_dyn(3);
90/// site_network.add_node("A".to_string(), HashSet::from([i.clone()])).unwrap();
91/// site_network.add_node("B".to_string(), HashSet::from([j.clone()])).unwrap();
92/// site_network.add_edge(&"A".to_string(), &"B".to_string()).unwrap();
93///
94/// let mut rng = ChaCha8Rng::seed_from_u64(42);
95/// let treetn = random_treetn::<f64, _, _>(&mut rng, &site_network, LinkSpace::uniform(4));
96///
97/// assert_eq!(treetn.node_count(), 2);
98/// ```
99pub fn random_treetn<T, R, V>(
100    rng: &mut R,
101    site_network: &SiteIndexNetwork<V, DefaultIndex>,
102    link_space: LinkSpace<V>,
103) -> TreeTN<TensorDynLen, V>
104where
105    T: RandomScalar,
106    R: Rng,
107    V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
108{
109    // Step 1: Create link indices for each edge
110    // Key: (smaller_name, larger_name), Value: link index
111    let mut link_indices: HashMap<(V, V), DefaultIndex> = HashMap::new();
112
113    // Get all edges from the site network topology
114    for (a, b) in site_network.edges() {
115        let key = if a < b {
116            (a.clone(), b.clone())
117        } else {
118            (b.clone(), a.clone())
119        };
120        let key_clone = (key.0.clone(), key.1.clone());
121
122        link_indices.entry(key).or_insert_with(|| {
123            let dim = link_space
124                .get(&key_clone.0, &key_clone.1)
125                .expect("LinkSpace must provide dimension for all edges");
126            Index::new_dyn(dim)
127        });
128    }
129
130    // Step 2: For each node, collect all indices and create random tensor
131    let mut tensors = Vec::new();
132    let mut node_names = Vec::new();
133
134    for node_name in site_network.node_names() {
135        let node_name = node_name.clone();
136
137        // Collect site indices
138        let site_inds = site_network
139            .site_space(&node_name)
140            .cloned()
141            .unwrap_or_default();
142
143        // Collect link indices from edges connected to this node
144        let mut all_indices: Vec<DefaultIndex> = site_inds.into_iter().collect();
145
146        for neighbor in site_network.neighbors(&node_name) {
147            let key = if node_name < neighbor {
148                (node_name.clone(), neighbor.clone())
149            } else {
150                (neighbor.clone(), node_name.clone())
151            };
152
153            if let Some(link_idx) = link_indices.get(&key) {
154                all_indices.push(link_idx.clone());
155            }
156        }
157
158        // Create random tensor
159        let tensor = TensorDynLen::random::<T, R>(rng, all_indices);
160
161        tensors.push(tensor);
162        node_names.push(node_name);
163    }
164
165    // Step 3: Create TreeTN from tensors
166    TreeTN::from_tensors(tensors, node_names).expect("Failed to create TreeTN from random tensors")
167}
168
169#[cfg(test)]
170mod tests;