Skip to main content

tensor4all_treetn/treetn/
tensor_like.rs

1//! TensorIndex and TensorLike implementations for TreeTN.
2//!
3//! **Design Decision**: TreeTN implements TensorIndex but NOT TensorLike.
4//!
5//! ## TensorIndex (Implemented)
6//!
7//! TreeTN implements TensorIndex because index operations are well-defined:
8//! - `external_indices()`: Returns all site (physical) indices
9//! - `replaceind()` / `replaceinds()`: Replace indices in tensors and metadata
10//!
11//! ## TensorLike (NOT Implemented)
12//!
13//! TreeTN does NOT implement TensorLike because:
14//! 1. **Unclear semantics**: What would `tensordot` between two TreeTNs mean?
15//! 2. **Hidden costs**: Full contraction has exponential cost
16//! 3. **Separation of concerns**: Dense tensors and TNs are fundamentally different
17//!
18//! ## Alternative API
19//!
20//! TreeTN provides its own methods instead:
21//! - `site_indices()`: Returns physical indices (not bonds)
22//! - `contract_to_tensor()`: Explicit method for full contraction (exponential cost)
23//! - `contract_nodes()`: Graph operations for node contraction
24
25use std::hash::Hash;
26
27use anyhow::Result;
28use tensor4all_core::{
29    DynIndex, IndexLike, LinearizationOrder, TensorDynLen, TensorIndex, TensorLike,
30};
31
32use super::TreeTN;
33
34// ============================================================================
35// TensorIndex implementation for TreeTN
36// ============================================================================
37
38impl<T, V> TensorIndex for TreeTN<T, V>
39where
40    T: TensorLike,
41    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
42    <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
43{
44    type Index = T::Index;
45
46    /// Return all external (site/physical) indices from all nodes.
47    ///
48    /// This collects all site indices from `site_index_network`.
49    /// Bond indices are NOT included (they are internal to the network).
50    fn external_indices(&self) -> Vec<Self::Index> {
51        let mut result = Vec::new();
52        for node_name in self.node_names() {
53            if let Some(site_space) = self.site_space(&node_name) {
54                result.extend(site_space.iter().cloned());
55            }
56        }
57        result
58    }
59
60    fn num_external_indices(&self) -> usize {
61        self.node_names()
62            .iter()
63            .filter_map(|name| self.site_space(name))
64            .map(|space| space.len())
65            .sum()
66    }
67
68    /// Replace an index in this TreeTN.
69    ///
70    /// Looks up the index location (site or link) and replaces it in:
71    /// - The tensor containing it
72    /// - The appropriate index network (site_index_network or link_index_network)
73    ///
74    /// Note: `replace_tensor` automatically updates the `site_index_network` based on
75    /// the new tensor's indices, so we don't need to manually call `replace_site_index`.
76    fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
77        // Validate dimension match
78        if old_index.dim() != new_index.dim() {
79            return Err(anyhow::anyhow!(
80                "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
81                old_index.dim(),
82                new_index.dim()
83            ));
84        }
85
86        let mut result = self.clone();
87
88        // Check if it's a site index
89        if let Some(node_name) = self.site_index_network.find_node_by_index(old_index) {
90            let node_idx = result
91                .node_index(node_name)
92                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", node_name))?;
93
94            // Replace in tensor - this also updates site_index_network via replace_tensor
95            let tensor = result
96                .tensor(node_idx)
97                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))?;
98            let old_in_tensor = tensor
99                .external_indices()
100                .iter()
101                .find(|idx| idx.id() == old_index.id())
102                .ok_or_else(|| {
103                    anyhow::anyhow!("Index not found in tensor at node {:?}", node_name)
104                })?
105                .clone();
106            let new_tensor = tensor.replaceind(&old_in_tensor, new_index)?;
107            result.replace_tensor(node_idx, new_tensor)?;
108
109            // Keep ortho_towards consistent (if present)
110            if let Some(dir) = result.ortho_towards.remove(old_index) {
111                result.ortho_towards.insert(new_index.clone(), dir);
112            }
113
114            return Ok(result);
115        }
116
117        // Check if it's a link index
118        if let Some(edge) = self.link_index_network.find_edge(old_index) {
119            let (node_a, node_b) = result
120                .graph
121                .graph()
122                .edge_endpoints(edge)
123                .ok_or_else(|| anyhow::anyhow!("Edge {:?} not found", edge))?;
124
125            // IMPORTANT: Update edge weight FIRST so replace_tensor validation matches.
126            *result
127                .bond_index_mut(edge)
128                .ok_or_else(|| anyhow::anyhow!("Bond index not found"))? = new_index.clone();
129
130            // Replace in both endpoint tensors - this also updates site_index_network
131            for node in [node_a, node_b] {
132                let tensor = result
133                    .tensor(node)
134                    .ok_or_else(|| anyhow::anyhow!("Tensor not found"))?;
135                let old_in_tensor = tensor
136                    .external_indices()
137                    .iter()
138                    .find(|idx| idx.id() == old_index.id())
139                    .ok_or_else(|| anyhow::anyhow!("Bond index not found in endpoint tensor"))?
140                    .clone();
141                let new_tensor = tensor.replaceind(&old_in_tensor, new_index)?;
142                result.replace_tensor(node, new_tensor)?;
143            }
144
145            // Keep ortho_towards consistent (if present)
146            if let Some(dir) = result.ortho_towards.remove(old_index) {
147                result.ortho_towards.insert(new_index.clone(), dir);
148            }
149
150            // Replace in link_index_network
151            result
152                .link_index_network
153                .replace_index(old_index, new_index, edge)
154                .map_err(|e| anyhow::anyhow!("{}", e))?;
155
156            return Ok(result);
157        }
158
159        Err(anyhow::anyhow!(
160            "Index {:?} not found in TreeTN",
161            old_index.id()
162        ))
163    }
164
165    /// Replace multiple indices in this TreeTN.
166    fn replaceinds(
167        &self,
168        old_indices: &[Self::Index],
169        new_indices: &[Self::Index],
170    ) -> Result<Self> {
171        if old_indices.len() != new_indices.len() {
172            return Err(anyhow::anyhow!(
173                "Length mismatch: {} old indices, {} new indices",
174                old_indices.len(),
175                new_indices.len()
176            ));
177        }
178
179        let mut result = self.clone();
180        for (old, new) in old_indices.iter().zip(new_indices.iter()) {
181            result = result.replaceind(old, new)?;
182        }
183        Ok(result)
184    }
185}
186
187impl<V> TreeTN<TensorDynLen, V>
188where
189    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
190    <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
191{
192    /// Replace one site index with multiple site indices using an exact reshape.
193    ///
194    /// This is a TreeTN-level wrapper around [`TensorDynLen::unfuse_index`].
195    /// It updates the owning node tensor and the site-index metadata, without
196    /// introducing any approximation.
197    ///
198    /// # Examples
199    /// ```
200    /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen};
201    /// use tensor4all_treetn::TreeTN;
202    ///
203    /// let fused = DynIndex::new_dyn(4);
204    /// let tensor = TensorDynLen::from_dense(vec![fused.clone()], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
205    /// let tn = TreeTN::<TensorDynLen, usize>::from_tensors(vec![tensor], vec![0]).unwrap();
206    /// let left = DynIndex::new_dyn(2);
207    /// let right = DynIndex::new_dyn(2);
208    ///
209    /// let unfused = tn
210    ///     .replace_site_index_with_indices(
211    ///         &fused,
212    ///         &[left.clone(), right.clone()],
213    ///         LinearizationOrder::ColumnMajor,
214    ///     )
215    ///     .unwrap();
216    ///
217    /// let dense = unfused.contract_to_tensor().unwrap();
218    /// let expected = TensorDynLen::from_dense(vec![left, right], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
219    /// assert!((&dense - &expected).maxabs() < 1.0e-12);
220    /// ```
221    pub fn replace_site_index_with_indices(
222        &self,
223        old_index: &DynIndex,
224        new_indices: &[DynIndex],
225        order: LinearizationOrder,
226    ) -> Result<Self> {
227        let node_name = self
228            .site_index_network
229            .find_node_by_index(old_index)
230            .cloned()
231            .ok_or_else(|| {
232                anyhow::anyhow!(
233                    "site index {:?} not found in TreeTN site index network",
234                    old_index.id()
235                )
236            })?;
237        let node_idx = self
238            .node_index(&node_name)
239            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", node_name))?;
240        let tensor = self
241            .tensor(node_idx)
242            .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))?;
243        let new_tensor = tensor.unfuse_index(old_index, new_indices, order)?;
244
245        let mut result = self.clone();
246        result.replace_tensor(node_idx, new_tensor)?;
247
248        if let Some(dir) = result.ortho_towards.remove(old_index) {
249            for new_index in new_indices {
250                result.ortho_towards.insert(new_index.clone(), dir.clone());
251            }
252        }
253
254        Ok(result)
255    }
256}
257
258#[cfg(test)]
259mod tests;