tensor4all_treetn/treetn/tensor_like.rs
1//! TensorIndex and TensorLike implementations for TreeTN.
2//!
3//! **Design Decision**: TreeTN implements TensorIndex but NOT TensorLike.
4//!
5//! ## TensorIndex (Implemented)
6//!
7//! TreeTN implements TensorIndex because index operations are well-defined:
8//! - `external_indices()`: Returns all site (physical) indices
9//! - `replaceind()` / `replaceinds()`: Replace indices in tensors and metadata
10//!
11//! ## TensorLike (NOT Implemented)
12//!
13//! TreeTN does NOT implement TensorLike because:
14//! 1. **Unclear semantics**: What would `tensordot` between two TreeTNs mean?
15//! 2. **Hidden costs**: Full contraction has exponential cost
16//! 3. **Separation of concerns**: Dense tensors and TNs are fundamentally different
17//!
18//! ## Alternative API
19//!
20//! TreeTN provides its own methods instead:
21//! - `site_indices()`: Returns physical indices (not bonds)
22//! - `contract_to_tensor()`: Explicit method for full contraction (exponential cost)
23//! - `contract_nodes()`: Graph operations for node contraction
24
25use std::hash::Hash;
26
27use anyhow::Result;
28use tensor4all_core::{
29 DynIndex, IndexLike, LinearizationOrder, TensorDynLen, TensorIndex, TensorLike,
30};
31
32use super::TreeTN;
33
34// ============================================================================
35// TensorIndex implementation for TreeTN
36// ============================================================================
37
38impl<T, V> TensorIndex for TreeTN<T, V>
39where
40 T: TensorLike,
41 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
42 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
43{
44 type Index = T::Index;
45
46 /// Return all external (site/physical) indices from all nodes.
47 ///
48 /// This collects all site indices from `site_index_network`.
49 /// Bond indices are NOT included (they are internal to the network).
50 fn external_indices(&self) -> Vec<Self::Index> {
51 let mut result = Vec::new();
52 for node_name in self.node_names() {
53 if let Some(site_space) = self.site_space(&node_name) {
54 result.extend(site_space.iter().cloned());
55 }
56 }
57 result
58 }
59
60 fn num_external_indices(&self) -> usize {
61 self.node_names()
62 .iter()
63 .filter_map(|name| self.site_space(name))
64 .map(|space| space.len())
65 .sum()
66 }
67
68 /// Replace an index in this TreeTN.
69 ///
70 /// Looks up the index location (site or link) and replaces it in:
71 /// - The tensor containing it
72 /// - The appropriate index network (site_index_network or link_index_network)
73 ///
74 /// Note: `replace_tensor` automatically updates the `site_index_network` based on
75 /// the new tensor's indices, so we don't need to manually call `replace_site_index`.
76 fn replaceind(&self, old_index: &Self::Index, new_index: &Self::Index) -> Result<Self> {
77 // Validate dimension match
78 if old_index.dim() != new_index.dim() {
79 return Err(anyhow::anyhow!(
80 "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
81 old_index.dim(),
82 new_index.dim()
83 ));
84 }
85
86 let mut result = self.clone();
87
88 // Check if it's a site index
89 if let Some(node_name) = self.site_index_network.find_node_by_index(old_index) {
90 let node_idx = result
91 .node_index(node_name)
92 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", node_name))?;
93
94 // Replace in tensor - this also updates site_index_network via replace_tensor
95 let tensor = result
96 .tensor(node_idx)
97 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))?;
98 let old_in_tensor = tensor
99 .external_indices()
100 .iter()
101 .find(|idx| idx.id() == old_index.id())
102 .ok_or_else(|| {
103 anyhow::anyhow!("Index not found in tensor at node {:?}", node_name)
104 })?
105 .clone();
106 let new_tensor = tensor.replaceind(&old_in_tensor, new_index)?;
107 result.replace_tensor(node_idx, new_tensor)?;
108
109 // Keep ortho_towards consistent (if present)
110 if let Some(dir) = result.ortho_towards.remove(old_index) {
111 result.ortho_towards.insert(new_index.clone(), dir);
112 }
113
114 return Ok(result);
115 }
116
117 // Check if it's a link index
118 if let Some(edge) = self.link_index_network.find_edge(old_index) {
119 let (node_a, node_b) = result
120 .graph
121 .graph()
122 .edge_endpoints(edge)
123 .ok_or_else(|| anyhow::anyhow!("Edge {:?} not found", edge))?;
124
125 // IMPORTANT: Update edge weight FIRST so replace_tensor validation matches.
126 *result
127 .bond_index_mut(edge)
128 .ok_or_else(|| anyhow::anyhow!("Bond index not found"))? = new_index.clone();
129
130 // Replace in both endpoint tensors - this also updates site_index_network
131 for node in [node_a, node_b] {
132 let tensor = result
133 .tensor(node)
134 .ok_or_else(|| anyhow::anyhow!("Tensor not found"))?;
135 let old_in_tensor = tensor
136 .external_indices()
137 .iter()
138 .find(|idx| idx.id() == old_index.id())
139 .ok_or_else(|| anyhow::anyhow!("Bond index not found in endpoint tensor"))?
140 .clone();
141 let new_tensor = tensor.replaceind(&old_in_tensor, new_index)?;
142 result.replace_tensor(node, new_tensor)?;
143 }
144
145 // Keep ortho_towards consistent (if present)
146 if let Some(dir) = result.ortho_towards.remove(old_index) {
147 result.ortho_towards.insert(new_index.clone(), dir);
148 }
149
150 // Replace in link_index_network
151 result
152 .link_index_network
153 .replace_index(old_index, new_index, edge)
154 .map_err(|e| anyhow::anyhow!("{}", e))?;
155
156 return Ok(result);
157 }
158
159 Err(anyhow::anyhow!(
160 "Index {:?} not found in TreeTN",
161 old_index.id()
162 ))
163 }
164
165 /// Replace multiple indices in this TreeTN.
166 fn replaceinds(
167 &self,
168 old_indices: &[Self::Index],
169 new_indices: &[Self::Index],
170 ) -> Result<Self> {
171 if old_indices.len() != new_indices.len() {
172 return Err(anyhow::anyhow!(
173 "Length mismatch: {} old indices, {} new indices",
174 old_indices.len(),
175 new_indices.len()
176 ));
177 }
178
179 let mut result = self.clone();
180 for (old, new) in old_indices.iter().zip(new_indices.iter()) {
181 result = result.replaceind(old, new)?;
182 }
183 Ok(result)
184 }
185}
186
187impl<V> TreeTN<TensorDynLen, V>
188where
189 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
190 <DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
191{
192 /// Replace one site index with multiple site indices using an exact reshape.
193 ///
194 /// This is a TreeTN-level wrapper around [`TensorDynLen::unfuse_index`].
195 /// It updates the owning node tensor and the site-index metadata, without
196 /// introducing any approximation.
197 ///
198 /// # Examples
199 /// ```
200 /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen};
201 /// use tensor4all_treetn::TreeTN;
202 ///
203 /// let fused = DynIndex::new_dyn(4);
204 /// let tensor = TensorDynLen::from_dense(vec![fused.clone()], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
205 /// let tn = TreeTN::<TensorDynLen, usize>::from_tensors(vec![tensor], vec![0]).unwrap();
206 /// let left = DynIndex::new_dyn(2);
207 /// let right = DynIndex::new_dyn(2);
208 ///
209 /// let unfused = tn
210 /// .replace_site_index_with_indices(
211 /// &fused,
212 /// &[left.clone(), right.clone()],
213 /// LinearizationOrder::ColumnMajor,
214 /// )
215 /// .unwrap();
216 ///
217 /// let dense = unfused.contract_to_tensor().unwrap();
218 /// let expected = TensorDynLen::from_dense(vec![left, right], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
219 /// assert!((&dense - &expected).maxabs() < 1.0e-12);
220 /// ```
221 pub fn replace_site_index_with_indices(
222 &self,
223 old_index: &DynIndex,
224 new_indices: &[DynIndex],
225 order: LinearizationOrder,
226 ) -> Result<Self> {
227 let node_name = self
228 .site_index_network
229 .find_node_by_index(old_index)
230 .cloned()
231 .ok_or_else(|| {
232 anyhow::anyhow!(
233 "site index {:?} not found in TreeTN site index network",
234 old_index.id()
235 )
236 })?;
237 let node_idx = self
238 .node_index(&node_name)
239 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", node_name))?;
240 let tensor = self
241 .tensor(node_idx)
242 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))?;
243 let new_tensor = tensor.unfuse_index(old_index, new_indices, order)?;
244
245 let mut result = self.clone();
246 result.replace_tensor(node_idx, new_tensor)?;
247
248 if let Some(dir) = result.ortho_towards.remove(old_index) {
249 for new_index in new_indices {
250 result.ortho_towards.insert(new_index.clone(), dir.clone());
251 }
252 }
253
254 Ok(result)
255 }
256}
257
258#[cfg(test)]
259mod tests;