Skip to main content

tensor4all_treetn/operator/
linear_operator.rs

1//! LinearOperator: Wrapper for MPO with index mapping.
2//!
3//! This module provides a `LinearOperator` that wraps an MPO (Matrix Product Operator)
4//! and handles the index ID mapping between true site indices and internal MPO indices.
5//!
6//! # Problem
7//!
8//! In the equation `A * x = b`:
9//! - `A.s_in` should match `x`'s site indices
10//! - `A.s_out` should match `b`'s site indices
11//!
12//! However, a tensor cannot have two indices with the same ID. So the MPO internally
13//! uses `s_in_tmp` and `s_out_tmp` with independent IDs.
14//!
15//! # Solution
16//!
17//! `LinearOperator` stores:
18//! - The MPO with internal index IDs (`s_in_tmp`, `s_out_tmp`)
19//! - Mapping from true `s_in`/`s_out` to internal `s_in_tmp`/`s_out_tmp`
20//!
21//! When applying to `x`, it automatically handles the index transformations.
22
23use std::collections::HashMap;
24use std::hash::Hash;
25
26use anyhow::Result;
27
28use tensor4all_core::AllowedPairs;
29use tensor4all_core::IndexLike;
30use tensor4all_core::TensorLike;
31
32use super::index_mapping::IndexMapping;
33use crate::treetn::TreeTN;
34
35/// LinearOperator: Wraps an MPO with index mapping for automatic transformations.
36///
37/// # Type Parameters
38///
39/// * `T` - Tensor type implementing `TensorLike`
40/// * `V` - Node name type
41///
42/// # Examples
43///
44/// A `LinearOperator` is typically obtained from a constructor function rather
45/// than built directly. The `mpo` field contains the underlying TreeTN:
46///
47/// ```
48/// use tensor4all_treetn::LinearOperator;
49/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
50/// use tensor4all_treetn::TreeTN;
51/// use std::collections::HashMap;
52///
53/// // Build a trivial single-node LinearOperator wrapping a 2x2 identity
54/// let s_in = DynIndex::new_dyn(2);
55/// let s_out = DynIndex::new_dyn(2);
56/// let t = TensorDynLen::from_dense(
57///     vec![s_in.clone(), s_out.clone()],
58///     vec![1.0_f64, 0.0, 0.0, 1.0],
59/// ).unwrap();
60///
61/// let mpo = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
62/// assert_eq!(mpo.node_count(), 1);
63/// ```
64#[derive(Debug, Clone)]
65pub struct LinearOperator<T, V>
66where
67    T: TensorLike,
68    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
69{
70    /// The MPO with internal index IDs
71    pub mpo: TreeTN<T, V>,
72    /// Input index mapping: node -> (true s_in, internal s_in_tmp)
73    pub input_mapping: HashMap<V, IndexMapping<T::Index>>,
74    /// Output index mapping: node -> (true s_out, internal s_out_tmp)
75    pub output_mapping: HashMap<V, IndexMapping<T::Index>>,
76}
77
78impl<T, V> LinearOperator<T, V>
79where
80    T: TensorLike,
81    T::Index: IndexLike,
82    <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
83    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
84{
85    /// Create a new LinearOperator from an MPO and index mappings.
86    ///
87    /// # Arguments
88    ///
89    /// * `mpo` - The MPO with internal index IDs
90    /// * `input_mapping` - Mapping from true input indices to internal indices
91    /// * `output_mapping` - Mapping from true output indices to internal indices
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use std::collections::HashMap;
97    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
98    /// use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN};
99    ///
100    /// // Build a 2x2 identity operator (single node)
101    /// let site = DynIndex::new_dyn(2);
102    /// let s_in = DynIndex::new_dyn(2);
103    /// let s_out = DynIndex::new_dyn(2);
104    /// let mpo_tensor = TensorDynLen::from_dense(
105    ///     vec![s_in.clone(), s_out.clone()],
106    ///     vec![1.0_f64, 0.0, 0.0, 1.0],
107    /// ).unwrap();
108    /// let mpo = TreeTN::<_, usize>::from_tensors(vec![mpo_tensor], vec![0]).unwrap();
109    ///
110    /// let mut input_mapping = HashMap::new();
111    /// input_mapping.insert(0usize, IndexMapping { true_index: site.clone(), internal_index: s_in });
112    /// let mut output_mapping = HashMap::new();
113    /// output_mapping.insert(0usize, IndexMapping { true_index: site.clone(), internal_index: s_out });
114    ///
115    /// let op = LinearOperator::new(mpo, input_mapping, output_mapping);
116    /// assert_eq!(op.mpo().node_count(), 1);
117    /// ```
118    pub fn new(
119        mpo: TreeTN<T, V>,
120        input_mapping: HashMap<V, IndexMapping<T::Index>>,
121        output_mapping: HashMap<V, IndexMapping<T::Index>>,
122    ) -> Self {
123        Self {
124            mpo,
125            input_mapping,
126            output_mapping,
127        }
128    }
129
130    /// Create a LinearOperator from an MPO and a reference state.
131    ///
132    /// This assumes:
133    /// - The MPO has site indices that we need to map
134    /// - The state's site indices define the true input space
135    /// - For `A * x = b` with `space(x) = space(b)`, the output space equals input space
136    ///
137    /// # Arguments
138    ///
139    /// * `mpo` - The MPO (operator A)
140    /// * `state` - Reference state (defines the true site index space)
141    ///
142    /// # Returns
143    ///
144    /// A LinearOperator with proper index mappings, or an error if structure is incompatible.
145    pub fn from_mpo_and_state(mpo: TreeTN<T, V>, state: &TreeTN<T, V>) -> Result<Self> {
146        let mut input_mapping = HashMap::new();
147        let mut output_mapping = HashMap::new();
148
149        for node in mpo.site_index_network().node_names() {
150            // Get state's site indices for this node
151            let state_site = state.site_space(node);
152
153            // Get MPO's site indices for this node
154            let mpo_site = mpo.site_space(node);
155
156            match (state_site, mpo_site) {
157                (Some(state_indices), Some(mpo_indices)) => {
158                    // MPO should have exactly 2 site indices per state site index:
159                    // one for input (s_in_tmp) and one for output (s_out_tmp)
160                    // Both should have the same dimension as the state's site index.
161
162                    if state_indices.len() * 2 != mpo_indices.len() {
163                        return Err(anyhow::anyhow!(
164                            "Node {:?}: MPO should have 2x site indices. State has {}, MPO has {}",
165                            node,
166                            state_indices.len(),
167                            mpo_indices.len()
168                        ));
169                    }
170
171                    // For each state site index, find matching MPO indices by dimension
172                    for state_idx in state_indices {
173                        let dim = state_idx.dim();
174
175                        // Find MPO indices with matching dimension
176                        let matching_mpo: Vec<_> =
177                            mpo_indices.iter().filter(|idx| idx.dim() == dim).collect();
178
179                        if matching_mpo.len() < 2 {
180                            return Err(anyhow::anyhow!(
181                                "Node {:?}: Not enough MPO indices with dimension {}. Found {}",
182                                node,
183                                dim,
184                                matching_mpo.len()
185                            ));
186                        }
187
188                        // Convention: first matching is s_in_tmp, second is s_out_tmp
189                        // (This depends on how the MPO was constructed)
190                        input_mapping.insert(
191                            node.clone(),
192                            IndexMapping {
193                                true_index: state_idx.clone(),
194                                internal_index: matching_mpo[0].clone(),
195                            },
196                        );
197
198                        // For output, use the same true index (space(x) = space(b))
199                        output_mapping.insert(
200                            node.clone(),
201                            IndexMapping {
202                                true_index: state_idx.clone(),
203                                internal_index: matching_mpo[1].clone(),
204                            },
205                        );
206                    }
207                }
208                (None, None) => {
209                    // No site indices for this node, OK
210                }
211                _ => {
212                    return Err(anyhow::anyhow!(
213                        "Node {:?}: Mismatched site space presence between state and MPO",
214                        node
215                    ));
216                }
217            }
218        }
219
220        Ok(Self {
221            mpo,
222            input_mapping,
223            output_mapping,
224        })
225    }
226
227    /// Apply the operator to a local tensor at a specific region.
228    ///
229    /// This is used during the sweep for local updates.
230    ///
231    /// # Arguments
232    ///
233    /// * `local_tensor` - The local tensor (merged tensors from the region)
234    /// * `region` - The nodes in the current region
235    ///
236    /// # Returns
237    ///
238    /// The result of applying the operator to the local tensor.
239    pub fn apply_local(&self, local_tensor: &T, region: &[V]) -> Result<T> {
240        // Step 1: Replace input indices in local_tensor with internal indices
241        let mut transformed = local_tensor.clone();
242        for node in region {
243            if let Some(mapping) = self.input_mapping.get(node) {
244                transformed =
245                    transformed.replaceind(&mapping.true_index, &mapping.internal_index)?;
246            }
247        }
248
249        // Step 2: Contract with local operator tensors
250        let mut op_tensor: Option<T> = None;
251        for node in region {
252            let node_idx = self
253                .mpo
254                .node_index(node)
255                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in MPO", node))?;
256            let tensor = self
257                .mpo
258                .tensor(node_idx)
259                .ok_or_else(|| anyhow::anyhow!("Tensor not found in MPO for node {:?}", node))?
260                .clone();
261
262            op_tensor = Some(match op_tensor {
263                None => tensor,
264                Some(t) => T::contract(&[&t, &tensor], AllowedPairs::All)?,
265            });
266        }
267
268        let op_tensor = op_tensor.ok_or_else(|| anyhow::anyhow!("Empty region"))?;
269
270        // Contract transformed tensor with operator
271        let contracted = T::contract(&[&transformed, &op_tensor], AllowedPairs::All)?;
272
273        // Step 3: Replace output indices back to true indices
274        let mut result = contracted;
275        for node in region {
276            if let Some(mapping) = self.output_mapping.get(node) {
277                result = result.replaceind(&mapping.internal_index, &mapping.true_index)?;
278            }
279        }
280
281        Ok(result)
282    }
283
284    /// Get the internal MPO.
285    pub fn mpo(&self) -> &TreeTN<T, V> {
286        &self.mpo
287    }
288
289    /// Get input mapping for a node.
290    pub fn get_input_mapping(&self, node: &V) -> Option<&IndexMapping<T::Index>> {
291        self.input_mapping.get(node)
292    }
293
294    /// Get output mapping for a node.
295    pub fn get_output_mapping(&self, node: &V) -> Option<&IndexMapping<T::Index>> {
296        self.output_mapping.get(node)
297    }
298
299    fn single_site_index_from_state(state: &TreeTN<T, V>, node: &V) -> Result<T::Index> {
300        let site_space = state
301            .site_space(node)
302            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in state site space", node))?;
303
304        if site_space.len() != 1 {
305            return Err(anyhow::anyhow!(
306                "Node {:?}: expected exactly 1 site index in state, found {}",
307                node,
308                site_space.len()
309            ));
310        }
311
312        site_space
313            .iter()
314            .next()
315            .cloned()
316            .ok_or_else(|| anyhow::anyhow!("Node {:?}: missing site index", node))
317    }
318
319    /// Reset true input indices to match the given state's site space.
320    ///
321    /// This only rewrites the external mapping. The internal MPO indices are unchanged.
322    pub fn set_input_space_from_state(&mut self, state: &TreeTN<T, V>) -> Result<()> {
323        let nodes: Vec<V> = self.input_mapping.keys().cloned().collect();
324        for node in nodes {
325            let new_true_index = Self::single_site_index_from_state(state, &node)?;
326            let mapping = self
327                .input_mapping
328                .get_mut(&node)
329                .ok_or_else(|| anyhow::anyhow!("Input mapping missing for node {:?}", node))?;
330            if mapping.internal_index.dim() != new_true_index.dim() {
331                return Err(anyhow::anyhow!(
332                    "Node {:?}: input mapping dimension {} does not match state site dimension {}",
333                    node,
334                    mapping.internal_index.dim(),
335                    new_true_index.dim()
336                ));
337            }
338            mapping.true_index = new_true_index;
339        }
340        Ok(())
341    }
342
343    /// Reset true output indices to match the given state's site space.
344    ///
345    /// This only rewrites the external mapping. The internal MPO indices are unchanged.
346    pub fn set_output_space_from_state(&mut self, state: &TreeTN<T, V>) -> Result<()> {
347        let nodes: Vec<V> = self.output_mapping.keys().cloned().collect();
348        for node in nodes {
349            let new_true_index = Self::single_site_index_from_state(state, &node)?;
350            let mapping = self
351                .output_mapping
352                .get_mut(&node)
353                .ok_or_else(|| anyhow::anyhow!("Output mapping missing for node {:?}", node))?;
354            if mapping.internal_index.dim() != new_true_index.dim() {
355                return Err(anyhow::anyhow!(
356                    "Node {:?}: output mapping dimension {} does not match state site dimension {}",
357                    node,
358                    mapping.internal_index.dim(),
359                    new_true_index.dim()
360                ));
361            }
362            mapping.true_index = new_true_index;
363        }
364        Ok(())
365    }
366
367    /// Align this operator's input and output site index mappings to match a target state.
368    ///
369    /// For each node in the operator's input and output mappings, the `true_index` is
370    /// updated to the corresponding site index from the target state. The internal MPO
371    /// indices remain unchanged.
372    ///
373    /// This is useful when an operator (e.g., from `shift_operator` or `affine_operator`)
374    /// was constructed with its own site indices, but needs to be applied to a state that
375    /// has different site index IDs. After calling `align_to_state`, the operator's
376    /// `true_index` fields will reference the state's site indices, enabling correct
377    /// index contraction during `apply_local`.
378    ///
379    /// # Arguments
380    ///
381    /// * `state` - The target state whose site indices define the true index space.
382    ///   Each node in the operator's mappings must exist in the state with exactly one
383    ///   site index of matching dimension.
384    ///
385    /// # Errors
386    ///
387    /// Returns an error if:
388    /// - A node in the operator's mapping is not found in the state
389    /// - A node in the state has more than one site index
390    /// - The dimension of the state's site index does not match the operator's internal index
391    ///
392    /// # Examples
393    ///
394    /// ```
395    /// use std::collections::HashMap;
396    ///
397    /// use tensor4all_core::{DynIndex, IndexLike, TensorDynLen};
398    /// use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN};
399    ///
400    /// let state_index = DynIndex::new_dyn(2);
401    /// let state_tensor = TensorDynLen::from_dense(vec![state_index.clone()], vec![1.0, 2.0]).unwrap();
402    /// let state = TreeTN::<TensorDynLen, usize>::from_tensors(vec![state_tensor], vec![0]).unwrap();
403    ///
404    /// let input_internal = DynIndex::new_dyn(2);
405    /// let output_internal = DynIndex::new_dyn(2);
406    /// let mpo_tensor = TensorDynLen::from_dense(
407    ///     vec![input_internal.clone(), output_internal.clone()],
408    ///     vec![1.0, 0.0, 0.0, 1.0],
409    /// ).unwrap();
410    /// let mpo = TreeTN::<TensorDynLen, usize>::from_tensors(vec![mpo_tensor], vec![0]).unwrap();
411    ///
412    /// let mut input_mapping = HashMap::new();
413    /// input_mapping.insert(
414    ///     0usize,
415    ///     IndexMapping {
416    ///         true_index: DynIndex::new_dyn(2),
417    ///         internal_index: input_internal,
418    ///     },
419    /// );
420    /// let mut output_mapping = HashMap::new();
421    /// output_mapping.insert(
422    ///     0usize,
423    ///     IndexMapping {
424    ///         true_index: DynIndex::new_dyn(2),
425    ///         internal_index: output_internal,
426    ///     },
427    /// );
428    ///
429    /// let mut op = LinearOperator::new(mpo, input_mapping, output_mapping);
430    /// op.align_to_state(&state).unwrap();
431    ///
432    /// assert!(op.input_mappings()[&0].true_index.same_id(&state_index));
433    /// assert!(op.output_mappings()[&0].true_index.same_id(&state_index));
434    /// ```
435    pub fn align_to_state(&mut self, state: &TreeTN<T, V>) -> Result<()> {
436        self.set_input_space_from_state(state)?;
437        self.set_output_space_from_state(state)?;
438        Ok(())
439    }
440
441    /// Returns the transposed operator by swapping input and output mappings.
442    ///
443    /// The pullback of a forward operator is its transpose: if the forward
444    /// operator realizes the matrix `M_{y,x}`, the transposed operator
445    /// realizes `M_{x,y}`. This method swaps `input_mapping` and
446    /// `output_mapping` without copying the underlying MPO tensors — it is
447    /// an O(1) operation.
448    ///
449    /// `.transpose().transpose()` yields an operator equivalent to the
450    /// original (mappings restored, MPO unchanged).
451    ///
452    /// # Examples
453    ///
454    /// ```
455    /// use std::collections::HashMap;
456    /// use tensor4all_core::{DynIndex, IndexLike, TensorDynLen};
457    /// use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN};
458    ///
459    /// let site_in = DynIndex::new_dyn(2);
460    /// let site_out = DynIndex::new_dyn(3);
461    /// let s_in_tmp = DynIndex::new_dyn(2);
462    /// let s_out_tmp = DynIndex::new_dyn(3);
463    ///
464    /// let mpo_tensor = TensorDynLen::from_dense(
465    ///     vec![s_in_tmp.clone(), s_out_tmp.clone()],
466    ///     vec![1.0_f64, 0.0, 0.0, 0.0, 1.0, 0.0],
467    /// ).unwrap();
468    /// let mpo = TreeTN::<_, usize>::from_tensors(vec![mpo_tensor], vec![0]).unwrap();
469    ///
470    /// let mut input_mapping = HashMap::new();
471    /// input_mapping.insert(
472    ///     0usize,
473    ///     IndexMapping { true_index: site_in.clone(), internal_index: s_in_tmp.clone() },
474    /// );
475    /// let mut output_mapping = HashMap::new();
476    /// output_mapping.insert(
477    ///     0usize,
478    ///     IndexMapping { true_index: site_out.clone(), internal_index: s_out_tmp.clone() },
479    /// );
480    ///
481    /// let op = LinearOperator::new(mpo, input_mapping, output_mapping);
482    /// let t = op.transpose();
483    ///
484    /// // Input/output mappings are swapped.
485    /// assert!(t.input_mapping[&0].true_index.same_id(&site_out));
486    /// assert!(t.output_mapping[&0].true_index.same_id(&site_in));
487    /// ```
488    pub fn transpose(self) -> Self {
489        Self {
490            mpo: self.mpo,
491            input_mapping: self.output_mapping,
492            output_mapping: self.input_mapping,
493        }
494    }
495}
496
497// ============================================================================
498// Helper methods
499// ============================================================================
500
501use std::collections::HashSet;
502
503use crate::operator::Operator;
504use crate::SiteIndexNetwork;
505
506// Implement Operator trait for LinearOperator
507impl<T, V> Operator<T, V> for LinearOperator<T, V>
508where
509    T: TensorLike,
510    T::Index: IndexLike + Clone + Hash + Eq,
511    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
512{
513    fn site_indices(&self) -> HashSet<T::Index> {
514        // Return union of input and output true indices
515        let mut result: HashSet<T::Index> = self
516            .input_mapping
517            .values()
518            .map(|m| m.true_index.clone())
519            .collect();
520        result.extend(self.output_mapping.values().map(|m| m.true_index.clone()));
521        result
522    }
523
524    fn site_index_network(&self) -> &SiteIndexNetwork<V, T::Index> {
525        self.mpo.site_index_network()
526    }
527
528    fn node_names(&self) -> HashSet<V> {
529        self.mpo
530            .site_index_network()
531            .node_names()
532            .into_iter()
533            .cloned()
534            .collect()
535    }
536}
537
538impl<T, V> LinearOperator<T, V>
539where
540    T: TensorLike,
541    T::Index: IndexLike,
542    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
543{
544    /// Get all input site indices (true indices from state space).
545    pub fn input_site_indices(&self) -> HashSet<T::Index> {
546        self.input_mapping
547            .values()
548            .map(|m| m.true_index.clone())
549            .collect()
550    }
551
552    /// Get all output site indices (true indices from result space).
553    pub fn output_site_indices(&self) -> HashSet<T::Index> {
554        self.output_mapping
555            .values()
556            .map(|m| m.true_index.clone())
557            .collect()
558    }
559
560    /// Get all input mappings.
561    pub fn input_mappings(&self) -> &HashMap<V, IndexMapping<T::Index>> {
562        &self.input_mapping
563    }
564
565    /// Get all output mappings.
566    pub fn output_mappings(&self) -> &HashMap<V, IndexMapping<T::Index>> {
567        &self.output_mapping
568    }
569}
570
571#[cfg(test)]
572mod tests;