Skip to main content

tensor4all_core/defaults/
tensordynlen.rs

1use crate::defaults::DynIndex;
2use crate::index_like::IndexLike;
3use crate::index_ops::{common_ind_positions, prepare_contraction, prepare_contraction_pairs};
4use crate::tensor_like::LinearizationOrder;
5use crate::AnyScalar;
6use anyhow::{Context, Result};
7use num_complex::Complex64;
8use num_traits::Zero;
9use rand::Rng;
10use rand_distr::{Distribution, StandardNormal};
11use std::cell::RefCell;
12use std::cmp::Reverse;
13use std::collections::{HashMap, HashSet};
14use std::env;
15use std::sync::{Arc, OnceLock};
16use std::time::{Duration, Instant};
17use tenferro::{DType, DotGeneralConfig, Tensor as NativeTensor};
18use tenferro_ad::EagerTensor;
19use tenferro_einsum::eager_tensor::einsum_subscripts as eager_einsum_ad;
20use tenferro_einsum::EinsumSubscripts;
21use tensor4all_tensorbackend::{
22    axpby_native_tensor, contract_native_tensor, default_eager_ctx,
23    dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
24    native_tensor_primal_to_dense_col_major, native_tensor_primal_to_diag_c64,
25    native_tensor_primal_to_diag_f64, native_tensor_primal_to_storage, scale_native_tensor,
26    storage_payload_native_read_input, storage_to_native_tensor, AnyScalar as BackendScalar,
27    StorageScalar, TensorElement,
28};
29use tensor4all_tensorbackend::{Storage, StorageKind};
30
31use super::contract::PairwiseContractionOptions;
32use super::structured_contraction::{
33    normalize_payload_read_for_roots, storage_from_payload_native, storage_payload_native,
34    OperandLayout, StructuredContractionPlan, StructuredContractionSpec,
35};
36
37#[derive(Debug, Default, Clone)]
38struct PairwiseContractProfileEntry {
39    calls: usize,
40    total_time: Duration,
41    total_bytes: usize,
42}
43
44thread_local! {
45    static PAIRWISE_CONTRACT_PROFILE_STATE: RefCell<HashMap<&'static str, PairwiseContractProfileEntry>> =
46        RefCell::new(HashMap::new());
47}
48
49fn pairwise_contract_profile_enabled() -> bool {
50    static ENABLED: OnceLock<bool> = OnceLock::new();
51    *ENABLED.get_or_init(|| env::var("T4A_PROFILE_PAIRWISE_CONTRACT").is_ok())
52}
53
54fn record_pairwise_contract_profile(section: &'static str, elapsed: Duration) {
55    if !pairwise_contract_profile_enabled() {
56        return;
57    }
58    PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
59        let mut state = state.borrow_mut();
60        let entry = state.entry(section).or_default();
61        entry.calls += 1;
62        entry.total_time += elapsed;
63    });
64}
65
66fn record_pairwise_contract_profile_bytes(section: &'static str, bytes: usize) {
67    if !pairwise_contract_profile_enabled() {
68        return;
69    }
70    PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
71        let mut state = state.borrow_mut();
72        let entry = state.entry(section).or_default();
73        entry.total_bytes += bytes;
74    });
75}
76
77fn profile_pairwise_contract_section<T>(section: &'static str, f: impl FnOnce() -> T) -> T {
78    if !pairwise_contract_profile_enabled() {
79        return f();
80    }
81    let started = Instant::now();
82    let result = f();
83    record_pairwise_contract_profile(section, started.elapsed());
84    result
85}
86
87/// Reset the aggregated pairwise `TensorDynLen` contraction profile.
88pub fn reset_pairwise_contract_profile() {
89    PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
90}
91
92/// Print and clear the aggregated pairwise `TensorDynLen` contraction profile.
93pub fn print_and_reset_pairwise_contract_profile() {
94    if !pairwise_contract_profile_enabled() {
95        return;
96    }
97    PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
98        let mut entries: Vec<_> = state
99            .borrow()
100            .iter()
101            .map(|(section, entry)| (*section, entry.clone()))
102            .collect();
103        state.borrow_mut().clear();
104        entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
105
106        eprintln!("=== TensorDynLen pairwise contract profile ===");
107        for (section, entry) in entries {
108            let per_call_us = if entry.calls == 0 {
109                0.0
110            } else {
111                entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64
112            };
113            eprintln!(
114                "{section}: calls={} total={:.6}ms per_call={:.3}us bytes={}",
115                entry.calls,
116                entry.total_time.as_secs_f64() * 1.0e3,
117                per_call_us,
118                entry.total_bytes,
119            );
120        }
121    });
122}
123
124fn native_tensor_profile_bytes(native: &NativeTensor) -> usize {
125    let element_size = match native.dtype() {
126        DType::F32 => 4,
127        DType::F64 => 8,
128        DType::C32 => 8,
129        DType::C64 => 16,
130        DType::I32 => 4,
131        DType::I64 => 8,
132        DType::Bool => 1,
133    };
134    native.shape().iter().product::<usize>() * element_size
135}
136
137/// Trait for scalar types that can generate random values from a standard
138/// normal distribution.
139///
140/// This enables the generic [`TensorDynLen::random`] constructor.
141pub trait RandomScalar: TensorElement {
142    /// Generate a random value from the standard normal distribution.
143    fn random_value<R: Rng>(rng: &mut R) -> Self;
144}
145
146impl RandomScalar for f64 {
147    fn random_value<R: Rng>(rng: &mut R) -> Self {
148        StandardNormal.sample(rng)
149    }
150}
151
152impl RandomScalar for Complex64 {
153    fn random_value<R: Rng>(rng: &mut R) -> Self {
154        Complex64::new(StandardNormal.sample(rng), StandardNormal.sample(rng))
155    }
156}
157
158/// Compute the permutation array from original indices to new indices.
159///
160/// This function finds the mapping from new indices to original indices by
161/// matching index IDs. The result is a permutation array `perm` such that
162/// `new_indices[i]` corresponds to `original_indices[perm[i]]`.
163///
164/// # Arguments
165/// * `original_indices` - The original indices in their current order
166/// * `new_indices` - The desired new indices order (must be a permutation of original_indices)
167///
168/// # Returns
169/// A `Vec<usize>` representing the permutation: `perm[i]` is the position in
170/// `original_indices` of the index that should be at position `i` in `new_indices`.
171///
172/// # Errors
173/// Returns an error if the slices have different lengths, if `new_indices`
174/// is not a permutation of `original_indices`, or if `new_indices` contains
175/// duplicate indices.
176///
177/// # Example
178/// ```
179/// use tensor4all_core::tensor::compute_permutation_from_indices;
180/// use tensor4all_core::DynIndex;
181///
182/// let i = DynIndex::new_dyn(2);
183/// let j = DynIndex::new_dyn(3);
184/// let original = vec![i.clone(), j.clone()];
185/// let new_order = vec![j.clone(), i.clone()];
186///
187/// let perm = compute_permutation_from_indices(&original, &new_order).unwrap();
188/// assert_eq!(perm, vec![1, 0]);  // j is at position 1, i is at position 0
189/// ```
190pub fn compute_permutation_from_indices(
191    original_indices: &[DynIndex],
192    new_indices: &[DynIndex],
193) -> Result<Vec<usize>> {
194    anyhow::ensure!(
195        new_indices.len() == original_indices.len(),
196        "new_indices length must match original_indices length"
197    );
198
199    let mut perm = Vec::with_capacity(new_indices.len());
200    let mut used = std::collections::HashSet::new();
201
202    for new_idx in new_indices {
203        // Find the position of this index in the original indices
204        // DynIndex implements Eq, so we can compare directly
205        let pos = original_indices
206            .iter()
207            .position(|old_idx| old_idx == new_idx)
208            .ok_or_else(|| {
209                anyhow::anyhow!("new_indices must be a permutation of original_indices")
210            })?;
211
212        anyhow::ensure!(used.insert(pos), "duplicate index in new_indices");
213        perm.push(pos);
214    }
215
216    Ok(perm)
217}
218
219#[derive(Clone)]
220pub(crate) struct StructuredAdValue {
221    payload: Arc<EagerTensor>,
222    payload_dims: Vec<usize>,
223    axis_classes: Vec<usize>,
224}
225
226#[derive(Clone)]
227pub(crate) enum TensorDynLenStorage {
228    Materialized(Arc<Storage>),
229    Eager {
230        inner: Arc<EagerTensor>,
231        axis_classes: Vec<usize>,
232    },
233}
234
235impl TensorDynLenStorage {
236    fn from_storage(storage: Arc<Storage>) -> Self {
237        Self::Materialized(storage)
238    }
239
240    fn from_eager_dense(inner: EagerTensor, rank: usize) -> Self {
241        Self::Eager {
242            inner: Arc::new(inner),
243            axis_classes: TensorDynLen::dense_axis_classes(rank),
244        }
245    }
246
247    fn eager(&self) -> Option<&EagerTensor> {
248        match self {
249            Self::Materialized(_) => None,
250            Self::Eager { inner, .. } => Some(inner.as_ref()),
251        }
252    }
253
254    fn axis_classes(&self) -> &[usize] {
255        match self {
256            Self::Materialized(storage) => storage.axis_classes(),
257            Self::Eager { axis_classes, .. } => axis_classes,
258        }
259    }
260
261    fn payload_dims(&self) -> &[usize] {
262        match self {
263            Self::Materialized(storage) => storage.payload_dims(),
264            Self::Eager { inner, .. } => inner.data().shape(),
265        }
266    }
267
268    fn payload_strides_vec(&self) -> Vec<isize> {
269        match self {
270            Self::Materialized(storage) => storage.payload_strides().to_vec(),
271            Self::Eager { inner, .. } => {
272                let mut stride = 1isize;
273                inner
274                    .data()
275                    .shape()
276                    .iter()
277                    .map(|&dim| {
278                        let current = stride;
279                        stride *= isize::try_from(dim).unwrap_or(isize::MAX);
280                        current
281                    })
282                    .collect()
283            }
284        }
285    }
286
287    fn is_f64(&self) -> bool {
288        match self {
289            Self::Materialized(storage) => storage.is_f64(),
290            Self::Eager { inner, .. } => inner.data().dtype() == DType::F64,
291        }
292    }
293
294    fn is_c64(&self) -> bool {
295        match self {
296            Self::Materialized(storage) => storage.is_c64(),
297            Self::Eager { inner, .. } => inner.data().dtype() == DType::C64,
298        }
299    }
300
301    fn is_complex(&self) -> bool {
302        match self {
303            Self::Materialized(storage) => storage.is_complex(),
304            Self::Eager { inner, .. } => matches!(inner.data().dtype(), DType::C32 | DType::C64),
305        }
306    }
307
308    fn is_diag(&self) -> bool {
309        match self {
310            Self::Materialized(storage) => storage.is_diag(),
311            Self::Eager { axis_classes, .. } => TensorDynLen::is_diag_axis_classes(axis_classes),
312        }
313    }
314
315    fn storage_kind(&self) -> StorageKind {
316        match self {
317            Self::Materialized(storage) => storage.storage_kind(),
318            Self::Eager { axis_classes, .. } => {
319                if axis_classes.iter().copied().eq(0..axis_classes.len()) {
320                    StorageKind::Dense
321                } else if TensorDynLen::is_diag_axis_classes(axis_classes) {
322                    StorageKind::Diagonal
323                } else {
324                    StorageKind::Structured
325                }
326            }
327        }
328    }
329
330    fn materialize(&self, logical_rank: usize) -> Result<Arc<Storage>> {
331        match self {
332            Self::Materialized(storage) => Ok(Arc::clone(storage)),
333            Self::Eager {
334                inner,
335                axis_classes,
336            } => Ok(Arc::new(
337                TensorDynLen::storage_from_native_with_axis_classes(
338                    inner.data(),
339                    axis_classes,
340                    logical_rank,
341                )?,
342            )),
343        }
344    }
345
346    fn scale(&self, scalar: &BackendScalar) -> Result<Storage> {
347        Ok(self.materialize(self.axis_classes().len())?.scale(scalar))
348    }
349
350    fn conj(&self) -> Result<Self> {
351        match self {
352            Self::Materialized(storage) => Ok(Self::Materialized(Arc::new(storage.conj()))),
353            Self::Eager {
354                inner,
355                axis_classes,
356            } => Ok(Self::Eager {
357                inner: Arc::new(inner.conj()?),
358                axis_classes: axis_classes.clone(),
359            }),
360        }
361    }
362
363    fn max_abs(&self) -> Result<f64> {
364        Ok(self.materialize(self.axis_classes().len())?.max_abs())
365    }
366}
367
368/// Dynamic-rank tensor with structured payload storage -- the central data type
369/// of tensor4all.
370///
371/// `TensorDynLen` stores a logical multi-dimensional tensor of `f64` or
372/// `Complex64` values together with a list of [`DynIndex`] labels. The
373/// authoritative payload is compact [`Storage`], which may be dense, diagonal,
374/// or explicitly structured. The indices carry unique identities (UUIDs) so
375/// that contraction, addition, and other binary operations can automatically
376/// match legs by identity rather than position.
377///
378/// # Key Operations
379///
380/// | Operation | Method |
381/// |-----------|--------|
382/// | Create from data | [`from_dense`](Self::from_dense), [`from_diag`](Self::from_diag), [`zeros`](Self::zeros) |
383/// | Extract data | [`to_vec`](Self::to_vec), [`into_dense_col_major_parts`](Self::into_dense_col_major_parts), [`sum`](Self::sum), [`only`](Self::only) |
384/// | Contraction | [`contract`](Self::contract) |
385/// | Arithmetic | [`add`](Self::add), [`scale`](Self::scale), [`axpby`](Self::axpby) |
386/// | Factorization | via [`TensorFactorizationLike::factorize`](crate::TensorFactorizationLike::factorize) |
387/// | Norms | [`norm`](Self::norm), [`norm_squared`](Self::norm_squared), [`maxabs`](Self::maxabs) |
388/// | Index ops | [`replaceind`](Self::replaceind), [`permute_indices`](Self::permute_indices) |
389///
390/// # Data Layout
391///
392/// Logical dense extraction uses **column-major** order (first index varies
393/// fastest), matching Fortran, Julia, and ITensors.jl conventions. Compact
394/// structured payloads additionally carry explicit payload dimensions, strides,
395/// and logical-axis classes.
396///
397/// # Examples
398///
399/// ```
400/// use tensor4all_core::{TensorDynLen, DynIndex};
401///
402/// // Create a 2x3 real tensor
403/// let i = DynIndex::new_dyn(2);
404/// let j = DynIndex::new_dyn(3);
405/// let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
406/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
407///
408/// assert_eq!(t.dims(), vec![2, 3]);
409/// assert!(t.is_f64());
410///
411/// // Sum all elements: 1+2+3+4+5+6 = 21
412/// let s = t.sum().unwrap();
413/// assert!((s.real() - 21.0).abs() < 1e-12);
414///
415/// // Extract data back out
416/// let data_out = t.to_vec::<f64>().unwrap();
417/// assert_eq!(data_out, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
418/// ```
419#[derive(Clone)]
420pub struct TensorDynLen {
421    /// Full index information (includes tags and other metadata).
422    pub indices: Vec<DynIndex>,
423    /// Authoritative compact payload storage.
424    pub(crate) storage: TensorDynLenStorage,
425    /// Optional tracked compact payload used to preserve structured AD layouts.
426    pub(crate) structured_ad: Option<Arc<StructuredAdValue>>,
427    /// Lazily materialized eager payload for native execution and AD.
428    pub(crate) eager_cache: Arc<OnceLock<Arc<EagerTensor>>>,
429}
430
431impl TensorDynLen {
432    fn dense_axis_classes(rank: usize) -> Vec<usize> {
433        (0..rank).collect()
434    }
435
436    fn diag_axis_classes(rank: usize) -> Vec<usize> {
437        if rank == 0 {
438            vec![]
439        } else {
440            vec![0; rank]
441        }
442    }
443
444    fn canonicalize_axis_classes(axis_classes: &[usize]) -> Vec<usize> {
445        let mut map = std::collections::HashMap::new();
446        let mut next = 0usize;
447        axis_classes
448            .iter()
449            .map(|&class_id| {
450                *map.entry(class_id).or_insert_with(|| {
451                    let canonical = next;
452                    next += 1;
453                    canonical
454                })
455            })
456            .collect()
457    }
458
459    fn permute_axis_classes(&self, perm: &[usize]) -> Vec<usize> {
460        let axis_classes = self.storage.axis_classes();
461        let permuted: Vec<usize> = perm.iter().map(|&index| axis_classes[index]).collect();
462        Self::canonicalize_axis_classes(&permuted)
463    }
464
465    fn normalize_insert_axis(op: &str, axis: isize, rank: usize) -> Result<usize> {
466        let normalized = if axis < 0 {
467            rank as isize + 1 + axis
468        } else {
469            axis
470        };
471        anyhow::ensure!(
472            normalized >= 0 && normalized <= rank as isize,
473            "{op}: axis {axis} is out of bounds for inserting into rank {rank}"
474        );
475        Ok(normalized as usize)
476    }
477
478    fn is_diag_axis_classes(axis_classes: &[usize]) -> bool {
479        axis_classes.len() >= 2 && axis_classes.iter().all(|&class_id| class_id == 0)
480    }
481
482    fn einsum_subscripts_from_usize_ids(
483        inputs: &[Vec<usize>],
484        output: &[usize],
485    ) -> Result<EinsumSubscripts> {
486        let input_labels = inputs
487            .iter()
488            .map(|ids| {
489                ids.iter()
490                    .map(|&id| {
491                        u32::try_from(id)
492                            .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
493                    })
494                    .collect::<Result<Vec<_>>>()
495            })
496            .collect::<Result<Vec<_>>>()?;
497        let output_labels = output
498            .iter()
499            .map(|&id| {
500                u32::try_from(id)
501                    .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
502            })
503            .collect::<Result<Vec<_>>>()?;
504        let input_refs = input_labels.iter().map(Vec::as_slice).collect::<Vec<_>>();
505        Ok(EinsumSubscripts::new(&input_refs, &output_labels))
506    }
507
508    fn build_binary_einsum_subscripts(
509        lhs_rank: usize,
510        axes_a: &[usize],
511        rhs_rank: usize,
512        axes_b: &[usize],
513    ) -> Result<EinsumSubscripts> {
514        anyhow::ensure!(
515            axes_a.len() == axes_b.len(),
516            "contract axis length mismatch: lhs {:?}, rhs {:?}",
517            axes_a,
518            axes_b
519        );
520
521        let mut lhs_ids = vec![usize::MAX; lhs_rank];
522        let mut rhs_ids = vec![usize::MAX; rhs_rank];
523        let mut next_id = 0usize;
524
525        let mut seen_lhs = vec![false; lhs_rank];
526        let mut seen_rhs = vec![false; rhs_rank];
527
528        for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
529            anyhow::ensure!(
530                lhs_axis < lhs_rank,
531                "lhs contract axis {lhs_axis} out of range"
532            );
533            anyhow::ensure!(
534                rhs_axis < rhs_rank,
535                "rhs contract axis {rhs_axis} out of range"
536            );
537            anyhow::ensure!(
538                !seen_lhs[lhs_axis],
539                "duplicate lhs contract axis {lhs_axis}"
540            );
541            anyhow::ensure!(
542                !seen_rhs[rhs_axis],
543                "duplicate rhs contract axis {rhs_axis}"
544            );
545            seen_lhs[lhs_axis] = true;
546            seen_rhs[rhs_axis] = true;
547            lhs_ids[lhs_axis] = next_id;
548            rhs_ids[rhs_axis] = next_id;
549            next_id += 1;
550        }
551
552        let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
553        for id in &mut lhs_ids {
554            if *id == usize::MAX {
555                *id = next_id;
556                output_ids.push(next_id);
557                next_id += 1;
558            }
559        }
560        for id in &mut rhs_ids {
561            if *id == usize::MAX {
562                *id = next_id;
563                output_ids.push(next_id);
564                next_id += 1;
565            }
566        }
567
568        Self::einsum_subscripts_from_usize_ids(&[lhs_ids, rhs_ids], &output_ids)
569    }
570
571    fn binary_dot_general_config(axes_a: &[usize], axes_b: &[usize]) -> Result<DotGeneralConfig> {
572        anyhow::ensure!(
573            axes_a.len() == axes_b.len(),
574            "contract axis length mismatch: lhs {:?}, rhs {:?}",
575            axes_a,
576            axes_b
577        );
578        Ok(DotGeneralConfig {
579            lhs_contracting_dims: axes_a.to_vec(),
580            rhs_contracting_dims: axes_b.to_vec(),
581            lhs_batch_dims: vec![],
582            rhs_batch_dims: vec![],
583        })
584    }
585
586    fn binary_contraction_axis_classes(
587        lhs_axis_classes: &[usize],
588        axes_a: &[usize],
589        rhs_axis_classes: &[usize],
590        axes_b: &[usize],
591    ) -> Vec<usize> {
592        debug_assert_eq!(axes_a.len(), axes_b.len());
593
594        fn find(parent: &mut [usize], value: usize) -> usize {
595            if parent[value] != value {
596                parent[value] = find(parent, parent[value]);
597            }
598            parent[value]
599        }
600
601        fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
602            let lhs_root = find(parent, lhs);
603            let rhs_root = find(parent, rhs);
604            if lhs_root != rhs_root {
605                parent[rhs_root] = lhs_root;
606            }
607        }
608
609        let lhs_payload_rank = lhs_axis_classes
610            .iter()
611            .copied()
612            .max()
613            .map(|value| value + 1)
614            .unwrap_or(0);
615        let rhs_payload_rank = rhs_axis_classes
616            .iter()
617            .copied()
618            .max()
619            .map(|value| value + 1)
620            .unwrap_or(0);
621        let rhs_offset = lhs_payload_rank;
622        let mut parent: Vec<usize> = (0..lhs_payload_rank + rhs_payload_rank).collect();
623
624        for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
625            union(
626                &mut parent,
627                lhs_axis_classes[lhs_axis],
628                rhs_offset + rhs_axis_classes[rhs_axis],
629            );
630        }
631
632        let mut lhs_contracted = vec![false; lhs_axis_classes.len()];
633        for &axis in axes_a {
634            lhs_contracted[axis] = true;
635        }
636        let mut rhs_contracted = vec![false; rhs_axis_classes.len()];
637        for &axis in axes_b {
638            rhs_contracted[axis] = true;
639        }
640
641        let mut root_to_class = std::collections::HashMap::new();
642        let mut next_class = 0usize;
643        let mut axis_classes = Vec::new();
644
645        for (axis, &class_id) in lhs_axis_classes.iter().enumerate() {
646            if !lhs_contracted[axis] {
647                let root = find(&mut parent, class_id);
648                let class = *root_to_class.entry(root).or_insert_with(|| {
649                    let value = next_class;
650                    next_class += 1;
651                    value
652                });
653                axis_classes.push(class);
654            }
655        }
656        for (axis, &class_id) in rhs_axis_classes.iter().enumerate() {
657            if !rhs_contracted[axis] {
658                let root = find(&mut parent, rhs_offset + class_id);
659                let class = *root_to_class.entry(root).or_insert_with(|| {
660                    let value = next_class;
661                    next_class += 1;
662                    value
663                });
664                axis_classes.push(class);
665            }
666        }
667
668        axis_classes
669    }
670
671    fn scale_subscripts(rank: usize) -> Result<EinsumSubscripts> {
672        let ids: Vec<usize> = (0..rank).collect();
673        Self::einsum_subscripts_from_usize_ids(&[ids.clone(), Vec::new()], &ids)
674    }
675
676    fn validate_indices(indices: &[DynIndex]) -> Result<()> {
677        let mut seen = HashSet::new();
678        for idx in indices {
679            anyhow::ensure!(
680                seen.insert(idx.clone()),
681                "Tensor indices must all be unique"
682            );
683        }
684        Ok(())
685    }
686
687    fn validate_diag_dims(dims: &[usize]) -> Result<()> {
688        if !dims.is_empty() {
689            let first_dim = dims[0];
690            for (i, &dim) in dims.iter().enumerate() {
691                anyhow::ensure!(
692                    dim == first_dim,
693                    "DiagTensor requires all indices to have the same dimension, but dims[{i}] = {dim} != dims[0] = {first_dim}"
694                );
695            }
696        }
697        Ok(())
698    }
699
700    fn seed_native_payload(storage: &Storage, dims: &[usize]) -> Result<NativeTensor> {
701        storage_to_native_tensor(storage, dims)
702    }
703
704    fn empty_eager_cache() -> Arc<OnceLock<Arc<EagerTensor>>> {
705        Arc::new(OnceLock::new())
706    }
707
708    fn eager_cache_with(inner: EagerTensor) -> Arc<OnceLock<Arc<EagerTensor>>> {
709        let cache = Arc::new(OnceLock::new());
710        let _ = cache.set(Arc::new(inner));
711        cache
712    }
713
714    fn compact_payload_inner(&self) -> Result<EagerTensor> {
715        Ok(EagerTensor::from_tensor_in(
716            storage_payload_native(self.storage.materialize(self.indices.len())?.as_ref())?,
717            default_eager_ctx(),
718        ))
719    }
720
721    fn tracked_compact_payload_value(&self) -> Option<&StructuredAdValue> {
722        self.structured_ad.as_deref()
723    }
724
725    fn compact_payload_is_logical_dense(&self, payload_dims: &[usize]) -> bool {
726        self.storage.axis_classes() == Self::dense_axis_classes(self.indices.len())
727            && payload_dims == self.dims()
728    }
729
730    fn uses_tracked_compact_storage(&self) -> bool {
731        self.tracked_compact_payload_value()
732            .is_some_and(|value| !self.compact_payload_is_logical_dense(&value.payload_dims))
733    }
734
735    fn ensure_shape_packing_preserves_ad(&self, op_name: &str) -> Result<()> {
736        anyhow::ensure!(
737            !self.uses_tracked_compact_storage(),
738            "{op_name}: structured AD tensors with compact storage are not supported because materializing compact storage would detach gradients"
739        );
740        Ok(())
741    }
742
743    fn operand_indices_for_contraction(&self, conjugate: bool) -> Vec<DynIndex> {
744        if conjugate {
745            self.indices.iter().map(|index| index.conj()).collect()
746        } else {
747            self.indices.clone()
748        }
749    }
750
751    fn build_binary_contraction_labels(
752        lhs_rank: usize,
753        axes_a: &[usize],
754        rhs_rank: usize,
755        axes_b: &[usize],
756    ) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> {
757        anyhow::ensure!(
758            axes_a.len() == axes_b.len(),
759            "contract axis length mismatch: lhs {:?}, rhs {:?}",
760            axes_a,
761            axes_b
762        );
763
764        let mut lhs_ids = vec![usize::MAX; lhs_rank];
765        let mut rhs_ids = vec![usize::MAX; rhs_rank];
766        let mut next_id = 0usize;
767
768        let mut seen_lhs = vec![false; lhs_rank];
769        let mut seen_rhs = vec![false; rhs_rank];
770
771        for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
772            anyhow::ensure!(
773                lhs_axis < lhs_rank,
774                "lhs contract axis {lhs_axis} out of range"
775            );
776            anyhow::ensure!(
777                rhs_axis < rhs_rank,
778                "rhs contract axis {rhs_axis} out of range"
779            );
780            anyhow::ensure!(
781                !seen_lhs[lhs_axis],
782                "duplicate lhs contract axis {lhs_axis}"
783            );
784            anyhow::ensure!(
785                !seen_rhs[rhs_axis],
786                "duplicate rhs contract axis {rhs_axis}"
787            );
788            seen_lhs[lhs_axis] = true;
789            seen_rhs[rhs_axis] = true;
790            lhs_ids[lhs_axis] = next_id;
791            rhs_ids[rhs_axis] = next_id;
792            next_id += 1;
793        }
794
795        let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
796        for id in &mut lhs_ids {
797            if *id == usize::MAX {
798                *id = next_id;
799                output_ids.push(next_id);
800                next_id += 1;
801            }
802        }
803        for id in &mut rhs_ids {
804            if *id == usize::MAX {
805                *id = next_id;
806                output_ids.push(next_id);
807                next_id += 1;
808            }
809        }
810
811        Ok((lhs_ids, rhs_ids, output_ids))
812    }
813
814    fn build_payload_einsum_subscripts(
815        input_roots: &[Vec<usize>],
816        output_roots: &[usize],
817    ) -> Result<EinsumSubscripts> {
818        Self::einsum_subscripts_from_usize_ids(input_roots, output_roots)
819    }
820
821    fn normalize_eager_payload_for_roots(
822        payload: &EagerTensor,
823        roots: &[usize],
824    ) -> Result<(Option<EagerTensor>, Vec<usize>)> {
825        anyhow::ensure!(
826            payload.data().shape().len() == roots.len(),
827            "payload rank {} does not match root label count {}",
828            payload.data().shape().len(),
829            roots.len()
830        );
831
832        let mut current_payload = None;
833        let mut current_roots = roots.to_vec();
834        while let Some((axis_a, axis_b)) = Self::first_duplicate_pair(&current_roots) {
835            let source = current_payload.as_ref().unwrap_or(payload);
836            current_payload = Some(source.extract_diag(axis_a, axis_b)?);
837            current_roots.remove(axis_b);
838        }
839
840        Ok((current_payload, current_roots))
841    }
842
843    fn first_duplicate_pair(values: &[usize]) -> Option<(usize, usize)> {
844        let mut first_axis_by_value = std::collections::HashMap::new();
845        for (axis, &value) in values.iter().enumerate() {
846            if let Some(&first_axis) = first_axis_by_value.get(&value) {
847                return Some((first_axis, axis));
848            }
849            first_axis_by_value.insert(value, axis);
850        }
851        None
852    }
853
854    fn binary_structured_contraction_plan(
855        &self,
856        other: &Self,
857        axes_a: &[usize],
858        axes_b: &[usize],
859    ) -> Result<(StructuredContractionPlan, Vec<Vec<usize>>, Vec<usize>)> {
860        let (lhs_labels, rhs_labels, output_labels) = Self::build_binary_contraction_labels(
861            self.indices.len(),
862            axes_a,
863            other.indices.len(),
864            axes_b,
865        )?;
866        let operands = vec![
867            OperandLayout::new(self.dims(), self.storage.axis_classes().to_vec())?,
868            OperandLayout::new(other.dims(), other.storage.axis_classes().to_vec())?,
869        ];
870        let spec = StructuredContractionSpec {
871            input_labels: vec![lhs_labels, rhs_labels],
872            output_labels,
873            retained_labels: Default::default(),
874        };
875        let plan = StructuredContractionPlan::new(&operands, &spec)?;
876        Ok((plan, spec.input_labels, spec.output_labels))
877    }
878
879    fn from_structured_payload_inner(
880        indices: Vec<DynIndex>,
881        payload_inner: EagerTensor,
882        payload_dims: Vec<usize>,
883        axis_classes: Vec<usize>,
884    ) -> Result<Self> {
885        Self::validate_indices(&indices)?;
886        if payload_inner.data().shape() != payload_dims {
887            return Err(anyhow::anyhow!(
888                "structured payload dims {:?} do not match planned payload dims {:?}",
889                payload_inner.data().shape(),
890                payload_dims
891            ));
892        }
893        let storage = storage_from_payload_native(
894            payload_inner.data().clone(),
895            &payload_dims,
896            axis_classes.clone(),
897        )?;
898        Self::validate_storage_matches_indices(&indices, &storage)?;
899        Ok(Self {
900            indices,
901            storage: TensorDynLenStorage::from_storage(Arc::new(storage)),
902            structured_ad: Some(Arc::new(StructuredAdValue {
903                payload: Arc::new(payload_inner),
904                payload_dims,
905                axis_classes,
906            })),
907            eager_cache: Self::empty_eager_cache(),
908        })
909    }
910
911    fn contract_structured_payloads(
912        &self,
913        other: &Self,
914        result_indices: Vec<DynIndex>,
915        axes_a: &[usize],
916        axes_b: &[usize],
917    ) -> Result<Self> {
918        let (plan, _, _) = self.binary_structured_contraction_plan(other, axes_a, axes_b)?;
919        let lhs_roots = plan.operand_plans[0].class_roots.clone();
920        let rhs_roots = plan.operand_plans[1].class_roots.clone();
921        let scalar_multiply =
922            lhs_roots.is_empty() && rhs_roots.is_empty() && plan.output_payload_roots.is_empty();
923
924        if let (Some(lhs_ad), Some(rhs_ad)) = (
925            self.tracked_compact_payload_value(),
926            other.tracked_compact_payload_value(),
927        ) {
928            if lhs_ad.payload.data().dtype() != rhs_ad.payload.data().dtype() {
929                return Err(anyhow::anyhow!(
930                    "structured AD contraction requires matching payload dtypes"
931                ));
932            }
933            let (lhs_normalized, lhs_labels) =
934                Self::normalize_eager_payload_for_roots(lhs_ad.payload.as_ref(), &lhs_roots)?;
935            let (rhs_normalized, rhs_labels) =
936                Self::normalize_eager_payload_for_roots(rhs_ad.payload.as_ref(), &rhs_roots)?;
937            let lhs_payload = lhs_normalized
938                .as_ref()
939                .unwrap_or_else(|| lhs_ad.payload.as_ref());
940            let rhs_payload = rhs_normalized
941                .as_ref()
942                .unwrap_or_else(|| rhs_ad.payload.as_ref());
943            let payload = if scalar_multiply {
944                lhs_payload.mul(rhs_payload)?
945            } else {
946                let subscripts = Self::build_payload_einsum_subscripts(
947                    &[lhs_labels, rhs_labels],
948                    &plan.output_payload_roots,
949                )?;
950                eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
951            };
952            return Self::from_structured_payload_inner(
953                result_indices,
954                payload,
955                plan.output_payload_dims,
956                plan.output_axis_classes,
957            );
958        }
959
960        if self.tracked_compact_payload_value().is_some()
961            || other.tracked_compact_payload_value().is_some()
962        {
963            let lhs_owned = if self.tracked_compact_payload_value().is_some() {
964                None
965            } else {
966                Some(self.compact_payload_inner()?)
967            };
968            let rhs_owned = if other.tracked_compact_payload_value().is_some() {
969                None
970            } else {
971                Some(other.compact_payload_inner()?)
972            };
973            let lhs = if let Some(value) = self.tracked_compact_payload_value() {
974                value.payload.as_ref()
975            } else {
976                lhs_owned
977                    .as_ref()
978                    .ok_or_else(|| anyhow::anyhow!("missing untracked left compact payload"))?
979            };
980            let rhs = if let Some(value) = other.tracked_compact_payload_value() {
981                value.payload.as_ref()
982            } else {
983                rhs_owned
984                    .as_ref()
985                    .ok_or_else(|| anyhow::anyhow!("missing untracked right compact payload"))?
986            };
987            if lhs.data().dtype() != rhs.data().dtype() {
988                return Err(anyhow::anyhow!(
989                    "structured AD contraction requires matching payload dtypes"
990                ));
991            }
992            let (lhs_normalized, lhs_labels) =
993                Self::normalize_eager_payload_for_roots(lhs, &lhs_roots)?;
994            let (rhs_normalized, rhs_labels) =
995                Self::normalize_eager_payload_for_roots(rhs, &rhs_roots)?;
996            let lhs_payload = lhs_normalized.as_ref().unwrap_or(lhs);
997            let rhs_payload = rhs_normalized.as_ref().unwrap_or(rhs);
998            let payload = if scalar_multiply {
999                lhs_payload.mul(rhs_payload)?
1000            } else {
1001                let subscripts = Self::build_payload_einsum_subscripts(
1002                    &[lhs_labels, rhs_labels],
1003                    &plan.output_payload_roots,
1004                )?;
1005                eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
1006            };
1007            return Self::from_structured_payload_inner(
1008                result_indices,
1009                payload,
1010                plan.output_payload_dims,
1011                plan.output_axis_classes,
1012            );
1013        }
1014
1015        let lhs_storage = self.storage.materialize(self.indices.len())?;
1016        let rhs_storage = other.storage.materialize(other.indices.len())?;
1017        let lhs = storage_payload_native_read_input(lhs_storage.as_ref())?;
1018        let rhs = storage_payload_native_read_input(rhs_storage.as_ref())?;
1019        if lhs.dtype() != rhs.dtype() {
1020            return Err(anyhow::anyhow!(
1021                "structured payload contraction requires matching payload dtypes"
1022            ));
1023        }
1024        let (lhs, lhs_labels) = normalize_payload_read_for_roots(lhs, &lhs_roots)?;
1025        let (rhs, rhs_labels) = normalize_payload_read_for_roots(rhs, &rhs_roots)?;
1026        let payload = tensor4all_tensorbackend::einsum_native_tensor_reads(
1027            &[(&lhs, lhs_labels.as_slice()), (&rhs, rhs_labels.as_slice())],
1028            &plan.output_payload_roots,
1029        )?;
1030        let storage = storage_from_payload_native(
1031            payload,
1032            &plan.output_payload_dims,
1033            plan.output_axis_classes,
1034        )?;
1035        Self::from_storage(result_indices, Arc::new(storage))
1036    }
1037
1038    fn should_use_structured_payload_contract(&self, other: &Self) -> bool {
1039        let same_payload_dtype = self.storage.is_f64() == other.storage.is_f64()
1040            && self.storage.is_complex() == other.storage.is_complex();
1041        same_payload_dtype
1042            && (self.tracked_compact_payload_value().is_some()
1043                || other.tracked_compact_payload_value().is_some()
1044                || self.storage.axis_classes() != Self::dense_axis_classes(self.indices.len())
1045                || other.storage.axis_classes() != Self::dense_axis_classes(other.indices.len()))
1046    }
1047
1048    fn storage_from_native_with_axis_classes(
1049        native: &NativeTensor,
1050        axis_classes: &[usize],
1051        logical_rank: usize,
1052    ) -> Result<Storage> {
1053        if Self::is_diag_axis_classes(axis_classes) {
1054            match native.dtype() {
1055                DType::F32 | DType::F64 | DType::I32 | DType::I64 | DType::Bool => {
1056                    Storage::from_diag_col_major(
1057                        native_tensor_primal_to_diag_f64(native)?,
1058                        logical_rank,
1059                    )
1060                }
1061                DType::C32 | DType::C64 => Storage::from_diag_col_major(
1062                    native_tensor_primal_to_diag_c64(native)?,
1063                    logical_rank,
1064                ),
1065            }
1066        } else {
1067            native_tensor_primal_to_storage(native)
1068        }
1069    }
1070
1071    fn dense_selected_diag_payload<T: TensorElement + Copy + Zero>(
1072        payload: Vec<T>,
1073        kept_dims: &[usize],
1074        selected_positions: &[usize],
1075    ) -> Vec<T> {
1076        let output_len = kept_dims.iter().product::<usize>();
1077        let mut data = vec![T::zero(); output_len];
1078        if output_len == 0 {
1079            return data;
1080        }
1081
1082        let Some((&first_position, rest)) = selected_positions.split_first() else {
1083            return data;
1084        };
1085        if rest.iter().any(|&position| position != first_position) {
1086            return data;
1087        }
1088
1089        let value = payload[first_position];
1090        if kept_dims.is_empty() {
1091            data[0] = value;
1092            return data;
1093        }
1094
1095        let mut offset = 0usize;
1096        let mut stride = 1usize;
1097        for &dim in kept_dims {
1098            offset += first_position * stride;
1099            stride *= dim;
1100        }
1101        data[offset] = value;
1102        data
1103    }
1104
1105    fn select_diag_indices(
1106        &self,
1107        kept_indices: Vec<DynIndex>,
1108        kept_dims: Vec<usize>,
1109        positions: &[usize],
1110    ) -> Result<Self> {
1111        if self.storage.is_f64() {
1112            let storage = self.storage.materialize(self.indices.len())?;
1113            let payload = storage
1114                .payload_f64_col_major_vec()
1115                .map_err(anyhow::Error::msg)?;
1116            let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
1117            Self::from_dense(kept_indices, data)
1118        } else if self.storage.is_c64() {
1119            let storage = self.storage.materialize(self.indices.len())?;
1120            let payload = storage
1121                .payload_c64_col_major_vec()
1122                .map_err(anyhow::Error::msg)?;
1123            let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
1124            Self::from_dense(kept_indices, data)
1125        } else {
1126            Err(anyhow::anyhow!("unsupported diagonal storage scalar type"))
1127        }
1128    }
1129
1130    fn col_major_strides(dims: &[usize]) -> Result<Vec<isize>> {
1131        let mut strides = Vec::with_capacity(dims.len());
1132        let mut stride = 1isize;
1133        for &dim in dims {
1134            strides.push(stride);
1135            let dim = isize::try_from(dim)
1136                .map_err(|_| anyhow::anyhow!("dimension does not fit in isize"))?;
1137            stride = stride
1138                .checked_mul(dim)
1139                .ok_or_else(|| anyhow::anyhow!("column-major stride overflow"))?;
1140        }
1141        Ok(strides)
1142    }
1143
1144    fn zero_structured_selection<T>(
1145        kept_indices: Vec<DynIndex>,
1146        kept_dims: &[usize],
1147    ) -> Result<Self>
1148    where
1149        T: TensorElement + Zero,
1150    {
1151        let output_len = checked_product(kept_dims)?;
1152        Self::from_dense(kept_indices, vec![T::zero(); output_len])
1153    }
1154
1155    fn select_structured_indices_typed<T>(
1156        &self,
1157        payload: Vec<T>,
1158        kept_axes: &[usize],
1159        kept_indices: Vec<DynIndex>,
1160        kept_dims: Vec<usize>,
1161        selected_axes: &[usize],
1162        positions: &[usize],
1163    ) -> Result<Self>
1164    where
1165        T: TensorElement + StorageScalar + Zero,
1166    {
1167        let payload_dims = self.storage.payload_dims();
1168        let axis_classes = self.storage.axis_classes();
1169        let payload_rank = payload_dims.len();
1170        let mut selected_class_positions = vec![None; payload_rank];
1171
1172        for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1173            let class_id = axis_classes[axis];
1174            if let Some(existing) = selected_class_positions[class_id] {
1175                if existing != position {
1176                    return Self::zero_structured_selection::<T>(kept_indices, &kept_dims);
1177                }
1178            } else {
1179                selected_class_positions[class_id] = Some(position);
1180            }
1181        }
1182
1183        let selected_class_kept = kept_axes
1184            .iter()
1185            .any(|&axis| selected_class_positions[axis_classes[axis]].is_some());
1186        if selected_class_kept {
1187            return self.select_structured_indices_dense(
1188                payload,
1189                kept_axes,
1190                kept_indices,
1191                kept_dims,
1192                &selected_class_positions,
1193            );
1194        }
1195
1196        let mut old_to_new_class = vec![None; payload_rank];
1197        let mut output_payload_dims = Vec::new();
1198        let mut output_axis_classes = Vec::with_capacity(kept_axes.len());
1199        for &axis in kept_axes {
1200            let class_id = axis_classes[axis];
1201            let new_class = match old_to_new_class[class_id] {
1202                Some(new_class) => new_class,
1203                None => {
1204                    let new_class = output_payload_dims.len();
1205                    old_to_new_class[class_id] = Some(new_class);
1206                    output_payload_dims.push(payload_dims[class_id]);
1207                    new_class
1208                }
1209            };
1210            output_axis_classes.push(new_class);
1211        }
1212
1213        let output_len = checked_product(&output_payload_dims)?;
1214        let mut output_payload = Vec::with_capacity(output_len);
1215        for linear in 0..output_len {
1216            let output_payload_index = decode_col_major_linear(linear, &output_payload_dims)?;
1217            let mut input_payload_index = vec![0usize; payload_rank];
1218            for class_id in 0..payload_rank {
1219                input_payload_index[class_id] =
1220                    if let Some(position) = selected_class_positions[class_id] {
1221                        position
1222                    } else if let Some(new_class) = old_to_new_class[class_id] {
1223                        output_payload_index[new_class]
1224                    } else {
1225                        return Err(anyhow::anyhow!(
1226                            "structured payload class {class_id} is neither selected nor kept"
1227                        ));
1228                    };
1229            }
1230            let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1231            output_payload.push(payload[input_linear]);
1232        }
1233
1234        let output_strides = Self::col_major_strides(&output_payload_dims)?;
1235        let storage = Storage::new_structured(
1236            output_payload,
1237            output_payload_dims,
1238            output_strides,
1239            output_axis_classes,
1240        )?;
1241        Self::from_storage(kept_indices, Arc::new(storage))
1242    }
1243
1244    fn select_structured_indices_dense<T>(
1245        &self,
1246        payload: Vec<T>,
1247        kept_axes: &[usize],
1248        kept_indices: Vec<DynIndex>,
1249        kept_dims: Vec<usize>,
1250        selected_class_positions: &[Option<usize>],
1251    ) -> Result<Self>
1252    where
1253        T: TensorElement + Zero,
1254    {
1255        let payload_dims = self.storage.payload_dims();
1256        let axis_classes = self.storage.axis_classes();
1257        let output_len = checked_product(&kept_dims)?;
1258        let mut output = Vec::with_capacity(output_len);
1259
1260        for linear in 0..output_len {
1261            let kept_position = decode_col_major_linear(linear, &kept_dims)?;
1262            let mut input_payload_index = selected_class_positions.to_vec();
1263            let mut is_structural_zero = false;
1264
1265            for (&axis, &position) in kept_axes.iter().zip(kept_position.iter()) {
1266                let class_id = axis_classes[axis];
1267                match input_payload_index[class_id] {
1268                    Some(existing) if existing != position => {
1269                        is_structural_zero = true;
1270                        break;
1271                    }
1272                    Some(_) => {}
1273                    None => input_payload_index[class_id] = Some(position),
1274                }
1275            }
1276
1277            if is_structural_zero {
1278                output.push(T::zero());
1279                continue;
1280            }
1281
1282            let input_payload_index = input_payload_index
1283                .into_iter()
1284                .enumerate()
1285                .map(|(class_id, position)| {
1286                    position.ok_or_else(|| {
1287                        anyhow::anyhow!(
1288                            "structured payload class {class_id} is neither selected nor kept"
1289                        )
1290                    })
1291                })
1292                .collect::<Result<Vec<_>>>()?;
1293            let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1294            output.push(payload[input_linear]);
1295        }
1296
1297        Self::from_dense(kept_indices, output)
1298    }
1299
1300    fn select_structured_indices(
1301        &self,
1302        kept_axes: &[usize],
1303        kept_indices: Vec<DynIndex>,
1304        kept_dims: Vec<usize>,
1305        selected_axes: &[usize],
1306        positions: &[usize],
1307    ) -> Result<Self> {
1308        if self.storage.is_f64() {
1309            let storage = self.storage.materialize(self.indices.len())?;
1310            let payload = storage
1311                .payload_f64_col_major_vec()
1312                .map_err(anyhow::Error::msg)?;
1313            self.select_structured_indices_typed(
1314                payload,
1315                kept_axes,
1316                kept_indices,
1317                kept_dims,
1318                selected_axes,
1319                positions,
1320            )
1321        } else if self.storage.is_c64() {
1322            let storage = self.storage.materialize(self.indices.len())?;
1323            let payload = storage
1324                .payload_c64_col_major_vec()
1325                .map_err(anyhow::Error::msg)?;
1326            self.select_structured_indices_typed(
1327                payload,
1328                kept_axes,
1329                kept_indices,
1330                kept_dims,
1331                selected_axes,
1332                positions,
1333            )
1334        } else {
1335            Err(anyhow::anyhow!(
1336                "unsupported structured storage scalar type"
1337            ))
1338        }
1339    }
1340
1341    fn validate_storage_matches_indices(indices: &[DynIndex], storage: &Storage) -> Result<()> {
1342        let dims = Self::expected_dims_from_indices(indices);
1343        let storage_dims = storage.logical_dims();
1344        if storage_dims != dims {
1345            return Err(anyhow::anyhow!(
1346                "storage logical dims {:?} do not match indices dims {:?}",
1347                storage_dims,
1348                dims
1349            ));
1350        }
1351        if storage.is_diag() {
1352            Self::validate_diag_dims(&dims)?;
1353        }
1354        Ok(())
1355    }
1356
1357    fn try_materialized_inner(&self) -> Result<&EagerTensor> {
1358        if let Some(value) = self.tracked_compact_payload_value() {
1359            if self.compact_payload_is_logical_dense(&value.payload_dims) {
1360                return Ok(value.payload.as_ref());
1361            }
1362        }
1363        if let Some(inner) = self.storage.eager() {
1364            return Ok(inner);
1365        }
1366        if self.eager_cache.get().is_none() {
1367            let dims = self.dims();
1368            let native = profile_pairwise_contract_section("materialize_storage_to_native", || {
1369                let storage = self.storage.materialize(self.indices.len())?;
1370                Self::seed_native_payload(storage.as_ref(), &dims)
1371            })
1372            .context("TensorDynLen materialization failed")?;
1373            record_pairwise_contract_profile_bytes(
1374                "materialize_storage_to_native",
1375                native_tensor_profile_bytes(&native),
1376            );
1377            let _ = self.eager_cache.set(Arc::new(EagerTensor::from_tensor_in(
1378                native,
1379                default_eager_ctx(),
1380            )));
1381        }
1382        self.eager_cache
1383            .get()
1384            .map(|inner| inner.as_ref())
1385            .ok_or_else(|| {
1386                anyhow::anyhow!("TensorDynLen materialization cache was not initialized")
1387            })
1388    }
1389
1390    pub(crate) fn as_inner(&self) -> Result<&EagerTensor> {
1391        self.try_materialized_inner()
1392    }
1393
1394    /// Compute dims from `indices` order.
1395    #[inline]
1396    fn expected_dims_from_indices(indices: &[DynIndex]) -> Vec<usize> {
1397        indices.iter().map(|idx| idx.dim()).collect()
1398    }
1399
1400    /// Get dims in the current `indices` order.
1401    ///
1402    /// This is computed on-demand from `indices` (single source of truth).
1403    ///
1404    /// # Examples
1405    ///
1406    /// ```
1407    /// use tensor4all_core::{DynIndex, TensorDynLen};
1408    ///
1409    /// let i = DynIndex::new_dyn(2);
1410    /// let j = DynIndex::new_dyn(3);
1411    /// let k = DynIndex::new_dyn(4);
1412    /// let t = TensorDynLen::from_dense(
1413    ///     vec![i, j, k],
1414    ///     vec![0.0; 24],
1415    /// ).unwrap();
1416    /// assert_eq!(t.dims(), vec![2, 3, 4]);
1417    /// ```
1418    pub fn dims(&self) -> Vec<usize> {
1419        Self::expected_dims_from_indices(&self.indices)
1420    }
1421
1422    /// Select fixed coordinates for tensor indices and drop those axes.
1423    ///
1424    /// The `selected_indices` slice identifies tensor axes by index identity,
1425    /// and `positions` gives the zero-based coordinate to take on each
1426    /// selected axis. Unselected indices are preserved in their original order.
1427    ///
1428    /// # Arguments
1429    ///
1430    /// * `selected_indices` - Indices to fix and remove from the result. Each
1431    ///   index must appear exactly once in this tensor.
1432    /// * `positions` - Coordinates for `selected_indices`. Each coordinate must
1433    ///   be less than the corresponding index dimension.
1434    ///
1435    /// # Returns
1436    ///
1437    /// A tensor over the unselected indices. Selecting no indices returns a
1438    /// clone of the original tensor. Selecting all indices returns a rank-0
1439    /// scalar tensor. Diagonal and structured tensors are sliced from their
1440    /// compact payload without materializing the original full tensor; the
1441    /// result keeps structured storage when the remaining logical axes can
1442    /// still be represented by axis classes.
1443    ///
1444    /// # Errors
1445    ///
1446    /// Returns an error if the argument lengths differ, a selected index is not
1447    /// present, a selected index is duplicated, or a coordinate is out of range.
1448    ///
1449    /// # Examples
1450    ///
1451    /// ```
1452    /// use tensor4all_core::{DynIndex, TensorDynLen};
1453    ///
1454    /// let i = DynIndex::new_dyn(2);
1455    /// let j = DynIndex::new_dyn(3);
1456    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1457    /// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
1458    ///
1459    /// let selected = tensor.select_indices(&[j], &[1]).unwrap();
1460    /// assert_eq!(selected.dims(), vec![2]);
1461    /// assert_eq!(selected.to_vec::<f64>().unwrap(), vec![3.0, 4.0]);
1462    /// ```
1463    pub fn select_indices(
1464        &self,
1465        selected_indices: &[DynIndex],
1466        positions: &[usize],
1467    ) -> Result<Self> {
1468        if selected_indices.len() != positions.len() {
1469            return Err(anyhow::anyhow!(
1470                "selected_indices length {} does not match positions length {}",
1471                selected_indices.len(),
1472                positions.len()
1473            ));
1474        }
1475        if selected_indices.is_empty() {
1476            return Ok(self.clone());
1477        }
1478
1479        let mut selected_axes = Vec::with_capacity(selected_indices.len());
1480        let mut seen_axes = HashSet::with_capacity(selected_indices.len());
1481        for (selected, &position) in selected_indices.iter().zip(positions.iter()) {
1482            let axis = self
1483                .indices
1484                .iter()
1485                .position(|index| index == selected)
1486                .ok_or_else(|| anyhow::anyhow!("selected index is not present in tensor"))?;
1487            if !seen_axes.insert(axis) {
1488                return Err(anyhow::anyhow!("selected index appears more than once"));
1489            }
1490            let dim = self.indices[axis].dim();
1491            if position >= dim {
1492                return Err(anyhow::anyhow!(
1493                    "selected coordinate {position} is out of range for axis {axis} with dim {dim}"
1494                ));
1495            }
1496            selected_axes.push(axis);
1497        }
1498
1499        let kept_axes = self
1500            .indices
1501            .iter()
1502            .enumerate()
1503            .filter(|(axis, _)| !seen_axes.contains(axis))
1504            .map(|(axis, _)| axis)
1505            .collect::<Vec<_>>();
1506        let kept_indices = kept_axes
1507            .iter()
1508            .map(|&axis| self.indices[axis].clone())
1509            .collect::<Vec<_>>();
1510        let kept_dims = kept_axes
1511            .iter()
1512            .map(|&axis| self.indices[axis].dim())
1513            .collect::<Vec<_>>();
1514
1515        if self.storage.storage_kind() == StorageKind::Diagonal {
1516            return self.select_diag_indices(kept_indices, kept_dims, positions);
1517        }
1518        if self.storage.storage_kind() == StorageKind::Structured {
1519            return self.select_structured_indices(
1520                &kept_axes,
1521                kept_indices,
1522                kept_dims,
1523                &selected_axes,
1524                positions,
1525            );
1526        }
1527        if self.storage.storage_kind() != StorageKind::Dense {
1528            return Err(anyhow::anyhow!(
1529                "select_indices got unsupported storage kind {:?}",
1530                self.storage.storage_kind()
1531            ));
1532        }
1533
1534        let rank = self.indices.len();
1535        let mut starts = vec![0_i64; rank];
1536        let mut slice_sizes = self.dims();
1537        for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1538            starts[axis] = i64::try_from(position)
1539                .map_err(|_| anyhow::anyhow!("selected coordinate does not fit in i64"))?;
1540            slice_sizes[axis] = 1;
1541        }
1542
1543        let starts_tensor = EagerTensor::from_tensor_in(
1544            NativeTensor::from_vec_col_major(vec![rank], starts),
1545            default_eager_ctx(),
1546        );
1547        let sliced = self
1548            .try_materialized_inner()?
1549            .dynamic_slice(&starts_tensor, &slice_sizes)?;
1550        Self::from_inner(kept_indices, sliced.reshape(&kept_dims)?)
1551    }
1552
1553    /// Stack tensors along a newly inserted index.
1554    ///
1555    /// Each input must have exactly the same index order and dimensions. The
1556    /// `new_index` dimension must match the number of input tensors. The
1557    /// `axis` argument follows tenferro/PyTorch-style insertion semantics:
1558    /// `0` inserts before the first existing axis and `-1` appends a trailing
1559    /// axis. Use `axis = -1` for batched contractions because tenferro uses
1560    /// trailing batch dimensions as the canonical batched-GEMM layout.
1561    ///
1562    /// # Errors
1563    ///
1564    /// Returns an error if no tensors are provided, the new index dimension
1565    /// does not match the number of tensors, an input has a different index
1566    /// order, `axis` is outside the valid insertion range, or a tracked
1567    /// structured-AD tensor uses compact storage that would need dense
1568    /// materialization.
1569    ///
1570    /// # Examples
1571    ///
1572    /// ```
1573    /// use tensor4all_core::{DynIndex, TensorDynLen};
1574    ///
1575    /// let i = DynIndex::new_dyn(2);
1576    /// let batch = DynIndex::new_dyn(2);
1577    /// let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0_f64, 2.0]).unwrap();
1578    /// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0_f64, 4.0]).unwrap();
1579    ///
1580    /// let stacked = TensorDynLen::stack_along_new_index(&[&a, &b], batch.clone(), -1).unwrap();
1581    ///
1582    /// assert_eq!(stacked.indices(), &[i, batch]);
1583    /// assert_eq!(stacked.to_vec::<f64>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
1584    /// ```
1585    pub fn stack_along_new_index(
1586        tensors: &[&Self],
1587        new_index: DynIndex,
1588        axis: isize,
1589    ) -> Result<Self> {
1590        let first = tensors
1591            .first()
1592            .copied()
1593            .ok_or_else(|| anyhow::anyhow!("stack_along_new_index requires at least one tensor"))?;
1594        anyhow::ensure!(
1595            new_index.dim() == tensors.len(),
1596            "stack_along_new_index: new index dim {} does not match tensor count {}",
1597            new_index.dim(),
1598            tensors.len()
1599        );
1600
1601        let base_indices = first.indices.clone();
1602        for tensor in tensors.iter().copied().skip(1) {
1603            anyhow::ensure!(
1604                tensor.indices == base_indices,
1605                "stack_along_new_index: input tensors must have identical index order"
1606            );
1607        }
1608        for &tensor in tensors {
1609            tensor.ensure_shape_packing_preserves_ad("stack_along_new_index")?;
1610        }
1611
1612        let insert_axis =
1613            Self::normalize_insert_axis("stack_along_new_index", axis, base_indices.len())?;
1614        let mut result_indices = base_indices;
1615        result_indices.insert(insert_axis, new_index);
1616
1617        let inner_refs = tensors
1618            .iter()
1619            .map(|tensor| tensor.try_materialized_inner())
1620            .collect::<Result<Vec<_>>>()?;
1621        let stacked = EagerTensor::stack(&inner_refs, axis)?;
1622        Self::from_inner(result_indices, stacked)
1623    }
1624
1625    /// Select positions along one index and replace it with a new index.
1626    ///
1627    /// This is the retained-axis counterpart to [`Self::select_indices`]:
1628    /// instead of fixing one coordinate and removing the index, it gathers a
1629    /// list of positions and keeps the gathered axis under `target_index`.
1630    /// Repeated positions are allowed; reverse-mode AD accumulates repeated
1631    /// cotangents through tenferro's scatter-add gather transpose.
1632    ///
1633    /// # Errors
1634    ///
1635    /// Returns an error if `source_index` is not present, `target_index.dim()`
1636    /// differs from `positions.len()`, or any position is out of range for the
1637    /// source index. A tracked structured-AD tensor with compact storage is
1638    /// also rejected because dense materialization would detach gradients.
1639    ///
1640    /// # Examples
1641    ///
1642    /// ```
1643    /// use tensor4all_core::{DynIndex, TensorDynLen};
1644    ///
1645    /// let source = DynIndex::new_dyn(3);
1646    /// let target = DynIndex::new_dyn(2);
1647    /// let tensor = TensorDynLen::from_dense(
1648    ///     vec![source.clone()],
1649    ///     vec![10.0_f64, 20.0, 30.0],
1650    /// ).unwrap();
1651    ///
1652    /// let selected = tensor.index_select(&source, target.clone(), &[2, 0]).unwrap();
1653    ///
1654    /// assert_eq!(selected.indices(), &[target]);
1655    /// assert_eq!(selected.to_vec::<f64>().unwrap(), vec![30.0, 10.0]);
1656    /// ```
1657    pub fn index_select(
1658        &self,
1659        source_index: &DynIndex,
1660        target_index: DynIndex,
1661        positions: &[usize],
1662    ) -> Result<Self> {
1663        anyhow::ensure!(
1664            target_index.dim() == positions.len(),
1665            "index_select: target index dim {} does not match position count {}",
1666            target_index.dim(),
1667            positions.len()
1668        );
1669        let axis = self
1670            .indices
1671            .iter()
1672            .position(|index| index == source_index)
1673            .ok_or_else(|| anyhow::anyhow!("index_select: source index is not present"))?;
1674        let source_dim = self.indices[axis].dim();
1675        for &position in positions {
1676            anyhow::ensure!(
1677                position < source_dim,
1678                "index_select: position {position} is out of range for source dim {source_dim}"
1679            );
1680        }
1681        self.ensure_shape_packing_preserves_ad("index_select")?;
1682
1683        let axis = isize::try_from(axis)
1684            .map_err(|_| anyhow::anyhow!("index_select: axis does not fit in isize"))?;
1685        let selected = self
1686            .try_materialized_inner()?
1687            .index_select(axis, positions)?;
1688        let mut result_indices = self.indices.clone();
1689        result_indices[axis as usize] = target_index;
1690        Self::from_inner(result_indices, selected)
1691    }
1692
1693    /// Create a new tensor with dynamic rank.
1694    ///
1695    /// # Errors
1696    /// Returns an error if the storage logical dimensions do not match the
1697    /// supplied indices, if diagonal storage has unequal logical dimensions,
1698    /// or if duplicate indices are provided.
1699    ///
1700    /// # Examples
1701    ///
1702    /// ```
1703    /// use tensor4all_core::{DynIndex, TensorDynLen};
1704    /// use tensor4all_tensorbackend::Storage;
1705    /// use std::sync::Arc;
1706    ///
1707    /// let i = DynIndex::new_dyn(3);
1708    /// let storage = Arc::new(Storage::new_dense::<f64>(3).unwrap());
1709    /// let t = TensorDynLen::new(vec![i], storage).unwrap();
1710    /// assert_eq!(t.dims(), vec![3]);
1711    /// ```
1712    pub fn new(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1713        Self::from_storage(indices, storage)
1714    }
1715
1716    /// Create a new tensor with dynamic rank, automatically computing dimensions from indices.
1717    ///
1718    /// This is a convenience constructor that extracts dimensions from indices using `IndexLike::dim()`.
1719    ///
1720    /// # Errors
1721    /// Returns an error if the storage logical dimensions do not match the
1722    /// supplied indices, if diagonal storage has unequal logical dimensions,
1723    /// or if duplicate indices are provided.
1724    ///
1725    /// # Examples
1726    ///
1727    /// ```
1728    /// use tensor4all_core::{DynIndex, TensorDynLen};
1729    /// use tensor4all_tensorbackend::Storage;
1730    /// use std::sync::Arc;
1731    ///
1732    /// let i = DynIndex::new_dyn(4);
1733    /// let storage = Arc::new(Storage::new_dense::<f64>(4).unwrap());
1734    /// let t = TensorDynLen::from_indices(vec![i], storage).unwrap();
1735    /// assert_eq!(t.dims(), vec![4]);
1736    /// ```
1737    pub fn from_indices(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1738        Self::new(indices, storage)
1739    }
1740
1741    /// Create a tensor from explicit compact storage.
1742    ///
1743    /// # Examples
1744    ///
1745    /// ```
1746    /// use tensor4all_core::{DynIndex, TensorDynLen};
1747    /// use tensor4all_tensorbackend::Storage;
1748    /// use std::sync::Arc;
1749    ///
1750    /// let i = DynIndex::new_dyn(2);
1751    /// let j = DynIndex::new_dyn(2);
1752    /// let storage = Arc::new(Storage::new_diag(vec![1.0_f64, 2.0]).unwrap());
1753    /// let t = TensorDynLen::from_storage(vec![i, j], storage).unwrap();
1754    /// assert_eq!(t.dims(), vec![2, 2]);
1755    /// ```
1756    pub fn from_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1757        Self::validate_indices(&indices)?;
1758        Self::validate_storage_matches_indices(&indices, storage.as_ref())?;
1759        Ok(Self {
1760            indices,
1761            storage: TensorDynLenStorage::from_storage(storage),
1762            structured_ad: None,
1763            eager_cache: Self::empty_eager_cache(),
1764        })
1765    }
1766
1767    /// Create a tensor from explicit structured storage.
1768    ///
1769    /// This is an alias for [`TensorDynLen::from_storage`] with a name that
1770    /// emphasizes that compact structured metadata is preserved.
1771    ///
1772    /// # Errors
1773    ///
1774    /// Returns an error if the storage logical dimensions do not match the
1775    /// supplied indices, or if duplicate indices are provided.
1776    ///
1777    /// # Examples
1778    ///
1779    /// ```
1780    /// use std::sync::Arc;
1781    /// use tensor4all_core::{DynIndex, TensorDynLen};
1782    /// use tensor4all_tensorbackend::{Storage, StorageKind};
1783    ///
1784    /// let i = DynIndex::new_dyn(2);
1785    /// let j = DynIndex::new_dyn(2);
1786    /// let storage = Arc::new(Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap());
1787    /// let tensor = TensorDynLen::from_structured_storage(vec![i, j], storage).unwrap();
1788    /// assert_eq!(tensor.storage().storage_kind(), StorageKind::Diagonal);
1789    /// ```
1790    pub fn from_structured_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1791        Self::from_storage(indices, storage)
1792    }
1793
1794    /// Create a tensor from a native tenferro payload.
1795    pub(crate) fn from_native(indices: Vec<DynIndex>, native: NativeTensor) -> Result<Self> {
1796        let axis_classes = Self::dense_axis_classes(indices.len());
1797        Self::from_native_with_axis_classes(indices, native, axis_classes)
1798    }
1799
1800    pub(crate) fn from_native_with_axis_classes(
1801        indices: Vec<DynIndex>,
1802        native: NativeTensor,
1803        axis_classes: Vec<usize>,
1804    ) -> Result<Self> {
1805        Self::from_inner_with_axis_classes(
1806            indices,
1807            EagerTensor::from_tensor_in(native, default_eager_ctx()),
1808            axis_classes,
1809        )
1810    }
1811
1812    pub(crate) fn from_inner(indices: Vec<DynIndex>, inner: EagerTensor) -> Result<Self> {
1813        let axis_classes = Self::dense_axis_classes(indices.len());
1814        Self::from_inner_with_axis_classes(indices, inner, axis_classes)
1815    }
1816
1817    pub(crate) fn from_diag_inner(
1818        indices: Vec<DynIndex>,
1819        payload_inner: EagerTensor,
1820    ) -> Result<Self> {
1821        let dims = Self::expected_dims_from_indices(&indices);
1822        Self::validate_indices(&indices)?;
1823        Self::validate_diag_dims(&dims)?;
1824        Self::validate_diag_payload_len(payload_inner.data().shape().iter().product(), &dims)?;
1825        let axis_classes = Self::diag_axis_classes(dims.len());
1826        let diag_inner = payload_inner.embed_diag(0, 1)?;
1827        Self::from_inner_with_axis_classes(indices, diag_inner, axis_classes)
1828    }
1829
1830    pub(crate) fn from_inner_with_axis_classes(
1831        indices: Vec<DynIndex>,
1832        inner: EagerTensor,
1833        axis_classes: Vec<usize>,
1834    ) -> Result<Self> {
1835        let dims = profile_pairwise_contract_section("from_inner_expected_dims", || {
1836            Self::expected_dims_from_indices(&indices)
1837        });
1838        profile_pairwise_contract_section("from_inner_validate_indices", || {
1839            Self::validate_indices(&indices)
1840        })?;
1841        if dims != inner.data().shape() {
1842            return Err(anyhow::anyhow!(
1843                "native payload dims {:?} do not match indices dims {:?}",
1844                inner.data().shape(),
1845                dims
1846            ));
1847        }
1848        if Self::is_diag_axis_classes(&axis_classes) {
1849            profile_pairwise_contract_section("from_inner_validate_diag_dims", || {
1850                Self::validate_diag_dims(&dims)
1851            })?;
1852        }
1853        let (storage, eager_cache) = if axis_classes == Self::dense_axis_classes(indices.len()) {
1854            (
1855                TensorDynLenStorage::from_eager_dense(inner, indices.len()),
1856                Self::empty_eager_cache(),
1857            )
1858        } else {
1859            let storage = profile_pairwise_contract_section("from_inner_storage_snapshot", || {
1860                Self::storage_from_native_with_axis_classes(
1861                    inner.data(),
1862                    &axis_classes,
1863                    indices.len(),
1864                )
1865            })?;
1866            record_pairwise_contract_profile_bytes(
1867                "from_inner_storage_snapshot",
1868                native_tensor_profile_bytes(inner.data()),
1869            );
1870            (
1871                TensorDynLenStorage::from_storage(Arc::new(storage)),
1872                profile_pairwise_contract_section("from_inner_eager_cache", || {
1873                    Self::eager_cache_with(inner)
1874                }),
1875            )
1876        };
1877        Ok(Self {
1878            indices,
1879            storage,
1880            structured_ad: None,
1881            eager_cache,
1882        })
1883    }
1884
1885    /// Borrow the indices.
1886    pub fn indices(&self) -> &[DynIndex] {
1887        &self.indices
1888    }
1889
1890    /// Borrow the native payload.
1891    pub(crate) fn as_native(&self) -> Result<&NativeTensor> {
1892        Ok(self.try_materialized_inner()?.data())
1893    }
1894
1895    /// Enable reverse-mode AD tracking on this tensor by creating a tracked leaf.
1896    pub fn enable_grad(self) -> Result<Self> {
1897        let materialized = self.storage.materialize(self.indices.len())?;
1898        let payload = storage_payload_native(materialized.as_ref())
1899            .context("TensorDynLen::enable_grad failed")?;
1900        let payload_dims = self.storage.payload_dims().to_vec();
1901        let axis_classes = self.storage.axis_classes().to_vec();
1902        Ok(Self {
1903            indices: self.indices,
1904            storage: self.storage,
1905            structured_ad: Some(Arc::new(StructuredAdValue {
1906                payload: Arc::new(EagerTensor::requires_grad_in(payload, default_eager_ctx())),
1907                payload_dims,
1908                axis_classes,
1909            })),
1910            eager_cache: Self::empty_eager_cache(),
1911        })
1912    }
1913
1914    /// Report whether this tensor participates in gradient tracking.
1915    pub fn tracks_grad(&self) -> bool {
1916        self.structured_ad
1917            .as_ref()
1918            .is_some_and(|value| value.payload.tracks_grad())
1919            || self.storage.eager().is_some_and(EagerTensor::tracks_grad)
1920            || self
1921                .eager_cache
1922                .get()
1923                .is_some_and(|inner| inner.tracks_grad())
1924    }
1925
1926    /// Return the accumulated gradient, if one has been stored.
1927    pub fn grad(&self) -> Result<Option<Self>> {
1928        if let Some(value) = self.tracked_compact_payload_value() {
1929            return value
1930                .payload
1931                .grad()
1932                .map(|grad| {
1933                    let storage = storage_from_payload_native(
1934                        grad.as_ref().clone(),
1935                        &value.payload_dims,
1936                        value.axis_classes.clone(),
1937                    )?;
1938                    Self::from_storage(self.indices.clone(), Arc::new(storage))
1939                })
1940                .transpose();
1941        }
1942        self.try_materialized_inner()?
1943            .grad()
1944            .map(|grad| {
1945                Self::from_native_with_axis_classes(
1946                    self.indices.clone(),
1947                    grad.as_ref().clone(),
1948                    self.storage.axis_classes().to_vec(),
1949                )
1950            })
1951            .transpose()
1952    }
1953
1954    /// Clear the accumulated gradient stored for this tensor.
1955    pub fn clear_grad(&self) -> Result<()> {
1956        if let Some(value) = self.tracked_compact_payload_value() {
1957            value.payload.clear_grad();
1958        }
1959        if let Some(inner) = self.storage.eager() {
1960            inner.clear_grad();
1961        }
1962        if let Some(inner) = self.eager_cache.get() {
1963            inner.clear_grad();
1964        }
1965        Ok(())
1966    }
1967
1968    /// Run reverse-mode autodiff from this scalar tensor.
1969    pub fn backward(&self) -> Result<()> {
1970        if let Some(value) = self.tracked_compact_payload_value() {
1971            return value
1972                .payload
1973                .backward()
1974                .map(|_| ())
1975                .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"));
1976        }
1977        self.try_materialized_inner()?
1978            .backward()
1979            .map(|_| ())
1980            .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"))
1981    }
1982
1983    /// Detach this tensor from the reverse graph.
1984    pub fn detach(&self) -> Result<Self> {
1985        if self.tracked_compact_payload_value().is_some() {
1986            return Self::from_storage(
1987                self.indices.clone(),
1988                self.storage.materialize(self.indices.len())?,
1989            );
1990        }
1991        Self::from_inner_with_axis_classes(
1992            self.indices.clone(),
1993            self.try_materialized_inner()?.detach(),
1994            self.storage.axis_classes().to_vec(),
1995        )
1996    }
1997
1998    /// Check if this tensor is already in canonical form.
1999    pub fn is_simple(&self) -> bool {
2000        true
2001    }
2002
2003    /// Materialize the primal snapshot as storage.
2004    pub fn to_storage(&self) -> Result<Arc<Storage>> {
2005        self.storage.materialize(self.indices.len())
2006    }
2007
2008    /// Returns the authoritative compact storage.
2009    pub fn storage(&self) -> Arc<Storage> {
2010        self.storage
2011            .materialize(self.indices.len())
2012            .expect("TensorDynLen storage materialization failed")
2013    }
2014
2015    /// Sum all elements, returning `AnyScalar`.
2016    ///
2017    /// # Examples
2018    ///
2019    /// ```
2020    /// use tensor4all_core::{DynIndex, TensorDynLen};
2021    ///
2022    /// let i = DynIndex::new_dyn(3);
2023    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0, 3.0]).unwrap();
2024    /// let s = t.sum().unwrap();
2025    /// assert!((s.real() - 6.0).abs() < 1e-12);
2026    /// ```
2027    pub fn sum(&self) -> Result<AnyScalar> {
2028        if self.indices.is_empty() {
2029            return AnyScalar::from_tensor(self.clone());
2030        }
2031        let axes: Vec<usize> = (0..self.indices.len()).collect();
2032        let reduced = self.try_materialized_inner()?.reduce_sum(&axes)?;
2033        AnyScalar::from_tensor(Self::from_inner(Vec::new(), reduced)?)
2034    }
2035
2036    /// Extract the scalar value from a 0-dimensional tensor (or 1-element tensor).
2037    ///
2038    /// This is similar to Julia's `only()` function.
2039    ///
2040    /// # Panics
2041    ///
2042    /// Panics if the tensor has more than one element.
2043    ///
2044    /// # Example
2045    ///
2046    /// ```
2047    /// use tensor4all_core::{TensorDynLen, AnyScalar};
2048    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2049    ///
2050    /// // Create a scalar tensor (0 dimensions, 1 element)
2051    /// let indices: Vec<Index<DynId>> = vec![];
2052    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![42.0]).unwrap();
2053    ///
2054    /// assert_eq!(tensor.only().unwrap().real(), 42.0);
2055    /// ```
2056    pub fn only(&self) -> Result<AnyScalar> {
2057        let dims = self.dims();
2058        let total_size = checked_product(&dims)?;
2059        anyhow::ensure!(
2060            total_size == 1 || dims.is_empty(),
2061            "only() requires a scalar tensor (1 element), got {} elements with dims {:?}",
2062            if dims.is_empty() { 1 } else { total_size },
2063            dims
2064        );
2065        self.sum()
2066    }
2067
2068    /// Permute the tensor dimensions using the given new indices order.
2069    ///
2070    /// This is the main permutation method that takes the desired new indices
2071    /// and automatically computes the corresponding permutation of dimensions
2072    /// and data. The new indices must be a permutation of the original indices
2073    /// (matched by ID).
2074    ///
2075    /// # Arguments
2076    /// * `new_indices` - The desired new indices order. Must be a permutation
2077    ///   of `self.indices` (matched by ID).
2078    ///
2079    /// # Panics
2080    /// Panics if `new_indices.len() != self.indices.len()`, if any index ID
2081    /// doesn't match, or if there are duplicate indices.
2082    ///
2083    /// # Example
2084    /// ```
2085    /// use tensor4all_core::TensorDynLen;
2086    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2087    ///
2088    /// // Create a 2×3 tensor
2089    /// let i = Index::new_dyn(2);
2090    /// let j = Index::new_dyn(3);
2091    /// let indices = vec![i.clone(), j.clone()];
2092    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![0.0; 6]).unwrap();
2093    ///
2094    /// // Permute to 3×2: swap the two dimensions by providing new indices order
2095    /// let permuted = tensor.permute_indices(&[j, i]).unwrap();
2096    /// assert_eq!(permuted.dims(), vec![3, 2]);
2097    /// ```
2098    pub fn permute_indices(&self, new_indices: &[DynIndex]) -> Result<Self> {
2099        // Compute permutation by matching IDs
2100        let perm = compute_permutation_from_indices(&self.indices, new_indices)?;
2101        if perm.iter().copied().eq(0..perm.len()) {
2102            return Ok(Self {
2103                indices: new_indices.to_vec(),
2104                storage: self.storage.clone(),
2105                structured_ad: self.structured_ad.clone(),
2106                eager_cache: Arc::clone(&self.eager_cache),
2107            });
2108        }
2109
2110        let permuted = self.try_materialized_inner()?.transpose(&perm)?;
2111        let axis_classes = self.permute_axis_classes(&perm);
2112        Self::from_inner_with_axis_classes(new_indices.to_vec(), permuted, axis_classes)
2113    }
2114
2115    /// Permute the tensor dimensions, returning a new tensor.
2116    ///
2117    /// This method reorders the indices, dimensions, and data according to the
2118    /// given permutation. The permutation specifies which old axis each new
2119    /// axis corresponds to: `new_axis[i] = old_axis[perm[i]]`.
2120    ///
2121    /// # Arguments
2122    /// * `perm` - The permutation: `perm[i]` is the old axis index for new axis `i`
2123    ///
2124    /// # Panics
2125    /// Panics if `perm.len() != self.indices.len()` or if the permutation is invalid.
2126    ///
2127    /// # Example
2128    /// ```
2129    /// use tensor4all_core::TensorDynLen;
2130    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2131    ///
2132    /// // Create a 2×3 tensor
2133    /// let indices = vec![
2134    ///     Index::new_dyn(2),
2135    ///     Index::new_dyn(3),
2136    /// ];
2137    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![0.0; 6]).unwrap();
2138    ///
2139    /// // Permute to 3×2: swap the two dimensions
2140    /// let permuted = tensor.permute(&[1, 0]).unwrap();
2141    /// assert_eq!(permuted.dims(), vec![3, 2]);
2142    /// ```
2143    pub fn permute(&self, perm: &[usize]) -> Result<Self> {
2144        anyhow::ensure!(
2145            perm.len() == self.indices.len(),
2146            "permutation length must match tensor rank"
2147        );
2148        let mut seen = HashSet::new();
2149        for &axis in perm {
2150            anyhow::ensure!(
2151                axis < self.indices.len(),
2152                "permutation axis {axis} out of range"
2153            );
2154            anyhow::ensure!(seen.insert(axis), "duplicate axis {axis} in permutation");
2155        }
2156        if perm.iter().copied().eq(0..perm.len()) {
2157            return Ok(self.clone());
2158        }
2159
2160        // Permute indices
2161        let new_indices: Vec<DynIndex> = perm.iter().map(|&i| self.indices[i].clone()).collect();
2162        let permuted = self.try_materialized_inner()?.transpose(perm)?;
2163        let axis_classes = self.permute_axis_classes(perm);
2164        Self::from_inner_with_axis_classes(new_indices, permuted, axis_classes)
2165    }
2166
2167    pub(crate) fn try_contract_pairwise_default(&self, other: &Self) -> Result<Self> {
2168        self.try_contract_pairwise_default_with_options(other, PairwiseContractionOptions::new())
2169    }
2170
2171    pub(crate) fn try_contract_pairwise_default_with_options(
2172        &self,
2173        other: &Self,
2174        options: PairwiseContractionOptions,
2175    ) -> Result<Self> {
2176        let self_indices = profile_pairwise_contract_section("operand_indices", || {
2177            self.operand_indices_for_contraction(options.lhs_conj)
2178        });
2179        let other_indices = profile_pairwise_contract_section("operand_indices", || {
2180            other.operand_indices_for_contraction(options.rhs_conj)
2181        });
2182        let self_dims = profile_pairwise_contract_section("expected_dims", || {
2183            Self::expected_dims_from_indices(&self_indices)
2184        });
2185        let other_dims = profile_pairwise_contract_section("expected_dims", || {
2186            Self::expected_dims_from_indices(&other_indices)
2187        });
2188        let spec = profile_pairwise_contract_section("prepare_contraction", || {
2189            prepare_contraction(&self_indices, &self_dims, &other_indices, &other_dims)
2190        })
2191        .context("contraction preparation failed")?;
2192        let result_axis_classes = profile_pairwise_contract_section("result_axis_classes", || {
2193            Self::binary_contraction_axis_classes(
2194                self.storage.axis_classes(),
2195                &spec.axes_a,
2196                other.storage.axis_classes(),
2197                &spec.axes_b,
2198            )
2199        });
2200
2201        if profile_pairwise_contract_section("structured_check", || {
2202            self.should_use_structured_payload_contract(other)
2203        }) {
2204            if options.has_conj() {
2205                let lhs = if options.lhs_conj {
2206                    self.conj()
2207                } else {
2208                    self.clone()
2209                };
2210                let rhs = if options.rhs_conj {
2211                    other.conj()
2212                } else {
2213                    other.clone()
2214                };
2215                return profile_pairwise_contract_section("structured_conj_fallback", || {
2216                    lhs.try_contract_pairwise_default(&rhs)
2217                });
2218            }
2219            return profile_pairwise_contract_section("structured_payload_contract", || {
2220                self.contract_structured_payloads(
2221                    other,
2222                    spec.result_indices.into_vec(),
2223                    &spec.axes_a,
2224                    &spec.axes_b,
2225                )
2226            });
2227        }
2228
2229        if self.indices.is_empty() && other.indices.is_empty() {
2230            if options.has_conj() {
2231                let lhs = if options.lhs_conj {
2232                    self.conj()
2233                } else {
2234                    self.clone()
2235                };
2236                let rhs = if options.rhs_conj {
2237                    other.conj()
2238                } else {
2239                    other.clone()
2240                };
2241                return lhs.try_contract_pairwise_default(&rhs);
2242            }
2243            let result = profile_pairwise_contract_section("scalar_mul", || {
2244                Ok::<_, anyhow::Error>(
2245                    self.try_materialized_inner()?
2246                        .mul(other.try_materialized_inner()?)?,
2247                )
2248            })?;
2249            return profile_pairwise_contract_section("from_inner", || {
2250                Self::from_inner(spec.result_indices.into_vec(), result)
2251            });
2252        }
2253
2254        let self_native = profile_pairwise_contract_section("as_native", || self.as_native())?;
2255        let other_native = profile_pairwise_contract_section("as_native", || other.as_native())?;
2256        if self_native.dtype() != other_native.dtype() {
2257            if options.has_conj() {
2258                let lhs = if options.lhs_conj {
2259                    self.conj()
2260                } else {
2261                    self.clone()
2262                };
2263                let rhs = if options.rhs_conj {
2264                    other.conj()
2265                } else {
2266                    other.clone()
2267                };
2268                return lhs.try_contract_pairwise_default(&rhs);
2269            }
2270            let result_native = profile_pairwise_contract_section("native_contract", || {
2271                contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)
2272            })?;
2273            return profile_pairwise_contract_section("from_native", || {
2274                Self::from_native_with_axis_classes(
2275                    spec.result_indices.into_vec(),
2276                    result_native,
2277                    result_axis_classes,
2278                )
2279            });
2280        }
2281
2282        let config = profile_pairwise_contract_section("build_dot_general_config", || {
2283            Self::binary_dot_general_config(&spec.axes_a, &spec.axes_b)
2284        })?;
2285        let result = profile_pairwise_contract_section("dot_general_with_conj", || {
2286            let lhs = profile_pairwise_contract_section("lhs_try_materialized_inner", || {
2287                self.try_materialized_inner()
2288            })?;
2289            let rhs = profile_pairwise_contract_section("rhs_try_materialized_inner", || {
2290                other.try_materialized_inner()
2291            })?;
2292            profile_pairwise_contract_section("dot_general_execute", || {
2293                lhs.dot_general_with_conj(rhs, &config, options.lhs_conj, options.rhs_conj)
2294            })
2295            .map_err(anyhow::Error::from)
2296        })?;
2297        record_pairwise_contract_profile_bytes(
2298            "dot_general_output",
2299            native_tensor_profile_bytes(result.data()),
2300        );
2301        profile_pairwise_contract_section("from_inner_axis_classes", || {
2302            Self::from_inner_with_axis_classes(
2303                spec.result_indices.into_vec(),
2304                result,
2305                result_axis_classes,
2306            )
2307        })
2308    }
2309
2310    pub(crate) fn try_tensordot_pairwise_explicit(
2311        &self,
2312        other: &Self,
2313        pairs: &[(DynIndex, DynIndex)],
2314    ) -> Result<Self> {
2315        use crate::index_ops::ContractionError;
2316
2317        let self_dims = Self::expected_dims_from_indices(&self.indices);
2318        let other_dims = Self::expected_dims_from_indices(&other.indices);
2319        let spec = prepare_contraction_pairs(
2320            &self.indices,
2321            &self_dims,
2322            &other.indices,
2323            &other_dims,
2324            pairs,
2325        )
2326        .map_err(|e| match e {
2327            ContractionError::NoCommonIndices => {
2328                anyhow::anyhow!("tensordot: No pairs specified for contraction")
2329            }
2330            ContractionError::BatchContractionNotImplemented => anyhow::anyhow!(
2331                "tensordot: Common index found but not in contraction pairs. \
2332                         Batch contraction is not yet implemented."
2333            ),
2334            ContractionError::IndexNotFound { tensor } => {
2335                anyhow::anyhow!("tensordot: Index not found in {} tensor", tensor)
2336            }
2337            ContractionError::DimensionMismatch {
2338                pos_a,
2339                pos_b,
2340                dim_a,
2341                dim_b,
2342            } => anyhow::anyhow!(
2343                "tensordot: Dimension mismatch: self[{}]={} != other[{}]={}",
2344                pos_a,
2345                dim_a,
2346                pos_b,
2347                dim_b
2348            ),
2349            ContractionError::DuplicateAxis { tensor, pos } => {
2350                anyhow::anyhow!("tensordot: Duplicate axis {} in {} tensor", pos, tensor)
2351            }
2352        })?;
2353        let result_axis_classes = Self::binary_contraction_axis_classes(
2354            self.storage.axis_classes(),
2355            &spec.axes_a,
2356            other.storage.axis_classes(),
2357            &spec.axes_b,
2358        );
2359
2360        if self.should_use_structured_payload_contract(other) {
2361            return self.contract_structured_payloads(
2362                other,
2363                spec.result_indices.into_vec(),
2364                &spec.axes_a,
2365                &spec.axes_b,
2366            );
2367        }
2368
2369        if self.indices.is_empty() && other.indices.is_empty() {
2370            let result = self
2371                .try_materialized_inner()?
2372                .mul(other.try_materialized_inner()?)
2373                .map_err(|e| anyhow::anyhow!("tensordot scalar multiply failed: {e}"))?;
2374            return Self::from_inner(spec.result_indices.into_vec(), result);
2375        }
2376
2377        let self_native = self.as_native()?;
2378        let other_native = other.as_native()?;
2379        if self_native.dtype() != other_native.dtype() {
2380            let result_native =
2381                contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)?;
2382            return Self::from_native_with_axis_classes(
2383                spec.result_indices.into_vec(),
2384                result_native,
2385                result_axis_classes,
2386            );
2387        }
2388
2389        let subscripts = Self::build_binary_einsum_subscripts(
2390            self.indices.len(),
2391            &spec.axes_a,
2392            other.indices.len(),
2393            &spec.axes_b,
2394        )?;
2395        let result = eager_einsum_ad(
2396            &[
2397                self.try_materialized_inner()?,
2398                other.try_materialized_inner()?,
2399            ],
2400            &subscripts,
2401        )
2402        .map_err(|e| anyhow::anyhow!("tensordot failed: {e}"))?;
2403        Self::from_inner_with_axis_classes(
2404            spec.result_indices.into_vec(),
2405            result,
2406            result_axis_classes,
2407        )
2408    }
2409
2410    pub(crate) fn try_outer_product_pairwise(&self, other: &Self) -> Result<Self> {
2411        use anyhow::Context;
2412
2413        // Check for common indices - outer product should have none
2414        let common_positions = common_ind_positions(&self.indices, &other.indices);
2415        if !common_positions.is_empty() {
2416            let common_ids: Vec<_> = common_positions
2417                .iter()
2418                .map(|(pos_a, _)| self.indices[*pos_a].id())
2419                .collect();
2420            return Err(anyhow::anyhow!(
2421                "outer_product: tensors have common indices {:?}. \
2422                 Use tensordot to contract common indices, or use sim() to replace \
2423                 indices with fresh IDs before computing outer product.",
2424                common_ids
2425            ))
2426            .context("outer_product: common indices found");
2427        }
2428
2429        // Build result indices and dimensions
2430        let mut result_indices = self.indices.clone();
2431        result_indices.extend(other.indices.iter().cloned());
2432        let result_axis_classes = Self::binary_contraction_axis_classes(
2433            self.storage.axis_classes(),
2434            &[],
2435            other.storage.axis_classes(),
2436            &[],
2437        );
2438        if self.should_use_structured_payload_contract(other) {
2439            return self.contract_structured_payloads(other, result_indices, &[], &[]);
2440        }
2441        let self_native = self.as_native()?;
2442        let other_native = other.as_native()?;
2443        if self_native.dtype() != other_native.dtype() {
2444            let result_native = contract_native_tensor(self_native, &[], other_native, &[])?;
2445            return Self::from_native_with_axis_classes(
2446                result_indices,
2447                result_native,
2448                result_axis_classes,
2449            );
2450        }
2451
2452        let subscripts = Self::build_binary_einsum_subscripts(
2453            self.indices.len(),
2454            &[],
2455            other.indices.len(),
2456            &[],
2457        )?;
2458        let result = eager_einsum_ad(
2459            &[
2460                self.try_materialized_inner()?,
2461                other.try_materialized_inner()?,
2462            ],
2463            &subscripts,
2464        )
2465        .map_err(|e| anyhow::anyhow!("outer_product failed: {e}"))?;
2466        Self::from_inner_with_axis_classes(result_indices, result, result_axis_classes)
2467    }
2468}
2469
2470// ============================================================================
2471// Random tensor generation
2472// ============================================================================
2473
2474impl TensorDynLen {
2475    /// Create a random tensor with values from standard normal distribution (generic over scalar type).
2476    ///
2477    /// For `f64`, each element is drawn from the standard normal distribution.
2478    /// For `Complex64`, both real and imaginary parts are drawn independently.
2479    ///
2480    /// # Type Parameters
2481    /// * `T` - The scalar element type (must implement [`RandomScalar`])
2482    /// * `R` - The random number generator type
2483    ///
2484    /// # Arguments
2485    /// * `rng` - Random number generator
2486    /// * `indices` - The indices for the tensor
2487    ///
2488    /// # Example
2489    /// ```
2490    /// use tensor4all_core::TensorDynLen;
2491    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2492    /// use rand::SeedableRng;
2493    /// use rand_chacha::ChaCha8Rng;
2494    ///
2495    /// let mut rng = ChaCha8Rng::seed_from_u64(42);
2496    /// let i = Index::new_dyn(2);
2497    /// let j = Index::new_dyn(3);
2498    /// let tensor: TensorDynLen = TensorDynLen::random::<f64, _>(&mut rng, vec![i, j]).unwrap();
2499    /// assert_eq!(tensor.dims(), vec![2, 3]);
2500    /// ```
2501    pub fn random<T: RandomScalar, R: Rng>(rng: &mut R, indices: Vec<DynIndex>) -> Result<Self> {
2502        let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
2503        let size = checked_product(&dims)?;
2504        let data: Vec<T> = (0..size).map(|_| T::random_value(rng)).collect();
2505        Self::from_dense(indices, data)
2506    }
2507}
2508
2509impl TensorDynLen {
2510    /// Add two tensors element-wise.
2511    ///
2512    /// The tensors must have the same index set (matched by ID). If the indices
2513    /// are in a different order, the other tensor will be permuted to match `self`.
2514    ///
2515    /// # Arguments
2516    /// * `other` - The tensor to add
2517    ///
2518    /// # Returns
2519    /// A new tensor representing `self + other`, or an error if:
2520    /// - The tensors have different index sets
2521    /// - The dimensions don't match
2522    /// - Storage types are incompatible
2523    ///
2524    /// # Example
2525    /// ```
2526    /// use tensor4all_core::TensorDynLen;
2527    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2528    ///
2529    /// let i = Index::new_dyn(2);
2530    /// let j = Index::new_dyn(3);
2531    ///
2532    /// let indices_a = vec![i.clone(), j.clone()];
2533    /// let data_a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2534    /// let tensor_a: TensorDynLen = TensorDynLen::from_dense(indices_a, data_a).unwrap();
2535    ///
2536    /// let indices_b = vec![i.clone(), j.clone()];
2537    /// let data_b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
2538    /// let tensor_b: TensorDynLen = TensorDynLen::from_dense(indices_b, data_b).unwrap();
2539    ///
2540    /// let sum = tensor_a.add(&tensor_b).unwrap();
2541    /// // sum = [[2, 3, 4], [5, 6, 7]]
2542    /// ```
2543    pub fn add(&self, other: &Self) -> Result<Self> {
2544        // Validate that both tensors have the same number of indices
2545        if self.indices.len() != other.indices.len() {
2546            return Err(anyhow::anyhow!(
2547                "Index count mismatch: self has {} indices, other has {}",
2548                self.indices.len(),
2549                other.indices.len()
2550            ));
2551        }
2552
2553        // Validate that both tensors have the same set of indices
2554        let self_set: HashSet<_> = self.indices.iter().collect();
2555        let other_set: HashSet<_> = other.indices.iter().collect();
2556
2557        if self_set != other_set {
2558            return Err(anyhow::anyhow!(
2559                "Index set mismatch: tensors must have the same indices"
2560            ));
2561        }
2562
2563        // Permute other to match self's index order (no-op if already aligned)
2564        let other_aligned = other.permute_indices(&self.indices)?;
2565
2566        // Validate dimensions match after alignment
2567        let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2568        let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2569        if self_expected_dims != other_expected_dims {
2570            use crate::TagSetLike;
2571            let fmt = |indices: &[DynIndex]| -> Vec<String> {
2572                indices
2573                    .iter()
2574                    .map(|idx| {
2575                        let tags: Vec<String> = idx.tags().iter().collect();
2576                        format!("{:?}(dim={},tags={:?})", idx.id(), idx.dim(), tags)
2577                    })
2578                    .collect()
2579            };
2580            return Err(anyhow::anyhow!(
2581                "Dimension mismatch after alignment.\n\
2582                 self: dims={:?}, indices(order)={:?}\n\
2583                 other_aligned: dims={:?}, indices(order)={:?}",
2584                self_expected_dims,
2585                fmt(&self.indices),
2586                other_expected_dims,
2587                fmt(&other_aligned.indices)
2588            ));
2589        }
2590
2591        self.axpby(
2592            AnyScalar::new_real(1.0),
2593            &other_aligned,
2594            AnyScalar::new_real(1.0),
2595        )
2596    }
2597
2598    /// Compute a linear combination: `a * self + b * other`.
2599    ///
2600    /// Both tensors must have the same set of indices (matched by ID).
2601    /// If indices are in a different order, `other` is automatically permuted
2602    /// to match `self`.
2603    ///
2604    /// # Examples
2605    ///
2606    /// ```
2607    /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen};
2608    ///
2609    /// let i = DynIndex::new_dyn(2);
2610    /// let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
2611    /// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 4.0]).unwrap();
2612    ///
2613    /// // 2*a + 3*b = [2+9, 4+12] = [11, 16]
2614    /// let result = a.axpby(AnyScalar::new_real(2.0), &b, AnyScalar::new_real(3.0)).unwrap();
2615    /// let data = result.to_vec::<f64>().unwrap();
2616    /// assert!((data[0] - 11.0).abs() < 1e-12);
2617    /// assert!((data[1] - 16.0).abs() < 1e-12);
2618    /// ```
2619    pub fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
2620        // Validate that both tensors have the same number of indices.
2621        if self.indices.len() != other.indices.len() {
2622            return Err(anyhow::anyhow!(
2623                "Index count mismatch: self has {} indices, other has {}",
2624                self.indices.len(),
2625                other.indices.len()
2626            ));
2627        }
2628
2629        // Validate that both tensors have the same set of indices.
2630        let self_set: HashSet<_> = self.indices.iter().collect();
2631        let other_set: HashSet<_> = other.indices.iter().collect();
2632        if self_set != other_set {
2633            return Err(anyhow::anyhow!(
2634                "Index set mismatch: tensors must have the same indices"
2635            ));
2636        }
2637
2638        // Align other tensor axis order to self.
2639        let other_aligned = other.permute_indices(&self.indices)?;
2640
2641        // Validate dimensions match after alignment.
2642        let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2643        let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2644        if self_expected_dims != other_expected_dims {
2645            return Err(anyhow::anyhow!(
2646                "Dimension mismatch after alignment: self={:?}, other_aligned={:?}",
2647                self_expected_dims,
2648                other_expected_dims
2649            ));
2650        }
2651
2652        let axis_classes = if self.storage.axis_classes() == other_aligned.storage.axis_classes() {
2653            self.storage.axis_classes().to_vec()
2654        } else {
2655            Self::dense_axis_classes(self.indices.len())
2656        };
2657
2658        let same_compact_layout = self.storage.payload_dims()
2659            == other_aligned.storage.payload_dims()
2660            && self.storage.payload_strides_vec() == other_aligned.storage.payload_strides_vec()
2661            && self.storage.axis_classes() == other_aligned.storage.axis_classes();
2662        if same_compact_layout
2663            && !self.tracks_grad()
2664            && !other_aligned.tracks_grad()
2665            && !a.tracks_grad()
2666            && !b.tracks_grad()
2667        {
2668            let lhs_storage = self.storage.materialize(self.indices.len())?;
2669            let rhs_storage = other_aligned
2670                .storage
2671                .materialize(other_aligned.indices.len())?;
2672            let combined = lhs_storage
2673                .axpby(
2674                    &a.to_backend_scalar(),
2675                    rhs_storage.as_ref(),
2676                    &b.to_backend_scalar(),
2677                )
2678                .map_err(|e| anyhow::anyhow!("storage axpby failed: {e}"))?;
2679            return Self::from_storage(self.indices.clone(), Arc::new(combined));
2680        }
2681
2682        let self_native = self.as_native()?;
2683        let other_native = other_aligned.as_native()?;
2684        let a_native = a.as_tensor()?.as_native()?;
2685        let b_native = b.as_tensor()?.as_native()?;
2686        if self_native.dtype() != other_native.dtype()
2687            || self_native.dtype() != a_native.dtype()
2688            || other_native.dtype() != b_native.dtype()
2689        {
2690            let combined = axpby_native_tensor(
2691                self_native,
2692                &a.to_backend_scalar(),
2693                other_native,
2694                &b.to_backend_scalar(),
2695            )?;
2696            return Self::from_native_with_axis_classes(
2697                self.indices.clone(),
2698                combined,
2699                axis_classes,
2700            );
2701        }
2702
2703        let lhs = self.scale(a)?;
2704        let rhs = other_aligned.scale(b)?;
2705        let combined = lhs
2706            .try_materialized_inner()?
2707            .add(rhs.try_materialized_inner()?)
2708            .map_err(|e| anyhow::anyhow!("tensor addition failed: {e}"))?;
2709        Self::from_inner_with_axis_classes(self.indices.clone(), combined, axis_classes)
2710    }
2711
2712    /// Scalar multiplication.
2713    ///
2714    /// Multiplies every element by `scalar`.
2715    ///
2716    /// # Examples
2717    ///
2718    /// ```
2719    /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen};
2720    ///
2721    /// let i = DynIndex::new_dyn(3);
2722    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0, 3.0]).unwrap();
2723    /// let scaled = t.scale(AnyScalar::new_real(2.0)).unwrap();
2724    /// assert_eq!(scaled.to_vec::<f64>().unwrap(), vec![2.0, 4.0, 6.0]);
2725    /// ```
2726    pub fn scale(&self, scalar: AnyScalar) -> Result<Self> {
2727        if !self.tracks_grad() && !scalar.tracks_grad() {
2728            let scaled = self.storage.scale(&scalar.to_backend_scalar())?;
2729            return Self::from_storage(self.indices.clone(), Arc::new(scaled));
2730        }
2731
2732        let self_native = self.as_native()?;
2733        let scalar_native = scalar.as_tensor()?.as_native()?;
2734        if self_native.dtype() != scalar_native.dtype() {
2735            let scaled = scale_native_tensor(self_native, &scalar.to_backend_scalar())?;
2736            return Self::from_native_with_axis_classes(
2737                self.indices.clone(),
2738                scaled,
2739                self.storage.axis_classes().to_vec(),
2740            );
2741        }
2742
2743        let scaled = if self.indices.is_empty() {
2744            self.try_materialized_inner()?
2745                .mul(scalar.as_tensor()?.try_materialized_inner()?)
2746                .map_err(|e| anyhow::anyhow!("scalar multiplication failed: {e}"))?
2747        } else {
2748            let subscripts = Self::scale_subscripts(self.indices.len())?;
2749            eager_einsum_ad(
2750                &[
2751                    self.try_materialized_inner()?,
2752                    scalar.as_tensor()?.try_materialized_inner()?,
2753                ],
2754                &subscripts,
2755            )
2756            .map_err(|e| anyhow::anyhow!("tensor scaling failed: {e}"))?
2757        };
2758        Self::from_inner_with_axis_classes(
2759            self.indices.clone(),
2760            scaled,
2761            self.storage.axis_classes().to_vec(),
2762        )
2763    }
2764
2765    /// Inner product (dot product) of two tensors.
2766    ///
2767    /// Computes `⟨self, other⟩ = Σ conj(self)_i * other_i`.
2768    ///
2769    /// # Examples
2770    ///
2771    /// ```
2772    /// use tensor4all_core::{DynIndex, TensorDynLen};
2773    ///
2774    /// let i = DynIndex::new_dyn(3);
2775    /// let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0, 3.0]).unwrap();
2776    /// let b = TensorDynLen::from_dense(vec![i.clone()], vec![4.0, 5.0, 6.0]).unwrap();
2777    ///
2778    /// // <a, b> = 1*4 + 2*5 + 3*6 = 32
2779    /// let ip = a.inner_product(&b).unwrap();
2780    /// assert!((ip.real() - 32.0).abs() < 1e-12);
2781    /// ```
2782    pub fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
2783        if self.indices.len() == other.indices.len() {
2784            let self_set: HashSet<_> = self.indices.iter().collect();
2785            let other_set: HashSet<_> = other.indices.iter().collect();
2786            if self_set == other_set {
2787                let other_aligned = other.permute_indices(&self.indices)?;
2788                let result = super::contract::contract_pair_with_operand_options(
2789                    self,
2790                    &other_aligned,
2791                    PairwiseContractionOptions::new().with_lhs_conj(true),
2792                )?;
2793                return result.sum();
2794            }
2795        }
2796
2797        // Contract self.conj() with other over all indices
2798        let result = super::contract::contract_pair_with_operand_options(
2799            self,
2800            other,
2801            PairwiseContractionOptions::new().with_lhs_conj(true),
2802        )?;
2803        // Result should be a scalar (no indices)
2804        result.sum()
2805    }
2806}
2807
2808// ============================================================================
2809// Index Replacement Methods
2810// ============================================================================
2811
2812impl TensorDynLen {
2813    /// Replace an index in the tensor with a new index.
2814    ///
2815    /// This replaces the index matching `old_index` by ID with `new_index`.
2816    /// The storage data is not modified, only the index metadata is changed.
2817    ///
2818    /// # Arguments
2819    /// * `old_index` - The index to replace (matched by ID)
2820    /// * `new_index` - The new index to use
2821    ///
2822    /// # Returns
2823    /// A new tensor with the index replaced. If no index matches `old_index`,
2824    /// returns a clone of the original tensor.
2825    ///
2826    /// # Errors
2827    /// Returns an error if the replacement index has a different dimension.
2828    ///
2829    /// # Example
2830    /// ```
2831    /// use tensor4all_core::TensorDynLen;
2832    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2833    ///
2834    /// let i = Index::new_dyn(2);
2835    /// let j = Index::new_dyn(3);
2836    /// let new_i = Index::new_dyn(2);  // Same dimension, different ID
2837    ///
2838    /// let indices = vec![i.clone(), j.clone()];
2839    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![0.0; 6]).unwrap();
2840    ///
2841    /// // Replace index i with new_i
2842    /// let replaced = tensor.replaceind(&i, &new_i).unwrap();
2843    /// assert_eq!(replaced.indices[0].id, new_i.id);
2844    /// assert_eq!(replaced.indices[1].id, j.id);
2845    /// ```
2846    pub fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
2847        // Validate dimension match
2848        if old_index.dim() != new_index.dim() {
2849            return Err(anyhow::anyhow!(
2850                "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2851                old_index.dim(),
2852                new_index.dim()
2853            ));
2854        }
2855
2856        let new_indices: Vec<_> = self
2857            .indices
2858            .iter()
2859            .map(|idx| {
2860                if *idx == *old_index {
2861                    new_index.clone()
2862                } else {
2863                    idx.clone()
2864                }
2865            })
2866            .collect();
2867
2868        Ok(Self {
2869            indices: new_indices,
2870            storage: self.storage.clone(),
2871            structured_ad: self.structured_ad.clone(),
2872            eager_cache: Arc::clone(&self.eager_cache),
2873        })
2874    }
2875
2876    /// Replace multiple indices in the tensor.
2877    ///
2878    /// This replaces each index in `old_indices` (matched by ID) with the corresponding
2879    /// index in `new_indices`. The storage data is not modified.
2880    ///
2881    /// # Arguments
2882    /// * `old_indices` - The indices to replace (matched by ID)
2883    /// * `new_indices` - The new indices to use
2884    ///
2885    /// # Returns
2886    /// A new tensor with the indices replaced. Indices not found in `old_indices`
2887    /// are kept unchanged.
2888    ///
2889    /// # Errors
2890    /// Returns an error if `old_indices` and `new_indices` have different
2891    /// lengths or if any replacement index has a different dimension.
2892    ///
2893    /// # Example
2894    /// ```
2895    /// use tensor4all_core::TensorDynLen;
2896    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2897    ///
2898    /// let i = Index::new_dyn(2);
2899    /// let j = Index::new_dyn(3);
2900    /// let new_i = Index::new_dyn(2);
2901    /// let new_j = Index::new_dyn(3);
2902    ///
2903    /// let indices = vec![i.clone(), j.clone()];
2904    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![0.0; 6]).unwrap();
2905    ///
2906    /// // Replace both indices
2907    /// let replaced = tensor
2908    ///     .replaceinds(&[i.clone(), j.clone()], &[new_i.clone(), new_j.clone()])
2909    ///     .unwrap();
2910    /// assert_eq!(replaced.indices[0].id, new_i.id);
2911    /// assert_eq!(replaced.indices[1].id, new_j.id);
2912    /// ```
2913    pub fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
2914        anyhow::ensure!(
2915            old_indices.len() == new_indices.len(),
2916            "old_indices and new_indices must have the same length"
2917        );
2918
2919        // Validate dimension matches for all replacements
2920        for (old, new) in old_indices.iter().zip(new_indices.iter()) {
2921            if old.dim() != new.dim() {
2922                return Err(anyhow::anyhow!(
2923                    "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2924                    old.dim(),
2925                    new.dim()
2926                ));
2927            }
2928        }
2929
2930        // Build a map from old indices to new indices
2931        let replacement_map: std::collections::HashMap<_, _> =
2932            old_indices.iter().zip(new_indices.iter()).collect();
2933
2934        let new_indices_vec: Vec<_> = self
2935            .indices
2936            .iter()
2937            .map(|idx| {
2938                if let Some(new_idx) = replacement_map.get(idx) {
2939                    (*new_idx).clone()
2940                } else {
2941                    idx.clone()
2942                }
2943            })
2944            .collect();
2945
2946        Ok(Self {
2947            indices: new_indices_vec,
2948            storage: self.storage.clone(),
2949            structured_ad: self.structured_ad.clone(),
2950            eager_cache: Arc::clone(&self.eager_cache),
2951        })
2952    }
2953}
2954
2955// ============================================================================
2956// Complex Conjugation
2957// ============================================================================
2958
2959impl TensorDynLen {
2960    /// Complex conjugate of all tensor elements.
2961    ///
2962    /// For real (f64) tensors, returns a copy (conjugate of real is identity).
2963    /// For complex (Complex64) tensors, conjugates each element.
2964    ///
2965    /// The indices and dimensions remain unchanged.
2966    ///
2967    /// This is inspired by the `conj` operation in ITensorMPS.jl.
2968    ///
2969    /// # Example
2970    /// ```
2971    /// use tensor4all_core::TensorDynLen;
2972    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2973    /// use num_complex::Complex64;
2974    ///
2975    /// let i = Index::new_dyn(2);
2976    /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -4.0)];
2977    /// let tensor: TensorDynLen = TensorDynLen::from_dense(vec![i], data).unwrap();
2978    ///
2979    /// let conj_tensor = tensor.conj();
2980    /// // Elements are now conjugated: 1-2i, 3+4i
2981    /// ```
2982    pub fn conj(&self) -> Self {
2983        // Conjugate tensor: conjugate storage data and map indices via IndexLike::conj()
2984        // For default undirected indices, conj() is a no-op, so this is future-proof
2985        // for QSpace-compatible directed indices where conj() flips Ket <-> Bra
2986        let new_indices: Vec<DynIndex> = self.indices.iter().map(|idx| idx.conj()).collect();
2987        let structured_ad = self.tracked_compact_payload_value().and_then(|value| {
2988            value.payload.conj().ok().map(|payload| {
2989                Arc::new(StructuredAdValue {
2990                    payload: Arc::new(payload),
2991                    payload_dims: value.payload_dims.clone(),
2992                    axis_classes: value.axis_classes.clone(),
2993                })
2994            })
2995        });
2996        let eager_cache = self
2997            .eager_cache
2998            .get()
2999            .and_then(|inner| inner.conj().ok())
3000            .map(Self::eager_cache_with)
3001            .unwrap_or_else(Self::empty_eager_cache);
3002        Self {
3003            indices: new_indices,
3004            storage: self.storage.conj().unwrap_or_else(|_| {
3005                TensorDynLenStorage::from_storage(Arc::new(self.storage().conj()))
3006            }),
3007            structured_ad,
3008            eager_cache,
3009        }
3010    }
3011}
3012
3013// ============================================================================
3014// Norm Computation
3015// ============================================================================
3016
3017impl TensorDynLen {
3018    /// Compute the squared Frobenius norm of the tensor: ||T||² = Σ|T_ijk...|²
3019    ///
3020    /// For real tensors: sum of squares of all elements.
3021    /// For complex tensors: sum of |z|² = z * conj(z) for all elements.
3022    ///
3023    /// # Example
3024    /// ```
3025    /// use tensor4all_core::TensorDynLen;
3026    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3027    ///
3028    /// let i = Index::new_dyn(2);
3029    /// let j = Index::new_dyn(3);
3030    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];  // 1² + 2² + ... + 6² = 91
3031    /// let tensor: TensorDynLen = TensorDynLen::from_dense(vec![i, j], data).unwrap();
3032    ///
3033    /// assert!((tensor.norm_squared() - 91.0).abs() < 1e-10);
3034    /// ```
3035    pub fn norm_squared(&self) -> f64 {
3036        self.try_norm_squared().unwrap_or(f64::NAN)
3037    }
3038
3039    /// Try to compute the squared Frobenius norm of the tensor.
3040    ///
3041    /// # Errors
3042    /// Returns an error if conjugation, contraction, or scalar extraction fails.
3043    pub fn try_norm_squared(&self) -> Result<f64> {
3044        // Special case: scalar tensor (no indices)
3045        if self.indices.is_empty() {
3046            // For a scalar, ||T||² = |value|²
3047            let value = self.sum()?;
3048            let abs_val = value.abs();
3049            return Ok(abs_val * abs_val);
3050        }
3051
3052        // Contract tensor with its conjugate over all indices → scalar
3053        // ||T||² = Σ T_ijk... * conj(T_ijk...) = Σ |T_ijk...|²
3054        let conj = self.conj();
3055        let scalar = super::contract::contract_pair(self, &conj)?;
3056        // The mathematical result is nonnegative and real. Clamp tiny negative
3057        // roundoff so downstream `sqrt` stays well-defined for complex tensors.
3058        Ok(scalar.sum()?.real().max(0.0))
3059    }
3060
3061    /// Compute the Frobenius norm of the tensor: ||T|| = sqrt(Σ|T_ijk...|²)
3062    ///
3063    /// # Example
3064    /// ```
3065    /// use tensor4all_core::TensorDynLen;
3066    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3067    ///
3068    /// let i = Index::new_dyn(2);
3069    /// let data = vec![3.0, 4.0];  // sqrt(9 + 16) = 5
3070    /// let tensor: TensorDynLen = TensorDynLen::from_dense(vec![i], data).unwrap();
3071    ///
3072    /// assert!((tensor.norm() - 5.0).abs() < 1e-10);
3073    /// ```
3074    pub fn norm(&self) -> f64 {
3075        self.norm_squared().sqrt()
3076    }
3077
3078    /// Maximum absolute value of all elements (L-infinity norm).
3079    ///
3080    /// # Examples
3081    ///
3082    /// ```
3083    /// use tensor4all_core::{DynIndex, TensorDynLen};
3084    ///
3085    /// let i = DynIndex::new_dyn(4);
3086    /// let t = TensorDynLen::from_dense(vec![i], vec![-5.0, 1.0, 3.0, -2.0]).unwrap();
3087    /// assert!((t.maxabs() - 5.0).abs() < 1e-12);
3088    /// ```
3089    pub fn maxabs(&self) -> f64 {
3090        self.storage.max_abs().unwrap_or(0.0)
3091    }
3092
3093    /// Element-wise subtraction with index alignment.
3094    ///
3095    /// This computes `self - other` using the same vector-space semantics as
3096    /// [`TensorVectorSpace`](crate::TensorVectorSpace).
3097    ///
3098    /// # Errors
3099    /// Returns an error if the tensors cannot be aligned or subtracted.
3100    pub fn sub(&self, other: &Self) -> Result<Self> {
3101        self.axpby(AnyScalar::new_real(1.0), other, AnyScalar::new_real(-1.0))
3102    }
3103
3104    /// Negate all elements.
3105    ///
3106    /// # Errors
3107    /// Returns an error if scalar multiplication fails for the tensor storage.
3108    pub fn neg(&self) -> Result<Self> {
3109        self.scale(AnyScalar::new_real(-1.0))
3110    }
3111
3112    /// Approximate equality check using Julia `isapprox`-style semantics.
3113    ///
3114    /// Returns `true` when `||self - other|| <= max(atol, rtol *
3115    /// max(||self||, ||other||))`.
3116    pub fn isapprox(&self, other: &Self, atol: f64, rtol: f64) -> bool {
3117        let diff = match self.sub(other) {
3118            Ok(d) => d,
3119            Err(_) => return false,
3120        };
3121        let diff_norm = diff.norm();
3122        diff_norm <= atol.max(rtol * self.norm().max(other.norm()))
3123    }
3124
3125    /// Create a diagonal Kronecker-delta tensor for one input/output index pair.
3126    ///
3127    /// # Errors
3128    /// Returns an error if the two indices have different dimensions.
3129    pub fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3130        <Self as TensorConstructionLike>::diagonal(input_index, output_index)
3131    }
3132
3133    /// Create a product of Kronecker-delta tensors for paired index lists.
3134    ///
3135    /// # Errors
3136    /// Returns an error if the index lists have different lengths or paired
3137    /// dimensions do not match.
3138    pub fn delta(input_indices: &[DynIndex], output_indices: &[DynIndex]) -> Result<Self> {
3139        <Self as TensorConstructionLike>::delta(input_indices, output_indices)
3140    }
3141
3142    /// Create a scalar tensor equal to one.
3143    ///
3144    /// # Errors
3145    /// Returns an error if dense scalar construction fails.
3146    pub fn scalar_one() -> Result<Self> {
3147        <Self as TensorConstructionLike>::scalar_one()
3148    }
3149
3150    /// Create a tensor filled with ones over the given indices.
3151    ///
3152    /// # Errors
3153    /// Returns an error if the tensor size overflows or dense construction fails.
3154    pub fn ones(indices: &[DynIndex]) -> Result<Self> {
3155        <Self as TensorConstructionLike>::ones(indices)
3156    }
3157
3158    /// Create a one-hot tensor with value one at the specified index positions.
3159    ///
3160    /// # Errors
3161    /// Returns an error if any coordinate is outside its index dimension.
3162    pub fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3163        <Self as TensorConstructionLike>::onehot(index_vals)
3164    }
3165
3166    /// Compute the relative distance between two tensors.
3167    ///
3168    /// Returns `||A - B|| / ||A||` (Frobenius norm).
3169    /// If `||A|| = 0`, returns `||B||` instead to avoid division by zero.
3170    ///
3171    /// This is the ITensor-style distance function useful for comparing tensors.
3172    ///
3173    /// # Arguments
3174    /// * `other` - The other tensor to compare with
3175    ///
3176    /// # Returns
3177    /// The relative distance as a f64 value.
3178    ///
3179    /// # Note
3180    /// The indices of both tensors must be permutable to each other.
3181    /// The result tensor (A - B) uses the index ordering from self.
3182    ///
3183    /// # Example
3184    /// ```
3185    /// use tensor4all_core::TensorDynLen;
3186    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3187    ///
3188    /// let i = Index::new_dyn(2);
3189    /// let data_a = vec![1.0, 0.0];
3190    /// let data_b = vec![1.0, 0.0];  // Same tensor
3191    /// let tensor_a: TensorDynLen = TensorDynLen::from_dense(vec![i.clone()], data_a).unwrap();
3192    /// let tensor_b: TensorDynLen = TensorDynLen::from_dense(vec![i.clone()], data_b).unwrap();
3193    ///
3194    /// assert!(tensor_a.distance(&tensor_b).unwrap() < 1e-10);  // Zero distance
3195    /// ```
3196    pub fn distance(&self, other: &Self) -> Result<f64> {
3197        let norm_self = self.norm();
3198
3199        // Compute A - B = A + (-1) * B
3200        let neg_other = other.scale(AnyScalar::new_real(-1.0))?;
3201        let diff = self.add(&neg_other)?;
3202        let norm_diff = diff.norm();
3203
3204        if norm_self > 0.0 {
3205            Ok(norm_diff / norm_self)
3206        } else {
3207            Ok(norm_diff)
3208        }
3209    }
3210}
3211
3212impl std::fmt::Debug for TensorDynLen {
3213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3214        f.debug_struct("TensorDynLen")
3215            .field("indices", &self.indices)
3216            .field("dims", &self.dims())
3217            .field("is_diag", &self.is_diag())
3218            .finish()
3219    }
3220}
3221
3222/// Create a diagonal tensor with dynamic rank from diagonal data.
3223///
3224/// # Arguments
3225/// * `indices` - The indices for the tensor (all must have the same dimension)
3226/// * `diag_data` - The diagonal elements (length must equal the dimension of indices)
3227///
3228/// The returned tensor preserves compact diagonal payload metadata; use
3229/// [`TensorDynLen::is_diag`] or [`TensorDynLen::storage`] to inspect that
3230/// representation.
3231///
3232/// # Panics
3233/// Panics if indices have different dimensions, or if diag_data length doesn't match.
3234///
3235/// # Examples
3236///
3237/// ```
3238/// use tensor4all_core::{DynIndex, diag_tensor_dyn_len};
3239///
3240/// let i = DynIndex::new_dyn(3);
3241/// let j = DynIndex::new_dyn(3);
3242/// let t = diag_tensor_dyn_len(vec![i, j], vec![1.0, 2.0, 3.0]).unwrap();
3243/// assert_eq!(t.dims(), vec![3, 3]);
3244/// assert!(t.is_diag());
3245/// ```
3246pub fn diag_tensor_dyn_len(indices: Vec<DynIndex>, diag_data: Vec<f64>) -> Result<TensorDynLen> {
3247    TensorDynLen::from_diag(indices, diag_data)
3248}
3249
3250#[allow(clippy::type_complexity)]
3251pub(crate) type UnfoldSplitInnerResult = (
3252    EagerTensor,
3253    usize,
3254    usize,
3255    usize,
3256    Vec<DynIndex>,
3257    Vec<DynIndex>,
3258);
3259
3260/// Unfold a tensor into a matrix by splitting indices into left and right groups.
3261///
3262/// This function validates the split, permutes the tensor so that left indices
3263/// come first, and returns a rank-2 native tenferro tensor along with metadata.
3264///
3265/// # Arguments
3266/// * `t` - Input tensor
3267/// * `left_inds` - Indices to place on the left (row) side of the matrix
3268///
3269/// # Returns
3270/// A tuple `(matrix_tensor, left_len, m, n, left_indices, right_indices)` where:
3271/// - `matrix_tensor` is a rank-2 `tenferro::Tensor` with shape `[m, n]`
3272/// - `left_len` is the number of left indices
3273/// - `m` is the product of left index dimensions
3274/// - `n` is the product of right index dimensions
3275/// - `left_indices` is the vector of left indices (cloned)
3276/// - `right_indices` is the vector of right indices (cloned)
3277///
3278/// # Errors
3279/// Returns an error if:
3280/// - The tensor rank is < 2
3281/// - `left_inds` is empty or contains all indices
3282/// - `left_inds` contains indices not in the tensor or duplicates
3283/// - Native reshape fails
3284///
3285/// # Examples
3286///
3287/// ```
3288/// use tensor4all_core::{DynIndex, TensorDynLen, unfold_split};
3289///
3290/// let i = DynIndex::new_dyn(2);
3291/// let j = DynIndex::new_dyn(3);
3292/// // 2x3 dense tensor with data [1..6]
3293/// let t = TensorDynLen::from_dense(
3294///     vec![i.clone(), j.clone()],
3295///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3296/// ).unwrap();
3297///
3298/// let (matrix, left_len, m, n, left_indices, right_indices) =
3299///     unfold_split(&t, &[i]).unwrap();
3300/// assert_eq!(left_len, 1);
3301/// assert_eq!(m, 2);
3302/// assert_eq!(n, 3);
3303/// assert_eq!(left_indices.len(), 1);
3304/// assert_eq!(right_indices.len(), 1);
3305/// ```
3306#[allow(clippy::type_complexity)]
3307pub fn unfold_split(
3308    t: &TensorDynLen,
3309    left_inds: &[DynIndex],
3310) -> Result<(
3311    NativeTensor,
3312    usize,
3313    usize,
3314    usize,
3315    Vec<DynIndex>,
3316    Vec<DynIndex>,
3317)> {
3318    let (matrix_inner, left_len, m, n, left_indices, right_indices) =
3319        unfold_split_inner(t, left_inds)?;
3320
3321    Ok((
3322        matrix_inner.data().clone(),
3323        left_len,
3324        m,
3325        n,
3326        left_indices,
3327        right_indices,
3328    ))
3329}
3330
3331pub(crate) fn unfold_split_inner(
3332    t: &TensorDynLen,
3333    left_inds: &[DynIndex],
3334) -> Result<UnfoldSplitInnerResult> {
3335    let rank = t.indices.len();
3336
3337    // Validate rank
3338    anyhow::ensure!(rank >= 2, "Tensor must have rank >= 2, got rank {}", rank);
3339
3340    let left_len = left_inds.len();
3341
3342    // Validate split: must be a proper subset
3343    anyhow::ensure!(
3344        left_len > 0 && left_len < rank,
3345        "Left indices must be a non-empty proper subset of tensor indices (0 < left_len < rank), got left_len={}, rank={}",
3346        left_len,
3347        rank
3348    );
3349
3350    // Validate that all left_inds are in the tensor and there are no duplicates
3351    let tensor_set: HashSet<_> = t.indices.iter().collect();
3352    let mut left_set = HashSet::new();
3353
3354    for left_idx in left_inds {
3355        anyhow::ensure!(
3356            tensor_set.contains(left_idx),
3357            "Index in left_inds not found in tensor"
3358        );
3359        anyhow::ensure!(left_set.insert(left_idx), "Duplicate index in left_inds");
3360    }
3361
3362    // Build right_inds: all indices not in left_inds, in original order
3363    let mut right_inds = Vec::new();
3364    for idx in &t.indices {
3365        if !left_set.contains(idx) {
3366            right_inds.push(idx.clone());
3367        }
3368    }
3369
3370    // Build new_indices: left_inds first, then right_inds
3371    let mut new_indices = Vec::with_capacity(rank);
3372    new_indices.extend_from_slice(left_inds);
3373    new_indices.extend_from_slice(&right_inds);
3374
3375    // Permute tensor to have left indices first, then right indices
3376    let unfolded = t.permute_indices(&new_indices)?;
3377
3378    // Compute matrix dimensions
3379    let unfolded_dims = unfolded.dims();
3380    let m: usize = unfolded_dims[..left_len].iter().product();
3381    let n: usize = unfolded_dims[left_len..].iter().product();
3382
3383    let matrix_tensor = unfolded.try_materialized_inner()?.reshape(&[m, n])?;
3384
3385    Ok((
3386        matrix_tensor,
3387        left_len,
3388        m,
3389        n,
3390        left_inds.to_vec(),
3391        right_inds,
3392    ))
3393}
3394
3395// ============================================================================
3396// TensorIndex implementation for TensorDynLen
3397// ============================================================================
3398
3399use crate::tensor_index::TensorIndex;
3400
3401impl TensorIndex for TensorDynLen {
3402    type Index = DynIndex;
3403
3404    fn external_indices(&self) -> Vec<DynIndex> {
3405        // For TensorDynLen, all indices are external.
3406        self.indices.clone()
3407    }
3408
3409    fn num_external_indices(&self) -> usize {
3410        self.indices.len()
3411    }
3412
3413    fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
3414        // Delegate to the inherent method
3415        TensorDynLen::replaceind(self, old_index, new_index)
3416    }
3417
3418    fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
3419        // Delegate to the inherent method
3420        TensorDynLen::replaceinds(self, old_indices, new_indices)
3421    }
3422}
3423
3424// ============================================================================
3425// TensorLike implementation for TensorDynLen
3426// ============================================================================
3427
3428use crate::tensor_like::{
3429    FactorizeError, FactorizeOptions, FactorizeResult, TensorConstructionLike,
3430    TensorContractionLike, TensorFactorizationLike, TensorVectorSpace,
3431};
3432
3433impl TensorVectorSpace for TensorDynLen {
3434    fn norm_squared(&self) -> f64 {
3435        TensorDynLen::norm_squared(self)
3436    }
3437
3438    fn maxabs(&self) -> f64 {
3439        TensorDynLen::maxabs(self)
3440    }
3441
3442    fn axpby(&self, a: crate::AnyScalar, other: &Self, b: crate::AnyScalar) -> Result<Self> {
3443        TensorDynLen::axpby(self, a, other, b)
3444    }
3445
3446    fn scale(&self, scalar: crate::AnyScalar) -> Result<Self> {
3447        TensorDynLen::scale(self, scalar)
3448    }
3449
3450    fn inner_product(&self, other: &Self) -> Result<crate::AnyScalar> {
3451        TensorDynLen::inner_product(self, other)
3452    }
3453}
3454
3455impl TensorFactorizationLike for TensorDynLen {
3456    fn factorize(
3457        &self,
3458        left_inds: &[DynIndex],
3459        options: &FactorizeOptions,
3460    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3461        crate::factorize::factorize(self, left_inds, options)
3462    }
3463
3464    fn factorize_full_rank(
3465        &self,
3466        left_inds: &[DynIndex],
3467        alg: crate::FactorizeAlg,
3468        canonical: crate::Canonical,
3469    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3470        crate::factorize::factorize_full_rank(self, left_inds, alg, canonical)
3471    }
3472}
3473
3474impl TensorContractionLike for TensorDynLen {
3475    fn conj(&self) -> Self {
3476        // Delegate to the inherent method (complex conjugate for dense tensors)
3477        TensorDynLen::conj(self)
3478    }
3479
3480    fn direct_sum(
3481        &self,
3482        other: &Self,
3483        pairs: &[(DynIndex, DynIndex)],
3484    ) -> Result<crate::tensor_like::DirectSumResult<Self>> {
3485        let (tensor, new_indices) = crate::direct_sum::direct_sum(self, other, pairs)?;
3486        Ok(crate::tensor_like::DirectSumResult {
3487            tensor,
3488            new_indices,
3489        })
3490    }
3491
3492    fn outer_product(&self, other: &Self) -> Result<Self> {
3493        super::contract::outer_product(self, other)
3494    }
3495
3496    fn permuteinds(&self, new_order: &[DynIndex]) -> Result<Self> {
3497        // Delegate to the inherent method
3498        TensorDynLen::permute_indices(self, new_order)
3499    }
3500
3501    fn fuse_indices(
3502        &self,
3503        old_indices: &[DynIndex],
3504        new_index: DynIndex,
3505        order: LinearizationOrder,
3506    ) -> Result<Self> {
3507        TensorDynLen::fuse_indices(self, old_indices, new_index, order)
3508    }
3509
3510    fn contract(tensors: &[&Self]) -> Result<Self> {
3511        super::contract::contract(tensors)
3512    }
3513
3514    fn contract_pair(&self, other: &Self) -> Result<Self> {
3515        super::contract::contract_pair(self, other)
3516    }
3517}
3518
3519impl TensorConstructionLike for TensorDynLen {
3520    fn select_indices(&self, selected_indices: &[DynIndex], positions: &[usize]) -> Result<Self> {
3521        TensorDynLen::select_indices(self, selected_indices, positions)
3522    }
3523
3524    fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3525        let dim = input_index.dim();
3526        if dim != output_index.dim() {
3527            return Err(anyhow::anyhow!(
3528                "Dimension mismatch: input index has dim {}, output has dim {}",
3529                dim,
3530                output_index.dim(),
3531            ));
3532        }
3533
3534        TensorDynLen::from_diag(
3535            vec![input_index.clone(), output_index.clone()],
3536            vec![1.0_f64; dim],
3537        )
3538    }
3539
3540    fn scalar_one() -> Result<Self> {
3541        TensorDynLen::from_dense(vec![], vec![1.0_f64])
3542    }
3543
3544    fn ones(indices: &[DynIndex]) -> Result<Self> {
3545        if indices.is_empty() {
3546            return Self::scalar_one();
3547        }
3548        let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3549        let total_size = checked_total_size(&dims)?;
3550        TensorDynLen::from_dense(indices.to_vec(), vec![1.0_f64; total_size])
3551    }
3552
3553    fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3554        if index_vals.is_empty() {
3555            return Self::scalar_one();
3556        }
3557        let indices: Vec<DynIndex> = index_vals.iter().map(|(idx, _)| idx.clone()).collect();
3558        let vals: Vec<usize> = index_vals.iter().map(|(_, v)| *v).collect();
3559        let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3560
3561        for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3562            if v >= d {
3563                return Err(anyhow::anyhow!(
3564                    "onehot: value {} at position {} is >= dimension {}",
3565                    v,
3566                    k,
3567                    d
3568                ));
3569            }
3570        }
3571
3572        let total_size = checked_total_size(&dims)?;
3573        let mut data = vec![0.0_f64; total_size];
3574
3575        let offset = column_major_offset(&dims, &vals)?;
3576        data[offset] = 1.0;
3577
3578        Self::from_dense(indices, data)
3579    }
3580
3581    // delta() uses the default implementation via diagonal() and outer_product()
3582}
3583
3584fn checked_total_size(dims: &[usize]) -> Result<usize> {
3585    dims.iter().try_fold(1_usize, |acc, &d| {
3586        if d == 0 {
3587            return Err(anyhow::anyhow!("invalid dimension 0"));
3588        }
3589        acc.checked_mul(d)
3590            .ok_or_else(|| anyhow::anyhow!("tensor size overflow"))
3591    })
3592}
3593
3594fn column_major_offset(dims: &[usize], vals: &[usize]) -> Result<usize> {
3595    if dims.len() != vals.len() {
3596        return Err(anyhow::anyhow!(
3597            "column_major_offset: dims.len() != vals.len()"
3598        ));
3599    }
3600    checked_total_size(dims)?;
3601
3602    let mut offset = 0usize;
3603    let mut stride = 1usize;
3604    for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3605        if d == 0 {
3606            return Err(anyhow::anyhow!("invalid dimension 0 at position {}", k));
3607        }
3608        if v >= d {
3609            return Err(anyhow::anyhow!(
3610                "column_major_offset: value {} at position {} is >= dimension {}",
3611                v,
3612                k,
3613                d
3614            ));
3615        }
3616        let term = v
3617            .checked_mul(stride)
3618            .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3619        offset = offset
3620            .checked_add(term)
3621            .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3622        stride = stride
3623            .checked_mul(d)
3624            .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3625    }
3626    Ok(offset)
3627}
3628
3629// ============================================================================
3630// High-level API for tensor construction (avoids direct Storage access)
3631// ============================================================================
3632
3633impl TensorDynLen {
3634    fn any_scalar_payload_to_complex(data: Vec<AnyScalar>) -> Vec<Complex64> {
3635        data.into_iter()
3636            .map(|value| {
3637                value
3638                    .as_c64()
3639                    .unwrap_or_else(|| Complex64::new(value.real(), 0.0))
3640            })
3641            .collect()
3642    }
3643
3644    fn any_scalar_payload_to_real(data: Vec<AnyScalar>) -> Vec<f64> {
3645        data.into_iter().map(|value| value.real()).collect()
3646    }
3647
3648    fn validate_dense_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3649        let expected_len = checked_total_size(dims)?;
3650        anyhow::ensure!(
3651            data_len == expected_len,
3652            "dense payload length {} does not match dims {:?} (expected {})",
3653            data_len,
3654            dims,
3655            expected_len
3656        );
3657        Ok(())
3658    }
3659
3660    fn validate_diag_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3661        anyhow::ensure!(
3662            !dims.is_empty(),
3663            "diagonal tensor construction requires at least one index"
3664        );
3665        Self::validate_diag_dims(dims)?;
3666        anyhow::ensure!(
3667            data_len == dims[0],
3668            "diagonal payload length {} does not match diagonal dimension {}",
3669            data_len,
3670            dims[0]
3671        );
3672        Ok(())
3673    }
3674
3675    /// Create a tensor from dense data with explicit indices.
3676    ///
3677    /// This is the recommended high-level API for creating tensors from raw data.
3678    /// It avoids direct access to `Storage` internals.
3679    ///
3680    /// # Type Parameters
3681    /// * `T` - Scalar type (`f64` or `Complex64`)
3682    ///
3683    /// # Arguments
3684    /// * `indices` - Vector of indices for the tensor
3685    /// * `data` - Tensor data in column-major order
3686    ///
3687    /// # Panics
3688    /// Panics if data length doesn't match the product of index dimensions.
3689    ///
3690    /// # Example
3691    /// ```
3692    /// use tensor4all_core::TensorDynLen;
3693    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3694    ///
3695    /// let i = Index::new_dyn(2);
3696    /// let j = Index::new_dyn(3);
3697    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
3698    /// let tensor: TensorDynLen = TensorDynLen::from_dense(vec![i, j], data).unwrap();
3699    /// assert_eq!(tensor.dims(), vec![2, 3]);
3700    /// ```
3701    pub fn from_dense<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3702        let dims = Self::expected_dims_from_indices(&indices);
3703        Self::validate_indices(&indices)?;
3704        Self::validate_dense_payload_len(data.len(), &dims)?;
3705        let native = dense_native_tensor_from_col_major(&data, &dims)?;
3706        Self::from_native(indices, native)
3707    }
3708
3709    /// Create a tensor from dense payload data provided as [`AnyScalar`] values.
3710    ///
3711    /// This is the preferred public API when the caller only knows the scalar
3712    /// type at runtime.
3713    ///
3714    /// # Examples
3715    /// ```
3716    /// use tensor4all_core::{AnyScalar, TensorDynLen};
3717    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3718    ///
3719    /// let i = Index::new_dyn(2);
3720    /// let j = Index::new_dyn(2);
3721    /// let tensor = TensorDynLen::from_dense_any(
3722    ///     vec![i, j],
3723    ///     vec![
3724    ///         AnyScalar::new_real(1.0),
3725    ///         AnyScalar::new_complex(0.0, 1.0),
3726    ///         AnyScalar::new_real(2.0),
3727    ///         AnyScalar::new_real(3.0),
3728    ///     ],
3729    /// ).unwrap();
3730    ///
3731    /// assert!(tensor.is_complex());
3732    /// assert_eq!(tensor.dims(), vec![2, 2]);
3733    /// ```
3734    pub fn from_dense_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3735        if data.iter().any(AnyScalar::is_complex) {
3736            Self::from_dense(indices, Self::any_scalar_payload_to_complex(data))
3737        } else {
3738            Self::from_dense(indices, Self::any_scalar_payload_to_real(data))
3739        }
3740    }
3741
3742    /// Create a diagonal tensor from diagonal payload data with explicit indices.
3743    ///
3744    /// All indices must have the same dimension, and `data.len()` must equal
3745    /// that dimension. The resulting tensor has nonzero entries only on
3746    /// the multi-index diagonal (`T[i,i,...,i] = data[i]`).
3747    ///
3748    /// The returned tensor preserves compact diagonal payload metadata; use
3749    /// [`TensorDynLen::is_diag`] or [`TensorDynLen::storage`] to inspect that
3750    /// representation.
3751    ///
3752    /// # Examples
3753    ///
3754    /// ```
3755    /// use tensor4all_core::{DynIndex, TensorDynLen};
3756    ///
3757    /// let i = DynIndex::new_dyn(3);
3758    /// let j = DynIndex::new_dyn(3);
3759    /// let diag = TensorDynLen::from_diag(vec![i, j], vec![1.0, 2.0, 3.0]).unwrap();
3760    /// assert!(diag.is_diag());
3761    ///
3762    /// let data = diag.to_vec::<f64>().unwrap();
3763    /// // 3x3 identity-like: [1,0,0, 0,2,0, 0,0,3] in column-major
3764    /// assert!((data[0] - 1.0).abs() < 1e-12);
3765    /// assert!((data[4] - 2.0).abs() < 1e-12);
3766    /// assert!((data[8] - 3.0).abs() < 1e-12);
3767    /// assert!((data[1]).abs() < 1e-12);  // off-diagonal is zero
3768    /// ```
3769    pub fn from_diag<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3770        let dims = Self::expected_dims_from_indices(&indices);
3771        Self::validate_indices(&indices)?;
3772        Self::validate_diag_payload_len(data.len(), &dims)?;
3773        let native = diag_native_tensor_from_col_major(&data, dims.len())?;
3774        Self::from_native_with_axis_classes(indices, native, Self::diag_axis_classes(dims.len()))
3775    }
3776
3777    /// Create a diagonal tensor from diagonal payload data provided as
3778    /// [`AnyScalar`] values.
3779    ///
3780    /// This is the preferred public API when the caller only knows the scalar
3781    /// type at runtime.
3782    ///
3783    /// # Examples
3784    /// ```
3785    /// use tensor4all_core::{AnyScalar, TensorDynLen};
3786    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3787    ///
3788    /// let i = Index::new_dyn(2);
3789    /// let j = Index::new_dyn(2);
3790    /// let tensor = TensorDynLen::from_diag_any(
3791    ///     vec![i, j],
3792    ///     vec![AnyScalar::new_real(1.0), AnyScalar::new_complex(2.0, -1.0)],
3793    /// ).unwrap();
3794    ///
3795    /// assert!(tensor.is_complex());
3796    /// assert_eq!(tensor.dims(), vec![2, 2]);
3797    /// ```
3798    pub fn from_diag_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3799        if data.iter().any(AnyScalar::is_complex) {
3800            Self::from_diag(indices, Self::any_scalar_payload_to_complex(data))
3801        } else {
3802            Self::from_diag(indices, Self::any_scalar_payload_to_real(data))
3803        }
3804    }
3805
3806    /// Create a copy tensor whose nonzero entries are `value` on the diagonal.
3807    ///
3808    /// For indices `[i, j, k]`, the returned tensor satisfies
3809    /// `T[i, j, k] = value` when `i = j = k`, and zero otherwise.
3810    ///
3811    /// # Examples
3812    /// ```
3813    /// use tensor4all_core::{AnyScalar, TensorDynLen};
3814    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3815    ///
3816    /// let i = Index::new_dyn(2);
3817    /// let j = Index::new_dyn(2);
3818    /// let k = Index::new_dyn(2);
3819    /// let tensor = TensorDynLen::copy_tensor(
3820    ///     vec![i, j, k],
3821    ///     AnyScalar::new_real(1.0),
3822    /// ).unwrap();
3823    ///
3824    /// assert_eq!(tensor.dims(), vec![2, 2, 2]);
3825    /// ```
3826    pub fn copy_tensor(indices: Vec<DynIndex>, value: AnyScalar) -> Result<Self> {
3827        if indices.is_empty() {
3828            return Self::from_dense_any(vec![], vec![value]);
3829        }
3830        let dim = indices[0].dim();
3831        let data = vec![value; dim];
3832        Self::from_diag_any(indices, data)
3833    }
3834
3835    /// Replace multiple tensor indices with one fused index using an exact local reshape.
3836    ///
3837    /// The indices in `old_indices` identify the axes to fuse by ID and also
3838    /// define the coordinate order used inside `new_index`. The new fused index
3839    /// is inserted at the earliest axis position among the fused axes; all
3840    /// other axes keep their original relative order. Use
3841    /// [`LinearizationOrder::ColumnMajor`] to match tensor4all's dense vector
3842    /// layout, or [`LinearizationOrder::RowMajor`] when interoperating with
3843    /// row-major fused coordinates.
3844    ///
3845    /// # Arguments
3846    /// * `old_indices` - Non-empty list of existing tensor indices to replace.
3847    ///   Each index is matched by ID, must appear exactly once in the tensor,
3848    ///   must have the same dimension as the matched tensor axis, and must not
3849    ///   be duplicated in this list.
3850    /// * `new_index` - Replacement index whose dimension must equal the product
3851    ///   of the dimensions in `old_indices`.
3852    /// * `order` - Linearization convention used to encode the old coordinates
3853    ///   into the single coordinate of `new_index`.
3854    ///
3855    /// # Returns
3856    /// A tensor with the same element type and values, but with `old_indices`
3857    /// replaced by `new_index`.
3858    ///
3859    /// # Errors
3860    /// Returns an error if `old_indices` is empty, contains duplicate IDs,
3861    /// references an index not present in the tensor, if the fused dimension
3862    /// does not match the product of the old dimensions, if the replacement
3863    /// would duplicate a kept index, or if the dense reshape cannot be
3864    /// represented without overflow.
3865    ///
3866    /// # Examples
3867    /// ```
3868    /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen};
3869    ///
3870    /// let i = DynIndex::new_dyn(2);
3871    /// let j = DynIndex::new_dyn(2);
3872    /// let fused = DynIndex::new_link(4).unwrap();
3873    /// let tensor = TensorDynLen::from_dense(
3874    ///     vec![i.clone(), j.clone()],
3875    ///     vec![1.0, 2.0, 3.0, 4.0],
3876    /// ).unwrap();
3877    ///
3878    /// let fused_tensor = tensor
3879    ///     .fuse_indices(&[i.clone(), j.clone()], fused.clone(), LinearizationOrder::ColumnMajor)
3880    ///     .unwrap();
3881    /// assert_eq!(fused_tensor.dims(), vec![4]);
3882    ///
3883    /// let roundtrip = fused_tensor
3884    ///     .unfuse_index(&fused, &[i, j], LinearizationOrder::ColumnMajor)
3885    ///     .unwrap();
3886    /// assert!(roundtrip.isapprox(&tensor, 1e-12, 0.0));
3887    /// ```
3888    pub fn fuse_indices(
3889        &self,
3890        old_indices: &[DynIndex],
3891        new_index: DynIndex,
3892        order: LinearizationOrder,
3893    ) -> Result<Self> {
3894        anyhow::ensure!(
3895            !old_indices.is_empty(),
3896            "fuse_indices requires at least one index to fuse"
3897        );
3898
3899        let old_dims = self.dims();
3900        let mut seen_indices = HashSet::new();
3901        let mut old_axes = Vec::with_capacity(old_indices.len());
3902        for old_index in old_indices {
3903            anyhow::ensure!(
3904                seen_indices.insert(old_index),
3905                "duplicate index in old_indices"
3906            );
3907            let axis = self
3908                .indices
3909                .iter()
3910                .position(|idx| idx == old_index)
3911                .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
3912            anyhow::ensure!(
3913                old_index.dim() == old_dims[axis],
3914                "old index dimension does not match tensor axis dimension"
3915            );
3916            old_axes.push(axis);
3917        }
3918
3919        let fused_dims: Vec<usize> = old_axes.iter().map(|&axis| old_dims[axis]).collect();
3920        let fused_product = checked_product(&fused_dims)?;
3921        anyhow::ensure!(
3922            fused_product == new_index.dim(),
3923            "product of old index dimensions must match the replacement index dimension"
3924        );
3925
3926        let insertion_axis =
3927            old_axes.iter().copied().min().ok_or_else(|| {
3928                anyhow::anyhow!("fuse_indices requires at least one index to fuse")
3929            })?;
3930        let old_axis_set: HashSet<usize> = old_axes.iter().copied().collect();
3931
3932        let mut result_indices =
3933            Vec::with_capacity(self.indices.len() - old_indices.len() + 1usize);
3934        for (axis, index) in self.indices.iter().enumerate() {
3935            if axis == insertion_axis {
3936                result_indices.push(new_index.clone());
3937            }
3938            if !old_axis_set.contains(&axis) {
3939                result_indices.push(index.clone());
3940            }
3941        }
3942        let mut result_seen = HashSet::new();
3943        for index in &result_indices {
3944            anyhow::ensure!(
3945                result_seen.insert(index),
3946                "fuse_indices result would contain duplicate index"
3947            );
3948        }
3949        Self::validate_indices(&result_indices)?;
3950
3951        let mut new_dims = Vec::with_capacity(old_dims.len() - old_indices.len() + 1usize);
3952        for (axis, dim) in old_dims.iter().copied().enumerate() {
3953            if axis == insertion_axis {
3954                new_dims.push(new_index.dim());
3955            }
3956            if !old_axis_set.contains(&axis) {
3957                new_dims.push(dim);
3958            }
3959        }
3960
3961        let old_data = self.to_vec_any()?;
3962        let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
3963        for (old_linear, value) in old_data.into_iter().enumerate() {
3964            let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
3965            let fused_multi: Vec<usize> = old_axes.iter().map(|&axis| old_multi[axis]).collect();
3966            let fused_linear = encode_linear_with_order(&fused_multi, &fused_dims, order)?;
3967
3968            let mut new_multi = Vec::with_capacity(new_dims.len());
3969            for (axis, old_coord) in old_multi.iter().copied().enumerate() {
3970                if axis == insertion_axis {
3971                    new_multi.push(fused_linear);
3972                }
3973                if !old_axis_set.contains(&axis) {
3974                    new_multi.push(old_coord);
3975                }
3976            }
3977            let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
3978            new_data[new_linear] = value;
3979        }
3980
3981        Self::from_dense_any(result_indices, new_data)
3982    }
3983
3984    /// Replace one fused index with multiple indices using an exact reshape.
3985    ///
3986    /// The caller must specify how the old fused index should be decoded into
3987    /// the new indices via `order`.
3988    ///
3989    /// # Examples
3990    /// ```
3991    /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen};
3992    ///
3993    /// let fused = DynIndex::new_dyn(4);
3994    /// let i = DynIndex::new_dyn(2);
3995    /// let j = DynIndex::new_dyn(2);
3996    /// let tensor = TensorDynLen::from_dense(vec![fused.clone()], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
3997    ///
3998    /// let unfused = tensor
3999    ///     .unfuse_index(&fused, &[i.clone(), j.clone()], LinearizationOrder::ColumnMajor)
4000    ///     .unwrap();
4001    ///
4002    /// let expected = TensorDynLen::from_dense(vec![i, j], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
4003    /// assert!(unfused.isapprox(&expected, 1e-12, 0.0));
4004    /// ```
4005    pub fn unfuse_index(
4006        &self,
4007        old_index: &DynIndex,
4008        new_indices: &[DynIndex],
4009        order: LinearizationOrder,
4010    ) -> Result<Self> {
4011        anyhow::ensure!(
4012            !new_indices.is_empty(),
4013            "unfuse_index requires at least one replacement index"
4014        );
4015
4016        let axis = self
4017            .indices
4018            .iter()
4019            .position(|idx| idx == old_index)
4020            .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
4021
4022        let replacement_dims: Vec<usize> = new_indices.iter().map(DynIndex::dim).collect();
4023        let replacement_product = checked_product(&replacement_dims)?;
4024        anyhow::ensure!(
4025            replacement_product == old_index.dim(),
4026            "product of new index dimensions must match the replaced index dimension"
4027        );
4028
4029        let mut result_indices =
4030            Vec::with_capacity(self.indices.len() - 1usize + new_indices.len());
4031        result_indices.extend_from_slice(&self.indices[..axis]);
4032        result_indices.extend(new_indices.iter().cloned());
4033        result_indices.extend_from_slice(&self.indices[axis + 1..]);
4034        Self::validate_indices(&result_indices)?;
4035
4036        let old_dims = self.dims();
4037        let mut new_dims = Vec::with_capacity(old_dims.len() - 1usize + replacement_dims.len());
4038        new_dims.extend_from_slice(&old_dims[..axis]);
4039        new_dims.extend_from_slice(&replacement_dims);
4040        new_dims.extend_from_slice(&old_dims[axis + 1..]);
4041
4042        let old_data = self.to_vec_any()?;
4043        let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
4044        for (old_linear, value) in old_data.into_iter().enumerate() {
4045            let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
4046            let split_multi = decode_linear_with_order(old_multi[axis], &replacement_dims, order)?;
4047            let mut new_multi = Vec::with_capacity(new_dims.len());
4048            new_multi.extend_from_slice(&old_multi[..axis]);
4049            new_multi.extend_from_slice(&split_multi);
4050            new_multi.extend_from_slice(&old_multi[axis + 1..]);
4051            let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
4052            new_data[new_linear] = value;
4053        }
4054
4055        Self::from_dense_any(result_indices, new_data)
4056    }
4057
4058    /// Create a scalar (0-dimensional) tensor from a supported element value.
4059    ///
4060    /// # Example
4061    /// ```
4062    /// use tensor4all_core::TensorDynLen;
4063    ///
4064    /// let scalar = TensorDynLen::scalar(42.0).unwrap();
4065    /// assert_eq!(scalar.dims(), Vec::<usize>::new());
4066    /// assert_eq!(scalar.only().unwrap().real(), 42.0);
4067    /// ```
4068    pub fn scalar<T: TensorElement>(value: T) -> Result<Self> {
4069        Self::from_dense(vec![], vec![value])
4070    }
4071
4072    /// Create a tensor filled with zeros of a supported element type.
4073    ///
4074    /// # Example
4075    /// ```
4076    /// use tensor4all_core::TensorDynLen;
4077    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
4078    ///
4079    /// let i = Index::new_dyn(2);
4080    /// let j = Index::new_dyn(3);
4081    /// let tensor = TensorDynLen::zeros::<f64>(vec![i, j]).unwrap();
4082    /// assert_eq!(tensor.dims(), vec![2, 3]);
4083    /// ```
4084    pub fn zeros<T: TensorElement + Zero + Clone>(indices: Vec<DynIndex>) -> Result<Self> {
4085        let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
4086        let size: usize = dims.iter().product();
4087        Self::from_dense(indices, vec![T::zero(); size])
4088    }
4089}
4090
4091// ============================================================================
4092// High-level API for data extraction (avoids direct .storage() access)
4093// ============================================================================
4094
4095impl TensorDynLen {
4096    /// Extract tensor data as a column-major `Vec<T>`.
4097    ///
4098    /// # Type Parameters
4099    /// * `T` - The scalar element type (`f64` or `Complex64`).
4100    ///
4101    /// # Returns
4102    /// A vector of the tensor data in column-major order.
4103    ///
4104    /// # Errors
4105    /// Returns an error if the tensor's scalar type does not match `T`.
4106    ///
4107    /// # Example
4108    /// ```
4109    /// use tensor4all_core::TensorDynLen;
4110    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
4111    ///
4112    /// let i = Index::new_dyn(2);
4113    /// let tensor = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
4114    /// let data = tensor.to_vec::<f64>().unwrap();
4115    /// assert_eq!(data, &[1.0, 2.0]);
4116    /// ```
4117    pub fn to_vec<T: TensorElement>(&self) -> Result<Vec<T>> {
4118        native_tensor_primal_to_dense_col_major(self.as_native()?)
4119    }
4120
4121    /// Consume the tensor and return its indices with dense column-major values.
4122    ///
4123    /// Use this when a caller needs to move index metadata and dense payload
4124    /// values across an API boundary. The returned values are ordered with the
4125    /// first tensor index varying fastest. Compact diagonal or structured
4126    /// storage is materialized into dense logical values.
4127    ///
4128    /// # Type Parameters
4129    /// * `T` - The scalar element type to extract. Use `f64` for real tensors
4130    ///   and `Complex64` for complex tensors.
4131    ///
4132    /// # Returns
4133    /// The tensor's original indices and dense column-major flat data.
4134    ///
4135    /// # Errors
4136    /// Returns an error if the tensor has tracked autodiff state, if the
4137    /// requested scalar type does not match the tensor payload, or if dense
4138    /// materialization fails.
4139    ///
4140    /// # Examples
4141    /// ```
4142    /// use tensor4all_core::{DynIndex, TensorDynLen};
4143    ///
4144    /// let i = DynIndex::new_dyn(2);
4145    /// let j = DynIndex::new_dyn(2);
4146    /// let tensor = TensorDynLen::from_dense(
4147    ///     vec![i.clone(), j.clone()],
4148    ///     vec![1.0_f64, 2.0, 3.0, 4.0],
4149    /// ).unwrap();
4150    ///
4151    /// let (indices, data) = tensor.into_dense_col_major_parts::<f64>().unwrap();
4152    ///
4153    /// assert_eq!(indices, vec![i, j]);
4154    /// assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
4155    /// ```
4156    pub fn into_dense_col_major_parts<T: TensorElement>(self) -> Result<(Vec<DynIndex>, Vec<T>)> {
4157        anyhow::ensure!(
4158            self.structured_ad.is_none() && !self.tracks_grad(),
4159            "TensorDynLen::into_dense_col_major_parts cannot consume tensors with tracked autodiff state"
4160        );
4161        let data = self.to_vec::<T>()?;
4162        Ok((self.indices, data))
4163    }
4164
4165    fn to_vec_any(&self) -> Result<Vec<AnyScalar>> {
4166        if self.is_complex() {
4167            self.to_vec::<Complex64>().map(|data| {
4168                data.into_iter()
4169                    .map(|value| AnyScalar::new_complex(value.re, value.im))
4170                    .collect()
4171            })
4172        } else {
4173            self.to_vec::<f64>()
4174                .map(|data| data.into_iter().map(AnyScalar::new_real).collect())
4175        }
4176    }
4177
4178    /// Extract tensor data as a column-major `Vec<f64>`.
4179    ///
4180    /// Prefer the generic [`to_vec::<f64>()`](Self::to_vec) method.
4181    /// This wrapper is kept for C API compatibility.
4182    pub fn as_slice_f64(&self) -> Result<Vec<f64>> {
4183        self.to_vec::<f64>()
4184    }
4185
4186    /// Extract tensor data as a column-major `Vec<Complex64>`.
4187    ///
4188    /// Prefer the generic [`to_vec::<Complex64>()`](Self::to_vec) method.
4189    /// This wrapper is kept for C API compatibility.
4190    pub fn as_slice_c64(&self) -> Result<Vec<Complex64>> {
4191        self.to_vec::<Complex64>()
4192    }
4193
4194    /// Check if the tensor has f64 storage.
4195    ///
4196    /// # Example
4197    /// ```
4198    /// use tensor4all_core::TensorDynLen;
4199    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
4200    ///
4201    /// let i = Index::new_dyn(2);
4202    /// let tensor = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
4203    /// assert!(tensor.is_f64());
4204    /// assert!(!tensor.is_complex());
4205    /// ```
4206    pub fn is_f64(&self) -> bool {
4207        self.storage.is_f64()
4208    }
4209
4210    /// Check whether the tensor carries diagonal logical axis metadata.
4211    ///
4212    /// # Examples
4213    ///
4214    /// ```
4215    /// use tensor4all_core::{DynIndex, TensorDynLen};
4216    /// use tensor4all_tensorbackend::Storage;
4217    ///
4218    /// // Tensors from `from_dense` use dense storage
4219    /// let i = DynIndex::new_dyn(2);
4220    /// let j = DynIndex::new_dyn(2);
4221    /// let dense = TensorDynLen::from_dense(vec![i, j], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
4222    /// assert!(!dense.is_diag());
4223    ///
4224    /// // Diagonal metadata is preserved when constructing from diagonal storage.
4225    /// let k = DynIndex::new_dyn(2);
4226    /// let l = DynIndex::new_dyn(2);
4227    /// let diag = TensorDynLen::from_storage(
4228    ///     vec![k, l],
4229    ///     Storage::from_diag_col_major(vec![1.0, 2.0], 2)
4230    ///         .map(std::sync::Arc::new)
4231    ///         .unwrap(),
4232    /// )
4233    /// .unwrap();
4234    /// assert!(diag.is_diag());
4235    /// ```
4236    pub fn is_diag(&self) -> bool {
4237        self.storage.is_diag()
4238    }
4239
4240    /// Check if the tensor has complex storage (C64).
4241    ///
4242    /// # Examples
4243    ///
4244    /// ```
4245    /// use tensor4all_core::{DynIndex, TensorDynLen};
4246    /// use num_complex::Complex64;
4247    ///
4248    /// let i = DynIndex::new_dyn(2);
4249    /// let real_t = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
4250    /// assert!(!real_t.is_complex());
4251    ///
4252    /// let complex_t = TensorDynLen::from_dense(
4253    ///     vec![i],
4254    ///     vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
4255    /// ).unwrap();
4256    /// assert!(complex_t.is_complex());
4257    /// ```
4258    pub fn is_complex(&self) -> bool {
4259        self.storage.is_complex()
4260    }
4261}
4262
4263fn checked_product(dims: &[usize]) -> Result<usize> {
4264    dims.iter().try_fold(1usize, |acc, &dim| {
4265        acc.checked_mul(dim)
4266            .ok_or_else(|| anyhow::anyhow!("dimension product overflow"))
4267    })
4268}
4269
4270fn decode_col_major_linear(linear: usize, dims: &[usize]) -> Result<Vec<usize>> {
4271    let total = checked_product(dims)?;
4272    anyhow::ensure!(
4273        linear < total,
4274        "linear offset {} out of bounds for dims {:?}",
4275        linear,
4276        dims
4277    );
4278    let mut remaining = linear;
4279    let mut out = Vec::with_capacity(dims.len());
4280    for &dim in dims {
4281        out.push(remaining % dim);
4282        remaining /= dim;
4283    }
4284    Ok(out)
4285}
4286
4287fn encode_col_major_linear(indices: &[usize], dims: &[usize]) -> Result<usize> {
4288    anyhow::ensure!(
4289        indices.len() == dims.len(),
4290        "index rank {} does not match dims {:?}",
4291        indices.len(),
4292        dims
4293    );
4294    let mut linear = 0usize;
4295    let mut stride = 1usize;
4296    for (&index, &dim) in indices.iter().zip(dims.iter()) {
4297        anyhow::ensure!(
4298            index < dim,
4299            "index {} out of bounds for dimension {}",
4300            index,
4301            dim
4302        );
4303        linear += index * stride;
4304        stride = stride
4305            .checked_mul(dim)
4306            .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
4307    }
4308    Ok(linear)
4309}
4310
4311fn decode_linear_with_order(
4312    linear: usize,
4313    dims: &[usize],
4314    order: LinearizationOrder,
4315) -> Result<Vec<usize>> {
4316    let total = checked_product(dims)?;
4317    anyhow::ensure!(
4318        linear < total,
4319        "linear offset {} out of bounds for dims {:?}",
4320        linear,
4321        dims
4322    );
4323
4324    let mut remaining = linear;
4325    let mut out = vec![0usize; dims.len()];
4326    match order {
4327        LinearizationOrder::ColumnMajor => {
4328            for (slot, &dim) in out.iter_mut().zip(dims.iter()) {
4329                *slot = remaining % dim;
4330                remaining /= dim;
4331            }
4332        }
4333        LinearizationOrder::RowMajor => {
4334            for (slot, &dim) in out.iter_mut().rev().zip(dims.iter().rev()) {
4335                *slot = remaining % dim;
4336                remaining /= dim;
4337            }
4338        }
4339    }
4340    Ok(out)
4341}
4342
4343fn encode_linear_with_order(
4344    indices: &[usize],
4345    dims: &[usize],
4346    order: LinearizationOrder,
4347) -> Result<usize> {
4348    match order {
4349        LinearizationOrder::ColumnMajor => encode_col_major_linear(indices, dims),
4350        LinearizationOrder::RowMajor => {
4351            anyhow::ensure!(
4352                indices.len() == dims.len(),
4353                "index rank {} does not match dims {:?}",
4354                indices.len(),
4355                dims
4356            );
4357            let mut linear = 0usize;
4358            let mut stride = 1usize;
4359            for (&index, &dim) in indices.iter().rev().zip(dims.iter().rev()) {
4360                anyhow::ensure!(
4361                    index < dim,
4362                    "index {} out of bounds for dimension {}",
4363                    index,
4364                    dim
4365                );
4366                linear += index * stride;
4367                stride = stride
4368                    .checked_mul(dim)
4369                    .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
4370            }
4371            Ok(linear)
4372        }
4373    }
4374}