Skip to main content

tensor4all_tensorbackend/
storage.rs

1use anyhow::{anyhow, ensure, Result};
2use num_complex::{Complex64, ComplexFloat};
3use std::ops::Mul;
4use std::sync::Arc;
5
6/// Trait for scalar types that can be stored in [`Storage`].
7///
8/// This enables generic constructors such as [`Storage::from_dense_col_major`]
9/// and [`Storage::from_diag_col_major`]. Implemented for `f64` and `Complex64`.
10///
11/// # Examples
12///
13/// ```
14/// use tensor4all_tensorbackend::{Storage, StorageScalar};
15///
16/// // Using the generic constructor -- scalar type is inferred from data
17/// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0], &[3]).unwrap();
18/// assert!(s.is_f64());
19///
20/// use num_complex::Complex64;
21/// let c = Storage::from_dense_col_major(
22///     vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
23///     &[2],
24/// ).unwrap();
25/// assert!(c.is_c64());
26/// ```
27pub trait StorageScalar: Clone + Send + Sync + 'static {
28    /// Build a dense [`Storage`] from column-major data.
29    fn build_dense_storage(data: Vec<Self>, logical_dims: &[usize]) -> Result<Storage>;
30    /// Build a diagonal [`Storage`] from diagonal payload data.
31    fn build_diag_storage(diag_data: Vec<Self>, logical_rank: usize) -> Result<Storage>;
32    /// Build a structured [`Storage`] from explicit payload metadata.
33    fn build_structured_storage(
34        data: Vec<Self>,
35        payload_dims: Vec<usize>,
36        strides: Vec<isize>,
37        axis_classes: Vec<usize>,
38    ) -> Result<Storage>;
39}
40
41impl StorageScalar for f64 {
42    fn build_dense_storage(data: Vec<Self>, logical_dims: &[usize]) -> Result<Storage> {
43        Storage::validate_dense_len(&data, logical_dims, "dense f64 payload")?;
44        Ok(Storage::from_repr(StorageRepr::F64(
45            StructuredStorage::from_dense_col_major(data, logical_dims)?,
46        )))
47    }
48    fn build_diag_storage(diag_data: Vec<Self>, logical_rank: usize) -> Result<Storage> {
49        Ok(Storage::from_repr(StorageRepr::F64(
50            StructuredStorage::from_diag_col_major(diag_data, logical_rank)?,
51        )))
52    }
53    fn build_structured_storage(
54        data: Vec<Self>,
55        payload_dims: Vec<usize>,
56        strides: Vec<isize>,
57        axis_classes: Vec<usize>,
58    ) -> Result<Storage> {
59        Ok(Storage::from_repr(StorageRepr::F64(
60            StructuredStorage::new(data, payload_dims, strides, axis_classes)?,
61        )))
62    }
63}
64
65impl StorageScalar for Complex64 {
66    fn build_dense_storage(data: Vec<Self>, logical_dims: &[usize]) -> Result<Storage> {
67        Storage::validate_dense_len(&data, logical_dims, "dense c64 payload")?;
68        Ok(Storage::from_repr(StorageRepr::C64(
69            StructuredStorage::from_dense_col_major(data, logical_dims)?,
70        )))
71    }
72    fn build_diag_storage(diag_data: Vec<Self>, logical_rank: usize) -> Result<Storage> {
73        Ok(Storage::from_repr(StorageRepr::C64(
74            StructuredStorage::from_diag_col_major(diag_data, logical_rank)?,
75        )))
76    }
77    fn build_structured_storage(
78        data: Vec<Self>,
79        payload_dims: Vec<usize>,
80        strides: Vec<isize>,
81        axis_classes: Vec<usize>,
82    ) -> Result<Storage> {
83        Ok(Storage::from_repr(StorageRepr::C64(
84            StructuredStorage::new(data, payload_dims, strides, axis_classes)?,
85        )))
86    }
87}
88
89pub(crate) fn col_major_strides(dims: &[usize]) -> Result<Vec<isize>> {
90    let mut strides = Vec::with_capacity(dims.len());
91    let mut stride = 1isize;
92    for &dim in dims {
93        strides.push(stride);
94        let dim = isize::try_from(dim)
95            .map_err(|_| anyhow!("column-major stride overflow for dims {dims:?}"))?;
96        stride = stride
97            .checked_mul(dim)
98            .ok_or_else(|| anyhow!("column-major stride overflow for dims {dims:?}"))?;
99    }
100    Ok(strides)
101}
102
103fn validate_canonical_axis_classes(axis_classes: &[usize]) -> Result<()> {
104    let mut next_class = 0usize;
105    for &class_id in axis_classes {
106        ensure!(
107            class_id <= next_class,
108            "axis_classes must be canonical first-appearance labels, got {axis_classes:?}"
109        );
110        if class_id == next_class {
111            next_class += 1;
112        }
113    }
114    Ok(())
115}
116
117fn required_storage_len(dims: &[usize], strides: &[isize]) -> Result<usize> {
118    if dims.is_empty() {
119        return Ok(1);
120    }
121    if dims.contains(&0) {
122        return Ok(0);
123    }
124    ensure!(
125        dims.len() == strides.len(),
126        "payload dims {:?} and strides {:?} must have the same rank",
127        dims,
128        strides
129    );
130
131    let mut max_offset = 0usize;
132    for (&dim, &stride) in dims.iter().zip(strides.iter()) {
133        ensure!(
134            stride >= 0,
135            "negative strides are not supported in StructuredStorage: {strides:?}"
136        );
137        if dim > 1 {
138            max_offset = max_offset
139                .checked_add((dim - 1) * usize::try_from(stride).unwrap_or(usize::MAX))
140                .ok_or_else(|| anyhow!("payload stride overflow for dims {dims:?}"))?;
141        }
142    }
143    Ok(max_offset + 1)
144}
145
146fn logical_dims_from_axis_classes(payload_dims: &[usize], axis_classes: &[usize]) -> Vec<usize> {
147    axis_classes
148        .iter()
149        .map(|&class_id| payload_dims[class_id])
150        .collect()
151}
152
153fn col_major_multi_index(mut linear: usize, dims: &[usize]) -> Vec<usize> {
154    let mut index = Vec::with_capacity(dims.len());
155    for &dim in dims {
156        if dim == 0 {
157            index.push(0);
158        } else {
159            index.push(linear % dim);
160            linear /= dim;
161        }
162    }
163    index
164}
165
166fn offset_from_strides(index: &[usize], strides: &[isize]) -> usize {
167    index
168        .iter()
169        .zip(strides.iter())
170        .map(|(&value, &stride)| value * usize::try_from(stride).unwrap_or(usize::MAX))
171        .sum()
172}
173
174/// Structured tensor snapshot storage.
175///
176/// `data` and `strides` describe the payload tensor, while `axis_classes`
177/// describes how logical axes map onto payload axes. Logical flat-buffer
178/// semantics are column-major.
179///
180/// A **dense** tensor has `axis_classes = [0, 1, ..., rank-1]` (each logical
181/// axis maps to a distinct payload axis). A **diagonal** tensor has
182/// `axis_classes = [0, 0, ..., 0]` (all logical axes share one payload axis),
183/// storing only the diagonal entries.
184///
185/// # Examples
186///
187/// ```
188/// use tensor4all_tensorbackend::StructuredStorage;
189///
190/// // Dense 2x3 storage, column-major: [[1,3,5],[2,4,6]]
191/// let dense = StructuredStorage::from_dense_col_major(
192///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3],
193/// ).unwrap();
194/// assert!(dense.is_dense());
195/// assert!(!dense.is_diag());
196/// assert_eq!(dense.logical_rank(), 2);
197/// assert_eq!(dense.logical_dims(), vec![2, 3]);
198///
199/// // Diagonal 3x3 storage
200/// let diag = StructuredStorage::from_diag_col_major(vec![1.0, 2.0, 3.0], 2).unwrap();
201/// assert!(diag.is_diag());
202/// assert_eq!(diag.logical_dims(), vec![3, 3]);
203/// assert_eq!(diag.len(), 3);
204/// ```
205#[derive(Debug, Clone, PartialEq)]
206pub struct StructuredStorage<T> {
207    data: Vec<T>,
208    payload_dims: Vec<usize>,
209    strides: Vec<isize>,
210    axis_classes: Vec<usize>,
211}
212
213impl<T> StructuredStorage<T> {
214    /// Creates a structured payload snapshot from explicit payload metadata.
215    ///
216    /// `payload_dims` and `strides` describe the compressed payload tensor,
217    /// while `axis_classes` maps logical axes onto payload axes in canonical
218    /// first-appearance order.
219    ///
220    /// # Errors
221    ///
222    /// Returns an error if:
223    /// - `axis_classes` is not in canonical first-appearance form
224    /// - `payload_dims` rank does not match `axis_classes`
225    /// - `strides` rank does not match `payload_dims`
226    /// - `data` length does not match the required storage length
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// use tensor4all_tensorbackend::StructuredStorage;
232    ///
233    /// // Dense 2x3 with explicit column-major strides
234    /// let s = StructuredStorage::new(
235    ///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
236    ///     vec![2, 3],     // payload_dims
237    ///     vec![1, 2],     // column-major strides
238    ///     vec![0, 1],     // axis_classes: each axis is independent
239    /// ).unwrap();
240    /// assert!(s.is_dense());
241    /// assert_eq!(s.len(), 6);
242    /// ```
243    pub fn new(
244        data: Vec<T>,
245        payload_dims: Vec<usize>,
246        strides: Vec<isize>,
247        axis_classes: Vec<usize>,
248    ) -> Result<Self> {
249        validate_canonical_axis_classes(&axis_classes)?;
250        ensure!(
251            payload_dims.len()
252                == axis_classes
253                    .iter()
254                    .copied()
255                    .max()
256                    .map(|value| value + 1)
257                    .unwrap_or(0),
258            "payload rank {} does not match axis_classes {:?}",
259            payload_dims.len(),
260            axis_classes
261        );
262        ensure!(
263            strides.len() == payload_dims.len(),
264            "payload dims {:?} and strides {:?} must have the same rank",
265            payload_dims,
266            strides
267        );
268        let required_len = required_storage_len(&payload_dims, &strides)?;
269        ensure!(
270            data.len() == required_len,
271            "payload storage len {} does not match required len {} for dims {:?} and strides {:?}",
272            data.len(),
273            required_len,
274            payload_dims,
275            strides
276        );
277        Ok(Self {
278            data,
279            payload_dims,
280            strides,
281            axis_classes,
282        })
283    }
284
285    /// Creates a dense structured snapshot from column-major logical data.
286    ///
287    /// # Errors
288    ///
289    /// Returns an error if `data.len()` does not equal the product of
290    /// `logical_dims`, or if column-major stride computation overflows.
291    ///
292    /// # Examples
293    ///
294    /// ```
295    /// use tensor4all_tensorbackend::StructuredStorage;
296    ///
297    /// let s = StructuredStorage::from_dense_col_major(vec![10.0, 20.0, 30.0, 40.0], &[2, 2]).unwrap();
298    /// assert!(s.is_dense());
299    /// assert_eq!(s.data(), &[10.0, 20.0, 30.0, 40.0]);
300    /// ```
301    pub fn from_dense_col_major(data: Vec<T>, logical_dims: &[usize]) -> Result<Self> {
302        let payload_dims = logical_dims.to_vec();
303        let strides = col_major_strides(&payload_dims)?;
304        let axis_classes = (0..logical_dims.len()).collect();
305        Self::new(data, payload_dims, strides, axis_classes)
306    }
307
308    /// Creates a diagonal structured snapshot from column-major diagonal data.
309    ///
310    /// The resulting tensor has `logical_rank` axes, each of size `diag_data.len()`.
311    /// Only the diagonal entries are stored.
312    ///
313    /// # Errors
314    ///
315    /// Returns an error if `logical_rank` is zero and the data does not contain
316    /// exactly one scalar value, or if column-major stride computation overflows.
317    ///
318    /// # Examples
319    ///
320    /// ```
321    /// use tensor4all_tensorbackend::StructuredStorage;
322    ///
323    /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0, 3.0], 2).unwrap();
324    /// assert!(d.is_diag());
325    /// assert_eq!(d.logical_dims(), vec![3, 3]);
326    /// assert_eq!(d.data(), &[1.0, 2.0, 3.0]);
327    /// ```
328    pub fn from_diag_col_major(diag_data: Vec<T>, logical_rank: usize) -> Result<Self> {
329        let payload_dims = if logical_rank == 0 {
330            vec![]
331        } else {
332            vec![diag_data.len()]
333        };
334        let strides = col_major_strides(&payload_dims)?;
335        let axis_classes = vec![0; logical_rank];
336        Self::new(diag_data, payload_dims, strides, axis_classes)
337    }
338
339    /// Returns the payload data buffer as a slice.
340    ///
341    /// # Examples
342    ///
343    /// ```
344    /// use tensor4all_tensorbackend::StructuredStorage;
345    ///
346    /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0], &[2]).unwrap();
347    /// assert_eq!(s.data(), &[1.0, 2.0]);
348    /// ```
349    pub fn data(&self) -> &[T] {
350        &self.data
351    }
352
353    /// Returns the payload tensor dimensions.
354    ///
355    /// For dense tensors, this equals the logical dimensions. For diagonal
356    /// tensors, this is a single-element slice `[n]` where `n` is the diagonal
357    /// length.
358    ///
359    /// # Examples
360    ///
361    /// ```
362    /// use tensor4all_tensorbackend::StructuredStorage;
363    ///
364    /// let s = StructuredStorage::from_dense_col_major(vec![0.0; 6], &[2, 3]).unwrap();
365    /// assert_eq!(s.payload_dims(), &[2, 3]);
366    ///
367    /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 3).unwrap();
368    /// assert_eq!(d.payload_dims(), &[2]);
369    /// ```
370    pub fn payload_dims(&self) -> &[usize] {
371        &self.payload_dims
372    }
373
374    /// Returns the payload tensor strides.
375    ///
376    /// # Examples
377    ///
378    /// ```
379    /// use tensor4all_tensorbackend::StructuredStorage;
380    ///
381    /// // Column-major 2x3: strides are [1, 2]
382    /// let s = StructuredStorage::from_dense_col_major(vec![0.0; 6], &[2, 3]).unwrap();
383    /// assert_eq!(s.strides(), &[1, 2]);
384    /// ```
385    pub fn strides(&self) -> &[isize] {
386        &self.strides
387    }
388
389    /// Returns the canonical logical-to-payload axis classes.
390    ///
391    /// Each entry maps a logical axis to a payload axis index. Repeated values
392    /// indicate axes that share the same payload dimension (e.g., diagonal).
393    ///
394    /// # Examples
395    ///
396    /// ```
397    /// use tensor4all_tensorbackend::StructuredStorage;
398    ///
399    /// let dense = StructuredStorage::from_dense_col_major(vec![0.0; 4], &[2, 2]).unwrap();
400    /// assert_eq!(dense.axis_classes(), &[0, 1]);
401    ///
402    /// let diag = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
403    /// assert_eq!(diag.axis_classes(), &[0, 0]);
404    /// ```
405    pub fn axis_classes(&self) -> &[usize] {
406        &self.axis_classes
407    }
408
409    /// Returns the logical dimensions derived from `payload_dims` and `axis_classes`.
410    ///
411    /// # Examples
412    ///
413    /// ```
414    /// use tensor4all_tensorbackend::StructuredStorage;
415    ///
416    /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0, 3.0], 3).unwrap();
417    /// assert_eq!(d.logical_dims(), vec![3, 3, 3]);
418    /// ```
419    pub fn logical_dims(&self) -> Vec<usize> {
420        logical_dims_from_axis_classes(&self.payload_dims, &self.axis_classes)
421    }
422
423    /// Returns the logical rank (number of logical axes).
424    ///
425    /// # Examples
426    ///
427    /// ```
428    /// use tensor4all_tensorbackend::StructuredStorage;
429    ///
430    /// let s = StructuredStorage::from_dense_col_major(vec![0.0; 6], &[2, 3]).unwrap();
431    /// assert_eq!(s.logical_rank(), 2);
432    /// ```
433    pub fn logical_rank(&self) -> usize {
434        self.axis_classes.len()
435    }
436
437    /// Returns `true` when the logical tensor is dense (each logical axis maps
438    /// to a unique payload axis).
439    ///
440    /// # Examples
441    ///
442    /// ```
443    /// use tensor4all_tensorbackend::StructuredStorage;
444    ///
445    /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0], &[2]).unwrap();
446    /// assert!(s.is_dense());
447    ///
448    /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
449    /// assert!(!d.is_dense());
450    /// ```
451    pub fn is_dense(&self) -> bool {
452        self.axis_classes
453            .iter()
454            .copied()
455            .eq(0..self.axis_classes.len())
456    }
457
458    /// Returns `true` when the logical tensor is diagonal (rank >= 2 and all
459    /// logical axes map to the same payload axis).
460    ///
461    /// # Examples
462    ///
463    /// ```
464    /// use tensor4all_tensorbackend::StructuredStorage;
465    ///
466    /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
467    /// assert!(d.is_diag());
468    ///
469    /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0], &[2]).unwrap();
470    /// assert!(!s.is_diag());
471    /// ```
472    pub fn is_diag(&self) -> bool {
473        self.logical_rank() >= 2 && self.axis_classes.iter().all(|&class_id| class_id == 0)
474    }
475
476    /// Returns the payload buffer length.
477    ///
478    /// # Examples
479    ///
480    /// ```
481    /// use tensor4all_tensorbackend::StructuredStorage;
482    ///
483    /// let dense = StructuredStorage::from_dense_col_major(vec![1.0, 2.0, 3.0], &[3]).unwrap();
484    /// assert_eq!(dense.len(), 3);
485    ///
486    /// let diag = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
487    /// assert_eq!(diag.len(), 2);
488    /// ```
489    pub fn len(&self) -> usize {
490        self.data.len()
491    }
492
493    /// Returns `true` when the payload buffer is empty.
494    ///
495    /// # Examples
496    ///
497    /// ```
498    /// use tensor4all_tensorbackend::StructuredStorage;
499    ///
500    /// let empty = StructuredStorage::from_dense_col_major(Vec::<f64>::new(), &[0]).unwrap();
501    /// assert!(empty.is_empty());
502    ///
503    /// let non_empty = StructuredStorage::from_dense_col_major(vec![1.0], &[1]).unwrap();
504    /// assert!(!non_empty.is_empty());
505    /// ```
506    pub fn is_empty(&self) -> bool {
507        self.data.is_empty()
508    }
509
510    /// Returns a borrowed view when the logical tensor is dense and the
511    /// payload is already stored contiguously in column-major order.
512    ///
513    /// Returns `None` for diagonal or non-contiguous payloads.
514    ///
515    /// # Examples
516    ///
517    /// ```
518    /// use tensor4all_tensorbackend::StructuredStorage;
519    ///
520    /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0, 3.0], &[3]).unwrap();
521    /// assert_eq!(s.dense_col_major_view_if_contiguous(), Some(&[1.0, 2.0, 3.0][..]));
522    ///
523    /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
524    /// assert_eq!(d.dense_col_major_view_if_contiguous(), None);
525    /// ```
526    pub fn dense_col_major_view_if_contiguous(&self) -> Option<&[T]> {
527        if self.is_dense()
528            && matches!(col_major_strides(&self.payload_dims), Ok(strides) if strides == self.strides)
529        {
530            Some(&self.data)
531        } else {
532            None
533        }
534    }
535
536    /// Returns a borrowed compact-payload view when the payload is already
537    /// stored contiguously in column-major order.
538    pub fn payload_col_major_view_if_contiguous(&self) -> Option<&[T]> {
539        if matches!(col_major_strides(&self.payload_dims), Ok(strides) if strides == self.strides) {
540            Some(&self.data)
541        } else {
542            None
543        }
544    }
545}
546
547impl<T: Clone> StructuredStorage<T> {
548    /// Materializes the payload tensor as a contiguous column-major buffer.
549    ///
550    /// If the payload is already column-major, returns a clone.
551    ///
552    /// # Examples
553    ///
554    /// ```
555    /// use tensor4all_tensorbackend::StructuredStorage;
556    ///
557    /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
558    /// assert_eq!(s.payload_col_major_vec(), vec![1.0, 2.0, 3.0, 4.0]);
559    /// ```
560    pub fn payload_col_major_vec(&self) -> Vec<T> {
561        let payload_len: usize = self.payload_dims.iter().product();
562        if payload_len == 0 {
563            return Vec::new();
564        }
565        if matches!(col_major_strides(&self.payload_dims), Ok(strides) if strides == self.strides) {
566            return self.data.clone();
567        }
568
569        (0..payload_len)
570            .map(|linear| {
571                let index = col_major_multi_index(linear, &self.payload_dims);
572                let offset = offset_from_strides(&index, &self.strides);
573                self.data[offset].clone()
574            })
575            .collect()
576    }
577
578    /// Returns a copy of the storage with logical axes permuted.
579    ///
580    /// # Errors
581    ///
582    /// Returns an error if `perm` is not a valid permutation of the logical axes.
583    ///
584    /// # Examples
585    ///
586    /// ```
587    /// use tensor4all_tensorbackend::StructuredStorage;
588    ///
589    /// // Diagonal 3x3x3 tensor; permute axes (identity for diag is always valid)
590    /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0, 3.0], 3).unwrap();
591    /// let p = d.permute_logical_axes(&[2, 0, 1]).unwrap();
592    /// // Diagonal: all axes share the same dimension, so dims stay the same
593    /// assert_eq!(p.logical_dims(), vec![3, 3, 3]);
594    /// assert!(p.is_diag());
595    /// ```
596    pub fn permute_logical_axes(&self, perm: &[usize]) -> Result<Self> {
597        ensure!(
598            perm.len() == self.axis_classes.len(),
599            "logical permutation length {} must match logical rank {}",
600            perm.len(),
601            self.axis_classes.len()
602        );
603        let mut seen = vec![false; self.axis_classes.len()];
604        let axis_classes = perm
605            .iter()
606            .map(|&index| {
607                ensure!(
608                    index < self.axis_classes.len(),
609                    "logical permutation axis {index} is out of range for rank {}",
610                    self.axis_classes.len()
611                );
612                ensure!(!seen[index], "logical permutation repeats axis {index}");
613                seen[index] = true;
614                Ok(self.axis_classes[index])
615            })
616            .collect::<Result<Vec<_>>>()?;
617        Self::new(
618            self.data.clone(),
619            self.payload_dims.clone(),
620            self.strides.clone(),
621            axis_classes,
622        )
623    }
624}
625
626impl<T: Copy> StructuredStorage<T> {
627    /// Maps payload elements while preserving payload metadata and axis classes.
628    ///
629    /// # Examples
630    ///
631    /// ```
632    /// use tensor4all_tensorbackend::StructuredStorage;
633    ///
634    /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0, 3.0], &[3]).unwrap();
635    /// let doubled = s.map_copy(|x| x * 2.0);
636    /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
637    /// ```
638    pub fn map_copy<U>(&self, mut f: impl FnMut(T) -> U) -> StructuredStorage<U> {
639        StructuredStorage {
640            data: self.data.iter().copied().map(&mut f).collect(),
641            payload_dims: self.payload_dims.clone(),
642            strides: self.strides.clone(),
643            axis_classes: self.axis_classes.clone(),
644        }
645    }
646}
647
648impl<T: Copy + Default> StructuredStorage<T> {
649    /// Materializes the logical tensor as a contiguous column-major dense buffer.
650    ///
651    /// Repeated entries in `axis_classes` encode equality constraints between
652    /// logical axes. Logical indices that violate those constraints are
653    /// structural zeros in the dense materialization.
654    ///
655    /// # Examples
656    ///
657    /// ```
658    /// use tensor4all_tensorbackend::StructuredStorage;
659    ///
660    /// // Diagonal [1, 2] in 2x2 becomes [1, 0, 0, 2] column-major
661    /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
662    /// assert_eq!(d.logical_dense_col_major_vec(), vec![1.0, 0.0, 0.0, 2.0]);
663    /// ```
664    pub fn logical_dense_col_major_vec(&self) -> Vec<T> {
665        let logical_dims = self.logical_dims();
666        let logical_len: usize = logical_dims.iter().product();
667        if logical_len == 0 {
668            return Vec::new();
669        }
670        if let Some(view) = self.dense_col_major_view_if_contiguous() {
671            return view.to_vec();
672        }
673        if self.is_dense() {
674            return self.payload_col_major_vec();
675        }
676
677        let payload_rank = self.payload_dims.len();
678        (0..logical_len)
679            .map(|linear| {
680                let logical_index = col_major_multi_index(linear, &logical_dims);
681                let mut payload_index = vec![0usize; payload_rank];
682                let mut seen = vec![false; payload_rank];
683                for (&value, &class_id) in logical_index.iter().zip(self.axis_classes.iter()) {
684                    if seen[class_id] {
685                        if payload_index[class_id] != value {
686                            return T::default();
687                        }
688                    } else {
689                        payload_index[class_id] = value;
690                        seen[class_id] = true;
691                    }
692                }
693                let offset = offset_from_strides(&payload_index, &self.strides);
694                self.data[offset]
695            })
696            .collect()
697    }
698}
699
700/// Storage backend for tensor data.
701///
702/// Public callers interact with this opaque wrapper through constructors and
703/// high-level query/materialization methods.
704///
705/// # Examples
706///
707/// ```
708/// use tensor4all_tensorbackend::Storage;
709///
710/// // Dense 2x3 matrix stored column-major: [[1,2,3],[4,5,6]]
711/// let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
712/// let s = Storage::from_dense_col_major(data, &[2, 3]).unwrap();
713/// assert!(s.is_f64());
714/// assert!(!s.is_complex());
715///
716/// // Diagonal storage: 2x2 identity-like diagonal
717/// let diag = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
718/// assert!(diag.is_f64());
719/// ```
720#[derive(Debug, Clone)]
721pub struct Storage(pub(crate) StorageRepr);
722
723/// Classifies the compact layout used by [`Storage`].
724///
725/// Use this to distinguish dense logical payloads from diagonal/copy payloads
726/// and general structured payloads without exposing the internal storage enum.
727///
728/// # Examples
729///
730/// ```
731/// use tensor4all_tensorbackend::{Storage, StorageKind};
732///
733/// let dense = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
734/// assert_eq!(dense.storage_kind(), StorageKind::Dense);
735///
736/// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
737/// assert_eq!(diag.storage_kind(), StorageKind::Diagonal);
738/// ```
739#[derive(Debug, Clone, Copy, PartialEq, Eq)]
740pub enum StorageKind {
741    /// Logical dense payload layout.
742    Dense,
743    /// Diagonal or copy-tensor payload layout.
744    Diagonal,
745    /// General structured payload layout with repeated axis classes.
746    Structured,
747}
748
749/// Errors returned by storage payload and elementwise operations.
750///
751/// Use this to distinguish scalar-kind mismatches, length mismatches, and
752/// invalid structured-storage metadata from general backend failures.
753///
754/// # Examples
755///
756/// ```
757/// use tensor4all_tensorbackend::{Storage, StorageError};
758///
759/// let storage = Storage::from_dense_col_major(vec![1.0_f64], &[1]).unwrap();
760/// let err = storage.payload_c64_col_major_vec().unwrap_err();
761/// assert!(matches!(err, StorageError::ScalarKindMismatch { .. }));
762/// ```
763#[derive(Debug, thiserror::Error)]
764pub enum StorageError {
765    /// The storage scalar kind did not match the requested operation.
766    #[error("expected {expected} storage when {operation}, got {actual}")]
767    ScalarKindMismatch {
768        /// The scalar kind that the caller requested.
769        expected: &'static str,
770        /// The scalar kind actually stored.
771        actual: &'static str,
772        /// Human-readable operation description.
773        operation: &'static str,
774    },
775    /// Two storages had different payload lengths for an elementwise operation.
776    #[error("storage lengths must match for {operation}: {left} != {right}")]
777    LengthMismatch {
778        /// Name of the operation being performed.
779        operation: &'static str,
780        /// Left-hand payload length.
781        left: usize,
782        /// Right-hand payload length.
783        right: usize,
784    },
785    /// Structured storage metadata was invalid after an operation.
786    #[error("invalid structured storage: {0}")]
787    InvalidStructuredStorage(String),
788    /// The requested operation does not support the provided storage kinds.
789    #[error("storage types are not supported for {operation}: {left} vs {right}")]
790    OperationNotSupported {
791        /// Name of the operation being performed.
792        operation: &'static str,
793        /// Left-hand storage kind.
794        left: &'static str,
795        /// Right-hand storage kind.
796        right: &'static str,
797    },
798    /// The requested operation requires real scalars but at least one scalar was complex.
799    #[error("expected real scalars in {operation} branch: a={a}, b={b}")]
800    RealScalarRequired {
801        /// Name of the operation being performed.
802        operation: &'static str,
803        /// Left scalar display string.
804        a: String,
805        /// Right scalar display string.
806        b: String,
807    },
808}
809
810/// Result type returned by storage methods that can fail with [`StorageError`].
811pub type StorageResult<T> = std::result::Result<T, StorageError>;
812
813#[derive(Debug, Clone)]
814pub(crate) enum StorageRepr {
815    /// Storage with f64 elements.
816    F64(StructuredStorage<f64>),
817    /// Storage with Complex64 elements.
818    C64(StructuredStorage<Complex64>),
819}
820
821fn storage_scalar_kind(repr: &StorageRepr) -> &'static str {
822    match repr {
823        StorageRepr::F64(_) => "f64",
824        StorageRepr::C64(_) => "Complex64",
825    }
826}
827
828/// Types that can be computed as the result of a reduction over `Storage`.
829///
830/// This lets callers write `let s: T = tensor.sum();` without matching on
831/// the storage variant. Implemented for `f64` and `Complex64`.
832///
833/// # Examples
834///
835/// ```
836/// use tensor4all_tensorbackend::{Storage, SumFromStorage};
837///
838/// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0], &[3]).unwrap();
839/// let total: f64 = f64::sum_from_storage(&s);
840/// assert!((total - 6.0).abs() < 1e-10);
841/// ```
842pub trait SumFromStorage: Sized {
843    /// Compute the sum of all elements in the storage.
844    fn sum_from_storage(storage: &Storage) -> Self;
845}
846
847impl SumFromStorage for f64 {
848    fn sum_from_storage(storage: &Storage) -> Self {
849        match &storage.0 {
850            StorageRepr::F64(v) => v.data().iter().copied().sum(),
851            StorageRepr::C64(v) => v.data().iter().map(|z| z.re).sum(),
852        }
853    }
854}
855
856impl SumFromStorage for Complex64 {
857    fn sum_from_storage(storage: &Storage) -> Self {
858        match &storage.0 {
859            StorageRepr::F64(v) => Complex64::new(v.data().iter().copied().sum(), 0.0),
860            StorageRepr::C64(v) => v.data().iter().copied().sum(),
861        }
862    }
863}
864
865// AnyScalar is now in its own module
866pub use crate::any_scalar::AnyScalar;
867
868impl Storage {
869    pub(crate) fn from_repr(repr: StorageRepr) -> Self {
870        Self(repr)
871    }
872
873    fn invalid_storage_error(err: anyhow::Error) -> StorageError {
874        StorageError::InvalidStructuredStorage(err.to_string())
875    }
876
877    #[cfg(test)]
878    pub(crate) fn repr(&self) -> &StorageRepr {
879        &self.0
880    }
881
882    fn validate_dense_len<T>(data: &[T], logical_dims: &[usize], label: &str) -> Result<()> {
883        let expected_len: usize = logical_dims.iter().product();
884        ensure!(
885            data.len() == expected_len,
886            "{label} len {} does not match logical dims {:?} (expected {})",
887            data.len(),
888            logical_dims,
889            expected_len
890        );
891        Ok(())
892    }
893
894    /// Create dense storage from column-major logical values (generic over scalar type).
895    ///
896    /// The scalar type is inferred from the `data` argument.
897    ///
898    /// # Errors
899    ///
900    /// Returns an error if the requested dense metadata overflows.
901    ///
902    /// # Examples
903    ///
904    /// ```
905    /// use tensor4all_tensorbackend::Storage;
906    ///
907    /// // 2x2 matrix, column-major: [[1,3],[2,4]]
908    /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
909    /// assert!(s.is_f64());
910    /// assert!(s.is_dense());
911    /// assert_eq!(s.len(), 4);
912    /// ```
913    pub fn from_dense_col_major<T: StorageScalar>(
914        data: Vec<T>,
915        logical_dims: &[usize],
916    ) -> Result<Self> {
917        T::build_dense_storage(data, logical_dims)
918    }
919
920    /// Create diagonal storage from column-major diagonal payload values (generic over scalar type).
921    ///
922    /// Creates a rank-2 diagonal storage by default. The scalar type is
923    /// inferred from `diag_data`.
924    ///
925    /// # Errors
926    ///
927    /// Currently infallible for valid data, but returns `Result` for consistency.
928    ///
929    /// # Examples
930    ///
931    /// ```
932    /// use tensor4all_tensorbackend::Storage;
933    ///
934    /// let s = Storage::from_diag_col_major(vec![1.0_f64, 2.0, 3.0], 2).unwrap();
935    /// assert!(s.is_diag());
936    /// assert!(s.is_f64());
937    /// assert_eq!(s.len(), 3);
938    /// ```
939    pub fn from_diag_col_major<T: StorageScalar>(
940        diag_data: Vec<T>,
941        logical_rank: usize,
942    ) -> Result<Self> {
943        T::build_diag_storage(diag_data, logical_rank)
944    }
945
946    /// Create a new 1D zero-initialized dense storage (generic over scalar type).
947    ///
948    /// # Examples
949    ///
950    /// ```
951    /// use tensor4all_tensorbackend::Storage;
952    ///
953    /// let s = Storage::new_dense::<f64>(5).unwrap();
954    /// assert!(s.is_dense());
955    /// assert_eq!(s.len(), 5);
956    /// assert!((s.max_abs()).abs() < 1e-10);
957    /// ```
958    pub fn new_dense<T: StorageScalar + Default>(size: usize) -> StorageResult<Self> {
959        Self::from_dense_col_major(vec![T::default(); size], &[size])
960            .map_err(Self::invalid_storage_error)
961    }
962
963    /// Create a new diagonal storage with the given diagonal data (generic over scalar type).
964    ///
965    /// # Errors
966    ///
967    /// Returns an error if diagonal metadata is invalid.
968    ///
969    /// # Examples
970    ///
971    /// ```
972    /// use tensor4all_tensorbackend::Storage;
973    ///
974    /// let s = Storage::new_diag(vec![1.0_f64, 2.0, 3.0]).unwrap();
975    /// assert!(s.is_diag());
976    /// assert!(s.is_f64());
977    /// ```
978    pub fn new_diag<T: StorageScalar>(diag_data: Vec<T>) -> StorageResult<Self> {
979        Self::from_diag_col_major(diag_data, 2).map_err(Self::invalid_storage_error)
980    }
981
982    /// Create a new structured storage (generic over scalar type).
983    ///
984    /// # Errors
985    ///
986    /// Returns an error if the structured metadata is inconsistent (see
987    /// [`StructuredStorage::new`] for details).
988    ///
989    /// # Examples
990    ///
991    /// ```
992    /// use tensor4all_tensorbackend::Storage;
993    ///
994    /// // Diagonal-like structured storage: axis_classes = [0, 0]
995    /// let s = Storage::new_structured(
996    ///     vec![1.0_f64, 2.0],
997    ///     vec![2],         // payload_dims
998    ///     vec![1],         // strides
999    ///     vec![0, 0],      // axis_classes: both axes map to payload axis 0
1000    /// ).unwrap();
1001    /// assert!(s.is_diag());
1002    /// ```
1003    pub fn new_structured<T: StorageScalar>(
1004        data: Vec<T>,
1005        payload_dims: Vec<usize>,
1006        strides: Vec<isize>,
1007        axis_classes: Vec<usize>,
1008    ) -> Result<Self> {
1009        T::build_structured_storage(data, payload_dims, strides, axis_classes)
1010    }
1011
1012    /// Create dense f64 storage from column-major logical values.
1013    ///
1014    /// # Errors
1015    ///
1016    /// Returns an error if `data.len()` does not match the product of `logical_dims`.
1017    ///
1018    /// # Examples
1019    ///
1020    /// ```
1021    /// use tensor4all_tensorbackend::Storage;
1022    ///
1023    /// let s = Storage::from_dense_f64_col_major(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1024    /// assert!(s.is_f64());
1025    /// assert!(s.is_dense());
1026    /// ```
1027    pub fn from_dense_f64_col_major(data: Vec<f64>, logical_dims: &[usize]) -> Result<Self> {
1028        Self::validate_dense_len(&data, logical_dims, "dense f64 payload")?;
1029        Ok(Self::from_repr(StorageRepr::F64(
1030            StructuredStorage::from_dense_col_major(data, logical_dims)?,
1031        )))
1032    }
1033
1034    /// Create dense Complex64 storage from column-major logical values.
1035    ///
1036    /// # Errors
1037    ///
1038    /// Returns an error if `data.len()` does not match the product of `logical_dims`.
1039    ///
1040    /// # Examples
1041    ///
1042    /// ```
1043    /// use tensor4all_tensorbackend::Storage;
1044    /// use num_complex::Complex64;
1045    ///
1046    /// let data = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)];
1047    /// let s = Storage::from_dense_c64_col_major(data, &[2]).unwrap();
1048    /// assert!(s.is_c64());
1049    /// assert!(s.is_dense());
1050    /// ```
1051    pub fn from_dense_c64_col_major(data: Vec<Complex64>, logical_dims: &[usize]) -> Result<Self> {
1052        Self::validate_dense_len(&data, logical_dims, "dense c64 payload")?;
1053        Ok(Self::from_repr(StorageRepr::C64(
1054            StructuredStorage::from_dense_col_major(data, logical_dims)?,
1055        )))
1056    }
1057
1058    /// Create diagonal f64 storage from column-major diagonal payload values.
1059    ///
1060    /// # Examples
1061    ///
1062    /// ```
1063    /// use tensor4all_tensorbackend::Storage;
1064    ///
1065    /// let s = Storage::from_diag_f64_col_major(vec![1.0, 2.0], 2).unwrap();
1066    /// assert!(s.is_diag());
1067    /// assert!(s.is_f64());
1068    /// ```
1069    pub fn from_diag_f64_col_major(diag_data: Vec<f64>, logical_rank: usize) -> Result<Self> {
1070        Ok(Self::from_repr(StorageRepr::F64(
1071            StructuredStorage::from_diag_col_major(diag_data, logical_rank)?,
1072        )))
1073    }
1074
1075    /// Create diagonal Complex64 storage from column-major diagonal payload values.
1076    ///
1077    /// # Examples
1078    ///
1079    /// ```
1080    /// use tensor4all_tensorbackend::Storage;
1081    /// use num_complex::Complex64;
1082    ///
1083    /// let data = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)];
1084    /// let s = Storage::from_diag_c64_col_major(data, 2).unwrap();
1085    /// assert!(s.is_diag());
1086    /// assert!(s.is_c64());
1087    /// ```
1088    pub fn from_diag_c64_col_major(diag_data: Vec<Complex64>, logical_rank: usize) -> Result<Self> {
1089        Ok(Self::from_repr(StorageRepr::C64(
1090            StructuredStorage::from_diag_col_major(diag_data, logical_rank)?,
1091        )))
1092    }
1093
1094    /// Check if this storage is logically dense.
1095    ///
1096    /// # Examples
1097    ///
1098    /// ```
1099    /// use tensor4all_tensorbackend::Storage;
1100    ///
1101    /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
1102    /// assert!(s.is_dense());
1103    ///
1104    /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1105    /// assert!(!d.is_dense());
1106    /// ```
1107    pub fn is_dense(&self) -> bool {
1108        match &self.0 {
1109            StorageRepr::F64(value) => value.is_dense(),
1110            StorageRepr::C64(value) => value.is_dense(),
1111        }
1112    }
1113
1114    /// Check if this storage is a Diag storage type.
1115    ///
1116    /// # Examples
1117    ///
1118    /// ```
1119    /// use tensor4all_tensorbackend::Storage;
1120    ///
1121    /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1122    /// assert!(d.is_diag());
1123    /// ```
1124    pub fn is_diag(&self) -> bool {
1125        match &self.0 {
1126            StorageRepr::F64(value) => value.is_diag(),
1127            StorageRepr::C64(value) => value.is_diag(),
1128        }
1129    }
1130
1131    /// Returns the compact layout class for this storage.
1132    ///
1133    /// The return value is metadata-only and never materializes dense logical
1134    /// values. Use it to choose whether to read compact payload metadata or
1135    /// dense logical values.
1136    ///
1137    /// # Examples
1138    ///
1139    /// ```
1140    /// use tensor4all_tensorbackend::{Storage, StorageKind};
1141    ///
1142    /// let structured = Storage::new_structured(
1143    ///     vec![1.0_f64, 2.0],
1144    ///     vec![2],
1145    ///     vec![1],
1146    ///     vec![0, 0],
1147    /// ).unwrap();
1148    /// assert_eq!(structured.storage_kind(), StorageKind::Diagonal);
1149    /// ```
1150    pub fn storage_kind(&self) -> StorageKind {
1151        if self.is_dense() {
1152            StorageKind::Dense
1153        } else if self.is_diag() {
1154            StorageKind::Diagonal
1155        } else {
1156            StorageKind::Structured
1157        }
1158    }
1159
1160    /// Returns the logical tensor dimensions represented by this storage.
1161    ///
1162    /// The dimensions are derived from payload dimensions and `axis_classes`.
1163    ///
1164    /// # Examples
1165    ///
1166    /// ```
1167    /// use tensor4all_tensorbackend::Storage;
1168    ///
1169    /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1170    /// assert_eq!(diag.logical_dims(), vec![2, 2]);
1171    /// ```
1172    pub fn logical_dims(&self) -> Vec<usize> {
1173        match &self.0 {
1174            StorageRepr::F64(value) => value.logical_dims(),
1175            StorageRepr::C64(value) => value.logical_dims(),
1176        }
1177    }
1178
1179    /// Returns the logical tensor rank represented by this storage.
1180    ///
1181    /// This equals `axis_classes().len()`, not necessarily `payload_dims().len()`.
1182    ///
1183    /// # Examples
1184    ///
1185    /// ```
1186    /// use tensor4all_tensorbackend::Storage;
1187    ///
1188    /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 3).unwrap();
1189    /// assert_eq!(diag.logical_rank(), 3);
1190    /// assert_eq!(diag.payload_dims(), &[2]);
1191    /// ```
1192    pub fn logical_rank(&self) -> usize {
1193        match &self.0 {
1194            StorageRepr::F64(value) => value.logical_rank(),
1195            StorageRepr::C64(value) => value.logical_rank(),
1196        }
1197    }
1198
1199    /// Returns the compact payload dimensions.
1200    ///
1201    /// For dense storage these match logical dimensions. For diagonal storage
1202    /// this is rank-1 even when the logical tensor has multiple axes.
1203    ///
1204    /// # Examples
1205    ///
1206    /// ```
1207    /// use tensor4all_tensorbackend::Storage;
1208    ///
1209    /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1210    /// assert_eq!(diag.payload_dims(), &[2]);
1211    /// ```
1212    pub fn payload_dims(&self) -> &[usize] {
1213        match &self.0 {
1214            StorageRepr::F64(value) => value.payload_dims(),
1215            StorageRepr::C64(value) => value.payload_dims(),
1216        }
1217    }
1218
1219    /// Returns the compact payload strides.
1220    ///
1221    /// Strides are measured in stored scalar elements and describe the compact
1222    /// payload buffer, not the logical dense tensor.
1223    ///
1224    /// # Examples
1225    ///
1226    /// ```
1227    /// use tensor4all_tensorbackend::Storage;
1228    ///
1229    /// let dense = Storage::from_dense_col_major(vec![0.0_f64; 6], &[2, 3]).unwrap();
1230    /// assert_eq!(dense.payload_strides(), &[1, 2]);
1231    /// ```
1232    pub fn payload_strides(&self) -> &[isize] {
1233        match &self.0 {
1234            StorageRepr::F64(value) => value.strides(),
1235            StorageRepr::C64(value) => value.strides(),
1236        }
1237    }
1238
1239    /// Returns logical-axis equivalence classes for this storage.
1240    ///
1241    /// Repeated class labels mean the corresponding logical axes share one
1242    /// payload axis. Dense storage has `[0, 1, ...]`; diagonal storage has
1243    /// repeated zero labels.
1244    ///
1245    /// # Examples
1246    ///
1247    /// ```
1248    /// use tensor4all_tensorbackend::Storage;
1249    ///
1250    /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1251    /// assert_eq!(diag.axis_classes(), &[0, 0]);
1252    /// ```
1253    pub fn axis_classes(&self) -> &[usize] {
1254        match &self.0 {
1255            StorageRepr::F64(value) => value.axis_classes(),
1256            StorageRepr::C64(value) => value.axis_classes(),
1257        }
1258    }
1259
1260    /// Returns the number of stored compact payload elements.
1261    ///
1262    /// For dense storage this equals the logical dense length. For diagonal and
1263    /// structured storage this is the compact payload length.
1264    ///
1265    /// # Examples
1266    ///
1267    /// ```
1268    /// use tensor4all_tensorbackend::Storage;
1269    ///
1270    /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1271    /// assert_eq!(diag.payload_len(), 2);
1272    /// ```
1273    pub fn payload_len(&self) -> usize {
1274        self.len()
1275    }
1276
1277    /// Copies the compact `f64` payload in column-major payload order.
1278    ///
1279    /// This does not materialize logical dense values. For diagonal storage the
1280    /// returned vector contains only diagonal payload values.
1281    ///
1282    /// # Errors
1283    ///
1284    /// Returns an error if the storage scalar type is not `f64`.
1285    ///
1286    /// # Examples
1287    ///
1288    /// ```
1289    /// use tensor4all_tensorbackend::Storage;
1290    ///
1291    /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1292    /// assert_eq!(diag.payload_f64_col_major_vec().unwrap(), vec![1.0, 2.0]);
1293    /// ```
1294    pub fn payload_f64_col_major_vec(&self) -> StorageResult<Vec<f64>> {
1295        match &self.0 {
1296            StorageRepr::F64(value) => Ok(value.payload_col_major_vec()),
1297            StorageRepr::C64(_) => Err(StorageError::ScalarKindMismatch {
1298                expected: "f64",
1299                actual: storage_scalar_kind(&self.0),
1300                operation: "copying f64 payload",
1301            }),
1302        }
1303    }
1304
1305    /// Borrows the compact `f64` payload when it is already contiguous in
1306    /// column-major payload order.
1307    pub fn payload_f64_col_major_view_if_contiguous(&self) -> StorageResult<Option<&[f64]>> {
1308        match &self.0 {
1309            StorageRepr::F64(value) => Ok(value.payload_col_major_view_if_contiguous()),
1310            StorageRepr::C64(_) => Err(StorageError::ScalarKindMismatch {
1311                expected: "f64",
1312                actual: storage_scalar_kind(&self.0),
1313                operation: "borrowing f64 payload",
1314            }),
1315        }
1316    }
1317
1318    /// Copies the compact `Complex64` payload in column-major payload order.
1319    ///
1320    /// This does not materialize logical dense values. Complex payloads are
1321    /// returned as native Rust `Complex64` values.
1322    ///
1323    /// # Errors
1324    ///
1325    /// Returns an error if the storage scalar type is not `Complex64`.
1326    ///
1327    /// # Examples
1328    ///
1329    /// ```
1330    /// use num_complex::Complex64;
1331    /// use tensor4all_tensorbackend::Storage;
1332    ///
1333    /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1334    /// let diag = Storage::from_diag_col_major(data.clone(), 2).unwrap();
1335    /// assert_eq!(diag.payload_c64_col_major_vec().unwrap(), data);
1336    /// ```
1337    pub fn payload_c64_col_major_vec(&self) -> StorageResult<Vec<Complex64>> {
1338        match &self.0 {
1339            StorageRepr::C64(value) => Ok(value.payload_col_major_vec()),
1340            StorageRepr::F64(_) => Err(StorageError::ScalarKindMismatch {
1341                expected: "Complex64",
1342                actual: storage_scalar_kind(&self.0),
1343                operation: "copying c64 payload",
1344            }),
1345        }
1346    }
1347
1348    /// Borrows the compact `Complex64` payload when it is already contiguous in
1349    /// column-major payload order.
1350    pub fn payload_c64_col_major_view_if_contiguous(&self) -> StorageResult<Option<&[Complex64]>> {
1351        match &self.0 {
1352            StorageRepr::C64(value) => Ok(value.payload_col_major_view_if_contiguous()),
1353            StorageRepr::F64(_) => Err(StorageError::ScalarKindMismatch {
1354                expected: "Complex64",
1355                actual: storage_scalar_kind(&self.0),
1356                operation: "borrowing c64 payload",
1357            }),
1358        }
1359    }
1360
1361    /// Check if this storage uses f64 scalar type.
1362    ///
1363    /// # Examples
1364    ///
1365    /// ```
1366    /// use tensor4all_tensorbackend::Storage;
1367    ///
1368    /// let s = Storage::from_dense_col_major(vec![1.0_f64], &[1]).unwrap();
1369    /// assert!(s.is_f64());
1370    /// assert!(!s.is_c64());
1371    /// ```
1372    pub fn is_f64(&self) -> bool {
1373        matches!(&self.0, StorageRepr::F64(_))
1374    }
1375
1376    /// Check if this storage uses Complex64 scalar type.
1377    ///
1378    /// # Examples
1379    ///
1380    /// ```
1381    /// use tensor4all_tensorbackend::Storage;
1382    /// use num_complex::Complex64;
1383    ///
1384    /// let s = Storage::from_dense_col_major(
1385    ///     vec![Complex64::new(1.0, 0.0)], &[1],
1386    /// ).unwrap();
1387    /// assert!(s.is_c64());
1388    /// ```
1389    pub fn is_c64(&self) -> bool {
1390        matches!(&self.0, StorageRepr::C64(_))
1391    }
1392
1393    /// Check if this storage uses complex scalar type.
1394    ///
1395    /// This is an alias for [`is_c64()`](Self::is_c64).
1396    ///
1397    /// # Examples
1398    ///
1399    /// ```
1400    /// use tensor4all_tensorbackend::Storage;
1401    /// use num_complex::Complex64;
1402    ///
1403    /// let s = Storage::from_dense_col_major(
1404    ///     vec![Complex64::new(1.0, 0.0)], &[1],
1405    /// ).unwrap();
1406    /// assert!(s.is_complex());
1407    ///
1408    /// let r = Storage::from_dense_col_major(vec![1.0_f64], &[1]).unwrap();
1409    /// assert!(!r.is_complex());
1410    /// ```
1411    pub fn is_complex(&self) -> bool {
1412        self.is_c64()
1413    }
1414
1415    /// Get the length of the storage payload (number of stored elements).
1416    ///
1417    /// For dense storage this equals the product of logical dimensions.
1418    /// For diagonal storage this equals the diagonal length.
1419    ///
1420    /// # Examples
1421    ///
1422    /// ```
1423    /// use tensor4all_tensorbackend::Storage;
1424    ///
1425    /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0], &[3]).unwrap();
1426    /// assert_eq!(s.len(), 3);
1427    ///
1428    /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1429    /// assert_eq!(d.len(), 2);
1430    /// ```
1431    pub fn len(&self) -> usize {
1432        match &self.0 {
1433            StorageRepr::F64(v) => v.len(),
1434            StorageRepr::C64(v) => v.len(),
1435        }
1436    }
1437
1438    /// Check if the storage is empty.
1439    ///
1440    /// # Examples
1441    ///
1442    /// ```
1443    /// use tensor4all_tensorbackend::Storage;
1444    ///
1445    /// let s = Storage::new_dense::<f64>(0).unwrap();
1446    /// assert!(s.is_empty());
1447    ///
1448    /// let s2 = Storage::new_dense::<f64>(3).unwrap();
1449    /// assert!(!s2.is_empty());
1450    /// ```
1451    pub fn is_empty(&self) -> bool {
1452        self.len() == 0
1453    }
1454
1455    /// Sum all elements, converting to type `T`.
1456    ///
1457    /// # Example
1458    /// ```
1459    /// use tensor4all_tensorbackend::Storage;
1460    /// let s = Storage::from_dense_col_major(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1461    /// assert_eq!(s.sum::<f64>(), 6.0);
1462    /// ```
1463    pub fn sum<T: SumFromStorage>(&self) -> T {
1464        T::sum_from_storage(self)
1465    }
1466
1467    /// Maximum absolute value over all stored elements.
1468    ///
1469    /// For real storage this is `max(|x|)`, and for complex storage this is
1470    /// `max(norm(z))`.
1471    ///
1472    /// # Examples
1473    ///
1474    /// ```
1475    /// use tensor4all_tensorbackend::Storage;
1476    ///
1477    /// let s = Storage::from_dense_col_major(vec![-3.0_f64, 1.0, 2.0], &[3]).unwrap();
1478    /// assert!((s.max_abs() - 3.0).abs() < 1e-10);
1479    /// ```
1480    pub fn max_abs(&self) -> f64 {
1481        match &self.0 {
1482            StorageRepr::F64(v) => v.data().iter().map(|x| x.abs()).fold(0.0_f64, f64::max),
1483            StorageRepr::C64(v) => v.data().iter().map(|z| z.norm()).fold(0.0_f64, f64::max),
1484        }
1485    }
1486
1487    /// Materialize dense logical values as a column-major `f64` buffer.
1488    ///
1489    /// For diagonal storage, off-diagonal entries are filled with zero.
1490    ///
1491    /// # Errors
1492    ///
1493    /// Returns an error if the storage is complex or `logical_dims` does not
1494    /// match the stored logical dimensions.
1495    ///
1496    /// # Examples
1497    ///
1498    /// ```
1499    /// use tensor4all_tensorbackend::Storage;
1500    ///
1501    /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1502    /// let dense = s.to_dense_f64_col_major_vec(&[2, 2]).unwrap();
1503    /// assert_eq!(dense, vec![1.0, 2.0, 3.0, 4.0]);
1504    /// ```
1505    pub fn to_dense_f64_col_major_vec(&self, logical_dims: &[usize]) -> StorageResult<Vec<f64>> {
1506        match &self.0 {
1507            StorageRepr::F64(v) => {
1508                let structured_dims = v.logical_dims();
1509                if structured_dims != logical_dims {
1510                    return Err(StorageError::InvalidStructuredStorage(format!(
1511                        "logical dims {:?} do not match StructuredF64 logical dims {:?}",
1512                        logical_dims, structured_dims
1513                    )));
1514                }
1515                Ok(v.logical_dense_col_major_vec())
1516            }
1517            StorageRepr::C64(_) => Err(StorageError::ScalarKindMismatch {
1518                expected: "f64",
1519                actual: storage_scalar_kind(&self.0),
1520                operation: "materializing dense f64 values",
1521            }),
1522        }
1523    }
1524
1525    /// Materialize dense logical values as a column-major `Complex64` buffer.
1526    ///
1527    /// # Errors
1528    ///
1529    /// Returns an error if the storage is real or `logical_dims` does not
1530    /// match the stored logical dimensions.
1531    ///
1532    /// # Examples
1533    ///
1534    /// ```
1535    /// use tensor4all_tensorbackend::Storage;
1536    /// use num_complex::Complex64;
1537    ///
1538    /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1539    /// let s = Storage::from_dense_col_major(data.clone(), &[2]).unwrap();
1540    /// let dense = s.to_dense_c64_col_major_vec(&[2]).unwrap();
1541    /// assert_eq!(dense, data);
1542    /// ```
1543    pub fn to_dense_c64_col_major_vec(
1544        &self,
1545        logical_dims: &[usize],
1546    ) -> StorageResult<Vec<Complex64>> {
1547        match &self.0 {
1548            StorageRepr::C64(v) => {
1549                let structured_dims = v.logical_dims();
1550                if structured_dims != logical_dims {
1551                    return Err(StorageError::InvalidStructuredStorage(format!(
1552                        "logical dims {:?} do not match StructuredC64 logical dims {:?}",
1553                        logical_dims, structured_dims
1554                    )));
1555                }
1556                Ok(v.logical_dense_col_major_vec())
1557            }
1558            StorageRepr::F64(_) => Err(StorageError::ScalarKindMismatch {
1559                expected: "Complex64",
1560                actual: storage_scalar_kind(&self.0),
1561                operation: "materializing dense c64 values",
1562            }),
1563        }
1564    }
1565
1566    /// Convert this storage to dense storage.
1567    ///
1568    /// For Diag storage, creates a Dense storage with diagonal elements set
1569    /// and off-diagonal elements as zero. For Dense storage, returns a copy.
1570    ///
1571    /// # Errors
1572    ///
1573    /// Returns an error if `dims` does not match the stored logical dimensions
1574    /// or if dense storage construction fails.
1575    ///
1576    /// # Examples
1577    ///
1578    /// ```
1579    /// use tensor4all_tensorbackend::Storage;
1580    ///
1581    /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1582    /// let dense = d.to_dense_storage(&[2, 2]).unwrap();
1583    /// assert!(dense.is_dense());
1584    /// let vals = dense.to_dense_f64_col_major_vec(&[2, 2]).unwrap();
1585    /// assert_eq!(vals, vec![1.0, 0.0, 0.0, 2.0]);
1586    /// ```
1587    pub fn to_dense_storage(&self, dims: &[usize]) -> StorageResult<Storage> {
1588        if self.is_f64() {
1589            let values = self.to_dense_f64_col_major_vec(dims)?;
1590            Storage::from_dense_col_major(values, dims).map_err(Self::invalid_storage_error)
1591        } else {
1592            let values = self.to_dense_c64_col_major_vec(dims)?;
1593            Storage::from_dense_col_major(values, dims).map_err(Self::invalid_storage_error)
1594        }
1595    }
1596
1597    /// Permute the storage data according to the given permutation.
1598    ///
1599    /// The `_dims` parameter is currently unused (reserved for future use).
1600    ///
1601    /// # Examples
1602    ///
1603    /// ```
1604    /// use tensor4all_tensorbackend::Storage;
1605    ///
1606    /// // Diagonal 2x2 tensor, permute axes (identity perm for diag is valid)
1607    /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1608    /// let t = d.permute_storage(&[2, 2], &[1, 0]).unwrap();
1609    /// assert!(t.is_diag());
1610    /// ```
1611    pub fn permute_storage(&self, _dims: &[usize], perm: &[usize]) -> StorageResult<Storage> {
1612        match &self.0 {
1613            StorageRepr::F64(v) => Ok(Storage::from_repr(StorageRepr::F64(
1614                v.permute_logical_axes(perm)
1615                    .map_err(Self::invalid_storage_error)?,
1616            ))),
1617            StorageRepr::C64(v) => Ok(Storage::from_repr(StorageRepr::C64(
1618                v.permute_logical_axes(perm)
1619                    .map_err(Self::invalid_storage_error)?,
1620            ))),
1621        }
1622    }
1623
1624    /// Extract real part from Complex64 storage as f64 storage.
1625    /// For f64 storage, returns a copy (clone).
1626    ///
1627    /// # Examples
1628    ///
1629    /// ```
1630    /// use tensor4all_tensorbackend::Storage;
1631    /// use num_complex::Complex64;
1632    ///
1633    /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1634    /// let s = Storage::from_dense_col_major(data, &[2]).unwrap();
1635    /// let re = s.extract_real_part();
1636    /// assert!(re.is_f64());
1637    /// assert_eq!(re.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![1.0, 3.0]);
1638    /// ```
1639    pub fn extract_real_part(&self) -> Storage {
1640        match &self.0 {
1641            StorageRepr::F64(v) => Storage::from_repr(StorageRepr::F64(v.clone())),
1642            StorageRepr::C64(v) => Storage::from_repr(StorageRepr::F64(v.map_copy(|z| z.re))),
1643        }
1644    }
1645
1646    /// Extract imaginary part from Complex64 storage as f64 storage.
1647    /// For f64 storage, returns zero storage.
1648    ///
1649    /// # Examples
1650    ///
1651    /// ```
1652    /// use tensor4all_tensorbackend::Storage;
1653    /// use num_complex::Complex64;
1654    ///
1655    /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1656    /// let s = Storage::from_dense_col_major(data, &[2]).unwrap();
1657    /// let im = s.extract_imag_part(&[2]);
1658    /// assert!(im.is_f64());
1659    /// assert_eq!(im.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![2.0, 4.0]);
1660    /// ```
1661    pub fn extract_imag_part(&self, _dims: &[usize]) -> Storage {
1662        match &self.0 {
1663            StorageRepr::F64(v) => Storage::from_repr(StorageRepr::F64(v.map_copy(|_| 0.0))),
1664            StorageRepr::C64(v) => Storage::from_repr(StorageRepr::F64(v.map_copy(|z| z.im))),
1665        }
1666    }
1667
1668    /// Convert f64 storage to Complex64 storage (real part only, imaginary part is zero).
1669    /// For Complex64 storage, returns a clone.
1670    ///
1671    /// # Examples
1672    ///
1673    /// ```
1674    /// use tensor4all_tensorbackend::Storage;
1675    ///
1676    /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
1677    /// let c = s.to_complex_storage();
1678    /// assert!(c.is_c64());
1679    /// ```
1680    pub fn to_complex_storage(&self) -> Storage {
1681        match &self.0 {
1682            StorageRepr::F64(v) => {
1683                Storage::from_repr(StorageRepr::C64(v.map_copy(|x| Complex64::new(x, 0.0))))
1684            }
1685            StorageRepr::C64(v) => Storage::from_repr(StorageRepr::C64(v.clone())),
1686        }
1687    }
1688
1689    /// Complex conjugate of all elements.
1690    ///
1691    /// For real (f64) storage, returns a clone (conjugate of real is identity).
1692    /// For complex (Complex64) storage, conjugates each element.
1693    ///
1694    /// This is inspired by the `conj` operation in ITensorMPS.jl.
1695    ///
1696    /// # Examples
1697    /// ```
1698    /// use tensor4all_tensorbackend::Storage;
1699    /// use num_complex::Complex64;
1700    ///
1701    /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -4.0)];
1702    /// let storage = Storage::from_dense_col_major(data, &[2]).unwrap();
1703    /// let conj_storage = storage.conj();
1704    ///
1705    /// let result = conj_storage.to_dense_c64_col_major_vec(&[2]).unwrap();
1706    /// assert_eq!(result[0], Complex64::new(1.0, -2.0));
1707    /// assert_eq!(result[1], Complex64::new(3.0, 4.0));
1708    /// ```
1709    pub fn conj(&self) -> Self {
1710        match &self.0 {
1711            StorageRepr::F64(v) => Storage::from_repr(StorageRepr::F64(v.clone())),
1712            StorageRepr::C64(v) => Storage::from_repr(StorageRepr::C64(v.map_copy(|z| z.conj()))),
1713        }
1714    }
1715
1716    /// Combine two f64 storages into Complex64 storage.
1717    ///
1718    /// `real_storage` becomes the real part, `imag_storage` becomes the imaginary part.
1719    /// Formula: `real + i * imag`.
1720    ///
1721    /// # Errors
1722    ///
1723    /// Returns an error if either storage is not `f64`, if their payload lengths
1724    /// differ, or if the result metadata is invalid.
1725    ///
1726    /// # Examples
1727    ///
1728    /// ```
1729    /// use tensor4all_tensorbackend::Storage;
1730    /// use num_complex::Complex64;
1731    ///
1732    /// let re = Storage::from_dense_col_major(vec![1.0_f64, 3.0], &[2]).unwrap();
1733    /// let im = Storage::from_dense_col_major(vec![2.0_f64, 4.0], &[2]).unwrap();
1734    /// let c = Storage::combine_to_complex(&re, &im).unwrap();
1735    /// assert!(c.is_c64());
1736    /// let vals = c.to_dense_c64_col_major_vec(&[2]).unwrap();
1737    /// assert_eq!(vals[0], Complex64::new(1.0, 2.0));
1738    /// assert_eq!(vals[1], Complex64::new(3.0, 4.0));
1739    /// ```
1740    pub fn combine_to_complex(
1741        real_storage: &Storage,
1742        imag_storage: &Storage,
1743    ) -> StorageResult<Storage> {
1744        match (&real_storage.0, &imag_storage.0) {
1745            (StorageRepr::F64(real), StorageRepr::F64(imag)) => {
1746                if real.len() != imag.len() {
1747                    return Err(StorageError::LengthMismatch {
1748                        operation: "combine_to_complex",
1749                        left: real.len(),
1750                        right: imag.len(),
1751                    });
1752                }
1753                let complex_vec: Vec<Complex64> = real
1754                    .data()
1755                    .iter()
1756                    .zip(imag.data().iter())
1757                    .map(|(&r, &i)| Complex64::new(r, i))
1758                    .collect();
1759                Ok(Storage::from_repr(StorageRepr::C64(
1760                    StructuredStorage::new(
1761                        complex_vec,
1762                        real.payload_dims().to_vec(),
1763                        real.strides().to_vec(),
1764                        real.axis_classes().to_vec(),
1765                    )
1766                    .map_err(Self::invalid_storage_error)?,
1767                )))
1768            }
1769            _ => Err(StorageError::OperationNotSupported {
1770                operation: "combine_to_complex",
1771                left: storage_scalar_kind(&real_storage.0),
1772                right: storage_scalar_kind(&imag_storage.0),
1773            }),
1774        }
1775    }
1776
1777    /// Add two storages element-wise, returning `Result` on error instead of panicking.
1778    ///
1779    /// Both storages must have the same type and length.
1780    ///
1781    /// # Errors
1782    ///
1783    /// Returns an error if storage types or lengths don't match.
1784    ///
1785    /// # Examples
1786    ///
1787    /// ```
1788    /// use tensor4all_tensorbackend::Storage;
1789    ///
1790    /// let a = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
1791    /// let b = Storage::from_dense_col_major(vec![3.0_f64, 4.0], &[2]).unwrap();
1792    /// let c = a.try_add(&b).unwrap();
1793    /// assert_eq!(c.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![4.0, 6.0]);
1794    /// ```
1795    pub fn try_add(&self, other: &Storage) -> StorageResult<Storage> {
1796        match (&self.0, &other.0) {
1797            (StorageRepr::F64(a), StorageRepr::F64(b)) => {
1798                if a.len() != b.len() {
1799                    return Err(StorageError::LengthMismatch {
1800                        operation: "addition",
1801                        left: a.len(),
1802                        right: b.len(),
1803                    });
1804                }
1805                let sum_vec: Vec<f64> = a
1806                    .data()
1807                    .iter()
1808                    .zip(b.data().iter())
1809                    .map(|(&x, &y)| x + y)
1810                    .collect();
1811                Ok(Storage::from_repr(StorageRepr::F64(
1812                    StructuredStorage::new(
1813                        sum_vec,
1814                        a.payload_dims().to_vec(),
1815                        a.strides().to_vec(),
1816                        a.axis_classes().to_vec(),
1817                    )
1818                    .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
1819                )))
1820            }
1821            (StorageRepr::C64(a), StorageRepr::C64(b)) => {
1822                if a.len() != b.len() {
1823                    return Err(StorageError::LengthMismatch {
1824                        operation: "addition",
1825                        left: a.len(),
1826                        right: b.len(),
1827                    });
1828                }
1829                let sum_vec: Vec<Complex64> = a
1830                    .data()
1831                    .iter()
1832                    .zip(b.data().iter())
1833                    .map(|(&x, &y)| x + y)
1834                    .collect();
1835                Ok(Storage::from_repr(StorageRepr::C64(
1836                    StructuredStorage::new(
1837                        sum_vec,
1838                        a.payload_dims().to_vec(),
1839                        a.strides().to_vec(),
1840                        a.axis_classes().to_vec(),
1841                    )
1842                    .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
1843                )))
1844            }
1845            _ => Err(StorageError::OperationNotSupported {
1846                operation: "addition",
1847                left: storage_scalar_kind(&self.0),
1848                right: storage_scalar_kind(&other.0),
1849            }),
1850        }
1851    }
1852
1853    /// Try to subtract two storages element-wise.
1854    ///
1855    /// # Errors
1856    ///
1857    /// Returns an error if the storages have different types or lengths.
1858    ///
1859    /// # Examples
1860    ///
1861    /// ```
1862    /// use tensor4all_tensorbackend::Storage;
1863    ///
1864    /// let a = Storage::from_dense_col_major(vec![5.0_f64, 7.0], &[2]).unwrap();
1865    /// let b = Storage::from_dense_col_major(vec![1.0_f64, 3.0], &[2]).unwrap();
1866    /// let c = a.try_sub(&b).unwrap();
1867    /// assert_eq!(c.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![4.0, 4.0]);
1868    /// ```
1869    pub fn try_sub(&self, other: &Storage) -> StorageResult<Storage> {
1870        match (&self.0, &other.0) {
1871            (StorageRepr::F64(a), StorageRepr::F64(b)) => {
1872                if a.len() != b.len() {
1873                    return Err(StorageError::LengthMismatch {
1874                        operation: "subtraction",
1875                        left: a.len(),
1876                        right: b.len(),
1877                    });
1878                }
1879                let diff_vec: Vec<f64> = a
1880                    .data()
1881                    .iter()
1882                    .zip(b.data().iter())
1883                    .map(|(&x, &y)| x - y)
1884                    .collect();
1885                Ok(Storage::from_repr(StorageRepr::F64(
1886                    StructuredStorage::new(
1887                        diff_vec,
1888                        a.payload_dims().to_vec(),
1889                        a.strides().to_vec(),
1890                        a.axis_classes().to_vec(),
1891                    )
1892                    .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
1893                )))
1894            }
1895            (StorageRepr::C64(a), StorageRepr::C64(b)) => {
1896                if a.len() != b.len() {
1897                    return Err(StorageError::LengthMismatch {
1898                        operation: "subtraction",
1899                        left: a.len(),
1900                        right: b.len(),
1901                    });
1902                }
1903                let diff_vec: Vec<Complex64> = a
1904                    .data()
1905                    .iter()
1906                    .zip(b.data().iter())
1907                    .map(|(&x, &y)| x - y)
1908                    .collect();
1909                Ok(Storage::from_repr(StorageRepr::C64(
1910                    StructuredStorage::new(
1911                        diff_vec,
1912                        a.payload_dims().to_vec(),
1913                        a.strides().to_vec(),
1914                        a.axis_classes().to_vec(),
1915                    )
1916                    .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
1917                )))
1918            }
1919            _ => Err(StorageError::OperationNotSupported {
1920                operation: "subtraction",
1921                left: storage_scalar_kind(&self.0),
1922                right: storage_scalar_kind(&other.0),
1923            }),
1924        }
1925    }
1926
1927    /// Scale storage by a scalar value.
1928    ///
1929    /// If the scalar is complex but the storage is real, the storage is promoted to complex.
1930    ///
1931    /// # Examples
1932    ///
1933    /// ```
1934    /// use tensor4all_tensorbackend::{AnyScalar, Storage};
1935    ///
1936    /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0], &[3]).unwrap();
1937    /// let scaled = s.scale(&AnyScalar::new_real(2.0));
1938    /// assert_eq!(scaled.to_dense_f64_col_major_vec(&[3]).unwrap(), vec![2.0, 4.0, 6.0]);
1939    /// ```
1940    pub fn scale(&self, scalar: &crate::AnyScalar) -> Storage {
1941        self * scalar.clone()
1942    }
1943
1944    /// Compute linear combination: `a * self + b * other`.
1945    ///
1946    /// # Errors
1947    ///
1948    /// Returns an error if the storages have different types or lengths.
1949    /// If any scalar is complex, the result is promoted to complex.
1950    ///
1951    /// # Examples
1952    ///
1953    /// ```
1954    /// use tensor4all_tensorbackend::{AnyScalar, Storage};
1955    ///
1956    /// let x = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
1957    /// let y = Storage::from_dense_col_major(vec![3.0_f64, 4.0], &[2]).unwrap();
1958    /// let a = AnyScalar::new_real(2.0);
1959    /// let b = AnyScalar::new_real(3.0);
1960    /// // result = 2*[1,2] + 3*[3,4] = [11, 16]
1961    /// let result = x.axpby(&a, &y, &b).unwrap();
1962    /// assert_eq!(result.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![11.0, 16.0]);
1963    /// ```
1964    pub fn axpby(
1965        &self,
1966        a: &crate::AnyScalar,
1967        other: &Storage,
1968        b: &crate::AnyScalar,
1969    ) -> StorageResult<Storage> {
1970        // First check lengths match
1971        if self.len() != other.len() {
1972            return Err(StorageError::LengthMismatch {
1973                operation: "axpby",
1974                left: self.len(),
1975                right: other.len(),
1976            });
1977        }
1978
1979        // Determine if we need complex output
1980        let needs_complex = a.is_complex()
1981            || b.is_complex()
1982            || matches!(&self.0, StorageRepr::C64(_))
1983            || matches!(&other.0, StorageRepr::C64(_));
1984
1985        if needs_complex {
1986            // Promote everything to complex
1987            let a_c: Complex64 = a.clone().into();
1988            let b_c: Complex64 = b.clone().into();
1989
1990            let (result, payload_dims, strides, axis_classes): (
1991                Vec<Complex64>,
1992                Vec<usize>,
1993                Vec<isize>,
1994                Vec<usize>,
1995            ) = match (&self.0, &other.0) {
1996                (StorageRepr::F64(x), StorageRepr::F64(y)) => (
1997                    x.data()
1998                        .iter()
1999                        .zip(y.data().iter())
2000                        .map(|(&xi, &yi)| {
2001                            a_c * Complex64::new(xi, 0.0) + b_c * Complex64::new(yi, 0.0)
2002                        })
2003                        .collect(),
2004                    x.payload_dims().to_vec(),
2005                    x.strides().to_vec(),
2006                    x.axis_classes().to_vec(),
2007                ),
2008                (StorageRepr::F64(x), StorageRepr::C64(y)) => (
2009                    x.data()
2010                        .iter()
2011                        .zip(y.data().iter())
2012                        .map(|(&xi, &yi)| a_c * Complex64::new(xi, 0.0) + b_c * yi)
2013                        .collect(),
2014                    x.payload_dims().to_vec(),
2015                    x.strides().to_vec(),
2016                    x.axis_classes().to_vec(),
2017                ),
2018                (StorageRepr::C64(x), StorageRepr::F64(y)) => (
2019                    x.data()
2020                        .iter()
2021                        .zip(y.data().iter())
2022                        .map(|(&xi, &yi)| a_c * xi + b_c * Complex64::new(yi, 0.0))
2023                        .collect(),
2024                    x.payload_dims().to_vec(),
2025                    x.strides().to_vec(),
2026                    x.axis_classes().to_vec(),
2027                ),
2028                (StorageRepr::C64(x), StorageRepr::C64(y)) => (
2029                    x.data()
2030                        .iter()
2031                        .zip(y.data().iter())
2032                        .map(|(&xi, &yi)| a_c * xi + b_c * yi)
2033                        .collect(),
2034                    x.payload_dims().to_vec(),
2035                    x.strides().to_vec(),
2036                    x.axis_classes().to_vec(),
2037                ),
2038            };
2039            Ok(Storage::from_repr(StorageRepr::C64(
2040                StructuredStorage::new(result, payload_dims, strides, axis_classes)
2041                    .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
2042            )))
2043        } else {
2044            // All real
2045            if !a.is_real() || !b.is_real() {
2046                return Err(StorageError::RealScalarRequired {
2047                    operation: "real axpby",
2048                    a: a.to_string(),
2049                    b: b.to_string(),
2050                });
2051            }
2052            let a_f = a.real();
2053            let b_f = b.real();
2054
2055            match (&self.0, &other.0) {
2056                (StorageRepr::F64(x), StorageRepr::F64(y)) => {
2057                    let result: Vec<f64> = x
2058                        .data()
2059                        .iter()
2060                        .zip(y.data().iter())
2061                        .map(|(&xi, &yi)| a_f * xi + b_f * yi)
2062                        .collect();
2063                    Ok(Storage::from_repr(StorageRepr::F64(
2064                        StructuredStorage::new(
2065                            result,
2066                            x.payload_dims().to_vec(),
2067                            x.strides().to_vec(),
2068                            x.axis_classes().to_vec(),
2069                        )
2070                        .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
2071                    )))
2072                }
2073                _ => Err(StorageError::OperationNotSupported {
2074                    operation: "axpby",
2075                    left: storage_scalar_kind(&self.0),
2076                    right: storage_scalar_kind(&other.0),
2077                }),
2078            }
2079        }
2080    }
2081}
2082
2083/// Helper to get a mutable reference to storage, cloning if needed (COW).
2084///
2085/// Uses `Arc::make_mut` semantics: if the `Arc` has only one strong reference,
2086/// returns a mutable reference to the existing allocation. Otherwise clones
2087/// the inner value first.
2088///
2089/// # Examples
2090///
2091/// ```
2092/// use std::sync::Arc;
2093/// use tensor4all_tensorbackend::{make_mut_storage, Storage};
2094///
2095/// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
2096/// let mut arc = Arc::new(s);
2097/// let s_mut = make_mut_storage(&mut arc);
2098/// // s_mut is now a mutable reference to Storage
2099/// assert!(s_mut.is_f64());
2100/// ```
2101pub fn make_mut_storage(arc: &mut Arc<Storage>) -> &mut Storage {
2102    Arc::make_mut(arc)
2103}
2104
2105/// Get the minimum dimension from a slice of dimensions.
2106///
2107/// Returns 1 for an empty slice. This is used for DiagTensor where all
2108/// indices must have the same dimension.
2109///
2110/// # Examples
2111///
2112/// ```
2113/// use tensor4all_tensorbackend::mindim;
2114///
2115/// assert_eq!(mindim(&[2, 3, 4]), 2);
2116/// assert_eq!(mindim(&[5, 5, 5]), 5);
2117/// assert_eq!(mindim(&[]), 1);
2118/// ```
2119pub fn mindim(dims: &[usize]) -> usize {
2120    dims.iter().copied().min().unwrap_or(1)
2121}
2122
2123/// Contract two storage tensors along specified axes.
2124///
2125/// All storage is StructuredStorage; contraction is delegated to the native
2126/// tenferro backend. This is the primary tensor contraction entry point at
2127/// the storage layer.
2128///
2129/// # Arguments
2130///
2131/// * `storage_a` - First tensor storage
2132/// * `dims_a` - Dimensions of the first tensor
2133/// * `axes_a` - Axes of the first tensor to contract
2134/// * `storage_b` - Second tensor storage
2135/// * `dims_b` - Dimensions of the second tensor
2136/// * `axes_b` - Axes of the second tensor to contract
2137/// * `result_dims` - Dimensions of the result tensor (empty for scalar result)
2138///
2139/// # Returns
2140/// A new `Storage` containing the contracted result.
2141///
2142/// # Errors
2143///
2144/// Returns an error if axes are invalid, contracted dimensions do not match, or
2145/// the native backend rejects the contraction.
2146///
2147/// # Examples
2148///
2149/// ```
2150/// use tensor4all_tensorbackend::{contract_storage, Storage};
2151///
2152/// // Matrix-vector multiply: A(2x3) * v(3) -> result(2)
2153/// let a = Storage::from_dense_col_major(
2154///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3],
2155/// ).unwrap();
2156/// let v = Storage::from_dense_col_major(vec![1.0, 1.0, 1.0], &[3]).unwrap();
2157/// let result = contract_storage(&a, &[2, 3], &[1], &v, &[3], &[0], &[2]).unwrap();
2158/// // Row sums: [1+3+5, 2+4+6] = [9, 12]
2159/// let vals = result.to_dense_f64_col_major_vec(&[2]).unwrap();
2160/// assert!((vals[0] - 9.0).abs() < 1e-10);
2161/// assert!((vals[1] - 12.0).abs() < 1e-10);
2162/// ```
2163pub fn contract_storage(
2164    storage_a: &Storage,
2165    dims_a: &[usize],
2166    axes_a: &[usize],
2167    storage_b: &Storage,
2168    dims_b: &[usize],
2169    axes_b: &[usize],
2170    result_dims: &[usize],
2171) -> StorageResult<Storage> {
2172    try_contract_storage(
2173        storage_a,
2174        dims_a,
2175        axes_a,
2176        storage_b,
2177        dims_b,
2178        axes_b,
2179        result_dims,
2180    )
2181}
2182
2183fn try_contract_storage(
2184    storage_a: &Storage,
2185    dims_a: &[usize],
2186    axes_a: &[usize],
2187    storage_b: &Storage,
2188    dims_b: &[usize],
2189    axes_b: &[usize],
2190    result_dims: &[usize],
2191) -> StorageResult<Storage> {
2192    if axes_a.len() != axes_b.len() {
2193        return Err(StorageError::InvalidStructuredStorage(format!(
2194            "contract axes lengths must match: {} != {}",
2195            axes_a.len(),
2196            axes_b.len()
2197        )));
2198    }
2199
2200    for (&a_axis, &b_axis) in axes_a.iter().zip(axes_b.iter()) {
2201        let Some(&a_dim) = dims_a.get(a_axis) else {
2202            return Err(StorageError::InvalidStructuredStorage(format!(
2203                "contract axis {a_axis} is out of range for left dims {dims_a:?}"
2204            )));
2205        };
2206        let Some(&b_dim) = dims_b.get(b_axis) else {
2207            return Err(StorageError::InvalidStructuredStorage(format!(
2208                "contract axis {b_axis} is out of range for right dims {dims_b:?}"
2209            )));
2210        };
2211        if a_dim != b_dim {
2212            return Err(StorageError::InvalidStructuredStorage(format!(
2213                "contracted dimensions must match: dims_a[{a_axis}] = {a_dim} != dims_b[{b_axis}] = {b_dim}"
2214            )));
2215        }
2216    }
2217
2218    crate::tenferro_bridge::contract_storage_native(
2219        storage_a,
2220        dims_a,
2221        axes_a,
2222        storage_b,
2223        dims_b,
2224        axes_b,
2225        result_dims,
2226    )
2227    .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))
2228}
2229
2230/// Multiply storage by a scalar (f64).
2231/// For Complex64 storage, multiplies each element by the scalar (treated as real).
2232impl Mul<f64> for &Storage {
2233    type Output = Storage;
2234
2235    fn mul(self, scalar: f64) -> Self::Output {
2236        match &self.0 {
2237            StorageRepr::F64(v) => Storage::from_repr(StorageRepr::F64(v.map_copy(|x| x * scalar))),
2238            StorageRepr::C64(v) => Storage::from_repr(StorageRepr::C64(
2239                v.map_copy(|z| z * Complex64::new(scalar, 0.0)),
2240            )),
2241        }
2242    }
2243}
2244
2245/// Multiply storage by a scalar (Complex64).
2246impl Mul<Complex64> for &Storage {
2247    type Output = Storage;
2248
2249    fn mul(self, scalar: Complex64) -> Self::Output {
2250        match &self.0 {
2251            StorageRepr::F64(v) => Storage::from_repr(StorageRepr::C64(
2252                v.map_copy(|x| Complex64::new(x, 0.0) * scalar),
2253            )),
2254            StorageRepr::C64(v) => Storage::from_repr(StorageRepr::C64(v.map_copy(|z| z * scalar))),
2255        }
2256    }
2257}
2258
2259/// Multiply storage by a scalar (AnyScalar).
2260/// May promote f64 storage to Complex64 when scalar is complex.
2261impl Mul<AnyScalar> for &Storage {
2262    type Output = Storage;
2263
2264    fn mul(self, scalar: AnyScalar) -> Self::Output {
2265        if scalar.is_complex() {
2266            let z: Complex64 = scalar.into();
2267            self * z
2268        } else {
2269            self * scalar.real()
2270        }
2271    }
2272}
2273
2274#[cfg(test)]
2275mod tests;