tensor4all_treetn/operator/linear_operator.rs
1//! LinearOperator: Wrapper for MPO with index mapping.
2//!
3//! This module provides a `LinearOperator` that wraps an MPO (Matrix Product Operator)
4//! and handles the index ID mapping between true site indices and internal MPO indices.
5//!
6//! # Problem
7//!
8//! In the equation `A * x = b`:
9//! - `A.s_in` should match `x`'s site indices
10//! - `A.s_out` should match `b`'s site indices
11//!
12//! However, a tensor cannot have two indices with the same ID. So the MPO internally
13//! uses `s_in_tmp` and `s_out_tmp` with independent IDs.
14//!
15//! # Solution
16//!
17//! `LinearOperator` stores:
18//! - The MPO with internal index IDs (`s_in_tmp`, `s_out_tmp`)
19//! - Mapping from true `s_in`/`s_out` to internal `s_in_tmp`/`s_out_tmp`
20//!
21//! When applying to `x`, it automatically handles the index transformations.
22
23use std::collections::HashMap;
24use std::hash::Hash;
25
26use anyhow::Result;
27
28use tensor4all_core::AllowedPairs;
29use tensor4all_core::IndexLike;
30use tensor4all_core::TensorLike;
31
32use super::index_mapping::IndexMapping;
33use crate::treetn::TreeTN;
34
35/// LinearOperator: Wraps an MPO with index mapping for automatic transformations.
36///
37/// # Type Parameters
38///
39/// * `T` - Tensor type implementing `TensorLike`
40/// * `V` - Node name type
41///
42/// # Examples
43///
44/// A `LinearOperator` is typically obtained from a constructor function rather
45/// than built directly. The `mpo` field contains the underlying TreeTN:
46///
47/// ```
48/// use tensor4all_treetn::LinearOperator;
49/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
50/// use tensor4all_treetn::TreeTN;
51/// use std::collections::HashMap;
52///
53/// // Build a trivial single-node LinearOperator wrapping a 2x2 identity
54/// let s_in = DynIndex::new_dyn(2);
55/// let s_out = DynIndex::new_dyn(2);
56/// let t = TensorDynLen::from_dense(
57/// vec![s_in.clone(), s_out.clone()],
58/// vec![1.0_f64, 0.0, 0.0, 1.0],
59/// ).unwrap();
60///
61/// let mpo = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
62/// assert_eq!(mpo.node_count(), 1);
63/// ```
64#[derive(Debug, Clone)]
65pub struct LinearOperator<T, V>
66where
67 T: TensorLike,
68 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
69{
70 /// The MPO with internal index IDs
71 pub mpo: TreeTN<T, V>,
72 /// Input index mapping: node -> (true s_in, internal s_in_tmp)
73 pub input_mapping: HashMap<V, IndexMapping<T::Index>>,
74 /// Output index mapping: node -> (true s_out, internal s_out_tmp)
75 pub output_mapping: HashMap<V, IndexMapping<T::Index>>,
76}
77
78impl<T, V> LinearOperator<T, V>
79where
80 T: TensorLike,
81 T::Index: IndexLike,
82 <T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
83 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
84{
85 /// Create a new LinearOperator from an MPO and index mappings.
86 ///
87 /// # Arguments
88 ///
89 /// * `mpo` - The MPO with internal index IDs
90 /// * `input_mapping` - Mapping from true input indices to internal indices
91 /// * `output_mapping` - Mapping from true output indices to internal indices
92 ///
93 /// # Examples
94 ///
95 /// ```
96 /// use std::collections::HashMap;
97 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
98 /// use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN};
99 ///
100 /// // Build a 2x2 identity operator (single node)
101 /// let site = DynIndex::new_dyn(2);
102 /// let s_in = DynIndex::new_dyn(2);
103 /// let s_out = DynIndex::new_dyn(2);
104 /// let mpo_tensor = TensorDynLen::from_dense(
105 /// vec![s_in.clone(), s_out.clone()],
106 /// vec![1.0_f64, 0.0, 0.0, 1.0],
107 /// ).unwrap();
108 /// let mpo = TreeTN::<_, usize>::from_tensors(vec![mpo_tensor], vec![0]).unwrap();
109 ///
110 /// let mut input_mapping = HashMap::new();
111 /// input_mapping.insert(0usize, IndexMapping { true_index: site.clone(), internal_index: s_in });
112 /// let mut output_mapping = HashMap::new();
113 /// output_mapping.insert(0usize, IndexMapping { true_index: site.clone(), internal_index: s_out });
114 ///
115 /// let op = LinearOperator::new(mpo, input_mapping, output_mapping);
116 /// assert_eq!(op.mpo().node_count(), 1);
117 /// ```
118 pub fn new(
119 mpo: TreeTN<T, V>,
120 input_mapping: HashMap<V, IndexMapping<T::Index>>,
121 output_mapping: HashMap<V, IndexMapping<T::Index>>,
122 ) -> Self {
123 Self {
124 mpo,
125 input_mapping,
126 output_mapping,
127 }
128 }
129
130 /// Create a LinearOperator from an MPO and a reference state.
131 ///
132 /// This assumes:
133 /// - The MPO has site indices that we need to map
134 /// - The state's site indices define the true input space
135 /// - For `A * x = b` with `space(x) = space(b)`, the output space equals input space
136 ///
137 /// # Arguments
138 ///
139 /// * `mpo` - The MPO (operator A)
140 /// * `state` - Reference state (defines the true site index space)
141 ///
142 /// # Returns
143 ///
144 /// A LinearOperator with proper index mappings, or an error if structure is incompatible.
145 pub fn from_mpo_and_state(mpo: TreeTN<T, V>, state: &TreeTN<T, V>) -> Result<Self> {
146 let mut input_mapping = HashMap::new();
147 let mut output_mapping = HashMap::new();
148
149 for node in mpo.site_index_network().node_names() {
150 // Get state's site indices for this node
151 let state_site = state.site_space(node);
152
153 // Get MPO's site indices for this node
154 let mpo_site = mpo.site_space(node);
155
156 match (state_site, mpo_site) {
157 (Some(state_indices), Some(mpo_indices)) => {
158 // MPO should have exactly 2 site indices per state site index:
159 // one for input (s_in_tmp) and one for output (s_out_tmp)
160 // Both should have the same dimension as the state's site index.
161
162 if state_indices.len() * 2 != mpo_indices.len() {
163 return Err(anyhow::anyhow!(
164 "Node {:?}: MPO should have 2x site indices. State has {}, MPO has {}",
165 node,
166 state_indices.len(),
167 mpo_indices.len()
168 ));
169 }
170
171 // For each state site index, find matching MPO indices by dimension
172 for state_idx in state_indices {
173 let dim = state_idx.dim();
174
175 // Find MPO indices with matching dimension
176 let matching_mpo: Vec<_> =
177 mpo_indices.iter().filter(|idx| idx.dim() == dim).collect();
178
179 if matching_mpo.len() < 2 {
180 return Err(anyhow::anyhow!(
181 "Node {:?}: Not enough MPO indices with dimension {}. Found {}",
182 node,
183 dim,
184 matching_mpo.len()
185 ));
186 }
187
188 // Convention: first matching is s_in_tmp, second is s_out_tmp
189 // (This depends on how the MPO was constructed)
190 input_mapping.insert(
191 node.clone(),
192 IndexMapping {
193 true_index: state_idx.clone(),
194 internal_index: matching_mpo[0].clone(),
195 },
196 );
197
198 // For output, use the same true index (space(x) = space(b))
199 output_mapping.insert(
200 node.clone(),
201 IndexMapping {
202 true_index: state_idx.clone(),
203 internal_index: matching_mpo[1].clone(),
204 },
205 );
206 }
207 }
208 (None, None) => {
209 // No site indices for this node, OK
210 }
211 _ => {
212 return Err(anyhow::anyhow!(
213 "Node {:?}: Mismatched site space presence between state and MPO",
214 node
215 ));
216 }
217 }
218 }
219
220 Ok(Self {
221 mpo,
222 input_mapping,
223 output_mapping,
224 })
225 }
226
227 /// Apply the operator to a local tensor at a specific region.
228 ///
229 /// This is used during the sweep for local updates.
230 ///
231 /// # Arguments
232 ///
233 /// * `local_tensor` - The local tensor (merged tensors from the region)
234 /// * `region` - The nodes in the current region
235 ///
236 /// # Returns
237 ///
238 /// The result of applying the operator to the local tensor.
239 pub fn apply_local(&self, local_tensor: &T, region: &[V]) -> Result<T> {
240 // Step 1: Replace input indices in local_tensor with internal indices
241 let mut transformed = local_tensor.clone();
242 for node in region {
243 if let Some(mapping) = self.input_mapping.get(node) {
244 transformed =
245 transformed.replaceind(&mapping.true_index, &mapping.internal_index)?;
246 }
247 }
248
249 // Step 2: Contract with local operator tensors
250 let mut op_tensor: Option<T> = None;
251 for node in region {
252 let node_idx = self
253 .mpo
254 .node_index(node)
255 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in MPO", node))?;
256 let tensor = self
257 .mpo
258 .tensor(node_idx)
259 .ok_or_else(|| anyhow::anyhow!("Tensor not found in MPO for node {:?}", node))?
260 .clone();
261
262 op_tensor = Some(match op_tensor {
263 None => tensor,
264 Some(t) => T::contract(&[&t, &tensor], AllowedPairs::All)?,
265 });
266 }
267
268 let op_tensor = op_tensor.ok_or_else(|| anyhow::anyhow!("Empty region"))?;
269
270 // Contract transformed tensor with operator
271 let contracted = T::contract(&[&transformed, &op_tensor], AllowedPairs::All)?;
272
273 // Step 3: Replace output indices back to true indices
274 let mut result = contracted;
275 for node in region {
276 if let Some(mapping) = self.output_mapping.get(node) {
277 result = result.replaceind(&mapping.internal_index, &mapping.true_index)?;
278 }
279 }
280
281 Ok(result)
282 }
283
284 /// Get the internal MPO.
285 pub fn mpo(&self) -> &TreeTN<T, V> {
286 &self.mpo
287 }
288
289 /// Get input mapping for a node.
290 pub fn get_input_mapping(&self, node: &V) -> Option<&IndexMapping<T::Index>> {
291 self.input_mapping.get(node)
292 }
293
294 /// Get output mapping for a node.
295 pub fn get_output_mapping(&self, node: &V) -> Option<&IndexMapping<T::Index>> {
296 self.output_mapping.get(node)
297 }
298
299 fn single_site_index_from_state(state: &TreeTN<T, V>, node: &V) -> Result<T::Index> {
300 let site_space = state
301 .site_space(node)
302 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in state site space", node))?;
303
304 if site_space.len() != 1 {
305 return Err(anyhow::anyhow!(
306 "Node {:?}: expected exactly 1 site index in state, found {}",
307 node,
308 site_space.len()
309 ));
310 }
311
312 site_space
313 .iter()
314 .next()
315 .cloned()
316 .ok_or_else(|| anyhow::anyhow!("Node {:?}: missing site index", node))
317 }
318
319 /// Reset true input indices to match the given state's site space.
320 ///
321 /// This only rewrites the external mapping. The internal MPO indices are unchanged.
322 pub fn set_input_space_from_state(&mut self, state: &TreeTN<T, V>) -> Result<()> {
323 let nodes: Vec<V> = self.input_mapping.keys().cloned().collect();
324 for node in nodes {
325 let new_true_index = Self::single_site_index_from_state(state, &node)?;
326 let mapping = self
327 .input_mapping
328 .get_mut(&node)
329 .ok_or_else(|| anyhow::anyhow!("Input mapping missing for node {:?}", node))?;
330 if mapping.internal_index.dim() != new_true_index.dim() {
331 return Err(anyhow::anyhow!(
332 "Node {:?}: input mapping dimension {} does not match state site dimension {}",
333 node,
334 mapping.internal_index.dim(),
335 new_true_index.dim()
336 ));
337 }
338 mapping.true_index = new_true_index;
339 }
340 Ok(())
341 }
342
343 /// Reset true output indices to match the given state's site space.
344 ///
345 /// This only rewrites the external mapping. The internal MPO indices are unchanged.
346 pub fn set_output_space_from_state(&mut self, state: &TreeTN<T, V>) -> Result<()> {
347 let nodes: Vec<V> = self.output_mapping.keys().cloned().collect();
348 for node in nodes {
349 let new_true_index = Self::single_site_index_from_state(state, &node)?;
350 let mapping = self
351 .output_mapping
352 .get_mut(&node)
353 .ok_or_else(|| anyhow::anyhow!("Output mapping missing for node {:?}", node))?;
354 if mapping.internal_index.dim() != new_true_index.dim() {
355 return Err(anyhow::anyhow!(
356 "Node {:?}: output mapping dimension {} does not match state site dimension {}",
357 node,
358 mapping.internal_index.dim(),
359 new_true_index.dim()
360 ));
361 }
362 mapping.true_index = new_true_index;
363 }
364 Ok(())
365 }
366
367 /// Align this operator's input and output site index mappings to match a target state.
368 ///
369 /// For each node in the operator's input and output mappings, the `true_index` is
370 /// updated to the corresponding site index from the target state. The internal MPO
371 /// indices remain unchanged.
372 ///
373 /// This is useful when an operator (e.g., from `shift_operator` or `affine_operator`)
374 /// was constructed with its own site indices, but needs to be applied to a state that
375 /// has different site index IDs. After calling `align_to_state`, the operator's
376 /// `true_index` fields will reference the state's site indices, enabling correct
377 /// index contraction during `apply_local`.
378 ///
379 /// # Arguments
380 ///
381 /// * `state` - The target state whose site indices define the true index space.
382 /// Each node in the operator's mappings must exist in the state with exactly one
383 /// site index of matching dimension.
384 ///
385 /// # Errors
386 ///
387 /// Returns an error if:
388 /// - A node in the operator's mapping is not found in the state
389 /// - A node in the state has more than one site index
390 /// - The dimension of the state's site index does not match the operator's internal index
391 ///
392 /// # Examples
393 ///
394 /// ```
395 /// use std::collections::HashMap;
396 ///
397 /// use tensor4all_core::{DynIndex, IndexLike, TensorDynLen};
398 /// use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN};
399 ///
400 /// let state_index = DynIndex::new_dyn(2);
401 /// let state_tensor = TensorDynLen::from_dense(vec![state_index.clone()], vec![1.0, 2.0]).unwrap();
402 /// let state = TreeTN::<TensorDynLen, usize>::from_tensors(vec![state_tensor], vec![0]).unwrap();
403 ///
404 /// let input_internal = DynIndex::new_dyn(2);
405 /// let output_internal = DynIndex::new_dyn(2);
406 /// let mpo_tensor = TensorDynLen::from_dense(
407 /// vec![input_internal.clone(), output_internal.clone()],
408 /// vec![1.0, 0.0, 0.0, 1.0],
409 /// ).unwrap();
410 /// let mpo = TreeTN::<TensorDynLen, usize>::from_tensors(vec![mpo_tensor], vec![0]).unwrap();
411 ///
412 /// let mut input_mapping = HashMap::new();
413 /// input_mapping.insert(
414 /// 0usize,
415 /// IndexMapping {
416 /// true_index: DynIndex::new_dyn(2),
417 /// internal_index: input_internal,
418 /// },
419 /// );
420 /// let mut output_mapping = HashMap::new();
421 /// output_mapping.insert(
422 /// 0usize,
423 /// IndexMapping {
424 /// true_index: DynIndex::new_dyn(2),
425 /// internal_index: output_internal,
426 /// },
427 /// );
428 ///
429 /// let mut op = LinearOperator::new(mpo, input_mapping, output_mapping);
430 /// op.align_to_state(&state).unwrap();
431 ///
432 /// assert!(op.input_mappings()[&0].true_index.same_id(&state_index));
433 /// assert!(op.output_mappings()[&0].true_index.same_id(&state_index));
434 /// ```
435 pub fn align_to_state(&mut self, state: &TreeTN<T, V>) -> Result<()> {
436 self.set_input_space_from_state(state)?;
437 self.set_output_space_from_state(state)?;
438 Ok(())
439 }
440
441 /// Returns the transposed operator by swapping input and output mappings.
442 ///
443 /// The pullback of a forward operator is its transpose: if the forward
444 /// operator realizes the matrix `M_{y,x}`, the transposed operator
445 /// realizes `M_{x,y}`. This method swaps `input_mapping` and
446 /// `output_mapping` without copying the underlying MPO tensors — it is
447 /// an O(1) operation.
448 ///
449 /// `.transpose().transpose()` yields an operator equivalent to the
450 /// original (mappings restored, MPO unchanged).
451 ///
452 /// # Examples
453 ///
454 /// ```
455 /// use std::collections::HashMap;
456 /// use tensor4all_core::{DynIndex, IndexLike, TensorDynLen};
457 /// use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN};
458 ///
459 /// let site_in = DynIndex::new_dyn(2);
460 /// let site_out = DynIndex::new_dyn(3);
461 /// let s_in_tmp = DynIndex::new_dyn(2);
462 /// let s_out_tmp = DynIndex::new_dyn(3);
463 ///
464 /// let mpo_tensor = TensorDynLen::from_dense(
465 /// vec![s_in_tmp.clone(), s_out_tmp.clone()],
466 /// vec![1.0_f64, 0.0, 0.0, 0.0, 1.0, 0.0],
467 /// ).unwrap();
468 /// let mpo = TreeTN::<_, usize>::from_tensors(vec![mpo_tensor], vec![0]).unwrap();
469 ///
470 /// let mut input_mapping = HashMap::new();
471 /// input_mapping.insert(
472 /// 0usize,
473 /// IndexMapping { true_index: site_in.clone(), internal_index: s_in_tmp.clone() },
474 /// );
475 /// let mut output_mapping = HashMap::new();
476 /// output_mapping.insert(
477 /// 0usize,
478 /// IndexMapping { true_index: site_out.clone(), internal_index: s_out_tmp.clone() },
479 /// );
480 ///
481 /// let op = LinearOperator::new(mpo, input_mapping, output_mapping);
482 /// let t = op.transpose();
483 ///
484 /// // Input/output mappings are swapped.
485 /// assert!(t.input_mapping[&0].true_index.same_id(&site_out));
486 /// assert!(t.output_mapping[&0].true_index.same_id(&site_in));
487 /// ```
488 pub fn transpose(self) -> Self {
489 Self {
490 mpo: self.mpo,
491 input_mapping: self.output_mapping,
492 output_mapping: self.input_mapping,
493 }
494 }
495}
496
497// ============================================================================
498// Helper methods
499// ============================================================================
500
501use std::collections::HashSet;
502
503use crate::operator::Operator;
504use crate::SiteIndexNetwork;
505
506// Implement Operator trait for LinearOperator
507impl<T, V> Operator<T, V> for LinearOperator<T, V>
508where
509 T: TensorLike,
510 T::Index: IndexLike + Clone + Hash + Eq,
511 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
512{
513 fn site_indices(&self) -> HashSet<T::Index> {
514 // Return union of input and output true indices
515 let mut result: HashSet<T::Index> = self
516 .input_mapping
517 .values()
518 .map(|m| m.true_index.clone())
519 .collect();
520 result.extend(self.output_mapping.values().map(|m| m.true_index.clone()));
521 result
522 }
523
524 fn site_index_network(&self) -> &SiteIndexNetwork<V, T::Index> {
525 self.mpo.site_index_network()
526 }
527
528 fn node_names(&self) -> HashSet<V> {
529 self.mpo
530 .site_index_network()
531 .node_names()
532 .into_iter()
533 .cloned()
534 .collect()
535 }
536}
537
538impl<T, V> LinearOperator<T, V>
539where
540 T: TensorLike,
541 T::Index: IndexLike,
542 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
543{
544 /// Get all input site indices (true indices from state space).
545 pub fn input_site_indices(&self) -> HashSet<T::Index> {
546 self.input_mapping
547 .values()
548 .map(|m| m.true_index.clone())
549 .collect()
550 }
551
552 /// Get all output site indices (true indices from result space).
553 pub fn output_site_indices(&self) -> HashSet<T::Index> {
554 self.output_mapping
555 .values()
556 .map(|m| m.true_index.clone())
557 .collect()
558 }
559
560 /// Get all input mappings.
561 pub fn input_mappings(&self) -> &HashMap<V, IndexMapping<T::Index>> {
562 &self.input_mapping
563 }
564
565 /// Get all output mappings.
566 pub fn output_mappings(&self) -> &HashMap<V, IndexMapping<T::Index>> {
567 &self.output_mapping
568 }
569}
570
571#[cfg(test)]
572mod tests;