Skip to main content

strided_view/
view.rs

1//! Julia-like dynamic-rank strided view types.
2//!
3//! This module provides the canonical view types for strided operations,
4//! matching Julia's StridedViews.jl data model:
5//!
6//! - [`StridedView`]: Immutable dynamic-rank strided view with lazy element operations
7//! - [`StridedViewMut`]: Mutable dynamic-rank strided view (Identity op only)
8//! - [`StridedArray`]: Owned strided multidimensional array
9
10use std::marker::PhantomData;
11use std::ops::{Index, IndexMut};
12use std::sync::Arc;
13
14use crate::element_op::{ComposableElementOp, ElementOp, ElementOpApply, Identity};
15use crate::{Result, StridedError};
16
17// ============================================================================
18// Validation helpers
19// ============================================================================
20
21/// Validate that all accessed offsets stay within `[0, len)`.
22pub(crate) fn validate_bounds(
23    len: usize,
24    dims: &[usize],
25    strides: &[isize],
26    offset: isize,
27) -> Result<()> {
28    if dims.len() != strides.len() {
29        return Err(StridedError::StrideLengthMismatch);
30    }
31    // Empty array - no access needed
32    if dims.iter().any(|&d| d == 0) {
33        return Ok(());
34    }
35    // Compute min and max offsets
36    let mut min_offset = offset;
37    let mut max_offset = offset;
38    for (&dim, &stride) in dims.iter().zip(strides.iter()) {
39        if dim > 1 {
40            let end = stride
41                .checked_mul(dim as isize - 1)
42                .ok_or(StridedError::OffsetOverflow)?;
43            if end >= 0 {
44                max_offset = max_offset
45                    .checked_add(end)
46                    .ok_or(StridedError::OffsetOverflow)?;
47            } else {
48                min_offset = min_offset
49                    .checked_add(end)
50                    .ok_or(StridedError::OffsetOverflow)?;
51            }
52        }
53    }
54    if min_offset < 0 || max_offset < 0 {
55        return Err(StridedError::OffsetOverflow);
56    }
57    if max_offset as usize >= len {
58        return Err(StridedError::OffsetOverflow);
59    }
60    Ok(())
61}
62
63/// Compute column-major strides (Julia default: first index varies fastest).
64pub fn col_major_strides(dims: &[usize]) -> Vec<isize> {
65    let rank = dims.len();
66    if rank == 0 {
67        return vec![];
68    }
69    let mut strides = vec![1isize; rank];
70    for i in 1..rank {
71        strides[i] = strides[i - 1] * dims[i - 1] as isize;
72    }
73    strides
74}
75
76/// Compute row-major strides (C default: last index varies fastest).
77pub fn row_major_strides(dims: &[usize]) -> Vec<isize> {
78    let rank = dims.len();
79    if rank == 0 {
80        return vec![];
81    }
82    let mut strides = vec![1isize; rank];
83    for i in (0..rank - 1).rev() {
84        strides[i] = strides[i + 1] * dims[i + 1] as isize;
85    }
86    strides
87}
88
89// ============================================================================
90// StridedView
91// ============================================================================
92
93/// Dynamic-rank immutable strided view with lazy element operations.
94///
95/// This is the Julia-equivalent `StridedView` type with:
96/// - Dynamic rank (dims/strides are heap-allocated)
97/// - Lazy element operations via the `Op` type parameter
98/// - Zero-copy transformations (permute, transpose, adjoint, conj)
99///
100/// # Type Parameters
101/// - `'a`: Lifetime of the underlying data
102/// - `T`: Element type
103/// - `Op`: Element operation applied lazily on access (default: `Identity`)
104pub struct StridedView<'a, T, Op = Identity> {
105    ptr: *const T,
106    data: &'a [T],
107    dims: Arc<[usize]>,
108    strides: Arc<[isize]>,
109    offset: isize,
110    _op: PhantomData<Op>,
111}
112
113unsafe impl<T: Send, Op: Send> Send for StridedView<'_, T, Op> {}
114unsafe impl<T: Sync, Op: Sync> Sync for StridedView<'_, T, Op> {}
115
116impl<T, Op> Clone for StridedView<'_, T, Op> {
117    fn clone(&self) -> Self {
118        Self {
119            ptr: self.ptr,
120            data: self.data,
121            dims: self.dims.clone(),
122            strides: self.strides.clone(),
123            offset: self.offset,
124            _op: PhantomData,
125        }
126    }
127}
128
129impl<T: std::fmt::Debug, Op> std::fmt::Debug for StridedView<'_, T, Op> {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("StridedView")
132            .field("dims", &self.dims)
133            .field("strides", &self.strides)
134            .field("offset", &self.offset)
135            .finish()
136    }
137}
138
139impl<'a, T, Op> StridedView<'a, T, Op> {
140    /// Create a new immutable strided view from a borrowed slice.
141    pub fn new(data: &'a [T], dims: &[usize], strides: &[isize], offset: isize) -> Result<Self> {
142        validate_bounds(data.len(), dims, strides, offset)?;
143        let ptr = unsafe { data.as_ptr().offset(offset) };
144        Ok(Self {
145            ptr,
146            data,
147            dims: Arc::from(dims),
148            strides: Arc::from(strides),
149            offset,
150            _op: PhantomData,
151        })
152    }
153
154    /// Create a view without bounds checking.
155    ///
156    /// # Safety
157    /// The caller must ensure all index combinations stay within bounds.
158    pub unsafe fn new_unchecked(
159        data: &'a [T],
160        dims: &[usize],
161        strides: &[isize],
162        offset: isize,
163    ) -> Self {
164        let ptr = data.as_ptr().offset(offset);
165        Self {
166            ptr,
167            data,
168            dims: Arc::from(dims),
169            strides: Arc::from(strides),
170            offset,
171            _op: PhantomData,
172        }
173    }
174
175    /// Returns the shape (dimension sizes) of this view.
176    #[inline]
177    pub fn dims(&self) -> &[usize] {
178        &self.dims
179    }
180
181    /// Returns the strides (in units of `T`) for each dimension.
182    #[inline]
183    pub fn strides(&self) -> &[isize] {
184        &self.strides
185    }
186
187    /// Returns the byte offset into the backing data.
188    #[inline]
189    pub fn offset(&self) -> isize {
190        self.offset
191    }
192
193    /// Returns the number of dimensions (rank).
194    #[inline]
195    pub fn ndim(&self) -> usize {
196        self.dims.len()
197    }
198
199    /// Returns the total number of elements.
200    #[inline]
201    pub fn len(&self) -> usize {
202        self.dims.iter().product()
203    }
204
205    /// Returns `true` if any dimension is zero.
206    #[inline]
207    pub fn is_empty(&self) -> bool {
208        self.dims.iter().any(|&d| d == 0)
209    }
210
211    /// Returns a reference to the backing data slice.
212    #[inline]
213    pub fn data(&self) -> &'a [T] {
214        self.data
215    }
216
217    /// Raw const pointer to element at the view's base offset.
218    #[inline]
219    pub fn ptr(&self) -> *const T {
220        self.ptr
221    }
222
223    /// Permute dimensions.
224    pub fn permute(&self, perm: &[usize]) -> Result<StridedView<'a, T, Op>> {
225        let rank = self.dims.len();
226        if perm.len() != rank {
227            return Err(StridedError::RankMismatch(perm.len(), rank));
228        }
229        let mut seen = vec![false; rank];
230        for &p in perm {
231            if p >= rank {
232                return Err(StridedError::InvalidAxis { axis: p, rank });
233            }
234            if seen[p] {
235                return Err(StridedError::InvalidAxis { axis: p, rank });
236            }
237            seen[p] = true;
238        }
239        let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
240        let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
241        Ok(StridedView {
242            ptr: self.ptr,
243            data: self.data,
244            dims: Arc::from(new_dims),
245            strides: Arc::from(new_strides),
246            offset: self.offset,
247            _op: PhantomData,
248        })
249    }
250
251    /// Create a diagonal view by fusing repeated axis pairs via stride trick (zero-copy).
252    ///
253    /// For each pair `(a, b)`:
254    /// - New stride = `strides[a] + strides[b]`
255    /// - New dim = `min(dims[a], dims[b])`
256    /// - The higher-numbered axis is removed
257    /// - Pairs use **original** axis numbering
258    ///
259    /// # Example
260    /// `A[i,i,j]` shape=`[n,n,m]` strides=`[s0,s1,s2]` -> shape=`[n,m]` strides=`[s0+s1, s2]`
261    pub fn diagonal_view(&self, axis_pairs: &[(usize, usize)]) -> Result<StridedView<'a, T, Op>> {
262        let ndim = self.ndim();
263        let mut dims: Vec<usize> = self.dims().to_vec();
264        let mut strides: Vec<isize> = self.strides().to_vec();
265
266        let mut axes_to_remove = Vec::new();
267        for &(a, b) in axis_pairs {
268            let (lo, hi) = if a < b { (a, b) } else { (b, a) };
269            if lo >= ndim || hi >= ndim {
270                return Err(StridedError::InvalidAxis {
271                    axis: hi,
272                    rank: ndim,
273                });
274            }
275            if lo == hi {
276                return Err(StridedError::InvalidAxis {
277                    axis: lo,
278                    rank: ndim,
279                });
280            }
281            strides[lo] += strides[hi];
282            dims[lo] = dims[lo].min(dims[hi]);
283            axes_to_remove.push(hi);
284        }
285
286        axes_to_remove.sort_unstable();
287        axes_to_remove.dedup();
288        for &ax in axes_to_remove.iter().rev() {
289            dims.remove(ax);
290            strides.remove(ax);
291        }
292
293        unsafe {
294            Ok(StridedView::new_unchecked(
295                self.data(),
296                &dims,
297                &strides,
298                self.offset(),
299            ))
300        }
301    }
302
303    /// Broadcast this view to a target shape.
304    ///
305    /// Size-1 dimensions are expanded (stride set to 0) to match target.
306    pub fn broadcast(&self, target_dims: &[usize]) -> Result<StridedView<'a, T, Op>> {
307        if self.dims.len() != target_dims.len() {
308            return Err(StridedError::RankMismatch(
309                self.dims.len(),
310                target_dims.len(),
311            ));
312        }
313        let mut new_strides = Vec::with_capacity(self.dims.len());
314        for i in 0..self.dims.len() {
315            if self.dims[i] == target_dims[i] {
316                new_strides.push(self.strides[i]);
317            } else if self.dims[i] == 1 {
318                new_strides.push(0);
319            } else {
320                return Err(StridedError::ShapeMismatch(
321                    self.dims.to_vec(),
322                    target_dims.to_vec(),
323                ));
324            }
325        }
326        Ok(StridedView {
327            ptr: self.ptr,
328            data: self.data,
329            dims: Arc::from(target_dims),
330            strides: Arc::from(new_strides),
331            offset: self.offset,
332            _op: PhantomData,
333        })
334    }
335}
336
337/// Composition methods: require `T: ElementOpApply` and `Op: ComposableElementOp<T>`.
338impl<'a, T: Copy + ElementOpApply, Op: ComposableElementOp<T>> StridedView<'a, T, Op> {
339    /// Transpose a 2D view: reverses dimensions and composes the Transpose element op.
340    ///
341    /// Julia equivalent: `Base.transpose(a::AbstractStridedView{<:Any, 2})`
342    pub fn transpose_2d(&self) -> Result<StridedView<'a, T, Op::ComposeTranspose>> {
343        if self.dims.len() != 2 {
344            return Err(StridedError::RankMismatch(self.dims.len(), 2));
345        }
346        Ok(StridedView {
347            ptr: self.ptr,
348            data: self.data,
349            dims: Arc::new([self.dims[1], self.dims[0]]),
350            strides: Arc::new([self.strides[1], self.strides[0]]),
351            offset: self.offset,
352            _op: PhantomData,
353        })
354    }
355
356    /// Adjoint (conjugate transpose) of a 2D view.
357    ///
358    /// Julia equivalent: `Base.adjoint(a::AbstractStridedView{<:Any, 2})`
359    pub fn adjoint_2d(&self) -> Result<StridedView<'a, T, Op::ComposeAdjoint>> {
360        if self.dims.len() != 2 {
361            return Err(StridedError::RankMismatch(self.dims.len(), 2));
362        }
363        Ok(StridedView {
364            ptr: self.ptr,
365            data: self.data,
366            dims: Arc::new([self.dims[1], self.dims[0]]),
367            strides: Arc::new([self.strides[1], self.strides[0]]),
368            offset: self.offset,
369            _op: PhantomData,
370        })
371    }
372
373    /// Complex conjugate (compose Conj without changing dims/strides).
374    ///
375    /// Julia equivalent: `Base.conj(a::AbstractStridedView)`
376    pub fn conj(&self) -> StridedView<'a, T, Op::ComposeConj> {
377        StridedView {
378            ptr: self.ptr,
379            data: self.data,
380            dims: self.dims.clone(),
381            strides: self.strides.clone(),
382            offset: self.offset,
383            _op: PhantomData,
384        }
385    }
386}
387
388/// Element access: requires `Op: ElementOp<T>`.
389impl<'a, T: Copy, Op: ElementOp<T>> StridedView<'a, T, Op> {
390    /// Get an element with the element operation applied.
391    pub fn get(&self, indices: &[usize]) -> T {
392        assert_eq!(indices.len(), self.dims.len(), "wrong number of indices");
393        let mut idx = 0isize;
394        for (i, &index) in indices.iter().enumerate() {
395            assert!(
396                index < self.dims[i],
397                "index {} out of bounds for dim {}",
398                index,
399                self.dims[i]
400            );
401            idx += index as isize * self.strides[i];
402        }
403        Op::apply(unsafe { *self.ptr.offset(idx) })
404    }
405
406    /// Get an element without bounds checking.
407    ///
408    /// # Safety
409    /// Caller must ensure indices are within bounds.
410    #[inline]
411    pub unsafe fn get_unchecked(&self, indices: &[usize]) -> T {
412        let mut idx = 0isize;
413        for (i, &index) in indices.iter().enumerate() {
414            idx += index as isize * self.strides[i];
415        }
416        Op::apply(*self.ptr.offset(idx))
417    }
418}
419
420// ============================================================================
421// StridedViewMut
422// ============================================================================
423
424/// Dynamic-rank mutable strided view.
425///
426/// Always uses `Identity` element operation for write simplicity.
427/// Julia typically applies ops on the read side.
428pub struct StridedViewMut<'a, T> {
429    ptr: *mut T,
430    data: &'a mut [T],
431    dims: Arc<[usize]>,
432    strides: Arc<[isize]>,
433    offset: isize,
434}
435
436unsafe impl<T: Send> Send for StridedViewMut<'_, T> {}
437
438impl<T: std::fmt::Debug> std::fmt::Debug for StridedViewMut<'_, T> {
439    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440        f.debug_struct("StridedViewMut")
441            .field("dims", &self.dims)
442            .field("strides", &self.strides)
443            .field("offset", &self.offset)
444            .finish()
445    }
446}
447
448impl<'a, T> StridedViewMut<'a, T> {
449    /// Create a new mutable strided view.
450    pub fn new(
451        data: &'a mut [T],
452        dims: &[usize],
453        strides: &[isize],
454        offset: isize,
455    ) -> Result<Self> {
456        validate_bounds(data.len(), dims, strides, offset)?;
457        let ptr = unsafe { data.as_mut_ptr().offset(offset) };
458        Ok(Self {
459            ptr,
460            data,
461            dims: Arc::from(dims),
462            strides: Arc::from(strides),
463            offset,
464        })
465    }
466
467    /// Create without bounds checking.
468    ///
469    /// # Safety
470    /// Caller must ensure all index combinations stay within bounds.
471    pub unsafe fn new_unchecked(
472        data: &'a mut [T],
473        dims: &[usize],
474        strides: &[isize],
475        offset: isize,
476    ) -> Self {
477        let ptr = data.as_mut_ptr().offset(offset);
478        Self {
479            ptr,
480            data,
481            dims: Arc::from(dims),
482            strides: Arc::from(strides),
483            offset,
484        }
485    }
486
487    /// Returns the shape (dimension sizes) of this view.
488    #[inline]
489    pub fn dims(&self) -> &[usize] {
490        &self.dims
491    }
492
493    /// Returns the strides (in units of `T`) for each dimension.
494    #[inline]
495    pub fn strides(&self) -> &[isize] {
496        &self.strides
497    }
498
499    /// Returns the byte offset into the backing data.
500    #[inline]
501    pub fn offset(&self) -> isize {
502        self.offset
503    }
504
505    /// Returns the number of dimensions (rank).
506    #[inline]
507    pub fn ndim(&self) -> usize {
508        self.dims.len()
509    }
510
511    /// Returns the total number of elements.
512    #[inline]
513    pub fn len(&self) -> usize {
514        self.dims.iter().product()
515    }
516
517    /// Returns `true` if any dimension is zero.
518    #[inline]
519    pub fn is_empty(&self) -> bool {
520        self.dims.iter().any(|&d| d == 0)
521    }
522
523    /// Raw const pointer to element at the view's base offset.
524    #[inline]
525    pub fn ptr(&self) -> *const T {
526        self.ptr as *const T
527    }
528
529    /// Raw mutable pointer to element at the view's base offset.
530    #[inline]
531    pub fn as_mut_ptr(&self) -> *mut T {
532        self.ptr
533    }
534
535    /// Returns a reference to the backing data slice.
536    #[inline]
537    pub fn data(&self) -> &[T] {
538        &self.data
539    }
540
541    /// Returns a mutable reference to the backing data slice.
542    #[inline]
543    pub fn data_mut(&mut self) -> &mut [T] {
544        &mut self.data
545    }
546
547    /// Permute dimensions, consuming the mutable view.
548    ///
549    /// Returns a new mutable view with reordered dimensions and strides.
550    /// Takes `self` by value to prevent aliasing of mutable views.
551    pub fn permute(self, perm: &[usize]) -> Result<StridedViewMut<'a, T>> {
552        let rank = self.dims.len();
553        if perm.len() != rank {
554            return Err(StridedError::RankMismatch(perm.len(), rank));
555        }
556        let mut seen = vec![false; rank];
557        for &p in perm {
558            if p >= rank {
559                return Err(StridedError::InvalidAxis { axis: p, rank });
560            }
561            if seen[p] {
562                return Err(StridedError::InvalidAxis { axis: p, rank });
563            }
564            seen[p] = true;
565        }
566        let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
567        let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
568        Ok(StridedViewMut {
569            ptr: self.ptr,
570            data: self.data,
571            dims: Arc::from(new_dims),
572            strides: Arc::from(new_strides),
573            offset: self.offset,
574        })
575    }
576
577    /// Reborrow as an immutable view.
578    pub fn as_view(&self) -> StridedView<'_, T, Identity> {
579        StridedView {
580            ptr: self.ptr as *const T,
581            data: unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.data.len()) },
582            dims: self.dims.clone(),
583            strides: self.strides.clone(),
584            offset: self.offset,
585            _op: PhantomData,
586        }
587    }
588}
589
590impl<'a, T: Copy> StridedViewMut<'a, T> {
591    /// Get an element.
592    pub fn get(&self, indices: &[usize]) -> T {
593        assert_eq!(indices.len(), self.dims.len());
594        let mut idx = 0isize;
595        for (i, &index) in indices.iter().enumerate() {
596            assert!(index < self.dims[i]);
597            idx += index as isize * self.strides[i];
598        }
599        unsafe { *self.ptr.offset(idx) }
600    }
601
602    /// Set an element.
603    pub fn set(&mut self, indices: &[usize], value: T) {
604        assert_eq!(indices.len(), self.dims.len());
605        let mut idx = 0isize;
606        for (i, &index) in indices.iter().enumerate() {
607            assert!(index < self.dims[i]);
608            idx += index as isize * self.strides[i];
609        }
610        unsafe {
611            *self.ptr.offset(idx) = value;
612        }
613    }
614}
615
616// ============================================================================
617// StridedArray
618// ============================================================================
619
620/// Owned strided multidimensional array.
621///
622/// Supports both column-major (Julia default) and row-major (C default) layouts.
623pub struct StridedArray<T> {
624    data: Vec<T>,
625    dims: Arc<[usize]>,
626    strides: Arc<[isize]>,
627    offset: isize,
628}
629
630impl<T: std::fmt::Debug> std::fmt::Debug for StridedArray<T> {
631    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632        f.debug_struct("StridedArray")
633            .field("dims", &self.dims)
634            .field("strides", &self.strides)
635            .field("offset", &self.offset)
636            .finish()
637    }
638}
639
640impl<T: Clone> Clone for StridedArray<T> {
641    fn clone(&self) -> Self {
642        Self {
643            data: self.data.clone(),
644            dims: self.dims.clone(),
645            strides: self.strides.clone(),
646            offset: self.offset,
647        }
648    }
649}
650
651impl<T: Clone + Default> StridedArray<T> {
652    /// Create a column-major (Julia default) tensor filled with Default values.
653    pub fn col_major(dims: &[usize]) -> Self {
654        let total: usize = dims.iter().product();
655        let data = vec![T::default(); total];
656        let strides = col_major_strides(dims);
657        Self {
658            data,
659            dims: Arc::from(dims),
660            strides: Arc::from(strides),
661            offset: 0,
662        }
663    }
664
665    /// Create a row-major (C default) tensor filled with Default values.
666    pub fn row_major(dims: &[usize]) -> Self {
667        let total: usize = dims.iter().product();
668        let data = vec![T::default(); total];
669        let strides = row_major_strides(dims);
670        Self {
671            data,
672            dims: Arc::from(dims),
673            strides: Arc::from(strides),
674            offset: 0,
675        }
676    }
677
678    /// Create a column-major tensor with values produced by a function.
679    ///
680    /// The function is called with indices in column-major iteration order.
681    pub fn from_fn_col_major(dims: &[usize], mut f: impl FnMut(&[usize]) -> T) -> Self {
682        let total: usize = dims.iter().product();
683        let strides = col_major_strides(dims);
684        let rank = dims.len();
685        let mut data = Vec::with_capacity(total);
686        let mut idx = vec![0usize; rank];
687        for _ in 0..total {
688            data.push(f(&idx));
689            for d in 0..rank {
690                idx[d] += 1;
691                if idx[d] < dims[d] {
692                    break;
693                }
694                idx[d] = 0;
695            }
696        }
697        Self {
698            data,
699            dims: Arc::from(dims),
700            strides: Arc::from(strides),
701            offset: 0,
702        }
703    }
704
705    /// Create a row-major tensor with values produced by a function.
706    ///
707    /// The function is called with indices in row-major iteration order.
708    pub fn from_fn_row_major(dims: &[usize], mut f: impl FnMut(&[usize]) -> T) -> Self {
709        let total: usize = dims.iter().product();
710        let strides = row_major_strides(dims);
711        let rank = dims.len();
712        let mut data = Vec::with_capacity(total);
713        let mut idx = vec![0usize; rank];
714        for _ in 0..total {
715            data.push(f(&idx));
716            for d in (0..rank).rev() {
717                idx[d] += 1;
718                if idx[d] < dims[d] {
719                    break;
720                }
721                idx[d] = 0;
722            }
723        }
724        Self {
725            data,
726            dims: Arc::from(dims),
727            strides: Arc::from(strides),
728            offset: 0,
729        }
730    }
731}
732
733impl<T> StridedArray<T> {
734    /// Create from raw parts.
735    pub fn from_parts(
736        data: Vec<T>,
737        dims: &[usize],
738        strides: &[isize],
739        offset: isize,
740    ) -> Result<Self> {
741        validate_bounds(data.len(), dims, strides, offset)?;
742        Ok(Self {
743            data,
744            dims: Arc::from(dims),
745            strides: Arc::from(strides),
746            offset,
747        })
748    }
749
750    /// Returns the shape (dimension sizes) of this array.
751    #[inline]
752    pub fn dims(&self) -> &[usize] {
753        &self.dims
754    }
755
756    /// Returns the strides (in units of `T`) for each dimension.
757    #[inline]
758    pub fn strides(&self) -> &[isize] {
759        &self.strides
760    }
761
762    /// Returns the number of dimensions (rank).
763    #[inline]
764    pub fn ndim(&self) -> usize {
765        self.dims.len()
766    }
767
768    /// Returns the total number of elements.
769    #[inline]
770    pub fn len(&self) -> usize {
771        self.dims.iter().product()
772    }
773
774    /// Returns `true` if any dimension is zero.
775    #[inline]
776    pub fn is_empty(&self) -> bool {
777        self.dims.iter().any(|&d| d == 0)
778    }
779
780    /// Returns a reference to the backing data slice.
781    #[inline]
782    pub fn data(&self) -> &[T] {
783        &self.data
784    }
785
786    /// Returns a mutable reference to the backing data slice.
787    #[inline]
788    pub fn data_mut(&mut self) -> &mut [T] {
789        &mut self.data
790    }
791
792    /// Create an immutable view over this tensor.
793    pub fn view(&self) -> StridedView<'_, T> {
794        let ptr = unsafe { self.data.as_ptr().offset(self.offset) };
795        StridedView {
796            ptr,
797            data: &self.data,
798            dims: self.dims.clone(),
799            strides: self.strides.clone(),
800            offset: self.offset,
801            _op: PhantomData,
802        }
803    }
804
805    /// Create a mutable view over this tensor.
806    pub fn view_mut(&mut self) -> StridedViewMut<'_, T> {
807        let ptr = unsafe { self.data.as_mut_ptr().offset(self.offset) };
808        StridedViewMut {
809            ptr,
810            data: &mut self.data,
811            dims: self.dims.clone(),
812            strides: self.strides.clone(),
813            offset: self.offset,
814        }
815    }
816
817    /// Permute dimensions (metadata-only reorder, no data copy).
818    ///
819    /// Returns a new array with reordered dims and strides.
820    /// The underlying data buffer is not touched.
821    pub fn permuted(self, perm: &[usize]) -> Result<Self> {
822        let rank = self.dims.len();
823        if perm.len() != rank {
824            return Err(StridedError::RankMismatch(perm.len(), rank));
825        }
826        let mut seen = vec![false; rank];
827        for &p in perm {
828            if p >= rank {
829                return Err(StridedError::InvalidAxis { axis: p, rank });
830            }
831            if seen[p] {
832                return Err(StridedError::InvalidAxis { axis: p, rank });
833            }
834            seen[p] = true;
835        }
836        let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
837        let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
838        Ok(Self {
839            data: self.data,
840            dims: Arc::from(new_dims),
841            strides: Arc::from(new_strides),
842            offset: self.offset,
843        })
844    }
845
846    /// Consume the array and return the backing `Vec<T>`.
847    pub fn into_data(self) -> Vec<T> {
848        self.data
849    }
850
851    /// Iterate over all elements in memory order.
852    pub fn iter(&self) -> std::slice::Iter<'_, T> {
853        self.data.iter()
854    }
855
856    /// Mutable iteration over all elements in memory order.
857    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
858        self.data.iter_mut()
859    }
860}
861
862impl<T: Default> StridedArray<T> {
863    /// Create a column-major tensor reusing an existing buffer.
864    ///
865    /// If `buf` has at least `product(dims)` elements, it is truncated and
866    /// zeroed.  Otherwise a fresh buffer is allocated.
867    pub fn col_major_from_buffer(mut buf: Vec<T>, dims: &[usize]) -> Self {
868        let total: usize = dims.iter().product();
869        if buf.len() >= total {
870            buf.truncate(total);
871        } else {
872            buf.resize_with(total, T::default);
873        }
874        // Zero the reused region
875        for v in buf.iter_mut() {
876            *v = T::default();
877        }
878        let strides = col_major_strides(dims);
879        Self {
880            data: buf,
881            dims: Arc::from(dims),
882            strides: Arc::from(strides),
883            offset: 0,
884        }
885    }
886}
887
888impl<T: Copy> StridedArray<T> {
889    /// Create a column-major tensor with **uninitialized** data.
890    ///
891    /// # Safety
892    /// Caller must write every element before reading.
893    pub unsafe fn col_major_uninit(dims: &[usize]) -> Self {
894        let total: usize = dims.iter().product();
895        let mut data = Vec::with_capacity(total);
896        data.set_len(total);
897        let strides = col_major_strides(dims);
898        Self {
899            data,
900            dims: Arc::from(dims),
901            strides: Arc::from(strides),
902            offset: 0,
903        }
904    }
905
906    /// Reuse an existing buffer as a column-major tensor **without zeroing**.
907    ///
908    /// # Safety
909    /// Caller must write every element before reading.
910    pub unsafe fn col_major_from_buffer_uninit(mut buf: Vec<T>, dims: &[usize]) -> Self {
911        let total: usize = dims.iter().product();
912        if buf.capacity() < total {
913            buf.reserve(total - buf.len());
914        }
915        buf.set_len(total);
916        let strides = col_major_strides(dims);
917        Self {
918            data: buf,
919            dims: Arc::from(dims),
920            strides: Arc::from(strides),
921            offset: 0,
922        }
923    }
924}
925
926impl<T: Copy> StridedArray<T> {
927    /// Get an element by multi-dimensional index.
928    pub fn get(&self, indices: &[usize]) -> T {
929        self.view().get(indices)
930    }
931
932    /// Set an element by multi-dimensional index.
933    pub fn set(&mut self, indices: &[usize], value: T) {
934        assert_eq!(indices.len(), self.dims.len());
935        let mut idx = self.offset;
936        for (i, &index) in indices.iter().enumerate() {
937            assert!(index < self.dims[i]);
938            idx += index as isize * self.strides[i];
939        }
940        self.data[idx as usize] = value;
941    }
942}
943
944impl<T: Copy> Index<&[usize]> for StridedArray<T> {
945    type Output = T;
946
947    fn index(&self, indices: &[usize]) -> &T {
948        let mut idx = self.offset;
949        for (i, &index) in indices.iter().enumerate() {
950            assert!(index < self.dims[i]);
951            idx += index as isize * self.strides[i];
952        }
953        &self.data[idx as usize]
954    }
955}
956
957impl<T: Copy> IndexMut<&[usize]> for StridedArray<T> {
958    fn index_mut(&mut self, indices: &[usize]) -> &mut T {
959        let mut idx = self.offset;
960        for (i, &index) in indices.iter().enumerate() {
961            assert!(index < self.dims[i]);
962            idx += index as isize * self.strides[i];
963        }
964        &mut self.data[idx as usize]
965    }
966}
967
968// ============================================================================
969// Tests
970// ============================================================================
971
972#[cfg(test)]
973mod tests {
974    use super::*;
975    use num_complex::Complex64;
976
977    #[test]
978    fn test_col_major_strides() {
979        assert_eq!(col_major_strides(&[3, 4]), vec![1, 3]);
980        assert_eq!(col_major_strides(&[2, 3, 4]), vec![1, 2, 6]);
981    }
982
983    #[test]
984    fn test_row_major_strides() {
985        assert_eq!(row_major_strides(&[3, 4]), vec![4, 1]);
986        assert_eq!(row_major_strides(&[2, 3, 4]), vec![12, 4, 1]);
987    }
988
989    #[test]
990    fn test_strided_view_new() {
991        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
992        let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
993        assert_eq!(view.ndim(), 2);
994        assert_eq!(view.dims(), &[2, 3]);
995        assert_eq!(view.strides(), &[3, 1]);
996        assert_eq!(view.len(), 6);
997    }
998
999    #[test]
1000    fn test_strided_view_get() {
1001        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1002        let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1003        assert_eq!(view.get(&[0, 0]), 1.0);
1004        assert_eq!(view.get(&[0, 1]), 2.0);
1005        assert_eq!(view.get(&[0, 2]), 3.0);
1006        assert_eq!(view.get(&[1, 0]), 4.0);
1007        assert_eq!(view.get(&[1, 2]), 6.0);
1008    }
1009
1010    #[test]
1011    fn test_strided_view_col_major() {
1012        // Column-major: strides [1, 2] for 2x3
1013        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1014        let view = StridedView::<f64>::new(&data, &[2, 3], &[1, 2], 0).unwrap();
1015        assert_eq!(view.get(&[0, 0]), 1.0); // data[0]
1016        assert_eq!(view.get(&[1, 0]), 2.0); // data[1]
1017        assert_eq!(view.get(&[0, 1]), 3.0); // data[2]
1018        assert_eq!(view.get(&[1, 1]), 4.0); // data[3]
1019        assert_eq!(view.get(&[0, 2]), 5.0); // data[4]
1020        assert_eq!(view.get(&[1, 2]), 6.0); // data[5]
1021    }
1022
1023    #[test]
1024    fn test_strided_view_permute() {
1025        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1026        let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1027        let perm = view.permute(&[1, 0]).unwrap();
1028        assert_eq!(perm.dims(), &[3, 2]);
1029        assert_eq!(perm.strides(), &[1, 3]);
1030        assert_eq!(perm.get(&[0, 0]), 1.0);
1031        assert_eq!(perm.get(&[1, 0]), 2.0);
1032        assert_eq!(perm.get(&[0, 1]), 4.0);
1033    }
1034
1035    #[test]
1036    fn test_strided_view_transpose_2d() {
1037        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1038        let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1039        let t = view.transpose_2d().unwrap();
1040        assert_eq!(t.dims(), &[3, 2]);
1041        assert_eq!(t.get(&[0, 0]), 1.0);
1042        assert_eq!(t.get(&[1, 0]), 2.0);
1043        assert_eq!(t.get(&[0, 1]), 4.0);
1044    }
1045
1046    #[test]
1047    fn test_strided_view_conj() {
1048        let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1049        let view = StridedView::<Complex64>::new(&data, &[2], &[1], 0).unwrap();
1050        let c = view.conj();
1051        assert_eq!(c.get(&[0]), Complex64::new(1.0, -2.0));
1052        assert_eq!(c.get(&[1]), Complex64::new(3.0, -4.0));
1053    }
1054
1055    #[test]
1056    fn test_strided_view_adjoint_2d() {
1057        let data = vec![
1058            Complex64::new(1.0, 2.0),
1059            Complex64::new(3.0, 4.0),
1060            Complex64::new(5.0, 6.0),
1061            Complex64::new(7.0, 8.0),
1062        ];
1063        // 2x2 row-major
1064        let view = StridedView::<Complex64>::new(&data, &[2, 2], &[2, 1], 0).unwrap();
1065        let adj = view.adjoint_2d().unwrap();
1066        assert_eq!(adj.dims(), &[2, 2]);
1067        // Adjoint: conj + transpose
1068        assert_eq!(adj.get(&[0, 0]), Complex64::new(1.0, -2.0));
1069        assert_eq!(adj.get(&[1, 0]), Complex64::new(3.0, -4.0));
1070        assert_eq!(adj.get(&[0, 1]), Complex64::new(5.0, -6.0));
1071    }
1072
1073    #[test]
1074    fn test_strided_view_broadcast() {
1075        let data = vec![1.0, 2.0, 3.0];
1076        let view = StridedView::<f64>::new(&data, &[1, 3], &[3, 1], 0).unwrap();
1077        let broad = view.broadcast(&[4, 3]).unwrap();
1078        assert_eq!(broad.dims(), &[4, 3]);
1079        for i in 0..4 {
1080            assert_eq!(broad.get(&[i, 0]), 1.0);
1081            assert_eq!(broad.get(&[i, 1]), 2.0);
1082            assert_eq!(broad.get(&[i, 2]), 3.0);
1083        }
1084    }
1085
1086    #[test]
1087    fn test_strided_view_mut() {
1088        let mut data = vec![0.0; 6];
1089        {
1090            let mut view = StridedViewMut::<f64>::new(&mut data, &[2, 3], &[3, 1], 0).unwrap();
1091            view.set(&[0, 0], 1.0);
1092            view.set(&[1, 2], 6.0);
1093        }
1094        assert_eq!(data[0], 1.0);
1095        assert_eq!(data[5], 6.0);
1096    }
1097
1098    #[test]
1099    fn test_strided_view_mut_as_view() {
1100        let mut data = vec![1.0, 2.0, 3.0];
1101        let vm = StridedViewMut::<f64>::new(&mut data, &[3], &[1], 0).unwrap();
1102        let v = vm.as_view();
1103        assert_eq!(v.get(&[0]), 1.0);
1104        assert_eq!(v.get(&[2]), 3.0);
1105    }
1106
1107    #[test]
1108    fn test_strided_tensor_col_major() {
1109        let t = StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
1110        assert_eq!(t.dims(), &[2, 3]);
1111        assert_eq!(t.strides(), &[1, 2]); // column-major
1112        assert_eq!(t.get(&[0, 0]), 0.0);
1113        assert_eq!(t.get(&[1, 0]), 3.0);
1114        assert_eq!(t.get(&[0, 1]), 1.0);
1115        assert_eq!(t.get(&[1, 2]), 5.0);
1116    }
1117
1118    #[test]
1119    fn test_strided_tensor_row_major() {
1120        let t = StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
1121        assert_eq!(t.dims(), &[2, 3]);
1122        assert_eq!(t.strides(), &[3, 1]); // row-major
1123        assert_eq!(t.get(&[0, 0]), 0.0);
1124        assert_eq!(t.get(&[0, 1]), 1.0);
1125        assert_eq!(t.get(&[1, 0]), 3.0);
1126        assert_eq!(t.get(&[1, 2]), 5.0);
1127    }
1128
1129    #[test]
1130    fn test_strided_tensor_view() {
1131        let t =
1132            StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
1133        let v = t.view();
1134        assert_eq!(v.get(&[0, 0]), 0.0);
1135        assert_eq!(v.get(&[1, 0]), 10.0);
1136        assert_eq!(v.get(&[0, 2]), 2.0);
1137    }
1138
1139    #[test]
1140    fn test_strided_tensor_view_mut() {
1141        let mut t = StridedArray::<f64>::col_major(&[2, 3]);
1142        {
1143            let mut vm = t.view_mut();
1144            vm.set(&[1, 2], 42.0);
1145        }
1146        assert_eq!(t.get(&[1, 2]), 42.0);
1147    }
1148
1149    #[test]
1150    fn test_strided_tensor_index() {
1151        let t =
1152            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 10 + idx[1]) as f64);
1153        assert_eq!(t[&[0usize, 0] as &[usize]], 0.0);
1154        assert_eq!(t[&[2usize, 3] as &[usize]], 23.0);
1155    }
1156
1157    #[test]
1158    fn test_strided_tensor_index_mut() {
1159        let mut t = StridedArray::<f64>::row_major(&[2, 3]);
1160        t[&[1usize, 2] as &[usize]] = 99.0;
1161        assert_eq!(t.get(&[1, 2]), 99.0);
1162    }
1163
1164    #[test]
1165    fn test_validate_bounds_ok() {
1166        assert!(validate_bounds(6, &[2, 3], &[3, 1], 0).is_ok());
1167        assert!(validate_bounds(6, &[2, 3], &[1, 2], 0).is_ok());
1168    }
1169
1170    #[test]
1171    fn test_validate_bounds_out_of_range() {
1172        assert!(validate_bounds(5, &[2, 3], &[3, 1], 0).is_err());
1173    }
1174
1175    #[test]
1176    fn test_validate_bounds_empty() {
1177        assert!(validate_bounds(0, &[0, 3], &[3, 1], 0).is_ok());
1178    }
1179
1180    #[test]
1181    fn test_validate_bounds_with_offset() {
1182        assert!(validate_bounds(7, &[2, 3], &[3, 1], 1).is_ok());
1183        assert!(validate_bounds(6, &[2, 3], &[3, 1], 1).is_err());
1184    }
1185
1186    #[test]
1187    fn test_strided_tensor_3d() {
1188        let t = StridedArray::<f64>::from_fn_col_major(&[2, 3, 4], |idx| {
1189            (idx[0] * 100 + idx[1] * 10 + idx[2]) as f64
1190        });
1191        assert_eq!(t.ndim(), 3);
1192        assert_eq!(t.strides(), &[1, 2, 6]); // column-major
1193        assert_eq!(t.get(&[0, 0, 0]), 0.0);
1194        assert_eq!(t.get(&[1, 0, 0]), 100.0);
1195        assert_eq!(t.get(&[0, 1, 0]), 10.0);
1196        assert_eq!(t.get(&[0, 0, 1]), 1.0);
1197        assert_eq!(t.get(&[1, 2, 3]), 123.0);
1198    }
1199
1200    #[test]
1201    fn test_diagonal_view_2d() {
1202        // A[i,i] shape=[3,3] row-major strides=[3,1]
1203        // diagonal: shape=[3] strides=[4] (3+1)
1204        let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
1205        let view = StridedView::<f64>::new(&data, &[3, 3], &[3, 1], 0).unwrap();
1206        let diag = view.diagonal_view(&[(0, 1)]).unwrap();
1207        assert_eq!(diag.dims(), &[3]);
1208        assert_eq!(diag.strides(), &[4]);
1209        assert_eq!(diag.get(&[0]), 0.0); // A[0,0]
1210        assert_eq!(diag.get(&[1]), 4.0); // A[1,1]
1211        assert_eq!(diag.get(&[2]), 8.0); // A[2,2]
1212    }
1213
1214    #[test]
1215    fn test_diagonal_view_3d_adjacent() {
1216        // A[i,i,j] shape=[2,2,3] row-major strides=[6,3,1]
1217        // diagonal over (0,1): shape=[2,3] strides=[9,1] (6+3)
1218        let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
1219        let view = StridedView::<f64>::new(&data, &[2, 2, 3], &[6, 3, 1], 0).unwrap();
1220        let diag = view.diagonal_view(&[(0, 1)]).unwrap();
1221        assert_eq!(diag.dims(), &[2, 3]);
1222        assert_eq!(diag.strides(), &[9, 1]);
1223        assert_eq!(diag.get(&[0, 0]), 0.0);
1224        assert_eq!(diag.get(&[0, 2]), 2.0);
1225        assert_eq!(diag.get(&[1, 0]), 9.0);
1226    }
1227
1228    #[test]
1229    fn test_diagonal_view_3d_non_adjacent() {
1230        // A[i,j,i] shape=[2,3,2] row-major strides=[6,2,1]
1231        // diagonal over (0,2): shape=[2,3] strides=[7,2] (6+1)
1232        let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
1233        let view = StridedView::<f64>::new(&data, &[2, 3, 2], &[6, 2, 1], 0).unwrap();
1234        let diag = view.diagonal_view(&[(0, 2)]).unwrap();
1235        assert_eq!(diag.dims(), &[2, 3]);
1236        assert_eq!(diag.strides(), &[7, 2]);
1237        assert_eq!(diag.get(&[0, 0]), 0.0);
1238        assert_eq!(diag.get(&[0, 1]), 2.0);
1239        assert_eq!(diag.get(&[1, 0]), 7.0);
1240        assert_eq!(diag.get(&[1, 2]), 11.0);
1241    }
1242
1243    #[test]
1244    fn test_diagonal_view_two_pairs() {
1245        // A[i,j,i,j] shape=[2,3,2,3] -> A_diag[i,j] shape=[2,3]
1246        let data: Vec<f64> = (0..36).map(|x| x as f64).collect();
1247        let view = StridedView::<f64>::new(&data, &[2, 3, 2, 3], &[18, 6, 3, 1], 0).unwrap();
1248        let diag = view.diagonal_view(&[(0, 2), (1, 3)]).unwrap();
1249        assert_eq!(diag.dims(), &[2, 3]);
1250        assert_eq!(diag.strides(), &[21, 7]);
1251        assert_eq!(diag.get(&[0, 0]), 0.0);
1252        assert_eq!(diag.get(&[1, 1]), 28.0);
1253    }
1254
1255    #[test]
1256    fn test_custom_type_without_element_op_apply() {
1257        // Custom Copy type that does NOT implement ElementOpApply.
1258        // Should work with Identity views (StridedView<T> default).
1259        #[derive(Debug, Clone, Copy, PartialEq)]
1260        struct MyScalar(f64);
1261
1262        impl Default for MyScalar {
1263            fn default() -> Self {
1264                MyScalar(0.0)
1265            }
1266        }
1267
1268        // Create array with custom type
1269        let data = vec![MyScalar(1.0), MyScalar(2.0), MyScalar(3.0), MyScalar(4.0)];
1270        let arr = StridedArray::from_parts(data, &[2, 2], &[2, 1], 0).unwrap();
1271
1272        // Access via Identity view (default)
1273        assert_eq!(arr.get(&[0, 0]), MyScalar(1.0));
1274        assert_eq!(arr.get(&[1, 0]), MyScalar(3.0));
1275
1276        // Create view and access elements
1277        let view: StridedView<MyScalar> = arr.view();
1278        assert_eq!(view.get(&[0, 1]), MyScalar(2.0));
1279
1280        // Permute view (metadata-only, no ElementOpApply needed)
1281        let perm = view.permute(&[1, 0]).unwrap();
1282        assert_eq!(perm.get(&[0, 0]), MyScalar(1.0));
1283        assert_eq!(perm.get(&[0, 1]), MyScalar(3.0)); // transposed
1284    }
1285}