Skip to main content

tensor4all_treetn/operator/
apply.rs

1//! Apply LinearOperator to TreeTN state.
2//!
3//! This module provides the `apply_linear_operator` function for computing `A|x⟩`
4//! where A is a LinearOperator (MPO with index mappings) and |x⟩ is a TreeTN state.
5//!
6//! # Algorithm
7//!
8//! The application works as follows:
9//! 1. **Partial Site Handling**: If the operator only covers some nodes of the state,
10//!    use `compose_exclusive_linear_operators` to fill gaps with identity operators.
11//! 2. **Index Transformation**: Replace state's site indices with operator's input indices.
12//! 3. **Contraction**: Contract the transformed state with the operator using
13//!    `contract_zipup`, `contract_fit`, or `contract_naive` depending on options.
14//! 4. **Output Transformation**: Replace operator's output indices with true output indices.
15//!
16//! # Example
17//!
18//! ```
19//! use std::collections::HashMap;
20//!
21//! use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
22//! use tensor4all_treetn::{apply_linear_operator, ApplyOptions, IndexMapping, LinearOperator, TreeTN};
23//!
24//! # fn main() -> anyhow::Result<()> {
25//! let site = DynIndex::new_dyn(2);
26//! let state_tensor = TensorDynLen::from_dense(vec![site.clone()], vec![1.0, 2.0])?;
27//! let state = TreeTN::<TensorDynLen, usize>::from_tensors(vec![state_tensor], vec![0])?;
28//!
29//! let input_internal = DynIndex::new_dyn(2);
30//! let output_internal = DynIndex::new_dyn(2);
31//! let mpo_tensor = TensorDynLen::from_dense(
32//!     vec![input_internal.clone(), output_internal.clone()],
33//!     vec![1.0, 0.0, 0.0, 1.0],
34//! )?;
35//! let mpo = TreeTN::<TensorDynLen, usize>::from_tensors(vec![mpo_tensor], vec![0])?;
36//!
37//! let mut input_mapping = HashMap::new();
38//! input_mapping.insert(
39//!     0usize,
40//!     IndexMapping {
41//!         true_index: site.clone(),
42//!         internal_index: input_internal,
43//!     },
44//! );
45//! let mut output_mapping = HashMap::new();
46//! output_mapping.insert(
47//!     0usize,
48//!     IndexMapping {
49//!         true_index: site.clone(),
50//!         internal_index: output_internal,
51//!     },
52//! );
53//!
54//! let operator = LinearOperator::new(mpo, input_mapping, output_mapping);
55//! let result = apply_linear_operator(&operator, &state, ApplyOptions::default())?;
56//! assert_eq!(result.node_count(), 1);
57//!
58//! // Applying identity preserves the state
59//! let result_dense = result.to_dense()?;
60//! let state_dense = state.to_dense()?;
61//! assert!((&result_dense - &state_dense).maxabs() < 1e-12);
62//! # Ok(())
63//! # }
64//! ```
65
66use std::collections::{HashMap, HashSet};
67use std::hash::Hash;
68use std::sync::Arc;
69
70use anyhow::{Context, Result};
71
72use tensor4all_core::{IndexLike, SvdTruncationPolicy, TensorIndex, TensorLike};
73
74use super::index_mapping::IndexMapping;
75use super::linear_operator::LinearOperator;
76use super::Operator;
77use crate::operator::compose::{
78    compose_exclusive_linear_operators, compose_exclusive_linear_operators_unchecked,
79};
80use crate::treetn::contraction::{contract, ContractionMethod, ContractionOptions};
81use crate::treetn::TreeTN;
82
83/// Options for [`apply_linear_operator`].
84///
85/// Controls the contraction algorithm, truncation parameters, and
86/// iterative sweep settings.
87///
88/// # Defaults
89///
90/// - `method`: [`ContractionMethod::Zipup`] (single-sweep, no iteration)
91/// - `max_rank`: `None` (no rank limit)
92/// - `svd_policy`: `None` (uses the SVD global default policy)
93/// - `qr_rtol`: `None` (uses the QR global default tolerance)
94/// - `nfullsweeps`: `1` (only used by Fit method)
95/// - `convergence_tol`: `None` (only used by Fit method)
96///
97/// # Examples
98///
99/// ```
100/// use tensor4all_treetn::ApplyOptions;
101/// use tensor4all_core::SvdTruncationPolicy;
102///
103/// // Default: Zipup with no truncation
104/// let opts = ApplyOptions::default();
105/// assert_eq!(opts.max_rank, None);
106///
107/// // Zipup with rank and tolerance limits
108/// let opts = ApplyOptions::zipup()
109///     .with_max_rank(50)
110///     .with_svd_policy(SvdTruncationPolicy::new(1e-8));
111/// assert_eq!(opts.max_rank, Some(50));
112/// assert_eq!(opts.svd_policy, Some(SvdTruncationPolicy::new(1e-8)));
113///
114/// // Fit method with sweep control
115/// let opts = ApplyOptions::fit().with_nfullsweeps(3).with_max_rank(20);
116/// assert_eq!(opts.nfullsweeps, 3);
117///
118/// // Naive contraction (exact, no truncation)
119/// let opts = ApplyOptions::naive();
120/// assert_eq!(opts.max_rank, None);
121/// ```
122#[derive(Debug, Clone)]
123pub struct ApplyOptions {
124    /// Contraction method to use.
125    pub method: ContractionMethod,
126    /// Maximum bond dimension for truncation.
127    pub max_rank: Option<usize>,
128    /// Explicit SVD truncation policy.
129    pub svd_policy: Option<SvdTruncationPolicy>,
130    /// QR-specific relative tolerance.
131    pub qr_rtol: Option<f64>,
132    /// Number of full sweeps for Fit method.
133    ///
134    /// A full sweep visits each edge twice (forward and backward) using an Euler tour.
135    pub nfullsweeps: usize,
136    /// Convergence tolerance for Fit method.
137    pub convergence_tol: Option<f64>,
138}
139
140impl Default for ApplyOptions {
141    fn default() -> Self {
142        Self {
143            method: ContractionMethod::Zipup,
144            max_rank: None,
145            svd_policy: None,
146            qr_rtol: None,
147            nfullsweeps: 1,
148            convergence_tol: None,
149        }
150    }
151}
152
153impl ApplyOptions {
154    /// Create options with ZipUp method (default).
155    pub fn zipup() -> Self {
156        Self::default()
157    }
158
159    /// Create options with Fit method.
160    pub fn fit() -> Self {
161        Self {
162            method: ContractionMethod::Fit,
163            ..Default::default()
164        }
165    }
166
167    /// Create options with Naive method.
168    pub fn naive() -> Self {
169        Self {
170            method: ContractionMethod::Naive,
171            ..Default::default()
172        }
173    }
174
175    /// Set maximum bond dimension.
176    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
177        self.max_rank = Some(max_rank);
178        self
179    }
180
181    /// Set the SVD truncation policy.
182    pub fn with_svd_policy(mut self, policy: SvdTruncationPolicy) -> Self {
183        self.svd_policy = Some(policy);
184        self
185    }
186
187    /// Set the QR-specific truncation tolerance.
188    pub fn with_qr_rtol(mut self, rtol: f64) -> Self {
189        self.qr_rtol = Some(rtol);
190        self
191    }
192
193    /// Set number of full sweeps for Fit method.
194    pub fn with_nfullsweeps(mut self, nfullsweeps: usize) -> Self {
195        self.nfullsweeps = nfullsweeps;
196        self
197    }
198}
199
200/// Apply a LinearOperator to a TreeTN state: compute `A|x⟩`.
201///
202/// This function handles:
203/// - Partial operators (fills gaps with identity via compose_exclusive_linear_operators)
204/// - Index transformations (input/output mappings)
205/// - Multiple contraction algorithms (ZipUp, Fit, Naive)
206///
207/// # Arguments
208///
209/// * `operator` - The LinearOperator to apply
210/// * `state` - The input state |x⟩
211/// * `options` - Options controlling the contraction algorithm
212///
213/// # Returns
214///
215/// The result `A|x⟩` as a TreeTN, or an error if application fails.
216///
217/// # Example
218///
219/// ```
220/// use std::collections::HashMap;
221///
222/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
223/// use tensor4all_treetn::{apply_linear_operator, ApplyOptions, IndexMapping, LinearOperator, TreeTN};
224///
225/// # fn main() -> anyhow::Result<()> {
226/// let site = DynIndex::new_dyn(2);
227/// let state_tensor = TensorDynLen::from_dense(vec![site.clone()], vec![1.0, 2.0])?;
228/// let state = TreeTN::<TensorDynLen, usize>::from_tensors(vec![state_tensor], vec![0])?;
229///
230/// let input_internal = DynIndex::new_dyn(2);
231/// let output_internal = DynIndex::new_dyn(2);
232/// let mpo_tensor = TensorDynLen::from_dense(
233///     vec![input_internal.clone(), output_internal.clone()],
234///     vec![1.0, 0.0, 0.0, 1.0],
235/// )?;
236/// let mpo = TreeTN::<TensorDynLen, usize>::from_tensors(vec![mpo_tensor], vec![0])?;
237///
238/// let mut input_mapping = HashMap::new();
239/// input_mapping.insert(
240///     0usize,
241///     IndexMapping {
242///         true_index: site.clone(),
243///         internal_index: input_internal,
244///     },
245/// );
246/// let mut output_mapping = HashMap::new();
247/// output_mapping.insert(
248///     0usize,
249///     IndexMapping {
250///         true_index: site.clone(),
251///         internal_index: output_internal,
252///     },
253/// );
254///
255/// let operator = LinearOperator::new(mpo, input_mapping, output_mapping);
256///
257/// let result = apply_linear_operator(&operator, &state, ApplyOptions::default())?;
258/// assert_eq!(result.node_count(), 1);
259///
260/// // Applying identity preserves the state
261/// let result_dense = result.to_dense()?;
262/// let state_dense = state.to_dense()?;
263/// assert!((&result_dense - &state_dense).maxabs() < 1e-12);
264///
265/// let truncated = apply_linear_operator(
266///     &operator,
267///     &state,
268///     ApplyOptions::zipup()
269///         .with_max_rank(4)
270///         .with_svd_policy(tensor4all_core::SvdTruncationPolicy::new(1e-10)),
271/// )?;
272/// assert_eq!(truncated.node_count(), 1);
273/// # Ok(())
274/// # }
275/// ```
276pub fn apply_linear_operator<T, V>(
277    operator: &LinearOperator<T, V>,
278    state: &TreeTN<T, V>,
279    options: ApplyOptions,
280) -> Result<TreeTN<T, V>>
281where
282    T: TensorLike,
283    T::Index: IndexLike + Clone + Hash + Eq + std::fmt::Debug,
284    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
285    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
286{
287    // 1. Check if operator covers all state nodes
288    let state_nodes: HashSet<V> = state.node_names().into_iter().collect();
289    let op_nodes: HashSet<V> = operator.node_names();
290
291    let full_operator = if op_nodes == state_nodes {
292        // Operator covers all nodes - use directly
293        operator.clone()
294    } else if op_nodes.is_subset(&state_nodes) {
295        // Partial operator - need to compose with identity on gaps
296        extend_operator_to_full_space(operator, state)?
297    } else {
298        return Err(anyhow::anyhow!(
299            "Operator nodes {:?} are not a subset of state nodes {:?}",
300            op_nodes,
301            state_nodes
302        ));
303    };
304
305    // 2. Transform state's site indices to operator's input indices
306    let transformed_state = transform_state_to_input(&full_operator, state)?;
307
308    // 3. Contract state with operator MPO
309    // Choose a center node (use first node in sorted order for determinism)
310    let mut node_names: Vec<_> = state.node_names();
311    node_names.sort();
312    let center = node_names
313        .first()
314        .ok_or_else(|| anyhow::anyhow!("Empty state"))?;
315
316    let contraction_options = ContractionOptions {
317        method: options.method,
318        max_rank: options.max_rank,
319        svd_policy: options.svd_policy,
320        qr_rtol: options.qr_rtol,
321        nfullsweeps: options.nfullsweeps,
322        convergence_tol: options.convergence_tol,
323        ..Default::default()
324    };
325
326    let contracted = contract(
327        &transformed_state,
328        full_operator.mpo(),
329        center,
330        contraction_options,
331    )
332    .context("Failed to contract state with operator")?;
333
334    // 4. Transform operator's output indices to true output indices
335    let result = transform_output_to_true(&full_operator, contracted)?;
336
337    Ok(result)
338}
339
340/// Extend a partial operator to cover the full state space.
341///
342/// Uses the operator support's Steiner tree to detect disconnected regions and
343/// fills all missing nodes with identity operators.
344/// For gap nodes, creates proper index mappings where:
345/// - True indices = state's actual site indices
346/// - Internal indices = new simulated indices for the MPO tensor
347fn extend_operator_to_full_space<T, V>(
348    operator: &LinearOperator<T, V>,
349    state: &TreeTN<T, V>,
350) -> Result<LinearOperator<T, V>>
351where
352    T: TensorLike,
353    T::Index: IndexLike + Clone + Hash + Eq + std::fmt::Debug,
354    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
355    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
356{
357    let state_network = state.site_index_network();
358    let op_nodes: HashSet<V> = operator.node_names();
359    let state_nodes: HashSet<V> = state.node_names().into_iter().collect();
360    let mut op_node_indices: HashSet<petgraph::stable_graph::NodeIndex> = HashSet::new();
361    for name in &op_nodes {
362        let node_index = state_network.node_index(name).ok_or_else(|| {
363            anyhow::anyhow!("Operator node {:?} is missing from the state network", name)
364        })?;
365        op_node_indices.insert(node_index);
366    }
367
368    let steiner_tree_nodes = state_network.steiner_tree_nodes(&op_node_indices);
369    let steiner_gap_nodes: HashSet<_> = steiner_tree_nodes
370        .difference(&op_node_indices)
371        .copied()
372        .collect();
373    let gap_nodes: Vec<V> = state_nodes.difference(&op_nodes).cloned().collect();
374
375    // Build gap site indices: for each gap node, create internal indices for the identity tensor.
376    // The (input_internal, output_internal) pairs are used to build the delta tensor.
377    #[allow(clippy::type_complexity)]
378    let mut gap_site_indices: HashMap<V, Vec<(T::Index, T::Index)>> = HashMap::new();
379
380    // Also track true<->internal mappings for gap nodes
381    #[allow(clippy::type_complexity)]
382    let mut gap_input_mappings: HashMap<V, IndexMapping<T::Index>> = HashMap::new();
383    #[allow(clippy::type_complexity)]
384    let mut gap_output_mappings: HashMap<V, IndexMapping<T::Index>> = HashMap::new();
385
386    for gap_name in &gap_nodes {
387        let site_space = state
388            .site_space(gap_name)
389            .ok_or_else(|| anyhow::anyhow!("Gap node {:?} has no site space", gap_name))?;
390
391        // For identity at gap nodes:
392        // - True indices = state's site indices (what apply_linear_operator maps from/to)
393        // - Internal indices = new simulated indices for the MPO tensor
394        let mut pairs: Vec<(T::Index, T::Index)> = Vec::new();
395
396        for (i, true_idx) in site_space.iter().enumerate() {
397            let input_internal = true_idx.sim();
398            let output_internal = true_idx.sim();
399            pairs.push((input_internal.clone(), output_internal.clone()));
400
401            // Store mapping for the first site index of each gap node
402            if i == 0 {
403                gap_input_mappings.insert(
404                    gap_name.clone(),
405                    IndexMapping {
406                        true_index: true_idx.clone(),
407                        internal_index: input_internal,
408                    },
409                );
410                gap_output_mappings.insert(
411                    gap_name.clone(),
412                    IndexMapping {
413                        true_index: true_idx.clone(),
414                        internal_index: output_internal,
415                    },
416                );
417            }
418        }
419
420        gap_site_indices.insert(gap_name.clone(), pairs);
421    }
422
423    let mut composed = if operator.mpo.edge_count() == 0 {
424        compose_exclusive_linear_operators_unchecked(state_network, &[operator], &gap_site_indices)
425            .context("Failed to compose operator with identity gaps")?
426    } else if steiner_gap_nodes.is_empty() {
427        compose_exclusive_linear_operators(state_network, &[operator], &gap_site_indices)
428            .context("Failed to compose operator with identity gaps")?
429    } else {
430        compose_operator_along_state_paths(
431            operator,
432            state_network,
433            &gap_site_indices,
434            gap_input_mappings.clone(),
435            gap_output_mappings.clone(),
436        )
437        .context("Failed to compose operator along state paths")?
438    };
439
440    // Override the mappings for gap nodes to use the correct true indices
441    // (compose_exclusive_linear_operators uses the internal indices as true indices for gaps)
442    for (gap_name, mapping) in gap_input_mappings {
443        composed.input_mapping.insert(gap_name, mapping);
444    }
445    for (gap_name, mapping) in gap_output_mappings {
446        composed.output_mapping.insert(gap_name, mapping);
447    }
448
449    Ok(composed)
450}
451
452#[allow(clippy::type_complexity)]
453fn compose_operator_along_state_paths<T, V>(
454    operator: &LinearOperator<T, V>,
455    state_network: &crate::site_index_network::SiteIndexNetwork<V, T::Index>,
456    gap_site_indices: &HashMap<V, Vec<(T::Index, T::Index)>>,
457    input_mappings: HashMap<V, IndexMapping<T::Index>>,
458    output_mappings: HashMap<V, IndexMapping<T::Index>>,
459) -> Result<LinearOperator<T, V>>
460where
461    T: TensorLike,
462    T::Index: IndexLike + Clone + Hash + Eq + std::fmt::Debug,
463    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
464    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
465{
466    let op_nodes: HashSet<V> = operator.node_names();
467    let mut tensors_by_node: HashMap<V, T> = HashMap::new();
468
469    let mut state_node_names: Vec<V> = state_network.node_names().into_iter().cloned().collect();
470    state_node_names.sort();
471
472    for node in &state_node_names {
473        if op_nodes.contains(node) {
474            let node_idx = operator.mpo.node_index(node).ok_or_else(|| {
475                anyhow::anyhow!(
476                    "compose_operator_along_state_paths: missing node {:?}",
477                    node
478                )
479            })?;
480            let tensor = operator.mpo.tensor(node_idx).ok_or_else(|| {
481                anyhow::anyhow!(
482                    "compose_operator_along_state_paths: missing tensor for {:?}",
483                    node
484                )
485            })?;
486            tensors_by_node.insert(node.clone(), tensor.clone());
487        } else {
488            let index_pairs = gap_site_indices.get(node).ok_or_else(|| {
489                anyhow::anyhow!(
490                    "compose_operator_along_state_paths: missing gap indices for {:?}",
491                    node
492                )
493            })?;
494            let input_indices: Vec<T::Index> = index_pairs.iter().map(|(i, _)| i.clone()).collect();
495            let output_indices: Vec<T::Index> =
496                index_pairs.iter().map(|(_, o)| o.clone()).collect();
497            let tensor = if input_indices.is_empty() {
498                T::delta(&[], &[]).context(
499                    "compose_operator_along_state_paths: failed to build scalar identity",
500                )?
501            } else {
502                T::delta(&input_indices, &output_indices).with_context(|| {
503                    format!(
504                        "compose_operator_along_state_paths: failed to build identity for gap {:?}",
505                        node
506                    )
507                })?
508            };
509            tensors_by_node.insert(node.clone(), tensor);
510        }
511    }
512
513    let mut op_edges: Vec<(V, V)> = operator.mpo.site_index_network().edges().collect();
514    op_edges.sort();
515    let mut used_state_edges: HashSet<(V, V)> = HashSet::new();
516
517    for (node_a, node_b) in op_edges {
518        let idx_a = state_network.node_index(&node_a).ok_or_else(|| {
519            anyhow::anyhow!(
520                "compose_operator_along_state_paths: missing state node {:?}",
521                node_a
522            )
523        })?;
524        let idx_b = state_network.node_index(&node_b).ok_or_else(|| {
525            anyhow::anyhow!(
526                "compose_operator_along_state_paths: missing state node {:?}",
527                node_b
528            )
529        })?;
530        let path = state_network.path_between(idx_a, idx_b).ok_or_else(|| {
531            anyhow::anyhow!(
532                "compose_operator_along_state_paths: no path between {:?} and {:?}",
533                node_a,
534                node_b
535            )
536        })?;
537        if path.len() < 2 {
538            continue;
539        }
540
541        let edge = operator.mpo.edge_between(&node_a, &node_b).ok_or_else(|| {
542            anyhow::anyhow!(
543                "compose_operator_along_state_paths: missing operator edge between {:?} and {:?}",
544                node_a,
545                node_b
546            )
547        })?;
548        let bond = operator
549            .mpo
550            .bond_index(edge)
551            .ok_or_else(|| {
552                anyhow::anyhow!("compose_operator_along_state_paths: missing bond index")
553            })?
554            .clone();
555
556        let mut chain_bonds = Vec::with_capacity(path.len() - 1);
557        chain_bonds.push(bond.sim());
558        for _ in 1..(path.len() - 1) {
559            let next = chain_bonds[chain_bonds.len() - 1].sim();
560            chain_bonds.push(next);
561        }
562
563        let start_name = state_network
564            .node_name(path[0])
565            .ok_or_else(|| anyhow::anyhow!("compose_operator_along_state_paths: missing start"))?
566            .clone();
567        let end_name = state_network
568            .node_name(path[path.len() - 1])
569            .ok_or_else(|| anyhow::anyhow!("compose_operator_along_state_paths: missing end"))?
570            .clone();
571
572        {
573            let tensor = tensors_by_node.get_mut(&start_name).ok_or_else(|| {
574                anyhow::anyhow!(
575                    "compose_operator_along_state_paths: missing tensor for {:?}",
576                    start_name
577                )
578            })?;
579            *tensor = tensor.replaceind(&bond, &chain_bonds[0]).with_context(|| {
580                format!(
581                    "compose_operator_along_state_paths: failed to reroute bond at {:?}",
582                    start_name
583                )
584            })?;
585        }
586        {
587            let tensor = tensors_by_node.get_mut(&end_name).ok_or_else(|| {
588                anyhow::anyhow!(
589                    "compose_operator_along_state_paths: missing tensor for {:?}",
590                    end_name
591                )
592            })?;
593            let last_bond = &chain_bonds[chain_bonds.len() - 1];
594            *tensor = tensor.replaceind(&bond, last_bond).with_context(|| {
595                format!(
596                    "compose_operator_along_state_paths: failed to reroute bond at {:?}",
597                    end_name
598                )
599            })?;
600        }
601
602        for i in 1..(path.len() - 1) {
603            let mid_name = state_network
604                .node_name(path[i])
605                .ok_or_else(|| anyhow::anyhow!("compose_operator_along_state_paths: missing mid"))?
606                .clone();
607            let delta = T::delta(
608                std::slice::from_ref(&chain_bonds[i - 1]),
609                std::slice::from_ref(&chain_bonds[i]),
610            )
611            .with_context(|| {
612                format!(
613                    "compose_operator_along_state_paths: failed to build bridge at {:?}",
614                    mid_name
615                )
616            })?;
617            let tensor = tensors_by_node.get_mut(&mid_name).ok_or_else(|| {
618                anyhow::anyhow!(
619                    "compose_operator_along_state_paths: missing tensor for {:?}",
620                    mid_name
621                )
622            })?;
623            *tensor = tensor.outer_product(&delta).with_context(|| {
624                format!(
625                    "compose_operator_along_state_paths: failed to attach bridge at {:?}",
626                    mid_name
627                )
628            })?;
629        }
630
631        for window in path.windows(2) {
632            let a = state_network
633                .node_name(window[0])
634                .ok_or_else(|| {
635                    anyhow::anyhow!("compose_operator_along_state_paths: missing path node")
636                })?
637                .clone();
638            let b = state_network
639                .node_name(window[1])
640                .ok_or_else(|| {
641                    anyhow::anyhow!("compose_operator_along_state_paths: missing path node")
642                })?
643                .clone();
644            let edge_key = if a <= b { (a, b) } else { (b, a) };
645            used_state_edges.insert(edge_key);
646        }
647    }
648
649    let mut state_edges: Vec<(V, V)> = state_network.edges().collect();
650    state_edges.sort();
651    for (node_a, node_b) in state_edges {
652        let edge_key = if node_a <= node_b {
653            (node_a.clone(), node_b.clone())
654        } else {
655            (node_b.clone(), node_a.clone())
656        };
657        if used_state_edges.contains(&edge_key) {
658            continue;
659        }
660        let (link_a, link_b) = T::Index::create_dummy_link_pair();
661        let ones_a = T::ones(std::slice::from_ref(&link_a)).with_context(|| {
662            format!(
663                "compose_operator_along_state_paths: failed to create dummy link tensor for {:?}",
664                node_a
665            )
666        })?;
667        let ones_b = T::ones(std::slice::from_ref(&link_b)).with_context(|| {
668            format!(
669                "compose_operator_along_state_paths: failed to create dummy link tensor for {:?}",
670                node_b
671            )
672        })?;
673
674        let tensor_a = tensors_by_node.get_mut(&node_a).ok_or_else(|| {
675            anyhow::anyhow!(
676                "compose_operator_along_state_paths: missing tensor for {:?}",
677                node_a
678            )
679        })?;
680        *tensor_a = tensor_a.outer_product(&ones_a).with_context(|| {
681            format!(
682                "compose_operator_along_state_paths: failed to attach dummy link at {:?}",
683                node_a
684            )
685        })?;
686
687        let tensor_b = tensors_by_node.get_mut(&node_b).ok_or_else(|| {
688            anyhow::anyhow!(
689                "compose_operator_along_state_paths: missing tensor for {:?}",
690                node_b
691            )
692        })?;
693        *tensor_b = tensor_b.outer_product(&ones_b).with_context(|| {
694            format!(
695                "compose_operator_along_state_paths: failed to attach dummy link at {:?}",
696                node_b
697            )
698        })?;
699    }
700
701    let tensors: Vec<T> = state_node_names
702        .iter()
703        .map(|node| {
704            tensors_by_node.get(node).cloned().ok_or_else(|| {
705                anyhow::anyhow!(
706                    "compose_operator_along_state_paths: missing tensor for {:?}",
707                    node
708                )
709            })
710        })
711        .collect::<Result<Vec<_>>>()?;
712
713    let mpo = TreeTN::from_tensors(tensors, state_node_names.clone())
714        .context("compose_operator_along_state_paths: failed to create TreeTN")?;
715
716    Ok(LinearOperator::new(mpo, input_mappings, output_mappings))
717}
718
719/// Transform state's site indices to operator's input indices.
720fn transform_state_to_input<T, V>(
721    operator: &LinearOperator<T, V>,
722    state: &TreeTN<T, V>,
723) -> Result<TreeTN<T, V>>
724where
725    T: TensorLike,
726    T::Index: IndexLike + Clone + Hash + Eq,
727    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
728    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
729{
730    let mut result = state.clone();
731
732    for (node, mapping) in operator.input_mappings() {
733        // Replace true_index with internal_index in the state
734        result = result
735            .replaceind(&mapping.true_index, &mapping.internal_index)
736            .with_context(|| format!("Failed to transform input index at node {:?}", node))?;
737    }
738
739    Ok(result)
740}
741
742/// Transform operator's output indices to true output indices.
743fn transform_output_to_true<T, V>(
744    operator: &LinearOperator<T, V>,
745    mut result: TreeTN<T, V>,
746) -> Result<TreeTN<T, V>>
747where
748    T: TensorLike,
749    T::Index: IndexLike + Clone + Hash + Eq,
750    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
751    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
752{
753    for (node, mapping) in operator.output_mappings() {
754        // Replace internal_index with true_index in the result
755        result = result
756            .replaceind(&mapping.internal_index, &mapping.true_index)
757            .with_context(|| format!("Failed to transform output index at node {:?}", node))?;
758    }
759
760    Ok(result)
761}
762
763// ============================================================================
764// TensorIndex implementation for LinearOperator
765// ============================================================================
766
767impl<T, V> TensorIndex for LinearOperator<T, V>
768where
769    T: TensorLike,
770    T::Index: IndexLike + Clone + Hash + Eq + std::fmt::Debug,
771    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
772    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
773{
774    type Index = T::Index;
775
776    /// Return all external indices (true input and output indices).
777    fn external_indices(&self) -> Vec<Self::Index> {
778        let mut result: Vec<Self::Index> = self
779            .input_mapping
780            .values()
781            .map(|m| m.true_index.clone())
782            .collect();
783        result.extend(self.output_mapping.values().map(|m| m.true_index.clone()));
784        result
785    }
786
787    fn num_external_indices(&self) -> usize {
788        self.input_mapping.len() + self.output_mapping.len()
789    }
790
791    /// Replace an external index (true index) in this operator.
792    ///
793    /// This updates the mapping but does NOT modify the internal MPO tensors.
794    fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
795        // Validate dimension match
796        if old_index.dim() != new_index.dim() {
797            return Err(anyhow::anyhow!(
798                "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
799                old_index.dim(),
800                new_index.dim()
801            ));
802        }
803
804        let mut result = self.clone();
805
806        // Check input mappings
807        for (node, mapping) in &self.input_mapping {
808            if mapping.true_index.same_id(old_index) {
809                result.input_mapping.insert(
810                    node.clone(),
811                    IndexMapping {
812                        true_index: new_index.clone(),
813                        internal_index: mapping.internal_index.clone(),
814                    },
815                );
816                return Ok(result);
817            }
818        }
819
820        // Check output mappings
821        for (node, mapping) in &self.output_mapping {
822            if mapping.true_index.same_id(old_index) {
823                result.output_mapping.insert(
824                    node.clone(),
825                    IndexMapping {
826                        true_index: new_index.clone(),
827                        internal_index: mapping.internal_index.clone(),
828                    },
829                );
830                return Ok(result);
831            }
832        }
833
834        Err(anyhow::anyhow!(
835            "Index {:?} not found in LinearOperator mappings",
836            old_index.id()
837        ))
838    }
839
840    /// Replace multiple external indices.
841    fn replaceinds(
842        &self,
843        old_indices: &[Self::Index],
844        new_indices: &[Self::Index],
845    ) -> Result<Self> {
846        if old_indices.len() != new_indices.len() {
847            return Err(anyhow::anyhow!(
848                "Length mismatch: {} old indices, {} new indices",
849                old_indices.len(),
850                new_indices.len()
851            ));
852        }
853
854        let mut result = self.clone();
855        for (old, new) in old_indices.iter().zip(new_indices.iter()) {
856            result = result.replaceind(old, new)?;
857        }
858        Ok(result)
859    }
860}
861
862// ============================================================================
863// Arc-based CoW wrapper for LinearOperator
864// ============================================================================
865
866/// LinearOperator with Arc-based Copy-on-Write semantics.
867///
868/// This wrapper uses `Arc` for the internal MPO to enable cheap cloning
869/// and efficient sharing. When mutation is needed, `make_mut` performs
870/// a clone only if there are other references.
871#[derive(Debug, Clone)]
872pub struct ArcLinearOperator<T, V>
873where
874    T: TensorLike,
875    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
876{
877    /// The MPO with internal index IDs (wrapped in Arc for CoW)
878    pub mpo: Arc<TreeTN<T, V>>,
879    /// Input index mapping: node -> (true s_in, internal s_in_tmp)
880    pub input_mapping: HashMap<V, IndexMapping<T::Index>>,
881    /// Output index mapping: node -> (true s_out, internal s_out_tmp)
882    pub output_mapping: HashMap<V, IndexMapping<T::Index>>,
883}
884
885impl<T, V> ArcLinearOperator<T, V>
886where
887    T: TensorLike,
888    T::Index: IndexLike + Clone,
889    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
890    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
891{
892    /// Create from an existing LinearOperator.
893    pub fn from_linear_operator(op: LinearOperator<T, V>) -> Self {
894        Self {
895            mpo: Arc::new(op.mpo),
896            input_mapping: op.input_mapping,
897            output_mapping: op.output_mapping,
898        }
899    }
900
901    /// Create a new ArcLinearOperator.
902    pub fn new(
903        mpo: TreeTN<T, V>,
904        input_mapping: HashMap<V, IndexMapping<T::Index>>,
905        output_mapping: HashMap<V, IndexMapping<T::Index>>,
906    ) -> Self {
907        Self {
908            mpo: Arc::new(mpo),
909            input_mapping,
910            output_mapping,
911        }
912    }
913
914    /// Get a mutable reference to the MPO, cloning if necessary.
915    ///
916    /// This implements Copy-on-Write semantics: if this is the only reference,
917    /// no copy is made. If there are other references, the MPO is cloned first.
918    pub fn mpo_mut(&mut self) -> &mut TreeTN<T, V> {
919        Arc::make_mut(&mut self.mpo)
920    }
921
922    /// Get an immutable reference to the MPO.
923    pub fn mpo(&self) -> &TreeTN<T, V> {
924        &self.mpo
925    }
926
927    /// Convert back to a LinearOperator (unwraps Arc if possible).
928    pub fn into_linear_operator(self) -> LinearOperator<T, V> {
929        LinearOperator {
930            mpo: Arc::try_unwrap(self.mpo).unwrap_or_else(|arc| (*arc).clone()),
931            input_mapping: self.input_mapping,
932            output_mapping: self.output_mapping,
933        }
934    }
935
936    /// Get input mapping for a node.
937    pub fn get_input_mapping(&self, node: &V) -> Option<&IndexMapping<T::Index>> {
938        self.input_mapping.get(node)
939    }
940
941    /// Get output mapping for a node.
942    pub fn get_output_mapping(&self, node: &V) -> Option<&IndexMapping<T::Index>> {
943        self.output_mapping.get(node)
944    }
945
946    /// Get all input mappings.
947    pub fn input_mappings(&self) -> &HashMap<V, IndexMapping<T::Index>> {
948        &self.input_mapping
949    }
950
951    /// Get all output mappings.
952    pub fn output_mappings(&self) -> &HashMap<V, IndexMapping<T::Index>> {
953        &self.output_mapping
954    }
955
956    /// Get node names covered by this operator.
957    pub fn node_names(&self) -> HashSet<V> {
958        self.mpo
959            .site_index_network()
960            .node_names()
961            .into_iter()
962            .cloned()
963            .collect()
964    }
965}
966
967#[cfg(test)]
968mod tests;