tensor4all_treetn/treetn/
decompose.rs1use std::collections::{HashMap, HashSet, VecDeque};
7use std::hash::Hash;
8
9use anyhow::Result;
10
11use tensor4all_core::{Canonical, FactorizeOptions, IndexLike, TensorLike};
12
13use super::TreeTN;
14
15#[derive(Debug, Clone)]
25pub struct TreeTopology<V, I> {
26 pub nodes: HashMap<V, Vec<I>>,
28 pub edges: Vec<(V, V)>,
30}
31
32impl<V: Clone + Hash + Eq, I: Clone + Eq> TreeTopology<V, I> {
33 pub fn new(nodes: HashMap<V, Vec<I>>, edges: Vec<(V, V)>) -> Self {
39 Self { nodes, edges }
40 }
41
42 pub fn validate(&self) -> Result<()> {
44 let n = self.nodes.len();
45 if n == 0 {
46 return Err(anyhow::anyhow!("Tree topology must have at least one node"));
47 }
48 if n > 1 && self.edges.len() != n - 1 {
49 return Err(anyhow::anyhow!(
50 "Tree must have exactly n-1 edges: got {} nodes and {} edges",
51 n,
52 self.edges.len()
53 ));
54 }
55 for (a, b) in &self.edges {
57 if !self.nodes.contains_key(a) {
58 return Err(anyhow::anyhow!("Edge refers to unknown node"));
59 }
60 if !self.nodes.contains_key(b) {
61 return Err(anyhow::anyhow!("Edge refers to unknown node"));
62 }
63 }
64 Ok(())
65 }
66}
67
68pub fn factorize_tensor_to_treetn<T, V>(
96 tensor: &T,
97 topology: &TreeTopology<V, <T::Index as IndexLike>::Id>,
98 root: &V,
99) -> Result<TreeTN<T, V>>
100where
101 T: TensorLike,
102 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
103 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug + Ord,
104{
105 factorize_tensor_to_treetn_with(tensor, topology, FactorizeOptions::qr(), root)
106}
107
108pub fn factorize_tensor_to_treetn_with<T, V>(
133 tensor: &T,
134 topology: &TreeTopology<V, <T::Index as IndexLike>::Id>,
135 options: FactorizeOptions,
136 root: &V,
137) -> Result<TreeTN<T, V>>
138where
139 T: TensorLike,
140 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
141 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug + Ord,
142{
143 factorize_tensor_to_treetn_with_root_impl(tensor, topology, options, root)
144}
145
146fn factorize_tensor_to_treetn_with_root_impl<T, V>(
147 tensor: &T,
148 topology: &TreeTopology<V, <T::Index as IndexLike>::Id>,
149 options: FactorizeOptions,
150 root: &V,
151) -> Result<TreeTN<T, V>>
152where
153 T: TensorLike,
154 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
155 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug + Ord,
156{
157 topology.validate()?;
158
159 let tensor_indices = tensor.external_indices();
160
161 if topology.nodes.len() == 1 {
162 let node_name = topology.nodes.keys().next().unwrap().clone();
164 if &node_name != root {
165 return Err(anyhow::anyhow!("Requested root node not found in topology"));
166 }
167 let mut tn = TreeTN::<T, V>::new();
168 tn.add_tensor(node_name.clone(), tensor.clone())?;
169 tn.set_canonical_region([node_name])?;
170 return Ok(tn);
171 }
172
173 let tensor_ids: HashSet<_> = tensor_indices.iter().map(|idx| idx.id().clone()).collect();
175 for (node, ids) in &topology.nodes {
176 for id in ids {
177 if !tensor_ids.contains(id) {
178 return Err(anyhow::anyhow!(
179 "Index ID {:?} for node {:?} not found in tensor (tensor has {} indices)",
180 id,
181 node,
182 tensor_indices.len()
183 ));
184 }
185 }
186 }
187
188 let mut assigned_ids: HashSet<<T::Index as IndexLike>::Id> = HashSet::new();
192 for (node, ids) in &topology.nodes {
193 for id in ids {
194 if !assigned_ids.insert(id.clone()) {
195 return Err(anyhow::anyhow!(
196 "Index ID {:?} is assigned to multiple nodes (at least {:?})",
197 id,
198 node
199 ));
200 }
201 }
202 }
203
204 let mut adj: HashMap<V, Vec<V>> = HashMap::new();
206 for node in topology.nodes.keys() {
207 adj.insert(node.clone(), Vec::new());
208 }
209 for (a, b) in &topology.edges {
210 adj.get_mut(a).unwrap().push(b.clone());
211 adj.get_mut(b).unwrap().push(a.clone());
212 }
213 for neighbors in adj.values_mut() {
215 neighbors.sort();
216 }
217
218 if !adj.contains_key(root) {
221 return Err(anyhow::anyhow!("Requested root node not found in topology"));
222 }
223
224 let mut traversal_order: Vec<(V, Option<V>)> = Vec::new(); let mut visited: HashSet<V> = HashSet::new();
227 let mut queue = VecDeque::new();
228 queue.push_back((root.clone(), None::<V>));
229
230 while let Some((node, parent)) = queue.pop_front() {
231 if visited.contains(&node) {
232 continue;
233 }
234 visited.insert(node.clone());
235 traversal_order.push((node.clone(), parent));
236
237 for neighbor in adj.get(&node).unwrap() {
238 if !visited.contains(neighbor) {
239 queue.push_back((neighbor.clone(), Some(node.clone())));
240 }
241 }
242 }
243
244 let mut children_by_parent: HashMap<V, Vec<V>> = HashMap::new();
245 for (node, parent) in &traversal_order {
246 if let Some(parent) = parent {
247 children_by_parent
248 .entry(parent.clone())
249 .or_default()
250 .push(node.clone());
251 }
252 }
253 for children in children_by_parent.values_mut() {
254 children.sort();
255 }
256
257 traversal_order.reverse();
259
260 let mut current_tensor = tensor.clone();
262
263 let mut node_tensors: HashMap<V, T> = HashMap::new();
265 let mut child_bonds: HashMap<V, T::Index> = HashMap::new();
267
268 let factorize_options = FactorizeOptions {
270 canonical: Canonical::Left,
271 ..options
272 };
273
274 #[allow(clippy::needless_range_loop)]
276 for i in 0..traversal_order.len() - 1 {
277 let (node, _parent) = &traversal_order[i];
278 let node_ids = topology.nodes.get(node).unwrap();
280
281 let current_indices = current_tensor.external_indices();
285 let mut desired_ids: HashSet<<T::Index as IndexLike>::Id> =
286 node_ids.iter().cloned().collect();
287 if let Some(children) = children_by_parent.get(node) {
288 for child in children {
289 let bond = child_bonds.get(child).ok_or_else(|| {
290 anyhow::anyhow!(
291 "Missing child bond for node {:?} while processing parent {:?}",
292 child,
293 node
294 )
295 })?;
296 desired_ids.insert(bond.id().clone());
297 }
298 }
299 let left_inds: Vec<_> = current_indices
300 .iter()
301 .filter(|idx| desired_ids.contains(idx.id()))
302 .cloned()
303 .collect();
304
305 if left_inds.is_empty() && current_indices.len() > 1 {
306 return Err(anyhow::anyhow!(
310 "No physical indices found for node {:?} (requested ids={:?}) in current tensor indices={:?}",
311 node,
312 node_ids,
313 current_indices
314 .iter()
315 .map(|idx| idx.id().clone())
316 .collect::<Vec<_>>()
317 ));
318 }
319
320 let factorize_result = current_tensor
324 .factorize(&left_inds, &factorize_options)
325 .map_err(|e| anyhow::anyhow!("Factorization failed: {:?}", e))?;
326
327 let left_indices = factorize_result.left.external_indices();
328 let right_indices = factorize_result.right.external_indices();
329 let shared_bonds =
330 tensor4all_core::index_ops::common_inds::<T::Index>(&left_indices, &right_indices);
331 if shared_bonds.len() != 1 {
332 return Err(anyhow::anyhow!(
333 "Expected exactly one parent bond for node {:?}, found {}",
334 node,
335 shared_bonds.len()
336 ));
337 }
338 child_bonds.insert(node.clone(), shared_bonds[0].clone());
339
340 node_tensors.insert(node.clone(), factorize_result.left);
342
343 current_tensor = factorize_result.right;
345 }
346
347 let (root_node, _) = &traversal_order.last().unwrap();
349 node_tensors.insert(root_node.clone(), current_tensor);
350
351 let mut node_names: Vec<V> = topology.nodes.keys().cloned().collect();
355 node_names.sort();
356 let tensors: Vec<T> = node_names
357 .iter()
358 .map(|name| node_tensors.get(name).cloned().unwrap())
359 .collect();
360
361 let mut tn = TreeTN::from_tensors(tensors, node_names)?;
362 tn.set_canonical_region([root.clone()])?;
363 Ok(tn)
364}
365
366#[cfg(test)]
367mod tests;