tensor4all_core/tensor_like.rs
1//! TensorLike trait for unifying tensor types.
2//!
3//! This module provides a fully generic trait for tensor-like objects that expose
4//! external indices and support contraction operations.
5//!
6//! # Design
7//!
8//! The trait is **fully generic** (monomorphic), meaning:
9//! - No trait objects (`dyn TensorLike`)
10//! - Uses associated type for `Index`
11//! - All methods return `Self` instead of concrete types
12//!
13//! For heterogeneous tensor collections, use an enum wrapper.
14
15use crate::any_scalar::AnyScalar;
16use crate::tensor_index::TensorIndex;
17use anyhow::Result;
18use std::fmt::Debug;
19
20// ============================================================================
21// Factorization types (non-generic, algorithm-specific)
22// ============================================================================
23
24use thiserror::Error;
25
26/// Error type for factorize operations.
27#[derive(Debug, Error)]
28pub enum FactorizeError {
29 /// Factorization computation failed.
30 #[error("Factorization failed: {0}")]
31 ComputationError(
32 /// The underlying error
33 #[from]
34 anyhow::Error,
35 ),
36 /// Invalid relative tolerance value (must be finite and non-negative).
37 #[error("Invalid rtol value: {0}. rtol must be finite and non-negative.")]
38 InvalidRtol(
39 /// The invalid rtol value
40 f64,
41 ),
42 /// The storage type is not supported for this operation.
43 #[error("Unsupported storage type: {0}")]
44 UnsupportedStorage(
45 /// Description of the unsupported storage type
46 &'static str,
47 ),
48 /// The canonical direction is not supported for this algorithm.
49 #[error("Unsupported canonical direction for this algorithm: {0}")]
50 UnsupportedCanonical(
51 /// Description of the unsupported canonical direction
52 &'static str,
53 ),
54 /// Error from SVD operation.
55 #[error("SVD error: {0}")]
56 SvdError(
57 /// The underlying SVD error
58 #[from]
59 crate::svd::SvdError,
60 ),
61 /// Error from QR operation.
62 #[error("QR error: {0}")]
63 QrError(
64 /// The underlying QR error
65 #[from]
66 crate::qr::QrError,
67 ),
68 /// Error from matrix CI operation.
69 #[error("Matrix CI error: {0}")]
70 MatrixCIError(
71 /// The underlying matrix CI error
72 #[from]
73 tensor4all_tcicore::MatrixCIError,
74 ),
75}
76
77/// Factorization algorithm.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
79pub enum FactorizeAlg {
80 /// Singular Value Decomposition.
81 #[default]
82 SVD,
83 /// QR decomposition.
84 QR,
85 /// Rank-revealing LU decomposition.
86 LU,
87 /// Cross Interpolation (LU-based).
88 CI,
89}
90
91/// Canonical direction for factorization.
92///
93/// This determines which factor is "canonical" (orthogonal for SVD/QR,
94/// or unit-diagonal for LU/CI).
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
96pub enum Canonical {
97 /// Left factor is canonical.
98 /// - SVD: L=U (orthogonal), R=S*V
99 /// - QR: L=Q (orthogonal), R=R
100 /// - LU/CI: L has unit diagonal
101 #[default]
102 Left,
103 /// Right factor is canonical.
104 /// - SVD: L=U*S, R=V (orthogonal)
105 /// - QR: Not supported (would need LQ)
106 /// - LU/CI: U has unit diagonal
107 Right,
108}
109
110/// Options for tensor factorization.
111#[derive(Debug, Clone)]
112pub struct FactorizeOptions {
113 /// Factorization algorithm to use.
114 pub alg: FactorizeAlg,
115 /// Canonical direction.
116 pub canonical: Canonical,
117 /// Relative tolerance for truncation.
118 /// If `None`, uses the algorithm's default.
119 pub rtol: Option<f64>,
120 /// Maximum rank for truncation.
121 /// If `None`, no rank limit is applied.
122 pub max_rank: Option<usize>,
123}
124
125impl Default for FactorizeOptions {
126 fn default() -> Self {
127 Self {
128 alg: FactorizeAlg::SVD,
129 canonical: Canonical::Left,
130 rtol: None,
131 max_rank: None,
132 }
133 }
134}
135
136impl FactorizeOptions {
137 /// Create options for SVD factorization.
138 pub fn svd() -> Self {
139 Self {
140 alg: FactorizeAlg::SVD,
141 ..Default::default()
142 }
143 }
144
145 /// Create options for QR factorization.
146 pub fn qr() -> Self {
147 Self {
148 alg: FactorizeAlg::QR,
149 ..Default::default()
150 }
151 }
152
153 /// Create options for LU factorization.
154 pub fn lu() -> Self {
155 Self {
156 alg: FactorizeAlg::LU,
157 ..Default::default()
158 }
159 }
160
161 /// Create options for CI factorization.
162 pub fn ci() -> Self {
163 Self {
164 alg: FactorizeAlg::CI,
165 ..Default::default()
166 }
167 }
168
169 /// Set canonical direction.
170 pub fn with_canonical(mut self, canonical: Canonical) -> Self {
171 self.canonical = canonical;
172 self
173 }
174
175 /// Set relative tolerance.
176 pub fn with_rtol(mut self, rtol: f64) -> Self {
177 self.rtol = Some(rtol);
178 self
179 }
180
181 /// Set maximum rank.
182 pub fn with_max_rank(mut self, max_rank: usize) -> Self {
183 self.max_rank = Some(max_rank);
184 self
185 }
186}
187
188/// Result of tensor factorization.
189///
190/// Generic over the tensor type `T`.
191#[derive(Debug, Clone)]
192pub struct FactorizeResult<T: TensorLike> {
193 /// Left factor tensor.
194 pub left: T,
195 /// Right factor tensor.
196 pub right: T,
197 /// Bond index connecting left and right factors.
198 pub bond_index: T::Index,
199 /// Singular values (only for SVD).
200 pub singular_values: Option<Vec<f64>>,
201 /// Rank of the factorization.
202 pub rank: usize,
203}
204
205// ============================================================================
206// Contraction types
207// ============================================================================
208
209/// Specifies which tensor pairs are allowed to contract.
210///
211/// This enum controls which tensor pairs can have their indices contracted
212/// in multi-tensor contraction operations. This is useful for tensor networks
213/// where the graph structure determines which tensors are connected.
214///
215/// # Example
216///
217/// ```ignore
218/// use tensor4all_core::{TensorLike, AllowedPairs};
219///
220/// // Contract all contractable index pairs (default behavior)
221/// let tensor_refs: Vec<&T> = tensors.iter().collect();
222/// let result = T::contract(&tensor_refs, AllowedPairs::All)?;
223///
224/// // Only contract indices between specified tensor pairs
225/// let edges = vec![(0, 1), (1, 2)]; // tensor 0-1 and tensor 1-2
226/// let result = T::contract(&tensor_refs, AllowedPairs::Specified(&edges))?;
227/// ```
228#[derive(Debug, Clone, Copy)]
229pub enum AllowedPairs<'a> {
230 /// All tensor pairs are allowed to contract.
231 ///
232 /// Indices with matching IDs across any two tensors will be contracted.
233 /// This is the default behavior, equivalent to ITensor's `*` operator.
234 All,
235 /// Only specified tensor pairs are allowed to contract.
236 ///
237 /// Each pair is `(tensor_idx_a, tensor_idx_b)` into the input tensor slice.
238 /// Indices are only contracted if they belong to an allowed pair.
239 ///
240 /// This is useful for tensor networks where the graph structure
241 /// determines which tensors are connected (e.g., TreeTN edges).
242 Specified(&'a [(usize, usize)]),
243}
244
245// ============================================================================
246// TensorLike trait (fully generic)
247// ============================================================================
248
249/// Trait for tensor-like objects that expose external indices and support contraction.
250///
251/// This trait is **fully generic** (monomorphic), meaning it does not support
252/// trait objects (`dyn TensorLike`). For heterogeneous tensor collections,
253/// use an enum wrapper instead.
254///
255/// # Design Principles
256///
257/// - **Minimal interface**: Only external indices and automatic contraction
258/// - **Fully generic**: Uses associated type for `Index`, returns `Self`
259/// - **Stable ordering**: `external_indices()` returns indices in deterministic order
260/// - **No trait objects**: Requires `Sized`, cannot use `dyn TensorLike`
261///
262/// # Example
263///
264/// ```ignore
265/// use tensor4all_core::{TensorLike, AllowedPairs};
266///
267/// fn contract_pair<T: TensorLike>(a: &T, b: &T) -> Result<T> {
268/// T::contract(&[a, b], AllowedPairs::All)
269/// }
270/// ```
271///
272/// # Heterogeneous Collections
273///
274/// For mixing different tensor types, define an enum:
275///
276/// ```ignore
277/// enum TensorNetwork {
278/// Dense(TensorDynLen),
279/// MPS(MatrixProductState),
280/// }
281/// ```
282///
283/// # Supertrait
284///
285/// `TensorLike` extends `TensorIndex`, which provides:
286/// - `external_indices()` - Get all external indices
287/// - `num_external_indices()` - Count external indices
288/// - `replaceind()` / `replaceinds()` - Replace indices
289///
290/// This separation allows tensor networks (like `TreeTN`) to implement
291/// index operations without implementing contraction/factorization.
292pub trait TensorLike: TensorIndex {
293 /// Factorize this tensor into left and right factors.
294 ///
295 /// This function dispatches to the appropriate algorithm based on `options.alg`:
296 /// - `SVD`: Singular Value Decomposition
297 /// - `QR`: QR decomposition
298 /// - `LU`: Rank-revealing LU decomposition
299 /// - `CI`: Cross Interpolation
300 ///
301 /// The `canonical` option controls which factor is "canonical":
302 /// - `Canonical::Left`: Left factor is orthogonal (SVD/QR) or unit-diagonal (LU/CI)
303 /// - `Canonical::Right`: Right factor is orthogonal (SVD) or unit-diagonal (LU/CI)
304 ///
305 /// # Arguments
306 /// * `left_inds` - Indices to place on the left side
307 /// * `options` - Factorization options
308 ///
309 /// # Returns
310 /// A `FactorizeResult` containing the left and right factors, bond index,
311 /// singular values (for SVD), and rank.
312 ///
313 /// # Errors
314 /// Returns `FactorizeError` if:
315 /// - The storage type is not supported (only DenseF64 and DenseC64)
316 /// - QR is used with `Canonical::Right`
317 /// - The underlying algorithm fails
318 fn factorize(
319 &self,
320 left_inds: &[<Self as TensorIndex>::Index],
321 options: &FactorizeOptions,
322 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError>;
323
324 /// Tensor conjugate operation.
325 ///
326 /// This is a generalized conjugate operation that depends on the tensor type:
327 /// - For dense tensors (TensorDynLen): element-wise complex conjugate
328 /// - For symmetric tensors: tensor conjugate considering symmetry sectors
329 ///
330 /// This operation is essential for computing inner products and overlaps
331 /// in tensor network algorithms like fitting.
332 ///
333 /// # Returns
334 /// A new tensor representing the tensor conjugate.
335 fn conj(&self) -> Self;
336
337 /// Direct sum of two tensors along specified index pairs.
338 ///
339 /// For tensors A and B with indices to be summed specified as pairs,
340 /// creates a new tensor C where each paired index has dimension = dim_A + dim_B.
341 /// Non-paired indices must match exactly between A and B (same ID).
342 ///
343 /// # Arguments
344 ///
345 /// * `other` - Second tensor
346 /// * `pairs` - Pairs of (self_index, other_index) to be summed. Each pair creates
347 /// a new index in the result with dimension = dim(self_index) + dim(other_index).
348 ///
349 /// # Returns
350 ///
351 /// A `DirectSumResult` containing the result tensor and new indices created
352 /// for the summed dimensions (one per pair).
353 ///
354 /// # Example
355 ///
356 /// ```ignore
357 /// // A has indices [i, j] with dims [2, 3]
358 /// // B has indices [i, k] with dims [2, 4]
359 /// // If we pair (j, k), result has indices [i, m] with dims [2, 7]
360 /// // where m is a new index with dim = 3 + 4 = 7
361 /// let result = a.direct_sum(&b, &[(j, k)])?;
362 /// ```
363 fn direct_sum(
364 &self,
365 other: &Self,
366 pairs: &[(<Self as TensorIndex>::Index, <Self as TensorIndex>::Index)],
367 ) -> Result<DirectSumResult<Self>>;
368
369 /// Outer product (tensor product) of two tensors.
370 ///
371 /// Computes the tensor product of `self` and `other`, resulting in a tensor
372 /// with all indices from both tensors. No indices are contracted.
373 ///
374 /// # Arguments
375 ///
376 /// * `other` - The other tensor to compute outer product with
377 ///
378 /// # Returns
379 ///
380 /// A new tensor representing the outer product.
381 ///
382 /// # Errors
383 ///
384 /// Returns an error if the tensors have common indices (by ID).
385 /// Use `tensordot` for contraction when indices overlap.
386 fn outer_product(&self, other: &Self) -> Result<Self>;
387
388 /// Compute the squared Frobenius norm of the tensor.
389 ///
390 /// The squared Frobenius norm is defined as the sum of squared absolute values
391 /// of all tensor elements: `||T||_F^2 = sum_i |T_i|^2`.
392 ///
393 /// This is used for computing norms in tensor network algorithms,
394 /// convergence checks, and normalization.
395 ///
396 /// # Returns
397 /// The squared Frobenius norm as a non-negative f64.
398 fn norm_squared(&self) -> f64;
399
400 /// Permute tensor indices to match the specified order.
401 ///
402 /// This reorders the tensor's axes to match the order specified by `new_order`.
403 /// The indices in `new_order` are matched by ID with the tensor's current indices.
404 ///
405 /// # Arguments
406 ///
407 /// * `new_order` - The desired order of indices (matched by ID)
408 ///
409 /// # Returns
410 ///
411 /// A new tensor with permuted indices.
412 ///
413 /// # Errors
414 ///
415 /// Returns an error if:
416 /// - The number of indices doesn't match
417 /// - An index ID in `new_order` is not found in the tensor
418 fn permuteinds(&self, new_order: &[<Self as TensorIndex>::Index]) -> Result<Self>;
419
420 /// Contract multiple tensors over their contractable indices.
421 ///
422 /// This method contracts 2 or more tensors. Pairs of indices that satisfy
423 /// `is_contractable()` (same ID, same dimension, compatible ConjState)
424 /// are contracted based on the `allowed` parameter.
425 ///
426 /// Handles disconnected tensor graphs automatically by:
427 /// 1. Finding connected components based on contractable indices
428 /// 2. Contracting each connected component separately
429 /// 3. Combining results using outer product
430 ///
431 /// # Arguments
432 ///
433 /// * `tensors` - Slice of tensor references to contract (must have length >= 1)
434 /// * `allowed` - Specifies which tensor pairs can have their indices contracted:
435 /// - `AllowedPairs::All`: Contract all contractable index pairs (default behavior)
436 /// - `AllowedPairs::Specified(&[(i, j)])`: Only contract indices between specified tensor pairs
437 ///
438 /// # Returns
439 ///
440 /// A new tensor representing the contracted result.
441 /// If tensors form disconnected components, they are combined via outer product.
442 ///
443 /// # Behavior by N
444 /// - N=0: Error
445 /// - N=1: Clone of input
446 /// - N>=2: Contract connected components, combine with outer product
447 ///
448 /// # Errors
449 ///
450 /// Returns an error if:
451 /// - No tensors are provided
452 /// - `AllowedPairs::Specified` contains a pair with no contractable indices
453 ///
454 /// # Example
455 ///
456 /// ```ignore
457 /// // Contract all contractable pairs
458 /// let result = T::contract(&[&a, &b, &c], AllowedPairs::All)?;
459 ///
460 /// // Only contract between tensor pairs (0,1) and (1,2)
461 /// let result = T::contract(&[&a, &b, &c], AllowedPairs::Specified(&[(0, 1), (1, 2)]))?;
462 /// ```
463 fn contract(tensors: &[&Self], allowed: AllowedPairs<'_>) -> Result<Self>;
464
465 /// Contract multiple tensors that must form a connected graph.
466 ///
467 /// This is the core contraction method that requires all tensors to be
468 /// connected through contractable indices. Use [`Self::contract`] if you want
469 /// automatic handling of disconnected components via outer product.
470 ///
471 /// # Arguments
472 ///
473 /// * `tensors` - Slice of tensor references to contract (must form a connected graph)
474 /// * `allowed` - Specifies which tensor pairs can have their indices contracted
475 ///
476 /// # Returns
477 ///
478 /// A new tensor representing the contracted result.
479 ///
480 /// # Errors
481 ///
482 /// Returns an error if:
483 /// - No tensors are provided
484 /// - The tensors form a disconnected graph
485 ///
486 /// # Example
487 ///
488 /// ```ignore
489 /// // All tensors must be connected through contractable indices
490 /// let result = T::contract_connected(&[&a, &b, &c], AllowedPairs::All)?;
491 /// ```
492 fn contract_connected(tensors: &[&Self], allowed: AllowedPairs<'_>) -> Result<Self>;
493
494 // ========================================================================
495 // Vector space operations (for Krylov solvers)
496 // ========================================================================
497
498 /// Compute a linear combination: `a * self + b * other`.
499 ///
500 /// This is the fundamental vector space operation.
501 fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self>;
502
503 /// Scalar multiplication.
504 fn scale(&self, scalar: AnyScalar) -> Result<Self>;
505
506 /// Inner product (dot product) of two tensors.
507 ///
508 /// Computes `⟨self, other⟩ = Σ conj(self)_i * other_i`.
509 fn inner_product(&self, other: &Self) -> Result<AnyScalar>;
510
511 /// Compute the Frobenius norm of the tensor.
512 fn norm(&self) -> f64 {
513 self.norm_squared().sqrt()
514 }
515
516 /// Maximum absolute value of all elements (L-infinity norm).
517 fn maxabs(&self) -> f64;
518
519 /// Element-wise subtraction: `self - other`.
520 ///
521 /// Indices are automatically permuted to match `self`'s order via `axpby`.
522 fn sub(&self, other: &Self) -> Result<Self> {
523 self.axpby(AnyScalar::new_real(1.0), other, AnyScalar::new_real(-1.0))
524 }
525
526 /// Negate all elements: `-self`.
527 fn neg(&self) -> Result<Self> {
528 self.scale(AnyScalar::new_real(-1.0))
529 }
530
531 /// Approximate equality check (Julia `isapprox` semantics).
532 ///
533 /// Returns `true` if `||self - other|| <= max(atol, rtol * max(||self||, ||other||))`.
534 fn isapprox(&self, other: &Self, atol: f64, rtol: f64) -> bool {
535 let diff = match self.sub(other) {
536 Ok(d) => d,
537 Err(_) => return false,
538 };
539 let diff_norm = diff.norm();
540 diff_norm <= atol.max(rtol * self.norm().max(other.norm()))
541 }
542
543 /// Validate structural consistency of this tensor.
544 ///
545 /// The default implementation does nothing (always succeeds).
546 /// Types with internal structure (e.g., [`BlockTensor`]) can override
547 /// this to check invariants such as index sharing between blocks.
548 fn validate(&self) -> Result<()> {
549 Ok(())
550 }
551
552 /// Create a diagonal (Kronecker delta) tensor for a single index pair.
553 ///
554 /// Creates a 2D tensor `T[i, o]` where `T[i, o] = δ_{i,o}` (1 if i==o, 0 otherwise).
555 ///
556 /// # Arguments
557 ///
558 /// * `input_index` - Input index
559 /// * `output_index` - Output index (must have same dimension as input)
560 ///
561 /// # Returns
562 ///
563 /// A 2D tensor with shape `[dim, dim]` representing the identity matrix.
564 ///
565 /// # Errors
566 ///
567 /// Returns an error if dimensions don't match.
568 ///
569 /// # Example
570 ///
571 /// For dimension 2:
572 /// ```text
573 /// diagonal(i, o) = [[1, 0], [0, 1]]
574 /// ```
575 fn diagonal(
576 input_index: &<Self as TensorIndex>::Index,
577 output_index: &<Self as TensorIndex>::Index,
578 ) -> Result<Self>;
579
580 /// Create a delta (identity) tensor as outer product of diagonals.
581 ///
582 /// For paired indices `(i1, o1), (i2, o2), ...`, creates a tensor where:
583 /// `T[i1, o1, i2, o2, ...] = δ_{i1,o1} × δ_{i2,o2} × ...`
584 ///
585 /// This is computed as the outer product of individual diagonal tensors.
586 ///
587 /// # Arguments
588 ///
589 /// * `input_indices` - Input indices
590 /// * `output_indices` - Output indices (must have same length and matching dimensions)
591 ///
592 /// # Returns
593 ///
594 /// A tensor representing the identity operator on the given index space.
595 ///
596 /// # Errors
597 ///
598 /// Returns an error if:
599 /// - Number of input and output indices don't match
600 /// - Dimensions of paired indices don't match
601 ///
602 /// # Example
603 ///
604 /// For a single index pair with dimension 2:
605 /// ```text
606 /// delta([i], [o]) = [[1, 0], [0, 1]]
607 /// ```
608 fn delta(
609 input_indices: &[<Self as TensorIndex>::Index],
610 output_indices: &[<Self as TensorIndex>::Index],
611 ) -> Result<Self> {
612 // Validate same number of input and output indices
613 if input_indices.len() != output_indices.len() {
614 return Err(anyhow::anyhow!(
615 "Number of input indices ({}) must match output indices ({})",
616 input_indices.len(),
617 output_indices.len()
618 ));
619 }
620
621 if input_indices.is_empty() {
622 // Return a scalar tensor with value 1.0
623 return Self::scalar_one();
624 }
625
626 // Build as outer product of diagonal tensors
627 let mut result = Self::diagonal(&input_indices[0], &output_indices[0])?;
628 for (inp, out) in input_indices[1..].iter().zip(output_indices[1..].iter()) {
629 let diag = Self::diagonal(inp, out)?;
630 result = result.outer_product(&diag)?;
631 }
632 Ok(result)
633 }
634
635 /// Create a scalar tensor with value 1.0.
636 ///
637 /// This is used as the identity element for outer products.
638 fn scalar_one() -> Result<Self>;
639
640 /// Create a tensor filled with 1.0 for the given indices.
641 ///
642 /// This is useful for adding indices to tensors via outer product
643 /// without changing tensor values (since multiplying by 1 is identity).
644 ///
645 /// # Example
646 /// To add a dummy index `l` to tensor `T`:
647 /// ```ignore
648 /// let ones = T::ones(&[l])?;
649 /// let t_with_l = t.outer_product(&ones)?;
650 /// ```
651 fn ones(indices: &[<Self as TensorIndex>::Index]) -> Result<Self>;
652
653 /// Create a one-hot tensor with value 1.0 at the specified index positions.
654 ///
655 /// Similar to ITensors.jl's `onehot(i => 1, j => 2)`.
656 ///
657 /// # Arguments
658 /// * `index_vals` - Pairs of (Index, 0-indexed position)
659 ///
660 /// # Errors
661 /// Returns error if any value >= corresponding index dimension.
662 fn onehot(index_vals: &[(<Self as TensorIndex>::Index, usize)]) -> Result<Self>;
663}
664
665/// Result of direct sum operation.
666#[derive(Debug, Clone)]
667pub struct DirectSumResult<T: TensorLike> {
668 /// The resulting tensor from direct sum.
669 pub tensor: T,
670 /// New indices created for the summed dimensions (one per pair).
671 pub new_indices: Vec<T::Index>,
672}
673
674#[cfg(test)]
675mod tests;