tensor4all_tensorbackend/storage.rs
1use anyhow::{anyhow, ensure, Result};
2use num_complex::{Complex64, ComplexFloat};
3use std::ops::Mul;
4use std::sync::Arc;
5
6/// Trait for scalar types that can be stored in [`Storage`].
7///
8/// This enables generic constructors such as [`Storage::from_dense_col_major`]
9/// and [`Storage::from_diag_col_major`]. Implemented for `f64` and `Complex64`.
10///
11/// # Examples
12///
13/// ```
14/// use tensor4all_tensorbackend::{Storage, StorageScalar};
15///
16/// // Using the generic constructor -- scalar type is inferred from data
17/// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0], &[3]).unwrap();
18/// assert!(s.is_f64());
19///
20/// use num_complex::Complex64;
21/// let c = Storage::from_dense_col_major(
22/// vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
23/// &[2],
24/// ).unwrap();
25/// assert!(c.is_c64());
26/// ```
27pub trait StorageScalar: Clone + Send + Sync + 'static {
28 /// Build a dense [`Storage`] from column-major data.
29 fn build_dense_storage(data: Vec<Self>, logical_dims: &[usize]) -> Result<Storage>;
30 /// Build a diagonal [`Storage`] from diagonal payload data.
31 fn build_diag_storage(diag_data: Vec<Self>, logical_rank: usize) -> Result<Storage>;
32 /// Build a structured [`Storage`] from explicit payload metadata.
33 fn build_structured_storage(
34 data: Vec<Self>,
35 payload_dims: Vec<usize>,
36 strides: Vec<isize>,
37 axis_classes: Vec<usize>,
38 ) -> Result<Storage>;
39}
40
41impl StorageScalar for f64 {
42 fn build_dense_storage(data: Vec<Self>, logical_dims: &[usize]) -> Result<Storage> {
43 Storage::validate_dense_len(&data, logical_dims, "dense f64 payload")?;
44 Ok(Storage::from_repr(StorageRepr::F64(
45 StructuredStorage::from_dense_col_major(data, logical_dims)?,
46 )))
47 }
48 fn build_diag_storage(diag_data: Vec<Self>, logical_rank: usize) -> Result<Storage> {
49 Ok(Storage::from_repr(StorageRepr::F64(
50 StructuredStorage::from_diag_col_major(diag_data, logical_rank)?,
51 )))
52 }
53 fn build_structured_storage(
54 data: Vec<Self>,
55 payload_dims: Vec<usize>,
56 strides: Vec<isize>,
57 axis_classes: Vec<usize>,
58 ) -> Result<Storage> {
59 Ok(Storage::from_repr(StorageRepr::F64(
60 StructuredStorage::new(data, payload_dims, strides, axis_classes)?,
61 )))
62 }
63}
64
65impl StorageScalar for Complex64 {
66 fn build_dense_storage(data: Vec<Self>, logical_dims: &[usize]) -> Result<Storage> {
67 Storage::validate_dense_len(&data, logical_dims, "dense c64 payload")?;
68 Ok(Storage::from_repr(StorageRepr::C64(
69 StructuredStorage::from_dense_col_major(data, logical_dims)?,
70 )))
71 }
72 fn build_diag_storage(diag_data: Vec<Self>, logical_rank: usize) -> Result<Storage> {
73 Ok(Storage::from_repr(StorageRepr::C64(
74 StructuredStorage::from_diag_col_major(diag_data, logical_rank)?,
75 )))
76 }
77 fn build_structured_storage(
78 data: Vec<Self>,
79 payload_dims: Vec<usize>,
80 strides: Vec<isize>,
81 axis_classes: Vec<usize>,
82 ) -> Result<Storage> {
83 Ok(Storage::from_repr(StorageRepr::C64(
84 StructuredStorage::new(data, payload_dims, strides, axis_classes)?,
85 )))
86 }
87}
88
89pub(crate) fn col_major_strides(dims: &[usize]) -> Result<Vec<isize>> {
90 let mut strides = Vec::with_capacity(dims.len());
91 let mut stride = 1isize;
92 for &dim in dims {
93 strides.push(stride);
94 let dim = isize::try_from(dim)
95 .map_err(|_| anyhow!("column-major stride overflow for dims {dims:?}"))?;
96 stride = stride
97 .checked_mul(dim)
98 .ok_or_else(|| anyhow!("column-major stride overflow for dims {dims:?}"))?;
99 }
100 Ok(strides)
101}
102
103fn validate_canonical_axis_classes(axis_classes: &[usize]) -> Result<()> {
104 let mut next_class = 0usize;
105 for &class_id in axis_classes {
106 ensure!(
107 class_id <= next_class,
108 "axis_classes must be canonical first-appearance labels, got {axis_classes:?}"
109 );
110 if class_id == next_class {
111 next_class += 1;
112 }
113 }
114 Ok(())
115}
116
117fn required_storage_len(dims: &[usize], strides: &[isize]) -> Result<usize> {
118 if dims.is_empty() {
119 return Ok(1);
120 }
121 if dims.contains(&0) {
122 return Ok(0);
123 }
124 ensure!(
125 dims.len() == strides.len(),
126 "payload dims {:?} and strides {:?} must have the same rank",
127 dims,
128 strides
129 );
130
131 let mut max_offset = 0usize;
132 for (&dim, &stride) in dims.iter().zip(strides.iter()) {
133 ensure!(
134 stride >= 0,
135 "negative strides are not supported in StructuredStorage: {strides:?}"
136 );
137 if dim > 1 {
138 max_offset = max_offset
139 .checked_add((dim - 1) * usize::try_from(stride).unwrap_or(usize::MAX))
140 .ok_or_else(|| anyhow!("payload stride overflow for dims {dims:?}"))?;
141 }
142 }
143 Ok(max_offset + 1)
144}
145
146fn logical_dims_from_axis_classes(payload_dims: &[usize], axis_classes: &[usize]) -> Vec<usize> {
147 axis_classes
148 .iter()
149 .map(|&class_id| payload_dims[class_id])
150 .collect()
151}
152
153fn col_major_multi_index(mut linear: usize, dims: &[usize]) -> Vec<usize> {
154 let mut index = Vec::with_capacity(dims.len());
155 for &dim in dims {
156 if dim == 0 {
157 index.push(0);
158 } else {
159 index.push(linear % dim);
160 linear /= dim;
161 }
162 }
163 index
164}
165
166fn offset_from_strides(index: &[usize], strides: &[isize]) -> usize {
167 index
168 .iter()
169 .zip(strides.iter())
170 .map(|(&value, &stride)| value * usize::try_from(stride).unwrap_or(usize::MAX))
171 .sum()
172}
173
174/// Structured tensor snapshot storage.
175///
176/// `data` and `strides` describe the payload tensor, while `axis_classes`
177/// describes how logical axes map onto payload axes. Logical flat-buffer
178/// semantics are column-major.
179///
180/// A **dense** tensor has `axis_classes = [0, 1, ..., rank-1]` (each logical
181/// axis maps to a distinct payload axis). A **diagonal** tensor has
182/// `axis_classes = [0, 0, ..., 0]` (all logical axes share one payload axis),
183/// storing only the diagonal entries.
184///
185/// # Examples
186///
187/// ```
188/// use tensor4all_tensorbackend::StructuredStorage;
189///
190/// // Dense 2x3 storage, column-major: [[1,3,5],[2,4,6]]
191/// let dense = StructuredStorage::from_dense_col_major(
192/// vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3],
193/// ).unwrap();
194/// assert!(dense.is_dense());
195/// assert!(!dense.is_diag());
196/// assert_eq!(dense.logical_rank(), 2);
197/// assert_eq!(dense.logical_dims(), vec![2, 3]);
198///
199/// // Diagonal 3x3 storage
200/// let diag = StructuredStorage::from_diag_col_major(vec![1.0, 2.0, 3.0], 2).unwrap();
201/// assert!(diag.is_diag());
202/// assert_eq!(diag.logical_dims(), vec![3, 3]);
203/// assert_eq!(diag.len(), 3);
204/// ```
205#[derive(Debug, Clone, PartialEq)]
206pub struct StructuredStorage<T> {
207 data: Vec<T>,
208 payload_dims: Vec<usize>,
209 strides: Vec<isize>,
210 axis_classes: Vec<usize>,
211}
212
213impl<T> StructuredStorage<T> {
214 /// Creates a structured payload snapshot from explicit payload metadata.
215 ///
216 /// `payload_dims` and `strides` describe the compressed payload tensor,
217 /// while `axis_classes` maps logical axes onto payload axes in canonical
218 /// first-appearance order.
219 ///
220 /// # Errors
221 ///
222 /// Returns an error if:
223 /// - `axis_classes` is not in canonical first-appearance form
224 /// - `payload_dims` rank does not match `axis_classes`
225 /// - `strides` rank does not match `payload_dims`
226 /// - `data` length does not match the required storage length
227 ///
228 /// # Examples
229 ///
230 /// ```
231 /// use tensor4all_tensorbackend::StructuredStorage;
232 ///
233 /// // Dense 2x3 with explicit column-major strides
234 /// let s = StructuredStorage::new(
235 /// vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
236 /// vec![2, 3], // payload_dims
237 /// vec![1, 2], // column-major strides
238 /// vec![0, 1], // axis_classes: each axis is independent
239 /// ).unwrap();
240 /// assert!(s.is_dense());
241 /// assert_eq!(s.len(), 6);
242 /// ```
243 pub fn new(
244 data: Vec<T>,
245 payload_dims: Vec<usize>,
246 strides: Vec<isize>,
247 axis_classes: Vec<usize>,
248 ) -> Result<Self> {
249 validate_canonical_axis_classes(&axis_classes)?;
250 ensure!(
251 payload_dims.len()
252 == axis_classes
253 .iter()
254 .copied()
255 .max()
256 .map(|value| value + 1)
257 .unwrap_or(0),
258 "payload rank {} does not match axis_classes {:?}",
259 payload_dims.len(),
260 axis_classes
261 );
262 ensure!(
263 strides.len() == payload_dims.len(),
264 "payload dims {:?} and strides {:?} must have the same rank",
265 payload_dims,
266 strides
267 );
268 let required_len = required_storage_len(&payload_dims, &strides)?;
269 ensure!(
270 data.len() == required_len,
271 "payload storage len {} does not match required len {} for dims {:?} and strides {:?}",
272 data.len(),
273 required_len,
274 payload_dims,
275 strides
276 );
277 Ok(Self {
278 data,
279 payload_dims,
280 strides,
281 axis_classes,
282 })
283 }
284
285 /// Creates a dense structured snapshot from column-major logical data.
286 ///
287 /// # Errors
288 ///
289 /// Returns an error if `data.len()` does not equal the product of
290 /// `logical_dims`, or if column-major stride computation overflows.
291 ///
292 /// # Examples
293 ///
294 /// ```
295 /// use tensor4all_tensorbackend::StructuredStorage;
296 ///
297 /// let s = StructuredStorage::from_dense_col_major(vec![10.0, 20.0, 30.0, 40.0], &[2, 2]).unwrap();
298 /// assert!(s.is_dense());
299 /// assert_eq!(s.data(), &[10.0, 20.0, 30.0, 40.0]);
300 /// ```
301 pub fn from_dense_col_major(data: Vec<T>, logical_dims: &[usize]) -> Result<Self> {
302 let payload_dims = logical_dims.to_vec();
303 let strides = col_major_strides(&payload_dims)?;
304 let axis_classes = (0..logical_dims.len()).collect();
305 Self::new(data, payload_dims, strides, axis_classes)
306 }
307
308 /// Creates a diagonal structured snapshot from column-major diagonal data.
309 ///
310 /// The resulting tensor has `logical_rank` axes, each of size `diag_data.len()`.
311 /// Only the diagonal entries are stored.
312 ///
313 /// # Errors
314 ///
315 /// Returns an error if `logical_rank` is zero and the data does not contain
316 /// exactly one scalar value, or if column-major stride computation overflows.
317 ///
318 /// # Examples
319 ///
320 /// ```
321 /// use tensor4all_tensorbackend::StructuredStorage;
322 ///
323 /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0, 3.0], 2).unwrap();
324 /// assert!(d.is_diag());
325 /// assert_eq!(d.logical_dims(), vec![3, 3]);
326 /// assert_eq!(d.data(), &[1.0, 2.0, 3.0]);
327 /// ```
328 pub fn from_diag_col_major(diag_data: Vec<T>, logical_rank: usize) -> Result<Self> {
329 let payload_dims = if logical_rank == 0 {
330 vec![]
331 } else {
332 vec![diag_data.len()]
333 };
334 let strides = col_major_strides(&payload_dims)?;
335 let axis_classes = vec![0; logical_rank];
336 Self::new(diag_data, payload_dims, strides, axis_classes)
337 }
338
339 /// Returns the payload data buffer as a slice.
340 ///
341 /// # Examples
342 ///
343 /// ```
344 /// use tensor4all_tensorbackend::StructuredStorage;
345 ///
346 /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0], &[2]).unwrap();
347 /// assert_eq!(s.data(), &[1.0, 2.0]);
348 /// ```
349 pub fn data(&self) -> &[T] {
350 &self.data
351 }
352
353 /// Returns the payload tensor dimensions.
354 ///
355 /// For dense tensors, this equals the logical dimensions. For diagonal
356 /// tensors, this is a single-element slice `[n]` where `n` is the diagonal
357 /// length.
358 ///
359 /// # Examples
360 ///
361 /// ```
362 /// use tensor4all_tensorbackend::StructuredStorage;
363 ///
364 /// let s = StructuredStorage::from_dense_col_major(vec![0.0; 6], &[2, 3]).unwrap();
365 /// assert_eq!(s.payload_dims(), &[2, 3]);
366 ///
367 /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 3).unwrap();
368 /// assert_eq!(d.payload_dims(), &[2]);
369 /// ```
370 pub fn payload_dims(&self) -> &[usize] {
371 &self.payload_dims
372 }
373
374 /// Returns the payload tensor strides.
375 ///
376 /// # Examples
377 ///
378 /// ```
379 /// use tensor4all_tensorbackend::StructuredStorage;
380 ///
381 /// // Column-major 2x3: strides are [1, 2]
382 /// let s = StructuredStorage::from_dense_col_major(vec![0.0; 6], &[2, 3]).unwrap();
383 /// assert_eq!(s.strides(), &[1, 2]);
384 /// ```
385 pub fn strides(&self) -> &[isize] {
386 &self.strides
387 }
388
389 /// Returns the canonical logical-to-payload axis classes.
390 ///
391 /// Each entry maps a logical axis to a payload axis index. Repeated values
392 /// indicate axes that share the same payload dimension (e.g., diagonal).
393 ///
394 /// # Examples
395 ///
396 /// ```
397 /// use tensor4all_tensorbackend::StructuredStorage;
398 ///
399 /// let dense = StructuredStorage::from_dense_col_major(vec![0.0; 4], &[2, 2]).unwrap();
400 /// assert_eq!(dense.axis_classes(), &[0, 1]);
401 ///
402 /// let diag = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
403 /// assert_eq!(diag.axis_classes(), &[0, 0]);
404 /// ```
405 pub fn axis_classes(&self) -> &[usize] {
406 &self.axis_classes
407 }
408
409 /// Returns the logical dimensions derived from `payload_dims` and `axis_classes`.
410 ///
411 /// # Examples
412 ///
413 /// ```
414 /// use tensor4all_tensorbackend::StructuredStorage;
415 ///
416 /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0, 3.0], 3).unwrap();
417 /// assert_eq!(d.logical_dims(), vec![3, 3, 3]);
418 /// ```
419 pub fn logical_dims(&self) -> Vec<usize> {
420 logical_dims_from_axis_classes(&self.payload_dims, &self.axis_classes)
421 }
422
423 /// Returns the logical rank (number of logical axes).
424 ///
425 /// # Examples
426 ///
427 /// ```
428 /// use tensor4all_tensorbackend::StructuredStorage;
429 ///
430 /// let s = StructuredStorage::from_dense_col_major(vec![0.0; 6], &[2, 3]).unwrap();
431 /// assert_eq!(s.logical_rank(), 2);
432 /// ```
433 pub fn logical_rank(&self) -> usize {
434 self.axis_classes.len()
435 }
436
437 /// Returns `true` when the logical tensor is dense (each logical axis maps
438 /// to a unique payload axis).
439 ///
440 /// # Examples
441 ///
442 /// ```
443 /// use tensor4all_tensorbackend::StructuredStorage;
444 ///
445 /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0], &[2]).unwrap();
446 /// assert!(s.is_dense());
447 ///
448 /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
449 /// assert!(!d.is_dense());
450 /// ```
451 pub fn is_dense(&self) -> bool {
452 self.axis_classes
453 .iter()
454 .copied()
455 .eq(0..self.axis_classes.len())
456 }
457
458 /// Returns `true` when the logical tensor is diagonal (rank >= 2 and all
459 /// logical axes map to the same payload axis).
460 ///
461 /// # Examples
462 ///
463 /// ```
464 /// use tensor4all_tensorbackend::StructuredStorage;
465 ///
466 /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
467 /// assert!(d.is_diag());
468 ///
469 /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0], &[2]).unwrap();
470 /// assert!(!s.is_diag());
471 /// ```
472 pub fn is_diag(&self) -> bool {
473 self.logical_rank() >= 2 && self.axis_classes.iter().all(|&class_id| class_id == 0)
474 }
475
476 /// Returns the payload buffer length.
477 ///
478 /// # Examples
479 ///
480 /// ```
481 /// use tensor4all_tensorbackend::StructuredStorage;
482 ///
483 /// let dense = StructuredStorage::from_dense_col_major(vec![1.0, 2.0, 3.0], &[3]).unwrap();
484 /// assert_eq!(dense.len(), 3);
485 ///
486 /// let diag = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
487 /// assert_eq!(diag.len(), 2);
488 /// ```
489 pub fn len(&self) -> usize {
490 self.data.len()
491 }
492
493 /// Returns `true` when the payload buffer is empty.
494 ///
495 /// # Examples
496 ///
497 /// ```
498 /// use tensor4all_tensorbackend::StructuredStorage;
499 ///
500 /// let empty = StructuredStorage::from_dense_col_major(Vec::<f64>::new(), &[0]).unwrap();
501 /// assert!(empty.is_empty());
502 ///
503 /// let non_empty = StructuredStorage::from_dense_col_major(vec![1.0], &[1]).unwrap();
504 /// assert!(!non_empty.is_empty());
505 /// ```
506 pub fn is_empty(&self) -> bool {
507 self.data.is_empty()
508 }
509
510 /// Returns a borrowed view when the logical tensor is dense and the
511 /// payload is already stored contiguously in column-major order.
512 ///
513 /// Returns `None` for diagonal or non-contiguous payloads.
514 ///
515 /// # Examples
516 ///
517 /// ```
518 /// use tensor4all_tensorbackend::StructuredStorage;
519 ///
520 /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0, 3.0], &[3]).unwrap();
521 /// assert_eq!(s.dense_col_major_view_if_contiguous(), Some(&[1.0, 2.0, 3.0][..]));
522 ///
523 /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
524 /// assert_eq!(d.dense_col_major_view_if_contiguous(), None);
525 /// ```
526 pub fn dense_col_major_view_if_contiguous(&self) -> Option<&[T]> {
527 if self.is_dense()
528 && matches!(col_major_strides(&self.payload_dims), Ok(strides) if strides == self.strides)
529 {
530 Some(&self.data)
531 } else {
532 None
533 }
534 }
535
536 /// Returns a borrowed compact-payload view when the payload is already
537 /// stored contiguously in column-major order.
538 pub fn payload_col_major_view_if_contiguous(&self) -> Option<&[T]> {
539 if matches!(col_major_strides(&self.payload_dims), Ok(strides) if strides == self.strides) {
540 Some(&self.data)
541 } else {
542 None
543 }
544 }
545}
546
547impl<T: Clone> StructuredStorage<T> {
548 /// Materializes the payload tensor as a contiguous column-major buffer.
549 ///
550 /// If the payload is already column-major, returns a clone.
551 ///
552 /// # Examples
553 ///
554 /// ```
555 /// use tensor4all_tensorbackend::StructuredStorage;
556 ///
557 /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
558 /// assert_eq!(s.payload_col_major_vec(), vec![1.0, 2.0, 3.0, 4.0]);
559 /// ```
560 pub fn payload_col_major_vec(&self) -> Vec<T> {
561 let payload_len: usize = self.payload_dims.iter().product();
562 if payload_len == 0 {
563 return Vec::new();
564 }
565 if matches!(col_major_strides(&self.payload_dims), Ok(strides) if strides == self.strides) {
566 return self.data.clone();
567 }
568
569 (0..payload_len)
570 .map(|linear| {
571 let index = col_major_multi_index(linear, &self.payload_dims);
572 let offset = offset_from_strides(&index, &self.strides);
573 self.data[offset].clone()
574 })
575 .collect()
576 }
577
578 /// Returns a copy of the storage with logical axes permuted.
579 ///
580 /// # Errors
581 ///
582 /// Returns an error if `perm` is not a valid permutation of the logical axes.
583 ///
584 /// # Examples
585 ///
586 /// ```
587 /// use tensor4all_tensorbackend::StructuredStorage;
588 ///
589 /// // Diagonal 3x3x3 tensor; permute axes (identity for diag is always valid)
590 /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0, 3.0], 3).unwrap();
591 /// let p = d.permute_logical_axes(&[2, 0, 1]).unwrap();
592 /// // Diagonal: all axes share the same dimension, so dims stay the same
593 /// assert_eq!(p.logical_dims(), vec![3, 3, 3]);
594 /// assert!(p.is_diag());
595 /// ```
596 pub fn permute_logical_axes(&self, perm: &[usize]) -> Result<Self> {
597 ensure!(
598 perm.len() == self.axis_classes.len(),
599 "logical permutation length {} must match logical rank {}",
600 perm.len(),
601 self.axis_classes.len()
602 );
603 let mut seen = vec![false; self.axis_classes.len()];
604 let axis_classes = perm
605 .iter()
606 .map(|&index| {
607 ensure!(
608 index < self.axis_classes.len(),
609 "logical permutation axis {index} is out of range for rank {}",
610 self.axis_classes.len()
611 );
612 ensure!(!seen[index], "logical permutation repeats axis {index}");
613 seen[index] = true;
614 Ok(self.axis_classes[index])
615 })
616 .collect::<Result<Vec<_>>>()?;
617 Self::new(
618 self.data.clone(),
619 self.payload_dims.clone(),
620 self.strides.clone(),
621 axis_classes,
622 )
623 }
624}
625
626impl<T: Copy> StructuredStorage<T> {
627 /// Maps payload elements while preserving payload metadata and axis classes.
628 ///
629 /// # Examples
630 ///
631 /// ```
632 /// use tensor4all_tensorbackend::StructuredStorage;
633 ///
634 /// let s = StructuredStorage::from_dense_col_major(vec![1.0, 2.0, 3.0], &[3]).unwrap();
635 /// let doubled = s.map_copy(|x| x * 2.0);
636 /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
637 /// ```
638 pub fn map_copy<U>(&self, mut f: impl FnMut(T) -> U) -> StructuredStorage<U> {
639 StructuredStorage {
640 data: self.data.iter().copied().map(&mut f).collect(),
641 payload_dims: self.payload_dims.clone(),
642 strides: self.strides.clone(),
643 axis_classes: self.axis_classes.clone(),
644 }
645 }
646}
647
648impl<T: Copy + Default> StructuredStorage<T> {
649 /// Materializes the logical tensor as a contiguous column-major dense buffer.
650 ///
651 /// Repeated entries in `axis_classes` encode equality constraints between
652 /// logical axes. Logical indices that violate those constraints are
653 /// structural zeros in the dense materialization.
654 ///
655 /// # Examples
656 ///
657 /// ```
658 /// use tensor4all_tensorbackend::StructuredStorage;
659 ///
660 /// // Diagonal [1, 2] in 2x2 becomes [1, 0, 0, 2] column-major
661 /// let d = StructuredStorage::from_diag_col_major(vec![1.0, 2.0], 2).unwrap();
662 /// assert_eq!(d.logical_dense_col_major_vec(), vec![1.0, 0.0, 0.0, 2.0]);
663 /// ```
664 pub fn logical_dense_col_major_vec(&self) -> Vec<T> {
665 let logical_dims = self.logical_dims();
666 let logical_len: usize = logical_dims.iter().product();
667 if logical_len == 0 {
668 return Vec::new();
669 }
670 if let Some(view) = self.dense_col_major_view_if_contiguous() {
671 return view.to_vec();
672 }
673 if self.is_dense() {
674 return self.payload_col_major_vec();
675 }
676
677 let payload_rank = self.payload_dims.len();
678 (0..logical_len)
679 .map(|linear| {
680 let logical_index = col_major_multi_index(linear, &logical_dims);
681 let mut payload_index = vec![0usize; payload_rank];
682 let mut seen = vec![false; payload_rank];
683 for (&value, &class_id) in logical_index.iter().zip(self.axis_classes.iter()) {
684 if seen[class_id] {
685 if payload_index[class_id] != value {
686 return T::default();
687 }
688 } else {
689 payload_index[class_id] = value;
690 seen[class_id] = true;
691 }
692 }
693 let offset = offset_from_strides(&payload_index, &self.strides);
694 self.data[offset]
695 })
696 .collect()
697 }
698}
699
700/// Storage backend for tensor data.
701///
702/// Public callers interact with this opaque wrapper through constructors and
703/// high-level query/materialization methods.
704///
705/// # Examples
706///
707/// ```
708/// use tensor4all_tensorbackend::Storage;
709///
710/// // Dense 2x3 matrix stored column-major: [[1,2,3],[4,5,6]]
711/// let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
712/// let s = Storage::from_dense_col_major(data, &[2, 3]).unwrap();
713/// assert!(s.is_f64());
714/// assert!(!s.is_complex());
715///
716/// // Diagonal storage: 2x2 identity-like diagonal
717/// let diag = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
718/// assert!(diag.is_f64());
719/// ```
720#[derive(Debug, Clone)]
721pub struct Storage(pub(crate) StorageRepr);
722
723/// Classifies the compact layout used by [`Storage`].
724///
725/// Use this to distinguish dense logical payloads from diagonal/copy payloads
726/// and general structured payloads without exposing the internal storage enum.
727///
728/// # Examples
729///
730/// ```
731/// use tensor4all_tensorbackend::{Storage, StorageKind};
732///
733/// let dense = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
734/// assert_eq!(dense.storage_kind(), StorageKind::Dense);
735///
736/// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
737/// assert_eq!(diag.storage_kind(), StorageKind::Diagonal);
738/// ```
739#[derive(Debug, Clone, Copy, PartialEq, Eq)]
740pub enum StorageKind {
741 /// Logical dense payload layout.
742 Dense,
743 /// Diagonal or copy-tensor payload layout.
744 Diagonal,
745 /// General structured payload layout with repeated axis classes.
746 Structured,
747}
748
749/// Errors returned by storage payload and elementwise operations.
750///
751/// Use this to distinguish scalar-kind mismatches, length mismatches, and
752/// invalid structured-storage metadata from general backend failures.
753///
754/// # Examples
755///
756/// ```
757/// use tensor4all_tensorbackend::{Storage, StorageError};
758///
759/// let storage = Storage::from_dense_col_major(vec![1.0_f64], &[1]).unwrap();
760/// let err = storage.payload_c64_col_major_vec().unwrap_err();
761/// assert!(matches!(err, StorageError::ScalarKindMismatch { .. }));
762/// ```
763#[derive(Debug, thiserror::Error)]
764pub enum StorageError {
765 /// The storage scalar kind did not match the requested operation.
766 #[error("expected {expected} storage when {operation}, got {actual}")]
767 ScalarKindMismatch {
768 /// The scalar kind that the caller requested.
769 expected: &'static str,
770 /// The scalar kind actually stored.
771 actual: &'static str,
772 /// Human-readable operation description.
773 operation: &'static str,
774 },
775 /// Two storages had different payload lengths for an elementwise operation.
776 #[error("storage lengths must match for {operation}: {left} != {right}")]
777 LengthMismatch {
778 /// Name of the operation being performed.
779 operation: &'static str,
780 /// Left-hand payload length.
781 left: usize,
782 /// Right-hand payload length.
783 right: usize,
784 },
785 /// Structured storage metadata was invalid after an operation.
786 #[error("invalid structured storage: {0}")]
787 InvalidStructuredStorage(String),
788 /// The requested operation does not support the provided storage kinds.
789 #[error("storage types are not supported for {operation}: {left} vs {right}")]
790 OperationNotSupported {
791 /// Name of the operation being performed.
792 operation: &'static str,
793 /// Left-hand storage kind.
794 left: &'static str,
795 /// Right-hand storage kind.
796 right: &'static str,
797 },
798 /// The requested operation requires real scalars but at least one scalar was complex.
799 #[error("expected real scalars in {operation} branch: a={a}, b={b}")]
800 RealScalarRequired {
801 /// Name of the operation being performed.
802 operation: &'static str,
803 /// Left scalar display string.
804 a: String,
805 /// Right scalar display string.
806 b: String,
807 },
808}
809
810/// Result type returned by storage methods that can fail with [`StorageError`].
811pub type StorageResult<T> = std::result::Result<T, StorageError>;
812
813#[derive(Debug, Clone)]
814pub(crate) enum StorageRepr {
815 /// Storage with f64 elements.
816 F64(StructuredStorage<f64>),
817 /// Storage with Complex64 elements.
818 C64(StructuredStorage<Complex64>),
819}
820
821fn storage_scalar_kind(repr: &StorageRepr) -> &'static str {
822 match repr {
823 StorageRepr::F64(_) => "f64",
824 StorageRepr::C64(_) => "Complex64",
825 }
826}
827
828/// Types that can be computed as the result of a reduction over `Storage`.
829///
830/// This lets callers write `let s: T = tensor.sum();` without matching on
831/// the storage variant. Implemented for `f64` and `Complex64`.
832///
833/// # Examples
834///
835/// ```
836/// use tensor4all_tensorbackend::{Storage, SumFromStorage};
837///
838/// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0], &[3]).unwrap();
839/// let total: f64 = f64::sum_from_storage(&s);
840/// assert!((total - 6.0).abs() < 1e-10);
841/// ```
842pub trait SumFromStorage: Sized {
843 /// Compute the sum of all elements in the storage.
844 fn sum_from_storage(storage: &Storage) -> Self;
845}
846
847impl SumFromStorage for f64 {
848 fn sum_from_storage(storage: &Storage) -> Self {
849 match &storage.0 {
850 StorageRepr::F64(v) => v.data().iter().copied().sum(),
851 StorageRepr::C64(v) => v.data().iter().map(|z| z.re).sum(),
852 }
853 }
854}
855
856impl SumFromStorage for Complex64 {
857 fn sum_from_storage(storage: &Storage) -> Self {
858 match &storage.0 {
859 StorageRepr::F64(v) => Complex64::new(v.data().iter().copied().sum(), 0.0),
860 StorageRepr::C64(v) => v.data().iter().copied().sum(),
861 }
862 }
863}
864
865// AnyScalar is now in its own module
866pub use crate::any_scalar::AnyScalar;
867
868impl Storage {
869 pub(crate) fn from_repr(repr: StorageRepr) -> Self {
870 Self(repr)
871 }
872
873 fn invalid_storage_error(err: anyhow::Error) -> StorageError {
874 StorageError::InvalidStructuredStorage(err.to_string())
875 }
876
877 #[cfg(test)]
878 pub(crate) fn repr(&self) -> &StorageRepr {
879 &self.0
880 }
881
882 fn validate_dense_len<T>(data: &[T], logical_dims: &[usize], label: &str) -> Result<()> {
883 let expected_len: usize = logical_dims.iter().product();
884 ensure!(
885 data.len() == expected_len,
886 "{label} len {} does not match logical dims {:?} (expected {})",
887 data.len(),
888 logical_dims,
889 expected_len
890 );
891 Ok(())
892 }
893
894 /// Create dense storage from column-major logical values (generic over scalar type).
895 ///
896 /// The scalar type is inferred from the `data` argument.
897 ///
898 /// # Errors
899 ///
900 /// Returns an error if the requested dense metadata overflows.
901 ///
902 /// # Examples
903 ///
904 /// ```
905 /// use tensor4all_tensorbackend::Storage;
906 ///
907 /// // 2x2 matrix, column-major: [[1,3],[2,4]]
908 /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
909 /// assert!(s.is_f64());
910 /// assert!(s.is_dense());
911 /// assert_eq!(s.len(), 4);
912 /// ```
913 pub fn from_dense_col_major<T: StorageScalar>(
914 data: Vec<T>,
915 logical_dims: &[usize],
916 ) -> Result<Self> {
917 T::build_dense_storage(data, logical_dims)
918 }
919
920 /// Create diagonal storage from column-major diagonal payload values (generic over scalar type).
921 ///
922 /// Creates a rank-2 diagonal storage by default. The scalar type is
923 /// inferred from `diag_data`.
924 ///
925 /// # Errors
926 ///
927 /// Currently infallible for valid data, but returns `Result` for consistency.
928 ///
929 /// # Examples
930 ///
931 /// ```
932 /// use tensor4all_tensorbackend::Storage;
933 ///
934 /// let s = Storage::from_diag_col_major(vec![1.0_f64, 2.0, 3.0], 2).unwrap();
935 /// assert!(s.is_diag());
936 /// assert!(s.is_f64());
937 /// assert_eq!(s.len(), 3);
938 /// ```
939 pub fn from_diag_col_major<T: StorageScalar>(
940 diag_data: Vec<T>,
941 logical_rank: usize,
942 ) -> Result<Self> {
943 T::build_diag_storage(diag_data, logical_rank)
944 }
945
946 /// Create a new 1D zero-initialized dense storage (generic over scalar type).
947 ///
948 /// # Examples
949 ///
950 /// ```
951 /// use tensor4all_tensorbackend::Storage;
952 ///
953 /// let s = Storage::new_dense::<f64>(5).unwrap();
954 /// assert!(s.is_dense());
955 /// assert_eq!(s.len(), 5);
956 /// assert!((s.max_abs()).abs() < 1e-10);
957 /// ```
958 pub fn new_dense<T: StorageScalar + Default>(size: usize) -> StorageResult<Self> {
959 Self::from_dense_col_major(vec![T::default(); size], &[size])
960 .map_err(Self::invalid_storage_error)
961 }
962
963 /// Create a new diagonal storage with the given diagonal data (generic over scalar type).
964 ///
965 /// # Errors
966 ///
967 /// Returns an error if diagonal metadata is invalid.
968 ///
969 /// # Examples
970 ///
971 /// ```
972 /// use tensor4all_tensorbackend::Storage;
973 ///
974 /// let s = Storage::new_diag(vec![1.0_f64, 2.0, 3.0]).unwrap();
975 /// assert!(s.is_diag());
976 /// assert!(s.is_f64());
977 /// ```
978 pub fn new_diag<T: StorageScalar>(diag_data: Vec<T>) -> StorageResult<Self> {
979 Self::from_diag_col_major(diag_data, 2).map_err(Self::invalid_storage_error)
980 }
981
982 /// Create a new structured storage (generic over scalar type).
983 ///
984 /// # Errors
985 ///
986 /// Returns an error if the structured metadata is inconsistent (see
987 /// [`StructuredStorage::new`] for details).
988 ///
989 /// # Examples
990 ///
991 /// ```
992 /// use tensor4all_tensorbackend::Storage;
993 ///
994 /// // Diagonal-like structured storage: axis_classes = [0, 0]
995 /// let s = Storage::new_structured(
996 /// vec![1.0_f64, 2.0],
997 /// vec![2], // payload_dims
998 /// vec![1], // strides
999 /// vec![0, 0], // axis_classes: both axes map to payload axis 0
1000 /// ).unwrap();
1001 /// assert!(s.is_diag());
1002 /// ```
1003 pub fn new_structured<T: StorageScalar>(
1004 data: Vec<T>,
1005 payload_dims: Vec<usize>,
1006 strides: Vec<isize>,
1007 axis_classes: Vec<usize>,
1008 ) -> Result<Self> {
1009 T::build_structured_storage(data, payload_dims, strides, axis_classes)
1010 }
1011
1012 /// Create dense f64 storage from column-major logical values.
1013 ///
1014 /// # Errors
1015 ///
1016 /// Returns an error if `data.len()` does not match the product of `logical_dims`.
1017 ///
1018 /// # Examples
1019 ///
1020 /// ```
1021 /// use tensor4all_tensorbackend::Storage;
1022 ///
1023 /// let s = Storage::from_dense_f64_col_major(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1024 /// assert!(s.is_f64());
1025 /// assert!(s.is_dense());
1026 /// ```
1027 pub fn from_dense_f64_col_major(data: Vec<f64>, logical_dims: &[usize]) -> Result<Self> {
1028 Self::validate_dense_len(&data, logical_dims, "dense f64 payload")?;
1029 Ok(Self::from_repr(StorageRepr::F64(
1030 StructuredStorage::from_dense_col_major(data, logical_dims)?,
1031 )))
1032 }
1033
1034 /// Create dense Complex64 storage from column-major logical values.
1035 ///
1036 /// # Errors
1037 ///
1038 /// Returns an error if `data.len()` does not match the product of `logical_dims`.
1039 ///
1040 /// # Examples
1041 ///
1042 /// ```
1043 /// use tensor4all_tensorbackend::Storage;
1044 /// use num_complex::Complex64;
1045 ///
1046 /// let data = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)];
1047 /// let s = Storage::from_dense_c64_col_major(data, &[2]).unwrap();
1048 /// assert!(s.is_c64());
1049 /// assert!(s.is_dense());
1050 /// ```
1051 pub fn from_dense_c64_col_major(data: Vec<Complex64>, logical_dims: &[usize]) -> Result<Self> {
1052 Self::validate_dense_len(&data, logical_dims, "dense c64 payload")?;
1053 Ok(Self::from_repr(StorageRepr::C64(
1054 StructuredStorage::from_dense_col_major(data, logical_dims)?,
1055 )))
1056 }
1057
1058 /// Create diagonal f64 storage from column-major diagonal payload values.
1059 ///
1060 /// # Examples
1061 ///
1062 /// ```
1063 /// use tensor4all_tensorbackend::Storage;
1064 ///
1065 /// let s = Storage::from_diag_f64_col_major(vec![1.0, 2.0], 2).unwrap();
1066 /// assert!(s.is_diag());
1067 /// assert!(s.is_f64());
1068 /// ```
1069 pub fn from_diag_f64_col_major(diag_data: Vec<f64>, logical_rank: usize) -> Result<Self> {
1070 Ok(Self::from_repr(StorageRepr::F64(
1071 StructuredStorage::from_diag_col_major(diag_data, logical_rank)?,
1072 )))
1073 }
1074
1075 /// Create diagonal Complex64 storage from column-major diagonal payload values.
1076 ///
1077 /// # Examples
1078 ///
1079 /// ```
1080 /// use tensor4all_tensorbackend::Storage;
1081 /// use num_complex::Complex64;
1082 ///
1083 /// let data = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)];
1084 /// let s = Storage::from_diag_c64_col_major(data, 2).unwrap();
1085 /// assert!(s.is_diag());
1086 /// assert!(s.is_c64());
1087 /// ```
1088 pub fn from_diag_c64_col_major(diag_data: Vec<Complex64>, logical_rank: usize) -> Result<Self> {
1089 Ok(Self::from_repr(StorageRepr::C64(
1090 StructuredStorage::from_diag_col_major(diag_data, logical_rank)?,
1091 )))
1092 }
1093
1094 /// Check if this storage is logically dense.
1095 ///
1096 /// # Examples
1097 ///
1098 /// ```
1099 /// use tensor4all_tensorbackend::Storage;
1100 ///
1101 /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
1102 /// assert!(s.is_dense());
1103 ///
1104 /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1105 /// assert!(!d.is_dense());
1106 /// ```
1107 pub fn is_dense(&self) -> bool {
1108 match &self.0 {
1109 StorageRepr::F64(value) => value.is_dense(),
1110 StorageRepr::C64(value) => value.is_dense(),
1111 }
1112 }
1113
1114 /// Check if this storage is a Diag storage type.
1115 ///
1116 /// # Examples
1117 ///
1118 /// ```
1119 /// use tensor4all_tensorbackend::Storage;
1120 ///
1121 /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1122 /// assert!(d.is_diag());
1123 /// ```
1124 pub fn is_diag(&self) -> bool {
1125 match &self.0 {
1126 StorageRepr::F64(value) => value.is_diag(),
1127 StorageRepr::C64(value) => value.is_diag(),
1128 }
1129 }
1130
1131 /// Returns the compact layout class for this storage.
1132 ///
1133 /// The return value is metadata-only and never materializes dense logical
1134 /// values. Use it to choose whether to read compact payload metadata or
1135 /// dense logical values.
1136 ///
1137 /// # Examples
1138 ///
1139 /// ```
1140 /// use tensor4all_tensorbackend::{Storage, StorageKind};
1141 ///
1142 /// let structured = Storage::new_structured(
1143 /// vec![1.0_f64, 2.0],
1144 /// vec![2],
1145 /// vec![1],
1146 /// vec![0, 0],
1147 /// ).unwrap();
1148 /// assert_eq!(structured.storage_kind(), StorageKind::Diagonal);
1149 /// ```
1150 pub fn storage_kind(&self) -> StorageKind {
1151 if self.is_dense() {
1152 StorageKind::Dense
1153 } else if self.is_diag() {
1154 StorageKind::Diagonal
1155 } else {
1156 StorageKind::Structured
1157 }
1158 }
1159
1160 /// Returns the logical tensor dimensions represented by this storage.
1161 ///
1162 /// The dimensions are derived from payload dimensions and `axis_classes`.
1163 ///
1164 /// # Examples
1165 ///
1166 /// ```
1167 /// use tensor4all_tensorbackend::Storage;
1168 ///
1169 /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1170 /// assert_eq!(diag.logical_dims(), vec![2, 2]);
1171 /// ```
1172 pub fn logical_dims(&self) -> Vec<usize> {
1173 match &self.0 {
1174 StorageRepr::F64(value) => value.logical_dims(),
1175 StorageRepr::C64(value) => value.logical_dims(),
1176 }
1177 }
1178
1179 /// Returns the logical tensor rank represented by this storage.
1180 ///
1181 /// This equals `axis_classes().len()`, not necessarily `payload_dims().len()`.
1182 ///
1183 /// # Examples
1184 ///
1185 /// ```
1186 /// use tensor4all_tensorbackend::Storage;
1187 ///
1188 /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 3).unwrap();
1189 /// assert_eq!(diag.logical_rank(), 3);
1190 /// assert_eq!(diag.payload_dims(), &[2]);
1191 /// ```
1192 pub fn logical_rank(&self) -> usize {
1193 match &self.0 {
1194 StorageRepr::F64(value) => value.logical_rank(),
1195 StorageRepr::C64(value) => value.logical_rank(),
1196 }
1197 }
1198
1199 /// Returns the compact payload dimensions.
1200 ///
1201 /// For dense storage these match logical dimensions. For diagonal storage
1202 /// this is rank-1 even when the logical tensor has multiple axes.
1203 ///
1204 /// # Examples
1205 ///
1206 /// ```
1207 /// use tensor4all_tensorbackend::Storage;
1208 ///
1209 /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1210 /// assert_eq!(diag.payload_dims(), &[2]);
1211 /// ```
1212 pub fn payload_dims(&self) -> &[usize] {
1213 match &self.0 {
1214 StorageRepr::F64(value) => value.payload_dims(),
1215 StorageRepr::C64(value) => value.payload_dims(),
1216 }
1217 }
1218
1219 /// Returns the compact payload strides.
1220 ///
1221 /// Strides are measured in stored scalar elements and describe the compact
1222 /// payload buffer, not the logical dense tensor.
1223 ///
1224 /// # Examples
1225 ///
1226 /// ```
1227 /// use tensor4all_tensorbackend::Storage;
1228 ///
1229 /// let dense = Storage::from_dense_col_major(vec![0.0_f64; 6], &[2, 3]).unwrap();
1230 /// assert_eq!(dense.payload_strides(), &[1, 2]);
1231 /// ```
1232 pub fn payload_strides(&self) -> &[isize] {
1233 match &self.0 {
1234 StorageRepr::F64(value) => value.strides(),
1235 StorageRepr::C64(value) => value.strides(),
1236 }
1237 }
1238
1239 /// Returns logical-axis equivalence classes for this storage.
1240 ///
1241 /// Repeated class labels mean the corresponding logical axes share one
1242 /// payload axis. Dense storage has `[0, 1, ...]`; diagonal storage has
1243 /// repeated zero labels.
1244 ///
1245 /// # Examples
1246 ///
1247 /// ```
1248 /// use tensor4all_tensorbackend::Storage;
1249 ///
1250 /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1251 /// assert_eq!(diag.axis_classes(), &[0, 0]);
1252 /// ```
1253 pub fn axis_classes(&self) -> &[usize] {
1254 match &self.0 {
1255 StorageRepr::F64(value) => value.axis_classes(),
1256 StorageRepr::C64(value) => value.axis_classes(),
1257 }
1258 }
1259
1260 /// Returns the number of stored compact payload elements.
1261 ///
1262 /// For dense storage this equals the logical dense length. For diagonal and
1263 /// structured storage this is the compact payload length.
1264 ///
1265 /// # Examples
1266 ///
1267 /// ```
1268 /// use tensor4all_tensorbackend::Storage;
1269 ///
1270 /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1271 /// assert_eq!(diag.payload_len(), 2);
1272 /// ```
1273 pub fn payload_len(&self) -> usize {
1274 self.len()
1275 }
1276
1277 /// Copies the compact `f64` payload in column-major payload order.
1278 ///
1279 /// This does not materialize logical dense values. For diagonal storage the
1280 /// returned vector contains only diagonal payload values.
1281 ///
1282 /// # Errors
1283 ///
1284 /// Returns an error if the storage scalar type is not `f64`.
1285 ///
1286 /// # Examples
1287 ///
1288 /// ```
1289 /// use tensor4all_tensorbackend::Storage;
1290 ///
1291 /// let diag = Storage::from_diag_col_major(vec![1.0_f64, 2.0], 2).unwrap();
1292 /// assert_eq!(diag.payload_f64_col_major_vec().unwrap(), vec![1.0, 2.0]);
1293 /// ```
1294 pub fn payload_f64_col_major_vec(&self) -> StorageResult<Vec<f64>> {
1295 match &self.0 {
1296 StorageRepr::F64(value) => Ok(value.payload_col_major_vec()),
1297 StorageRepr::C64(_) => Err(StorageError::ScalarKindMismatch {
1298 expected: "f64",
1299 actual: storage_scalar_kind(&self.0),
1300 operation: "copying f64 payload",
1301 }),
1302 }
1303 }
1304
1305 /// Borrows the compact `f64` payload when it is already contiguous in
1306 /// column-major payload order.
1307 pub fn payload_f64_col_major_view_if_contiguous(&self) -> StorageResult<Option<&[f64]>> {
1308 match &self.0 {
1309 StorageRepr::F64(value) => Ok(value.payload_col_major_view_if_contiguous()),
1310 StorageRepr::C64(_) => Err(StorageError::ScalarKindMismatch {
1311 expected: "f64",
1312 actual: storage_scalar_kind(&self.0),
1313 operation: "borrowing f64 payload",
1314 }),
1315 }
1316 }
1317
1318 /// Copies the compact `Complex64` payload in column-major payload order.
1319 ///
1320 /// This does not materialize logical dense values. Complex payloads are
1321 /// returned as native Rust `Complex64` values.
1322 ///
1323 /// # Errors
1324 ///
1325 /// Returns an error if the storage scalar type is not `Complex64`.
1326 ///
1327 /// # Examples
1328 ///
1329 /// ```
1330 /// use num_complex::Complex64;
1331 /// use tensor4all_tensorbackend::Storage;
1332 ///
1333 /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1334 /// let diag = Storage::from_diag_col_major(data.clone(), 2).unwrap();
1335 /// assert_eq!(diag.payload_c64_col_major_vec().unwrap(), data);
1336 /// ```
1337 pub fn payload_c64_col_major_vec(&self) -> StorageResult<Vec<Complex64>> {
1338 match &self.0 {
1339 StorageRepr::C64(value) => Ok(value.payload_col_major_vec()),
1340 StorageRepr::F64(_) => Err(StorageError::ScalarKindMismatch {
1341 expected: "Complex64",
1342 actual: storage_scalar_kind(&self.0),
1343 operation: "copying c64 payload",
1344 }),
1345 }
1346 }
1347
1348 /// Borrows the compact `Complex64` payload when it is already contiguous in
1349 /// column-major payload order.
1350 pub fn payload_c64_col_major_view_if_contiguous(&self) -> StorageResult<Option<&[Complex64]>> {
1351 match &self.0 {
1352 StorageRepr::C64(value) => Ok(value.payload_col_major_view_if_contiguous()),
1353 StorageRepr::F64(_) => Err(StorageError::ScalarKindMismatch {
1354 expected: "Complex64",
1355 actual: storage_scalar_kind(&self.0),
1356 operation: "borrowing c64 payload",
1357 }),
1358 }
1359 }
1360
1361 /// Check if this storage uses f64 scalar type.
1362 ///
1363 /// # Examples
1364 ///
1365 /// ```
1366 /// use tensor4all_tensorbackend::Storage;
1367 ///
1368 /// let s = Storage::from_dense_col_major(vec![1.0_f64], &[1]).unwrap();
1369 /// assert!(s.is_f64());
1370 /// assert!(!s.is_c64());
1371 /// ```
1372 pub fn is_f64(&self) -> bool {
1373 matches!(&self.0, StorageRepr::F64(_))
1374 }
1375
1376 /// Check if this storage uses Complex64 scalar type.
1377 ///
1378 /// # Examples
1379 ///
1380 /// ```
1381 /// use tensor4all_tensorbackend::Storage;
1382 /// use num_complex::Complex64;
1383 ///
1384 /// let s = Storage::from_dense_col_major(
1385 /// vec![Complex64::new(1.0, 0.0)], &[1],
1386 /// ).unwrap();
1387 /// assert!(s.is_c64());
1388 /// ```
1389 pub fn is_c64(&self) -> bool {
1390 matches!(&self.0, StorageRepr::C64(_))
1391 }
1392
1393 /// Check if this storage uses complex scalar type.
1394 ///
1395 /// This is an alias for [`is_c64()`](Self::is_c64).
1396 ///
1397 /// # Examples
1398 ///
1399 /// ```
1400 /// use tensor4all_tensorbackend::Storage;
1401 /// use num_complex::Complex64;
1402 ///
1403 /// let s = Storage::from_dense_col_major(
1404 /// vec![Complex64::new(1.0, 0.0)], &[1],
1405 /// ).unwrap();
1406 /// assert!(s.is_complex());
1407 ///
1408 /// let r = Storage::from_dense_col_major(vec![1.0_f64], &[1]).unwrap();
1409 /// assert!(!r.is_complex());
1410 /// ```
1411 pub fn is_complex(&self) -> bool {
1412 self.is_c64()
1413 }
1414
1415 /// Get the length of the storage payload (number of stored elements).
1416 ///
1417 /// For dense storage this equals the product of logical dimensions.
1418 /// For diagonal storage this equals the diagonal length.
1419 ///
1420 /// # Examples
1421 ///
1422 /// ```
1423 /// use tensor4all_tensorbackend::Storage;
1424 ///
1425 /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0], &[3]).unwrap();
1426 /// assert_eq!(s.len(), 3);
1427 ///
1428 /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1429 /// assert_eq!(d.len(), 2);
1430 /// ```
1431 pub fn len(&self) -> usize {
1432 match &self.0 {
1433 StorageRepr::F64(v) => v.len(),
1434 StorageRepr::C64(v) => v.len(),
1435 }
1436 }
1437
1438 /// Check if the storage is empty.
1439 ///
1440 /// # Examples
1441 ///
1442 /// ```
1443 /// use tensor4all_tensorbackend::Storage;
1444 ///
1445 /// let s = Storage::new_dense::<f64>(0).unwrap();
1446 /// assert!(s.is_empty());
1447 ///
1448 /// let s2 = Storage::new_dense::<f64>(3).unwrap();
1449 /// assert!(!s2.is_empty());
1450 /// ```
1451 pub fn is_empty(&self) -> bool {
1452 self.len() == 0
1453 }
1454
1455 /// Sum all elements, converting to type `T`.
1456 ///
1457 /// # Example
1458 /// ```
1459 /// use tensor4all_tensorbackend::Storage;
1460 /// let s = Storage::from_dense_col_major(vec![1.0, 2.0, 3.0], &[3]).unwrap();
1461 /// assert_eq!(s.sum::<f64>(), 6.0);
1462 /// ```
1463 pub fn sum<T: SumFromStorage>(&self) -> T {
1464 T::sum_from_storage(self)
1465 }
1466
1467 /// Maximum absolute value over all stored elements.
1468 ///
1469 /// For real storage this is `max(|x|)`, and for complex storage this is
1470 /// `max(norm(z))`.
1471 ///
1472 /// # Examples
1473 ///
1474 /// ```
1475 /// use tensor4all_tensorbackend::Storage;
1476 ///
1477 /// let s = Storage::from_dense_col_major(vec![-3.0_f64, 1.0, 2.0], &[3]).unwrap();
1478 /// assert!((s.max_abs() - 3.0).abs() < 1e-10);
1479 /// ```
1480 pub fn max_abs(&self) -> f64 {
1481 match &self.0 {
1482 StorageRepr::F64(v) => v.data().iter().map(|x| x.abs()).fold(0.0_f64, f64::max),
1483 StorageRepr::C64(v) => v.data().iter().map(|z| z.norm()).fold(0.0_f64, f64::max),
1484 }
1485 }
1486
1487 /// Materialize dense logical values as a column-major `f64` buffer.
1488 ///
1489 /// For diagonal storage, off-diagonal entries are filled with zero.
1490 ///
1491 /// # Errors
1492 ///
1493 /// Returns an error if the storage is complex or `logical_dims` does not
1494 /// match the stored logical dimensions.
1495 ///
1496 /// # Examples
1497 ///
1498 /// ```
1499 /// use tensor4all_tensorbackend::Storage;
1500 ///
1501 /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1502 /// let dense = s.to_dense_f64_col_major_vec(&[2, 2]).unwrap();
1503 /// assert_eq!(dense, vec![1.0, 2.0, 3.0, 4.0]);
1504 /// ```
1505 pub fn to_dense_f64_col_major_vec(&self, logical_dims: &[usize]) -> StorageResult<Vec<f64>> {
1506 match &self.0 {
1507 StorageRepr::F64(v) => {
1508 let structured_dims = v.logical_dims();
1509 if structured_dims != logical_dims {
1510 return Err(StorageError::InvalidStructuredStorage(format!(
1511 "logical dims {:?} do not match StructuredF64 logical dims {:?}",
1512 logical_dims, structured_dims
1513 )));
1514 }
1515 Ok(v.logical_dense_col_major_vec())
1516 }
1517 StorageRepr::C64(_) => Err(StorageError::ScalarKindMismatch {
1518 expected: "f64",
1519 actual: storage_scalar_kind(&self.0),
1520 operation: "materializing dense f64 values",
1521 }),
1522 }
1523 }
1524
1525 /// Materialize dense logical values as a column-major `Complex64` buffer.
1526 ///
1527 /// # Errors
1528 ///
1529 /// Returns an error if the storage is real or `logical_dims` does not
1530 /// match the stored logical dimensions.
1531 ///
1532 /// # Examples
1533 ///
1534 /// ```
1535 /// use tensor4all_tensorbackend::Storage;
1536 /// use num_complex::Complex64;
1537 ///
1538 /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1539 /// let s = Storage::from_dense_col_major(data.clone(), &[2]).unwrap();
1540 /// let dense = s.to_dense_c64_col_major_vec(&[2]).unwrap();
1541 /// assert_eq!(dense, data);
1542 /// ```
1543 pub fn to_dense_c64_col_major_vec(
1544 &self,
1545 logical_dims: &[usize],
1546 ) -> StorageResult<Vec<Complex64>> {
1547 match &self.0 {
1548 StorageRepr::C64(v) => {
1549 let structured_dims = v.logical_dims();
1550 if structured_dims != logical_dims {
1551 return Err(StorageError::InvalidStructuredStorage(format!(
1552 "logical dims {:?} do not match StructuredC64 logical dims {:?}",
1553 logical_dims, structured_dims
1554 )));
1555 }
1556 Ok(v.logical_dense_col_major_vec())
1557 }
1558 StorageRepr::F64(_) => Err(StorageError::ScalarKindMismatch {
1559 expected: "Complex64",
1560 actual: storage_scalar_kind(&self.0),
1561 operation: "materializing dense c64 values",
1562 }),
1563 }
1564 }
1565
1566 /// Convert this storage to dense storage.
1567 ///
1568 /// For Diag storage, creates a Dense storage with diagonal elements set
1569 /// and off-diagonal elements as zero. For Dense storage, returns a copy.
1570 ///
1571 /// # Errors
1572 ///
1573 /// Returns an error if `dims` does not match the stored logical dimensions
1574 /// or if dense storage construction fails.
1575 ///
1576 /// # Examples
1577 ///
1578 /// ```
1579 /// use tensor4all_tensorbackend::Storage;
1580 ///
1581 /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1582 /// let dense = d.to_dense_storage(&[2, 2]).unwrap();
1583 /// assert!(dense.is_dense());
1584 /// let vals = dense.to_dense_f64_col_major_vec(&[2, 2]).unwrap();
1585 /// assert_eq!(vals, vec![1.0, 0.0, 0.0, 2.0]);
1586 /// ```
1587 pub fn to_dense_storage(&self, dims: &[usize]) -> StorageResult<Storage> {
1588 if self.is_f64() {
1589 let values = self.to_dense_f64_col_major_vec(dims)?;
1590 Storage::from_dense_col_major(values, dims).map_err(Self::invalid_storage_error)
1591 } else {
1592 let values = self.to_dense_c64_col_major_vec(dims)?;
1593 Storage::from_dense_col_major(values, dims).map_err(Self::invalid_storage_error)
1594 }
1595 }
1596
1597 /// Permute the storage data according to the given permutation.
1598 ///
1599 /// The `_dims` parameter is currently unused (reserved for future use).
1600 ///
1601 /// # Examples
1602 ///
1603 /// ```
1604 /// use tensor4all_tensorbackend::Storage;
1605 ///
1606 /// // Diagonal 2x2 tensor, permute axes (identity perm for diag is valid)
1607 /// let d = Storage::new_diag(vec![1.0_f64, 2.0]).unwrap();
1608 /// let t = d.permute_storage(&[2, 2], &[1, 0]).unwrap();
1609 /// assert!(t.is_diag());
1610 /// ```
1611 pub fn permute_storage(&self, _dims: &[usize], perm: &[usize]) -> StorageResult<Storage> {
1612 match &self.0 {
1613 StorageRepr::F64(v) => Ok(Storage::from_repr(StorageRepr::F64(
1614 v.permute_logical_axes(perm)
1615 .map_err(Self::invalid_storage_error)?,
1616 ))),
1617 StorageRepr::C64(v) => Ok(Storage::from_repr(StorageRepr::C64(
1618 v.permute_logical_axes(perm)
1619 .map_err(Self::invalid_storage_error)?,
1620 ))),
1621 }
1622 }
1623
1624 /// Extract real part from Complex64 storage as f64 storage.
1625 /// For f64 storage, returns a copy (clone).
1626 ///
1627 /// # Examples
1628 ///
1629 /// ```
1630 /// use tensor4all_tensorbackend::Storage;
1631 /// use num_complex::Complex64;
1632 ///
1633 /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1634 /// let s = Storage::from_dense_col_major(data, &[2]).unwrap();
1635 /// let re = s.extract_real_part();
1636 /// assert!(re.is_f64());
1637 /// assert_eq!(re.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![1.0, 3.0]);
1638 /// ```
1639 pub fn extract_real_part(&self) -> Storage {
1640 match &self.0 {
1641 StorageRepr::F64(v) => Storage::from_repr(StorageRepr::F64(v.clone())),
1642 StorageRepr::C64(v) => Storage::from_repr(StorageRepr::F64(v.map_copy(|z| z.re))),
1643 }
1644 }
1645
1646 /// Extract imaginary part from Complex64 storage as f64 storage.
1647 /// For f64 storage, returns zero storage.
1648 ///
1649 /// # Examples
1650 ///
1651 /// ```
1652 /// use tensor4all_tensorbackend::Storage;
1653 /// use num_complex::Complex64;
1654 ///
1655 /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1656 /// let s = Storage::from_dense_col_major(data, &[2]).unwrap();
1657 /// let im = s.extract_imag_part(&[2]);
1658 /// assert!(im.is_f64());
1659 /// assert_eq!(im.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![2.0, 4.0]);
1660 /// ```
1661 pub fn extract_imag_part(&self, _dims: &[usize]) -> Storage {
1662 match &self.0 {
1663 StorageRepr::F64(v) => Storage::from_repr(StorageRepr::F64(v.map_copy(|_| 0.0))),
1664 StorageRepr::C64(v) => Storage::from_repr(StorageRepr::F64(v.map_copy(|z| z.im))),
1665 }
1666 }
1667
1668 /// Convert f64 storage to Complex64 storage (real part only, imaginary part is zero).
1669 /// For Complex64 storage, returns a clone.
1670 ///
1671 /// # Examples
1672 ///
1673 /// ```
1674 /// use tensor4all_tensorbackend::Storage;
1675 ///
1676 /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
1677 /// let c = s.to_complex_storage();
1678 /// assert!(c.is_c64());
1679 /// ```
1680 pub fn to_complex_storage(&self) -> Storage {
1681 match &self.0 {
1682 StorageRepr::F64(v) => {
1683 Storage::from_repr(StorageRepr::C64(v.map_copy(|x| Complex64::new(x, 0.0))))
1684 }
1685 StorageRepr::C64(v) => Storage::from_repr(StorageRepr::C64(v.clone())),
1686 }
1687 }
1688
1689 /// Complex conjugate of all elements.
1690 ///
1691 /// For real (f64) storage, returns a clone (conjugate of real is identity).
1692 /// For complex (Complex64) storage, conjugates each element.
1693 ///
1694 /// This is inspired by the `conj` operation in ITensorMPS.jl.
1695 ///
1696 /// # Examples
1697 /// ```
1698 /// use tensor4all_tensorbackend::Storage;
1699 /// use num_complex::Complex64;
1700 ///
1701 /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -4.0)];
1702 /// let storage = Storage::from_dense_col_major(data, &[2]).unwrap();
1703 /// let conj_storage = storage.conj();
1704 ///
1705 /// let result = conj_storage.to_dense_c64_col_major_vec(&[2]).unwrap();
1706 /// assert_eq!(result[0], Complex64::new(1.0, -2.0));
1707 /// assert_eq!(result[1], Complex64::new(3.0, 4.0));
1708 /// ```
1709 pub fn conj(&self) -> Self {
1710 match &self.0 {
1711 StorageRepr::F64(v) => Storage::from_repr(StorageRepr::F64(v.clone())),
1712 StorageRepr::C64(v) => Storage::from_repr(StorageRepr::C64(v.map_copy(|z| z.conj()))),
1713 }
1714 }
1715
1716 /// Combine two f64 storages into Complex64 storage.
1717 ///
1718 /// `real_storage` becomes the real part, `imag_storage` becomes the imaginary part.
1719 /// Formula: `real + i * imag`.
1720 ///
1721 /// # Errors
1722 ///
1723 /// Returns an error if either storage is not `f64`, if their payload lengths
1724 /// differ, or if the result metadata is invalid.
1725 ///
1726 /// # Examples
1727 ///
1728 /// ```
1729 /// use tensor4all_tensorbackend::Storage;
1730 /// use num_complex::Complex64;
1731 ///
1732 /// let re = Storage::from_dense_col_major(vec![1.0_f64, 3.0], &[2]).unwrap();
1733 /// let im = Storage::from_dense_col_major(vec![2.0_f64, 4.0], &[2]).unwrap();
1734 /// let c = Storage::combine_to_complex(&re, &im).unwrap();
1735 /// assert!(c.is_c64());
1736 /// let vals = c.to_dense_c64_col_major_vec(&[2]).unwrap();
1737 /// assert_eq!(vals[0], Complex64::new(1.0, 2.0));
1738 /// assert_eq!(vals[1], Complex64::new(3.0, 4.0));
1739 /// ```
1740 pub fn combine_to_complex(
1741 real_storage: &Storage,
1742 imag_storage: &Storage,
1743 ) -> StorageResult<Storage> {
1744 match (&real_storage.0, &imag_storage.0) {
1745 (StorageRepr::F64(real), StorageRepr::F64(imag)) => {
1746 if real.len() != imag.len() {
1747 return Err(StorageError::LengthMismatch {
1748 operation: "combine_to_complex",
1749 left: real.len(),
1750 right: imag.len(),
1751 });
1752 }
1753 let complex_vec: Vec<Complex64> = real
1754 .data()
1755 .iter()
1756 .zip(imag.data().iter())
1757 .map(|(&r, &i)| Complex64::new(r, i))
1758 .collect();
1759 Ok(Storage::from_repr(StorageRepr::C64(
1760 StructuredStorage::new(
1761 complex_vec,
1762 real.payload_dims().to_vec(),
1763 real.strides().to_vec(),
1764 real.axis_classes().to_vec(),
1765 )
1766 .map_err(Self::invalid_storage_error)?,
1767 )))
1768 }
1769 _ => Err(StorageError::OperationNotSupported {
1770 operation: "combine_to_complex",
1771 left: storage_scalar_kind(&real_storage.0),
1772 right: storage_scalar_kind(&imag_storage.0),
1773 }),
1774 }
1775 }
1776
1777 /// Add two storages element-wise, returning `Result` on error instead of panicking.
1778 ///
1779 /// Both storages must have the same type and length.
1780 ///
1781 /// # Errors
1782 ///
1783 /// Returns an error if storage types or lengths don't match.
1784 ///
1785 /// # Examples
1786 ///
1787 /// ```
1788 /// use tensor4all_tensorbackend::Storage;
1789 ///
1790 /// let a = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
1791 /// let b = Storage::from_dense_col_major(vec![3.0_f64, 4.0], &[2]).unwrap();
1792 /// let c = a.try_add(&b).unwrap();
1793 /// assert_eq!(c.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![4.0, 6.0]);
1794 /// ```
1795 pub fn try_add(&self, other: &Storage) -> StorageResult<Storage> {
1796 match (&self.0, &other.0) {
1797 (StorageRepr::F64(a), StorageRepr::F64(b)) => {
1798 if a.len() != b.len() {
1799 return Err(StorageError::LengthMismatch {
1800 operation: "addition",
1801 left: a.len(),
1802 right: b.len(),
1803 });
1804 }
1805 let sum_vec: Vec<f64> = a
1806 .data()
1807 .iter()
1808 .zip(b.data().iter())
1809 .map(|(&x, &y)| x + y)
1810 .collect();
1811 Ok(Storage::from_repr(StorageRepr::F64(
1812 StructuredStorage::new(
1813 sum_vec,
1814 a.payload_dims().to_vec(),
1815 a.strides().to_vec(),
1816 a.axis_classes().to_vec(),
1817 )
1818 .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
1819 )))
1820 }
1821 (StorageRepr::C64(a), StorageRepr::C64(b)) => {
1822 if a.len() != b.len() {
1823 return Err(StorageError::LengthMismatch {
1824 operation: "addition",
1825 left: a.len(),
1826 right: b.len(),
1827 });
1828 }
1829 let sum_vec: Vec<Complex64> = a
1830 .data()
1831 .iter()
1832 .zip(b.data().iter())
1833 .map(|(&x, &y)| x + y)
1834 .collect();
1835 Ok(Storage::from_repr(StorageRepr::C64(
1836 StructuredStorage::new(
1837 sum_vec,
1838 a.payload_dims().to_vec(),
1839 a.strides().to_vec(),
1840 a.axis_classes().to_vec(),
1841 )
1842 .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
1843 )))
1844 }
1845 _ => Err(StorageError::OperationNotSupported {
1846 operation: "addition",
1847 left: storage_scalar_kind(&self.0),
1848 right: storage_scalar_kind(&other.0),
1849 }),
1850 }
1851 }
1852
1853 /// Try to subtract two storages element-wise.
1854 ///
1855 /// # Errors
1856 ///
1857 /// Returns an error if the storages have different types or lengths.
1858 ///
1859 /// # Examples
1860 ///
1861 /// ```
1862 /// use tensor4all_tensorbackend::Storage;
1863 ///
1864 /// let a = Storage::from_dense_col_major(vec![5.0_f64, 7.0], &[2]).unwrap();
1865 /// let b = Storage::from_dense_col_major(vec![1.0_f64, 3.0], &[2]).unwrap();
1866 /// let c = a.try_sub(&b).unwrap();
1867 /// assert_eq!(c.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![4.0, 4.0]);
1868 /// ```
1869 pub fn try_sub(&self, other: &Storage) -> StorageResult<Storage> {
1870 match (&self.0, &other.0) {
1871 (StorageRepr::F64(a), StorageRepr::F64(b)) => {
1872 if a.len() != b.len() {
1873 return Err(StorageError::LengthMismatch {
1874 operation: "subtraction",
1875 left: a.len(),
1876 right: b.len(),
1877 });
1878 }
1879 let diff_vec: Vec<f64> = a
1880 .data()
1881 .iter()
1882 .zip(b.data().iter())
1883 .map(|(&x, &y)| x - y)
1884 .collect();
1885 Ok(Storage::from_repr(StorageRepr::F64(
1886 StructuredStorage::new(
1887 diff_vec,
1888 a.payload_dims().to_vec(),
1889 a.strides().to_vec(),
1890 a.axis_classes().to_vec(),
1891 )
1892 .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
1893 )))
1894 }
1895 (StorageRepr::C64(a), StorageRepr::C64(b)) => {
1896 if a.len() != b.len() {
1897 return Err(StorageError::LengthMismatch {
1898 operation: "subtraction",
1899 left: a.len(),
1900 right: b.len(),
1901 });
1902 }
1903 let diff_vec: Vec<Complex64> = a
1904 .data()
1905 .iter()
1906 .zip(b.data().iter())
1907 .map(|(&x, &y)| x - y)
1908 .collect();
1909 Ok(Storage::from_repr(StorageRepr::C64(
1910 StructuredStorage::new(
1911 diff_vec,
1912 a.payload_dims().to_vec(),
1913 a.strides().to_vec(),
1914 a.axis_classes().to_vec(),
1915 )
1916 .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
1917 )))
1918 }
1919 _ => Err(StorageError::OperationNotSupported {
1920 operation: "subtraction",
1921 left: storage_scalar_kind(&self.0),
1922 right: storage_scalar_kind(&other.0),
1923 }),
1924 }
1925 }
1926
1927 /// Scale storage by a scalar value.
1928 ///
1929 /// If the scalar is complex but the storage is real, the storage is promoted to complex.
1930 ///
1931 /// # Examples
1932 ///
1933 /// ```
1934 /// use tensor4all_tensorbackend::{AnyScalar, Storage};
1935 ///
1936 /// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0, 3.0], &[3]).unwrap();
1937 /// let scaled = s.scale(&AnyScalar::new_real(2.0));
1938 /// assert_eq!(scaled.to_dense_f64_col_major_vec(&[3]).unwrap(), vec![2.0, 4.0, 6.0]);
1939 /// ```
1940 pub fn scale(&self, scalar: &crate::AnyScalar) -> Storage {
1941 self * scalar.clone()
1942 }
1943
1944 /// Compute linear combination: `a * self + b * other`.
1945 ///
1946 /// # Errors
1947 ///
1948 /// Returns an error if the storages have different types or lengths.
1949 /// If any scalar is complex, the result is promoted to complex.
1950 ///
1951 /// # Examples
1952 ///
1953 /// ```
1954 /// use tensor4all_tensorbackend::{AnyScalar, Storage};
1955 ///
1956 /// let x = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
1957 /// let y = Storage::from_dense_col_major(vec![3.0_f64, 4.0], &[2]).unwrap();
1958 /// let a = AnyScalar::new_real(2.0);
1959 /// let b = AnyScalar::new_real(3.0);
1960 /// // result = 2*[1,2] + 3*[3,4] = [11, 16]
1961 /// let result = x.axpby(&a, &y, &b).unwrap();
1962 /// assert_eq!(result.to_dense_f64_col_major_vec(&[2]).unwrap(), vec![11.0, 16.0]);
1963 /// ```
1964 pub fn axpby(
1965 &self,
1966 a: &crate::AnyScalar,
1967 other: &Storage,
1968 b: &crate::AnyScalar,
1969 ) -> StorageResult<Storage> {
1970 // First check lengths match
1971 if self.len() != other.len() {
1972 return Err(StorageError::LengthMismatch {
1973 operation: "axpby",
1974 left: self.len(),
1975 right: other.len(),
1976 });
1977 }
1978
1979 // Determine if we need complex output
1980 let needs_complex = a.is_complex()
1981 || b.is_complex()
1982 || matches!(&self.0, StorageRepr::C64(_))
1983 || matches!(&other.0, StorageRepr::C64(_));
1984
1985 if needs_complex {
1986 // Promote everything to complex
1987 let a_c: Complex64 = a.clone().into();
1988 let b_c: Complex64 = b.clone().into();
1989
1990 let (result, payload_dims, strides, axis_classes): (
1991 Vec<Complex64>,
1992 Vec<usize>,
1993 Vec<isize>,
1994 Vec<usize>,
1995 ) = match (&self.0, &other.0) {
1996 (StorageRepr::F64(x), StorageRepr::F64(y)) => (
1997 x.data()
1998 .iter()
1999 .zip(y.data().iter())
2000 .map(|(&xi, &yi)| {
2001 a_c * Complex64::new(xi, 0.0) + b_c * Complex64::new(yi, 0.0)
2002 })
2003 .collect(),
2004 x.payload_dims().to_vec(),
2005 x.strides().to_vec(),
2006 x.axis_classes().to_vec(),
2007 ),
2008 (StorageRepr::F64(x), StorageRepr::C64(y)) => (
2009 x.data()
2010 .iter()
2011 .zip(y.data().iter())
2012 .map(|(&xi, &yi)| a_c * Complex64::new(xi, 0.0) + b_c * yi)
2013 .collect(),
2014 x.payload_dims().to_vec(),
2015 x.strides().to_vec(),
2016 x.axis_classes().to_vec(),
2017 ),
2018 (StorageRepr::C64(x), StorageRepr::F64(y)) => (
2019 x.data()
2020 .iter()
2021 .zip(y.data().iter())
2022 .map(|(&xi, &yi)| a_c * xi + b_c * Complex64::new(yi, 0.0))
2023 .collect(),
2024 x.payload_dims().to_vec(),
2025 x.strides().to_vec(),
2026 x.axis_classes().to_vec(),
2027 ),
2028 (StorageRepr::C64(x), StorageRepr::C64(y)) => (
2029 x.data()
2030 .iter()
2031 .zip(y.data().iter())
2032 .map(|(&xi, &yi)| a_c * xi + b_c * yi)
2033 .collect(),
2034 x.payload_dims().to_vec(),
2035 x.strides().to_vec(),
2036 x.axis_classes().to_vec(),
2037 ),
2038 };
2039 Ok(Storage::from_repr(StorageRepr::C64(
2040 StructuredStorage::new(result, payload_dims, strides, axis_classes)
2041 .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
2042 )))
2043 } else {
2044 // All real
2045 if !a.is_real() || !b.is_real() {
2046 return Err(StorageError::RealScalarRequired {
2047 operation: "real axpby",
2048 a: a.to_string(),
2049 b: b.to_string(),
2050 });
2051 }
2052 let a_f = a.real();
2053 let b_f = b.real();
2054
2055 match (&self.0, &other.0) {
2056 (StorageRepr::F64(x), StorageRepr::F64(y)) => {
2057 let result: Vec<f64> = x
2058 .data()
2059 .iter()
2060 .zip(y.data().iter())
2061 .map(|(&xi, &yi)| a_f * xi + b_f * yi)
2062 .collect();
2063 Ok(Storage::from_repr(StorageRepr::F64(
2064 StructuredStorage::new(
2065 result,
2066 x.payload_dims().to_vec(),
2067 x.strides().to_vec(),
2068 x.axis_classes().to_vec(),
2069 )
2070 .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))?,
2071 )))
2072 }
2073 _ => Err(StorageError::OperationNotSupported {
2074 operation: "axpby",
2075 left: storage_scalar_kind(&self.0),
2076 right: storage_scalar_kind(&other.0),
2077 }),
2078 }
2079 }
2080 }
2081}
2082
2083/// Helper to get a mutable reference to storage, cloning if needed (COW).
2084///
2085/// Uses `Arc::make_mut` semantics: if the `Arc` has only one strong reference,
2086/// returns a mutable reference to the existing allocation. Otherwise clones
2087/// the inner value first.
2088///
2089/// # Examples
2090///
2091/// ```
2092/// use std::sync::Arc;
2093/// use tensor4all_tensorbackend::{make_mut_storage, Storage};
2094///
2095/// let s = Storage::from_dense_col_major(vec![1.0_f64, 2.0], &[2]).unwrap();
2096/// let mut arc = Arc::new(s);
2097/// let s_mut = make_mut_storage(&mut arc);
2098/// // s_mut is now a mutable reference to Storage
2099/// assert!(s_mut.is_f64());
2100/// ```
2101pub fn make_mut_storage(arc: &mut Arc<Storage>) -> &mut Storage {
2102 Arc::make_mut(arc)
2103}
2104
2105/// Get the minimum dimension from a slice of dimensions.
2106///
2107/// Returns 1 for an empty slice. This is used for DiagTensor where all
2108/// indices must have the same dimension.
2109///
2110/// # Examples
2111///
2112/// ```
2113/// use tensor4all_tensorbackend::mindim;
2114///
2115/// assert_eq!(mindim(&[2, 3, 4]), 2);
2116/// assert_eq!(mindim(&[5, 5, 5]), 5);
2117/// assert_eq!(mindim(&[]), 1);
2118/// ```
2119pub fn mindim(dims: &[usize]) -> usize {
2120 dims.iter().copied().min().unwrap_or(1)
2121}
2122
2123/// Contract two storage tensors along specified axes.
2124///
2125/// All storage is StructuredStorage; contraction is delegated to the native
2126/// tenferro backend. This is the primary tensor contraction entry point at
2127/// the storage layer.
2128///
2129/// # Arguments
2130///
2131/// * `storage_a` - First tensor storage
2132/// * `dims_a` - Dimensions of the first tensor
2133/// * `axes_a` - Axes of the first tensor to contract
2134/// * `storage_b` - Second tensor storage
2135/// * `dims_b` - Dimensions of the second tensor
2136/// * `axes_b` - Axes of the second tensor to contract
2137/// * `result_dims` - Dimensions of the result tensor (empty for scalar result)
2138///
2139/// # Returns
2140/// A new `Storage` containing the contracted result.
2141///
2142/// # Errors
2143///
2144/// Returns an error if axes are invalid, contracted dimensions do not match, or
2145/// the native backend rejects the contraction.
2146///
2147/// # Examples
2148///
2149/// ```
2150/// use tensor4all_tensorbackend::{contract_storage, Storage};
2151///
2152/// // Matrix-vector multiply: A(2x3) * v(3) -> result(2)
2153/// let a = Storage::from_dense_col_major(
2154/// vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3],
2155/// ).unwrap();
2156/// let v = Storage::from_dense_col_major(vec![1.0, 1.0, 1.0], &[3]).unwrap();
2157/// let result = contract_storage(&a, &[2, 3], &[1], &v, &[3], &[0], &[2]).unwrap();
2158/// // Row sums: [1+3+5, 2+4+6] = [9, 12]
2159/// let vals = result.to_dense_f64_col_major_vec(&[2]).unwrap();
2160/// assert!((vals[0] - 9.0).abs() < 1e-10);
2161/// assert!((vals[1] - 12.0).abs() < 1e-10);
2162/// ```
2163pub fn contract_storage(
2164 storage_a: &Storage,
2165 dims_a: &[usize],
2166 axes_a: &[usize],
2167 storage_b: &Storage,
2168 dims_b: &[usize],
2169 axes_b: &[usize],
2170 result_dims: &[usize],
2171) -> StorageResult<Storage> {
2172 try_contract_storage(
2173 storage_a,
2174 dims_a,
2175 axes_a,
2176 storage_b,
2177 dims_b,
2178 axes_b,
2179 result_dims,
2180 )
2181}
2182
2183fn try_contract_storage(
2184 storage_a: &Storage,
2185 dims_a: &[usize],
2186 axes_a: &[usize],
2187 storage_b: &Storage,
2188 dims_b: &[usize],
2189 axes_b: &[usize],
2190 result_dims: &[usize],
2191) -> StorageResult<Storage> {
2192 if axes_a.len() != axes_b.len() {
2193 return Err(StorageError::InvalidStructuredStorage(format!(
2194 "contract axes lengths must match: {} != {}",
2195 axes_a.len(),
2196 axes_b.len()
2197 )));
2198 }
2199
2200 for (&a_axis, &b_axis) in axes_a.iter().zip(axes_b.iter()) {
2201 let Some(&a_dim) = dims_a.get(a_axis) else {
2202 return Err(StorageError::InvalidStructuredStorage(format!(
2203 "contract axis {a_axis} is out of range for left dims {dims_a:?}"
2204 )));
2205 };
2206 let Some(&b_dim) = dims_b.get(b_axis) else {
2207 return Err(StorageError::InvalidStructuredStorage(format!(
2208 "contract axis {b_axis} is out of range for right dims {dims_b:?}"
2209 )));
2210 };
2211 if a_dim != b_dim {
2212 return Err(StorageError::InvalidStructuredStorage(format!(
2213 "contracted dimensions must match: dims_a[{a_axis}] = {a_dim} != dims_b[{b_axis}] = {b_dim}"
2214 )));
2215 }
2216 }
2217
2218 crate::tenferro_bridge::contract_storage_native(
2219 storage_a,
2220 dims_a,
2221 axes_a,
2222 storage_b,
2223 dims_b,
2224 axes_b,
2225 result_dims,
2226 )
2227 .map_err(|err| StorageError::InvalidStructuredStorage(err.to_string()))
2228}
2229
2230/// Multiply storage by a scalar (f64).
2231/// For Complex64 storage, multiplies each element by the scalar (treated as real).
2232impl Mul<f64> for &Storage {
2233 type Output = Storage;
2234
2235 fn mul(self, scalar: f64) -> Self::Output {
2236 match &self.0 {
2237 StorageRepr::F64(v) => Storage::from_repr(StorageRepr::F64(v.map_copy(|x| x * scalar))),
2238 StorageRepr::C64(v) => Storage::from_repr(StorageRepr::C64(
2239 v.map_copy(|z| z * Complex64::new(scalar, 0.0)),
2240 )),
2241 }
2242 }
2243}
2244
2245/// Multiply storage by a scalar (Complex64).
2246impl Mul<Complex64> for &Storage {
2247 type Output = Storage;
2248
2249 fn mul(self, scalar: Complex64) -> Self::Output {
2250 match &self.0 {
2251 StorageRepr::F64(v) => Storage::from_repr(StorageRepr::C64(
2252 v.map_copy(|x| Complex64::new(x, 0.0) * scalar),
2253 )),
2254 StorageRepr::C64(v) => Storage::from_repr(StorageRepr::C64(v.map_copy(|z| z * scalar))),
2255 }
2256 }
2257}
2258
2259/// Multiply storage by a scalar (AnyScalar).
2260/// May promote f64 storage to Complex64 when scalar is complex.
2261impl Mul<AnyScalar> for &Storage {
2262 type Output = Storage;
2263
2264 fn mul(self, scalar: AnyScalar) -> Self::Output {
2265 if scalar.is_complex() {
2266 let z: Complex64 = scalar.into();
2267 self * z
2268 } else {
2269 self * scalar.real()
2270 }
2271 }
2272}
2273
2274#[cfg(test)]
2275mod tests;