tensor4all_core/
tensor_like.rs

1//! TensorLike trait for unifying tensor types.
2//!
3//! This module provides a fully generic trait for tensor-like objects that expose
4//! external indices and support contraction operations.
5//!
6//! # Design
7//!
8//! The trait is **fully generic** (monomorphic), meaning:
9//! - No trait objects (`dyn TensorLike`)
10//! - Uses associated type for `Index`
11//! - All methods return `Self` instead of concrete types
12//!
13//! For heterogeneous tensor collections, use an enum wrapper.
14
15use crate::any_scalar::AnyScalar;
16use crate::tensor_index::TensorIndex;
17use anyhow::Result;
18use std::fmt::Debug;
19
20// ============================================================================
21// Factorization types (non-generic, algorithm-specific)
22// ============================================================================
23
24use thiserror::Error;
25
26/// Error type for factorize operations.
27#[derive(Debug, Error)]
28pub enum FactorizeError {
29    /// Factorization computation failed.
30    #[error("Factorization failed: {0}")]
31    ComputationError(
32        /// The underlying error
33        #[from]
34        anyhow::Error,
35    ),
36    /// Invalid relative tolerance value (must be finite and non-negative).
37    #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
38    InvalidRtol(
39        /// The invalid rtol value
40        f64,
41    ),
42    /// The storage type is not supported for this operation.
43    #[error("Unsupported storage type: {0}")]
44    UnsupportedStorage(
45        /// Description of the unsupported storage type
46        &'static str,
47    ),
48    /// The canonical direction is not supported for this algorithm.
49    #[error("Unsupported canonical direction for this algorithm: {0}")]
50    UnsupportedCanonical(
51        /// Description of the unsupported canonical direction
52        &'static str,
53    ),
54    /// Error from SVD operation.
55    #[error("SVD error: {0}")]
56    SvdError(
57        /// The underlying SVD error
58        #[from]
59        crate::svd::SvdError,
60    ),
61    /// Error from QR operation.
62    #[error("QR error: {0}")]
63    QrError(
64        /// The underlying QR error
65        #[from]
66        crate::qr::QrError,
67    ),
68    /// Error from matrix CI operation.
69    #[error("Matrix CI error: {0}")]
70    MatrixCIError(
71        /// The underlying matrix CI error
72        #[from]
73        tensor4all_tcicore::MatrixCIError,
74    ),
75}
76
77/// Factorization algorithm.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
79pub enum FactorizeAlg {
80    /// Singular Value Decomposition.
81    #[default]
82    SVD,
83    /// QR decomposition.
84    QR,
85    /// Rank-revealing LU decomposition.
86    LU,
87    /// Cross Interpolation (LU-based).
88    CI,
89}
90
91/// Canonical direction for factorization.
92///
93/// This determines which factor is "canonical" (orthogonal for SVD/QR,
94/// or unit-diagonal for LU/CI).
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
96pub enum Canonical {
97    /// Left factor is canonical.
98    /// - SVD: L=U (orthogonal), R=S*V
99    /// - QR: L=Q (orthogonal), R=R
100    /// - LU/CI: L has unit diagonal
101    #[default]
102    Left,
103    /// Right factor is canonical.
104    /// - SVD: L=U*S, R=V (orthogonal)
105    /// - QR: Not supported (would need LQ)
106    /// - LU/CI: U has unit diagonal
107    Right,
108}
109
110/// Options for tensor factorization.
111#[derive(Debug, Clone)]
112pub struct FactorizeOptions {
113    /// Factorization algorithm to use.
114    pub alg: FactorizeAlg,
115    /// Canonical direction.
116    pub canonical: Canonical,
117    /// Relative tolerance for truncation.
118    /// If `None`, uses the algorithm's default.
119    pub rtol: Option<f64>,
120    /// Maximum rank for truncation.
121    /// If `None`, no rank limit is applied.
122    pub max_rank: Option<usize>,
123}
124
125impl Default for FactorizeOptions {
126    fn default() -> Self {
127        Self {
128            alg: FactorizeAlg::SVD,
129            canonical: Canonical::Left,
130            rtol: None,
131            max_rank: None,
132        }
133    }
134}
135
136impl FactorizeOptions {
137    /// Create options for SVD factorization.
138    pub fn svd() -> Self {
139        Self {
140            alg: FactorizeAlg::SVD,
141            ..Default::default()
142        }
143    }
144
145    /// Create options for QR factorization.
146    pub fn qr() -> Self {
147        Self {
148            alg: FactorizeAlg::QR,
149            ..Default::default()
150        }
151    }
152
153    /// Create options for LU factorization.
154    pub fn lu() -> Self {
155        Self {
156            alg: FactorizeAlg::LU,
157            ..Default::default()
158        }
159    }
160
161    /// Create options for CI factorization.
162    pub fn ci() -> Self {
163        Self {
164            alg: FactorizeAlg::CI,
165            ..Default::default()
166        }
167    }
168
169    /// Set canonical direction.
170    pub fn with_canonical(mut self, canonical: Canonical) -> Self {
171        self.canonical = canonical;
172        self
173    }
174
175    /// Set relative tolerance.
176    pub fn with_rtol(mut self, rtol: f64) -> Self {
177        self.rtol = Some(rtol);
178        self
179    }
180
181    /// Set maximum rank.
182    pub fn with_max_rank(mut self, max_rank: usize) -> Self {
183        self.max_rank = Some(max_rank);
184        self
185    }
186}
187
188/// Result of tensor factorization.
189///
190/// Generic over the tensor type `T`.
191#[derive(Debug, Clone)]
192pub struct FactorizeResult<T: TensorLike> {
193    /// Left factor tensor.
194    pub left: T,
195    /// Right factor tensor.
196    pub right: T,
197    /// Bond index connecting left and right factors.
198    pub bond_index: T::Index,
199    /// Singular values (only for SVD).
200    pub singular_values: Option<Vec<f64>>,
201    /// Rank of the factorization.
202    pub rank: usize,
203}
204
205// ============================================================================
206// Contraction types
207// ============================================================================
208
209/// Specifies which tensor pairs are allowed to contract.
210///
211/// This enum controls which tensor pairs can have their indices contracted
212/// in multi-tensor contraction operations. This is useful for tensor networks
213/// where the graph structure determines which tensors are connected.
214///
215/// # Example
216///
217/// ```ignore
218/// use tensor4all_core::{TensorLike, AllowedPairs};
219///
220/// // Contract all contractable index pairs (default behavior)
221/// let tensor_refs: Vec<&T> = tensors.iter().collect();
222/// let result = T::contract(&tensor_refs, AllowedPairs::All)?;
223///
224/// // Only contract indices between specified tensor pairs
225/// let edges = vec![(0, 1), (1, 2)];  // tensor 0-1 and tensor 1-2
226/// let result = T::contract(&tensor_refs, AllowedPairs::Specified(&edges))?;
227/// ```
228#[derive(Debug, Clone, Copy)]
229pub enum AllowedPairs<'a> {
230    /// All tensor pairs are allowed to contract.
231    ///
232    /// Indices with matching IDs across any two tensors will be contracted.
233    /// This is the default behavior, equivalent to ITensor's `*` operator.
234    All,
235    /// Only specified tensor pairs are allowed to contract.
236    ///
237    /// Each pair is `(tensor_idx_a, tensor_idx_b)` into the input tensor slice.
238    /// Indices are only contracted if they belong to an allowed pair.
239    ///
240    /// This is useful for tensor networks where the graph structure
241    /// determines which tensors are connected (e.g., TreeTN edges).
242    Specified(&'a [(usize, usize)]),
243}
244
245// ============================================================================
246// TensorLike trait (fully generic)
247// ============================================================================
248
249/// Trait for tensor-like objects that expose external indices and support contraction.
250///
251/// This trait is **fully generic** (monomorphic), meaning it does not support
252/// trait objects (`dyn TensorLike`). For heterogeneous tensor collections,
253/// use an enum wrapper instead.
254///
255/// # Design Principles
256///
257/// - **Minimal interface**: Only external indices and automatic contraction
258/// - **Fully generic**: Uses associated type for `Index`, returns `Self`
259/// - **Stable ordering**: `external_indices()` returns indices in deterministic order
260/// - **No trait objects**: Requires `Sized`, cannot use `dyn TensorLike`
261///
262/// # Example
263///
264/// ```ignore
265/// use tensor4all_core::{TensorLike, AllowedPairs};
266///
267/// fn contract_pair<T: TensorLike>(a: &T, b: &T) -> Result<T> {
268///     T::contract(&[a, b], AllowedPairs::All)
269/// }
270/// ```
271///
272/// # Heterogeneous Collections
273///
274/// For mixing different tensor types, define an enum:
275///
276/// ```ignore
277/// enum TensorNetwork {
278///     Dense(TensorDynLen),
279///     MPS(MatrixProductState),
280/// }
281/// ```
282///
283/// # Supertrait
284///
285/// `TensorLike` extends `TensorIndex`, which provides:
286/// - `external_indices()` - Get all external indices
287/// - `num_external_indices()` - Count external indices
288/// - `replaceind()` / `replaceinds()` - Replace indices
289///
290/// This separation allows tensor networks (like `TreeTN`) to implement
291/// index operations without implementing contraction/factorization.
292pub trait TensorLike: TensorIndex {
293    /// Factorize this tensor into left and right factors.
294    ///
295    /// This function dispatches to the appropriate algorithm based on `options.alg`:
296    /// - `SVD`: Singular Value Decomposition
297    /// - `QR`: QR decomposition
298    /// - `LU`: Rank-revealing LU decomposition
299    /// - `CI`: Cross Interpolation
300    ///
301    /// The `canonical` option controls which factor is "canonical":
302    /// - `Canonical::Left`: Left factor is orthogonal (SVD/QR) or unit-diagonal (LU/CI)
303    /// - `Canonical::Right`: Right factor is orthogonal (SVD) or unit-diagonal (LU/CI)
304    ///
305    /// # Arguments
306    /// * `left_inds` - Indices to place on the left side
307    /// * `options` - Factorization options
308    ///
309    /// # Returns
310    /// A `FactorizeResult` containing the left and right factors, bond index,
311    /// singular values (for SVD), and rank.
312    ///
313    /// # Errors
314    /// Returns `FactorizeError` if:
315    /// - The storage type is not supported (only DenseF64 and DenseC64)
316    /// - QR is used with `Canonical::Right`
317    /// - The underlying algorithm fails
318    fn factorize(
319        &self,
320        left_inds: &[<Self as TensorIndex>::Index],
321        options: &FactorizeOptions,
322    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError>;
323
324    /// Tensor conjugate operation.
325    ///
326    /// This is a generalized conjugate operation that depends on the tensor type:
327    /// - For dense tensors (TensorDynLen): element-wise complex conjugate
328    /// - For symmetric tensors: tensor conjugate considering symmetry sectors
329    ///
330    /// This operation is essential for computing inner products and overlaps
331    /// in tensor network algorithms like fitting.
332    ///
333    /// # Returns
334    /// A new tensor representing the tensor conjugate.
335    fn conj(&self) -> Self;
336
337    /// Direct sum of two tensors along specified index pairs.
338    ///
339    /// For tensors A and B with indices to be summed specified as pairs,
340    /// creates a new tensor C where each paired index has dimension = dim_A + dim_B.
341    /// Non-paired indices must match exactly between A and B (same ID).
342    ///
343    /// # Arguments
344    ///
345    /// * `other` - Second tensor
346    /// * `pairs` - Pairs of (self_index, other_index) to be summed. Each pair creates
347    ///   a new index in the result with dimension = dim(self_index) + dim(other_index).
348    ///
349    /// # Returns
350    ///
351    /// A `DirectSumResult` containing the result tensor and new indices created
352    /// for the summed dimensions (one per pair).
353    ///
354    /// # Example
355    ///
356    /// ```ignore
357    /// // A has indices [i, j] with dims [2, 3]
358    /// // B has indices [i, k] with dims [2, 4]
359    /// // If we pair (j, k), result has indices [i, m] with dims [2, 7]
360    /// // where m is a new index with dim = 3 + 4 = 7
361    /// let result = a.direct_sum(&b, &[(j, k)])?;
362    /// ```
363    fn direct_sum(
364        &self,
365        other: &Self,
366        pairs: &[(<Self as TensorIndex>::Index, <Self as TensorIndex>::Index)],
367    ) -> Result<DirectSumResult<Self>>;
368
369    /// Outer product (tensor product) of two tensors.
370    ///
371    /// Computes the tensor product of `self` and `other`, resulting in a tensor
372    /// with all indices from both tensors. No indices are contracted.
373    ///
374    /// # Arguments
375    ///
376    /// * `other` - The other tensor to compute outer product with
377    ///
378    /// # Returns
379    ///
380    /// A new tensor representing the outer product.
381    ///
382    /// # Errors
383    ///
384    /// Returns an error if the tensors have common indices (by ID).
385    /// Use `tensordot` for contraction when indices overlap.
386    fn outer_product(&self, other: &Self) -> Result<Self>;
387
388    /// Compute the squared Frobenius norm of the tensor.
389    ///
390    /// The squared Frobenius norm is defined as the sum of squared absolute values
391    /// of all tensor elements: `||T||_F^2 = sum_i |T_i|^2`.
392    ///
393    /// This is used for computing norms in tensor network algorithms,
394    /// convergence checks, and normalization.
395    ///
396    /// # Returns
397    /// The squared Frobenius norm as a non-negative f64.
398    fn norm_squared(&self) -> f64;
399
400    /// Permute tensor indices to match the specified order.
401    ///
402    /// This reorders the tensor's axes to match the order specified by `new_order`.
403    /// The indices in `new_order` are matched by ID with the tensor's current indices.
404    ///
405    /// # Arguments
406    ///
407    /// * `new_order` - The desired order of indices (matched by ID)
408    ///
409    /// # Returns
410    ///
411    /// A new tensor with permuted indices.
412    ///
413    /// # Errors
414    ///
415    /// Returns an error if:
416    /// - The number of indices doesn't match
417    /// - An index ID in `new_order` is not found in the tensor
418    fn permuteinds(&self, new_order: &[<Self as TensorIndex>::Index]) -> Result<Self>;
419
420    /// Contract multiple tensors over their contractable indices.
421    ///
422    /// This method contracts 2 or more tensors. Pairs of indices that satisfy
423    /// `is_contractable()` (same ID, same dimension, compatible ConjState)
424    /// are contracted based on the `allowed` parameter.
425    ///
426    /// Handles disconnected tensor graphs automatically by:
427    /// 1. Finding connected components based on contractable indices
428    /// 2. Contracting each connected component separately
429    /// 3. Combining results using outer product
430    ///
431    /// # Arguments
432    ///
433    /// * `tensors` - Slice of tensor references to contract (must have length >= 1)
434    /// * `allowed` - Specifies which tensor pairs can have their indices contracted:
435    ///   - `AllowedPairs::All`: Contract all contractable index pairs (default behavior)
436    ///   - `AllowedPairs::Specified(&[(i, j)])`: Only contract indices between specified tensor pairs
437    ///
438    /// # Returns
439    ///
440    /// A new tensor representing the contracted result.
441    /// If tensors form disconnected components, they are combined via outer product.
442    ///
443    /// # Behavior by N
444    /// - N=0: Error
445    /// - N=1: Clone of input
446    /// - N>=2: Contract connected components, combine with outer product
447    ///
448    /// # Errors
449    ///
450    /// Returns an error if:
451    /// - No tensors are provided
452    /// - `AllowedPairs::Specified` contains a pair with no contractable indices
453    ///
454    /// # Example
455    ///
456    /// ```ignore
457    /// // Contract all contractable pairs
458    /// let result = T::contract(&[&a, &b, &c], AllowedPairs::All)?;
459    ///
460    /// // Only contract between tensor pairs (0,1) and (1,2)
461    /// let result = T::contract(&[&a, &b, &c], AllowedPairs::Specified(&[(0, 1), (1, 2)]))?;
462    /// ```
463    fn contract(tensors: &[&Self], allowed: AllowedPairs<'_>) -> Result<Self>;
464
465    /// Contract multiple tensors that must form a connected graph.
466    ///
467    /// This is the core contraction method that requires all tensors to be
468    /// connected through contractable indices. Use [`Self::contract`] if you want
469    /// automatic handling of disconnected components via outer product.
470    ///
471    /// # Arguments
472    ///
473    /// * `tensors` - Slice of tensor references to contract (must form a connected graph)
474    /// * `allowed` - Specifies which tensor pairs can have their indices contracted
475    ///
476    /// # Returns
477    ///
478    /// A new tensor representing the contracted result.
479    ///
480    /// # Errors
481    ///
482    /// Returns an error if:
483    /// - No tensors are provided
484    /// - The tensors form a disconnected graph
485    ///
486    /// # Example
487    ///
488    /// ```ignore
489    /// // All tensors must be connected through contractable indices
490    /// let result = T::contract_connected(&[&a, &b, &c], AllowedPairs::All)?;
491    /// ```
492    fn contract_connected(tensors: &[&Self], allowed: AllowedPairs<'_>) -> Result<Self>;
493
494    // ========================================================================
495    // Vector space operations (for Krylov solvers)
496    // ========================================================================
497
498    /// Compute a linear combination: `a * self + b * other`.
499    ///
500    /// This is the fundamental vector space operation.
501    fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self>;
502
503    /// Scalar multiplication.
504    fn scale(&self, scalar: AnyScalar) -> Result<Self>;
505
506    /// Inner product (dot product) of two tensors.
507    ///
508    /// Computes `⟨self, other⟩ = Σ conj(self)_i * other_i`.
509    fn inner_product(&self, other: &Self) -> Result<AnyScalar>;
510
511    /// Compute the Frobenius norm of the tensor.
512    fn norm(&self) -> f64 {
513        self.norm_squared().sqrt()
514    }
515
516    /// Maximum absolute value of all elements (L-infinity norm).
517    fn maxabs(&self) -> f64;
518
519    /// Element-wise subtraction: `self - other`.
520    ///
521    /// Indices are automatically permuted to match `self`'s order via `axpby`.
522    fn sub(&self, other: &Self) -> Result<Self> {
523        self.axpby(AnyScalar::new_real(1.0), other, AnyScalar::new_real(-1.0))
524    }
525
526    /// Negate all elements: `-self`.
527    fn neg(&self) -> Result<Self> {
528        self.scale(AnyScalar::new_real(-1.0))
529    }
530
531    /// Approximate equality check (Julia `isapprox` semantics).
532    ///
533    /// Returns `true` if `||self - other|| <= max(atol, rtol * max(||self||, ||other||))`.
534    fn isapprox(&self, other: &Self, atol: f64, rtol: f64) -> bool {
535        let diff = match self.sub(other) {
536            Ok(d) => d,
537            Err(_) => return false,
538        };
539        let diff_norm = diff.norm();
540        diff_norm <= atol.max(rtol * self.norm().max(other.norm()))
541    }
542
543    /// Validate structural consistency of this tensor.
544    ///
545    /// The default implementation does nothing (always succeeds).
546    /// Types with internal structure (e.g., [`BlockTensor`]) can override
547    /// this to check invariants such as index sharing between blocks.
548    fn validate(&self) -> Result<()> {
549        Ok(())
550    }
551
552    /// Create a diagonal (Kronecker delta) tensor for a single index pair.
553    ///
554    /// Creates a 2D tensor `T[i, o]` where `T[i, o] = δ_{i,o}` (1 if i==o, 0 otherwise).
555    ///
556    /// # Arguments
557    ///
558    /// * `input_index` - Input index
559    /// * `output_index` - Output index (must have same dimension as input)
560    ///
561    /// # Returns
562    ///
563    /// A 2D tensor with shape `[dim, dim]` representing the identity matrix.
564    ///
565    /// # Errors
566    ///
567    /// Returns an error if dimensions don't match.
568    ///
569    /// # Example
570    ///
571    /// For dimension 2:
572    /// ```text
573    /// diagonal(i, o) = [[1, 0], [0, 1]]
574    /// ```
575    fn diagonal(
576        input_index: &<Self as TensorIndex>::Index,
577        output_index: &<Self as TensorIndex>::Index,
578    ) -> Result<Self>;
579
580    /// Create a delta (identity) tensor as outer product of diagonals.
581    ///
582    /// For paired indices `(i1, o1), (i2, o2), ...`, creates a tensor where:
583    /// `T[i1, o1, i2, o2, ...] = δ_{i1,o1} × δ_{i2,o2} × ...`
584    ///
585    /// This is computed as the outer product of individual diagonal tensors.
586    ///
587    /// # Arguments
588    ///
589    /// * `input_indices` - Input indices
590    /// * `output_indices` - Output indices (must have same length and matching dimensions)
591    ///
592    /// # Returns
593    ///
594    /// A tensor representing the identity operator on the given index space.
595    ///
596    /// # Errors
597    ///
598    /// Returns an error if:
599    /// - Number of input and output indices don't match
600    /// - Dimensions of paired indices don't match
601    ///
602    /// # Example
603    ///
604    /// For a single index pair with dimension 2:
605    /// ```text
606    /// delta([i], [o]) = [[1, 0], [0, 1]]
607    /// ```
608    fn delta(
609        input_indices: &[<Self as TensorIndex>::Index],
610        output_indices: &[<Self as TensorIndex>::Index],
611    ) -> Result<Self> {
612        // Validate same number of input and output indices
613        if input_indices.len() != output_indices.len() {
614            return Err(anyhow::anyhow!(
615                "Number of input indices ({}) must match output indices ({})",
616                input_indices.len(),
617                output_indices.len()
618            ));
619        }
620
621        if input_indices.is_empty() {
622            // Return a scalar tensor with value 1.0
623            return Self::scalar_one();
624        }
625
626        // Build as outer product of diagonal tensors
627        let mut result = Self::diagonal(&input_indices[0], &output_indices[0])?;
628        for (inp, out) in input_indices[1..].iter().zip(output_indices[1..].iter()) {
629            let diag = Self::diagonal(inp, out)?;
630            result = result.outer_product(&diag)?;
631        }
632        Ok(result)
633    }
634
635    /// Create a scalar tensor with value 1.0.
636    ///
637    /// This is used as the identity element for outer products.
638    fn scalar_one() -> Result<Self>;
639
640    /// Create a tensor filled with 1.0 for the given indices.
641    ///
642    /// This is useful for adding indices to tensors via outer product
643    /// without changing tensor values (since multiplying by 1 is identity).
644    ///
645    /// # Example
646    /// To add a dummy index `l` to tensor `T`:
647    /// ```ignore
648    /// let ones = T::ones(&[l])?;
649    /// let t_with_l = t.outer_product(&ones)?;
650    /// ```
651    fn ones(indices: &[<Self as TensorIndex>::Index]) -> Result<Self>;
652
653    /// Create a one-hot tensor with value 1.0 at the specified index positions.
654    ///
655    /// Similar to ITensors.jl's `onehot(i => 1, j => 2)`.
656    ///
657    /// # Arguments
658    /// * `index_vals` - Pairs of (Index, 0-indexed position)
659    ///
660    /// # Errors
661    /// Returns error if any value >= corresponding index dimension.
662    fn onehot(index_vals: &[(<Self as TensorIndex>::Index, usize)]) -> Result<Self>;
663}
664
665/// Result of direct sum operation.
666#[derive(Debug, Clone)]
667pub struct DirectSumResult<T: TensorLike> {
668    /// The resulting tensor from direct sum.
669    pub tensor: T,
670    /// New indices created for the summed dimensions (one per pair).
671    pub new_indices: Vec<T::Index>,
672}
673
674#[cfg(test)]
675mod tests;