Skip to main content

tensor4all_treetn/treetn/
ops.rs

1//! Trait implementations and operations for TreeTN.
2//!
3//! This module provides:
4//! - `Default` implementation
5//! - `Clone` implementation
6//! - `Debug` implementation
7//! - `log_norm` for computing the logarithm of the Frobenius norm
8//! - `norm`, `norm_squared` for computing the Frobenius norm
9//! - `inner` for computing inner products of two TreeTNs
10//! - `to_dense` for contracting to a single tensor
11//! - `evaluate` for evaluating at specific index values
12//! - `evaluate_at` for evaluating using `Index` objects instead of raw IDs
13//! - `all_site_indices` for retrieving all site indices and their owning vertices
14
15use std::collections::{HashMap, HashSet};
16use std::hash::Hash;
17
18use tensor4all_core::{AllowedPairs, AnyScalar, ColMajorArrayRef, IndexLike, TensorLike};
19
20use super::TreeTN;
21
22// ============================================================================
23// Default implementation
24// ============================================================================
25
26impl<T, V> Default for TreeTN<T, V>
27where
28    T: TensorLike,
29    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
30{
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36// ============================================================================
37// Clone implementation
38// ============================================================================
39
40impl<T, V> Clone for TreeTN<T, V>
41where
42    T: TensorLike,
43    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
44{
45    fn clone(&self) -> Self {
46        Self {
47            graph: self.graph.clone(),
48            canonical_region: self.canonical_region.clone(),
49            canonical_form: self.canonical_form,
50            site_index_network: self.site_index_network.clone(),
51            link_index_network: self.link_index_network.clone(),
52            ortho_towards: self.ortho_towards.clone(),
53        }
54    }
55}
56
57// ============================================================================
58// Debug implementation
59// ============================================================================
60
61impl<T, V> std::fmt::Debug for TreeTN<T, V>
62where
63    T: TensorLike,
64    V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
65{
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("TreeTN")
68            .field("node_count", &self.node_count())
69            .field("edge_count", &self.edge_count())
70            .field("canonical_region", &self.canonical_region)
71            .finish_non_exhaustive()
72    }
73}
74
75// ============================================================================
76// Norm Computation
77// ============================================================================
78
79use anyhow::{Context, Result};
80
81use crate::algorithm::CanonicalForm;
82use crate::CanonicalizationOptions;
83
84impl<T, V> TreeTN<T, V>
85where
86    T: TensorLike,
87    V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug,
88{
89    /// Compute log(||TreeTN||_F), the log of the Frobenius norm.
90    ///
91    /// Uses canonicalization to avoid numerical overflow:
92    /// when canonicalized to a single site with Unitary form,
93    /// the Frobenius norm of the whole network equals the norm of the center tensor.
94    ///
95    /// # Note
96    /// This method is mutable because it may need to canonicalize the network
97    /// to a single Unitary center. Use `log_norm` (without canonicalization) if you
98    /// already have a properly canonicalized network.
99    ///
100    /// # Returns
101    /// The natural logarithm of the Frobenius norm.
102    ///
103    /// # Errors
104    /// Returns an error if:
105    /// - The network is empty
106    /// - Canonicalization fails
107    ///
108    /// # Examples
109    ///
110    /// ```
111    /// use tensor4all_treetn::TreeTN;
112    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
113    ///
114    /// let s = DynIndex::new_dyn(2);
115    /// let t = TensorDynLen::from_dense(vec![s], vec![3.0_f64, 4.0]).unwrap();
116    /// let mut tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
117    ///
118    /// // log(||[3, 4]||) = log(5)
119    /// let ln = tn.log_norm().unwrap();
120    /// assert!((ln - 5.0_f64.ln()).abs() < 1e-10);
121    /// ```
122    pub fn log_norm(&mut self) -> Result<f64> {
123        let n = self.node_count();
124        if n == 0 {
125            return Err(anyhow::anyhow!("Cannot compute log_norm of empty TreeTN"))
126                .context("log_norm: network must have at least one node");
127        }
128
129        // Determine the single center site (by name)
130        let center_name: V =
131            if self.is_canonicalized() && self.canonical_form() == Some(CanonicalForm::Unitary) {
132                if self.canonical_region.len() == 1 {
133                    // Already Unitary canonicalized to single site - use it
134                    self.canonical_region.iter().next().unwrap().clone()
135                } else {
136                    // Unitary canonicalized to multiple sites - canonicalize to min site
137                    let min_center = self.canonical_region.iter().min().unwrap().clone();
138                    self.canonicalize_mut(
139                        std::iter::once(min_center.clone()),
140                        CanonicalizationOptions::default(),
141                    )
142                    .context("log_norm: failed to canonicalize to single site")?;
143                    min_center
144                }
145            } else {
146                // Not canonicalized or not Unitary - canonicalize to min node name
147                let min_node_name = self
148                    .node_names()
149                    .into_iter()
150                    .min()
151                    .ok_or_else(|| anyhow::anyhow!("No nodes in TreeTN"))
152                    .context("log_norm: network must have nodes")?;
153                self.canonicalize_mut(
154                    std::iter::once(min_node_name.clone()),
155                    CanonicalizationOptions::default(),
156                )
157                .context("log_norm: failed to canonicalize")?;
158                min_node_name
159            };
160
161        // Get center node index and tensor
162        let center_node = self
163            .node_index(&center_name)
164            .ok_or_else(|| anyhow::anyhow!("Center node not found"))
165            .context("log_norm: center node must exist")?;
166
167        let center_tensor = self
168            .tensor(center_node)
169            .ok_or_else(|| anyhow::anyhow!("Center tensor not found"))
170            .context("log_norm: center tensor must exist")?;
171
172        let norm_sq = center_tensor.norm_squared();
173        let norm = norm_sq.sqrt();
174
175        Ok(norm.ln())
176    }
177
178    /// Compute the Frobenius norm of the TreeTN.
179    ///
180    /// Uses `log_norm` internally: `norm = exp(log_norm)`.
181    ///
182    /// # Note
183    /// This method is mutable because it may need to canonicalize the network.
184    ///
185    /// # Errors
186    /// Returns an error if the network is empty or canonicalization fails.
187    ///
188    /// # Examples
189    ///
190    /// ```
191    /// use tensor4all_treetn::TreeTN;
192    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
193    ///
194    /// // Single-node TreeTN with tensor [1, 0, 0, 1] (identity 2x2)
195    /// let s0 = DynIndex::new_dyn(2);
196    /// let s1 = DynIndex::new_dyn(2);
197    /// let t = TensorDynLen::from_dense(
198    ///     vec![s0.clone(), s1.clone()],
199    ///     vec![1.0_f64, 0.0, 0.0, 1.0],
200    /// ).unwrap();
201    ///
202    /// let mut tn = TreeTN::<_, String>::from_tensors(
203    ///     vec![t],
204    ///     vec!["A".to_string()],
205    /// ).unwrap();
206    ///
207    /// // Frobenius norm of [[1,0],[0,1]] = sqrt(2)
208    /// let n = tn.norm().unwrap();
209    /// assert!((n - 2.0_f64.sqrt()).abs() < 1e-10);
210    /// ```
211    pub fn norm(&mut self) -> Result<f64> {
212        let log_n = self
213            .log_norm()
214            .context("norm: failed to compute log_norm")?;
215        Ok(log_n.exp())
216    }
217
218    /// Compute the squared Frobenius norm of the TreeTN.
219    ///
220    /// Returns `||self||^2 = norm()^2`.
221    ///
222    /// # Note
223    /// This method is mutable because it may need to canonicalize the network.
224    ///
225    /// # Errors
226    /// Returns an error if the network is empty or canonicalization fails.
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// use tensor4all_treetn::TreeTN;
232    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
233    ///
234    /// let s = DynIndex::new_dyn(2);
235    /// let t = TensorDynLen::from_dense(vec![s], vec![3.0_f64, 4.0]).unwrap();
236    /// let mut tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
237    ///
238    /// // ||[3, 4]||^2 = 9 + 16 = 25
239    /// let nsq = tn.norm_squared().unwrap();
240    /// assert!((nsq - 25.0).abs() < 1e-10);
241    /// ```
242    pub fn norm_squared(&mut self) -> Result<f64> {
243        let n = self
244            .norm()
245            .context("norm_squared: failed to compute norm")?;
246        Ok(n * n)
247    }
248
249    /// Scale the tensor network by a complex scalar.
250    ///
251    /// This multiplies a single node tensor, chosen deterministically as the
252    /// minimum-named node, so the represented state is scaled once rather than
253    /// applying `scalar^n` across all nodes.
254    ///
255    /// Scaling a non-center tensor generally invalidates any existing
256    /// canonicalization metadata, so this method clears the cached canonical
257    /// region and orthogonality directions after updating the tensor.
258    ///
259    /// # Arguments
260    /// * `scalar` - Scalar multiplier applied to the represented tensor network
261    ///
262    /// # Returns
263    /// `Ok(())` after the selected node tensor has been updated in place
264    ///
265    /// # Errors
266    /// Returns an error if the TreeTN is empty or the selected node/tensor
267    /// cannot be found
268    ///
269    /// # Examples
270    ///
271    /// ```
272    /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen, TensorIndex, TensorLike};
273    /// use tensor4all_treetn::TreeTN;
274    ///
275    /// let s = DynIndex::new_dyn(2);
276    /// let t = TensorDynLen::from_dense(vec![s], vec![1.0_f64, -2.0]).unwrap();
277    /// let mut tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
278    ///
279    /// tn.scale(AnyScalar::new_real(2.0)).unwrap();
280    ///
281    /// let dense = tn.to_dense().unwrap();
282    /// let expected = TensorDynLen::from_dense(
283    ///     dense.external_indices(),
284    ///     vec![2.0_f64, -4.0],
285    /// ).unwrap();
286    /// assert!((&dense - &expected).maxabs() < 1e-12);
287    /// ```
288    pub fn scale(&mut self, scalar: AnyScalar) -> Result<()> {
289        let min_node = self
290            .node_names()
291            .into_iter()
292            .min()
293            .ok_or_else(|| anyhow::anyhow!("Cannot scale empty TreeTN"))
294            .context("scale: network must have at least one node")?;
295        let node_idx = self
296            .node_index(&min_node)
297            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", min_node))
298            .context("scale: selected node must exist")?;
299        let tensor = self
300            .tensor(node_idx)
301            .ok_or_else(|| anyhow::anyhow!("Node tensor not found for {:?}", min_node))
302            .context("scale: selected node tensor must exist")?
303            .clone();
304        let scaled = tensor
305            .scale(scalar)
306            .context("scale: tensor scaling failed")?;
307        self.replace_tensor(node_idx, scaled)?
308            .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", min_node))
309            .context("scale: failed to replace scaled tensor")?;
310
311        self.clear_canonical_region();
312        self.ortho_towards.clear();
313
314        Ok(())
315    }
316
317    /// Compute the inner product of two TreeTNs.
318    ///
319    /// Computes `<self | other>` = sum over all indices of `conj(self) * other`.
320    ///
321    /// Both TreeTNs must have the same site indices (same IDs).
322    /// Link indices may differ between the two TreeTNs.
323    ///
324    /// # Algorithm
325    /// 1. Replace link indices in `other` with fresh IDs to avoid collision.
326    /// 2. At each node, contract `conj(self_tensor) * other_tensor` pairwise.
327    /// 3. Sweep from leaves to root, contracting the environment.
328    ///
329    /// This is equivalent to contracting the entire network
330    /// `conj(self) * other` into a scalar.
331    ///
332    /// # Errors
333    /// Returns an error if the networks have incompatible topologies.
334    ///
335    /// # Examples
336    ///
337    /// ```
338    /// use tensor4all_treetn::TreeTN;
339    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike};
340    ///
341    /// let s = DynIndex::new_dyn(2);
342    /// let t = TensorDynLen::from_dense(vec![s], vec![3.0_f64, 4.0]).unwrap();
343    /// let tn = TreeTN::<_, usize>::from_tensors(vec![t], vec![0]).unwrap();
344    ///
345    /// // <v|v> = 3^2 + 4^2 = 25
346    /// let ip = tn.inner(&tn).unwrap();
347    /// assert!((ip.real() - 25.0).abs() < 1e-10);
348    /// ```
349    pub fn inner(&self, other: &Self) -> Result<AnyScalar>
350    where
351        <T::Index as IndexLike>::Id:
352            Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
353    {
354        if self.node_count() == 0 && other.node_count() == 0 {
355            return Ok(AnyScalar::new_real(0.0));
356        }
357        if !self.share_equivalent_site_index_network(other) {
358            return Err(anyhow::anyhow!(
359                "inner: TreeTNs must have the same topology and site indices"
360            ));
361        }
362
363        let root_name = self
364            .node_names()
365            .into_iter()
366            .min()
367            .ok_or_else(|| anyhow::anyhow!("Cannot compute inner product of empty TreeTN"))
368            .context("inner: network must have at least one node")?;
369        let other_sim = other.sim_internal_inds();
370
371        let post_order = self
372            .site_index_network()
373            .post_order_dfs(&root_name)
374            .ok_or_else(|| anyhow::anyhow!("Root node {:?} not found", root_name))
375            .context("inner: failed to build post-order traversal")?;
376
377        let mut parent_of: HashMap<V, Option<V>> = HashMap::new();
378        parent_of.insert(root_name.clone(), None);
379        let mut stack = vec![root_name.clone()];
380        while let Some(node_name) = stack.pop() {
381            let mut neighbors: Vec<V> = self.site_index_network().neighbors(&node_name).collect();
382            neighbors.sort();
383            for neighbor in neighbors {
384                if parent_of.contains_key(&neighbor) {
385                    continue;
386                }
387                parent_of.insert(neighbor.clone(), Some(node_name.clone()));
388                stack.push(neighbor);
389            }
390        }
391
392        let mut envs: HashMap<V, T> = HashMap::new();
393
394        for node_name in post_order {
395            let node_idx_self = self
396                .node_index(&node_name)
397                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in self", node_name))
398                .context("inner: self node must exist")?;
399            let node_idx_other = other_sim
400                .node_index(&node_name)
401                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in other", node_name))
402                .context("inner: other node must exist")?;
403
404            let mut env = self
405                .tensor(node_idx_self)
406                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))
407                .context("inner: self tensor must exist")?
408                .conj();
409
410            let mut children: Vec<V> = parent_of
411                .iter()
412                .filter_map(|(child, parent)| {
413                    if parent.as_ref() == Some(&node_name) {
414                        Some(child.clone())
415                    } else {
416                        None
417                    }
418                })
419                .collect();
420            children.sort();
421
422            for child_name in children {
423                let child_env = envs.remove(&child_name).ok_or_else(|| {
424                    anyhow::anyhow!(
425                        "Missing child environment for child {:?} of node {:?}",
426                        child_name,
427                        node_name
428                    )
429                })?;
430                env = T::contract(&[&env, &child_env], AllowedPairs::All)
431                    .context("inner: failed to absorb child environment")?;
432            }
433
434            let other_tensor = other_sim
435                .tensor(node_idx_other)
436                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))
437                .context("inner: other tensor must exist")?;
438            env = T::contract(&[&env, other_tensor], AllowedPairs::All)
439                .context("inner: failed to contract node bra-ket tensors")?;
440
441            envs.insert(node_name, env);
442        }
443
444        let result_tensor = envs
445            .remove(&root_name)
446            .ok_or_else(|| anyhow::anyhow!("Root environment was not produced"))
447            .context("inner: root contraction failed")?;
448        if !envs.is_empty() {
449            return Err(anyhow::anyhow!(
450                "inner: contraction left {} dangling environments",
451                envs.len()
452            ));
453        }
454
455        let scalar_one = T::scalar_one().context("inner: failed to create scalar_one")?;
456        scalar_one
457            .inner_product(&result_tensor)
458            .context("inner: failed to extract scalar value")
459    }
460
461    /// Convert the TreeTN to a single dense tensor.
462    ///
463    /// This contracts all tensors in the network along their link/bond indices,
464    /// producing a single tensor with only site (physical) indices.
465    ///
466    /// This is an alias for `contract_to_tensor()`.
467    ///
468    /// # Warning
469    /// This operation can be very expensive for large networks,
470    /// as the result size grows exponentially with the number of sites.
471    ///
472    /// # Errors
473    /// Returns an error if the network is empty or contraction fails.
474    ///
475    /// # Examples
476    ///
477    /// ```
478    /// use tensor4all_treetn::TreeTN;
479    /// use tensor4all_core::{DynIndex, TensorDynLen, TensorIndex, TensorLike};
480    ///
481    /// // Build a 2-node chain
482    /// let s0 = DynIndex::new_dyn(2);
483    /// let bond = DynIndex::new_dyn(2);
484    /// let s1 = DynIndex::new_dyn(2);
485    ///
486    /// // Identity matrices
487    /// let t0 = TensorDynLen::from_dense(
488    ///     vec![s0.clone(), bond.clone()],
489    ///     vec![1.0_f64, 0.0, 0.0, 1.0],
490    /// ).unwrap();
491    /// let t1 = TensorDynLen::from_dense(
492    ///     vec![bond.clone(), s1.clone()],
493    ///     vec![1.0_f64, 0.0, 0.0, 1.0],
494    /// ).unwrap();
495    ///
496    /// let tn = TreeTN::<_, String>::from_tensors(
497    ///     vec![t0, t1],
498    ///     vec!["A".to_string(), "B".to_string()],
499    /// ).unwrap();
500    ///
501    /// // Contract to a single dense tensor over site indices s0 and s1
502    /// let dense = tn.to_dense().unwrap();
503    /// // Result is rank-2 (two site indices s0 and s1)
504    /// assert_eq!(dense.num_external_indices(), 2);
505    /// ```
506    pub fn to_dense(&self) -> Result<T> {
507        self.contract_to_tensor()
508            .context("to_dense: failed to contract network to tensor")
509    }
510
511    /// Returns all site index IDs and their owning vertex names.
512    ///
513    /// Returns `(index_ids, vertex_names)` where `index_ids[i]` belongs to
514    /// vertex `vertex_names[i]`. Order is unspecified but consistent
515    /// between the two vectors.
516    ///
517    /// For [`evaluate()`](Self::evaluate), pass `index_ids` and arrange
518    /// values in the same order.
519    #[allow(clippy::type_complexity)]
520    pub fn all_site_index_ids(&self) -> Result<(Vec<<T::Index as IndexLike>::Id>, Vec<V>)>
521    where
522        V: Clone,
523        <T::Index as IndexLike>::Id: Clone,
524    {
525        let mut ids = Vec::new();
526        let mut vertex_names = Vec::new();
527        for node_name in self.node_names() {
528            let site_space = self
529                .site_space(&node_name)
530                .ok_or_else(|| anyhow::anyhow!("Site space not found for node {:?}", node_name))
531                .context("all_site_index_ids: site space must exist")?;
532            for index in site_space {
533                ids.push(index.id().clone());
534                vertex_names.push(node_name.clone());
535            }
536        }
537        Ok((ids, vertex_names))
538    }
539
540    /// Evaluate the TreeTN at multiple multi-indices (batch).
541    ///
542    /// # Arguments
543    /// * `index_ids` - Identifies each site index by its ID (from
544    ///   [`all_site_index_ids()`](Self::all_site_index_ids)).
545    ///   Must enumerate every site index exactly once.
546    /// * `values` - Column-major array of shape `[n_indices, n_points]`.
547    ///   `values.get(&[i, p])` is the value of `index_ids[i]` at point `p`.
548    ///
549    /// # Returns
550    /// A `Vec<AnyScalar>` of length `n_points`.
551    ///
552    /// # Errors
553    /// Returns an error if:
554    /// - The network is empty
555    /// - `values` shape is inconsistent with `index_ids`
556    /// - An index ID is unknown
557    /// - Index values are out of bounds
558    /// - Contraction fails
559    pub fn evaluate(
560        &self,
561        index_ids: &[<T::Index as IndexLike>::Id],
562        values: ColMajorArrayRef<'_, usize>,
563    ) -> Result<Vec<AnyScalar>>
564    where
565        <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
566    {
567        if self.node_count() == 0 {
568            return Err(anyhow::anyhow!("Cannot evaluate empty TreeTN"))
569                .context("evaluate: network must have at least one node");
570        }
571
572        let n_indices = index_ids.len();
573        anyhow::ensure!(
574            values.shape().len() == 2,
575            "evaluate: values must be 2D, got {}D",
576            values.shape().len()
577        );
578        anyhow::ensure!(
579            values.shape()[0] == n_indices,
580            "evaluate: values.shape()[0] ({}) != index_ids.len() ({})",
581            values.shape()[0],
582            n_indices
583        );
584        let n_points = values.shape()[1];
585
586        // Build index_id -> position lookup (Vec-based linear scan is fine for
587        // the small number of site indices typical in practice).
588        let mut known_ids: HashSet<<T::Index as IndexLike>::Id> = HashSet::new();
589        let mut total_site_indices: usize = 0;
590        for node_name in self.node_names() {
591            let site_space = self
592                .site_space(&node_name)
593                .ok_or_else(|| anyhow::anyhow!("Site space not found for node {:?}", node_name))
594                .context("evaluate: site space must exist")?;
595            for index in site_space {
596                known_ids.insert(index.id().clone());
597                total_site_indices += 1;
598            }
599        }
600
601        // Validate: index_ids.len() must equal total number of site indices.
602        anyhow::ensure!(
603            n_indices == total_site_indices,
604            "evaluate: index_ids.len() ({}) != total site indices ({})",
605            n_indices,
606            total_site_indices
607        );
608
609        // Validate: no duplicate index IDs.
610        {
611            let mut seen = HashSet::with_capacity(n_indices);
612            for id in index_ids {
613                anyhow::ensure!(seen.insert(id), "evaluate: duplicate index ID {:?}", id);
614            }
615        }
616
617        // Validate: all provided IDs must be known (exist in the network).
618        for id in index_ids {
619            anyhow::ensure!(
620                known_ids.contains(id),
621                "evaluate: unknown index ID {:?}",
622                id
623            );
624        }
625
626        // Pre-compute per-node data: (node_name, node_index, tensor_ref,
627        //   site_entries: Vec<(Index, position_in_index_ids)>)
628        // This avoids HashMap lookups and repeated node_index/tensor lookups
629        // inside the per-point loop.
630        struct NodeEntry<'a, T: TensorLike, V> {
631            name: V,
632            tensor: &'a T,
633            /// (site_index, position in `index_ids`)
634            site_entries: Vec<(T::Index, usize)>,
635        }
636
637        let node_names = self.node_names();
638        let mut node_entries: Vec<NodeEntry<'_, T, V>> = Vec::with_capacity(node_names.len());
639
640        for node_name in &node_names {
641            let node_idx = self
642                .node_index(node_name)
643                .ok_or_else(|| anyhow::anyhow!("Node {:?} not found", node_name))
644                .context("evaluate: node must exist")?;
645
646            let tensor = self
647                .tensor(node_idx)
648                .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name))
649                .context("evaluate: tensor must exist")?;
650
651            let site_space = self.site_space(node_name);
652            let mut site_entries = Vec::new();
653            if let Some(space) = site_space {
654                for index in space {
655                    let id = index.id();
656                    let pos = index_ids
657                        .iter()
658                        .position(|x| x == id)
659                        .ok_or_else(|| anyhow::anyhow!("Index ID {:?} not found in index_ids", id))
660                        .context("evaluate: all site indices must be covered by index_ids")?;
661                    site_entries.push((index.clone(), pos));
662                }
663            }
664
665            node_entries.push(NodeEntry {
666                name: node_name.clone(),
667                tensor,
668                site_entries,
669            });
670        }
671
672        let mut results = Vec::with_capacity(n_points);
673        for point in 0..n_points {
674            let mut contracted_tensors: Vec<T> = Vec::with_capacity(node_entries.len());
675            let mut contracted_names: Vec<V> = Vec::with_capacity(node_entries.len());
676
677            for entry in &node_entries {
678                if entry.site_entries.is_empty() {
679                    // No site indices - just use the tensor as is
680                    contracted_tensors.push(entry.tensor.clone());
681                    contracted_names.push(entry.name.clone());
682                    continue;
683                }
684
685                let index_vals: Vec<(T::Index, usize)> = entry
686                    .site_entries
687                    .iter()
688                    .map(|(idx, pos)| {
689                        let val = *values.get(&[*pos, point]).unwrap();
690                        (idx.clone(), val)
691                    })
692                    .collect();
693
694                let onehot =
695                    T::onehot(&index_vals).context("evaluate: failed to create one-hot tensor")?;
696
697                let result =
698                    T::contract(&[entry.tensor, &onehot], tensor4all_core::AllowedPairs::All)
699                        .context("evaluate: failed to contract tensor with one-hot")?;
700
701                contracted_tensors.push(result);
702                contracted_names.push(entry.name.clone());
703            }
704
705            // Build a temporary TreeTN from the contracted tensors and contract to scalar
706            let temp_tn = TreeTN::<T, V>::from_tensors(contracted_tensors, contracted_names)
707                .context("evaluate: failed to build temporary TreeTN")?;
708            let result_tensor = temp_tn
709                .contract_to_tensor()
710                .context("evaluate: failed to contract to scalar")?;
711
712            let scalar_one = T::scalar_one().context("evaluate: failed to create scalar_one")?;
713            let scalar = scalar_one
714                .inner_product(&result_tensor)
715                .context("evaluate: failed to extract scalar value")?;
716            results.push(scalar);
717        }
718
719        Ok(results)
720    }
721
722    /// Returns all site indices and their owning vertex names.
723    ///
724    /// Returns `(indices, vertex_names)` where `indices[i]` belongs to
725    /// vertex `vertex_names[i]`. Order is unspecified but consistent
726    /// between the two vectors.
727    ///
728    /// This is the `Index`-based counterpart of
729    /// [`all_site_index_ids()`](Self::all_site_index_ids), returning
730    /// full `Index` objects instead of raw IDs.
731    ///
732    /// # Errors
733    /// Returns an error if a node's site space cannot be found.
734    ///
735    /// # Examples
736    /// ```
737    /// use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorLike};
738    /// use tensor4all_treetn::TreeTN;
739    ///
740    /// let s0 = DynIndex::new_dyn(2);
741    /// let bond = DynIndex::new_dyn(3);
742    /// let s1 = DynIndex::new_dyn(2);
743    /// let t0 = TensorDynLen::from_dense(
744    ///     vec![s0.clone(), bond.clone()], vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
745    /// ).unwrap();
746    /// let t1 = TensorDynLen::from_dense(
747    ///     vec![bond.clone(), s1.clone()], vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
748    /// ).unwrap();
749    /// let tn = TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0, t1], vec![0, 1]).unwrap();
750    ///
751    /// let (indices, vertices) = tn.all_site_indices().unwrap();
752    /// assert_eq!(indices.len(), 2);
753    /// assert_eq!(vertices.len(), 2);
754    ///
755    /// // The returned indices contain both s0 and s1
756    /// let id_set: std::collections::HashSet<_> = indices.iter().map(|i| *i.id()).collect();
757    /// assert!(id_set.contains(s0.id()));
758    /// assert!(id_set.contains(s1.id()));
759    /// ```
760    #[allow(clippy::type_complexity)]
761    pub fn all_site_indices(&self) -> Result<(Vec<T::Index>, Vec<V>)>
762    where
763        V: Clone,
764        T::Index: Clone,
765    {
766        let mut indices = Vec::new();
767        let mut node_names = Vec::new();
768        for node_name in self.node_names() {
769            let site_space = self
770                .site_space(&node_name)
771                .ok_or_else(|| anyhow::anyhow!("Site space not found for node {:?}", node_name))
772                .context("all_site_indices: site space must exist")?;
773            for index in site_space {
774                indices.push(index.clone());
775                node_names.push(node_name.clone());
776            }
777        }
778        Ok((indices, node_names))
779    }
780
781    /// Evaluate the TreeTN at multiple multi-indices (batch), using
782    /// `Index` objects instead of raw IDs.
783    ///
784    /// This is a convenience wrapper around [`evaluate()`](Self::evaluate)
785    /// that accepts `&[T::Index]` directly, extracting the IDs
786    /// internally.
787    ///
788    /// # Arguments
789    /// * `indices` - Identifies each site index by its `Index` object
790    ///   (e.g. from [`all_site_indices()`](Self::all_site_indices)).
791    ///   Must enumerate every site index exactly once.
792    /// * `values` - Column-major array of shape `[n_indices, n_points]`.
793    ///   `values.get(&[i, p])` is the value of `indices[i]` at point `p`.
794    ///
795    /// # Returns
796    /// A `Vec<AnyScalar>` of length `n_points`.
797    ///
798    /// # Errors
799    /// Returns an error if the underlying [`evaluate()`](Self::evaluate)
800    /// call fails (see its documentation for details).
801    ///
802    /// # Examples
803    /// ```
804    /// use tensor4all_core::{ColMajorArrayRef, DynIndex, IndexLike, TensorDynLen, TensorLike};
805    /// use tensor4all_treetn::TreeTN;
806    ///
807    /// let s0 = DynIndex::new_dyn(3);
808    /// let t0 = TensorDynLen::from_dense(vec![s0.clone()], vec![10.0, 20.0, 30.0]).unwrap();
809    /// let tn = TreeTN::<TensorDynLen, usize>::from_tensors(vec![t0], vec![0]).unwrap();
810    ///
811    /// let (indices, _vertices) = tn.all_site_indices().unwrap();
812    ///
813    /// // Evaluate at index value 2
814    /// let data = [2usize];
815    /// let shape = [indices.len(), 1];
816    /// let values = ColMajorArrayRef::new(&data, &shape);
817    /// let result = tn.evaluate_at(&indices, values).unwrap();
818    /// assert!((result[0].real() - 30.0).abs() < 1e-10);
819    /// ```
820    pub fn evaluate_at(
821        &self,
822        indices: &[T::Index],
823        values: ColMajorArrayRef<'_, usize>,
824    ) -> Result<Vec<AnyScalar>>
825    where
826        <T::Index as IndexLike>::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
827    {
828        let index_ids: Vec<_> = indices.iter().map(|idx| idx.id().clone()).collect();
829        self.evaluate(&index_ids, values)
830    }
831}
832
833#[cfg(test)]
834mod tests;