strided_view/
view.rs

1//! Julia-like dynamic-rank strided view types.
2//!
3//! This module provides the canonical view types for strided operations,
4//! matching Julia's StridedViews.jl data model:
5//!
6//! - [`StridedView`]: Immutable dynamic-rank strided view with lazy element operations
7//! - [`StridedViewMut`]: Mutable dynamic-rank strided view (Identity op only)
8//! - [`StridedArray`]: Owned strided multidimensional array
9
10use std::marker::PhantomData;
11use std::ops::{Index, IndexMut};
12use std::sync::Arc;
13
14use crate::element_op::{ComposableElementOp, ElementOp, ElementOpApply, Identity};
15use crate::{Result, StridedError};
16
17// ============================================================================
18// Validation helpers
19// ============================================================================
20
21/// Validate that all accessed offsets stay within `[0, len)`.
22fn validate_bounds(len: usize, dims: &[usize], strides: &[isize], offset: isize) -> Result<()> {
23    if dims.len() != strides.len() {
24        return Err(StridedError::StrideLengthMismatch);
25    }
26    // Empty array - no access needed
27    if dims.iter().any(|&d| d == 0) {
28        return Ok(());
29    }
30    // Compute min and max offsets
31    let mut min_offset = offset;
32    let mut max_offset = offset;
33    for (&dim, &stride) in dims.iter().zip(strides.iter()) {
34        if dim > 1 {
35            let end = stride
36                .checked_mul(dim as isize - 1)
37                .ok_or(StridedError::OffsetOverflow)?;
38            if end >= 0 {
39                max_offset = max_offset
40                    .checked_add(end)
41                    .ok_or(StridedError::OffsetOverflow)?;
42            } else {
43                min_offset = min_offset
44                    .checked_add(end)
45                    .ok_or(StridedError::OffsetOverflow)?;
46            }
47        }
48    }
49    if min_offset < 0 || max_offset < 0 {
50        return Err(StridedError::OffsetOverflow);
51    }
52    if max_offset as usize >= len {
53        return Err(StridedError::OffsetOverflow);
54    }
55    Ok(())
56}
57
58/// Compute column-major strides (Julia default: first index varies fastest).
59pub fn col_major_strides(dims: &[usize]) -> Vec<isize> {
60    let rank = dims.len();
61    if rank == 0 {
62        return vec![];
63    }
64    let mut strides = vec![1isize; rank];
65    for i in 1..rank {
66        strides[i] = strides[i - 1] * dims[i - 1] as isize;
67    }
68    strides
69}
70
71/// Compute row-major strides (C default: last index varies fastest).
72pub fn row_major_strides(dims: &[usize]) -> Vec<isize> {
73    let rank = dims.len();
74    if rank == 0 {
75        return vec![];
76    }
77    let mut strides = vec![1isize; rank];
78    for i in (0..rank - 1).rev() {
79        strides[i] = strides[i + 1] * dims[i + 1] as isize;
80    }
81    strides
82}
83
84// ============================================================================
85// StridedView
86// ============================================================================
87
88/// Dynamic-rank immutable strided view with lazy element operations.
89///
90/// This is the Julia-equivalent `StridedView` type with:
91/// - Dynamic rank (dims/strides are heap-allocated)
92/// - Lazy element operations via the `Op` type parameter
93/// - Zero-copy transformations (permute, transpose, adjoint, conj)
94///
95/// # Type Parameters
96/// - `'a`: Lifetime of the underlying data
97/// - `T`: Element type
98/// - `Op`: Element operation applied lazily on access (default: `Identity`)
99pub struct StridedView<'a, T, Op = Identity> {
100    ptr: *const T,
101    data: &'a [T],
102    dims: Arc<[usize]>,
103    strides: Arc<[isize]>,
104    offset: isize,
105    _op: PhantomData<Op>,
106}
107
108unsafe impl<T: Send, Op: Send> Send for StridedView<'_, T, Op> {}
109unsafe impl<T: Sync, Op: Sync> Sync for StridedView<'_, T, Op> {}
110
111impl<T, Op> Clone for StridedView<'_, T, Op> {
112    fn clone(&self) -> Self {
113        Self {
114            ptr: self.ptr,
115            data: self.data,
116            dims: self.dims.clone(),
117            strides: self.strides.clone(),
118            offset: self.offset,
119            _op: PhantomData,
120        }
121    }
122}
123
124impl<T: std::fmt::Debug, Op> std::fmt::Debug for StridedView<'_, T, Op> {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        f.debug_struct("StridedView")
127            .field("dims", &self.dims)
128            .field("strides", &self.strides)
129            .field("offset", &self.offset)
130            .finish()
131    }
132}
133
134impl<'a, T, Op> StridedView<'a, T, Op> {
135    /// Create a new immutable strided view from a borrowed slice.
136    pub fn new(data: &'a [T], dims: &[usize], strides: &[isize], offset: isize) -> Result<Self> {
137        validate_bounds(data.len(), dims, strides, offset)?;
138        let ptr = unsafe { data.as_ptr().offset(offset) };
139        Ok(Self {
140            ptr,
141            data,
142            dims: Arc::from(dims),
143            strides: Arc::from(strides),
144            offset,
145            _op: PhantomData,
146        })
147    }
148
149    /// Create a view without bounds checking.
150    ///
151    /// # Safety
152    /// The caller must ensure all index combinations stay within bounds.
153    pub unsafe fn new_unchecked(
154        data: &'a [T],
155        dims: &[usize],
156        strides: &[isize],
157        offset: isize,
158    ) -> Self {
159        let ptr = data.as_ptr().offset(offset);
160        Self {
161            ptr,
162            data,
163            dims: Arc::from(dims),
164            strides: Arc::from(strides),
165            offset,
166            _op: PhantomData,
167        }
168    }
169
170    /// Returns the shape (dimension sizes) of this view.
171    #[inline]
172    pub fn dims(&self) -> &[usize] {
173        &self.dims
174    }
175
176    /// Returns the strides (in units of `T`) for each dimension.
177    #[inline]
178    pub fn strides(&self) -> &[isize] {
179        &self.strides
180    }
181
182    /// Returns the byte offset into the backing data.
183    #[inline]
184    pub fn offset(&self) -> isize {
185        self.offset
186    }
187
188    /// Returns the number of dimensions (rank).
189    #[inline]
190    pub fn ndim(&self) -> usize {
191        self.dims.len()
192    }
193
194    /// Returns the total number of elements.
195    #[inline]
196    pub fn len(&self) -> usize {
197        self.dims.iter().product()
198    }
199
200    /// Returns `true` if any dimension is zero.
201    #[inline]
202    pub fn is_empty(&self) -> bool {
203        self.dims.iter().any(|&d| d == 0)
204    }
205
206    /// Returns a reference to the backing data slice.
207    #[inline]
208    pub fn data(&self) -> &'a [T] {
209        self.data
210    }
211
212    /// Raw const pointer to element at the view's base offset.
213    #[inline]
214    pub fn ptr(&self) -> *const T {
215        self.ptr
216    }
217
218    /// Permute dimensions.
219    pub fn permute(&self, perm: &[usize]) -> Result<StridedView<'a, T, Op>> {
220        let rank = self.dims.len();
221        if perm.len() != rank {
222            return Err(StridedError::RankMismatch(perm.len(), rank));
223        }
224        let mut seen = vec![false; rank];
225        for &p in perm {
226            if p >= rank {
227                return Err(StridedError::InvalidAxis { axis: p, rank });
228            }
229            if seen[p] {
230                return Err(StridedError::InvalidAxis { axis: p, rank });
231            }
232            seen[p] = true;
233        }
234        let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
235        let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
236        Ok(StridedView {
237            ptr: self.ptr,
238            data: self.data,
239            dims: Arc::from(new_dims),
240            strides: Arc::from(new_strides),
241            offset: self.offset,
242            _op: PhantomData,
243        })
244    }
245
246    /// Create a diagonal view by fusing repeated axis pairs via stride trick (zero-copy).
247    ///
248    /// For each pair `(a, b)`:
249    /// - New stride = `strides[a] + strides[b]`
250    /// - New dim = `min(dims[a], dims[b])`
251    /// - The higher-numbered axis is removed
252    /// - Pairs use **original** axis numbering
253    ///
254    /// # Example
255    /// `A[i,i,j]` shape=`[n,n,m]` strides=`[s0,s1,s2]` -> shape=`[n,m]` strides=`[s0+s1, s2]`
256    pub fn diagonal_view(&self, axis_pairs: &[(usize, usize)]) -> Result<StridedView<'a, T, Op>> {
257        let ndim = self.ndim();
258        let mut dims: Vec<usize> = self.dims().to_vec();
259        let mut strides: Vec<isize> = self.strides().to_vec();
260
261        let mut axes_to_remove = Vec::new();
262        for &(a, b) in axis_pairs {
263            let (lo, hi) = if a < b { (a, b) } else { (b, a) };
264            if lo >= ndim || hi >= ndim {
265                return Err(StridedError::InvalidAxis {
266                    axis: hi,
267                    rank: ndim,
268                });
269            }
270            if lo == hi {
271                return Err(StridedError::InvalidAxis {
272                    axis: lo,
273                    rank: ndim,
274                });
275            }
276            strides[lo] += strides[hi];
277            dims[lo] = dims[lo].min(dims[hi]);
278            axes_to_remove.push(hi);
279        }
280
281        axes_to_remove.sort_unstable();
282        axes_to_remove.dedup();
283        for &ax in axes_to_remove.iter().rev() {
284            dims.remove(ax);
285            strides.remove(ax);
286        }
287
288        unsafe {
289            Ok(StridedView::new_unchecked(
290                self.data(),
291                &dims,
292                &strides,
293                self.offset(),
294            ))
295        }
296    }
297
298    /// Broadcast this view to a target shape.
299    ///
300    /// Size-1 dimensions are expanded (stride set to 0) to match target.
301    pub fn broadcast(&self, target_dims: &[usize]) -> Result<StridedView<'a, T, Op>> {
302        if self.dims.len() != target_dims.len() {
303            return Err(StridedError::RankMismatch(
304                self.dims.len(),
305                target_dims.len(),
306            ));
307        }
308        let mut new_strides = Vec::with_capacity(self.dims.len());
309        for i in 0..self.dims.len() {
310            if self.dims[i] == target_dims[i] {
311                new_strides.push(self.strides[i]);
312            } else if self.dims[i] == 1 {
313                new_strides.push(0);
314            } else {
315                return Err(StridedError::ShapeMismatch(
316                    self.dims.to_vec(),
317                    target_dims.to_vec(),
318                ));
319            }
320        }
321        Ok(StridedView {
322            ptr: self.ptr,
323            data: self.data,
324            dims: Arc::from(target_dims),
325            strides: Arc::from(new_strides),
326            offset: self.offset,
327            _op: PhantomData,
328        })
329    }
330}
331
332/// Composition methods: require `T: ElementOpApply` and `Op: ComposableElementOp<T>`.
333impl<'a, T: Copy + ElementOpApply, Op: ComposableElementOp<T>> StridedView<'a, T, Op> {
334    /// Transpose a 2D view: reverses dimensions and composes the Transpose element op.
335    ///
336    /// Julia equivalent: `Base.transpose(a::AbstractStridedView{<:Any, 2})`
337    pub fn transpose_2d(&self) -> Result<StridedView<'a, T, Op::ComposeTranspose>> {
338        if self.dims.len() != 2 {
339            return Err(StridedError::RankMismatch(self.dims.len(), 2));
340        }
341        Ok(StridedView {
342            ptr: self.ptr,
343            data: self.data,
344            dims: Arc::new([self.dims[1], self.dims[0]]),
345            strides: Arc::new([self.strides[1], self.strides[0]]),
346            offset: self.offset,
347            _op: PhantomData,
348        })
349    }
350
351    /// Adjoint (conjugate transpose) of a 2D view.
352    ///
353    /// Julia equivalent: `Base.adjoint(a::AbstractStridedView{<:Any, 2})`
354    pub fn adjoint_2d(&self) -> Result<StridedView<'a, T, Op::ComposeAdjoint>> {
355        if self.dims.len() != 2 {
356            return Err(StridedError::RankMismatch(self.dims.len(), 2));
357        }
358        Ok(StridedView {
359            ptr: self.ptr,
360            data: self.data,
361            dims: Arc::new([self.dims[1], self.dims[0]]),
362            strides: Arc::new([self.strides[1], self.strides[0]]),
363            offset: self.offset,
364            _op: PhantomData,
365        })
366    }
367
368    /// Complex conjugate (compose Conj without changing dims/strides).
369    ///
370    /// Julia equivalent: `Base.conj(a::AbstractStridedView)`
371    pub fn conj(&self) -> StridedView<'a, T, Op::ComposeConj> {
372        StridedView {
373            ptr: self.ptr,
374            data: self.data,
375            dims: self.dims.clone(),
376            strides: self.strides.clone(),
377            offset: self.offset,
378            _op: PhantomData,
379        }
380    }
381}
382
383/// Element access: requires `Op: ElementOp<T>`.
384impl<'a, T: Copy, Op: ElementOp<T>> StridedView<'a, T, Op> {
385    /// Get an element with the element operation applied.
386    pub fn get(&self, indices: &[usize]) -> T {
387        assert_eq!(indices.len(), self.dims.len(), "wrong number of indices");
388        let mut idx = 0isize;
389        for (i, &index) in indices.iter().enumerate() {
390            assert!(
391                index < self.dims[i],
392                "index {} out of bounds for dim {}",
393                index,
394                self.dims[i]
395            );
396            idx += index as isize * self.strides[i];
397        }
398        Op::apply(unsafe { *self.ptr.offset(idx) })
399    }
400
401    /// Get an element without bounds checking.
402    ///
403    /// # Safety
404    /// Caller must ensure indices are within bounds.
405    #[inline]
406    pub unsafe fn get_unchecked(&self, indices: &[usize]) -> T {
407        let mut idx = 0isize;
408        for (i, &index) in indices.iter().enumerate() {
409            idx += index as isize * self.strides[i];
410        }
411        Op::apply(*self.ptr.offset(idx))
412    }
413}
414
415// ============================================================================
416// StridedViewMut
417// ============================================================================
418
419/// Dynamic-rank mutable strided view.
420///
421/// Always uses `Identity` element operation for write simplicity.
422/// Julia typically applies ops on the read side.
423pub struct StridedViewMut<'a, T> {
424    ptr: *mut T,
425    data: &'a mut [T],
426    dims: Arc<[usize]>,
427    strides: Arc<[isize]>,
428    offset: isize,
429}
430
431unsafe impl<T: Send> Send for StridedViewMut<'_, T> {}
432
433impl<T: std::fmt::Debug> std::fmt::Debug for StridedViewMut<'_, T> {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        f.debug_struct("StridedViewMut")
436            .field("dims", &self.dims)
437            .field("strides", &self.strides)
438            .field("offset", &self.offset)
439            .finish()
440    }
441}
442
443impl<'a, T> StridedViewMut<'a, T> {
444    /// Create a new mutable strided view.
445    pub fn new(
446        data: &'a mut [T],
447        dims: &[usize],
448        strides: &[isize],
449        offset: isize,
450    ) -> Result<Self> {
451        validate_bounds(data.len(), dims, strides, offset)?;
452        let ptr = unsafe { data.as_mut_ptr().offset(offset) };
453        Ok(Self {
454            ptr,
455            data,
456            dims: Arc::from(dims),
457            strides: Arc::from(strides),
458            offset,
459        })
460    }
461
462    /// Create without bounds checking.
463    ///
464    /// # Safety
465    /// Caller must ensure all index combinations stay within bounds.
466    pub unsafe fn new_unchecked(
467        data: &'a mut [T],
468        dims: &[usize],
469        strides: &[isize],
470        offset: isize,
471    ) -> Self {
472        let ptr = data.as_mut_ptr().offset(offset);
473        Self {
474            ptr,
475            data,
476            dims: Arc::from(dims),
477            strides: Arc::from(strides),
478            offset,
479        }
480    }
481
482    /// Returns the shape (dimension sizes) of this view.
483    #[inline]
484    pub fn dims(&self) -> &[usize] {
485        &self.dims
486    }
487
488    /// Returns the strides (in units of `T`) for each dimension.
489    #[inline]
490    pub fn strides(&self) -> &[isize] {
491        &self.strides
492    }
493
494    /// Returns the byte offset into the backing data.
495    #[inline]
496    pub fn offset(&self) -> isize {
497        self.offset
498    }
499
500    /// Returns the number of dimensions (rank).
501    #[inline]
502    pub fn ndim(&self) -> usize {
503        self.dims.len()
504    }
505
506    /// Returns the total number of elements.
507    #[inline]
508    pub fn len(&self) -> usize {
509        self.dims.iter().product()
510    }
511
512    /// Returns `true` if any dimension is zero.
513    #[inline]
514    pub fn is_empty(&self) -> bool {
515        self.dims.iter().any(|&d| d == 0)
516    }
517
518    /// Raw const pointer to element at the view's base offset.
519    #[inline]
520    pub fn ptr(&self) -> *const T {
521        self.ptr as *const T
522    }
523
524    /// Raw mutable pointer to element at the view's base offset.
525    #[inline]
526    pub fn as_mut_ptr(&self) -> *mut T {
527        self.ptr
528    }
529
530    /// Permute dimensions, consuming the mutable view.
531    ///
532    /// Returns a new mutable view with reordered dimensions and strides.
533    /// Takes `self` by value to prevent aliasing of mutable views.
534    pub fn permute(self, perm: &[usize]) -> Result<StridedViewMut<'a, T>> {
535        let rank = self.dims.len();
536        if perm.len() != rank {
537            return Err(StridedError::RankMismatch(perm.len(), rank));
538        }
539        let mut seen = vec![false; rank];
540        for &p in perm {
541            if p >= rank {
542                return Err(StridedError::InvalidAxis { axis: p, rank });
543            }
544            if seen[p] {
545                return Err(StridedError::InvalidAxis { axis: p, rank });
546            }
547            seen[p] = true;
548        }
549        let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
550        let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
551        Ok(StridedViewMut {
552            ptr: self.ptr,
553            data: self.data,
554            dims: Arc::from(new_dims),
555            strides: Arc::from(new_strides),
556            offset: self.offset,
557        })
558    }
559
560    /// Reborrow as an immutable view.
561    pub fn as_view(&self) -> StridedView<'_, T, Identity> {
562        StridedView {
563            ptr: self.ptr as *const T,
564            data: unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.data.len()) },
565            dims: self.dims.clone(),
566            strides: self.strides.clone(),
567            offset: self.offset,
568            _op: PhantomData,
569        }
570    }
571}
572
573impl<'a, T: Copy> StridedViewMut<'a, T> {
574    /// Get an element.
575    pub fn get(&self, indices: &[usize]) -> T {
576        assert_eq!(indices.len(), self.dims.len());
577        let mut idx = 0isize;
578        for (i, &index) in indices.iter().enumerate() {
579            assert!(index < self.dims[i]);
580            idx += index as isize * self.strides[i];
581        }
582        unsafe { *self.ptr.offset(idx) }
583    }
584
585    /// Set an element.
586    pub fn set(&mut self, indices: &[usize], value: T) {
587        assert_eq!(indices.len(), self.dims.len());
588        let mut idx = 0isize;
589        for (i, &index) in indices.iter().enumerate() {
590            assert!(index < self.dims[i]);
591            idx += index as isize * self.strides[i];
592        }
593        unsafe {
594            *self.ptr.offset(idx) = value;
595        }
596    }
597}
598
599// ============================================================================
600// StridedArray
601// ============================================================================
602
603/// Owned strided multidimensional array.
604///
605/// Supports both column-major (Julia default) and row-major (C default) layouts.
606pub struct StridedArray<T> {
607    data: Vec<T>,
608    dims: Arc<[usize]>,
609    strides: Arc<[isize]>,
610    offset: isize,
611}
612
613impl<T: std::fmt::Debug> std::fmt::Debug for StridedArray<T> {
614    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615        f.debug_struct("StridedArray")
616            .field("dims", &self.dims)
617            .field("strides", &self.strides)
618            .field("offset", &self.offset)
619            .finish()
620    }
621}
622
623impl<T: Clone> Clone for StridedArray<T> {
624    fn clone(&self) -> Self {
625        Self {
626            data: self.data.clone(),
627            dims: self.dims.clone(),
628            strides: self.strides.clone(),
629            offset: self.offset,
630        }
631    }
632}
633
634impl<T: Clone + Default> StridedArray<T> {
635    /// Create a column-major (Julia default) tensor filled with Default values.
636    pub fn col_major(dims: &[usize]) -> Self {
637        let total: usize = dims.iter().product();
638        let data = vec![T::default(); total];
639        let strides = col_major_strides(dims);
640        Self {
641            data,
642            dims: Arc::from(dims),
643            strides: Arc::from(strides),
644            offset: 0,
645        }
646    }
647
648    /// Create a row-major (C default) tensor filled with Default values.
649    pub fn row_major(dims: &[usize]) -> Self {
650        let total: usize = dims.iter().product();
651        let data = vec![T::default(); total];
652        let strides = row_major_strides(dims);
653        Self {
654            data,
655            dims: Arc::from(dims),
656            strides: Arc::from(strides),
657            offset: 0,
658        }
659    }
660
661    /// Create a column-major tensor with values produced by a function.
662    ///
663    /// The function is called with indices in column-major iteration order.
664    pub fn from_fn_col_major(dims: &[usize], mut f: impl FnMut(&[usize]) -> T) -> Self {
665        let total: usize = dims.iter().product();
666        let strides = col_major_strides(dims);
667        let rank = dims.len();
668        let mut data = Vec::with_capacity(total);
669        let mut idx = vec![0usize; rank];
670        for _ in 0..total {
671            data.push(f(&idx));
672            for d in 0..rank {
673                idx[d] += 1;
674                if idx[d] < dims[d] {
675                    break;
676                }
677                idx[d] = 0;
678            }
679        }
680        Self {
681            data,
682            dims: Arc::from(dims),
683            strides: Arc::from(strides),
684            offset: 0,
685        }
686    }
687
688    /// Create a row-major tensor with values produced by a function.
689    ///
690    /// The function is called with indices in row-major iteration order.
691    pub fn from_fn_row_major(dims: &[usize], mut f: impl FnMut(&[usize]) -> T) -> Self {
692        let total: usize = dims.iter().product();
693        let strides = row_major_strides(dims);
694        let rank = dims.len();
695        let mut data = Vec::with_capacity(total);
696        let mut idx = vec![0usize; rank];
697        for _ in 0..total {
698            data.push(f(&idx));
699            for d in (0..rank).rev() {
700                idx[d] += 1;
701                if idx[d] < dims[d] {
702                    break;
703                }
704                idx[d] = 0;
705            }
706        }
707        Self {
708            data,
709            dims: Arc::from(dims),
710            strides: Arc::from(strides),
711            offset: 0,
712        }
713    }
714}
715
716impl<T> StridedArray<T> {
717    /// Create from raw parts.
718    pub fn from_parts(
719        data: Vec<T>,
720        dims: &[usize],
721        strides: &[isize],
722        offset: isize,
723    ) -> Result<Self> {
724        validate_bounds(data.len(), dims, strides, offset)?;
725        Ok(Self {
726            data,
727            dims: Arc::from(dims),
728            strides: Arc::from(strides),
729            offset,
730        })
731    }
732
733    /// Returns the shape (dimension sizes) of this array.
734    #[inline]
735    pub fn dims(&self) -> &[usize] {
736        &self.dims
737    }
738
739    /// Returns the strides (in units of `T`) for each dimension.
740    #[inline]
741    pub fn strides(&self) -> &[isize] {
742        &self.strides
743    }
744
745    /// Returns the number of dimensions (rank).
746    #[inline]
747    pub fn ndim(&self) -> usize {
748        self.dims.len()
749    }
750
751    /// Returns the total number of elements.
752    #[inline]
753    pub fn len(&self) -> usize {
754        self.dims.iter().product()
755    }
756
757    /// Returns `true` if any dimension is zero.
758    #[inline]
759    pub fn is_empty(&self) -> bool {
760        self.dims.iter().any(|&d| d == 0)
761    }
762
763    /// Returns a reference to the backing data slice.
764    #[inline]
765    pub fn data(&self) -> &[T] {
766        &self.data
767    }
768
769    /// Returns a mutable reference to the backing data slice.
770    #[inline]
771    pub fn data_mut(&mut self) -> &mut [T] {
772        &mut self.data
773    }
774
775    /// Create an immutable view over this tensor.
776    pub fn view(&self) -> StridedView<'_, T> {
777        let ptr = unsafe { self.data.as_ptr().offset(self.offset) };
778        StridedView {
779            ptr,
780            data: &self.data,
781            dims: self.dims.clone(),
782            strides: self.strides.clone(),
783            offset: self.offset,
784            _op: PhantomData,
785        }
786    }
787
788    /// Create a mutable view over this tensor.
789    pub fn view_mut(&mut self) -> StridedViewMut<'_, T> {
790        let ptr = unsafe { self.data.as_mut_ptr().offset(self.offset) };
791        StridedViewMut {
792            ptr,
793            data: &mut self.data,
794            dims: self.dims.clone(),
795            strides: self.strides.clone(),
796            offset: self.offset,
797        }
798    }
799
800    /// Permute dimensions (metadata-only reorder, no data copy).
801    ///
802    /// Returns a new array with reordered dims and strides.
803    /// The underlying data buffer is not touched.
804    pub fn permuted(self, perm: &[usize]) -> Result<Self> {
805        let rank = self.dims.len();
806        if perm.len() != rank {
807            return Err(StridedError::RankMismatch(perm.len(), rank));
808        }
809        let mut seen = vec![false; rank];
810        for &p in perm {
811            if p >= rank {
812                return Err(StridedError::InvalidAxis { axis: p, rank });
813            }
814            if seen[p] {
815                return Err(StridedError::InvalidAxis { axis: p, rank });
816            }
817            seen[p] = true;
818        }
819        let new_dims: Vec<usize> = perm.iter().map(|&p| self.dims[p]).collect();
820        let new_strides: Vec<isize> = perm.iter().map(|&p| self.strides[p]).collect();
821        Ok(Self {
822            data: self.data,
823            dims: Arc::from(new_dims),
824            strides: Arc::from(new_strides),
825            offset: self.offset,
826        })
827    }
828
829    /// Consume the array and return the backing `Vec<T>`.
830    pub fn into_data(self) -> Vec<T> {
831        self.data
832    }
833
834    /// Iterate over all elements in memory order.
835    pub fn iter(&self) -> std::slice::Iter<'_, T> {
836        self.data.iter()
837    }
838
839    /// Mutable iteration over all elements in memory order.
840    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
841        self.data.iter_mut()
842    }
843}
844
845impl<T: Default> StridedArray<T> {
846    /// Create a column-major tensor reusing an existing buffer.
847    ///
848    /// If `buf` has at least `product(dims)` elements, it is truncated and
849    /// zeroed.  Otherwise a fresh buffer is allocated.
850    pub fn col_major_from_buffer(mut buf: Vec<T>, dims: &[usize]) -> Self {
851        let total: usize = dims.iter().product();
852        if buf.len() >= total {
853            buf.truncate(total);
854        } else {
855            buf.resize_with(total, T::default);
856        }
857        // Zero the reused region
858        for v in buf.iter_mut() {
859            *v = T::default();
860        }
861        let strides = col_major_strides(dims);
862        Self {
863            data: buf,
864            dims: Arc::from(dims),
865            strides: Arc::from(strides),
866            offset: 0,
867        }
868    }
869}
870
871impl<T: Copy> StridedArray<T> {
872    /// Create a column-major tensor with **uninitialized** data.
873    ///
874    /// # Safety
875    /// Caller must write every element before reading.
876    pub unsafe fn col_major_uninit(dims: &[usize]) -> Self {
877        let total: usize = dims.iter().product();
878        let mut data = Vec::with_capacity(total);
879        data.set_len(total);
880        let strides = col_major_strides(dims);
881        Self {
882            data,
883            dims: Arc::from(dims),
884            strides: Arc::from(strides),
885            offset: 0,
886        }
887    }
888
889    /// Reuse an existing buffer as a column-major tensor **without zeroing**.
890    ///
891    /// # Safety
892    /// Caller must write every element before reading.
893    pub unsafe fn col_major_from_buffer_uninit(mut buf: Vec<T>, dims: &[usize]) -> Self {
894        let total: usize = dims.iter().product();
895        if buf.capacity() < total {
896            buf.reserve(total - buf.len());
897        }
898        buf.set_len(total);
899        let strides = col_major_strides(dims);
900        Self {
901            data: buf,
902            dims: Arc::from(dims),
903            strides: Arc::from(strides),
904            offset: 0,
905        }
906    }
907}
908
909impl<T: Copy> StridedArray<T> {
910    /// Get an element by multi-dimensional index.
911    pub fn get(&self, indices: &[usize]) -> T {
912        self.view().get(indices)
913    }
914
915    /// Set an element by multi-dimensional index.
916    pub fn set(&mut self, indices: &[usize], value: T) {
917        assert_eq!(indices.len(), self.dims.len());
918        let mut idx = self.offset;
919        for (i, &index) in indices.iter().enumerate() {
920            assert!(index < self.dims[i]);
921            idx += index as isize * self.strides[i];
922        }
923        self.data[idx as usize] = value;
924    }
925}
926
927impl<T: Copy> Index<&[usize]> for StridedArray<T> {
928    type Output = T;
929
930    fn index(&self, indices: &[usize]) -> &T {
931        let mut idx = self.offset;
932        for (i, &index) in indices.iter().enumerate() {
933            assert!(index < self.dims[i]);
934            idx += index as isize * self.strides[i];
935        }
936        &self.data[idx as usize]
937    }
938}
939
940impl<T: Copy> IndexMut<&[usize]> for StridedArray<T> {
941    fn index_mut(&mut self, indices: &[usize]) -> &mut T {
942        let mut idx = self.offset;
943        for (i, &index) in indices.iter().enumerate() {
944            assert!(index < self.dims[i]);
945            idx += index as isize * self.strides[i];
946        }
947        &mut self.data[idx as usize]
948    }
949}
950
951// ============================================================================
952// Tests
953// ============================================================================
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958    use num_complex::Complex64;
959
960    #[test]
961    fn test_col_major_strides() {
962        assert_eq!(col_major_strides(&[3, 4]), vec![1, 3]);
963        assert_eq!(col_major_strides(&[2, 3, 4]), vec![1, 2, 6]);
964    }
965
966    #[test]
967    fn test_row_major_strides() {
968        assert_eq!(row_major_strides(&[3, 4]), vec![4, 1]);
969        assert_eq!(row_major_strides(&[2, 3, 4]), vec![12, 4, 1]);
970    }
971
972    #[test]
973    fn test_strided_view_new() {
974        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
975        let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
976        assert_eq!(view.ndim(), 2);
977        assert_eq!(view.dims(), &[2, 3]);
978        assert_eq!(view.strides(), &[3, 1]);
979        assert_eq!(view.len(), 6);
980    }
981
982    #[test]
983    fn test_strided_view_get() {
984        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
985        let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
986        assert_eq!(view.get(&[0, 0]), 1.0);
987        assert_eq!(view.get(&[0, 1]), 2.0);
988        assert_eq!(view.get(&[0, 2]), 3.0);
989        assert_eq!(view.get(&[1, 0]), 4.0);
990        assert_eq!(view.get(&[1, 2]), 6.0);
991    }
992
993    #[test]
994    fn test_strided_view_col_major() {
995        // Column-major: strides [1, 2] for 2x3
996        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
997        let view = StridedView::<f64>::new(&data, &[2, 3], &[1, 2], 0).unwrap();
998        assert_eq!(view.get(&[0, 0]), 1.0); // data[0]
999        assert_eq!(view.get(&[1, 0]), 2.0); // data[1]
1000        assert_eq!(view.get(&[0, 1]), 3.0); // data[2]
1001        assert_eq!(view.get(&[1, 1]), 4.0); // data[3]
1002        assert_eq!(view.get(&[0, 2]), 5.0); // data[4]
1003        assert_eq!(view.get(&[1, 2]), 6.0); // data[5]
1004    }
1005
1006    #[test]
1007    fn test_strided_view_permute() {
1008        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1009        let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1010        let perm = view.permute(&[1, 0]).unwrap();
1011        assert_eq!(perm.dims(), &[3, 2]);
1012        assert_eq!(perm.strides(), &[1, 3]);
1013        assert_eq!(perm.get(&[0, 0]), 1.0);
1014        assert_eq!(perm.get(&[1, 0]), 2.0);
1015        assert_eq!(perm.get(&[0, 1]), 4.0);
1016    }
1017
1018    #[test]
1019    fn test_strided_view_transpose_2d() {
1020        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1021        let view = StridedView::<f64>::new(&data, &[2, 3], &[3, 1], 0).unwrap();
1022        let t = view.transpose_2d().unwrap();
1023        assert_eq!(t.dims(), &[3, 2]);
1024        assert_eq!(t.get(&[0, 0]), 1.0);
1025        assert_eq!(t.get(&[1, 0]), 2.0);
1026        assert_eq!(t.get(&[0, 1]), 4.0);
1027    }
1028
1029    #[test]
1030    fn test_strided_view_conj() {
1031        let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
1032        let view = StridedView::<Complex64>::new(&data, &[2], &[1], 0).unwrap();
1033        let c = view.conj();
1034        assert_eq!(c.get(&[0]), Complex64::new(1.0, -2.0));
1035        assert_eq!(c.get(&[1]), Complex64::new(3.0, -4.0));
1036    }
1037
1038    #[test]
1039    fn test_strided_view_adjoint_2d() {
1040        let data = vec![
1041            Complex64::new(1.0, 2.0),
1042            Complex64::new(3.0, 4.0),
1043            Complex64::new(5.0, 6.0),
1044            Complex64::new(7.0, 8.0),
1045        ];
1046        // 2x2 row-major
1047        let view = StridedView::<Complex64>::new(&data, &[2, 2], &[2, 1], 0).unwrap();
1048        let adj = view.adjoint_2d().unwrap();
1049        assert_eq!(adj.dims(), &[2, 2]);
1050        // Adjoint: conj + transpose
1051        assert_eq!(adj.get(&[0, 0]), Complex64::new(1.0, -2.0));
1052        assert_eq!(adj.get(&[1, 0]), Complex64::new(3.0, -4.0));
1053        assert_eq!(adj.get(&[0, 1]), Complex64::new(5.0, -6.0));
1054    }
1055
1056    #[test]
1057    fn test_strided_view_broadcast() {
1058        let data = vec![1.0, 2.0, 3.0];
1059        let view = StridedView::<f64>::new(&data, &[1, 3], &[3, 1], 0).unwrap();
1060        let broad = view.broadcast(&[4, 3]).unwrap();
1061        assert_eq!(broad.dims(), &[4, 3]);
1062        for i in 0..4 {
1063            assert_eq!(broad.get(&[i, 0]), 1.0);
1064            assert_eq!(broad.get(&[i, 1]), 2.0);
1065            assert_eq!(broad.get(&[i, 2]), 3.0);
1066        }
1067    }
1068
1069    #[test]
1070    fn test_strided_view_mut() {
1071        let mut data = vec![0.0; 6];
1072        {
1073            let mut view = StridedViewMut::<f64>::new(&mut data, &[2, 3], &[3, 1], 0).unwrap();
1074            view.set(&[0, 0], 1.0);
1075            view.set(&[1, 2], 6.0);
1076        }
1077        assert_eq!(data[0], 1.0);
1078        assert_eq!(data[5], 6.0);
1079    }
1080
1081    #[test]
1082    fn test_strided_view_mut_as_view() {
1083        let mut data = vec![1.0, 2.0, 3.0];
1084        let vm = StridedViewMut::<f64>::new(&mut data, &[3], &[1], 0).unwrap();
1085        let v = vm.as_view();
1086        assert_eq!(v.get(&[0]), 1.0);
1087        assert_eq!(v.get(&[2]), 3.0);
1088    }
1089
1090    #[test]
1091    fn test_strided_tensor_col_major() {
1092        let t = StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
1093        assert_eq!(t.dims(), &[2, 3]);
1094        assert_eq!(t.strides(), &[1, 2]); // column-major
1095        assert_eq!(t.get(&[0, 0]), 0.0);
1096        assert_eq!(t.get(&[1, 0]), 3.0);
1097        assert_eq!(t.get(&[0, 1]), 1.0);
1098        assert_eq!(t.get(&[1, 2]), 5.0);
1099    }
1100
1101    #[test]
1102    fn test_strided_tensor_row_major() {
1103        let t = StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1]) as f64);
1104        assert_eq!(t.dims(), &[2, 3]);
1105        assert_eq!(t.strides(), &[3, 1]); // row-major
1106        assert_eq!(t.get(&[0, 0]), 0.0);
1107        assert_eq!(t.get(&[0, 1]), 1.0);
1108        assert_eq!(t.get(&[1, 0]), 3.0);
1109        assert_eq!(t.get(&[1, 2]), 5.0);
1110    }
1111
1112    #[test]
1113    fn test_strided_tensor_view() {
1114        let t =
1115            StridedArray::<f64>::from_fn_col_major(&[2, 3], |idx| (idx[0] * 10 + idx[1]) as f64);
1116        let v = t.view();
1117        assert_eq!(v.get(&[0, 0]), 0.0);
1118        assert_eq!(v.get(&[1, 0]), 10.0);
1119        assert_eq!(v.get(&[0, 2]), 2.0);
1120    }
1121
1122    #[test]
1123    fn test_strided_tensor_view_mut() {
1124        let mut t = StridedArray::<f64>::col_major(&[2, 3]);
1125        {
1126            let mut vm = t.view_mut();
1127            vm.set(&[1, 2], 42.0);
1128        }
1129        assert_eq!(t.get(&[1, 2]), 42.0);
1130    }
1131
1132    #[test]
1133    fn test_strided_tensor_index() {
1134        let t =
1135            StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 10 + idx[1]) as f64);
1136        assert_eq!(t[&[0usize, 0] as &[usize]], 0.0);
1137        assert_eq!(t[&[2usize, 3] as &[usize]], 23.0);
1138    }
1139
1140    #[test]
1141    fn test_strided_tensor_index_mut() {
1142        let mut t = StridedArray::<f64>::row_major(&[2, 3]);
1143        t[&[1usize, 2] as &[usize]] = 99.0;
1144        assert_eq!(t.get(&[1, 2]), 99.0);
1145    }
1146
1147    #[test]
1148    fn test_validate_bounds_ok() {
1149        assert!(validate_bounds(6, &[2, 3], &[3, 1], 0).is_ok());
1150        assert!(validate_bounds(6, &[2, 3], &[1, 2], 0).is_ok());
1151    }
1152
1153    #[test]
1154    fn test_validate_bounds_out_of_range() {
1155        assert!(validate_bounds(5, &[2, 3], &[3, 1], 0).is_err());
1156    }
1157
1158    #[test]
1159    fn test_validate_bounds_empty() {
1160        assert!(validate_bounds(0, &[0, 3], &[3, 1], 0).is_ok());
1161    }
1162
1163    #[test]
1164    fn test_validate_bounds_with_offset() {
1165        assert!(validate_bounds(7, &[2, 3], &[3, 1], 1).is_ok());
1166        assert!(validate_bounds(6, &[2, 3], &[3, 1], 1).is_err());
1167    }
1168
1169    #[test]
1170    fn test_strided_tensor_3d() {
1171        let t = StridedArray::<f64>::from_fn_col_major(&[2, 3, 4], |idx| {
1172            (idx[0] * 100 + idx[1] * 10 + idx[2]) as f64
1173        });
1174        assert_eq!(t.ndim(), 3);
1175        assert_eq!(t.strides(), &[1, 2, 6]); // column-major
1176        assert_eq!(t.get(&[0, 0, 0]), 0.0);
1177        assert_eq!(t.get(&[1, 0, 0]), 100.0);
1178        assert_eq!(t.get(&[0, 1, 0]), 10.0);
1179        assert_eq!(t.get(&[0, 0, 1]), 1.0);
1180        assert_eq!(t.get(&[1, 2, 3]), 123.0);
1181    }
1182
1183    #[test]
1184    fn test_diagonal_view_2d() {
1185        // A[i,i] shape=[3,3] row-major strides=[3,1]
1186        // diagonal: shape=[3] strides=[4] (3+1)
1187        let data: Vec<f64> = (0..9).map(|x| x as f64).collect();
1188        let view = StridedView::<f64>::new(&data, &[3, 3], &[3, 1], 0).unwrap();
1189        let diag = view.diagonal_view(&[(0, 1)]).unwrap();
1190        assert_eq!(diag.dims(), &[3]);
1191        assert_eq!(diag.strides(), &[4]);
1192        assert_eq!(diag.get(&[0]), 0.0); // A[0,0]
1193        assert_eq!(diag.get(&[1]), 4.0); // A[1,1]
1194        assert_eq!(diag.get(&[2]), 8.0); // A[2,2]
1195    }
1196
1197    #[test]
1198    fn test_diagonal_view_3d_adjacent() {
1199        // A[i,i,j] shape=[2,2,3] row-major strides=[6,3,1]
1200        // diagonal over (0,1): shape=[2,3] strides=[9,1] (6+3)
1201        let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
1202        let view = StridedView::<f64>::new(&data, &[2, 2, 3], &[6, 3, 1], 0).unwrap();
1203        let diag = view.diagonal_view(&[(0, 1)]).unwrap();
1204        assert_eq!(diag.dims(), &[2, 3]);
1205        assert_eq!(diag.strides(), &[9, 1]);
1206        assert_eq!(diag.get(&[0, 0]), 0.0);
1207        assert_eq!(diag.get(&[0, 2]), 2.0);
1208        assert_eq!(diag.get(&[1, 0]), 9.0);
1209    }
1210
1211    #[test]
1212    fn test_diagonal_view_3d_non_adjacent() {
1213        // A[i,j,i] shape=[2,3,2] row-major strides=[6,2,1]
1214        // diagonal over (0,2): shape=[2,3] strides=[7,2] (6+1)
1215        let data: Vec<f64> = (0..12).map(|x| x as f64).collect();
1216        let view = StridedView::<f64>::new(&data, &[2, 3, 2], &[6, 2, 1], 0).unwrap();
1217        let diag = view.diagonal_view(&[(0, 2)]).unwrap();
1218        assert_eq!(diag.dims(), &[2, 3]);
1219        assert_eq!(diag.strides(), &[7, 2]);
1220        assert_eq!(diag.get(&[0, 0]), 0.0);
1221        assert_eq!(diag.get(&[0, 1]), 2.0);
1222        assert_eq!(diag.get(&[1, 0]), 7.0);
1223        assert_eq!(diag.get(&[1, 2]), 11.0);
1224    }
1225
1226    #[test]
1227    fn test_diagonal_view_two_pairs() {
1228        // A[i,j,i,j] shape=[2,3,2,3] -> A_diag[i,j] shape=[2,3]
1229        let data: Vec<f64> = (0..36).map(|x| x as f64).collect();
1230        let view = StridedView::<f64>::new(&data, &[2, 3, 2, 3], &[18, 6, 3, 1], 0).unwrap();
1231        let diag = view.diagonal_view(&[(0, 2), (1, 3)]).unwrap();
1232        assert_eq!(diag.dims(), &[2, 3]);
1233        assert_eq!(diag.strides(), &[21, 7]);
1234        assert_eq!(diag.get(&[0, 0]), 0.0);
1235        assert_eq!(diag.get(&[1, 1]), 28.0);
1236    }
1237
1238    #[test]
1239    fn test_custom_type_without_element_op_apply() {
1240        // Custom Copy type that does NOT implement ElementOpApply.
1241        // Should work with Identity views (StridedView<T> default).
1242        #[derive(Debug, Clone, Copy, PartialEq)]
1243        struct MyScalar(f64);
1244
1245        impl Default for MyScalar {
1246            fn default() -> Self {
1247                MyScalar(0.0)
1248            }
1249        }
1250
1251        // Create array with custom type
1252        let data = vec![MyScalar(1.0), MyScalar(2.0), MyScalar(3.0), MyScalar(4.0)];
1253        let arr = StridedArray::from_parts(data, &[2, 2], &[2, 1], 0).unwrap();
1254
1255        // Access via Identity view (default)
1256        assert_eq!(arr.get(&[0, 0]), MyScalar(1.0));
1257        assert_eq!(arr.get(&[1, 0]), MyScalar(3.0));
1258
1259        // Create view and access elements
1260        let view: StridedView<MyScalar> = arr.view();
1261        assert_eq!(view.get(&[0, 1]), MyScalar(2.0));
1262
1263        // Permute view (metadata-only, no ElementOpApply needed)
1264        let perm = view.permute(&[1, 0]).unwrap();
1265        assert_eq!(perm.get(&[0, 0]), MyScalar(1.0));
1266        assert_eq!(perm.get(&[0, 1]), MyScalar(3.0)); // transposed
1267    }
1268}