Skip to main content

tensor4all_treetn/treetn/
addition.rs

1//! Addition operations for TreeTN using direct-sum (block) construction.
2//!
3//! This module provides helper functions and types for adding two TreeTNs:
4//! - [`MergedBondInfo`]: Information about merged bond indices
5//! - [`compute_merged_bond_indices`]: Compute merged bond index information from two networks
6
7use petgraph::visit::EdgeRef;
8use std::collections::{HashMap, HashSet};
9use std::hash::Hash;
10
11use anyhow::{bail, Result};
12
13use tensor4all_core::{AnyScalar, IndexLike, TensorIndex, TensorLike};
14
15use super::TreeTN;
16
17/// Information about a merged bond index for direct-sum addition.
18///
19/// When adding two TreeTNs, each bond index in the result has dimension
20/// `dim_a + dim_b`, where `dim_a` and `dim_b` are the original bond dimensions.
21#[derive(Debug, Clone)]
22pub struct MergedBondInfo<I>
23where
24    I: IndexLike,
25{
26    /// Bond dimension from the first TreeTN
27    pub dim_a: usize,
28    /// Bond dimension from the second TreeTN
29    pub dim_b: usize,
30    /// The new merged bond index (with dimension dim_a + dim_b)
31    pub merged_index: I,
32}
33
34impl<T, V> TreeTN<T, V>
35where
36    T: TensorLike,
37    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
38{
39    fn sorted_site_space(site_space: &HashSet<T::Index>) -> Vec<T::Index>
40    where
41        T::Index: Clone,
42        <T::Index as IndexLike>::Id: Ord,
43    {
44        let mut indices: Vec<_> = site_space.iter().cloned().collect();
45        indices.sort_by(|left, right| {
46            left.dim()
47                .cmp(&right.dim())
48                .then_with(|| left.id().cmp(right.id()))
49        });
50        indices
51    }
52
53    /// Reindex this TreeTN's site space to match a template network.
54    ///
55    /// The topology must match, and each corresponding node must carry the same
56    /// number of site indices with the same dimensions. Site indices are paired
57    /// node-by-node after sorting by `(dim, id)` for deterministic matching.
58    ///
59    /// # Arguments
60    /// * `template` - Reference TreeTN whose site index IDs should be adopted
61    ///
62    /// # Returns
63    /// A new TreeTN with the same tensor data as `self`, but site index IDs
64    /// rewritten to match `template`.
65    ///
66    /// # Errors
67    /// Returns an error if the two networks have different topologies or
68    /// incompatible site-space dimensions on any node.
69    ///
70    /// # Examples
71    /// ```
72    /// use tensor4all_core::{DynIndex, TensorDynLen};
73    /// use tensor4all_treetn::TreeTN;
74    ///
75    /// # fn make_chain(site0: DynIndex, site1: DynIndex) -> TreeTN<TensorDynLen, usize> {
76    /// #     let bond = DynIndex::new_dyn(1);
77    /// #     let t0 = TensorDynLen::from_dense(vec![site0, bond.clone()], vec![1.0, 0.0]).unwrap();
78    /// #     let t1 = TensorDynLen::from_dense(vec![bond, site1], vec![1.0, 0.0]).unwrap();
79    /// #     TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0, t1], vec![0, 1]).unwrap()
80    /// # }
81    /// let state_a = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2));
82    /// let state_b = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2));
83    ///
84    /// let aligned = state_b.reindex_site_space_like(&state_a).unwrap();
85    /// assert!(aligned.share_equivalent_site_index_network(&state_a));
86    /// ```
87    pub fn reindex_site_space_like(&self, template: &Self) -> Result<Self>
88    where
89        V: Ord,
90        T::Index: Clone,
91        <T::Index as IndexLike>::Id:
92            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
93    {
94        if !self.same_topology(template) {
95            bail!("reindex_site_space_like: networks have incompatible topologies");
96        }
97
98        let mut old_indices = Vec::new();
99        let mut new_indices = Vec::new();
100
101        for node_name in self.node_names() {
102            let self_site_space = self
103                .site_space(&node_name)
104                .ok_or_else(|| anyhow::anyhow!("site space not found for node {:?}", node_name))?;
105            let template_site_space = template.site_space(&node_name).ok_or_else(|| {
106                anyhow::anyhow!("template site space not found for node {:?}", node_name)
107            })?;
108
109            if self_site_space.len() != template_site_space.len() {
110                bail!(
111                    "reindex_site_space_like: node {:?} has {} site indices in self but {} in template",
112                    node_name,
113                    self_site_space.len(),
114                    template_site_space.len()
115                );
116            }
117
118            let self_sorted = Self::sorted_site_space(self_site_space);
119            let template_sorted = Self::sorted_site_space(template_site_space);
120
121            for (old_index, new_index) in self_sorted.iter().zip(template_sorted.iter()) {
122                if old_index.dim() != new_index.dim() {
123                    bail!(
124                        "reindex_site_space_like: node {:?} site dimension mismatch {} != {}",
125                        node_name,
126                        old_index.dim(),
127                        new_index.dim()
128                    );
129                }
130                old_indices.push(old_index.clone());
131                new_indices.push(new_index.clone());
132            }
133        }
134
135        self.replaceinds(&old_indices, &new_indices)
136    }
137
138    /// Add two TreeTNs after aligning the second operand's site index IDs to the first.
139    ///
140    /// This is useful when two states share the same topology and site dimensions
141    /// but were constructed with different site index IDs.
142    ///
143    /// # Arguments
144    /// * `other` - The other TreeTN to align and add
145    ///
146    /// # Returns
147    /// The direct-sum addition result with site IDs matching `self`.
148    ///
149    /// # Examples
150    /// ```
151    /// use tensor4all_core::{DynIndex, TensorDynLen};
152    /// use tensor4all_treetn::TreeTN;
153    ///
154    /// # fn make_chain(site0: DynIndex, site1: DynIndex) -> TreeTN<TensorDynLen, usize> {
155    /// #     let bond = DynIndex::new_dyn(1);
156    /// #     let t0 = TensorDynLen::from_dense(vec![site0, bond.clone()], vec![1.0, 0.0]).unwrap();
157    /// #     let t1 = TensorDynLen::from_dense(vec![bond, site1], vec![1.0, 0.0]).unwrap();
158    /// #     TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0, t1], vec![0, 1]).unwrap()
159    /// # }
160    /// let state_a = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2));
161    /// let state_b = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2));
162    ///
163    /// let sum = state_a.add_aligned(&state_b).unwrap();
164    /// assert_eq!(sum.node_count(), 2);
165    /// assert!(sum.share_equivalent_site_index_network(&state_a));
166    /// ```
167    pub fn add_aligned(&self, other: &Self) -> Result<Self>
168    where
169        V: Ord,
170        T::Index: Clone,
171        <T::Index as IndexLike>::Id:
172            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
173    {
174        let other_aligned = other.reindex_site_space_like(self)?;
175        self.add(&other_aligned)
176    }
177
178    /// Compute merged bond indices for direct-sum addition.
179    ///
180    /// For each edge in the network, compute the merged bond information
181    /// containing dimensions from both networks and a new merged index.
182    ///
183    /// # Arguments
184    /// * `other` - The other TreeTN to compute merged bonds with
185    ///
186    /// # Returns
187    /// A HashMap mapping edge keys (node_name_pair in canonical order) to MergedBondInfo.
188    ///
189    /// # Errors
190    /// Returns an error if:
191    /// - Networks have incompatible topologies
192    /// - Bond indices cannot be found
193    #[allow(clippy::type_complexity)]
194    pub fn compute_merged_bond_indices(
195        &self,
196        other: &Self,
197    ) -> Result<HashMap<(V, V), MergedBondInfo<T::Index>>>
198    where
199        V: Ord,
200    {
201        let mut result = HashMap::new();
202
203        for edge in self.graph.graph().edge_indices() {
204            let (src, tgt) = self
205                .graph
206                .graph()
207                .edge_endpoints(edge)
208                .ok_or_else(|| anyhow::anyhow!("Edge has no endpoints"))?;
209
210            let bond_index_a = self
211                .bond_index(edge)
212                .ok_or_else(|| anyhow::anyhow!("Bond index not found in self"))?;
213            let dim_a = bond_index_a.dim();
214
215            let src_name = self
216                .graph
217                .node_name(src)
218                .ok_or_else(|| anyhow::anyhow!("Source node name not found"))?
219                .clone();
220            let tgt_name = self
221                .graph
222                .node_name(tgt)
223                .ok_or_else(|| anyhow::anyhow!("Target node name not found"))?
224                .clone();
225
226            // Find corresponding edge in other
227            let src_idx_other = other
228                .graph
229                .node_index(&src_name)
230                .ok_or_else(|| anyhow::anyhow!("Source node not found in other"))?;
231            let tgt_idx_other = other
232                .graph
233                .node_index(&tgt_name)
234                .ok_or_else(|| anyhow::anyhow!("Target node not found in other"))?;
235
236            // Find edge between these nodes in other
237            let edge_other = other
238                .graph
239                .graph()
240                .edges_connecting(src_idx_other, tgt_idx_other)
241                .next()
242                .or_else(|| {
243                    other
244                        .graph
245                        .graph()
246                        .edges_connecting(tgt_idx_other, src_idx_other)
247                        .next()
248                })
249                .ok_or_else(|| anyhow::anyhow!("Edge not found in other"))?;
250
251            let bond_index_b = other
252                .bond_index(edge_other.id())
253                .ok_or_else(|| anyhow::anyhow!("Bond index not found in other"))?;
254            let dim_b = bond_index_b.dim();
255
256            // Create merged bond index using direct_sum on dummy tensors
257            // For now, we just store dimensions; the actual merged index will be
258            // created during the direct sum operation using TensorLike::direct_sum
259            //
260            // Note: We need a way to create a new index with dim_a + dim_b.
261            // This requires the TensorLike implementation to handle index creation.
262            // For now, we clone one of the existing indices as a placeholder.
263            // The actual merging happens in the direct_sum operation.
264            let merged_index = bond_index_a.clone();
265
266            // Store in canonical order (smaller name first)
267            let key = if src_name < tgt_name {
268                (src_name, tgt_name)
269            } else {
270                (tgt_name, src_name)
271            };
272
273            result.insert(
274                key,
275                MergedBondInfo {
276                    dim_a,
277                    dim_b,
278                    merged_index,
279                },
280            );
281        }
282
283        Ok(result)
284    }
285
286    /// Add two TreeTNs using direct-sum construction.
287    ///
288    /// This creates a new TreeTN where each tensor is the direct sum of the
289    /// corresponding tensors from self and other, with bond dimensions merged.
290    /// The two networks must share the same topology **and** the same site
291    /// index IDs. Use [`add_aligned`](Self::add_aligned) if site index IDs differ.
292    ///
293    /// # Arguments
294    /// * `other` - The other TreeTN to add
295    ///
296    /// # Returns
297    /// A new TreeTN representing the sum.
298    ///
299    /// # Errors
300    /// Returns an error if the networks have incompatible structures.
301    ///
302    /// # Examples
303    ///
304    /// ```
305    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
306    /// use tensor4all_treetn::TreeTN;
307    ///
308    /// let s = DynIndex::new_dyn(2);
309    /// let t = TensorDynLen::from_dense(vec![s.clone()], vec![1.0_f64, 2.0]).unwrap();
310    /// let tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
311    ///
312    /// // Adding a single-node network to itself doubles the values
313    /// let sum = tn.add(&tn).unwrap();
314    /// let dense = sum.to_dense().unwrap();
315    /// let expected = TensorDynLen::from_dense(vec![s], vec![2.0, 4.0]).unwrap();
316    /// assert!((&dense - &expected).maxabs() < 1e-12);
317    /// ```
318    pub fn add(&self, other: &Self) -> Result<Self>
319    where
320        V: Ord,
321        <T::Index as IndexLike>::Id:
322            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
323    {
324        // Verify same topology
325        if !self.same_topology(other) {
326            return Err(anyhow::anyhow!(
327                "Cannot add TreeTNs with different topologies"
328            ));
329        }
330
331        // Track merged indices for each edge.
332        // Key: (smaller_node_name, larger_node_name) for canonical ordering
333        // Value: the merged bond index to use for this edge
334        let mut edge_merged_indices: HashMap<(V, V), T::Index> = HashMap::new();
335
336        // For each node, compute the direct sum of tensors
337        let mut result_tensors: Vec<T> = Vec::new();
338        let mut result_node_names: Vec<V> = Vec::new();
339
340        for node_name in self.node_names() {
341            let self_idx = self.node_index(&node_name).unwrap();
342            let other_idx = other.node_index(&node_name).unwrap();
343
344            let tensor_a = self.tensor(self_idx).unwrap();
345            let tensor_b = other.tensor(other_idx).unwrap();
346
347            // Find bond index pairs for this node and track neighbors
348            let mut bond_pairs: Vec<(T::Index, T::Index)> = Vec::new();
349            let mut neighbors_for_edges: Vec<V> = Vec::new();
350
351            for neighbor in self.site_index_network().neighbors(&node_name) {
352                // Get bond index from self
353                let self_edge = self.edge_between(&node_name, &neighbor).unwrap();
354                let self_bond = self.bond_index(self_edge).unwrap();
355
356                // Get bond index from other
357                let other_edge = other.edge_between(&node_name, &neighbor).unwrap();
358                let other_bond = other.bond_index(other_edge).unwrap();
359
360                bond_pairs.push((self_bond.clone(), other_bond.clone()));
361                neighbors_for_edges.push(neighbor);
362            }
363
364            // For nodes with no bonds (single-node network), use element-wise addition
365            // instead of direct_sum (which requires at least one index pair).
366            if bond_pairs.is_empty() {
367                let sum_tensor =
368                    tensor_a.axpby(AnyScalar::new_real(1.0), tensor_b, AnyScalar::new_real(1.0))?;
369                result_tensors.push(sum_tensor);
370                result_node_names.push(node_name);
371                continue;
372            }
373
374            // Compute direct sum
375            let direct_sum_result = tensor_a.direct_sum(tensor_b, &bond_pairs)?;
376            let mut result_tensor = direct_sum_result.tensor;
377
378            // For each edge, ensure we use consistent merged indices:
379            // - If we've already seen this edge, replace the auto-generated index with the stored one
380            // - If this is the first time seeing this edge, store the auto-generated index
381            for (i, neighbor) in neighbors_for_edges.iter().enumerate() {
382                // Create canonical edge key (smaller name first)
383                let edge_key = if node_name < *neighbor {
384                    (node_name.clone(), neighbor.clone())
385                } else {
386                    (neighbor.clone(), node_name.clone())
387                };
388
389                let new_index = &direct_sum_result.new_indices[i];
390
391                if let Some(stored_index) = edge_merged_indices.get(&edge_key) {
392                    // Edge already processed - replace auto-generated index with stored one
393                    result_tensor = result_tensor.replaceind(new_index, stored_index)?;
394                } else {
395                    // First time seeing this edge - store the auto-generated index
396                    edge_merged_indices.insert(edge_key, new_index.clone());
397                }
398            }
399
400            result_tensors.push(result_tensor);
401            result_node_names.push(node_name);
402        }
403
404        // Build result TreeTN
405        TreeTN::from_tensors(result_tensors, result_node_names)
406    }
407}
408
409#[cfg(test)]
410mod tests;