Skip to main content

tensor4all_core/defaults/
tensordynlen.rs

1use crate::defaults::DynIndex;
2use crate::index_like::IndexLike;
3use crate::index_ops::{common_ind_positions, prepare_contraction, prepare_contraction_pairs};
4use crate::tensor_like::LinearizationOrder;
5use crate::AnyScalar;
6use anyhow::{Context, Result};
7use num_complex::Complex64;
8use num_traits::Zero;
9use rand::Rng;
10use rand_distr::{Distribution, StandardNormal};
11use std::cell::RefCell;
12use std::cmp::Reverse;
13use std::collections::{HashMap, HashSet};
14use std::env;
15use std::sync::{Arc, OnceLock};
16use std::time::{Duration, Instant};
17use tenferro::eager_tensor::einsum_subscripts as eager_einsum_ad;
18use tenferro::{
19    CpuBackend, DType, DotGeneralConfig, EagerTensor, EinsumSubscripts, Tensor as NativeTensor,
20};
21use tensor4all_tensorbackend::{
22    axpby_native_tensor, contract_native_tensor, default_eager_ctx,
23    dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
24    native_tensor_primal_to_dense_col_major, native_tensor_primal_to_diag_c64,
25    native_tensor_primal_to_diag_f64, native_tensor_primal_to_storage, scale_native_tensor,
26    storage_payload_native_read_input, storage_to_native_tensor, AnyScalar as BackendScalar,
27    StorageScalar, TensorElement,
28};
29use tensor4all_tensorbackend::{Storage, StorageKind};
30
31use super::contract::PairwiseContractionOptions;
32use super::structured_contraction::{
33    normalize_payload_read_for_roots, storage_from_payload_native, storage_payload_native,
34    OperandLayout, StructuredContractionPlan, StructuredContractionSpec,
35};
36
37#[derive(Debug, Default, Clone)]
38struct PairwiseContractProfileEntry {
39    calls: usize,
40    total_time: Duration,
41    total_bytes: usize,
42}
43
44thread_local! {
45    static PAIRWISE_CONTRACT_PROFILE_STATE: RefCell<HashMap<&'static str, PairwiseContractProfileEntry>> =
46        RefCell::new(HashMap::new());
47}
48
49fn pairwise_contract_profile_enabled() -> bool {
50    static ENABLED: OnceLock<bool> = OnceLock::new();
51    *ENABLED.get_or_init(|| env::var("T4A_PROFILE_PAIRWISE_CONTRACT").is_ok())
52}
53
54fn record_pairwise_contract_profile(section: &'static str, elapsed: Duration) {
55    if !pairwise_contract_profile_enabled() {
56        return;
57    }
58    PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
59        let mut state = state.borrow_mut();
60        let entry = state.entry(section).or_default();
61        entry.calls += 1;
62        entry.total_time += elapsed;
63    });
64}
65
66fn record_pairwise_contract_profile_bytes(section: &'static str, bytes: usize) {
67    if !pairwise_contract_profile_enabled() {
68        return;
69    }
70    PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
71        let mut state = state.borrow_mut();
72        let entry = state.entry(section).or_default();
73        entry.total_bytes += bytes;
74    });
75}
76
77fn profile_pairwise_contract_section<T>(section: &'static str, f: impl FnOnce() -> T) -> T {
78    if !pairwise_contract_profile_enabled() {
79        return f();
80    }
81    let started = Instant::now();
82    let result = f();
83    record_pairwise_contract_profile(section, started.elapsed());
84    result
85}
86
87/// Reset the aggregated pairwise `TensorDynLen` contraction profile.
88pub fn reset_pairwise_contract_profile() {
89    PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
90}
91
92/// Print and clear the aggregated pairwise `TensorDynLen` contraction profile.
93pub fn print_and_reset_pairwise_contract_profile() {
94    if !pairwise_contract_profile_enabled() {
95        return;
96    }
97    PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
98        let mut entries: Vec<_> = state
99            .borrow()
100            .iter()
101            .map(|(section, entry)| (*section, entry.clone()))
102            .collect();
103        state.borrow_mut().clear();
104        entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
105
106        eprintln!("=== TensorDynLen pairwise contract profile ===");
107        for (section, entry) in entries {
108            let per_call_us = if entry.calls == 0 {
109                0.0
110            } else {
111                entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64
112            };
113            eprintln!(
114                "{section}: calls={} total={:.6}ms per_call={:.3}us bytes={}",
115                entry.calls,
116                entry.total_time.as_secs_f64() * 1.0e3,
117                per_call_us,
118                entry.total_bytes,
119            );
120        }
121    });
122}
123
124fn native_tensor_profile_bytes(native: &NativeTensor) -> usize {
125    let element_size = match native.dtype() {
126        DType::F32 => 4,
127        DType::F64 => 8,
128        DType::C32 => 8,
129        DType::C64 => 16,
130        DType::I64 => 8,
131    };
132    native.shape().iter().product::<usize>() * element_size
133}
134
135/// Trait for scalar types that can generate random values from a standard
136/// normal distribution.
137///
138/// This enables the generic [`TensorDynLen::random`] constructor.
139pub trait RandomScalar: TensorElement {
140    /// Generate a random value from the standard normal distribution.
141    fn random_value<R: Rng>(rng: &mut R) -> Self;
142}
143
144impl RandomScalar for f64 {
145    fn random_value<R: Rng>(rng: &mut R) -> Self {
146        StandardNormal.sample(rng)
147    }
148}
149
150impl RandomScalar for Complex64 {
151    fn random_value<R: Rng>(rng: &mut R) -> Self {
152        Complex64::new(StandardNormal.sample(rng), StandardNormal.sample(rng))
153    }
154}
155
156/// Compute the permutation array from original indices to new indices.
157///
158/// This function finds the mapping from new indices to original indices by
159/// matching index IDs. The result is a permutation array `perm` such that
160/// `new_indices[i]` corresponds to `original_indices[perm[i]]`.
161///
162/// # Arguments
163/// * `original_indices` - The original indices in their current order
164/// * `new_indices` - The desired new indices order (must be a permutation of original_indices)
165///
166/// # Returns
167/// A `Vec<usize>` representing the permutation: `perm[i]` is the position in
168/// `original_indices` of the index that should be at position `i` in `new_indices`.
169///
170/// # Errors
171/// Returns an error if the slices have different lengths, if `new_indices`
172/// is not a permutation of `original_indices`, or if `new_indices` contains
173/// duplicate indices.
174///
175/// # Example
176/// ```
177/// use tensor4all_core::tensor::compute_permutation_from_indices;
178/// use tensor4all_core::DynIndex;
179///
180/// let i = DynIndex::new_dyn(2);
181/// let j = DynIndex::new_dyn(3);
182/// let original = vec![i.clone(), j.clone()];
183/// let new_order = vec![j.clone(), i.clone()];
184///
185/// let perm = compute_permutation_from_indices(&original, &new_order).unwrap();
186/// assert_eq!(perm, vec![1, 0]);  // j is at position 1, i is at position 0
187/// ```
188pub fn compute_permutation_from_indices(
189    original_indices: &[DynIndex],
190    new_indices: &[DynIndex],
191) -> Result<Vec<usize>> {
192    anyhow::ensure!(
193        new_indices.len() == original_indices.len(),
194        "new_indices length must match original_indices length"
195    );
196
197    let mut perm = Vec::with_capacity(new_indices.len());
198    let mut used = std::collections::HashSet::new();
199
200    for new_idx in new_indices {
201        // Find the position of this index in the original indices
202        // DynIndex implements Eq, so we can compare directly
203        let pos = original_indices
204            .iter()
205            .position(|old_idx| old_idx == new_idx)
206            .ok_or_else(|| {
207                anyhow::anyhow!("new_indices must be a permutation of original_indices")
208            })?;
209
210        anyhow::ensure!(used.insert(pos), "duplicate index in new_indices");
211        perm.push(pos);
212    }
213
214    Ok(perm)
215}
216
217#[derive(Clone)]
218pub(crate) struct StructuredAdValue {
219    payload: Arc<EagerTensor<CpuBackend>>,
220    payload_dims: Vec<usize>,
221    axis_classes: Vec<usize>,
222}
223
224#[derive(Clone)]
225pub(crate) enum TensorDynLenStorage {
226    Materialized(Arc<Storage>),
227    Eager {
228        inner: Arc<EagerTensor<CpuBackend>>,
229        axis_classes: Vec<usize>,
230    },
231}
232
233impl TensorDynLenStorage {
234    fn from_storage(storage: Arc<Storage>) -> Self {
235        Self::Materialized(storage)
236    }
237
238    fn from_eager_dense(inner: EagerTensor<CpuBackend>, rank: usize) -> Self {
239        Self::Eager {
240            inner: Arc::new(inner),
241            axis_classes: TensorDynLen::dense_axis_classes(rank),
242        }
243    }
244
245    fn eager(&self) -> Option<&EagerTensor<CpuBackend>> {
246        match self {
247            Self::Materialized(_) => None,
248            Self::Eager { inner, .. } => Some(inner.as_ref()),
249        }
250    }
251
252    fn axis_classes(&self) -> &[usize] {
253        match self {
254            Self::Materialized(storage) => storage.axis_classes(),
255            Self::Eager { axis_classes, .. } => axis_classes,
256        }
257    }
258
259    fn payload_dims(&self) -> &[usize] {
260        match self {
261            Self::Materialized(storage) => storage.payload_dims(),
262            Self::Eager { inner, .. } => inner.data().shape(),
263        }
264    }
265
266    fn payload_strides_vec(&self) -> Vec<isize> {
267        match self {
268            Self::Materialized(storage) => storage.payload_strides().to_vec(),
269            Self::Eager { inner, .. } => {
270                let mut stride = 1isize;
271                inner
272                    .data()
273                    .shape()
274                    .iter()
275                    .map(|&dim| {
276                        let current = stride;
277                        stride *= isize::try_from(dim).unwrap_or(isize::MAX);
278                        current
279                    })
280                    .collect()
281            }
282        }
283    }
284
285    fn is_f64(&self) -> bool {
286        match self {
287            Self::Materialized(storage) => storage.is_f64(),
288            Self::Eager { inner, .. } => inner.data().dtype() == DType::F64,
289        }
290    }
291
292    fn is_c64(&self) -> bool {
293        match self {
294            Self::Materialized(storage) => storage.is_c64(),
295            Self::Eager { inner, .. } => inner.data().dtype() == DType::C64,
296        }
297    }
298
299    fn is_complex(&self) -> bool {
300        match self {
301            Self::Materialized(storage) => storage.is_complex(),
302            Self::Eager { inner, .. } => matches!(inner.data().dtype(), DType::C32 | DType::C64),
303        }
304    }
305
306    fn is_diag(&self) -> bool {
307        match self {
308            Self::Materialized(storage) => storage.is_diag(),
309            Self::Eager { axis_classes, .. } => TensorDynLen::is_diag_axis_classes(axis_classes),
310        }
311    }
312
313    fn storage_kind(&self) -> StorageKind {
314        match self {
315            Self::Materialized(storage) => storage.storage_kind(),
316            Self::Eager { axis_classes, .. } => {
317                if axis_classes.iter().copied().eq(0..axis_classes.len()) {
318                    StorageKind::Dense
319                } else if TensorDynLen::is_diag_axis_classes(axis_classes) {
320                    StorageKind::Diagonal
321                } else {
322                    StorageKind::Structured
323                }
324            }
325        }
326    }
327
328    fn materialize(&self, logical_rank: usize) -> Result<Arc<Storage>> {
329        match self {
330            Self::Materialized(storage) => Ok(Arc::clone(storage)),
331            Self::Eager {
332                inner,
333                axis_classes,
334            } => Ok(Arc::new(
335                TensorDynLen::storage_from_native_with_axis_classes(
336                    inner.data(),
337                    axis_classes,
338                    logical_rank,
339                )?,
340            )),
341        }
342    }
343
344    fn scale(&self, scalar: &BackendScalar) -> Result<Storage> {
345        Ok(self.materialize(self.axis_classes().len())?.scale(scalar))
346    }
347
348    fn conj(&self) -> Result<Self> {
349        match self {
350            Self::Materialized(storage) => Ok(Self::Materialized(Arc::new(storage.conj()))),
351            Self::Eager {
352                inner,
353                axis_classes,
354            } => Ok(Self::Eager {
355                inner: Arc::new(inner.conj()?),
356                axis_classes: axis_classes.clone(),
357            }),
358        }
359    }
360
361    fn max_abs(&self) -> Result<f64> {
362        Ok(self.materialize(self.axis_classes().len())?.max_abs())
363    }
364}
365
366/// Dynamic-rank tensor with structured payload storage -- the central data type
367/// of tensor4all.
368///
369/// `TensorDynLen` stores a logical multi-dimensional tensor of `f64` or
370/// `Complex64` values together with a list of [`DynIndex`] labels. The
371/// authoritative payload is compact [`Storage`], which may be dense, diagonal,
372/// or explicitly structured. The indices carry unique identities (UUIDs) so
373/// that contraction, addition, and other binary operations can automatically
374/// match legs by identity rather than position.
375///
376/// # Key Operations
377///
378/// | Operation | Method |
379/// |-----------|--------|
380/// | Create from data | [`from_dense`](Self::from_dense), [`from_diag`](Self::from_diag), [`zeros`](Self::zeros) |
381/// | Extract data | [`to_vec`](Self::to_vec), [`into_dense_col_major_parts`](Self::into_dense_col_major_parts), [`sum`](Self::sum), [`only`](Self::only) |
382/// | Contraction | [`contract`](Self::contract) |
383/// | Arithmetic | [`add`](Self::add), [`scale`](Self::scale), [`axpby`](Self::axpby) |
384/// | Factorization | via [`TensorFactorizationLike::factorize`](crate::TensorFactorizationLike::factorize) |
385/// | Norms | [`norm`](Self::norm), [`norm_squared`](Self::norm_squared), [`maxabs`](Self::maxabs) |
386/// | Index ops | [`replaceind`](Self::replaceind), [`permute_indices`](Self::permute_indices) |
387///
388/// # Data Layout
389///
390/// Logical dense extraction uses **column-major** order (first index varies
391/// fastest), matching Fortran, Julia, and ITensors.jl conventions. Compact
392/// structured payloads additionally carry explicit payload dimensions, strides,
393/// and logical-axis classes.
394///
395/// # Examples
396///
397/// ```
398/// use tensor4all_core::{TensorDynLen, DynIndex};
399///
400/// // Create a 2x3 real tensor
401/// let i = DynIndex::new_dyn(2);
402/// let j = DynIndex::new_dyn(3);
403/// let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
404/// let t = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
405///
406/// assert_eq!(t.dims(), vec![2, 3]);
407/// assert!(t.is_f64());
408///
409/// // Sum all elements: 1+2+3+4+5+6 = 21
410/// let s = t.sum().unwrap();
411/// assert!((s.real() - 21.0).abs() < 1e-12);
412///
413/// // Extract data back out
414/// let data_out = t.to_vec::<f64>().unwrap();
415/// assert_eq!(data_out, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
416/// ```
417#[derive(Clone)]
418pub struct TensorDynLen {
419    /// Full index information (includes tags and other metadata).
420    pub indices: Vec<DynIndex>,
421    /// Authoritative compact payload storage.
422    pub(crate) storage: TensorDynLenStorage,
423    /// Optional tracked compact payload used to preserve structured AD layouts.
424    pub(crate) structured_ad: Option<Arc<StructuredAdValue>>,
425    /// Lazily materialized eager payload for native execution and AD.
426    pub(crate) eager_cache: Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>>,
427}
428
429impl TensorDynLen {
430    fn dense_axis_classes(rank: usize) -> Vec<usize> {
431        (0..rank).collect()
432    }
433
434    fn diag_axis_classes(rank: usize) -> Vec<usize> {
435        if rank == 0 {
436            vec![]
437        } else {
438            vec![0; rank]
439        }
440    }
441
442    fn canonicalize_axis_classes(axis_classes: &[usize]) -> Vec<usize> {
443        let mut map = std::collections::HashMap::new();
444        let mut next = 0usize;
445        axis_classes
446            .iter()
447            .map(|&class_id| {
448                *map.entry(class_id).or_insert_with(|| {
449                    let canonical = next;
450                    next += 1;
451                    canonical
452                })
453            })
454            .collect()
455    }
456
457    fn permute_axis_classes(&self, perm: &[usize]) -> Vec<usize> {
458        let axis_classes = self.storage.axis_classes();
459        let permuted: Vec<usize> = perm.iter().map(|&index| axis_classes[index]).collect();
460        Self::canonicalize_axis_classes(&permuted)
461    }
462
463    fn normalize_insert_axis(op: &str, axis: isize, rank: usize) -> Result<usize> {
464        let normalized = if axis < 0 {
465            rank as isize + 1 + axis
466        } else {
467            axis
468        };
469        anyhow::ensure!(
470            normalized >= 0 && normalized <= rank as isize,
471            "{op}: axis {axis} is out of bounds for inserting into rank {rank}"
472        );
473        Ok(normalized as usize)
474    }
475
476    fn is_diag_axis_classes(axis_classes: &[usize]) -> bool {
477        axis_classes.len() >= 2 && axis_classes.iter().all(|&class_id| class_id == 0)
478    }
479
480    fn einsum_subscripts_from_usize_ids(
481        inputs: &[Vec<usize>],
482        output: &[usize],
483    ) -> Result<EinsumSubscripts> {
484        let input_labels = inputs
485            .iter()
486            .map(|ids| {
487                ids.iter()
488                    .map(|&id| {
489                        u32::try_from(id)
490                            .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
491                    })
492                    .collect::<Result<Vec<_>>>()
493            })
494            .collect::<Result<Vec<_>>>()?;
495        let output_labels = output
496            .iter()
497            .map(|&id| {
498                u32::try_from(id)
499                    .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
500            })
501            .collect::<Result<Vec<_>>>()?;
502        let input_refs = input_labels.iter().map(Vec::as_slice).collect::<Vec<_>>();
503        Ok(EinsumSubscripts::new(&input_refs, &output_labels))
504    }
505
506    fn build_binary_einsum_subscripts(
507        lhs_rank: usize,
508        axes_a: &[usize],
509        rhs_rank: usize,
510        axes_b: &[usize],
511    ) -> Result<EinsumSubscripts> {
512        anyhow::ensure!(
513            axes_a.len() == axes_b.len(),
514            "contract axis length mismatch: lhs {:?}, rhs {:?}",
515            axes_a,
516            axes_b
517        );
518
519        let mut lhs_ids = vec![usize::MAX; lhs_rank];
520        let mut rhs_ids = vec![usize::MAX; rhs_rank];
521        let mut next_id = 0usize;
522
523        let mut seen_lhs = vec![false; lhs_rank];
524        let mut seen_rhs = vec![false; rhs_rank];
525
526        for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
527            anyhow::ensure!(
528                lhs_axis < lhs_rank,
529                "lhs contract axis {lhs_axis} out of range"
530            );
531            anyhow::ensure!(
532                rhs_axis < rhs_rank,
533                "rhs contract axis {rhs_axis} out of range"
534            );
535            anyhow::ensure!(
536                !seen_lhs[lhs_axis],
537                "duplicate lhs contract axis {lhs_axis}"
538            );
539            anyhow::ensure!(
540                !seen_rhs[rhs_axis],
541                "duplicate rhs contract axis {rhs_axis}"
542            );
543            seen_lhs[lhs_axis] = true;
544            seen_rhs[rhs_axis] = true;
545            lhs_ids[lhs_axis] = next_id;
546            rhs_ids[rhs_axis] = next_id;
547            next_id += 1;
548        }
549
550        let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
551        for id in &mut lhs_ids {
552            if *id == usize::MAX {
553                *id = next_id;
554                output_ids.push(next_id);
555                next_id += 1;
556            }
557        }
558        for id in &mut rhs_ids {
559            if *id == usize::MAX {
560                *id = next_id;
561                output_ids.push(next_id);
562                next_id += 1;
563            }
564        }
565
566        Self::einsum_subscripts_from_usize_ids(&[lhs_ids, rhs_ids], &output_ids)
567    }
568
569    fn binary_dot_general_config(axes_a: &[usize], axes_b: &[usize]) -> Result<DotGeneralConfig> {
570        anyhow::ensure!(
571            axes_a.len() == axes_b.len(),
572            "contract axis length mismatch: lhs {:?}, rhs {:?}",
573            axes_a,
574            axes_b
575        );
576        Ok(DotGeneralConfig {
577            lhs_contracting_dims: axes_a.to_vec(),
578            rhs_contracting_dims: axes_b.to_vec(),
579            lhs_batch_dims: vec![],
580            rhs_batch_dims: vec![],
581        })
582    }
583
584    fn binary_contraction_axis_classes(
585        lhs_axis_classes: &[usize],
586        axes_a: &[usize],
587        rhs_axis_classes: &[usize],
588        axes_b: &[usize],
589    ) -> Vec<usize> {
590        debug_assert_eq!(axes_a.len(), axes_b.len());
591
592        fn find(parent: &mut [usize], value: usize) -> usize {
593            if parent[value] != value {
594                parent[value] = find(parent, parent[value]);
595            }
596            parent[value]
597        }
598
599        fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
600            let lhs_root = find(parent, lhs);
601            let rhs_root = find(parent, rhs);
602            if lhs_root != rhs_root {
603                parent[rhs_root] = lhs_root;
604            }
605        }
606
607        let lhs_payload_rank = lhs_axis_classes
608            .iter()
609            .copied()
610            .max()
611            .map(|value| value + 1)
612            .unwrap_or(0);
613        let rhs_payload_rank = rhs_axis_classes
614            .iter()
615            .copied()
616            .max()
617            .map(|value| value + 1)
618            .unwrap_or(0);
619        let rhs_offset = lhs_payload_rank;
620        let mut parent: Vec<usize> = (0..lhs_payload_rank + rhs_payload_rank).collect();
621
622        for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
623            union(
624                &mut parent,
625                lhs_axis_classes[lhs_axis],
626                rhs_offset + rhs_axis_classes[rhs_axis],
627            );
628        }
629
630        let mut lhs_contracted = vec![false; lhs_axis_classes.len()];
631        for &axis in axes_a {
632            lhs_contracted[axis] = true;
633        }
634        let mut rhs_contracted = vec![false; rhs_axis_classes.len()];
635        for &axis in axes_b {
636            rhs_contracted[axis] = true;
637        }
638
639        let mut root_to_class = std::collections::HashMap::new();
640        let mut next_class = 0usize;
641        let mut axis_classes = Vec::new();
642
643        for (axis, &class_id) in lhs_axis_classes.iter().enumerate() {
644            if !lhs_contracted[axis] {
645                let root = find(&mut parent, class_id);
646                let class = *root_to_class.entry(root).or_insert_with(|| {
647                    let value = next_class;
648                    next_class += 1;
649                    value
650                });
651                axis_classes.push(class);
652            }
653        }
654        for (axis, &class_id) in rhs_axis_classes.iter().enumerate() {
655            if !rhs_contracted[axis] {
656                let root = find(&mut parent, rhs_offset + class_id);
657                let class = *root_to_class.entry(root).or_insert_with(|| {
658                    let value = next_class;
659                    next_class += 1;
660                    value
661                });
662                axis_classes.push(class);
663            }
664        }
665
666        axis_classes
667    }
668
669    fn scale_subscripts(rank: usize) -> Result<EinsumSubscripts> {
670        let ids: Vec<usize> = (0..rank).collect();
671        Self::einsum_subscripts_from_usize_ids(&[ids.clone(), Vec::new()], &ids)
672    }
673
674    fn validate_indices(indices: &[DynIndex]) -> Result<()> {
675        let mut seen = HashSet::new();
676        for idx in indices {
677            anyhow::ensure!(
678                seen.insert(idx.clone()),
679                "Tensor indices must all be unique"
680            );
681        }
682        Ok(())
683    }
684
685    fn validate_diag_dims(dims: &[usize]) -> Result<()> {
686        if !dims.is_empty() {
687            let first_dim = dims[0];
688            for (i, &dim) in dims.iter().enumerate() {
689                anyhow::ensure!(
690                    dim == first_dim,
691                    "DiagTensor requires all indices to have the same dimension, but dims[{i}] = {dim} != dims[0] = {first_dim}"
692                );
693            }
694        }
695        Ok(())
696    }
697
698    fn seed_native_payload(storage: &Storage, dims: &[usize]) -> Result<NativeTensor> {
699        storage_to_native_tensor(storage, dims)
700    }
701
702    fn empty_eager_cache() -> Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>> {
703        Arc::new(OnceLock::new())
704    }
705
706    fn eager_cache_with(
707        inner: EagerTensor<CpuBackend>,
708    ) -> Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>> {
709        let cache = Arc::new(OnceLock::new());
710        let _ = cache.set(Arc::new(inner));
711        cache
712    }
713
714    fn compact_payload_inner(&self) -> Result<EagerTensor<CpuBackend>> {
715        Ok(EagerTensor::from_tensor_in(
716            storage_payload_native(self.storage.materialize(self.indices.len())?.as_ref())?,
717            default_eager_ctx(),
718        ))
719    }
720
721    fn tracked_compact_payload_value(&self) -> Option<&StructuredAdValue> {
722        self.structured_ad.as_deref()
723    }
724
725    fn compact_payload_is_logical_dense(&self, payload_dims: &[usize]) -> bool {
726        self.storage.axis_classes() == Self::dense_axis_classes(self.indices.len())
727            && payload_dims == self.dims()
728    }
729
730    fn uses_tracked_compact_storage(&self) -> bool {
731        self.tracked_compact_payload_value()
732            .is_some_and(|value| !self.compact_payload_is_logical_dense(&value.payload_dims))
733    }
734
735    fn ensure_shape_packing_preserves_ad(&self, op_name: &str) -> Result<()> {
736        anyhow::ensure!(
737            !self.uses_tracked_compact_storage(),
738            "{op_name}: structured AD tensors with compact storage are not supported because materializing compact storage would detach gradients"
739        );
740        Ok(())
741    }
742
743    fn operand_indices_for_contraction(&self, conjugate: bool) -> Vec<DynIndex> {
744        if conjugate {
745            self.indices.iter().map(|index| index.conj()).collect()
746        } else {
747            self.indices.clone()
748        }
749    }
750
751    fn build_binary_contraction_labels(
752        lhs_rank: usize,
753        axes_a: &[usize],
754        rhs_rank: usize,
755        axes_b: &[usize],
756    ) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> {
757        anyhow::ensure!(
758            axes_a.len() == axes_b.len(),
759            "contract axis length mismatch: lhs {:?}, rhs {:?}",
760            axes_a,
761            axes_b
762        );
763
764        let mut lhs_ids = vec![usize::MAX; lhs_rank];
765        let mut rhs_ids = vec![usize::MAX; rhs_rank];
766        let mut next_id = 0usize;
767
768        let mut seen_lhs = vec![false; lhs_rank];
769        let mut seen_rhs = vec![false; rhs_rank];
770
771        for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
772            anyhow::ensure!(
773                lhs_axis < lhs_rank,
774                "lhs contract axis {lhs_axis} out of range"
775            );
776            anyhow::ensure!(
777                rhs_axis < rhs_rank,
778                "rhs contract axis {rhs_axis} out of range"
779            );
780            anyhow::ensure!(
781                !seen_lhs[lhs_axis],
782                "duplicate lhs contract axis {lhs_axis}"
783            );
784            anyhow::ensure!(
785                !seen_rhs[rhs_axis],
786                "duplicate rhs contract axis {rhs_axis}"
787            );
788            seen_lhs[lhs_axis] = true;
789            seen_rhs[rhs_axis] = true;
790            lhs_ids[lhs_axis] = next_id;
791            rhs_ids[rhs_axis] = next_id;
792            next_id += 1;
793        }
794
795        let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
796        for id in &mut lhs_ids {
797            if *id == usize::MAX {
798                *id = next_id;
799                output_ids.push(next_id);
800                next_id += 1;
801            }
802        }
803        for id in &mut rhs_ids {
804            if *id == usize::MAX {
805                *id = next_id;
806                output_ids.push(next_id);
807                next_id += 1;
808            }
809        }
810
811        Ok((lhs_ids, rhs_ids, output_ids))
812    }
813
814    fn build_payload_einsum_subscripts(
815        input_roots: &[Vec<usize>],
816        output_roots: &[usize],
817    ) -> Result<EinsumSubscripts> {
818        Self::einsum_subscripts_from_usize_ids(input_roots, output_roots)
819    }
820
821    fn normalize_eager_payload_for_roots(
822        payload: &EagerTensor<CpuBackend>,
823        roots: &[usize],
824    ) -> Result<(Option<EagerTensor<CpuBackend>>, Vec<usize>)> {
825        anyhow::ensure!(
826            payload.data().shape().len() == roots.len(),
827            "payload rank {} does not match root label count {}",
828            payload.data().shape().len(),
829            roots.len()
830        );
831
832        let mut current_payload = None;
833        let mut current_roots = roots.to_vec();
834        while let Some((axis_a, axis_b)) = Self::first_duplicate_pair(&current_roots) {
835            let source = current_payload.as_ref().unwrap_or(payload);
836            current_payload = Some(source.extract_diag(axis_a, axis_b)?);
837            current_roots.remove(axis_b);
838        }
839
840        Ok((current_payload, current_roots))
841    }
842
843    fn first_duplicate_pair(values: &[usize]) -> Option<(usize, usize)> {
844        let mut first_axis_by_value = std::collections::HashMap::new();
845        for (axis, &value) in values.iter().enumerate() {
846            if let Some(&first_axis) = first_axis_by_value.get(&value) {
847                return Some((first_axis, axis));
848            }
849            first_axis_by_value.insert(value, axis);
850        }
851        None
852    }
853
854    fn binary_structured_contraction_plan(
855        &self,
856        other: &Self,
857        axes_a: &[usize],
858        axes_b: &[usize],
859    ) -> Result<(StructuredContractionPlan, Vec<Vec<usize>>, Vec<usize>)> {
860        let (lhs_labels, rhs_labels, output_labels) = Self::build_binary_contraction_labels(
861            self.indices.len(),
862            axes_a,
863            other.indices.len(),
864            axes_b,
865        )?;
866        let operands = vec![
867            OperandLayout::new(self.dims(), self.storage.axis_classes().to_vec())?,
868            OperandLayout::new(other.dims(), other.storage.axis_classes().to_vec())?,
869        ];
870        let spec = StructuredContractionSpec {
871            input_labels: vec![lhs_labels, rhs_labels],
872            output_labels,
873            retained_labels: Default::default(),
874        };
875        let plan = StructuredContractionPlan::new(&operands, &spec)?;
876        Ok((plan, spec.input_labels, spec.output_labels))
877    }
878
879    fn from_structured_payload_inner(
880        indices: Vec<DynIndex>,
881        payload_inner: EagerTensor<CpuBackend>,
882        payload_dims: Vec<usize>,
883        axis_classes: Vec<usize>,
884    ) -> Result<Self> {
885        Self::validate_indices(&indices)?;
886        if payload_inner.data().shape() != payload_dims {
887            return Err(anyhow::anyhow!(
888                "structured payload dims {:?} do not match planned payload dims {:?}",
889                payload_inner.data().shape(),
890                payload_dims
891            ));
892        }
893        let storage = storage_from_payload_native(
894            payload_inner.data().clone(),
895            &payload_dims,
896            axis_classes.clone(),
897        )?;
898        Self::validate_storage_matches_indices(&indices, &storage)?;
899        Ok(Self {
900            indices,
901            storage: TensorDynLenStorage::from_storage(Arc::new(storage)),
902            structured_ad: Some(Arc::new(StructuredAdValue {
903                payload: Arc::new(payload_inner),
904                payload_dims,
905                axis_classes,
906            })),
907            eager_cache: Self::empty_eager_cache(),
908        })
909    }
910
911    fn contract_structured_payloads(
912        &self,
913        other: &Self,
914        result_indices: Vec<DynIndex>,
915        axes_a: &[usize],
916        axes_b: &[usize],
917    ) -> Result<Self> {
918        let (plan, _, _) = self.binary_structured_contraction_plan(other, axes_a, axes_b)?;
919        let lhs_roots = plan.operand_plans[0].class_roots.clone();
920        let rhs_roots = plan.operand_plans[1].class_roots.clone();
921        let scalar_multiply =
922            lhs_roots.is_empty() && rhs_roots.is_empty() && plan.output_payload_roots.is_empty();
923
924        if let (Some(lhs_ad), Some(rhs_ad)) = (
925            self.tracked_compact_payload_value(),
926            other.tracked_compact_payload_value(),
927        ) {
928            if lhs_ad.payload.data().dtype() != rhs_ad.payload.data().dtype() {
929                return Err(anyhow::anyhow!(
930                    "structured AD contraction requires matching payload dtypes"
931                ));
932            }
933            let (lhs_normalized, lhs_labels) =
934                Self::normalize_eager_payload_for_roots(lhs_ad.payload.as_ref(), &lhs_roots)?;
935            let (rhs_normalized, rhs_labels) =
936                Self::normalize_eager_payload_for_roots(rhs_ad.payload.as_ref(), &rhs_roots)?;
937            let lhs_payload = lhs_normalized
938                .as_ref()
939                .unwrap_or_else(|| lhs_ad.payload.as_ref());
940            let rhs_payload = rhs_normalized
941                .as_ref()
942                .unwrap_or_else(|| rhs_ad.payload.as_ref());
943            let payload = if scalar_multiply {
944                lhs_payload.mul(rhs_payload)?
945            } else {
946                let subscripts = Self::build_payload_einsum_subscripts(
947                    &[lhs_labels, rhs_labels],
948                    &plan.output_payload_roots,
949                )?;
950                eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
951            };
952            return Self::from_structured_payload_inner(
953                result_indices,
954                payload,
955                plan.output_payload_dims,
956                plan.output_axis_classes,
957            );
958        }
959
960        if self.tracked_compact_payload_value().is_some()
961            || other.tracked_compact_payload_value().is_some()
962        {
963            let lhs_owned = if self.tracked_compact_payload_value().is_some() {
964                None
965            } else {
966                Some(self.compact_payload_inner()?)
967            };
968            let rhs_owned = if other.tracked_compact_payload_value().is_some() {
969                None
970            } else {
971                Some(other.compact_payload_inner()?)
972            };
973            let lhs = if let Some(value) = self.tracked_compact_payload_value() {
974                value.payload.as_ref()
975            } else {
976                lhs_owned
977                    .as_ref()
978                    .ok_or_else(|| anyhow::anyhow!("missing untracked left compact payload"))?
979            };
980            let rhs = if let Some(value) = other.tracked_compact_payload_value() {
981                value.payload.as_ref()
982            } else {
983                rhs_owned
984                    .as_ref()
985                    .ok_or_else(|| anyhow::anyhow!("missing untracked right compact payload"))?
986            };
987            if lhs.data().dtype() != rhs.data().dtype() {
988                return Err(anyhow::anyhow!(
989                    "structured AD contraction requires matching payload dtypes"
990                ));
991            }
992            let (lhs_normalized, lhs_labels) =
993                Self::normalize_eager_payload_for_roots(lhs, &lhs_roots)?;
994            let (rhs_normalized, rhs_labels) =
995                Self::normalize_eager_payload_for_roots(rhs, &rhs_roots)?;
996            let lhs_payload = lhs_normalized.as_ref().unwrap_or(lhs);
997            let rhs_payload = rhs_normalized.as_ref().unwrap_or(rhs);
998            let payload = if scalar_multiply {
999                lhs_payload.mul(rhs_payload)?
1000            } else {
1001                let subscripts = Self::build_payload_einsum_subscripts(
1002                    &[lhs_labels, rhs_labels],
1003                    &plan.output_payload_roots,
1004                )?;
1005                eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
1006            };
1007            return Self::from_structured_payload_inner(
1008                result_indices,
1009                payload,
1010                plan.output_payload_dims,
1011                plan.output_axis_classes,
1012            );
1013        }
1014
1015        let lhs_storage = self.storage.materialize(self.indices.len())?;
1016        let rhs_storage = other.storage.materialize(other.indices.len())?;
1017        let lhs = storage_payload_native_read_input(lhs_storage.as_ref())?;
1018        let rhs = storage_payload_native_read_input(rhs_storage.as_ref())?;
1019        if lhs.dtype() != rhs.dtype() {
1020            return Err(anyhow::anyhow!(
1021                "structured payload contraction requires matching payload dtypes"
1022            ));
1023        }
1024        let (lhs, lhs_labels) = normalize_payload_read_for_roots(lhs, &lhs_roots)?;
1025        let (rhs, rhs_labels) = normalize_payload_read_for_roots(rhs, &rhs_roots)?;
1026        let payload = tensor4all_tensorbackend::einsum_native_tensor_reads(
1027            &[(&lhs, lhs_labels.as_slice()), (&rhs, rhs_labels.as_slice())],
1028            &plan.output_payload_roots,
1029        )?;
1030        let storage = storage_from_payload_native(
1031            payload,
1032            &plan.output_payload_dims,
1033            plan.output_axis_classes,
1034        )?;
1035        Self::from_storage(result_indices, Arc::new(storage))
1036    }
1037
1038    fn should_use_structured_payload_contract(&self, other: &Self) -> bool {
1039        let same_payload_dtype = self.storage.is_f64() == other.storage.is_f64()
1040            && self.storage.is_complex() == other.storage.is_complex();
1041        same_payload_dtype
1042            && (self.tracked_compact_payload_value().is_some()
1043                || other.tracked_compact_payload_value().is_some()
1044                || self.storage.axis_classes() != Self::dense_axis_classes(self.indices.len())
1045                || other.storage.axis_classes() != Self::dense_axis_classes(other.indices.len()))
1046    }
1047
1048    fn storage_from_native_with_axis_classes(
1049        native: &NativeTensor,
1050        axis_classes: &[usize],
1051        logical_rank: usize,
1052    ) -> Result<Storage> {
1053        if Self::is_diag_axis_classes(axis_classes) {
1054            match native.dtype() {
1055                DType::F32 | DType::F64 | DType::I64 => Storage::from_diag_col_major(
1056                    native_tensor_primal_to_diag_f64(native)?,
1057                    logical_rank,
1058                ),
1059                DType::C32 | DType::C64 => Storage::from_diag_col_major(
1060                    native_tensor_primal_to_diag_c64(native)?,
1061                    logical_rank,
1062                ),
1063            }
1064        } else {
1065            native_tensor_primal_to_storage(native)
1066        }
1067    }
1068
1069    fn dense_selected_diag_payload<T: TensorElement + Copy + Zero>(
1070        payload: Vec<T>,
1071        kept_dims: &[usize],
1072        selected_positions: &[usize],
1073    ) -> Vec<T> {
1074        let output_len = kept_dims.iter().product::<usize>();
1075        let mut data = vec![T::zero(); output_len];
1076        if output_len == 0 {
1077            return data;
1078        }
1079
1080        let Some((&first_position, rest)) = selected_positions.split_first() else {
1081            return data;
1082        };
1083        if rest.iter().any(|&position| position != first_position) {
1084            return data;
1085        }
1086
1087        let value = payload[first_position];
1088        if kept_dims.is_empty() {
1089            data[0] = value;
1090            return data;
1091        }
1092
1093        let mut offset = 0usize;
1094        let mut stride = 1usize;
1095        for &dim in kept_dims {
1096            offset += first_position * stride;
1097            stride *= dim;
1098        }
1099        data[offset] = value;
1100        data
1101    }
1102
1103    fn select_diag_indices(
1104        &self,
1105        kept_indices: Vec<DynIndex>,
1106        kept_dims: Vec<usize>,
1107        positions: &[usize],
1108    ) -> Result<Self> {
1109        if self.storage.is_f64() {
1110            let storage = self.storage.materialize(self.indices.len())?;
1111            let payload = storage
1112                .payload_f64_col_major_vec()
1113                .map_err(anyhow::Error::msg)?;
1114            let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
1115            Self::from_dense(kept_indices, data)
1116        } else if self.storage.is_c64() {
1117            let storage = self.storage.materialize(self.indices.len())?;
1118            let payload = storage
1119                .payload_c64_col_major_vec()
1120                .map_err(anyhow::Error::msg)?;
1121            let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
1122            Self::from_dense(kept_indices, data)
1123        } else {
1124            Err(anyhow::anyhow!("unsupported diagonal storage scalar type"))
1125        }
1126    }
1127
1128    fn col_major_strides(dims: &[usize]) -> Result<Vec<isize>> {
1129        let mut strides = Vec::with_capacity(dims.len());
1130        let mut stride = 1isize;
1131        for &dim in dims {
1132            strides.push(stride);
1133            let dim = isize::try_from(dim)
1134                .map_err(|_| anyhow::anyhow!("dimension does not fit in isize"))?;
1135            stride = stride
1136                .checked_mul(dim)
1137                .ok_or_else(|| anyhow::anyhow!("column-major stride overflow"))?;
1138        }
1139        Ok(strides)
1140    }
1141
1142    fn zero_structured_selection<T>(
1143        kept_indices: Vec<DynIndex>,
1144        kept_dims: &[usize],
1145    ) -> Result<Self>
1146    where
1147        T: TensorElement + Zero,
1148    {
1149        let output_len = checked_product(kept_dims)?;
1150        Self::from_dense(kept_indices, vec![T::zero(); output_len])
1151    }
1152
1153    fn select_structured_indices_typed<T>(
1154        &self,
1155        payload: Vec<T>,
1156        kept_axes: &[usize],
1157        kept_indices: Vec<DynIndex>,
1158        kept_dims: Vec<usize>,
1159        selected_axes: &[usize],
1160        positions: &[usize],
1161    ) -> Result<Self>
1162    where
1163        T: TensorElement + StorageScalar + Zero,
1164    {
1165        let payload_dims = self.storage.payload_dims();
1166        let axis_classes = self.storage.axis_classes();
1167        let payload_rank = payload_dims.len();
1168        let mut selected_class_positions = vec![None; payload_rank];
1169
1170        for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1171            let class_id = axis_classes[axis];
1172            if let Some(existing) = selected_class_positions[class_id] {
1173                if existing != position {
1174                    return Self::zero_structured_selection::<T>(kept_indices, &kept_dims);
1175                }
1176            } else {
1177                selected_class_positions[class_id] = Some(position);
1178            }
1179        }
1180
1181        let selected_class_kept = kept_axes
1182            .iter()
1183            .any(|&axis| selected_class_positions[axis_classes[axis]].is_some());
1184        if selected_class_kept {
1185            return self.select_structured_indices_dense(
1186                payload,
1187                kept_axes,
1188                kept_indices,
1189                kept_dims,
1190                &selected_class_positions,
1191            );
1192        }
1193
1194        let mut old_to_new_class = vec![None; payload_rank];
1195        let mut output_payload_dims = Vec::new();
1196        let mut output_axis_classes = Vec::with_capacity(kept_axes.len());
1197        for &axis in kept_axes {
1198            let class_id = axis_classes[axis];
1199            let new_class = match old_to_new_class[class_id] {
1200                Some(new_class) => new_class,
1201                None => {
1202                    let new_class = output_payload_dims.len();
1203                    old_to_new_class[class_id] = Some(new_class);
1204                    output_payload_dims.push(payload_dims[class_id]);
1205                    new_class
1206                }
1207            };
1208            output_axis_classes.push(new_class);
1209        }
1210
1211        let output_len = checked_product(&output_payload_dims)?;
1212        let mut output_payload = Vec::with_capacity(output_len);
1213        for linear in 0..output_len {
1214            let output_payload_index = decode_col_major_linear(linear, &output_payload_dims)?;
1215            let mut input_payload_index = vec![0usize; payload_rank];
1216            for class_id in 0..payload_rank {
1217                input_payload_index[class_id] =
1218                    if let Some(position) = selected_class_positions[class_id] {
1219                        position
1220                    } else if let Some(new_class) = old_to_new_class[class_id] {
1221                        output_payload_index[new_class]
1222                    } else {
1223                        return Err(anyhow::anyhow!(
1224                            "structured payload class {class_id} is neither selected nor kept"
1225                        ));
1226                    };
1227            }
1228            let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1229            output_payload.push(payload[input_linear]);
1230        }
1231
1232        let output_strides = Self::col_major_strides(&output_payload_dims)?;
1233        let storage = Storage::new_structured(
1234            output_payload,
1235            output_payload_dims,
1236            output_strides,
1237            output_axis_classes,
1238        )?;
1239        Self::from_storage(kept_indices, Arc::new(storage))
1240    }
1241
1242    fn select_structured_indices_dense<T>(
1243        &self,
1244        payload: Vec<T>,
1245        kept_axes: &[usize],
1246        kept_indices: Vec<DynIndex>,
1247        kept_dims: Vec<usize>,
1248        selected_class_positions: &[Option<usize>],
1249    ) -> Result<Self>
1250    where
1251        T: TensorElement + Zero,
1252    {
1253        let payload_dims = self.storage.payload_dims();
1254        let axis_classes = self.storage.axis_classes();
1255        let output_len = checked_product(&kept_dims)?;
1256        let mut output = Vec::with_capacity(output_len);
1257
1258        for linear in 0..output_len {
1259            let kept_position = decode_col_major_linear(linear, &kept_dims)?;
1260            let mut input_payload_index = selected_class_positions.to_vec();
1261            let mut is_structural_zero = false;
1262
1263            for (&axis, &position) in kept_axes.iter().zip(kept_position.iter()) {
1264                let class_id = axis_classes[axis];
1265                match input_payload_index[class_id] {
1266                    Some(existing) if existing != position => {
1267                        is_structural_zero = true;
1268                        break;
1269                    }
1270                    Some(_) => {}
1271                    None => input_payload_index[class_id] = Some(position),
1272                }
1273            }
1274
1275            if is_structural_zero {
1276                output.push(T::zero());
1277                continue;
1278            }
1279
1280            let input_payload_index = input_payload_index
1281                .into_iter()
1282                .enumerate()
1283                .map(|(class_id, position)| {
1284                    position.ok_or_else(|| {
1285                        anyhow::anyhow!(
1286                            "structured payload class {class_id} is neither selected nor kept"
1287                        )
1288                    })
1289                })
1290                .collect::<Result<Vec<_>>>()?;
1291            let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1292            output.push(payload[input_linear]);
1293        }
1294
1295        Self::from_dense(kept_indices, output)
1296    }
1297
1298    fn select_structured_indices(
1299        &self,
1300        kept_axes: &[usize],
1301        kept_indices: Vec<DynIndex>,
1302        kept_dims: Vec<usize>,
1303        selected_axes: &[usize],
1304        positions: &[usize],
1305    ) -> Result<Self> {
1306        if self.storage.is_f64() {
1307            let storage = self.storage.materialize(self.indices.len())?;
1308            let payload = storage
1309                .payload_f64_col_major_vec()
1310                .map_err(anyhow::Error::msg)?;
1311            self.select_structured_indices_typed(
1312                payload,
1313                kept_axes,
1314                kept_indices,
1315                kept_dims,
1316                selected_axes,
1317                positions,
1318            )
1319        } else if self.storage.is_c64() {
1320            let storage = self.storage.materialize(self.indices.len())?;
1321            let payload = storage
1322                .payload_c64_col_major_vec()
1323                .map_err(anyhow::Error::msg)?;
1324            self.select_structured_indices_typed(
1325                payload,
1326                kept_axes,
1327                kept_indices,
1328                kept_dims,
1329                selected_axes,
1330                positions,
1331            )
1332        } else {
1333            Err(anyhow::anyhow!(
1334                "unsupported structured storage scalar type"
1335            ))
1336        }
1337    }
1338
1339    fn validate_storage_matches_indices(indices: &[DynIndex], storage: &Storage) -> Result<()> {
1340        let dims = Self::expected_dims_from_indices(indices);
1341        let storage_dims = storage.logical_dims();
1342        if storage_dims != dims {
1343            return Err(anyhow::anyhow!(
1344                "storage logical dims {:?} do not match indices dims {:?}",
1345                storage_dims,
1346                dims
1347            ));
1348        }
1349        if storage.is_diag() {
1350            Self::validate_diag_dims(&dims)?;
1351        }
1352        Ok(())
1353    }
1354
1355    fn try_materialized_inner(&self) -> Result<&EagerTensor<CpuBackend>> {
1356        if let Some(value) = self.tracked_compact_payload_value() {
1357            if self.compact_payload_is_logical_dense(&value.payload_dims) {
1358                return Ok(value.payload.as_ref());
1359            }
1360        }
1361        if let Some(inner) = self.storage.eager() {
1362            return Ok(inner);
1363        }
1364        if self.eager_cache.get().is_none() {
1365            let dims = self.dims();
1366            let native = profile_pairwise_contract_section("materialize_storage_to_native", || {
1367                let storage = self.storage.materialize(self.indices.len())?;
1368                Self::seed_native_payload(storage.as_ref(), &dims)
1369            })
1370            .context("TensorDynLen materialization failed")?;
1371            record_pairwise_contract_profile_bytes(
1372                "materialize_storage_to_native",
1373                native_tensor_profile_bytes(&native),
1374            );
1375            let _ = self.eager_cache.set(Arc::new(EagerTensor::from_tensor_in(
1376                native,
1377                default_eager_ctx(),
1378            )));
1379        }
1380        self.eager_cache
1381            .get()
1382            .map(|inner| inner.as_ref())
1383            .ok_or_else(|| {
1384                anyhow::anyhow!("TensorDynLen materialization cache was not initialized")
1385            })
1386    }
1387
1388    pub(crate) fn as_inner(&self) -> Result<&EagerTensor<CpuBackend>> {
1389        self.try_materialized_inner()
1390    }
1391
1392    /// Compute dims from `indices` order.
1393    #[inline]
1394    fn expected_dims_from_indices(indices: &[DynIndex]) -> Vec<usize> {
1395        indices.iter().map(|idx| idx.dim()).collect()
1396    }
1397
1398    /// Get dims in the current `indices` order.
1399    ///
1400    /// This is computed on-demand from `indices` (single source of truth).
1401    ///
1402    /// # Examples
1403    ///
1404    /// ```
1405    /// use tensor4all_core::{DynIndex, TensorDynLen};
1406    ///
1407    /// let i = DynIndex::new_dyn(2);
1408    /// let j = DynIndex::new_dyn(3);
1409    /// let k = DynIndex::new_dyn(4);
1410    /// let t = TensorDynLen::from_dense(
1411    ///     vec![i, j, k],
1412    ///     vec![0.0; 24],
1413    /// ).unwrap();
1414    /// assert_eq!(t.dims(), vec![2, 3, 4]);
1415    /// ```
1416    pub fn dims(&self) -> Vec<usize> {
1417        Self::expected_dims_from_indices(&self.indices)
1418    }
1419
1420    /// Select fixed coordinates for tensor indices and drop those axes.
1421    ///
1422    /// The `selected_indices` slice identifies tensor axes by index identity,
1423    /// and `positions` gives the zero-based coordinate to take on each
1424    /// selected axis. Unselected indices are preserved in their original order.
1425    ///
1426    /// # Arguments
1427    ///
1428    /// * `selected_indices` - Indices to fix and remove from the result. Each
1429    ///   index must appear exactly once in this tensor.
1430    /// * `positions` - Coordinates for `selected_indices`. Each coordinate must
1431    ///   be less than the corresponding index dimension.
1432    ///
1433    /// # Returns
1434    ///
1435    /// A tensor over the unselected indices. Selecting no indices returns a
1436    /// clone of the original tensor. Selecting all indices returns a rank-0
1437    /// scalar tensor. Diagonal and structured tensors are sliced from their
1438    /// compact payload without materializing the original full tensor; the
1439    /// result keeps structured storage when the remaining logical axes can
1440    /// still be represented by axis classes.
1441    ///
1442    /// # Errors
1443    ///
1444    /// Returns an error if the argument lengths differ, a selected index is not
1445    /// present, a selected index is duplicated, or a coordinate is out of range.
1446    ///
1447    /// # Examples
1448    ///
1449    /// ```
1450    /// use tensor4all_core::{DynIndex, TensorDynLen};
1451    ///
1452    /// let i = DynIndex::new_dyn(2);
1453    /// let j = DynIndex::new_dyn(3);
1454    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1455    /// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data).unwrap();
1456    ///
1457    /// let selected = tensor.select_indices(&[j], &[1]).unwrap();
1458    /// assert_eq!(selected.dims(), vec![2]);
1459    /// assert_eq!(selected.to_vec::<f64>().unwrap(), vec![3.0, 4.0]);
1460    /// ```
1461    pub fn select_indices(
1462        &self,
1463        selected_indices: &[DynIndex],
1464        positions: &[usize],
1465    ) -> Result<Self> {
1466        if selected_indices.len() != positions.len() {
1467            return Err(anyhow::anyhow!(
1468                "selected_indices length {} does not match positions length {}",
1469                selected_indices.len(),
1470                positions.len()
1471            ));
1472        }
1473        if selected_indices.is_empty() {
1474            return Ok(self.clone());
1475        }
1476
1477        let mut selected_axes = Vec::with_capacity(selected_indices.len());
1478        let mut seen_axes = HashSet::with_capacity(selected_indices.len());
1479        for (selected, &position) in selected_indices.iter().zip(positions.iter()) {
1480            let axis = self
1481                .indices
1482                .iter()
1483                .position(|index| index == selected)
1484                .ok_or_else(|| anyhow::anyhow!("selected index is not present in tensor"))?;
1485            if !seen_axes.insert(axis) {
1486                return Err(anyhow::anyhow!("selected index appears more than once"));
1487            }
1488            let dim = self.indices[axis].dim();
1489            if position >= dim {
1490                return Err(anyhow::anyhow!(
1491                    "selected coordinate {position} is out of range for axis {axis} with dim {dim}"
1492                ));
1493            }
1494            selected_axes.push(axis);
1495        }
1496
1497        let kept_axes = self
1498            .indices
1499            .iter()
1500            .enumerate()
1501            .filter(|(axis, _)| !seen_axes.contains(axis))
1502            .map(|(axis, _)| axis)
1503            .collect::<Vec<_>>();
1504        let kept_indices = kept_axes
1505            .iter()
1506            .map(|&axis| self.indices[axis].clone())
1507            .collect::<Vec<_>>();
1508        let kept_dims = kept_axes
1509            .iter()
1510            .map(|&axis| self.indices[axis].dim())
1511            .collect::<Vec<_>>();
1512
1513        if self.storage.storage_kind() == StorageKind::Diagonal {
1514            return self.select_diag_indices(kept_indices, kept_dims, positions);
1515        }
1516        if self.storage.storage_kind() == StorageKind::Structured {
1517            return self.select_structured_indices(
1518                &kept_axes,
1519                kept_indices,
1520                kept_dims,
1521                &selected_axes,
1522                positions,
1523            );
1524        }
1525        if self.storage.storage_kind() != StorageKind::Dense {
1526            return Err(anyhow::anyhow!(
1527                "select_indices got unsupported storage kind {:?}",
1528                self.storage.storage_kind()
1529            ));
1530        }
1531
1532        let rank = self.indices.len();
1533        let mut starts = vec![0_i64; rank];
1534        let mut slice_sizes = self.dims();
1535        for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1536            starts[axis] = i64::try_from(position)
1537                .map_err(|_| anyhow::anyhow!("selected coordinate does not fit in i64"))?;
1538            slice_sizes[axis] = 1;
1539        }
1540
1541        let starts_tensor = EagerTensor::from_tensor_in(
1542            NativeTensor::from_vec(vec![rank], starts),
1543            default_eager_ctx(),
1544        );
1545        let sliced = self
1546            .try_materialized_inner()?
1547            .dynamic_slice(&starts_tensor, &slice_sizes)?;
1548        Self::from_inner(kept_indices, sliced.reshape(&kept_dims)?)
1549    }
1550
1551    /// Stack tensors along a newly inserted index.
1552    ///
1553    /// Each input must have exactly the same index order and dimensions. The
1554    /// `new_index` dimension must match the number of input tensors. The
1555    /// `axis` argument follows tenferro/PyTorch-style insertion semantics:
1556    /// `0` inserts before the first existing axis and `-1` appends a trailing
1557    /// axis. Use `axis = -1` for batched contractions because tenferro uses
1558    /// trailing batch dimensions as the canonical batched-GEMM layout.
1559    ///
1560    /// # Errors
1561    ///
1562    /// Returns an error if no tensors are provided, the new index dimension
1563    /// does not match the number of tensors, an input has a different index
1564    /// order, `axis` is outside the valid insertion range, or a tracked
1565    /// structured-AD tensor uses compact storage that would need dense
1566    /// materialization.
1567    ///
1568    /// # Examples
1569    ///
1570    /// ```
1571    /// use tensor4all_core::{DynIndex, TensorDynLen};
1572    ///
1573    /// let i = DynIndex::new_dyn(2);
1574    /// let batch = DynIndex::new_dyn(2);
1575    /// let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0_f64, 2.0]).unwrap();
1576    /// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0_f64, 4.0]).unwrap();
1577    ///
1578    /// let stacked = TensorDynLen::stack_along_new_index(&[&a, &b], batch.clone(), -1).unwrap();
1579    ///
1580    /// assert_eq!(stacked.indices(), &[i, batch]);
1581    /// assert_eq!(stacked.to_vec::<f64>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
1582    /// ```
1583    pub fn stack_along_new_index(
1584        tensors: &[&Self],
1585        new_index: DynIndex,
1586        axis: isize,
1587    ) -> Result<Self> {
1588        let first = tensors
1589            .first()
1590            .copied()
1591            .ok_or_else(|| anyhow::anyhow!("stack_along_new_index requires at least one tensor"))?;
1592        anyhow::ensure!(
1593            new_index.dim() == tensors.len(),
1594            "stack_along_new_index: new index dim {} does not match tensor count {}",
1595            new_index.dim(),
1596            tensors.len()
1597        );
1598
1599        let base_indices = first.indices.clone();
1600        for tensor in tensors.iter().copied().skip(1) {
1601            anyhow::ensure!(
1602                tensor.indices == base_indices,
1603                "stack_along_new_index: input tensors must have identical index order"
1604            );
1605        }
1606        for &tensor in tensors {
1607            tensor.ensure_shape_packing_preserves_ad("stack_along_new_index")?;
1608        }
1609
1610        let insert_axis =
1611            Self::normalize_insert_axis("stack_along_new_index", axis, base_indices.len())?;
1612        let mut result_indices = base_indices;
1613        result_indices.insert(insert_axis, new_index);
1614
1615        let inner_refs = tensors
1616            .iter()
1617            .map(|tensor| tensor.try_materialized_inner())
1618            .collect::<Result<Vec<_>>>()?;
1619        let stacked = EagerTensor::stack(&inner_refs, axis)?;
1620        Self::from_inner(result_indices, stacked)
1621    }
1622
1623    /// Select positions along one index and replace it with a new index.
1624    ///
1625    /// This is the retained-axis counterpart to [`Self::select_indices`]:
1626    /// instead of fixing one coordinate and removing the index, it gathers a
1627    /// list of positions and keeps the gathered axis under `target_index`.
1628    /// Repeated positions are allowed; reverse-mode AD accumulates repeated
1629    /// cotangents through tenferro's scatter-add gather transpose.
1630    ///
1631    /// # Errors
1632    ///
1633    /// Returns an error if `source_index` is not present, `target_index.dim()`
1634    /// differs from `positions.len()`, or any position is out of range for the
1635    /// source index. A tracked structured-AD tensor with compact storage is
1636    /// also rejected because dense materialization would detach gradients.
1637    ///
1638    /// # Examples
1639    ///
1640    /// ```
1641    /// use tensor4all_core::{DynIndex, TensorDynLen};
1642    ///
1643    /// let source = DynIndex::new_dyn(3);
1644    /// let target = DynIndex::new_dyn(2);
1645    /// let tensor = TensorDynLen::from_dense(
1646    ///     vec![source.clone()],
1647    ///     vec![10.0_f64, 20.0, 30.0],
1648    /// ).unwrap();
1649    ///
1650    /// let selected = tensor.index_select(&source, target.clone(), &[2, 0]).unwrap();
1651    ///
1652    /// assert_eq!(selected.indices(), &[target]);
1653    /// assert_eq!(selected.to_vec::<f64>().unwrap(), vec![30.0, 10.0]);
1654    /// ```
1655    pub fn index_select(
1656        &self,
1657        source_index: &DynIndex,
1658        target_index: DynIndex,
1659        positions: &[usize],
1660    ) -> Result<Self> {
1661        anyhow::ensure!(
1662            target_index.dim() == positions.len(),
1663            "index_select: target index dim {} does not match position count {}",
1664            target_index.dim(),
1665            positions.len()
1666        );
1667        let axis = self
1668            .indices
1669            .iter()
1670            .position(|index| index == source_index)
1671            .ok_or_else(|| anyhow::anyhow!("index_select: source index is not present"))?;
1672        let source_dim = self.indices[axis].dim();
1673        for &position in positions {
1674            anyhow::ensure!(
1675                position < source_dim,
1676                "index_select: position {position} is out of range for source dim {source_dim}"
1677            );
1678        }
1679        self.ensure_shape_packing_preserves_ad("index_select")?;
1680
1681        let axis = isize::try_from(axis)
1682            .map_err(|_| anyhow::anyhow!("index_select: axis does not fit in isize"))?;
1683        let selected = self
1684            .try_materialized_inner()?
1685            .index_select(axis, positions)?;
1686        let mut result_indices = self.indices.clone();
1687        result_indices[axis as usize] = target_index;
1688        Self::from_inner(result_indices, selected)
1689    }
1690
1691    /// Create a new tensor with dynamic rank.
1692    ///
1693    /// # Errors
1694    /// Returns an error if the storage logical dimensions do not match the
1695    /// supplied indices, if diagonal storage has unequal logical dimensions,
1696    /// or if duplicate indices are provided.
1697    ///
1698    /// # Examples
1699    ///
1700    /// ```
1701    /// use tensor4all_core::{DynIndex, TensorDynLen};
1702    /// use tensor4all_tensorbackend::Storage;
1703    /// use std::sync::Arc;
1704    ///
1705    /// let i = DynIndex::new_dyn(3);
1706    /// let storage = Arc::new(Storage::new_dense::<f64>(3).unwrap());
1707    /// let t = TensorDynLen::new(vec![i], storage).unwrap();
1708    /// assert_eq!(t.dims(), vec![3]);
1709    /// ```
1710    pub fn new(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1711        Self::from_storage(indices, storage)
1712    }
1713
1714    /// Create a new tensor with dynamic rank, automatically computing dimensions from indices.
1715    ///
1716    /// This is a convenience constructor that extracts dimensions from indices using `IndexLike::dim()`.
1717    ///
1718    /// # Errors
1719    /// Returns an error if the storage logical dimensions do not match the
1720    /// supplied indices, if diagonal storage has unequal logical dimensions,
1721    /// or if duplicate indices are provided.
1722    ///
1723    /// # Examples
1724    ///
1725    /// ```
1726    /// use tensor4all_core::{DynIndex, TensorDynLen};
1727    /// use tensor4all_tensorbackend::Storage;
1728    /// use std::sync::Arc;
1729    ///
1730    /// let i = DynIndex::new_dyn(4);
1731    /// let storage = Arc::new(Storage::new_dense::<f64>(4).unwrap());
1732    /// let t = TensorDynLen::from_indices(vec![i], storage).unwrap();
1733    /// assert_eq!(t.dims(), vec![4]);
1734    /// ```
1735    pub fn from_indices(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1736        Self::new(indices, storage)
1737    }
1738
1739    /// Create a tensor from explicit compact storage.
1740    ///
1741    /// # Examples
1742    ///
1743    /// ```
1744    /// use tensor4all_core::{DynIndex, TensorDynLen};
1745    /// use tensor4all_tensorbackend::Storage;
1746    /// use std::sync::Arc;
1747    ///
1748    /// let i = DynIndex::new_dyn(2);
1749    /// let j = DynIndex::new_dyn(2);
1750    /// let storage = Arc::new(Storage::new_diag(vec![1.0_f64, 2.0]).unwrap());
1751    /// let t = TensorDynLen::from_storage(vec![i, j], storage).unwrap();
1752    /// assert_eq!(t.dims(), vec![2, 2]);
1753    /// ```
1754    pub fn from_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1755        Self::validate_indices(&indices)?;
1756        Self::validate_storage_matches_indices(&indices, storage.as_ref())?;
1757        Ok(Self {
1758            indices,
1759            storage: TensorDynLenStorage::from_storage(storage),
1760            structured_ad: None,
1761            eager_cache: Self::empty_eager_cache(),
1762        })
1763    }
1764
1765    /// Create a tensor from explicit structured storage.
1766    ///
1767    /// This is an alias for [`TensorDynLen::from_storage`] with a name that
1768    /// emphasizes that compact structured metadata is preserved.
1769    ///
1770    /// # Errors
1771    ///
1772    /// Returns an error if the storage logical dimensions do not match the
1773    /// supplied indices, or if duplicate indices are provided.
1774    ///
1775    /// # Examples
1776    ///
1777    /// ```
1778    /// use std::sync::Arc;
1779    /// use tensor4all_core::{DynIndex, TensorDynLen};
1780    /// use tensor4all_tensorbackend::{Storage, StorageKind};
1781    ///
1782    /// let i = DynIndex::new_dyn(2);
1783    /// let j = DynIndex::new_dyn(2);
1784    /// let storage = Arc::new(Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap());
1785    /// let tensor = TensorDynLen::from_structured_storage(vec![i, j], storage).unwrap();
1786    /// assert_eq!(tensor.storage().storage_kind(), StorageKind::Diagonal);
1787    /// ```
1788    pub fn from_structured_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1789        Self::from_storage(indices, storage)
1790    }
1791
1792    /// Create a tensor from a native tenferro payload.
1793    pub(crate) fn from_native(indices: Vec<DynIndex>, native: NativeTensor) -> Result<Self> {
1794        let axis_classes = Self::dense_axis_classes(indices.len());
1795        Self::from_native_with_axis_classes(indices, native, axis_classes)
1796    }
1797
1798    pub(crate) fn from_native_with_axis_classes(
1799        indices: Vec<DynIndex>,
1800        native: NativeTensor,
1801        axis_classes: Vec<usize>,
1802    ) -> Result<Self> {
1803        Self::from_inner_with_axis_classes(
1804            indices,
1805            EagerTensor::from_tensor_in(native, default_eager_ctx()),
1806            axis_classes,
1807        )
1808    }
1809
1810    pub(crate) fn from_inner(
1811        indices: Vec<DynIndex>,
1812        inner: EagerTensor<CpuBackend>,
1813    ) -> Result<Self> {
1814        let axis_classes = Self::dense_axis_classes(indices.len());
1815        Self::from_inner_with_axis_classes(indices, inner, axis_classes)
1816    }
1817
1818    pub(crate) fn from_diag_inner(
1819        indices: Vec<DynIndex>,
1820        payload_inner: EagerTensor<CpuBackend>,
1821    ) -> Result<Self> {
1822        let dims = Self::expected_dims_from_indices(&indices);
1823        Self::validate_indices(&indices)?;
1824        Self::validate_diag_dims(&dims)?;
1825        Self::validate_diag_payload_len(payload_inner.data().shape().iter().product(), &dims)?;
1826        let axis_classes = Self::diag_axis_classes(dims.len());
1827        let diag_inner = payload_inner.embed_diag(0, 1)?;
1828        Self::from_inner_with_axis_classes(indices, diag_inner, axis_classes)
1829    }
1830
1831    pub(crate) fn from_inner_with_axis_classes(
1832        indices: Vec<DynIndex>,
1833        inner: EagerTensor<CpuBackend>,
1834        axis_classes: Vec<usize>,
1835    ) -> Result<Self> {
1836        let dims = profile_pairwise_contract_section("from_inner_expected_dims", || {
1837            Self::expected_dims_from_indices(&indices)
1838        });
1839        profile_pairwise_contract_section("from_inner_validate_indices", || {
1840            Self::validate_indices(&indices)
1841        })?;
1842        if dims != inner.data().shape() {
1843            return Err(anyhow::anyhow!(
1844                "native payload dims {:?} do not match indices dims {:?}",
1845                inner.data().shape(),
1846                dims
1847            ));
1848        }
1849        if Self::is_diag_axis_classes(&axis_classes) {
1850            profile_pairwise_contract_section("from_inner_validate_diag_dims", || {
1851                Self::validate_diag_dims(&dims)
1852            })?;
1853        }
1854        let (storage, eager_cache) = if axis_classes == Self::dense_axis_classes(indices.len()) {
1855            (
1856                TensorDynLenStorage::from_eager_dense(inner, indices.len()),
1857                Self::empty_eager_cache(),
1858            )
1859        } else {
1860            let storage = profile_pairwise_contract_section("from_inner_storage_snapshot", || {
1861                Self::storage_from_native_with_axis_classes(
1862                    inner.data(),
1863                    &axis_classes,
1864                    indices.len(),
1865                )
1866            })?;
1867            record_pairwise_contract_profile_bytes(
1868                "from_inner_storage_snapshot",
1869                native_tensor_profile_bytes(inner.data()),
1870            );
1871            (
1872                TensorDynLenStorage::from_storage(Arc::new(storage)),
1873                profile_pairwise_contract_section("from_inner_eager_cache", || {
1874                    Self::eager_cache_with(inner)
1875                }),
1876            )
1877        };
1878        Ok(Self {
1879            indices,
1880            storage,
1881            structured_ad: None,
1882            eager_cache,
1883        })
1884    }
1885
1886    /// Borrow the indices.
1887    pub fn indices(&self) -> &[DynIndex] {
1888        &self.indices
1889    }
1890
1891    /// Borrow the native payload.
1892    pub(crate) fn as_native(&self) -> Result<&NativeTensor> {
1893        Ok(self.try_materialized_inner()?.data())
1894    }
1895
1896    /// Enable reverse-mode AD tracking on this tensor by creating a tracked leaf.
1897    pub fn enable_grad(self) -> Result<Self> {
1898        let materialized = self.storage.materialize(self.indices.len())?;
1899        let payload = storage_payload_native(materialized.as_ref())
1900            .context("TensorDynLen::enable_grad failed")?;
1901        let payload_dims = self.storage.payload_dims().to_vec();
1902        let axis_classes = self.storage.axis_classes().to_vec();
1903        Ok(Self {
1904            indices: self.indices,
1905            storage: self.storage,
1906            structured_ad: Some(Arc::new(StructuredAdValue {
1907                payload: Arc::new(EagerTensor::requires_grad_in(payload, default_eager_ctx())),
1908                payload_dims,
1909                axis_classes,
1910            })),
1911            eager_cache: Self::empty_eager_cache(),
1912        })
1913    }
1914
1915    /// Report whether this tensor participates in gradient tracking.
1916    pub fn tracks_grad(&self) -> bool {
1917        self.structured_ad
1918            .as_ref()
1919            .is_some_and(|value| value.payload.tracks_grad())
1920            || self.storage.eager().is_some_and(EagerTensor::tracks_grad)
1921            || self
1922                .eager_cache
1923                .get()
1924                .is_some_and(|inner| inner.tracks_grad())
1925    }
1926
1927    /// Return the accumulated gradient, if one has been stored.
1928    pub fn grad(&self) -> Result<Option<Self>> {
1929        if let Some(value) = self.tracked_compact_payload_value() {
1930            return value
1931                .payload
1932                .grad()
1933                .map(|grad| {
1934                    let storage = storage_from_payload_native(
1935                        grad.as_ref().clone(),
1936                        &value.payload_dims,
1937                        value.axis_classes.clone(),
1938                    )?;
1939                    Self::from_storage(self.indices.clone(), Arc::new(storage))
1940                })
1941                .transpose();
1942        }
1943        self.try_materialized_inner()?
1944            .grad()
1945            .map(|grad| {
1946                Self::from_native_with_axis_classes(
1947                    self.indices.clone(),
1948                    grad.as_ref().clone(),
1949                    self.storage.axis_classes().to_vec(),
1950                )
1951            })
1952            .transpose()
1953    }
1954
1955    /// Clear the accumulated gradient stored for this tensor.
1956    pub fn clear_grad(&self) -> Result<()> {
1957        if let Some(value) = self.tracked_compact_payload_value() {
1958            value.payload.clear_grad();
1959        }
1960        if let Some(inner) = self.storage.eager() {
1961            inner.clear_grad();
1962        }
1963        if let Some(inner) = self.eager_cache.get() {
1964            inner.clear_grad();
1965        }
1966        Ok(())
1967    }
1968
1969    /// Run reverse-mode autodiff from this scalar tensor.
1970    pub fn backward(&self) -> Result<()> {
1971        if let Some(value) = self.tracked_compact_payload_value() {
1972            return value
1973                .payload
1974                .backward()
1975                .map(|_| ())
1976                .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"));
1977        }
1978        self.try_materialized_inner()?
1979            .backward()
1980            .map(|_| ())
1981            .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"))
1982    }
1983
1984    /// Detach this tensor from the reverse graph.
1985    pub fn detach(&self) -> Result<Self> {
1986        if self.tracked_compact_payload_value().is_some() {
1987            return Self::from_storage(
1988                self.indices.clone(),
1989                self.storage.materialize(self.indices.len())?,
1990            );
1991        }
1992        Self::from_inner_with_axis_classes(
1993            self.indices.clone(),
1994            self.try_materialized_inner()?.detach(),
1995            self.storage.axis_classes().to_vec(),
1996        )
1997    }
1998
1999    /// Check if this tensor is already in canonical form.
2000    pub fn is_simple(&self) -> bool {
2001        true
2002    }
2003
2004    /// Materialize the primal snapshot as storage.
2005    pub fn to_storage(&self) -> Result<Arc<Storage>> {
2006        self.storage.materialize(self.indices.len())
2007    }
2008
2009    /// Returns the authoritative compact storage.
2010    pub fn storage(&self) -> Arc<Storage> {
2011        self.storage
2012            .materialize(self.indices.len())
2013            .expect("TensorDynLen storage materialization failed")
2014    }
2015
2016    /// Sum all elements, returning `AnyScalar`.
2017    ///
2018    /// # Examples
2019    ///
2020    /// ```
2021    /// use tensor4all_core::{DynIndex, TensorDynLen};
2022    ///
2023    /// let i = DynIndex::new_dyn(3);
2024    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0, 3.0]).unwrap();
2025    /// let s = t.sum().unwrap();
2026    /// assert!((s.real() - 6.0).abs() < 1e-12);
2027    /// ```
2028    pub fn sum(&self) -> Result<AnyScalar> {
2029        if self.indices.is_empty() {
2030            return AnyScalar::from_tensor(self.clone());
2031        }
2032        let axes: Vec<usize> = (0..self.indices.len()).collect();
2033        let reduced = self.try_materialized_inner()?.reduce_sum(&axes)?;
2034        AnyScalar::from_tensor(Self::from_inner(Vec::new(), reduced)?)
2035    }
2036
2037    /// Extract the scalar value from a 0-dimensional tensor (or 1-element tensor).
2038    ///
2039    /// This is similar to Julia's `only()` function.
2040    ///
2041    /// # Panics
2042    ///
2043    /// Panics if the tensor has more than one element.
2044    ///
2045    /// # Example
2046    ///
2047    /// ```
2048    /// use tensor4all_core::{TensorDynLen, AnyScalar};
2049    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2050    ///
2051    /// // Create a scalar tensor (0 dimensions, 1 element)
2052    /// let indices: Vec<Index<DynId>> = vec![];
2053    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![42.0]).unwrap();
2054    ///
2055    /// assert_eq!(tensor.only().unwrap().real(), 42.0);
2056    /// ```
2057    pub fn only(&self) -> Result<AnyScalar> {
2058        let dims = self.dims();
2059        let total_size = checked_product(&dims)?;
2060        anyhow::ensure!(
2061            total_size == 1 || dims.is_empty(),
2062            "only() requires a scalar tensor (1 element), got {} elements with dims {:?}",
2063            if dims.is_empty() { 1 } else { total_size },
2064            dims
2065        );
2066        self.sum()
2067    }
2068
2069    /// Permute the tensor dimensions using the given new indices order.
2070    ///
2071    /// This is the main permutation method that takes the desired new indices
2072    /// and automatically computes the corresponding permutation of dimensions
2073    /// and data. The new indices must be a permutation of the original indices
2074    /// (matched by ID).
2075    ///
2076    /// # Arguments
2077    /// * `new_indices` - The desired new indices order. Must be a permutation
2078    ///   of `self.indices` (matched by ID).
2079    ///
2080    /// # Panics
2081    /// Panics if `new_indices.len() != self.indices.len()`, if any index ID
2082    /// doesn't match, or if there are duplicate indices.
2083    ///
2084    /// # Example
2085    /// ```
2086    /// use tensor4all_core::TensorDynLen;
2087    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2088    ///
2089    /// // Create a 2×3 tensor
2090    /// let i = Index::new_dyn(2);
2091    /// let j = Index::new_dyn(3);
2092    /// let indices = vec![i.clone(), j.clone()];
2093    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![0.0; 6]).unwrap();
2094    ///
2095    /// // Permute to 3×2: swap the two dimensions by providing new indices order
2096    /// let permuted = tensor.permute_indices(&[j, i]).unwrap();
2097    /// assert_eq!(permuted.dims(), vec![3, 2]);
2098    /// ```
2099    pub fn permute_indices(&self, new_indices: &[DynIndex]) -> Result<Self> {
2100        // Compute permutation by matching IDs
2101        let perm = compute_permutation_from_indices(&self.indices, new_indices)?;
2102        if perm.iter().copied().eq(0..perm.len()) {
2103            return Ok(Self {
2104                indices: new_indices.to_vec(),
2105                storage: self.storage.clone(),
2106                structured_ad: self.structured_ad.clone(),
2107                eager_cache: Arc::clone(&self.eager_cache),
2108            });
2109        }
2110
2111        let permuted = self.try_materialized_inner()?.transpose(&perm)?;
2112        let axis_classes = self.permute_axis_classes(&perm);
2113        Self::from_inner_with_axis_classes(new_indices.to_vec(), permuted, axis_classes)
2114    }
2115
2116    /// Permute the tensor dimensions, returning a new tensor.
2117    ///
2118    /// This method reorders the indices, dimensions, and data according to the
2119    /// given permutation. The permutation specifies which old axis each new
2120    /// axis corresponds to: `new_axis[i] = old_axis[perm[i]]`.
2121    ///
2122    /// # Arguments
2123    /// * `perm` - The permutation: `perm[i]` is the old axis index for new axis `i`
2124    ///
2125    /// # Panics
2126    /// Panics if `perm.len() != self.indices.len()` or if the permutation is invalid.
2127    ///
2128    /// # Example
2129    /// ```
2130    /// use tensor4all_core::TensorDynLen;
2131    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2132    ///
2133    /// // Create a 2×3 tensor
2134    /// let indices = vec![
2135    ///     Index::new_dyn(2),
2136    ///     Index::new_dyn(3),
2137    /// ];
2138    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![0.0; 6]).unwrap();
2139    ///
2140    /// // Permute to 3×2: swap the two dimensions
2141    /// let permuted = tensor.permute(&[1, 0]).unwrap();
2142    /// assert_eq!(permuted.dims(), vec![3, 2]);
2143    /// ```
2144    pub fn permute(&self, perm: &[usize]) -> Result<Self> {
2145        anyhow::ensure!(
2146            perm.len() == self.indices.len(),
2147            "permutation length must match tensor rank"
2148        );
2149        let mut seen = HashSet::new();
2150        for &axis in perm {
2151            anyhow::ensure!(
2152                axis < self.indices.len(),
2153                "permutation axis {axis} out of range"
2154            );
2155            anyhow::ensure!(seen.insert(axis), "duplicate axis {axis} in permutation");
2156        }
2157        if perm.iter().copied().eq(0..perm.len()) {
2158            return Ok(self.clone());
2159        }
2160
2161        // Permute indices
2162        let new_indices: Vec<DynIndex> = perm.iter().map(|&i| self.indices[i].clone()).collect();
2163        let permuted = self.try_materialized_inner()?.transpose(perm)?;
2164        let axis_classes = self.permute_axis_classes(perm);
2165        Self::from_inner_with_axis_classes(new_indices, permuted, axis_classes)
2166    }
2167
2168    pub(crate) fn try_contract_pairwise_default(&self, other: &Self) -> Result<Self> {
2169        self.try_contract_pairwise_default_with_options(other, PairwiseContractionOptions::new())
2170    }
2171
2172    pub(crate) fn try_contract_pairwise_default_with_options(
2173        &self,
2174        other: &Self,
2175        options: PairwiseContractionOptions,
2176    ) -> Result<Self> {
2177        let self_indices = profile_pairwise_contract_section("operand_indices", || {
2178            self.operand_indices_for_contraction(options.lhs_conj)
2179        });
2180        let other_indices = profile_pairwise_contract_section("operand_indices", || {
2181            other.operand_indices_for_contraction(options.rhs_conj)
2182        });
2183        let self_dims = profile_pairwise_contract_section("expected_dims", || {
2184            Self::expected_dims_from_indices(&self_indices)
2185        });
2186        let other_dims = profile_pairwise_contract_section("expected_dims", || {
2187            Self::expected_dims_from_indices(&other_indices)
2188        });
2189        let spec = profile_pairwise_contract_section("prepare_contraction", || {
2190            prepare_contraction(&self_indices, &self_dims, &other_indices, &other_dims)
2191        })
2192        .context("contraction preparation failed")?;
2193        let result_axis_classes = profile_pairwise_contract_section("result_axis_classes", || {
2194            Self::binary_contraction_axis_classes(
2195                self.storage.axis_classes(),
2196                &spec.axes_a,
2197                other.storage.axis_classes(),
2198                &spec.axes_b,
2199            )
2200        });
2201
2202        if profile_pairwise_contract_section("structured_check", || {
2203            self.should_use_structured_payload_contract(other)
2204        }) {
2205            if options.has_conj() {
2206                let lhs = if options.lhs_conj {
2207                    self.conj()
2208                } else {
2209                    self.clone()
2210                };
2211                let rhs = if options.rhs_conj {
2212                    other.conj()
2213                } else {
2214                    other.clone()
2215                };
2216                return profile_pairwise_contract_section("structured_conj_fallback", || {
2217                    lhs.try_contract_pairwise_default(&rhs)
2218                });
2219            }
2220            return profile_pairwise_contract_section("structured_payload_contract", || {
2221                self.contract_structured_payloads(
2222                    other,
2223                    spec.result_indices.into_vec(),
2224                    &spec.axes_a,
2225                    &spec.axes_b,
2226                )
2227            });
2228        }
2229
2230        if self.indices.is_empty() && other.indices.is_empty() {
2231            if options.has_conj() {
2232                let lhs = if options.lhs_conj {
2233                    self.conj()
2234                } else {
2235                    self.clone()
2236                };
2237                let rhs = if options.rhs_conj {
2238                    other.conj()
2239                } else {
2240                    other.clone()
2241                };
2242                return lhs.try_contract_pairwise_default(&rhs);
2243            }
2244            let result = profile_pairwise_contract_section("scalar_mul", || {
2245                Ok::<_, anyhow::Error>(
2246                    self.try_materialized_inner()?
2247                        .mul(other.try_materialized_inner()?)?,
2248                )
2249            })?;
2250            return profile_pairwise_contract_section("from_inner", || {
2251                Self::from_inner(spec.result_indices.into_vec(), result)
2252            });
2253        }
2254
2255        let self_native = profile_pairwise_contract_section("as_native", || self.as_native())?;
2256        let other_native = profile_pairwise_contract_section("as_native", || other.as_native())?;
2257        if self_native.dtype() != other_native.dtype() {
2258            if options.has_conj() {
2259                let lhs = if options.lhs_conj {
2260                    self.conj()
2261                } else {
2262                    self.clone()
2263                };
2264                let rhs = if options.rhs_conj {
2265                    other.conj()
2266                } else {
2267                    other.clone()
2268                };
2269                return lhs.try_contract_pairwise_default(&rhs);
2270            }
2271            let result_native = profile_pairwise_contract_section("native_contract", || {
2272                contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)
2273            })?;
2274            return profile_pairwise_contract_section("from_native", || {
2275                Self::from_native_with_axis_classes(
2276                    spec.result_indices.into_vec(),
2277                    result_native,
2278                    result_axis_classes,
2279                )
2280            });
2281        }
2282
2283        let config = profile_pairwise_contract_section("build_dot_general_config", || {
2284            Self::binary_dot_general_config(&spec.axes_a, &spec.axes_b)
2285        })?;
2286        let result = profile_pairwise_contract_section("dot_general_with_conj", || {
2287            let lhs = profile_pairwise_contract_section("lhs_try_materialized_inner", || {
2288                self.try_materialized_inner()
2289            })?;
2290            let rhs = profile_pairwise_contract_section("rhs_try_materialized_inner", || {
2291                other.try_materialized_inner()
2292            })?;
2293            profile_pairwise_contract_section("dot_general_execute", || {
2294                lhs.dot_general_with_conj(rhs, &config, options.lhs_conj, options.rhs_conj)
2295            })
2296            .map_err(anyhow::Error::from)
2297        })?;
2298        record_pairwise_contract_profile_bytes(
2299            "dot_general_output",
2300            native_tensor_profile_bytes(result.data()),
2301        );
2302        profile_pairwise_contract_section("from_inner_axis_classes", || {
2303            Self::from_inner_with_axis_classes(
2304                spec.result_indices.into_vec(),
2305                result,
2306                result_axis_classes,
2307            )
2308        })
2309    }
2310
2311    pub(crate) fn try_tensordot_pairwise_explicit(
2312        &self,
2313        other: &Self,
2314        pairs: &[(DynIndex, DynIndex)],
2315    ) -> Result<Self> {
2316        use crate::index_ops::ContractionError;
2317
2318        let self_dims = Self::expected_dims_from_indices(&self.indices);
2319        let other_dims = Self::expected_dims_from_indices(&other.indices);
2320        let spec = prepare_contraction_pairs(
2321            &self.indices,
2322            &self_dims,
2323            &other.indices,
2324            &other_dims,
2325            pairs,
2326        )
2327        .map_err(|e| match e {
2328            ContractionError::NoCommonIndices => {
2329                anyhow::anyhow!("tensordot: No pairs specified for contraction")
2330            }
2331            ContractionError::BatchContractionNotImplemented => anyhow::anyhow!(
2332                "tensordot: Common index found but not in contraction pairs. \
2333                         Batch contraction is not yet implemented."
2334            ),
2335            ContractionError::IndexNotFound { tensor } => {
2336                anyhow::anyhow!("tensordot: Index not found in {} tensor", tensor)
2337            }
2338            ContractionError::DimensionMismatch {
2339                pos_a,
2340                pos_b,
2341                dim_a,
2342                dim_b,
2343            } => anyhow::anyhow!(
2344                "tensordot: Dimension mismatch: self[{}]={} != other[{}]={}",
2345                pos_a,
2346                dim_a,
2347                pos_b,
2348                dim_b
2349            ),
2350            ContractionError::DuplicateAxis { tensor, pos } => {
2351                anyhow::anyhow!("tensordot: Duplicate axis {} in {} tensor", pos, tensor)
2352            }
2353        })?;
2354        let result_axis_classes = Self::binary_contraction_axis_classes(
2355            self.storage.axis_classes(),
2356            &spec.axes_a,
2357            other.storage.axis_classes(),
2358            &spec.axes_b,
2359        );
2360
2361        if self.should_use_structured_payload_contract(other) {
2362            return self.contract_structured_payloads(
2363                other,
2364                spec.result_indices.into_vec(),
2365                &spec.axes_a,
2366                &spec.axes_b,
2367            );
2368        }
2369
2370        if self.indices.is_empty() && other.indices.is_empty() {
2371            let result = self
2372                .try_materialized_inner()?
2373                .mul(other.try_materialized_inner()?)
2374                .map_err(|e| anyhow::anyhow!("tensordot scalar multiply failed: {e}"))?;
2375            return Self::from_inner(spec.result_indices.into_vec(), result);
2376        }
2377
2378        let self_native = self.as_native()?;
2379        let other_native = other.as_native()?;
2380        if self_native.dtype() != other_native.dtype() {
2381            let result_native =
2382                contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)?;
2383            return Self::from_native_with_axis_classes(
2384                spec.result_indices.into_vec(),
2385                result_native,
2386                result_axis_classes,
2387            );
2388        }
2389
2390        let subscripts = Self::build_binary_einsum_subscripts(
2391            self.indices.len(),
2392            &spec.axes_a,
2393            other.indices.len(),
2394            &spec.axes_b,
2395        )?;
2396        let result = eager_einsum_ad(
2397            &[
2398                self.try_materialized_inner()?,
2399                other.try_materialized_inner()?,
2400            ],
2401            &subscripts,
2402        )
2403        .map_err(|e| anyhow::anyhow!("tensordot failed: {e}"))?;
2404        Self::from_inner_with_axis_classes(
2405            spec.result_indices.into_vec(),
2406            result,
2407            result_axis_classes,
2408        )
2409    }
2410
2411    pub(crate) fn try_outer_product_pairwise(&self, other: &Self) -> Result<Self> {
2412        use anyhow::Context;
2413
2414        // Check for common indices - outer product should have none
2415        let common_positions = common_ind_positions(&self.indices, &other.indices);
2416        if !common_positions.is_empty() {
2417            let common_ids: Vec<_> = common_positions
2418                .iter()
2419                .map(|(pos_a, _)| self.indices[*pos_a].id())
2420                .collect();
2421            return Err(anyhow::anyhow!(
2422                "outer_product: tensors have common indices {:?}. \
2423                 Use tensordot to contract common indices, or use sim() to replace \
2424                 indices with fresh IDs before computing outer product.",
2425                common_ids
2426            ))
2427            .context("outer_product: common indices found");
2428        }
2429
2430        // Build result indices and dimensions
2431        let mut result_indices = self.indices.clone();
2432        result_indices.extend(other.indices.iter().cloned());
2433        let result_axis_classes = Self::binary_contraction_axis_classes(
2434            self.storage.axis_classes(),
2435            &[],
2436            other.storage.axis_classes(),
2437            &[],
2438        );
2439        if self.should_use_structured_payload_contract(other) {
2440            return self.contract_structured_payloads(other, result_indices, &[], &[]);
2441        }
2442        let self_native = self.as_native()?;
2443        let other_native = other.as_native()?;
2444        if self_native.dtype() != other_native.dtype() {
2445            let result_native = contract_native_tensor(self_native, &[], other_native, &[])?;
2446            return Self::from_native_with_axis_classes(
2447                result_indices,
2448                result_native,
2449                result_axis_classes,
2450            );
2451        }
2452
2453        let subscripts = Self::build_binary_einsum_subscripts(
2454            self.indices.len(),
2455            &[],
2456            other.indices.len(),
2457            &[],
2458        )?;
2459        let result = eager_einsum_ad(
2460            &[
2461                self.try_materialized_inner()?,
2462                other.try_materialized_inner()?,
2463            ],
2464            &subscripts,
2465        )
2466        .map_err(|e| anyhow::anyhow!("outer_product failed: {e}"))?;
2467        Self::from_inner_with_axis_classes(result_indices, result, result_axis_classes)
2468    }
2469}
2470
2471// ============================================================================
2472// Random tensor generation
2473// ============================================================================
2474
2475impl TensorDynLen {
2476    /// Create a random tensor with values from standard normal distribution (generic over scalar type).
2477    ///
2478    /// For `f64`, each element is drawn from the standard normal distribution.
2479    /// For `Complex64`, both real and imaginary parts are drawn independently.
2480    ///
2481    /// # Type Parameters
2482    /// * `T` - The scalar element type (must implement [`RandomScalar`])
2483    /// * `R` - The random number generator type
2484    ///
2485    /// # Arguments
2486    /// * `rng` - Random number generator
2487    /// * `indices` - The indices for the tensor
2488    ///
2489    /// # Example
2490    /// ```
2491    /// use tensor4all_core::TensorDynLen;
2492    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2493    /// use rand::SeedableRng;
2494    /// use rand_chacha::ChaCha8Rng;
2495    ///
2496    /// let mut rng = ChaCha8Rng::seed_from_u64(42);
2497    /// let i = Index::new_dyn(2);
2498    /// let j = Index::new_dyn(3);
2499    /// let tensor: TensorDynLen = TensorDynLen::random::<f64, _>(&mut rng, vec![i, j]).unwrap();
2500    /// assert_eq!(tensor.dims(), vec![2, 3]);
2501    /// ```
2502    pub fn random<T: RandomScalar, R: Rng>(rng: &mut R, indices: Vec<DynIndex>) -> Result<Self> {
2503        let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
2504        let size = checked_product(&dims)?;
2505        let data: Vec<T> = (0..size).map(|_| T::random_value(rng)).collect();
2506        Self::from_dense(indices, data)
2507    }
2508}
2509
2510impl TensorDynLen {
2511    /// Add two tensors element-wise.
2512    ///
2513    /// The tensors must have the same index set (matched by ID). If the indices
2514    /// are in a different order, the other tensor will be permuted to match `self`.
2515    ///
2516    /// # Arguments
2517    /// * `other` - The tensor to add
2518    ///
2519    /// # Returns
2520    /// A new tensor representing `self + other`, or an error if:
2521    /// - The tensors have different index sets
2522    /// - The dimensions don't match
2523    /// - Storage types are incompatible
2524    ///
2525    /// # Example
2526    /// ```
2527    /// use tensor4all_core::TensorDynLen;
2528    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2529    ///
2530    /// let i = Index::new_dyn(2);
2531    /// let j = Index::new_dyn(3);
2532    ///
2533    /// let indices_a = vec![i.clone(), j.clone()];
2534    /// let data_a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2535    /// let tensor_a: TensorDynLen = TensorDynLen::from_dense(indices_a, data_a).unwrap();
2536    ///
2537    /// let indices_b = vec![i.clone(), j.clone()];
2538    /// let data_b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
2539    /// let tensor_b: TensorDynLen = TensorDynLen::from_dense(indices_b, data_b).unwrap();
2540    ///
2541    /// let sum = tensor_a.add(&tensor_b).unwrap();
2542    /// // sum = [[2, 3, 4], [5, 6, 7]]
2543    /// ```
2544    pub fn add(&self, other: &Self) -> Result<Self> {
2545        // Validate that both tensors have the same number of indices
2546        if self.indices.len() != other.indices.len() {
2547            return Err(anyhow::anyhow!(
2548                "Index count mismatch: self has {} indices, other has {}",
2549                self.indices.len(),
2550                other.indices.len()
2551            ));
2552        }
2553
2554        // Validate that both tensors have the same set of indices
2555        let self_set: HashSet<_> = self.indices.iter().collect();
2556        let other_set: HashSet<_> = other.indices.iter().collect();
2557
2558        if self_set != other_set {
2559            return Err(anyhow::anyhow!(
2560                "Index set mismatch: tensors must have the same indices"
2561            ));
2562        }
2563
2564        // Permute other to match self's index order (no-op if already aligned)
2565        let other_aligned = other.permute_indices(&self.indices)?;
2566
2567        // Validate dimensions match after alignment
2568        let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2569        let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2570        if self_expected_dims != other_expected_dims {
2571            use crate::TagSetLike;
2572            let fmt = |indices: &[DynIndex]| -> Vec<String> {
2573                indices
2574                    .iter()
2575                    .map(|idx| {
2576                        let tags: Vec<String> = idx.tags().iter().collect();
2577                        format!("{:?}(dim={},tags={:?})", idx.id(), idx.dim(), tags)
2578                    })
2579                    .collect()
2580            };
2581            return Err(anyhow::anyhow!(
2582                "Dimension mismatch after alignment.\n\
2583                 self: dims={:?}, indices(order)={:?}\n\
2584                 other_aligned: dims={:?}, indices(order)={:?}",
2585                self_expected_dims,
2586                fmt(&self.indices),
2587                other_expected_dims,
2588                fmt(&other_aligned.indices)
2589            ));
2590        }
2591
2592        self.axpby(
2593            AnyScalar::new_real(1.0),
2594            &other_aligned,
2595            AnyScalar::new_real(1.0),
2596        )
2597    }
2598
2599    /// Compute a linear combination: `a * self + b * other`.
2600    ///
2601    /// Both tensors must have the same set of indices (matched by ID).
2602    /// If indices are in a different order, `other` is automatically permuted
2603    /// to match `self`.
2604    ///
2605    /// # Examples
2606    ///
2607    /// ```
2608    /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen};
2609    ///
2610    /// let i = DynIndex::new_dyn(2);
2611    /// let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
2612    /// let b = TensorDynLen::from_dense(vec![i.clone()], vec![3.0, 4.0]).unwrap();
2613    ///
2614    /// // 2*a + 3*b = [2+9, 4+12] = [11, 16]
2615    /// let result = a.axpby(AnyScalar::new_real(2.0), &b, AnyScalar::new_real(3.0)).unwrap();
2616    /// let data = result.to_vec::<f64>().unwrap();
2617    /// assert!((data[0] - 11.0).abs() < 1e-12);
2618    /// assert!((data[1] - 16.0).abs() < 1e-12);
2619    /// ```
2620    pub fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
2621        // Validate that both tensors have the same number of indices.
2622        if self.indices.len() != other.indices.len() {
2623            return Err(anyhow::anyhow!(
2624                "Index count mismatch: self has {} indices, other has {}",
2625                self.indices.len(),
2626                other.indices.len()
2627            ));
2628        }
2629
2630        // Validate that both tensors have the same set of indices.
2631        let self_set: HashSet<_> = self.indices.iter().collect();
2632        let other_set: HashSet<_> = other.indices.iter().collect();
2633        if self_set != other_set {
2634            return Err(anyhow::anyhow!(
2635                "Index set mismatch: tensors must have the same indices"
2636            ));
2637        }
2638
2639        // Align other tensor axis order to self.
2640        let other_aligned = other.permute_indices(&self.indices)?;
2641
2642        // Validate dimensions match after alignment.
2643        let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2644        let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2645        if self_expected_dims != other_expected_dims {
2646            return Err(anyhow::anyhow!(
2647                "Dimension mismatch after alignment: self={:?}, other_aligned={:?}",
2648                self_expected_dims,
2649                other_expected_dims
2650            ));
2651        }
2652
2653        let axis_classes = if self.storage.axis_classes() == other_aligned.storage.axis_classes() {
2654            self.storage.axis_classes().to_vec()
2655        } else {
2656            Self::dense_axis_classes(self.indices.len())
2657        };
2658
2659        let same_compact_layout = self.storage.payload_dims()
2660            == other_aligned.storage.payload_dims()
2661            && self.storage.payload_strides_vec() == other_aligned.storage.payload_strides_vec()
2662            && self.storage.axis_classes() == other_aligned.storage.axis_classes();
2663        if same_compact_layout
2664            && !self.tracks_grad()
2665            && !other_aligned.tracks_grad()
2666            && !a.tracks_grad()
2667            && !b.tracks_grad()
2668        {
2669            let lhs_storage = self.storage.materialize(self.indices.len())?;
2670            let rhs_storage = other_aligned
2671                .storage
2672                .materialize(other_aligned.indices.len())?;
2673            let combined = lhs_storage
2674                .axpby(
2675                    &a.to_backend_scalar(),
2676                    rhs_storage.as_ref(),
2677                    &b.to_backend_scalar(),
2678                )
2679                .map_err(|e| anyhow::anyhow!("storage axpby failed: {e}"))?;
2680            return Self::from_storage(self.indices.clone(), Arc::new(combined));
2681        }
2682
2683        let self_native = self.as_native()?;
2684        let other_native = other_aligned.as_native()?;
2685        let a_native = a.as_tensor()?.as_native()?;
2686        let b_native = b.as_tensor()?.as_native()?;
2687        if self_native.dtype() != other_native.dtype()
2688            || self_native.dtype() != a_native.dtype()
2689            || other_native.dtype() != b_native.dtype()
2690        {
2691            let combined = axpby_native_tensor(
2692                self_native,
2693                &a.to_backend_scalar(),
2694                other_native,
2695                &b.to_backend_scalar(),
2696            )?;
2697            return Self::from_native_with_axis_classes(
2698                self.indices.clone(),
2699                combined,
2700                axis_classes,
2701            );
2702        }
2703
2704        let lhs = self.scale(a)?;
2705        let rhs = other_aligned.scale(b)?;
2706        let combined = lhs
2707            .try_materialized_inner()?
2708            .add(rhs.try_materialized_inner()?)
2709            .map_err(|e| anyhow::anyhow!("tensor addition failed: {e}"))?;
2710        Self::from_inner_with_axis_classes(self.indices.clone(), combined, axis_classes)
2711    }
2712
2713    /// Scalar multiplication.
2714    ///
2715    /// Multiplies every element by `scalar`.
2716    ///
2717    /// # Examples
2718    ///
2719    /// ```
2720    /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen};
2721    ///
2722    /// let i = DynIndex::new_dyn(3);
2723    /// let t = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0, 3.0]).unwrap();
2724    /// let scaled = t.scale(AnyScalar::new_real(2.0)).unwrap();
2725    /// assert_eq!(scaled.to_vec::<f64>().unwrap(), vec![2.0, 4.0, 6.0]);
2726    /// ```
2727    pub fn scale(&self, scalar: AnyScalar) -> Result<Self> {
2728        if !self.tracks_grad() && !scalar.tracks_grad() {
2729            let scaled = self.storage.scale(&scalar.to_backend_scalar())?;
2730            return Self::from_storage(self.indices.clone(), Arc::new(scaled));
2731        }
2732
2733        let self_native = self.as_native()?;
2734        let scalar_native = scalar.as_tensor()?.as_native()?;
2735        if self_native.dtype() != scalar_native.dtype() {
2736            let scaled = scale_native_tensor(self_native, &scalar.to_backend_scalar())?;
2737            return Self::from_native_with_axis_classes(
2738                self.indices.clone(),
2739                scaled,
2740                self.storage.axis_classes().to_vec(),
2741            );
2742        }
2743
2744        let scaled = if self.indices.is_empty() {
2745            self.try_materialized_inner()?
2746                .mul(scalar.as_tensor()?.try_materialized_inner()?)
2747                .map_err(|e| anyhow::anyhow!("scalar multiplication failed: {e}"))?
2748        } else {
2749            let subscripts = Self::scale_subscripts(self.indices.len())?;
2750            eager_einsum_ad(
2751                &[
2752                    self.try_materialized_inner()?,
2753                    scalar.as_tensor()?.try_materialized_inner()?,
2754                ],
2755                &subscripts,
2756            )
2757            .map_err(|e| anyhow::anyhow!("tensor scaling failed: {e}"))?
2758        };
2759        Self::from_inner_with_axis_classes(
2760            self.indices.clone(),
2761            scaled,
2762            self.storage.axis_classes().to_vec(),
2763        )
2764    }
2765
2766    /// Inner product (dot product) of two tensors.
2767    ///
2768    /// Computes `⟨self, other⟩ = Σ conj(self)_i * other_i`.
2769    ///
2770    /// # Examples
2771    ///
2772    /// ```
2773    /// use tensor4all_core::{DynIndex, TensorDynLen};
2774    ///
2775    /// let i = DynIndex::new_dyn(3);
2776    /// let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0, 3.0]).unwrap();
2777    /// let b = TensorDynLen::from_dense(vec![i.clone()], vec![4.0, 5.0, 6.0]).unwrap();
2778    ///
2779    /// // <a, b> = 1*4 + 2*5 + 3*6 = 32
2780    /// let ip = a.inner_product(&b).unwrap();
2781    /// assert!((ip.real() - 32.0).abs() < 1e-12);
2782    /// ```
2783    pub fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
2784        if self.indices.len() == other.indices.len() {
2785            let self_set: HashSet<_> = self.indices.iter().collect();
2786            let other_set: HashSet<_> = other.indices.iter().collect();
2787            if self_set == other_set {
2788                let other_aligned = other.permute_indices(&self.indices)?;
2789                let result = super::contract::contract_pair_with_operand_options(
2790                    self,
2791                    &other_aligned,
2792                    PairwiseContractionOptions::new().with_lhs_conj(true),
2793                )?;
2794                return result.sum();
2795            }
2796        }
2797
2798        // Contract self.conj() with other over all indices
2799        let result = super::contract::contract_pair_with_operand_options(
2800            self,
2801            other,
2802            PairwiseContractionOptions::new().with_lhs_conj(true),
2803        )?;
2804        // Result should be a scalar (no indices)
2805        result.sum()
2806    }
2807}
2808
2809// ============================================================================
2810// Index Replacement Methods
2811// ============================================================================
2812
2813impl TensorDynLen {
2814    /// Replace an index in the tensor with a new index.
2815    ///
2816    /// This replaces the index matching `old_index` by ID with `new_index`.
2817    /// The storage data is not modified, only the index metadata is changed.
2818    ///
2819    /// # Arguments
2820    /// * `old_index` - The index to replace (matched by ID)
2821    /// * `new_index` - The new index to use
2822    ///
2823    /// # Returns
2824    /// A new tensor with the index replaced. If no index matches `old_index`,
2825    /// returns a clone of the original tensor.
2826    ///
2827    /// # Errors
2828    /// Returns an error if the replacement index has a different dimension.
2829    ///
2830    /// # Example
2831    /// ```
2832    /// use tensor4all_core::TensorDynLen;
2833    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2834    ///
2835    /// let i = Index::new_dyn(2);
2836    /// let j = Index::new_dyn(3);
2837    /// let new_i = Index::new_dyn(2);  // Same dimension, different ID
2838    ///
2839    /// let indices = vec![i.clone(), j.clone()];
2840    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![0.0; 6]).unwrap();
2841    ///
2842    /// // Replace index i with new_i
2843    /// let replaced = tensor.replaceind(&i, &new_i).unwrap();
2844    /// assert_eq!(replaced.indices[0].id, new_i.id);
2845    /// assert_eq!(replaced.indices[1].id, j.id);
2846    /// ```
2847    pub fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
2848        // Validate dimension match
2849        if old_index.dim() != new_index.dim() {
2850            return Err(anyhow::anyhow!(
2851                "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2852                old_index.dim(),
2853                new_index.dim()
2854            ));
2855        }
2856
2857        let new_indices: Vec<_> = self
2858            .indices
2859            .iter()
2860            .map(|idx| {
2861                if *idx == *old_index {
2862                    new_index.clone()
2863                } else {
2864                    idx.clone()
2865                }
2866            })
2867            .collect();
2868
2869        Ok(Self {
2870            indices: new_indices,
2871            storage: self.storage.clone(),
2872            structured_ad: self.structured_ad.clone(),
2873            eager_cache: Arc::clone(&self.eager_cache),
2874        })
2875    }
2876
2877    /// Replace multiple indices in the tensor.
2878    ///
2879    /// This replaces each index in `old_indices` (matched by ID) with the corresponding
2880    /// index in `new_indices`. The storage data is not modified.
2881    ///
2882    /// # Arguments
2883    /// * `old_indices` - The indices to replace (matched by ID)
2884    /// * `new_indices` - The new indices to use
2885    ///
2886    /// # Returns
2887    /// A new tensor with the indices replaced. Indices not found in `old_indices`
2888    /// are kept unchanged.
2889    ///
2890    /// # Errors
2891    /// Returns an error if `old_indices` and `new_indices` have different
2892    /// lengths or if any replacement index has a different dimension.
2893    ///
2894    /// # Example
2895    /// ```
2896    /// use tensor4all_core::TensorDynLen;
2897    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2898    ///
2899    /// let i = Index::new_dyn(2);
2900    /// let j = Index::new_dyn(3);
2901    /// let new_i = Index::new_dyn(2);
2902    /// let new_j = Index::new_dyn(3);
2903    ///
2904    /// let indices = vec![i.clone(), j.clone()];
2905    /// let tensor: TensorDynLen = TensorDynLen::from_dense(indices, vec![0.0; 6]).unwrap();
2906    ///
2907    /// // Replace both indices
2908    /// let replaced = tensor
2909    ///     .replaceinds(&[i.clone(), j.clone()], &[new_i.clone(), new_j.clone()])
2910    ///     .unwrap();
2911    /// assert_eq!(replaced.indices[0].id, new_i.id);
2912    /// assert_eq!(replaced.indices[1].id, new_j.id);
2913    /// ```
2914    pub fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
2915        anyhow::ensure!(
2916            old_indices.len() == new_indices.len(),
2917            "old_indices and new_indices must have the same length"
2918        );
2919
2920        // Validate dimension matches for all replacements
2921        for (old, new) in old_indices.iter().zip(new_indices.iter()) {
2922            if old.dim() != new.dim() {
2923                return Err(anyhow::anyhow!(
2924                    "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2925                    old.dim(),
2926                    new.dim()
2927                ));
2928            }
2929        }
2930
2931        // Build a map from old indices to new indices
2932        let replacement_map: std::collections::HashMap<_, _> =
2933            old_indices.iter().zip(new_indices.iter()).collect();
2934
2935        let new_indices_vec: Vec<_> = self
2936            .indices
2937            .iter()
2938            .map(|idx| {
2939                if let Some(new_idx) = replacement_map.get(idx) {
2940                    (*new_idx).clone()
2941                } else {
2942                    idx.clone()
2943                }
2944            })
2945            .collect();
2946
2947        Ok(Self {
2948            indices: new_indices_vec,
2949            storage: self.storage.clone(),
2950            structured_ad: self.structured_ad.clone(),
2951            eager_cache: Arc::clone(&self.eager_cache),
2952        })
2953    }
2954}
2955
2956// ============================================================================
2957// Complex Conjugation
2958// ============================================================================
2959
2960impl TensorDynLen {
2961    /// Complex conjugate of all tensor elements.
2962    ///
2963    /// For real (f64) tensors, returns a copy (conjugate of real is identity).
2964    /// For complex (Complex64) tensors, conjugates each element.
2965    ///
2966    /// The indices and dimensions remain unchanged.
2967    ///
2968    /// This is inspired by the `conj` operation in ITensorMPS.jl.
2969    ///
2970    /// # Example
2971    /// ```
2972    /// use tensor4all_core::TensorDynLen;
2973    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
2974    /// use num_complex::Complex64;
2975    ///
2976    /// let i = Index::new_dyn(2);
2977    /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -4.0)];
2978    /// let tensor: TensorDynLen = TensorDynLen::from_dense(vec![i], data).unwrap();
2979    ///
2980    /// let conj_tensor = tensor.conj();
2981    /// // Elements are now conjugated: 1-2i, 3+4i
2982    /// ```
2983    pub fn conj(&self) -> Self {
2984        // Conjugate tensor: conjugate storage data and map indices via IndexLike::conj()
2985        // For default undirected indices, conj() is a no-op, so this is future-proof
2986        // for QSpace-compatible directed indices where conj() flips Ket <-> Bra
2987        let new_indices: Vec<DynIndex> = self.indices.iter().map(|idx| idx.conj()).collect();
2988        let structured_ad = self.tracked_compact_payload_value().and_then(|value| {
2989            value.payload.conj().ok().map(|payload| {
2990                Arc::new(StructuredAdValue {
2991                    payload: Arc::new(payload),
2992                    payload_dims: value.payload_dims.clone(),
2993                    axis_classes: value.axis_classes.clone(),
2994                })
2995            })
2996        });
2997        let eager_cache = self
2998            .eager_cache
2999            .get()
3000            .and_then(|inner| inner.conj().ok())
3001            .map(Self::eager_cache_with)
3002            .unwrap_or_else(Self::empty_eager_cache);
3003        Self {
3004            indices: new_indices,
3005            storage: self.storage.conj().unwrap_or_else(|_| {
3006                TensorDynLenStorage::from_storage(Arc::new(self.storage().conj()))
3007            }),
3008            structured_ad,
3009            eager_cache,
3010        }
3011    }
3012}
3013
3014// ============================================================================
3015// Norm Computation
3016// ============================================================================
3017
3018impl TensorDynLen {
3019    /// Compute the squared Frobenius norm of the tensor: ||T||² = Σ|T_ijk...|²
3020    ///
3021    /// For real tensors: sum of squares of all elements.
3022    /// For complex tensors: sum of |z|² = z * conj(z) for all elements.
3023    ///
3024    /// # Example
3025    /// ```
3026    /// use tensor4all_core::TensorDynLen;
3027    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3028    ///
3029    /// let i = Index::new_dyn(2);
3030    /// let j = Index::new_dyn(3);
3031    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];  // 1² + 2² + ... + 6² = 91
3032    /// let tensor: TensorDynLen = TensorDynLen::from_dense(vec![i, j], data).unwrap();
3033    ///
3034    /// assert!((tensor.norm_squared() - 91.0).abs() < 1e-10);
3035    /// ```
3036    pub fn norm_squared(&self) -> f64 {
3037        self.try_norm_squared().unwrap_or(f64::NAN)
3038    }
3039
3040    /// Try to compute the squared Frobenius norm of the tensor.
3041    ///
3042    /// # Errors
3043    /// Returns an error if conjugation, contraction, or scalar extraction fails.
3044    pub fn try_norm_squared(&self) -> Result<f64> {
3045        // Special case: scalar tensor (no indices)
3046        if self.indices.is_empty() {
3047            // For a scalar, ||T||² = |value|²
3048            let value = self.sum()?;
3049            let abs_val = value.abs();
3050            return Ok(abs_val * abs_val);
3051        }
3052
3053        // Contract tensor with its conjugate over all indices → scalar
3054        // ||T||² = Σ T_ijk... * conj(T_ijk...) = Σ |T_ijk...|²
3055        let conj = self.conj();
3056        let scalar = super::contract::contract_pair(self, &conj)?;
3057        // The mathematical result is nonnegative and real. Clamp tiny negative
3058        // roundoff so downstream `sqrt` stays well-defined for complex tensors.
3059        Ok(scalar.sum()?.real().max(0.0))
3060    }
3061
3062    /// Compute the Frobenius norm of the tensor: ||T|| = sqrt(Σ|T_ijk...|²)
3063    ///
3064    /// # Example
3065    /// ```
3066    /// use tensor4all_core::TensorDynLen;
3067    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3068    ///
3069    /// let i = Index::new_dyn(2);
3070    /// let data = vec![3.0, 4.0];  // sqrt(9 + 16) = 5
3071    /// let tensor: TensorDynLen = TensorDynLen::from_dense(vec![i], data).unwrap();
3072    ///
3073    /// assert!((tensor.norm() - 5.0).abs() < 1e-10);
3074    /// ```
3075    pub fn norm(&self) -> f64 {
3076        self.norm_squared().sqrt()
3077    }
3078
3079    /// Maximum absolute value of all elements (L-infinity norm).
3080    ///
3081    /// # Examples
3082    ///
3083    /// ```
3084    /// use tensor4all_core::{DynIndex, TensorDynLen};
3085    ///
3086    /// let i = DynIndex::new_dyn(4);
3087    /// let t = TensorDynLen::from_dense(vec![i], vec![-5.0, 1.0, 3.0, -2.0]).unwrap();
3088    /// assert!((t.maxabs() - 5.0).abs() < 1e-12);
3089    /// ```
3090    pub fn maxabs(&self) -> f64 {
3091        self.storage.max_abs().unwrap_or(0.0)
3092    }
3093
3094    /// Element-wise subtraction with index alignment.
3095    ///
3096    /// This computes `self - other` using the same vector-space semantics as
3097    /// [`TensorVectorSpace`](crate::TensorVectorSpace).
3098    ///
3099    /// # Errors
3100    /// Returns an error if the tensors cannot be aligned or subtracted.
3101    pub fn sub(&self, other: &Self) -> Result<Self> {
3102        self.axpby(AnyScalar::new_real(1.0), other, AnyScalar::new_real(-1.0))
3103    }
3104
3105    /// Negate all elements.
3106    ///
3107    /// # Errors
3108    /// Returns an error if scalar multiplication fails for the tensor storage.
3109    pub fn neg(&self) -> Result<Self> {
3110        self.scale(AnyScalar::new_real(-1.0))
3111    }
3112
3113    /// Approximate equality check using Julia `isapprox`-style semantics.
3114    ///
3115    /// Returns `true` when `||self - other|| <= max(atol, rtol *
3116    /// max(||self||, ||other||))`.
3117    pub fn isapprox(&self, other: &Self, atol: f64, rtol: f64) -> bool {
3118        let diff = match self.sub(other) {
3119            Ok(d) => d,
3120            Err(_) => return false,
3121        };
3122        let diff_norm = diff.norm();
3123        diff_norm <= atol.max(rtol * self.norm().max(other.norm()))
3124    }
3125
3126    /// Create a diagonal Kronecker-delta tensor for one input/output index pair.
3127    ///
3128    /// # Errors
3129    /// Returns an error if the two indices have different dimensions.
3130    pub fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3131        <Self as TensorConstructionLike>::diagonal(input_index, output_index)
3132    }
3133
3134    /// Create a product of Kronecker-delta tensors for paired index lists.
3135    ///
3136    /// # Errors
3137    /// Returns an error if the index lists have different lengths or paired
3138    /// dimensions do not match.
3139    pub fn delta(input_indices: &[DynIndex], output_indices: &[DynIndex]) -> Result<Self> {
3140        <Self as TensorConstructionLike>::delta(input_indices, output_indices)
3141    }
3142
3143    /// Create a scalar tensor equal to one.
3144    ///
3145    /// # Errors
3146    /// Returns an error if dense scalar construction fails.
3147    pub fn scalar_one() -> Result<Self> {
3148        <Self as TensorConstructionLike>::scalar_one()
3149    }
3150
3151    /// Create a tensor filled with ones over the given indices.
3152    ///
3153    /// # Errors
3154    /// Returns an error if the tensor size overflows or dense construction fails.
3155    pub fn ones(indices: &[DynIndex]) -> Result<Self> {
3156        <Self as TensorConstructionLike>::ones(indices)
3157    }
3158
3159    /// Create a one-hot tensor with value one at the specified index positions.
3160    ///
3161    /// # Errors
3162    /// Returns an error if any coordinate is outside its index dimension.
3163    pub fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3164        <Self as TensorConstructionLike>::onehot(index_vals)
3165    }
3166
3167    /// Compute the relative distance between two tensors.
3168    ///
3169    /// Returns `||A - B|| / ||A||` (Frobenius norm).
3170    /// If `||A|| = 0`, returns `||B||` instead to avoid division by zero.
3171    ///
3172    /// This is the ITensor-style distance function useful for comparing tensors.
3173    ///
3174    /// # Arguments
3175    /// * `other` - The other tensor to compare with
3176    ///
3177    /// # Returns
3178    /// The relative distance as a f64 value.
3179    ///
3180    /// # Note
3181    /// The indices of both tensors must be permutable to each other.
3182    /// The result tensor (A - B) uses the index ordering from self.
3183    ///
3184    /// # Example
3185    /// ```
3186    /// use tensor4all_core::TensorDynLen;
3187    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3188    ///
3189    /// let i = Index::new_dyn(2);
3190    /// let data_a = vec![1.0, 0.0];
3191    /// let data_b = vec![1.0, 0.0];  // Same tensor
3192    /// let tensor_a: TensorDynLen = TensorDynLen::from_dense(vec![i.clone()], data_a).unwrap();
3193    /// let tensor_b: TensorDynLen = TensorDynLen::from_dense(vec![i.clone()], data_b).unwrap();
3194    ///
3195    /// assert!(tensor_a.distance(&tensor_b).unwrap() < 1e-10);  // Zero distance
3196    /// ```
3197    pub fn distance(&self, other: &Self) -> Result<f64> {
3198        let norm_self = self.norm();
3199
3200        // Compute A - B = A + (-1) * B
3201        let neg_other = other.scale(AnyScalar::new_real(-1.0))?;
3202        let diff = self.add(&neg_other)?;
3203        let norm_diff = diff.norm();
3204
3205        if norm_self > 0.0 {
3206            Ok(norm_diff / norm_self)
3207        } else {
3208            Ok(norm_diff)
3209        }
3210    }
3211}
3212
3213impl std::fmt::Debug for TensorDynLen {
3214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3215        f.debug_struct("TensorDynLen")
3216            .field("indices", &self.indices)
3217            .field("dims", &self.dims())
3218            .field("is_diag", &self.is_diag())
3219            .finish()
3220    }
3221}
3222
3223/// Create a diagonal tensor with dynamic rank from diagonal data.
3224///
3225/// # Arguments
3226/// * `indices` - The indices for the tensor (all must have the same dimension)
3227/// * `diag_data` - The diagonal elements (length must equal the dimension of indices)
3228///
3229/// The returned tensor preserves compact diagonal payload metadata; use
3230/// [`TensorDynLen::is_diag`] or [`TensorDynLen::storage`] to inspect that
3231/// representation.
3232///
3233/// # Panics
3234/// Panics if indices have different dimensions, or if diag_data length doesn't match.
3235///
3236/// # Examples
3237///
3238/// ```
3239/// use tensor4all_core::{DynIndex, diag_tensor_dyn_len};
3240///
3241/// let i = DynIndex::new_dyn(3);
3242/// let j = DynIndex::new_dyn(3);
3243/// let t = diag_tensor_dyn_len(vec![i, j], vec![1.0, 2.0, 3.0]).unwrap();
3244/// assert_eq!(t.dims(), vec![3, 3]);
3245/// assert!(t.is_diag());
3246/// ```
3247pub fn diag_tensor_dyn_len(indices: Vec<DynIndex>, diag_data: Vec<f64>) -> Result<TensorDynLen> {
3248    TensorDynLen::from_diag(indices, diag_data)
3249}
3250
3251#[allow(clippy::type_complexity)]
3252pub(crate) type UnfoldSplitInnerResult = (
3253    EagerTensor<CpuBackend>,
3254    usize,
3255    usize,
3256    usize,
3257    Vec<DynIndex>,
3258    Vec<DynIndex>,
3259);
3260
3261/// Unfold a tensor into a matrix by splitting indices into left and right groups.
3262///
3263/// This function validates the split, permutes the tensor so that left indices
3264/// come first, and returns a rank-2 native tenferro tensor along with metadata.
3265///
3266/// # Arguments
3267/// * `t` - Input tensor
3268/// * `left_inds` - Indices to place on the left (row) side of the matrix
3269///
3270/// # Returns
3271/// A tuple `(matrix_tensor, left_len, m, n, left_indices, right_indices)` where:
3272/// - `matrix_tensor` is a rank-2 `tenferro::Tensor` with shape `[m, n]`
3273/// - `left_len` is the number of left indices
3274/// - `m` is the product of left index dimensions
3275/// - `n` is the product of right index dimensions
3276/// - `left_indices` is the vector of left indices (cloned)
3277/// - `right_indices` is the vector of right indices (cloned)
3278///
3279/// # Errors
3280/// Returns an error if:
3281/// - The tensor rank is < 2
3282/// - `left_inds` is empty or contains all indices
3283/// - `left_inds` contains indices not in the tensor or duplicates
3284/// - Native reshape fails
3285///
3286/// # Examples
3287///
3288/// ```
3289/// use tensor4all_core::{DynIndex, TensorDynLen, unfold_split};
3290///
3291/// let i = DynIndex::new_dyn(2);
3292/// let j = DynIndex::new_dyn(3);
3293/// // 2x3 dense tensor with data [1..6]
3294/// let t = TensorDynLen::from_dense(
3295///     vec![i.clone(), j.clone()],
3296///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
3297/// ).unwrap();
3298///
3299/// let (matrix, left_len, m, n, left_indices, right_indices) =
3300///     unfold_split(&t, &[i]).unwrap();
3301/// assert_eq!(left_len, 1);
3302/// assert_eq!(m, 2);
3303/// assert_eq!(n, 3);
3304/// assert_eq!(left_indices.len(), 1);
3305/// assert_eq!(right_indices.len(), 1);
3306/// ```
3307#[allow(clippy::type_complexity)]
3308pub fn unfold_split(
3309    t: &TensorDynLen,
3310    left_inds: &[DynIndex],
3311) -> Result<(
3312    NativeTensor,
3313    usize,
3314    usize,
3315    usize,
3316    Vec<DynIndex>,
3317    Vec<DynIndex>,
3318)> {
3319    let (matrix_inner, left_len, m, n, left_indices, right_indices) =
3320        unfold_split_inner(t, left_inds)?;
3321
3322    Ok((
3323        matrix_inner.data().clone(),
3324        left_len,
3325        m,
3326        n,
3327        left_indices,
3328        right_indices,
3329    ))
3330}
3331
3332pub(crate) fn unfold_split_inner(
3333    t: &TensorDynLen,
3334    left_inds: &[DynIndex],
3335) -> Result<UnfoldSplitInnerResult> {
3336    let rank = t.indices.len();
3337
3338    // Validate rank
3339    anyhow::ensure!(rank >= 2, "Tensor must have rank >= 2, got rank {}", rank);
3340
3341    let left_len = left_inds.len();
3342
3343    // Validate split: must be a proper subset
3344    anyhow::ensure!(
3345        left_len > 0 && left_len < rank,
3346        "Left indices must be a non-empty proper subset of tensor indices (0 < left_len < rank), got left_len={}, rank={}",
3347        left_len,
3348        rank
3349    );
3350
3351    // Validate that all left_inds are in the tensor and there are no duplicates
3352    let tensor_set: HashSet<_> = t.indices.iter().collect();
3353    let mut left_set = HashSet::new();
3354
3355    for left_idx in left_inds {
3356        anyhow::ensure!(
3357            tensor_set.contains(left_idx),
3358            "Index in left_inds not found in tensor"
3359        );
3360        anyhow::ensure!(left_set.insert(left_idx), "Duplicate index in left_inds");
3361    }
3362
3363    // Build right_inds: all indices not in left_inds, in original order
3364    let mut right_inds = Vec::new();
3365    for idx in &t.indices {
3366        if !left_set.contains(idx) {
3367            right_inds.push(idx.clone());
3368        }
3369    }
3370
3371    // Build new_indices: left_inds first, then right_inds
3372    let mut new_indices = Vec::with_capacity(rank);
3373    new_indices.extend_from_slice(left_inds);
3374    new_indices.extend_from_slice(&right_inds);
3375
3376    // Permute tensor to have left indices first, then right indices
3377    let unfolded = t.permute_indices(&new_indices)?;
3378
3379    // Compute matrix dimensions
3380    let unfolded_dims = unfolded.dims();
3381    let m: usize = unfolded_dims[..left_len].iter().product();
3382    let n: usize = unfolded_dims[left_len..].iter().product();
3383
3384    let matrix_tensor = unfolded.try_materialized_inner()?.reshape(&[m, n])?;
3385
3386    Ok((
3387        matrix_tensor,
3388        left_len,
3389        m,
3390        n,
3391        left_inds.to_vec(),
3392        right_inds,
3393    ))
3394}
3395
3396// ============================================================================
3397// TensorIndex implementation for TensorDynLen
3398// ============================================================================
3399
3400use crate::tensor_index::TensorIndex;
3401
3402impl TensorIndex for TensorDynLen {
3403    type Index = DynIndex;
3404
3405    fn external_indices(&self) -> Vec<DynIndex> {
3406        // For TensorDynLen, all indices are external.
3407        self.indices.clone()
3408    }
3409
3410    fn num_external_indices(&self) -> usize {
3411        self.indices.len()
3412    }
3413
3414    fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
3415        // Delegate to the inherent method
3416        TensorDynLen::replaceind(self, old_index, new_index)
3417    }
3418
3419    fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
3420        // Delegate to the inherent method
3421        TensorDynLen::replaceinds(self, old_indices, new_indices)
3422    }
3423}
3424
3425// ============================================================================
3426// TensorLike implementation for TensorDynLen
3427// ============================================================================
3428
3429use crate::tensor_like::{
3430    FactorizeError, FactorizeOptions, FactorizeResult, TensorConstructionLike,
3431    TensorContractionLike, TensorFactorizationLike, TensorVectorSpace,
3432};
3433
3434impl TensorVectorSpace for TensorDynLen {
3435    fn norm_squared(&self) -> f64 {
3436        TensorDynLen::norm_squared(self)
3437    }
3438
3439    fn maxabs(&self) -> f64 {
3440        TensorDynLen::maxabs(self)
3441    }
3442
3443    fn axpby(&self, a: crate::AnyScalar, other: &Self, b: crate::AnyScalar) -> Result<Self> {
3444        TensorDynLen::axpby(self, a, other, b)
3445    }
3446
3447    fn scale(&self, scalar: crate::AnyScalar) -> Result<Self> {
3448        TensorDynLen::scale(self, scalar)
3449    }
3450
3451    fn inner_product(&self, other: &Self) -> Result<crate::AnyScalar> {
3452        TensorDynLen::inner_product(self, other)
3453    }
3454}
3455
3456impl TensorFactorizationLike for TensorDynLen {
3457    fn factorize(
3458        &self,
3459        left_inds: &[DynIndex],
3460        options: &FactorizeOptions,
3461    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3462        crate::factorize::factorize(self, left_inds, options)
3463    }
3464
3465    fn factorize_full_rank(
3466        &self,
3467        left_inds: &[DynIndex],
3468        alg: crate::FactorizeAlg,
3469        canonical: crate::Canonical,
3470    ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3471        crate::factorize::factorize_full_rank(self, left_inds, alg, canonical)
3472    }
3473}
3474
3475impl TensorContractionLike for TensorDynLen {
3476    fn conj(&self) -> Self {
3477        // Delegate to the inherent method (complex conjugate for dense tensors)
3478        TensorDynLen::conj(self)
3479    }
3480
3481    fn direct_sum(
3482        &self,
3483        other: &Self,
3484        pairs: &[(DynIndex, DynIndex)],
3485    ) -> Result<crate::tensor_like::DirectSumResult<Self>> {
3486        let (tensor, new_indices) = crate::direct_sum::direct_sum(self, other, pairs)?;
3487        Ok(crate::tensor_like::DirectSumResult {
3488            tensor,
3489            new_indices,
3490        })
3491    }
3492
3493    fn outer_product(&self, other: &Self) -> Result<Self> {
3494        super::contract::outer_product(self, other)
3495    }
3496
3497    fn permuteinds(&self, new_order: &[DynIndex]) -> Result<Self> {
3498        // Delegate to the inherent method
3499        TensorDynLen::permute_indices(self, new_order)
3500    }
3501
3502    fn fuse_indices(
3503        &self,
3504        old_indices: &[DynIndex],
3505        new_index: DynIndex,
3506        order: LinearizationOrder,
3507    ) -> Result<Self> {
3508        TensorDynLen::fuse_indices(self, old_indices, new_index, order)
3509    }
3510
3511    fn contract(tensors: &[&Self]) -> Result<Self> {
3512        super::contract::contract(tensors)
3513    }
3514
3515    fn contract_pair(&self, other: &Self) -> Result<Self> {
3516        super::contract::contract_pair(self, other)
3517    }
3518}
3519
3520impl TensorConstructionLike for TensorDynLen {
3521    fn select_indices(&self, selected_indices: &[DynIndex], positions: &[usize]) -> Result<Self> {
3522        TensorDynLen::select_indices(self, selected_indices, positions)
3523    }
3524
3525    fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3526        let dim = input_index.dim();
3527        if dim != output_index.dim() {
3528            return Err(anyhow::anyhow!(
3529                "Dimension mismatch: input index has dim {}, output has dim {}",
3530                dim,
3531                output_index.dim(),
3532            ));
3533        }
3534
3535        TensorDynLen::from_diag(
3536            vec![input_index.clone(), output_index.clone()],
3537            vec![1.0_f64; dim],
3538        )
3539    }
3540
3541    fn scalar_one() -> Result<Self> {
3542        TensorDynLen::from_dense(vec![], vec![1.0_f64])
3543    }
3544
3545    fn ones(indices: &[DynIndex]) -> Result<Self> {
3546        if indices.is_empty() {
3547            return Self::scalar_one();
3548        }
3549        let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3550        let total_size = checked_total_size(&dims)?;
3551        TensorDynLen::from_dense(indices.to_vec(), vec![1.0_f64; total_size])
3552    }
3553
3554    fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3555        if index_vals.is_empty() {
3556            return Self::scalar_one();
3557        }
3558        let indices: Vec<DynIndex> = index_vals.iter().map(|(idx, _)| idx.clone()).collect();
3559        let vals: Vec<usize> = index_vals.iter().map(|(_, v)| *v).collect();
3560        let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3561
3562        for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3563            if v >= d {
3564                return Err(anyhow::anyhow!(
3565                    "onehot: value {} at position {} is >= dimension {}",
3566                    v,
3567                    k,
3568                    d
3569                ));
3570            }
3571        }
3572
3573        let total_size = checked_total_size(&dims)?;
3574        let mut data = vec![0.0_f64; total_size];
3575
3576        let offset = column_major_offset(&dims, &vals)?;
3577        data[offset] = 1.0;
3578
3579        Self::from_dense(indices, data)
3580    }
3581
3582    // delta() uses the default implementation via diagonal() and outer_product()
3583}
3584
3585fn checked_total_size(dims: &[usize]) -> Result<usize> {
3586    dims.iter().try_fold(1_usize, |acc, &d| {
3587        if d == 0 {
3588            return Err(anyhow::anyhow!("invalid dimension 0"));
3589        }
3590        acc.checked_mul(d)
3591            .ok_or_else(|| anyhow::anyhow!("tensor size overflow"))
3592    })
3593}
3594
3595fn column_major_offset(dims: &[usize], vals: &[usize]) -> Result<usize> {
3596    if dims.len() != vals.len() {
3597        return Err(anyhow::anyhow!(
3598            "column_major_offset: dims.len() != vals.len()"
3599        ));
3600    }
3601    checked_total_size(dims)?;
3602
3603    let mut offset = 0usize;
3604    let mut stride = 1usize;
3605    for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3606        if d == 0 {
3607            return Err(anyhow::anyhow!("invalid dimension 0 at position {}", k));
3608        }
3609        if v >= d {
3610            return Err(anyhow::anyhow!(
3611                "column_major_offset: value {} at position {} is >= dimension {}",
3612                v,
3613                k,
3614                d
3615            ));
3616        }
3617        let term = v
3618            .checked_mul(stride)
3619            .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3620        offset = offset
3621            .checked_add(term)
3622            .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3623        stride = stride
3624            .checked_mul(d)
3625            .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3626    }
3627    Ok(offset)
3628}
3629
3630// ============================================================================
3631// High-level API for tensor construction (avoids direct Storage access)
3632// ============================================================================
3633
3634impl TensorDynLen {
3635    fn any_scalar_payload_to_complex(data: Vec<AnyScalar>) -> Vec<Complex64> {
3636        data.into_iter()
3637            .map(|value| {
3638                value
3639                    .as_c64()
3640                    .unwrap_or_else(|| Complex64::new(value.real(), 0.0))
3641            })
3642            .collect()
3643    }
3644
3645    fn any_scalar_payload_to_real(data: Vec<AnyScalar>) -> Vec<f64> {
3646        data.into_iter().map(|value| value.real()).collect()
3647    }
3648
3649    fn validate_dense_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3650        let expected_len = checked_total_size(dims)?;
3651        anyhow::ensure!(
3652            data_len == expected_len,
3653            "dense payload length {} does not match dims {:?} (expected {})",
3654            data_len,
3655            dims,
3656            expected_len
3657        );
3658        Ok(())
3659    }
3660
3661    fn validate_diag_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3662        anyhow::ensure!(
3663            !dims.is_empty(),
3664            "diagonal tensor construction requires at least one index"
3665        );
3666        Self::validate_diag_dims(dims)?;
3667        anyhow::ensure!(
3668            data_len == dims[0],
3669            "diagonal payload length {} does not match diagonal dimension {}",
3670            data_len,
3671            dims[0]
3672        );
3673        Ok(())
3674    }
3675
3676    /// Create a tensor from dense data with explicit indices.
3677    ///
3678    /// This is the recommended high-level API for creating tensors from raw data.
3679    /// It avoids direct access to `Storage` internals.
3680    ///
3681    /// # Type Parameters
3682    /// * `T` - Scalar type (`f64` or `Complex64`)
3683    ///
3684    /// # Arguments
3685    /// * `indices` - Vector of indices for the tensor
3686    /// * `data` - Tensor data in column-major order
3687    ///
3688    /// # Panics
3689    /// Panics if data length doesn't match the product of index dimensions.
3690    ///
3691    /// # Example
3692    /// ```
3693    /// use tensor4all_core::TensorDynLen;
3694    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3695    ///
3696    /// let i = Index::new_dyn(2);
3697    /// let j = Index::new_dyn(3);
3698    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
3699    /// let tensor: TensorDynLen = TensorDynLen::from_dense(vec![i, j], data).unwrap();
3700    /// assert_eq!(tensor.dims(), vec![2, 3]);
3701    /// ```
3702    pub fn from_dense<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3703        let dims = Self::expected_dims_from_indices(&indices);
3704        Self::validate_indices(&indices)?;
3705        Self::validate_dense_payload_len(data.len(), &dims)?;
3706        let native = dense_native_tensor_from_col_major(&data, &dims)?;
3707        Self::from_native(indices, native)
3708    }
3709
3710    /// Create a tensor from dense payload data provided as [`AnyScalar`] values.
3711    ///
3712    /// This is the preferred public API when the caller only knows the scalar
3713    /// type at runtime.
3714    ///
3715    /// # Examples
3716    /// ```
3717    /// use tensor4all_core::{AnyScalar, TensorDynLen};
3718    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3719    ///
3720    /// let i = Index::new_dyn(2);
3721    /// let j = Index::new_dyn(2);
3722    /// let tensor = TensorDynLen::from_dense_any(
3723    ///     vec![i, j],
3724    ///     vec![
3725    ///         AnyScalar::new_real(1.0),
3726    ///         AnyScalar::new_complex(0.0, 1.0),
3727    ///         AnyScalar::new_real(2.0),
3728    ///         AnyScalar::new_real(3.0),
3729    ///     ],
3730    /// ).unwrap();
3731    ///
3732    /// assert!(tensor.is_complex());
3733    /// assert_eq!(tensor.dims(), vec![2, 2]);
3734    /// ```
3735    pub fn from_dense_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3736        if data.iter().any(AnyScalar::is_complex) {
3737            Self::from_dense(indices, Self::any_scalar_payload_to_complex(data))
3738        } else {
3739            Self::from_dense(indices, Self::any_scalar_payload_to_real(data))
3740        }
3741    }
3742
3743    /// Create a diagonal tensor from diagonal payload data with explicit indices.
3744    ///
3745    /// All indices must have the same dimension, and `data.len()` must equal
3746    /// that dimension. The resulting tensor has nonzero entries only on
3747    /// the multi-index diagonal (`T[i,i,...,i] = data[i]`).
3748    ///
3749    /// The returned tensor preserves compact diagonal payload metadata; use
3750    /// [`TensorDynLen::is_diag`] or [`TensorDynLen::storage`] to inspect that
3751    /// representation.
3752    ///
3753    /// # Examples
3754    ///
3755    /// ```
3756    /// use tensor4all_core::{DynIndex, TensorDynLen};
3757    ///
3758    /// let i = DynIndex::new_dyn(3);
3759    /// let j = DynIndex::new_dyn(3);
3760    /// let diag = TensorDynLen::from_diag(vec![i, j], vec![1.0, 2.0, 3.0]).unwrap();
3761    /// assert!(diag.is_diag());
3762    ///
3763    /// let data = diag.to_vec::<f64>().unwrap();
3764    /// // 3x3 identity-like: [1,0,0, 0,2,0, 0,0,3] in column-major
3765    /// assert!((data[0] - 1.0).abs() < 1e-12);
3766    /// assert!((data[4] - 2.0).abs() < 1e-12);
3767    /// assert!((data[8] - 3.0).abs() < 1e-12);
3768    /// assert!((data[1]).abs() < 1e-12);  // off-diagonal is zero
3769    /// ```
3770    pub fn from_diag<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3771        let dims = Self::expected_dims_from_indices(&indices);
3772        Self::validate_indices(&indices)?;
3773        Self::validate_diag_payload_len(data.len(), &dims)?;
3774        let native = diag_native_tensor_from_col_major(&data, dims.len())?;
3775        Self::from_native_with_axis_classes(indices, native, Self::diag_axis_classes(dims.len()))
3776    }
3777
3778    /// Create a diagonal tensor from diagonal payload data provided as
3779    /// [`AnyScalar`] values.
3780    ///
3781    /// This is the preferred public API when the caller only knows the scalar
3782    /// type at runtime.
3783    ///
3784    /// # Examples
3785    /// ```
3786    /// use tensor4all_core::{AnyScalar, TensorDynLen};
3787    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3788    ///
3789    /// let i = Index::new_dyn(2);
3790    /// let j = Index::new_dyn(2);
3791    /// let tensor = TensorDynLen::from_diag_any(
3792    ///     vec![i, j],
3793    ///     vec![AnyScalar::new_real(1.0), AnyScalar::new_complex(2.0, -1.0)],
3794    /// ).unwrap();
3795    ///
3796    /// assert!(tensor.is_complex());
3797    /// assert_eq!(tensor.dims(), vec![2, 2]);
3798    /// ```
3799    pub fn from_diag_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3800        if data.iter().any(AnyScalar::is_complex) {
3801            Self::from_diag(indices, Self::any_scalar_payload_to_complex(data))
3802        } else {
3803            Self::from_diag(indices, Self::any_scalar_payload_to_real(data))
3804        }
3805    }
3806
3807    /// Create a copy tensor whose nonzero entries are `value` on the diagonal.
3808    ///
3809    /// For indices `[i, j, k]`, the returned tensor satisfies
3810    /// `T[i, j, k] = value` when `i = j = k`, and zero otherwise.
3811    ///
3812    /// # Examples
3813    /// ```
3814    /// use tensor4all_core::{AnyScalar, TensorDynLen};
3815    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
3816    ///
3817    /// let i = Index::new_dyn(2);
3818    /// let j = Index::new_dyn(2);
3819    /// let k = Index::new_dyn(2);
3820    /// let tensor = TensorDynLen::copy_tensor(
3821    ///     vec![i, j, k],
3822    ///     AnyScalar::new_real(1.0),
3823    /// ).unwrap();
3824    ///
3825    /// assert_eq!(tensor.dims(), vec![2, 2, 2]);
3826    /// ```
3827    pub fn copy_tensor(indices: Vec<DynIndex>, value: AnyScalar) -> Result<Self> {
3828        if indices.is_empty() {
3829            return Self::from_dense_any(vec![], vec![value]);
3830        }
3831        let dim = indices[0].dim();
3832        let data = vec![value; dim];
3833        Self::from_diag_any(indices, data)
3834    }
3835
3836    /// Replace multiple tensor indices with one fused index using an exact local reshape.
3837    ///
3838    /// The indices in `old_indices` identify the axes to fuse by ID and also
3839    /// define the coordinate order used inside `new_index`. The new fused index
3840    /// is inserted at the earliest axis position among the fused axes; all
3841    /// other axes keep their original relative order. Use
3842    /// [`LinearizationOrder::ColumnMajor`] to match tensor4all's dense vector
3843    /// layout, or [`LinearizationOrder::RowMajor`] when interoperating with
3844    /// row-major fused coordinates.
3845    ///
3846    /// # Arguments
3847    /// * `old_indices` - Non-empty list of existing tensor indices to replace.
3848    ///   Each index is matched by ID, must appear exactly once in the tensor,
3849    ///   must have the same dimension as the matched tensor axis, and must not
3850    ///   be duplicated in this list.
3851    /// * `new_index` - Replacement index whose dimension must equal the product
3852    ///   of the dimensions in `old_indices`.
3853    /// * `order` - Linearization convention used to encode the old coordinates
3854    ///   into the single coordinate of `new_index`.
3855    ///
3856    /// # Returns
3857    /// A tensor with the same element type and values, but with `old_indices`
3858    /// replaced by `new_index`.
3859    ///
3860    /// # Errors
3861    /// Returns an error if `old_indices` is empty, contains duplicate IDs,
3862    /// references an index not present in the tensor, if the fused dimension
3863    /// does not match the product of the old dimensions, if the replacement
3864    /// would duplicate a kept index, or if the dense reshape cannot be
3865    /// represented without overflow.
3866    ///
3867    /// # Examples
3868    /// ```
3869    /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen};
3870    ///
3871    /// let i = DynIndex::new_dyn(2);
3872    /// let j = DynIndex::new_dyn(2);
3873    /// let fused = DynIndex::new_link(4).unwrap();
3874    /// let tensor = TensorDynLen::from_dense(
3875    ///     vec![i.clone(), j.clone()],
3876    ///     vec![1.0, 2.0, 3.0, 4.0],
3877    /// ).unwrap();
3878    ///
3879    /// let fused_tensor = tensor
3880    ///     .fuse_indices(&[i.clone(), j.clone()], fused.clone(), LinearizationOrder::ColumnMajor)
3881    ///     .unwrap();
3882    /// assert_eq!(fused_tensor.dims(), vec![4]);
3883    ///
3884    /// let roundtrip = fused_tensor
3885    ///     .unfuse_index(&fused, &[i, j], LinearizationOrder::ColumnMajor)
3886    ///     .unwrap();
3887    /// assert!(roundtrip.isapprox(&tensor, 1e-12, 0.0));
3888    /// ```
3889    pub fn fuse_indices(
3890        &self,
3891        old_indices: &[DynIndex],
3892        new_index: DynIndex,
3893        order: LinearizationOrder,
3894    ) -> Result<Self> {
3895        anyhow::ensure!(
3896            !old_indices.is_empty(),
3897            "fuse_indices requires at least one index to fuse"
3898        );
3899
3900        let old_dims = self.dims();
3901        let mut seen_indices = HashSet::new();
3902        let mut old_axes = Vec::with_capacity(old_indices.len());
3903        for old_index in old_indices {
3904            anyhow::ensure!(
3905                seen_indices.insert(old_index),
3906                "duplicate index in old_indices"
3907            );
3908            let axis = self
3909                .indices
3910                .iter()
3911                .position(|idx| idx == old_index)
3912                .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
3913            anyhow::ensure!(
3914                old_index.dim() == old_dims[axis],
3915                "old index dimension does not match tensor axis dimension"
3916            );
3917            old_axes.push(axis);
3918        }
3919
3920        let fused_dims: Vec<usize> = old_axes.iter().map(|&axis| old_dims[axis]).collect();
3921        let fused_product = checked_product(&fused_dims)?;
3922        anyhow::ensure!(
3923            fused_product == new_index.dim(),
3924            "product of old index dimensions must match the replacement index dimension"
3925        );
3926
3927        let insertion_axis =
3928            old_axes.iter().copied().min().ok_or_else(|| {
3929                anyhow::anyhow!("fuse_indices requires at least one index to fuse")
3930            })?;
3931        let old_axis_set: HashSet<usize> = old_axes.iter().copied().collect();
3932
3933        let mut result_indices =
3934            Vec::with_capacity(self.indices.len() - old_indices.len() + 1usize);
3935        for (axis, index) in self.indices.iter().enumerate() {
3936            if axis == insertion_axis {
3937                result_indices.push(new_index.clone());
3938            }
3939            if !old_axis_set.contains(&axis) {
3940                result_indices.push(index.clone());
3941            }
3942        }
3943        let mut result_seen = HashSet::new();
3944        for index in &result_indices {
3945            anyhow::ensure!(
3946                result_seen.insert(index),
3947                "fuse_indices result would contain duplicate index"
3948            );
3949        }
3950        Self::validate_indices(&result_indices)?;
3951
3952        let mut new_dims = Vec::with_capacity(old_dims.len() - old_indices.len() + 1usize);
3953        for (axis, dim) in old_dims.iter().copied().enumerate() {
3954            if axis == insertion_axis {
3955                new_dims.push(new_index.dim());
3956            }
3957            if !old_axis_set.contains(&axis) {
3958                new_dims.push(dim);
3959            }
3960        }
3961
3962        let old_data = self.to_vec_any()?;
3963        let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
3964        for (old_linear, value) in old_data.into_iter().enumerate() {
3965            let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
3966            let fused_multi: Vec<usize> = old_axes.iter().map(|&axis| old_multi[axis]).collect();
3967            let fused_linear = encode_linear_with_order(&fused_multi, &fused_dims, order)?;
3968
3969            let mut new_multi = Vec::with_capacity(new_dims.len());
3970            for (axis, old_coord) in old_multi.iter().copied().enumerate() {
3971                if axis == insertion_axis {
3972                    new_multi.push(fused_linear);
3973                }
3974                if !old_axis_set.contains(&axis) {
3975                    new_multi.push(old_coord);
3976                }
3977            }
3978            let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
3979            new_data[new_linear] = value;
3980        }
3981
3982        Self::from_dense_any(result_indices, new_data)
3983    }
3984
3985    /// Replace one fused index with multiple indices using an exact reshape.
3986    ///
3987    /// The caller must specify how the old fused index should be decoded into
3988    /// the new indices via `order`.
3989    ///
3990    /// # Examples
3991    /// ```
3992    /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen};
3993    ///
3994    /// let fused = DynIndex::new_dyn(4);
3995    /// let i = DynIndex::new_dyn(2);
3996    /// let j = DynIndex::new_dyn(2);
3997    /// let tensor = TensorDynLen::from_dense(vec![fused.clone()], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
3998    ///
3999    /// let unfused = tensor
4000    ///     .unfuse_index(&fused, &[i.clone(), j.clone()], LinearizationOrder::ColumnMajor)
4001    ///     .unwrap();
4002    ///
4003    /// let expected = TensorDynLen::from_dense(vec![i, j], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
4004    /// assert!(unfused.isapprox(&expected, 1e-12, 0.0));
4005    /// ```
4006    pub fn unfuse_index(
4007        &self,
4008        old_index: &DynIndex,
4009        new_indices: &[DynIndex],
4010        order: LinearizationOrder,
4011    ) -> Result<Self> {
4012        anyhow::ensure!(
4013            !new_indices.is_empty(),
4014            "unfuse_index requires at least one replacement index"
4015        );
4016
4017        let axis = self
4018            .indices
4019            .iter()
4020            .position(|idx| idx == old_index)
4021            .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
4022
4023        let replacement_dims: Vec<usize> = new_indices.iter().map(DynIndex::dim).collect();
4024        let replacement_product = checked_product(&replacement_dims)?;
4025        anyhow::ensure!(
4026            replacement_product == old_index.dim(),
4027            "product of new index dimensions must match the replaced index dimension"
4028        );
4029
4030        let mut result_indices =
4031            Vec::with_capacity(self.indices.len() - 1usize + new_indices.len());
4032        result_indices.extend_from_slice(&self.indices[..axis]);
4033        result_indices.extend(new_indices.iter().cloned());
4034        result_indices.extend_from_slice(&self.indices[axis + 1..]);
4035        Self::validate_indices(&result_indices)?;
4036
4037        let old_dims = self.dims();
4038        let mut new_dims = Vec::with_capacity(old_dims.len() - 1usize + replacement_dims.len());
4039        new_dims.extend_from_slice(&old_dims[..axis]);
4040        new_dims.extend_from_slice(&replacement_dims);
4041        new_dims.extend_from_slice(&old_dims[axis + 1..]);
4042
4043        let old_data = self.to_vec_any()?;
4044        let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
4045        for (old_linear, value) in old_data.into_iter().enumerate() {
4046            let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
4047            let split_multi = decode_linear_with_order(old_multi[axis], &replacement_dims, order)?;
4048            let mut new_multi = Vec::with_capacity(new_dims.len());
4049            new_multi.extend_from_slice(&old_multi[..axis]);
4050            new_multi.extend_from_slice(&split_multi);
4051            new_multi.extend_from_slice(&old_multi[axis + 1..]);
4052            let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
4053            new_data[new_linear] = value;
4054        }
4055
4056        Self::from_dense_any(result_indices, new_data)
4057    }
4058
4059    /// Create a scalar (0-dimensional) tensor from a supported element value.
4060    ///
4061    /// # Example
4062    /// ```
4063    /// use tensor4all_core::TensorDynLen;
4064    ///
4065    /// let scalar = TensorDynLen::scalar(42.0).unwrap();
4066    /// assert_eq!(scalar.dims(), Vec::<usize>::new());
4067    /// assert_eq!(scalar.only().unwrap().real(), 42.0);
4068    /// ```
4069    pub fn scalar<T: TensorElement>(value: T) -> Result<Self> {
4070        Self::from_dense(vec![], vec![value])
4071    }
4072
4073    /// Create a tensor filled with zeros of a supported element type.
4074    ///
4075    /// # Example
4076    /// ```
4077    /// use tensor4all_core::TensorDynLen;
4078    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
4079    ///
4080    /// let i = Index::new_dyn(2);
4081    /// let j = Index::new_dyn(3);
4082    /// let tensor = TensorDynLen::zeros::<f64>(vec![i, j]).unwrap();
4083    /// assert_eq!(tensor.dims(), vec![2, 3]);
4084    /// ```
4085    pub fn zeros<T: TensorElement + Zero + Clone>(indices: Vec<DynIndex>) -> Result<Self> {
4086        let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
4087        let size: usize = dims.iter().product();
4088        Self::from_dense(indices, vec![T::zero(); size])
4089    }
4090}
4091
4092// ============================================================================
4093// High-level API for data extraction (avoids direct .storage() access)
4094// ============================================================================
4095
4096impl TensorDynLen {
4097    /// Extract tensor data as a column-major `Vec<T>`.
4098    ///
4099    /// # Type Parameters
4100    /// * `T` - The scalar element type (`f64` or `Complex64`).
4101    ///
4102    /// # Returns
4103    /// A vector of the tensor data in column-major order.
4104    ///
4105    /// # Errors
4106    /// Returns an error if the tensor's scalar type does not match `T`.
4107    ///
4108    /// # Example
4109    /// ```
4110    /// use tensor4all_core::TensorDynLen;
4111    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
4112    ///
4113    /// let i = Index::new_dyn(2);
4114    /// let tensor = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
4115    /// let data = tensor.to_vec::<f64>().unwrap();
4116    /// assert_eq!(data, &[1.0, 2.0]);
4117    /// ```
4118    pub fn to_vec<T: TensorElement>(&self) -> Result<Vec<T>> {
4119        native_tensor_primal_to_dense_col_major(self.as_native()?)
4120    }
4121
4122    /// Consume the tensor and return its indices with dense column-major values.
4123    ///
4124    /// Use this when a caller needs to move index metadata and dense payload
4125    /// values across an API boundary. The returned values are ordered with the
4126    /// first tensor index varying fastest. Compact diagonal or structured
4127    /// storage is materialized into dense logical values.
4128    ///
4129    /// # Type Parameters
4130    /// * `T` - The scalar element type to extract. Use `f64` for real tensors
4131    ///   and `Complex64` for complex tensors.
4132    ///
4133    /// # Returns
4134    /// The tensor's original indices and dense column-major flat data.
4135    ///
4136    /// # Errors
4137    /// Returns an error if the tensor has tracked autodiff state, if the
4138    /// requested scalar type does not match the tensor payload, or if dense
4139    /// materialization fails.
4140    ///
4141    /// # Examples
4142    /// ```
4143    /// use tensor4all_core::{DynIndex, TensorDynLen};
4144    ///
4145    /// let i = DynIndex::new_dyn(2);
4146    /// let j = DynIndex::new_dyn(2);
4147    /// let tensor = TensorDynLen::from_dense(
4148    ///     vec![i.clone(), j.clone()],
4149    ///     vec![1.0_f64, 2.0, 3.0, 4.0],
4150    /// ).unwrap();
4151    ///
4152    /// let (indices, data) = tensor.into_dense_col_major_parts::<f64>().unwrap();
4153    ///
4154    /// assert_eq!(indices, vec![i, j]);
4155    /// assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
4156    /// ```
4157    pub fn into_dense_col_major_parts<T: TensorElement>(self) -> Result<(Vec<DynIndex>, Vec<T>)> {
4158        anyhow::ensure!(
4159            self.structured_ad.is_none() && !self.tracks_grad(),
4160            "TensorDynLen::into_dense_col_major_parts cannot consume tensors with tracked autodiff state"
4161        );
4162        let data = self.to_vec::<T>()?;
4163        Ok((self.indices, data))
4164    }
4165
4166    fn to_vec_any(&self) -> Result<Vec<AnyScalar>> {
4167        if self.is_complex() {
4168            self.to_vec::<Complex64>().map(|data| {
4169                data.into_iter()
4170                    .map(|value| AnyScalar::new_complex(value.re, value.im))
4171                    .collect()
4172            })
4173        } else {
4174            self.to_vec::<f64>()
4175                .map(|data| data.into_iter().map(AnyScalar::new_real).collect())
4176        }
4177    }
4178
4179    /// Extract tensor data as a column-major `Vec<f64>`.
4180    ///
4181    /// Prefer the generic [`to_vec::<f64>()`](Self::to_vec) method.
4182    /// This wrapper is kept for C API compatibility.
4183    pub fn as_slice_f64(&self) -> Result<Vec<f64>> {
4184        self.to_vec::<f64>()
4185    }
4186
4187    /// Extract tensor data as a column-major `Vec<Complex64>`.
4188    ///
4189    /// Prefer the generic [`to_vec::<Complex64>()`](Self::to_vec) method.
4190    /// This wrapper is kept for C API compatibility.
4191    pub fn as_slice_c64(&self) -> Result<Vec<Complex64>> {
4192        self.to_vec::<Complex64>()
4193    }
4194
4195    /// Check if the tensor has f64 storage.
4196    ///
4197    /// # Example
4198    /// ```
4199    /// use tensor4all_core::TensorDynLen;
4200    /// use tensor4all_core::index::{DefaultIndex as Index, DynId};
4201    ///
4202    /// let i = Index::new_dyn(2);
4203    /// let tensor = TensorDynLen::from_dense(vec![i], vec![1.0, 2.0]).unwrap();
4204    /// assert!(tensor.is_f64());
4205    /// assert!(!tensor.is_complex());
4206    /// ```
4207    pub fn is_f64(&self) -> bool {
4208        self.storage.is_f64()
4209    }
4210
4211    /// Check whether the tensor carries diagonal logical axis metadata.
4212    ///
4213    /// # Examples
4214    ///
4215    /// ```
4216    /// use tensor4all_core::{DynIndex, TensorDynLen};
4217    /// use tensor4all_tensorbackend::Storage;
4218    ///
4219    /// // Tensors from `from_dense` use dense storage
4220    /// let i = DynIndex::new_dyn(2);
4221    /// let j = DynIndex::new_dyn(2);
4222    /// let dense = TensorDynLen::from_dense(vec![i, j], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
4223    /// assert!(!dense.is_diag());
4224    ///
4225    /// // Diagonal metadata is preserved when constructing from diagonal storage.
4226    /// let k = DynIndex::new_dyn(2);
4227    /// let l = DynIndex::new_dyn(2);
4228    /// let diag = TensorDynLen::from_storage(
4229    ///     vec![k, l],
4230    ///     Storage::from_diag_col_major(vec![1.0, 2.0], 2)
4231    ///         .map(std::sync::Arc::new)
4232    ///         .unwrap(),
4233    /// )
4234    /// .unwrap();
4235    /// assert!(diag.is_diag());
4236    /// ```
4237    pub fn is_diag(&self) -> bool {
4238        self.storage.is_diag()
4239    }
4240
4241    /// Check if the tensor has complex storage (C64).
4242    ///
4243    /// # Examples
4244    ///
4245    /// ```
4246    /// use tensor4all_core::{DynIndex, TensorDynLen};
4247    /// use num_complex::Complex64;
4248    ///
4249    /// let i = DynIndex::new_dyn(2);
4250    /// let real_t = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap();
4251    /// assert!(!real_t.is_complex());
4252    ///
4253    /// let complex_t = TensorDynLen::from_dense(
4254    ///     vec![i],
4255    ///     vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
4256    /// ).unwrap();
4257    /// assert!(complex_t.is_complex());
4258    /// ```
4259    pub fn is_complex(&self) -> bool {
4260        self.storage.is_complex()
4261    }
4262}
4263
4264fn checked_product(dims: &[usize]) -> Result<usize> {
4265    dims.iter().try_fold(1usize, |acc, &dim| {
4266        acc.checked_mul(dim)
4267            .ok_or_else(|| anyhow::anyhow!("dimension product overflow"))
4268    })
4269}
4270
4271fn decode_col_major_linear(linear: usize, dims: &[usize]) -> Result<Vec<usize>> {
4272    let total = checked_product(dims)?;
4273    anyhow::ensure!(
4274        linear < total,
4275        "linear offset {} out of bounds for dims {:?}",
4276        linear,
4277        dims
4278    );
4279    let mut remaining = linear;
4280    let mut out = Vec::with_capacity(dims.len());
4281    for &dim in dims {
4282        out.push(remaining % dim);
4283        remaining /= dim;
4284    }
4285    Ok(out)
4286}
4287
4288fn encode_col_major_linear(indices: &[usize], dims: &[usize]) -> Result<usize> {
4289    anyhow::ensure!(
4290        indices.len() == dims.len(),
4291        "index rank {} does not match dims {:?}",
4292        indices.len(),
4293        dims
4294    );
4295    let mut linear = 0usize;
4296    let mut stride = 1usize;
4297    for (&index, &dim) in indices.iter().zip(dims.iter()) {
4298        anyhow::ensure!(
4299            index < dim,
4300            "index {} out of bounds for dimension {}",
4301            index,
4302            dim
4303        );
4304        linear += index * stride;
4305        stride = stride
4306            .checked_mul(dim)
4307            .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
4308    }
4309    Ok(linear)
4310}
4311
4312fn decode_linear_with_order(
4313    linear: usize,
4314    dims: &[usize],
4315    order: LinearizationOrder,
4316) -> Result<Vec<usize>> {
4317    let total = checked_product(dims)?;
4318    anyhow::ensure!(
4319        linear < total,
4320        "linear offset {} out of bounds for dims {:?}",
4321        linear,
4322        dims
4323    );
4324
4325    let mut remaining = linear;
4326    let mut out = vec![0usize; dims.len()];
4327    match order {
4328        LinearizationOrder::ColumnMajor => {
4329            for (slot, &dim) in out.iter_mut().zip(dims.iter()) {
4330                *slot = remaining % dim;
4331                remaining /= dim;
4332            }
4333        }
4334        LinearizationOrder::RowMajor => {
4335            for (slot, &dim) in out.iter_mut().rev().zip(dims.iter().rev()) {
4336                *slot = remaining % dim;
4337                remaining /= dim;
4338            }
4339        }
4340    }
4341    Ok(out)
4342}
4343
4344fn encode_linear_with_order(
4345    indices: &[usize],
4346    dims: &[usize],
4347    order: LinearizationOrder,
4348) -> Result<usize> {
4349    match order {
4350        LinearizationOrder::ColumnMajor => encode_col_major_linear(indices, dims),
4351        LinearizationOrder::RowMajor => {
4352            anyhow::ensure!(
4353                indices.len() == dims.len(),
4354                "index rank {} does not match dims {:?}",
4355                indices.len(),
4356                dims
4357            );
4358            let mut linear = 0usize;
4359            let mut stride = 1usize;
4360            for (&index, &dim) in indices.iter().rev().zip(dims.iter().rev()) {
4361                anyhow::ensure!(
4362                    index < dim,
4363                    "index {} out of bounds for dimension {}",
4364                    index,
4365                    dim
4366                );
4367                linear += index * stride;
4368                stride = stride
4369                    .checked_mul(dim)
4370                    .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
4371            }
4372            Ok(linear)
4373        }
4374    }
4375}