tensor4all_treetn/treetn/addition.rs
1//! Addition operations for TreeTN using direct-sum (block) construction.
2//!
3//! This module provides helper functions and types for adding two TreeTNs:
4//! - [`MergedBondInfo`]: Information about merged bond indices
5//! - [`compute_merged_bond_indices`]: Compute merged bond index information from two networks
6
7use petgraph::visit::EdgeRef;
8use std::collections::{HashMap, HashSet};
9use std::hash::Hash;
10
11use anyhow::{bail, Result};
12
13use tensor4all_core::{AnyScalar, IndexLike, TensorIndex, TensorLike};
14
15use super::TreeTN;
16
17/// Information about a merged bond index for direct-sum addition.
18///
19/// When adding two TreeTNs, each bond index in the result has dimension
20/// `dim_a + dim_b`, where `dim_a` and `dim_b` are the original bond dimensions.
21#[derive(Debug, Clone)]
22pub struct MergedBondInfo<I>
23where
24 I: IndexLike,
25{
26 /// Bond dimension from the first TreeTN
27 pub dim_a: usize,
28 /// Bond dimension from the second TreeTN
29 pub dim_b: usize,
30 /// The new merged bond index (with dimension dim_a + dim_b)
31 pub merged_index: I,
32}
33
34impl<T, V> TreeTN<T, V>
35where
36 T: TensorLike,
37 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
38{
39 fn sorted_site_space(site_space: &HashSet<T::Index>) -> Vec<T::Index>
40 where
41 T::Index: Clone,
42 <T::Index as IndexLike>::Id: Ord,
43 {
44 let mut indices: Vec<_> = site_space.iter().cloned().collect();
45 indices.sort_by(|left, right| {
46 left.dim()
47 .cmp(&right.dim())
48 .then_with(|| left.id().cmp(right.id()))
49 });
50 indices
51 }
52
53 /// Reindex this TreeTN's site space to match a template network.
54 ///
55 /// The topology must match, and each corresponding node must carry the same
56 /// number of site indices with the same dimensions. Site indices are paired
57 /// node-by-node after sorting by `(dim, id)` for deterministic matching.
58 ///
59 /// # Arguments
60 /// * `template` - Reference TreeTN whose site index IDs should be adopted
61 ///
62 /// # Returns
63 /// A new TreeTN with the same tensor data as `self`, but site index IDs
64 /// rewritten to match `template`.
65 ///
66 /// # Errors
67 /// Returns an error if the two networks have different topologies or
68 /// incompatible site-space dimensions on any node.
69 ///
70 /// # Examples
71 /// ```
72 /// use tensor4all_core::{DynIndex, TensorDynLen};
73 /// use tensor4all_treetn::TreeTN;
74 ///
75 /// # fn make_chain(site0: DynIndex, site1: DynIndex) -> TreeTN<TensorDynLen, usize> {
76 /// # let bond = DynIndex::new_dyn(1);
77 /// # let t0 = TensorDynLen::from_dense(vec![site0, bond.clone()], vec![1.0, 0.0]).unwrap();
78 /// # let t1 = TensorDynLen::from_dense(vec![bond, site1], vec![1.0, 0.0]).unwrap();
79 /// # TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0, t1], vec![0, 1]).unwrap()
80 /// # }
81 /// let state_a = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2));
82 /// let state_b = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2));
83 ///
84 /// let aligned = state_b.reindex_site_space_like(&state_a).unwrap();
85 /// assert!(aligned.share_equivalent_site_index_network(&state_a));
86 /// ```
87 pub fn reindex_site_space_like(&self, template: &Self) -> Result<Self>
88 where
89 V: Ord,
90 T::Index: Clone,
91 <T::Index as IndexLike>::Id:
92 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
93 {
94 if !self.same_topology(template) {
95 bail!("reindex_site_space_like: networks have incompatible topologies");
96 }
97
98 let mut old_indices = Vec::new();
99 let mut new_indices = Vec::new();
100
101 for node_name in self.node_names() {
102 let self_site_space = self
103 .site_space(&node_name)
104 .ok_or_else(|| anyhow::anyhow!("site space not found for node {:?}", node_name))?;
105 let template_site_space = template.site_space(&node_name).ok_or_else(|| {
106 anyhow::anyhow!("template site space not found for node {:?}", node_name)
107 })?;
108
109 if self_site_space.len() != template_site_space.len() {
110 bail!(
111 "reindex_site_space_like: node {:?} has {} site indices in self but {} in template",
112 node_name,
113 self_site_space.len(),
114 template_site_space.len()
115 );
116 }
117
118 let self_sorted = Self::sorted_site_space(self_site_space);
119 let template_sorted = Self::sorted_site_space(template_site_space);
120
121 for (old_index, new_index) in self_sorted.iter().zip(template_sorted.iter()) {
122 if old_index.dim() != new_index.dim() {
123 bail!(
124 "reindex_site_space_like: node {:?} site dimension mismatch {} != {}",
125 node_name,
126 old_index.dim(),
127 new_index.dim()
128 );
129 }
130 old_indices.push(old_index.clone());
131 new_indices.push(new_index.clone());
132 }
133 }
134
135 self.replaceinds(&old_indices, &new_indices)
136 }
137
138 /// Add two TreeTNs after aligning the second operand's site index IDs to the first.
139 ///
140 /// This is useful when two states share the same topology and site dimensions
141 /// but were constructed with different site index IDs.
142 ///
143 /// # Arguments
144 /// * `other` - The other TreeTN to align and add
145 ///
146 /// # Returns
147 /// The direct-sum addition result with site IDs matching `self`.
148 ///
149 /// # Examples
150 /// ```
151 /// use tensor4all_core::{DynIndex, TensorDynLen};
152 /// use tensor4all_treetn::TreeTN;
153 ///
154 /// # fn make_chain(site0: DynIndex, site1: DynIndex) -> TreeTN<TensorDynLen, usize> {
155 /// # let bond = DynIndex::new_dyn(1);
156 /// # let t0 = TensorDynLen::from_dense(vec![site0, bond.clone()], vec![1.0, 0.0]).unwrap();
157 /// # let t1 = TensorDynLen::from_dense(vec![bond, site1], vec![1.0, 0.0]).unwrap();
158 /// # TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0, t1], vec![0, 1]).unwrap()
159 /// # }
160 /// let state_a = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2));
161 /// let state_b = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2));
162 ///
163 /// let sum = state_a.add_aligned(&state_b).unwrap();
164 /// assert_eq!(sum.node_count(), 2);
165 /// assert!(sum.share_equivalent_site_index_network(&state_a));
166 /// ```
167 pub fn add_aligned(&self, other: &Self) -> Result<Self>
168 where
169 V: Ord,
170 T::Index: Clone,
171 <T::Index as IndexLike>::Id:
172 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
173 {
174 let other_aligned = other.reindex_site_space_like(self)?;
175 self.add(&other_aligned)
176 }
177
178 /// Compute merged bond indices for direct-sum addition.
179 ///
180 /// For each edge in the network, compute the merged bond information
181 /// containing dimensions from both networks and a new merged index.
182 ///
183 /// # Arguments
184 /// * `other` - The other TreeTN to compute merged bonds with
185 ///
186 /// # Returns
187 /// A HashMap mapping edge keys (node_name_pair in canonical order) to MergedBondInfo.
188 ///
189 /// # Errors
190 /// Returns an error if:
191 /// - Networks have incompatible topologies
192 /// - Bond indices cannot be found
193 #[allow(clippy::type_complexity)]
194 pub fn compute_merged_bond_indices(
195 &self,
196 other: &Self,
197 ) -> Result<HashMap<(V, V), MergedBondInfo<T::Index>>>
198 where
199 V: Ord,
200 {
201 let mut result = HashMap::new();
202
203 for edge in self.graph.graph().edge_indices() {
204 let (src, tgt) = self
205 .graph
206 .graph()
207 .edge_endpoints(edge)
208 .ok_or_else(|| anyhow::anyhow!("Edge has no endpoints"))?;
209
210 let bond_index_a = self
211 .bond_index(edge)
212 .ok_or_else(|| anyhow::anyhow!("Bond index not found in self"))?;
213 let dim_a = bond_index_a.dim();
214
215 let src_name = self
216 .graph
217 .node_name(src)
218 .ok_or_else(|| anyhow::anyhow!("Source node name not found"))?
219 .clone();
220 let tgt_name = self
221 .graph
222 .node_name(tgt)
223 .ok_or_else(|| anyhow::anyhow!("Target node name not found"))?
224 .clone();
225
226 // Find corresponding edge in other
227 let src_idx_other = other
228 .graph
229 .node_index(&src_name)
230 .ok_or_else(|| anyhow::anyhow!("Source node not found in other"))?;
231 let tgt_idx_other = other
232 .graph
233 .node_index(&tgt_name)
234 .ok_or_else(|| anyhow::anyhow!("Target node not found in other"))?;
235
236 // Find edge between these nodes in other
237 let edge_other = other
238 .graph
239 .graph()
240 .edges_connecting(src_idx_other, tgt_idx_other)
241 .next()
242 .or_else(|| {
243 other
244 .graph
245 .graph()
246 .edges_connecting(tgt_idx_other, src_idx_other)
247 .next()
248 })
249 .ok_or_else(|| anyhow::anyhow!("Edge not found in other"))?;
250
251 let bond_index_b = other
252 .bond_index(edge_other.id())
253 .ok_or_else(|| anyhow::anyhow!("Bond index not found in other"))?;
254 let dim_b = bond_index_b.dim();
255
256 // Create merged bond index using direct_sum on dummy tensors
257 // For now, we just store dimensions; the actual merged index will be
258 // created during the direct sum operation using TensorLike::direct_sum
259 //
260 // Note: We need a way to create a new index with dim_a + dim_b.
261 // This requires the TensorLike implementation to handle index creation.
262 // For now, we clone one of the existing indices as a placeholder.
263 // The actual merging happens in the direct_sum operation.
264 let merged_index = bond_index_a.clone();
265
266 // Store in canonical order (smaller name first)
267 let key = if src_name < tgt_name {
268 (src_name, tgt_name)
269 } else {
270 (tgt_name, src_name)
271 };
272
273 result.insert(
274 key,
275 MergedBondInfo {
276 dim_a,
277 dim_b,
278 merged_index,
279 },
280 );
281 }
282
283 Ok(result)
284 }
285
286 /// Add two TreeTNs using direct-sum construction.
287 ///
288 /// This creates a new TreeTN where each tensor is the direct sum of the
289 /// corresponding tensors from self and other, with bond dimensions merged.
290 /// The two networks must share the same topology **and** the same site
291 /// index IDs. Use [`add_aligned`](Self::add_aligned) if site index IDs differ.
292 ///
293 /// # Arguments
294 /// * `other` - The other TreeTN to add
295 ///
296 /// # Returns
297 /// A new TreeTN representing the sum.
298 ///
299 /// # Errors
300 /// Returns an error if the networks have incompatible structures.
301 ///
302 /// # Examples
303 ///
304 /// ```
305 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
306 /// use tensor4all_treetn::TreeTN;
307 ///
308 /// let s = DynIndex::new_dyn(2);
309 /// let t = TensorDynLen::from_dense(vec![s.clone()], vec![1.0_f64, 2.0]).unwrap();
310 /// let tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
311 ///
312 /// // Adding a single-node network to itself doubles the values
313 /// let sum = tn.add(&tn).unwrap();
314 /// let dense = sum.to_dense().unwrap();
315 /// let expected = TensorDynLen::from_dense(vec![s], vec![2.0, 4.0]).unwrap();
316 /// assert!((&dense - &expected).maxabs() < 1e-12);
317 /// ```
318 pub fn add(&self, other: &Self) -> Result<Self>
319 where
320 V: Ord,
321 <T::Index as IndexLike>::Id:
322 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
323 {
324 // Verify same topology
325 if !self.same_topology(other) {
326 return Err(anyhow::anyhow!(
327 "Cannot add TreeTNs with different topologies"
328 ));
329 }
330
331 // Track merged indices for each edge.
332 // Key: (smaller_node_name, larger_node_name) for canonical ordering
333 // Value: the merged bond index to use for this edge
334 let mut edge_merged_indices: HashMap<(V, V), T::Index> = HashMap::new();
335
336 // For each node, compute the direct sum of tensors
337 let mut result_tensors: Vec<T> = Vec::new();
338 let mut result_node_names: Vec<V> = Vec::new();
339
340 for node_name in self.node_names() {
341 let self_idx = self.node_index(&node_name).unwrap();
342 let other_idx = other.node_index(&node_name).unwrap();
343
344 let tensor_a = self.tensor(self_idx).unwrap();
345 let tensor_b = other.tensor(other_idx).unwrap();
346
347 // Find bond index pairs for this node and track neighbors
348 let mut bond_pairs: Vec<(T::Index, T::Index)> = Vec::new();
349 let mut neighbors_for_edges: Vec<V> = Vec::new();
350
351 for neighbor in self.site_index_network().neighbors(&node_name) {
352 // Get bond index from self
353 let self_edge = self.edge_between(&node_name, &neighbor).unwrap();
354 let self_bond = self.bond_index(self_edge).unwrap();
355
356 // Get bond index from other
357 let other_edge = other.edge_between(&node_name, &neighbor).unwrap();
358 let other_bond = other.bond_index(other_edge).unwrap();
359
360 bond_pairs.push((self_bond.clone(), other_bond.clone()));
361 neighbors_for_edges.push(neighbor);
362 }
363
364 // For nodes with no bonds (single-node network), use element-wise addition
365 // instead of direct_sum (which requires at least one index pair).
366 if bond_pairs.is_empty() {
367 let sum_tensor =
368 tensor_a.axpby(AnyScalar::new_real(1.0), tensor_b, AnyScalar::new_real(1.0))?;
369 result_tensors.push(sum_tensor);
370 result_node_names.push(node_name);
371 continue;
372 }
373
374 // Compute direct sum
375 let direct_sum_result = tensor_a.direct_sum(tensor_b, &bond_pairs)?;
376 let mut result_tensor = direct_sum_result.tensor;
377
378 // For each edge, ensure we use consistent merged indices:
379 // - If we've already seen this edge, replace the auto-generated index with the stored one
380 // - If this is the first time seeing this edge, store the auto-generated index
381 for (i, neighbor) in neighbors_for_edges.iter().enumerate() {
382 // Create canonical edge key (smaller name first)
383 let edge_key = if node_name < *neighbor {
384 (node_name.clone(), neighbor.clone())
385 } else {
386 (neighbor.clone(), node_name.clone())
387 };
388
389 let new_index = &direct_sum_result.new_indices[i];
390
391 if let Some(stored_index) = edge_merged_indices.get(&edge_key) {
392 // Edge already processed - replace auto-generated index with stored one
393 result_tensor = result_tensor.replaceind(new_index, stored_index)?;
394 } else {
395 // First time seeing this edge - store the auto-generated index
396 edge_merged_indices.insert(edge_key, new_index.clone());
397 }
398 }
399
400 result_tensors.push(result_tensor);
401 result_node_names.push(node_name);
402 }
403
404 // Build result TreeTN
405 TreeTN::from_tensors(result_tensors, result_node_names)
406 }
407}
408
409#[cfg(test)]
410mod tests;