tensor4all_treetn/treetn/ops.rs
1//! Trait implementations and operations for TreeTN.
2//!
3//! This module provides:
4//! - `Default` implementation
5//! - `Clone` implementation
6//! - `Debug` implementation
7//! - `log_norm` for computing the logarithm of the Frobenius norm
8//! - `norm`, `norm_squared` for computing the Frobenius norm
9//! - `inner` for computing inner products of two TreeTNs
10//! - `to_dense` for contracting to a single tensor
11//! - `evaluate` for evaluating at specific index values
12//! - `evaluate_at` for evaluating using `Index` objects instead of raw IDs
13//! - `all_site_indices` for retrieving all site indices and their owning vertices
14
15use std::collections::{HashMap, HashSet};
16use std::hash::Hash;
17
18use tensor4all_core::{AllowedPairs, AnyScalar, ColMajorArrayRef, IndexLike, TensorLike};
19
20use super::TreeTN;
21
22// ============================================================================
23// Default implementation
24// ============================================================================
25
26impl<T, V> Default for TreeTN<T, V>
27where
28 T: TensorLike,
29 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
30{
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36// ============================================================================
37// Clone implementation
38// ============================================================================
39
40impl<T, V> Clone for TreeTN<T, V>
41where
42 T: TensorLike,
43 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
44{
45 fn clone(&self) -> Self {
46 Self {
47 graph: self.graph.clone(),
48 canonical_region: self.canonical_region.clone(),
49 canonical_form: self.canonical_form,
50 site_index_network: self.site_index_network.clone(),
51 link_index_network: self.link_index_network.clone(),
52 ortho_towards: self.ortho_towards.clone(),
53 }
54 }
55}
56
57// ============================================================================
58// Debug implementation
59// ============================================================================
60
61impl<T, V> std::fmt::Debug for TreeTN<T, V>
62where
63 T: TensorLike,
64 V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
65{
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("TreeTN")
68 .field("node_count", &self.node_count())
69 .field("edge_count", &self.edge_count())
70 .field("canonical_region", &self.canonical_region)
71 .finish_non_exhaustive()
72 }
73}
74
75// ============================================================================
76// Norm Computation
77// ============================================================================
78
79use anyhow::{Context, Result};
80
81use crate::algorithm::CanonicalForm;
82use crate::CanonicalizationOptions;
83
84impl<T, V> TreeTN<T, V>
85where
86 T: TensorLike,
87 V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
88{
89 /// Compute log(||TreeTN||_F), the log of the Frobenius norm.
90 ///
91 /// Uses canonicalization to avoid numerical overflow:
92 /// when canonicalized to a single site with Unitary form,
93 /// the Frobenius norm of the whole network equals the norm of the center tensor.
94 ///
95 /// # Note
96 /// This method is mutable because it may need to canonicalize the network
97 /// to a single Unitary center. Use `log_norm` (without canonicalization) if you
98 /// already have a properly canonicalized network.
99 ///
100 /// # Returns
101 /// The natural logarithm of the Frobenius norm.
102 ///
103 /// # Errors
104 /// Returns an error if:
105 /// - The network is empty
106 /// - Canonicalization fails
107 ///
108 /// # Examples
109 ///
110 /// ```
111 /// use tensor4all_treetn::TreeTN;
112 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
113 ///
114 /// let s = DynIndex::new_dyn(2);
115 /// let t = TensorDynLen::from_dense(vec![s], vec![3.0_f64, 4.0]).unwrap();
116 /// let mut tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
117 ///
118 /// // log(||[3, 4]||) = log(5)
119 /// let ln = tn.log_norm().unwrap();
120 /// assert!((ln - 5.0_f64.ln()).abs() < 1e-10);
121 /// ```
122 pub fn log_norm(&mut self) -> Result<f64> {
123 let n = self.node_count();
124 if n == 0 {
125 return Err(anyhow::anyhow!("Cannot compute log_norm of empty TreeTN"))
126 .context("log_norm: network must have at least one node");
127 }
128
129 // Determine the single center site (by name)
130 let center_name: V =
131 if self.is_canonicalized() && self.canonical_form() == Some(CanonicalForm::Unitary) {
132 if self.canonical_region.len() == 1 {
133 // Already Unitary canonicalized to single site - use it
134 self.canonical_region.iter().next().unwrap().clone()
135 } else {
136 // Unitary canonicalized to multiple sites - canonicalize to min site
137 let min_center = self.canonical_region.iter().min().unwrap().clone();
138 self.canonicalize_mut(
139 std::iter::once(min_center.clone()),
140 CanonicalizationOptions::default(),
141 )
142 .context("log_norm: failed to canonicalize to single site")?;
143 min_center
144 }
145 } else {
146 // Not canonicalized or not Unitary - canonicalize to min node name
147 let min_node_name = self
148 .node_names()
149 .into_iter()
150 .min()
151 .ok_or_else(|| anyhow::anyhow!("No nodes in TreeTN"))
152 .context("log_norm: network must have nodes")?;
153 self.canonicalize_mut(
154 std::iter::once(min_node_name.clone()),
155 CanonicalizationOptions::default(),
156 )
157 .context("log_norm: failed to canonicalize")?;
158 min_node_name
159 };
160
161 // Get center node index and tensor
162 let center_node = self
163 .node_index(¢er_name)
164 .ok_or_else(|| anyhow::anyhow!("Center node not found"))
165 .context("log_norm: center node must exist")?;
166
167 let center_tensor = self
168 .tensor(center_node)
169 .ok_or_else(|| anyhow::anyhow!("Center tensor not found"))
170 .context("log_norm: center tensor must exist")?;
171
172 let norm_sq = center_tensor.norm_squared();
173 let norm = norm_sq.sqrt();
174
175 Ok(norm.ln())
176 }
177
178 /// Compute the Frobenius norm of the TreeTN.
179 ///
180 /// Uses `log_norm` internally: `norm = exp(log_norm)`.
181 ///
182 /// # Note
183 /// This method is mutable because it may need to canonicalize the network.
184 ///
185 /// # Errors
186 /// Returns an error if the network is empty or canonicalization fails.
187 ///
188 /// # Examples
189 ///
190 /// ```
191 /// use tensor4all_treetn::TreeTN;
192 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
193 ///
194 /// // Single-node TreeTN with tensor [1, 0, 0, 1] (identity 2x2)
195 /// let s0 = DynIndex::new_dyn(2);
196 /// let s1 = DynIndex::new_dyn(2);
197 /// let t = TensorDynLen::from_dense(
198 /// vec![s0.clone(), s1.clone()],
199 /// vec![1.0_f64, 0.0, 0.0, 1.0],
200 /// ).unwrap();
201 ///
202 /// let mut tn = TreeTN::<_, String>::from_tensors(
203 /// vec![t],
204 /// vec!["A".to_string()],
205 /// ).unwrap();
206 ///
207 /// // Frobenius norm of [[1,0],[0,1]] = sqrt(2)
208 /// let n = tn.norm().unwrap();
209 /// assert!((n - 2.0_f64.sqrt()).abs() < 1e-10);
210 /// ```
211 pub fn norm(&mut self) -> Result<f64> {
212 let log_n = self
213 .log_norm()
214 .context("norm: failed to compute log_norm")?;
215 Ok(log_n.exp())
216 }
217
218 /// Compute the squared Frobenius norm of the TreeTN.
219 ///
220 /// Returns `||self||^2 = norm()^2`.
221 ///
222 /// # Note
223 /// This method is mutable because it may need to canonicalize the network.
224 ///
225 /// # Errors
226 /// Returns an error if the network is empty or canonicalization fails.
227 ///
228 /// # Examples
229 ///
230 /// ```
231 /// use tensor4all_treetn::TreeTN;
232 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
233 ///
234 /// let s = DynIndex::new_dyn(2);
235 /// let t = TensorDynLen::from_dense(vec![s], vec![3.0_f64, 4.0]).unwrap();
236 /// let mut tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
237 ///
238 /// // ||[3, 4]||^2 = 9 + 16 = 25
239 /// let nsq = tn.norm_squared().unwrap();
240 /// assert!((nsq - 25.0).abs() < 1e-10);
241 /// ```
242 pub fn norm_squared(&mut self) -> Result<f64> {
243 let n = self
244 .norm()
245 .context("norm_squared: failed to compute norm")?;
246 Ok(n * n)
247 }
248
249 /// Scale the tensor network by a complex scalar.
250 ///
251 /// This multiplies a single node tensor, chosen deterministically as the
252 /// minimum-named node, so the represented state is scaled once rather than
253 /// applying `scalar^n` across all nodes.
254 ///
255 /// Scaling a non-center tensor generally invalidates any existing
256 /// canonicalization metadata, so this method clears the cached canonical
257 /// region and orthogonality directions after updating the tensor.
258 ///
259 /// # Arguments
260 /// * `scalar` - Scalar multiplier applied to the represented tensor network
261 ///
262 /// # Returns
263 /// `Ok(())` after the selected node tensor has been updated in place
264 ///
265 /// # Errors
266 /// Returns an error if the TreeTN is empty or the selected node/tensor
267 /// cannot be found
268 ///
269 /// # Examples
270 ///
271 /// ```
272 /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen, TensorIndex, TensorLike};
273 /// use tensor4all_treetn::TreeTN;
274 ///
275 /// let s = DynIndex::new_dyn(2);
276 /// let t = TensorDynLen::from_dense(vec![s], vec![1.0_f64, -2.0]).unwrap();
277 /// let mut tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
278 ///
279 /// tn.scale(AnyScalar::new_real(2.0)).unwrap();
280 ///
281 /// let dense = tn.to_dense().unwrap();
282 /// let expected = TensorDynLen::from_dense(
283 /// dense.external_indices(),
284 /// vec![2.0_f64, -4.0],
285 /// ).unwrap();
286 /// assert!((&dense - &expected).maxabs() < 1e-12);
287 /// ```
288 pub fn scale(&mut self, scalar: AnyScalar) -> Result<()> {
289 let min_node = self
290 .node_names()
291 .into_iter()
292 .min()
293 .ok_or_else(|| anyhow::anyhow!("Cannot scale empty TreeTN"))
294 .context("scale: network must have at least one node")?;
295 let node_idx = self
296 .node_index(&min_node)
297 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", min_node))
298 .context("scale: selected node must exist")?;
299 let tensor = self
300 .tensor(node_idx)
301 .ok_or_else(|| anyhow::anyhow!("Node tensor not found for {:?}", min_node))
302 .context("scale: selected node tensor must exist")?
303 .clone();
304 let scaled = tensor
305 .scale(scalar)
306 .context("scale: tensor scaling failed")?;
307 self.replace_tensor(node_idx, scaled)?
308 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", min_node))
309 .context("scale: failed to replace scaled tensor")?;
310
311 self.clear_canonical_region();
312 self.ortho_towards.clear();
313
314 Ok(())
315 }
316
317 /// Compute the inner product of two TreeTNs.
318 ///
319 /// Computes `<self | other>` = sum over all indices of `conj(self) * other`.
320 ///
321 /// Both TreeTNs must have the same site indices (same IDs).
322 /// Link indices may differ between the two TreeTNs.
323 ///
324 /// # Algorithm
325 /// 1. Replace link indices in `other` with fresh IDs to avoid collision.
326 /// 2. At each node, contract `conj(self_tensor) * other_tensor` pairwise.
327 /// 3. Sweep from leaves to root, contracting the environment.
328 ///
329 /// This is equivalent to contracting the entire network
330 /// `conj(self) * other` into a scalar.
331 ///
332 /// # Errors
333 /// Returns an error if the networks have incompatible topologies.
334 ///
335 /// # Examples
336 ///
337 /// ```
338 /// use tensor4all_treetn::TreeTN;
339 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
340 ///
341 /// let s = DynIndex::new_dyn(2);
342 /// let t = TensorDynLen::from_dense(vec![s], vec![3.0_f64, 4.0]).unwrap();
343 /// let tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
344 ///
345 /// // <v|v> = 3^2 + 4^2 = 25
346 /// let ip = tn.inner(&tn).unwrap();
347 /// assert!((ip.real() - 25.0).abs() < 1e-10);
348 /// ```
349 pub fn inner(&self, other: &Self) -> Result<AnyScalar>
350 where
351 <T::Index as IndexLike>::Id:
352 Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
353 {
354 if self.node_count() == 0 && other.node_count() == 0 {
355 return Ok(AnyScalar::new_real(0.0));
356 }
357 if !self.share_equivalent_site_index_network(other) {
358 return Err(anyhow::anyhow!(
359 "inner: TreeTNs must have the same topology and site indices"
360 ));
361 }
362
363 let root_name = self
364 .node_names()
365 .into_iter()
366 .min()
367 .ok_or_else(|| anyhow::anyhow!("Cannot compute inner product of empty TreeTN"))
368 .context("inner: network must have at least one node")?;
369 let other_sim = other.sim_internal_inds();
370
371 let post_order = self
372 .site_index_network()
373 .post_order_dfs(&root_name)
374 .ok_or_else(|| anyhow::anyhow!("Root node {:?} not found", root_name))
375 .context("inner: failed to build post-order traversal")?;
376
377 let mut parent_of: HashMap<V, Option<V>> = HashMap::new();
378 parent_of.insert(root_name.clone(), None);
379 let mut stack = vec![root_name.clone()];
380 while let Some(node_name) = stack.pop() {
381 let mut neighbors: Vec<V> = self.site_index_network().neighbors(&node_name).collect();
382 neighbors.sort();
383 for neighbor in neighbors {
384 if parent_of.contains_key(&neighbor) {
385 continue;
386 }
387 parent_of.insert(neighbor.clone(), Some(node_name.clone()));
388 stack.push(neighbor);
389 }
390 }
391
392 let mut envs: HashMap<V, T> = HashMap::new();
393
394 for node_name in post_order {
395 let node_idx_self = self
396 .node_index(&node_name)
397 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in self", node_name))
398 .context("inner: self node must exist")?;
399 let node_idx_other = other_sim
400 .node_index(&node_name)
401 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in other", node_name))
402 .context("inner: other node must exist")?;
403
404 let mut env = self
405 .tensor(node_idx_self)
406 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))
407 .context("inner: self tensor must exist")?
408 .conj();
409
410 let mut children: Vec<V> = parent_of
411 .iter()
412 .filter_map(|(child, parent)| {
413 if parent.as_ref() == Some(&node_name) {
414 Some(child.clone())
415 } else {
416 None
417 }
418 })
419 .collect();
420 children.sort();
421
422 for child_name in children {
423 let child_env = envs.remove(&child_name).ok_or_else(|| {
424 anyhow::anyhow!(
425 "Missing child environment for child {:?} of node {:?}",
426 child_name,
427 node_name
428 )
429 })?;
430 env = T::contract(&[&env, &child_env], AllowedPairs::All)
431 .context("inner: failed to absorb child environment")?;
432 }
433
434 let other_tensor = other_sim
435 .tensor(node_idx_other)
436 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))
437 .context("inner: other tensor must exist")?;
438 env = T::contract(&[&env, other_tensor], AllowedPairs::All)
439 .context("inner: failed to contract node bra-ket tensors")?;
440
441 envs.insert(node_name, env);
442 }
443
444 let result_tensor = envs
445 .remove(&root_name)
446 .ok_or_else(|| anyhow::anyhow!("Root environment was not produced"))
447 .context("inner: root contraction failed")?;
448 if !envs.is_empty() {
449 return Err(anyhow::anyhow!(
450 "inner: contraction left {} dangling environments",
451 envs.len()
452 ));
453 }
454
455 let scalar_one = T::scalar_one().context("inner: failed to create scalar_one")?;
456 scalar_one
457 .inner_product(&result_tensor)
458 .context("inner: failed to extract scalar value")
459 }
460
461 /// Convert the TreeTN to a single dense tensor.
462 ///
463 /// This contracts all tensors in the network along their link/bond indices,
464 /// producing a single tensor with only site (physical) indices.
465 ///
466 /// This is an alias for `contract_to_tensor()`.
467 ///
468 /// # Warning
469 /// This operation can be very expensive for large networks,
470 /// as the result size grows exponentially with the number of sites.
471 ///
472 /// # Errors
473 /// Returns an error if the network is empty or contraction fails.
474 ///
475 /// # Examples
476 ///
477 /// ```
478 /// use tensor4all_treetn::TreeTN;
479 /// use tensor4all_core::{DynIndex, TensorDynLen, TensorIndex, TensorLike};
480 ///
481 /// // Build a 2-node chain
482 /// let s0 = DynIndex::new_dyn(2);
483 /// let bond = DynIndex::new_dyn(2);
484 /// let s1 = DynIndex::new_dyn(2);
485 ///
486 /// // Identity matrices
487 /// let t0 = TensorDynLen::from_dense(
488 /// vec![s0.clone(), bond.clone()],
489 /// vec![1.0_f64, 0.0, 0.0, 1.0],
490 /// ).unwrap();
491 /// let t1 = TensorDynLen::from_dense(
492 /// vec![bond.clone(), s1.clone()],
493 /// vec![1.0_f64, 0.0, 0.0, 1.0],
494 /// ).unwrap();
495 ///
496 /// let tn = TreeTN::<_, String>::from_tensors(
497 /// vec![t0, t1],
498 /// vec!["A".to_string(), "B".to_string()],
499 /// ).unwrap();
500 ///
501 /// // Contract to a single dense tensor over site indices s0 and s1
502 /// let dense = tn.to_dense().unwrap();
503 /// // Result is rank-2 (two site indices s0 and s1)
504 /// assert_eq!(dense.num_external_indices(), 2);
505 /// ```
506 pub fn to_dense(&self) -> Result<T> {
507 self.contract_to_tensor()
508 .context("to_dense: failed to contract network to tensor")
509 }
510
511 /// Returns all site index IDs and their owning vertex names.
512 ///
513 /// Returns `(index_ids, vertex_names)` where `index_ids[i]` belongs to
514 /// vertex `vertex_names[i]`. Order is unspecified but consistent
515 /// between the two vectors.
516 ///
517 /// For [`evaluate()`](Self::evaluate), pass `index_ids` and arrange
518 /// values in the same order.
519 #[allow(clippy::type_complexity)]
520 pub fn all_site_index_ids(&self) -> Result<(Vec<<T::Index as IndexLike>::Id>, Vec<V>)>
521 where
522 V: Clone,
523 <T::Index as IndexLike>::Id: Clone,
524 {
525 let mut ids = Vec::new();
526 let mut vertex_names = Vec::new();
527 for node_name in self.node_names() {
528 let site_space = self
529 .site_space(&node_name)
530 .ok_or_else(|| anyhow::anyhow!("Site space not found for node {:?}", node_name))
531 .context("all_site_index_ids: site space must exist")?;
532 for index in site_space {
533 ids.push(index.id().clone());
534 vertex_names.push(node_name.clone());
535 }
536 }
537 Ok((ids, vertex_names))
538 }
539
540 /// Evaluate the TreeTN at multiple multi-indices (batch).
541 ///
542 /// # Arguments
543 /// * `index_ids` - Identifies each site index by its ID (from
544 /// [`all_site_index_ids()`](Self::all_site_index_ids)).
545 /// Must enumerate every site index exactly once.
546 /// * `values` - Column-major array of shape `[n_indices, n_points]`.
547 /// `values.get(&[i, p])` is the value of `index_ids[i]` at point `p`.
548 ///
549 /// # Returns
550 /// A `Vec<AnyScalar>` of length `n_points`.
551 ///
552 /// # Errors
553 /// Returns an error if:
554 /// - The network is empty
555 /// - `values` shape is inconsistent with `index_ids`
556 /// - An index ID is unknown
557 /// - Index values are out of bounds
558 /// - Contraction fails
559 pub fn evaluate(
560 &self,
561 index_ids: &[<T::Index as IndexLike>::Id],
562 values: ColMajorArrayRef<'_, usize>,
563 ) -> Result<Vec<AnyScalar>>
564 where
565 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
566 {
567 if self.node_count() == 0 {
568 return Err(anyhow::anyhow!("Cannot evaluate empty TreeTN"))
569 .context("evaluate: network must have at least one node");
570 }
571
572 let n_indices = index_ids.len();
573 anyhow::ensure!(
574 values.shape().len() == 2,
575 "evaluate: values must be 2D, got {}D",
576 values.shape().len()
577 );
578 anyhow::ensure!(
579 values.shape()[0] == n_indices,
580 "evaluate: values.shape()[0] ({}) != index_ids.len() ({})",
581 values.shape()[0],
582 n_indices
583 );
584 let n_points = values.shape()[1];
585
586 // Build index_id -> position lookup (Vec-based linear scan is fine for
587 // the small number of site indices typical in practice).
588 let mut known_ids: HashSet<<T::Index as IndexLike>::Id> = HashSet::new();
589 let mut total_site_indices: usize = 0;
590 for node_name in self.node_names() {
591 let site_space = self
592 .site_space(&node_name)
593 .ok_or_else(|| anyhow::anyhow!("Site space not found for node {:?}", node_name))
594 .context("evaluate: site space must exist")?;
595 for index in site_space {
596 known_ids.insert(index.id().clone());
597 total_site_indices += 1;
598 }
599 }
600
601 // Validate: index_ids.len() must equal total number of site indices.
602 anyhow::ensure!(
603 n_indices == total_site_indices,
604 "evaluate: index_ids.len() ({}) != total site indices ({})",
605 n_indices,
606 total_site_indices
607 );
608
609 // Validate: no duplicate index IDs.
610 {
611 let mut seen = HashSet::with_capacity(n_indices);
612 for id in index_ids {
613 anyhow::ensure!(seen.insert(id), "evaluate: duplicate index ID {:?}", id);
614 }
615 }
616
617 // Validate: all provided IDs must be known (exist in the network).
618 for id in index_ids {
619 anyhow::ensure!(
620 known_ids.contains(id),
621 "evaluate: unknown index ID {:?}",
622 id
623 );
624 }
625
626 // Pre-compute per-node data: (node_name, node_index, tensor_ref,
627 // site_entries: Vec<(Index, position_in_index_ids)>)
628 // This avoids HashMap lookups and repeated node_index/tensor lookups
629 // inside the per-point loop.
630 struct NodeEntry<'a, T: TensorLike, V> {
631 name: V,
632 tensor: &'a T,
633 /// (site_index, position in `index_ids`)
634 site_entries: Vec<(T::Index, usize)>,
635 }
636
637 let node_names = self.node_names();
638 let mut node_entries: Vec<NodeEntry<'_, T, V>> = Vec::with_capacity(node_names.len());
639
640 for node_name in &node_names {
641 let node_idx = self
642 .node_index(node_name)
643 .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", node_name))
644 .context("evaluate: node must exist")?;
645
646 let tensor = self
647 .tensor(node_idx)
648 .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))
649 .context("evaluate: tensor must exist")?;
650
651 let site_space = self.site_space(node_name);
652 let mut site_entries = Vec::new();
653 if let Some(space) = site_space {
654 for index in space {
655 let id = index.id();
656 let pos = index_ids
657 .iter()
658 .position(|x| x == id)
659 .ok_or_else(|| anyhow::anyhow!("Index ID {:?} not found in index_ids", id))
660 .context("evaluate: all site indices must be covered by index_ids")?;
661 site_entries.push((index.clone(), pos));
662 }
663 }
664
665 node_entries.push(NodeEntry {
666 name: node_name.clone(),
667 tensor,
668 site_entries,
669 });
670 }
671
672 let mut results = Vec::with_capacity(n_points);
673 for point in 0..n_points {
674 let mut contracted_tensors: Vec<T> = Vec::with_capacity(node_entries.len());
675 let mut contracted_names: Vec<V> = Vec::with_capacity(node_entries.len());
676
677 for entry in &node_entries {
678 if entry.site_entries.is_empty() {
679 // No site indices - just use the tensor as is
680 contracted_tensors.push(entry.tensor.clone());
681 contracted_names.push(entry.name.clone());
682 continue;
683 }
684
685 let index_vals: Vec<(T::Index, usize)> = entry
686 .site_entries
687 .iter()
688 .map(|(idx, pos)| {
689 let val = *values.get(&[*pos, point]).unwrap();
690 (idx.clone(), val)
691 })
692 .collect();
693
694 let onehot =
695 T::onehot(&index_vals).context("evaluate: failed to create one-hot tensor")?;
696
697 let result =
698 T::contract(&[entry.tensor, &onehot], tensor4all_core::AllowedPairs::All)
699 .context("evaluate: failed to contract tensor with one-hot")?;
700
701 contracted_tensors.push(result);
702 contracted_names.push(entry.name.clone());
703 }
704
705 // Build a temporary TreeTN from the contracted tensors and contract to scalar
706 let temp_tn = TreeTN::<T, V>::from_tensors(contracted_tensors, contracted_names)
707 .context("evaluate: failed to build temporary TreeTN")?;
708 let result_tensor = temp_tn
709 .contract_to_tensor()
710 .context("evaluate: failed to contract to scalar")?;
711
712 let scalar_one = T::scalar_one().context("evaluate: failed to create scalar_one")?;
713 let scalar = scalar_one
714 .inner_product(&result_tensor)
715 .context("evaluate: failed to extract scalar value")?;
716 results.push(scalar);
717 }
718
719 Ok(results)
720 }
721
722 /// Returns all site indices and their owning vertex names.
723 ///
724 /// Returns `(indices, vertex_names)` where `indices[i]` belongs to
725 /// vertex `vertex_names[i]`. Order is unspecified but consistent
726 /// between the two vectors.
727 ///
728 /// This is the `Index`-based counterpart of
729 /// [`all_site_index_ids()`](Self::all_site_index_ids), returning
730 /// full `Index` objects instead of raw IDs.
731 ///
732 /// # Errors
733 /// Returns an error if a node's site space cannot be found.
734 ///
735 /// # Examples
736 /// ```
737 /// use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorLike};
738 /// use tensor4all_treetn::TreeTN;
739 ///
740 /// let s0 = DynIndex::new_dyn(2);
741 /// let bond = DynIndex::new_dyn(3);
742 /// let s1 = DynIndex::new_dyn(2);
743 /// let t0 = TensorDynLen::from_dense(
744 /// vec![s0.clone(), bond.clone()], vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
745 /// ).unwrap();
746 /// let t1 = TensorDynLen::from_dense(
747 /// vec![bond.clone(), s1.clone()], vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
748 /// ).unwrap();
749 /// let tn = TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0, t1], vec![0, 1]).unwrap();
750 ///
751 /// let (indices, vertices) = tn.all_site_indices().unwrap();
752 /// assert_eq!(indices.len(), 2);
753 /// assert_eq!(vertices.len(), 2);
754 ///
755 /// // The returned indices contain both s0 and s1
756 /// let id_set: std::collections::HashSet<_> = indices.iter().map(|i| *i.id()).collect();
757 /// assert!(id_set.contains(s0.id()));
758 /// assert!(id_set.contains(s1.id()));
759 /// ```
760 #[allow(clippy::type_complexity)]
761 pub fn all_site_indices(&self) -> Result<(Vec<T::Index>, Vec<V>)>
762 where
763 V: Clone,
764 T::Index: Clone,
765 {
766 let mut indices = Vec::new();
767 let mut node_names = Vec::new();
768 for node_name in self.node_names() {
769 let site_space = self
770 .site_space(&node_name)
771 .ok_or_else(|| anyhow::anyhow!("Site space not found for node {:?}", node_name))
772 .context("all_site_indices: site space must exist")?;
773 for index in site_space {
774 indices.push(index.clone());
775 node_names.push(node_name.clone());
776 }
777 }
778 Ok((indices, node_names))
779 }
780
781 /// Evaluate the TreeTN at multiple multi-indices (batch), using
782 /// `Index` objects instead of raw IDs.
783 ///
784 /// This is a convenience wrapper around [`evaluate()`](Self::evaluate)
785 /// that accepts `&[T::Index]` directly, extracting the IDs
786 /// internally.
787 ///
788 /// # Arguments
789 /// * `indices` - Identifies each site index by its `Index` object
790 /// (e.g. from [`all_site_indices()`](Self::all_site_indices)).
791 /// Must enumerate every site index exactly once.
792 /// * `values` - Column-major array of shape `[n_indices, n_points]`.
793 /// `values.get(&[i, p])` is the value of `indices[i]` at point `p`.
794 ///
795 /// # Returns
796 /// A `Vec<AnyScalar>` of length `n_points`.
797 ///
798 /// # Errors
799 /// Returns an error if the underlying [`evaluate()`](Self::evaluate)
800 /// call fails (see its documentation for details).
801 ///
802 /// # Examples
803 /// ```
804 /// use tensor4all_core::{ColMajorArrayRef, DynIndex, IndexLike, TensorDynLen, TensorLike};
805 /// use tensor4all_treetn::TreeTN;
806 ///
807 /// let s0 = DynIndex::new_dyn(3);
808 /// let t0 = TensorDynLen::from_dense(vec![s0.clone()], vec![10.0, 20.0, 30.0]).unwrap();
809 /// let tn = TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0], vec![0]).unwrap();
810 ///
811 /// let (indices, _vertices) = tn.all_site_indices().unwrap();
812 ///
813 /// // Evaluate at index value 2
814 /// let data = [2usize];
815 /// let shape = [indices.len(), 1];
816 /// let values = ColMajorArrayRef::new(&data, &shape);
817 /// let result = tn.evaluate_at(&indices, values).unwrap();
818 /// assert!((result[0].real() - 30.0).abs() < 1e-10);
819 /// ```
820 pub fn evaluate_at(
821 &self,
822 indices: &[T::Index],
823 values: ColMajorArrayRef<'_, usize>,
824 ) -> Result<Vec<AnyScalar>>
825 where
826 <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
827 {
828 let index_ids: Vec<_> = indices.iter().map(|idx| idx.id().clone()).collect();
829 self.evaluate(&index_ids, values)
830 }
831}
832
833#[cfg(test)]
834mod tests;