tensor4all_treetn/operator/
compose.rs1use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9
10use anyhow::{Context, Result};
11use petgraph::stable_graph::NodeIndex;
12
13use tensor4all_core::{IndexLike, TensorLike};
14
15use super::index_mapping::IndexMapping;
16use super::linear_operator::LinearOperator;
17use super::Operator;
18use crate::site_index_network::SiteIndexNetwork;
19use crate::treetn::TreeTN;
20
21pub fn are_exclusive_operators<T, V, O>(
37 target: &SiteIndexNetwork<V, T::Index>,
38 operators: &[&O],
39) -> bool
40where
41 T: TensorLike,
42 V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
43 O: Operator<T, V>,
44{
45 let node_sets: Vec<HashSet<V>> = operators.iter().map(|op| op.node_names()).collect();
47
48 for i in 0..node_sets.len() {
50 for j in (i + 1)..node_sets.len() {
51 if !node_sets[i].is_disjoint(&node_sets[j]) {
52 return false;
53 }
54 }
55 }
56
57 for node_set in &node_sets {
59 if node_set.is_empty() {
60 continue;
61 }
62
63 let node_indices: HashSet<NodeIndex> = node_set
65 .iter()
66 .filter_map(|name| target.node_index(name))
67 .collect();
68
69 if node_indices.len() != node_set.len() {
70 return false;
72 }
73
74 if !target.is_connected_subset(&node_indices) {
75 return false;
76 }
77 }
78
79 for i in 0..node_sets.len() {
81 for j in (i + 1)..node_sets.len() {
82 if !check_path_exclusive::<T, V>(target, &node_sets[i], &node_sets[j], &node_sets) {
83 return false;
84 }
85 }
86 }
87
88 true
89}
90
91fn check_path_exclusive<T, V>(
93 target: &SiteIndexNetwork<V, T::Index>,
94 set_a: &HashSet<V>,
95 set_b: &HashSet<V>,
96 all_sets: &[HashSet<V>],
97) -> bool
98where
99 T: TensorLike,
100 V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
101{
102 let node_a = match set_a.iter().next() {
104 Some(n) => n,
105 None => return true, };
107 let node_b = match set_b.iter().next() {
108 Some(n) => n,
109 None => return true,
110 };
111
112 let idx_a = match target.node_index(node_a) {
114 Some(idx) => idx,
115 None => return false,
116 };
117 let idx_b = match target.node_index(node_b) {
118 Some(idx) => idx,
119 None => return false,
120 };
121
122 let path = match target.path_between(idx_a, idx_b) {
123 Some(p) => p,
124 None => return false, };
126
127 let other_operator_nodes: HashSet<&V> = all_sets
129 .iter()
130 .filter(|s| *s != set_a && *s != set_b)
131 .flat_map(|s| s.iter())
132 .collect();
133
134 for node_idx in &path[1..path.len().saturating_sub(1)] {
135 if let Some(name) = target.node_name(*node_idx) {
136 if other_operator_nodes.contains(name) {
137 return false;
138 }
139 }
140 }
141
142 true
143}
144
145#[allow(clippy::type_complexity)]
168pub fn compose_exclusive_linear_operators<T, V>(
169 target: &SiteIndexNetwork<V, T::Index>,
170 operators: &[&LinearOperator<T, V>],
171 gap_site_indices: &HashMap<V, Vec<(T::Index, T::Index)>>,
172) -> Result<LinearOperator<T, V>>
173where
174 T: TensorLike,
175 T::Index: IndexLike + Clone + Hash + Eq + Debug,
176 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
177 V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
178{
179 compose_exclusive_linear_operators_inner(target, operators, gap_site_indices, true)
180}
181
182#[allow(clippy::type_complexity)]
183#[allow(dead_code)]
184pub(crate) fn compose_exclusive_linear_operators_unchecked<T, V>(
185 target: &SiteIndexNetwork<V, T::Index>,
186 operators: &[&LinearOperator<T, V>],
187 gap_site_indices: &HashMap<V, Vec<(T::Index, T::Index)>>,
188) -> Result<LinearOperator<T, V>>
189where
190 T: TensorLike,
191 T::Index: IndexLike + Clone + Hash + Eq + Debug,
192 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
193 V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
194{
195 compose_exclusive_linear_operators_inner(target, operators, gap_site_indices, false)
196}
197
198#[allow(clippy::type_complexity)]
199fn compose_exclusive_linear_operators_inner<T, V>(
200 target: &SiteIndexNetwork<V, T::Index>,
201 operators: &[&LinearOperator<T, V>],
202 gap_site_indices: &HashMap<V, Vec<(T::Index, T::Index)>>,
203 validate_exclusivity: bool,
204) -> Result<LinearOperator<T, V>>
205where
206 T: TensorLike,
207 T::Index: IndexLike + Clone + Hash + Eq + Debug,
208 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
209 V: Clone + Hash + Eq + Ord + Send + Sync + Debug,
210{
211 if validate_exclusivity && !are_exclusive_operators::<T, V, _>(target, operators) {
213 return Err(anyhow::anyhow!(
214 "Operators are not exclusive: they may overlap or not form connected subtrees"
215 ))
216 .context("compose_exclusive_linear_operators: operators must be exclusive");
217 }
218
219 let covered: HashSet<V> = operators.iter().flat_map(|op| op.node_names()).collect();
221
222 let mut node_to_operator: HashMap<V, usize> = HashMap::new();
223 for (op_idx, op) in operators.iter().enumerate() {
224 for name in op.node_names() {
225 node_to_operator.insert(name, op_idx);
226 }
227 }
228
229 let all_target_nodes: HashSet<V> = target.node_names().into_iter().cloned().collect();
231 let gaps: Vec<V> = all_target_nodes.difference(&covered).cloned().collect();
232 let gap_set: HashSet<V> = gaps.iter().cloned().collect();
233
234 let mut dummy_links_for_node: HashMap<V, Vec<T::Index>> = HashMap::new();
237
238 for (node_a, node_b) in target.edges() {
239 let comp_a = node_to_operator.get(&node_a);
240 let comp_b = node_to_operator.get(&node_b);
241 let is_gap_a = gap_set.contains(&node_a);
242 let is_gap_b = gap_set.contains(&node_b);
243
244 let is_cross = match (comp_a, comp_b, is_gap_a, is_gap_b) {
245 (Some(a), Some(b), false, false) => a != b,
246 _ => true,
247 };
248
249 if is_cross {
250 let (link_a, link_b) = T::Index::create_dummy_link_pair();
251 dummy_links_for_node
252 .entry(node_a.clone())
253 .or_default()
254 .push(link_a);
255 dummy_links_for_node
256 .entry(node_b.clone())
257 .or_default()
258 .push(link_b);
259 }
260 }
261
262 let mut tensors: Vec<T> = Vec::new();
264 let mut result_node_names: Vec<V> = Vec::new();
265 let mut combined_input_mapping: HashMap<V, IndexMapping<T::Index>> = HashMap::new();
266 let mut combined_output_mapping: HashMap<V, IndexMapping<T::Index>> = HashMap::new();
267
268 for op in operators {
270 for name in op.node_names() {
271 let node_idx = op
272 .mpo()
273 .node_index(&name)
274 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", name))?;
275 let mut tensor = op
276 .mpo()
277 .tensor(node_idx)
278 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", name))?
279 .clone();
280
281 if let Some(links) = dummy_links_for_node.get(&name) {
283 for link in links {
284 let ones = T::ones(std::slice::from_ref(link)).with_context(|| {
285 format!("Failed to create ones tensor for dummy link at {:?}", name)
286 })?;
287 tensor = tensor.outer_product(&ones).with_context(|| {
288 format!("Failed to add dummy link to tensor at {:?}", name)
289 })?;
290 }
291 }
292
293 tensors.push(tensor);
294 result_node_names.push(name.clone());
295
296 if let Some(input_map) = op.get_input_mapping(&name) {
298 combined_input_mapping.insert(name.clone(), input_map.clone());
299 }
300 if let Some(output_map) = op.get_output_mapping(&name) {
301 combined_output_mapping.insert(name.clone(), output_map.clone());
302 }
303 }
304 }
305
306 for gap_name in gaps {
308 let index_pairs = gap_site_indices.get(&gap_name).ok_or_else(|| {
309 anyhow::anyhow!("Site indices not provided for gap node {:?}", gap_name)
310 })?;
311
312 let input_indices: Vec<T::Index> = index_pairs.iter().map(|(i, _)| i.clone()).collect();
314 let output_indices: Vec<T::Index> = index_pairs.iter().map(|(_, o)| o.clone()).collect();
315
316 if let Some((true_input, true_output)) = index_pairs.first() {
318 if !combined_input_mapping.contains_key(&gap_name) {
319 combined_input_mapping.insert(
320 gap_name.clone(),
321 IndexMapping {
322 true_index: true_input.clone(),
323 internal_index: input_indices[0].clone(),
324 },
325 );
326 }
327 if !combined_output_mapping.contains_key(&gap_name) {
328 combined_output_mapping.insert(
329 gap_name.clone(),
330 IndexMapping {
331 true_index: true_output.clone(),
332 internal_index: output_indices[0].clone(),
333 },
334 );
335 }
336 }
337
338 let mut identity_tensor = if input_indices.is_empty() {
340 T::delta(&[], &[]).context("Failed to create scalar identity tensor")?
341 } else {
342 T::delta(&input_indices, &output_indices).with_context(|| {
343 format!("Failed to build identity tensor for gap {:?}", gap_name)
344 })?
345 };
346
347 if let Some(links) = dummy_links_for_node.get(&gap_name) {
349 for link in links {
350 let ones = T::ones(std::slice::from_ref(link)).with_context(|| {
351 format!(
352 "Failed to create ones tensor for dummy link at gap {:?}",
353 gap_name
354 )
355 })?;
356 identity_tensor = identity_tensor.outer_product(&ones).with_context(|| {
357 format!("Failed to add dummy link to gap tensor {:?}", gap_name)
358 })?;
359 }
360 }
361
362 tensors.push(identity_tensor);
363 result_node_names.push(gap_name);
364 }
365
366 let mpo = TreeTN::from_tensors(tensors, result_node_names)
368 .context("compose_exclusive_linear_operators: failed to create TreeTN")?;
369
370 Ok(LinearOperator::new(
371 mpo,
372 combined_input_mapping,
373 combined_output_mapping,
374 ))
375}
376
377#[cfg(test)]
378mod tests;