Skip to main content

tensor4all_treetn/operator/
compose.rs

1//! Composition of exclusive (non-overlapping) operators.
2//!
3//! This module provides functions to compose multiple operators that act on
4//! non-overlapping regions into a single operator on the full target space.
5
6use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9
10use anyhow::{Context, Result};
11use petgraph::stable_graph::NodeIndex;
12
13use tensor4all_core::{IndexLike, TensorLike};
14
15use super::index_mapping::IndexMapping;
16use super::linear_operator::LinearOperator;
17use super::Operator;
18use crate::site_index_network::SiteIndexNetwork;
19use crate::treetn::TreeTN;
20
21/// Check if a set of operators are exclusive (non-overlapping) on the target network.
22///
23/// Operators are exclusive if:
24/// 1. **Vertex-disjoint**: No two operators share a node
25/// 2. **Connected subtrees**: Each operator's nodes form a connected subtree
26/// 3. **Path-exclusive**: Paths between different operators don't cross other operators
27///
28/// # Arguments
29///
30/// * `target` - The target site index network (full space)
31/// * `operators` - The operators to check
32///
33/// # Returns
34///
35/// `true` if operators are exclusive, `false` otherwise.
36pub fn are_exclusive_operators<T, V, O>(
37    target: &SiteIndexNetwork<V, T::Index>,
38    operators: &[&O],
39) -> bool
40where
41    T: TensorLike,
42    V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
43    O: Operator<T, V>,
44{
45    // Collect node sets for each operator
46    let node_sets: Vec<HashSet<V>> = operators.iter().map(|op| op.node_names()).collect();
47
48    // 1. Check vertex-disjoint
49    for i in 0..node_sets.len() {
50        for j in (i + 1)..node_sets.len() {
51            if !node_sets[i].is_disjoint(&node_sets[j]) {
52                return false;
53            }
54        }
55    }
56
57    // 2. Check each operator's nodes form a connected subtree in target
58    for node_set in &node_sets {
59        if node_set.is_empty() {
60            continue;
61        }
62
63        // Convert to NodeIndex set
64        let node_indices: HashSet<NodeIndex> = node_set
65            .iter()
66            .filter_map(|name| target.node_index(name))
67            .collect();
68
69        if node_indices.len() != node_set.len() {
70            // Some nodes don't exist in target
71            return false;
72        }
73
74        if !target.is_connected_subset(&node_indices) {
75            return false;
76        }
77    }
78
79    // 3. Path-exclusive check: paths between operators should not cross other operators
80    for i in 0..node_sets.len() {
81        for j in (i + 1)..node_sets.len() {
82            if !check_path_exclusive::<T, V>(target, &node_sets[i], &node_sets[j], &node_sets) {
83                return false;
84            }
85        }
86    }
87
88    true
89}
90
91/// Check if paths between two operator regions don't cross other operators.
92fn check_path_exclusive<T, V>(
93    target: &SiteIndexNetwork<V, T::Index>,
94    set_a: &HashSet<V>,
95    set_b: &HashSet<V>,
96    all_sets: &[HashSet<V>],
97) -> bool
98where
99    T: TensorLike,
100    V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
101{
102    // Find a node from each set
103    let node_a = match set_a.iter().next() {
104        Some(n) => n,
105        None => return true, // Empty set
106    };
107    let node_b = match set_b.iter().next() {
108        Some(n) => n,
109        None => return true,
110    };
111
112    // Get path between them
113    let idx_a = match target.node_index(node_a) {
114        Some(idx) => idx,
115        None => return false,
116    };
117    let idx_b = match target.node_index(node_b) {
118        Some(idx) => idx,
119        None => return false,
120    };
121
122    let path = match target.path_between(idx_a, idx_b) {
123        Some(p) => p,
124        None => return false, // No path means disconnected, which is fine for exclusivity
125    };
126
127    // Check that path nodes (excluding endpoints) don't belong to other operators
128    let other_operator_nodes: HashSet<&V> = all_sets
129        .iter()
130        .filter(|s| *s != set_a && *s != set_b)
131        .flat_map(|s| s.iter())
132        .collect();
133
134    for node_idx in &path[1..path.len().saturating_sub(1)] {
135        if let Some(name) = target.node_name(*node_idx) {
136            if other_operator_nodes.contains(name) {
137                return false;
138            }
139        }
140    }
141
142    true
143}
144
145/// Compose exclusive LinearOperators into a single LinearOperator.
146///
147/// This function takes multiple non-overlapping operators and combines them into
148/// a single operator that acts on the full target space. Gap positions (nodes not
149/// covered by any operator) are filled with identity operators using `T::delta()`.
150///
151/// # Arguments
152///
153/// * `target` - The full site index network (defines the output structure)
154/// * `operators` - Non-overlapping LinearOperators to compose
155/// * `gap_site_indices` - Site indices for gap nodes: node_name -> [(input_index, output_index), ...]
156///
157/// # Returns
158///
159/// A LinearOperator representing the composed operator on the full target space.
160///
161/// # Errors
162///
163/// Returns an error if:
164/// - Operators are not exclusive (overlapping)
165/// - Operator nodes don't exist in target
166/// - Gap node site indices not provided
167#[allow(clippy::type_complexity)]
168pub fn compose_exclusive_linear_operators<T, V>(
169    target: &SiteIndexNetwork<V, T::Index>,
170    operators: &[&LinearOperator<T, V>],
171    gap_site_indices: &HashMap<V, Vec<(T::Index, T::Index)>>,
172) -> Result<LinearOperator<T, V>>
173where
174    T: TensorLike,
175    T::Index: IndexLike + Clone + Hash + Eq + Debug,
176    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
177    V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
178{
179    compose_exclusive_linear_operators_inner(target, operators, gap_site_indices, true)
180}
181
182#[allow(clippy::type_complexity)]
183#[allow(dead_code)]
184pub(crate) fn compose_exclusive_linear_operators_unchecked<T, V>(
185    target: &SiteIndexNetwork<V, T::Index>,
186    operators: &[&LinearOperator<T, V>],
187    gap_site_indices: &HashMap<V, Vec<(T::Index, T::Index)>>,
188) -> Result<LinearOperator<T, V>>
189where
190    T: TensorLike,
191    T::Index: IndexLike + Clone + Hash + Eq + Debug,
192    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
193    V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
194{
195    compose_exclusive_linear_operators_inner(target, operators, gap_site_indices, false)
196}
197
198#[allow(clippy::type_complexity)]
199fn compose_exclusive_linear_operators_inner<T, V>(
200    target: &SiteIndexNetwork<V, T::Index>,
201    operators: &[&LinearOperator<T, V>],
202    gap_site_indices: &HashMap<V, Vec<(T::Index, T::Index)>>,
203    validate_exclusivity: bool,
204) -> Result<LinearOperator<T, V>>
205where
206    T: TensorLike,
207    T::Index: IndexLike + Clone + Hash + Eq + Debug,
208    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
209    V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
210{
211    // 1. Validate exclusivity
212    if validate_exclusivity && !are_exclusive_operators::<T, V, _>(target, operators) {
213        return Err(anyhow::anyhow!(
214            "Operators are not exclusive: they may overlap or not form connected subtrees"
215        ))
216        .context("compose_exclusive_linear_operators: operators must be exclusive");
217    }
218
219    // 2. Collect covered nodes and build node-to-operator map
220    let covered: HashSet<V> = operators.iter().flat_map(|op| op.node_names()).collect();
221
222    let mut node_to_operator: HashMap<V, usize> = HashMap::new();
223    for (op_idx, op) in operators.iter().enumerate() {
224        for name in op.node_names() {
225            node_to_operator.insert(name, op_idx);
226        }
227    }
228
229    // 3. Identify gap nodes
230    let all_target_nodes: HashSet<V> = target.node_names().into_iter().cloned().collect();
231    let gaps: Vec<V> = all_target_nodes.difference(&covered).cloned().collect();
232    let gap_set: HashSet<V> = gaps.iter().cloned().collect();
233
234    // 4. Identify cross-component edges and create dummy link pairs
235    // An edge is cross-component if endpoints are in different operators or one is a gap
236    let mut dummy_links_for_node: HashMap<V, Vec<T::Index>> = HashMap::new();
237
238    for (node_a, node_b) in target.edges() {
239        let comp_a = node_to_operator.get(&node_a);
240        let comp_b = node_to_operator.get(&node_b);
241        let is_gap_a = gap_set.contains(&node_a);
242        let is_gap_b = gap_set.contains(&node_b);
243
244        let is_cross = match (comp_a, comp_b, is_gap_a, is_gap_b) {
245            (Some(a), Some(b), false, false) => a != b,
246            _ => true,
247        };
248
249        if is_cross {
250            let (link_a, link_b) = T::Index::create_dummy_link_pair();
251            dummy_links_for_node
252                .entry(node_a.clone())
253                .or_default()
254                .push(link_a);
255            dummy_links_for_node
256                .entry(node_b.clone())
257                .or_default()
258                .push(link_b);
259        }
260    }
261
262    // 5. Build tensors and mappings
263    let mut tensors: Vec<T> = Vec::new();
264    let mut result_node_names: Vec<V> = Vec::new();
265    let mut combined_input_mapping: HashMap<V, IndexMapping<T::Index>> = HashMap::new();
266    let mut combined_output_mapping: HashMap<V, IndexMapping<T::Index>> = HashMap::new();
267
268    // 5a. Add tensors from operators (with dummy links added via outer product)
269    for op in operators {
270        for name in op.node_names() {
271            let node_idx = op
272                .mpo()
273                .node_index(&name)
274                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", name))?;
275            let mut tensor = op
276                .mpo()
277                .tensor(node_idx)
278                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", name))?
279                .clone();
280
281            // Add dummy links via outer product with ones tensor
282            if let Some(links) = dummy_links_for_node.get(&name) {
283                for link in links {
284                    let ones = T::ones(std::slice::from_ref(link)).with_context(|| {
285                        format!("Failed to create ones tensor for dummy link at {:?}", name)
286                    })?;
287                    tensor = tensor.outer_product(&ones).with_context(|| {
288                        format!("Failed to add dummy link to tensor at {:?}", name)
289                    })?;
290                }
291            }
292
293            tensors.push(tensor);
294            result_node_names.push(name.clone());
295
296            // Copy mappings
297            if let Some(input_map) = op.get_input_mapping(&name) {
298                combined_input_mapping.insert(name.clone(), input_map.clone());
299            }
300            if let Some(output_map) = op.get_output_mapping(&name) {
301                combined_output_mapping.insert(name.clone(), output_map.clone());
302            }
303        }
304    }
305
306    // 5b. Add identity tensors at gaps (with dummy links added via outer product)
307    for gap_name in gaps {
308        let index_pairs = gap_site_indices.get(&gap_name).ok_or_else(|| {
309            anyhow::anyhow!("Site indices not provided for gap node {:?}", gap_name)
310        })?;
311
312        // Collect site indices for delta
313        let input_indices: Vec<T::Index> = index_pairs.iter().map(|(i, _)| i.clone()).collect();
314        let output_indices: Vec<T::Index> = index_pairs.iter().map(|(_, o)| o.clone()).collect();
315
316        // Store mappings for the first site pair (if any)
317        if let Some((true_input, true_output)) = index_pairs.first() {
318            if !combined_input_mapping.contains_key(&gap_name) {
319                combined_input_mapping.insert(
320                    gap_name.clone(),
321                    IndexMapping {
322                        true_index: true_input.clone(),
323                        internal_index: input_indices[0].clone(),
324                    },
325                );
326            }
327            if !combined_output_mapping.contains_key(&gap_name) {
328                combined_output_mapping.insert(
329                    gap_name.clone(),
330                    IndexMapping {
331                        true_index: true_output.clone(),
332                        internal_index: output_indices[0].clone(),
333                    },
334                );
335            }
336        }
337
338        // Create delta tensor for site indices
339        let mut identity_tensor = if input_indices.is_empty() {
340            T::delta(&[], &[]).context("Failed to create scalar identity tensor")?
341        } else {
342            T::delta(&input_indices, &output_indices).with_context(|| {
343                format!("Failed to build identity tensor for gap {:?}", gap_name)
344            })?
345        };
346
347        // Add dummy links via outer product (bond indices, not site indices)
348        if let Some(links) = dummy_links_for_node.get(&gap_name) {
349            for link in links {
350                let ones = T::ones(std::slice::from_ref(link)).with_context(|| {
351                    format!(
352                        "Failed to create ones tensor for dummy link at gap {:?}",
353                        gap_name
354                    )
355                })?;
356                identity_tensor = identity_tensor.outer_product(&ones).with_context(|| {
357                    format!("Failed to add dummy link to gap tensor {:?}", gap_name)
358                })?;
359            }
360        }
361
362        tensors.push(identity_tensor);
363        result_node_names.push(gap_name);
364    }
365
366    // 6. Create TreeTN from tensors
367    let mpo = TreeTN::from_tensors(tensors, result_node_names)
368        .context("compose_exclusive_linear_operators: failed to create TreeTN")?;
369
370    Ok(LinearOperator::new(
371        mpo,
372        combined_input_mapping,
373        combined_output_mapping,
374    ))
375}
376
377#[cfg(test)]
378mod tests;