tensor4all_core/
col_major_array.rs

1//! N-dimensional column-major array types.
2//!
3//! Column-major layout: the element at multi-index `[i0, i1, i2, ...]` is stored
4//! at flat offset `i0 + shape[0] * (i1 + shape[1] * (i2 + ...))`.
5//!
6//! Three flavors are provided:
7//! - [`ColMajorArrayRef`] — borrowed data and shape (read-only)
8//! - [`ColMajorArrayMut`] — mutably borrowed data, borrowed shape
9//! - [`ColMajorArray`] — fully owned data and shape
10
11/// Errors that can occur when constructing or modifying a column-major array.
12#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
13pub enum ColMajorArrayError {
14    /// The length of the data does not match the product of the shape dimensions.
15    #[error("Shape mismatch: shape {shape:?} requires {expected} elements, but got {actual}")]
16    ShapeMismatch {
17        /// The requested shape.
18        shape: Vec<usize>,
19        /// Number of elements implied by the shape.
20        expected: usize,
21        /// Number of elements actually provided.
22        actual: usize,
23    },
24
25    /// The column length does not match `nrows`.
26    #[error("Column length mismatch: expected {expected} elements, but got {actual}")]
27    ColumnLengthMismatch {
28        /// Expected number of rows.
29        expected: usize,
30        /// Actual number of elements in the column.
31        actual: usize,
32    },
33
34    /// A 2D operation was called on an array that is not 2-dimensional.
35    #[error("Expected a 2D array, but ndim = {ndim}")]
36    Not2D {
37        /// The actual number of dimensions.
38        ndim: usize,
39    },
40
41    /// The product of shape dimensions overflows `usize`.
42    #[error("Shape product overflow: shape {shape:?} overflows usize")]
43    ShapeOverflow {
44        /// The shape that caused the overflow.
45        shape: Vec<usize>,
46    },
47
48    /// Incrementing the column count would overflow `usize`.
49    #[error("Column count overflow")]
50    ColumnCountOverflow,
51}
52
53// ---------------------------------------------------------------------------
54// Helper: compute the total number of elements from a shape
55// ---------------------------------------------------------------------------
56
57fn checked_shape_numel(shape: &[usize]) -> Option<usize> {
58    shape
59        .iter()
60        .copied()
61        .try_fold(1usize, |acc, d| acc.checked_mul(d))
62}
63
64fn shape_numel(shape: &[usize]) -> usize {
65    checked_shape_numel(shape).expect("shape product overflows usize")
66}
67
68/// Compute the flat offset for a column-major multi-index, using checked
69/// arithmetic. Returns `None` if any index is out of bounds or on overflow.
70fn flat_offset(shape: &[usize], index: &[usize]) -> Option<usize> {
71    if index.len() != shape.len() {
72        return None;
73    }
74    // Traverse from the last axis to the first:
75    //   offset = i_{n-1}
76    //   offset = i_{n-2} + shape[n-2] * offset  -- but we build from the back
77    // Actually, column-major: offset = i0 + s0*(i1 + s1*(i2 + ...))
78    // Evaluate right-to-left (Horner-like):
79    let mut offset: usize = 0;
80    for (idx, dim) in index.iter().zip(shape.iter()).rev() {
81        if *idx >= *dim {
82            return None;
83        }
84        offset = offset.checked_mul(*dim)?.checked_add(*idx)?;
85    }
86    Some(offset)
87}
88
89// ===========================================================================
90// ColMajorArrayRef
91// ===========================================================================
92
93/// A borrowed, read-only view of an N-dimensional column-major array.
94#[derive(Debug, Clone, Copy)]
95pub struct ColMajorArrayRef<'a, T> {
96    data: &'a [T],
97    shape: &'a [usize],
98}
99
100impl<'a, T> ColMajorArrayRef<'a, T> {
101    /// Create a new borrowed array view.
102    ///
103    /// # Panics
104    ///
105    /// Panics if `data.len() != shape.iter().product()`.
106    pub fn new(data: &'a [T], shape: &'a [usize]) -> Self {
107        let expected = shape_numel(shape);
108        assert_eq!(
109            data.len(),
110            expected,
111            "ColMajorArrayRef::new: data length {} != shape product {}",
112            data.len(),
113            expected,
114        );
115        Self { data, shape }
116    }
117
118    /// Number of dimensions.
119    pub fn ndim(&self) -> usize {
120        self.shape.len()
121    }
122
123    /// Shape of the array.
124    pub fn shape(&self) -> &[usize] {
125        self.shape
126    }
127
128    /// Total number of elements.
129    pub fn len(&self) -> usize {
130        self.data.len()
131    }
132
133    /// Whether the array is empty (zero elements).
134    pub fn is_empty(&self) -> bool {
135        self.data.is_empty()
136    }
137
138    /// Flat (contiguous) data slice.
139    pub fn data(&self) -> &[T] {
140        self.data
141    }
142
143    /// Get a reference to the element at the given multi-index, or `None` if
144    /// out of bounds.
145    pub fn get(&self, index: &[usize]) -> Option<&T> {
146        let off = flat_offset(self.shape, index)?;
147        self.data.get(off)
148    }
149}
150
151// ===========================================================================
152// ColMajorArrayMut
153// ===========================================================================
154
155/// A mutably borrowed view of an N-dimensional column-major array.
156#[derive(Debug)]
157pub struct ColMajorArrayMut<'a, T> {
158    data: &'a mut [T],
159    shape: &'a [usize],
160}
161
162impl<'a, T> ColMajorArrayMut<'a, T> {
163    /// Create a new mutable borrowed array view.
164    ///
165    /// # Panics
166    ///
167    /// Panics if `data.len() != shape.iter().product()`.
168    pub fn new(data: &'a mut [T], shape: &'a [usize]) -> Self {
169        let expected = shape_numel(shape);
170        assert_eq!(
171            data.len(),
172            expected,
173            "ColMajorArrayMut::new: data length {} != shape product {}",
174            data.len(),
175            expected,
176        );
177        Self { data, shape }
178    }
179
180    /// Number of dimensions.
181    pub fn ndim(&self) -> usize {
182        self.shape.len()
183    }
184
185    /// Shape of the array.
186    pub fn shape(&self) -> &[usize] {
187        self.shape
188    }
189
190    /// Total number of elements.
191    pub fn len(&self) -> usize {
192        self.data.len()
193    }
194
195    /// Whether the array is empty (zero elements).
196    pub fn is_empty(&self) -> bool {
197        self.data.is_empty()
198    }
199
200    /// Flat (contiguous) data slice (read-only).
201    pub fn data(&self) -> &[T] {
202        self.data
203    }
204
205    /// Flat (contiguous) data slice (mutable).
206    pub fn data_mut(&mut self) -> &mut [T] {
207        self.data
208    }
209
210    /// Get a reference to the element at the given multi-index, or `None` if
211    /// out of bounds.
212    pub fn get(&self, index: &[usize]) -> Option<&T> {
213        let off = flat_offset(self.shape, index)?;
214        self.data.get(off)
215    }
216
217    /// Get a mutable reference to the element at the given multi-index, or
218    /// `None` if out of bounds.
219    pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
220        let off = flat_offset(self.shape, index)?;
221        self.data.get_mut(off)
222    }
223}
224
225// ===========================================================================
226// ColMajorArray (owned)
227// ===========================================================================
228
229/// A fully owned N-dimensional column-major array.
230#[derive(Debug, Clone, PartialEq, Eq)]
231pub struct ColMajorArray<T> {
232    data: Vec<T>,
233    shape: Vec<usize>,
234}
235
236impl<T> ColMajorArray<T> {
237    /// Create a new owned array from data and shape.
238    ///
239    /// Returns an error if `data.len()` does not equal the product of the
240    /// shape dimensions.
241    pub fn new(data: Vec<T>, shape: Vec<usize>) -> Result<Self, ColMajorArrayError> {
242        let expected =
243            checked_shape_numel(&shape).ok_or_else(|| ColMajorArrayError::ShapeOverflow {
244                shape: shape.clone(),
245            })?;
246        if data.len() != expected {
247            return Err(ColMajorArrayError::ShapeMismatch {
248                shape,
249                expected,
250                actual: data.len(),
251            });
252        }
253        Ok(Self { data, shape })
254    }
255
256    /// Number of dimensions.
257    pub fn ndim(&self) -> usize {
258        self.shape.len()
259    }
260
261    /// Shape of the array.
262    pub fn shape(&self) -> &[usize] {
263        &self.shape
264    }
265
266    /// Total number of elements.
267    pub fn len(&self) -> usize {
268        self.data.len()
269    }
270
271    /// Whether the array is empty (zero elements).
272    pub fn is_empty(&self) -> bool {
273        self.data.is_empty()
274    }
275
276    /// Flat (contiguous) data slice (read-only).
277    pub fn data(&self) -> &[T] {
278        &self.data
279    }
280
281    /// Flat (contiguous) data slice (mutable).
282    pub fn data_mut(&mut self) -> &mut [T] {
283        &mut self.data
284    }
285
286    /// Get a reference to the element at the given multi-index, or `None` if
287    /// out of bounds.
288    pub fn get(&self, index: &[usize]) -> Option<&T> {
289        let off = flat_offset(&self.shape, index)?;
290        self.data.get(off)
291    }
292
293    /// Get a mutable reference to the element at the given multi-index, or
294    /// `None` if out of bounds.
295    pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
296        let off = flat_offset(&self.shape, index)?;
297        self.data.get_mut(off)
298    }
299
300    /// Consume the array and return the underlying data vector.
301    pub fn into_data(self) -> Vec<T> {
302        self.data
303    }
304
305    /// Borrow as a [`ColMajorArrayRef`].
306    pub fn as_ref(&self) -> ColMajorArrayRef<'_, T> {
307        ColMajorArrayRef {
308            data: &self.data,
309            shape: &self.shape,
310        }
311    }
312
313    /// Borrow as a [`ColMajorArrayMut`].
314    pub fn as_mut(&mut self) -> ColMajorArrayMut<'_, T> {
315        ColMajorArrayMut {
316            data: &mut self.data,
317            shape: &self.shape,
318        }
319    }
320
321    // -- 2D helpers ---------------------------------------------------------
322
323    /// Number of rows (panics if not 2D).
324    pub fn nrows(&self) -> usize {
325        assert_eq!(self.ndim(), 2, "nrows() requires a 2D array");
326        self.shape[0]
327    }
328
329    /// Number of columns (panics if not 2D).
330    pub fn ncols(&self) -> usize {
331        assert_eq!(self.ndim(), 2, "ncols() requires a 2D array");
332        self.shape[1]
333    }
334
335    /// Return a slice for column `j` of a 2D array, or `None` if `j` is out
336    /// of range. Panics if the array is not 2D.
337    pub fn column(&self, j: usize) -> Option<&[T]> {
338        assert_eq!(self.ndim(), 2, "column() requires a 2D array");
339        let nrows = self.shape[0];
340        if j >= self.shape[1] {
341            return None;
342        }
343        let start = nrows.checked_mul(j)?;
344        let end = start.checked_add(nrows)?;
345        Some(&self.data[start..end])
346    }
347
348    /// Append a column to a 2D array.
349    ///
350    /// The `col` slice must have length equal to `nrows()`. This extends
351    /// the internal data and increments `shape[1]`.
352    ///
353    /// Returns an error if the array is not 2D or if the column length does
354    /// not match `nrows()`.
355    pub fn push_column(&mut self, col: &[T]) -> Result<(), ColMajorArrayError>
356    where
357        T: Clone,
358    {
359        if self.ndim() != 2 {
360            return Err(ColMajorArrayError::Not2D { ndim: self.ndim() });
361        }
362        let nrows = self.shape[0];
363        if col.len() != nrows {
364            return Err(ColMajorArrayError::ColumnLengthMismatch {
365                expected: nrows,
366                actual: col.len(),
367            });
368        }
369        self.data.extend_from_slice(col);
370        self.shape[1] = self.shape[1]
371            .checked_add(1)
372            .ok_or(ColMajorArrayError::ColumnCountOverflow)?;
373        Ok(())
374    }
375}
376
377// -- Factories (require trait bounds on T) ----------------------------------
378
379impl<T: Clone> ColMajorArray<T> {
380    /// Create an array filled with a given value.
381    ///
382    /// Returns an error if the product of shape dimensions overflows `usize`.
383    pub fn filled(shape: Vec<usize>, value: T) -> Result<Self, ColMajorArrayError> {
384        let n = checked_shape_numel(&shape).ok_or_else(|| ColMajorArrayError::ShapeOverflow {
385            shape: shape.clone(),
386        })?;
387        Ok(Self {
388            data: vec![value; n],
389            shape,
390        })
391    }
392}
393
394impl<T: Default + Clone> ColMajorArray<T> {
395    /// Create an array filled with [`Default::default()`] (e.g., zeros for
396    /// numeric types).
397    ///
398    /// Returns an error if the product of shape dimensions overflows `usize`.
399    pub fn zeros(shape: Vec<usize>) -> Result<Self, ColMajorArrayError> {
400        let n = checked_shape_numel(&shape).ok_or_else(|| ColMajorArrayError::ShapeOverflow {
401            shape: shape.clone(),
402        })?;
403        Ok(Self {
404            data: vec![T::default(); n],
405            shape,
406        })
407    }
408}
409
410// ===========================================================================
411// Tests
412// ===========================================================================
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    // -- 1D creation + get --------------------------------------------------
419
420    #[test]
421    fn test_1d_creation_and_get() {
422        let arr = ColMajorArray::new(vec![10, 20, 30], vec![3]).unwrap();
423        assert_eq!(arr.ndim(), 1);
424        assert_eq!(arr.shape(), &[3]);
425        assert_eq!(arr.len(), 3);
426        assert!(!arr.is_empty());
427
428        assert_eq!(arr.get(&[0]), Some(&10));
429        assert_eq!(arr.get(&[1]), Some(&20));
430        assert_eq!(arr.get(&[2]), Some(&30));
431    }
432
433    // -- 2D creation + get --------------------------------------------------
434
435    #[test]
436    fn test_2d_creation_and_get() {
437        // 2x3 matrix in column-major:
438        // Column 0: [1, 2], Column 1: [3, 4], Column 2: [5, 6]
439        // Flat: [1, 2, 3, 4, 5, 6]
440        let arr = ColMajorArray::new(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
441        assert_eq!(arr.ndim(), 2);
442        assert_eq!(arr.shape(), &[2, 3]);
443        assert_eq!(arr.len(), 6);
444
445        // (row, col)
446        assert_eq!(arr.get(&[0, 0]), Some(&1));
447        assert_eq!(arr.get(&[1, 0]), Some(&2));
448        assert_eq!(arr.get(&[0, 1]), Some(&3));
449        assert_eq!(arr.get(&[1, 1]), Some(&4));
450        assert_eq!(arr.get(&[0, 2]), Some(&5));
451        assert_eq!(arr.get(&[1, 2]), Some(&6));
452    }
453
454    // -- 3D creation + get --------------------------------------------------
455
456    #[test]
457    fn test_3d_creation_and_get() {
458        // Shape [2, 3, 2]: total 12 elements
459        let data: Vec<i32> = (0..12).collect();
460        let arr = ColMajorArray::new(data.clone(), vec![2, 3, 2]).unwrap();
461        assert_eq!(arr.ndim(), 3);
462        assert_eq!(arr.len(), 12);
463
464        // Verify column-major offset: i0 + 2*(i1 + 3*i2)
465        for i2 in 0..2 {
466            for i1 in 0..3 {
467                for i0 in 0..2 {
468                    let expected_offset = i0 + 2 * (i1 + 3 * i2);
469                    assert_eq!(
470                        arr.get(&[i0, i1, i2]),
471                        Some(&(expected_offset as i32)),
472                        "Mismatch at [{i0}, {i1}, {i2}]"
473                    );
474                }
475            }
476        }
477    }
478
479    // -- Column-major order verification (2D) -------------------------------
480
481    #[test]
482    fn test_column_major_order_2d() {
483        let nrows = 3;
484        let ncols = 4;
485        let data: Vec<i32> = (0..(nrows * ncols) as i32).collect();
486        let arr = ColMajorArray::new(data.clone(), vec![nrows, ncols]).unwrap();
487
488        // In column-major, data[i + nrows * j] == arr[(i, j)]
489        for j in 0..ncols {
490            for i in 0..nrows {
491                assert_eq!(arr.get(&[i, j]), Some(&data[i + nrows * j]));
492            }
493        }
494    }
495
496    // -- get_mut ------------------------------------------------------------
497
498    #[test]
499    fn test_get_mut() {
500        let mut arr = ColMajorArray::new(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
501        if let Some(v) = arr.get_mut(&[1, 0]) {
502            *v = 42;
503        }
504        assert_eq!(arr.get(&[1, 0]), Some(&42));
505        // Other elements unchanged
506        assert_eq!(arr.get(&[0, 0]), Some(&1));
507        assert_eq!(arr.get(&[0, 1]), Some(&3));
508        assert_eq!(arr.get(&[1, 1]), Some(&4));
509    }
510
511    // -- push_column --------------------------------------------------------
512
513    #[test]
514    fn test_push_column() {
515        let mut arr = ColMajorArray::new(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
516        assert_eq!(arr.ncols(), 2);
517
518        arr.push_column(&[5, 6]).unwrap();
519        assert_eq!(arr.ncols(), 3);
520        assert_eq!(arr.shape(), &[2, 3]);
521        assert_eq!(arr.len(), 6);
522        assert_eq!(arr.get(&[0, 2]), Some(&5));
523        assert_eq!(arr.get(&[1, 2]), Some(&6));
524    }
525
526    #[test]
527    fn test_push_column_wrong_length() {
528        let mut arr = ColMajorArray::new(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
529        let err = arr.push_column(&[5, 6, 7]).unwrap_err();
530        assert_eq!(
531            err,
532            ColMajorArrayError::ColumnLengthMismatch {
533                expected: 2,
534                actual: 3,
535            }
536        );
537    }
538
539    #[test]
540    fn test_push_column_not_2d() {
541        let mut arr = ColMajorArray::new(vec![1, 2, 3], vec![3]).unwrap();
542        let err = arr.push_column(&[4]).unwrap_err();
543        assert_eq!(err, ColMajorArrayError::Not2D { ndim: 1 });
544    }
545
546    // -- column() slice access ----------------------------------------------
547
548    #[test]
549    fn test_column_access() {
550        let arr = ColMajorArray::new(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
551        assert_eq!(arr.column(0), Some([1, 2].as_slice()));
552        assert_eq!(arr.column(1), Some([3, 4].as_slice()));
553        assert_eq!(arr.column(2), Some([5, 6].as_slice()));
554        assert_eq!(arr.column(3), None); // out of bounds
555    }
556
557    // -- zeros, filled ------------------------------------------------------
558
559    #[test]
560    fn test_zeros() {
561        let arr: ColMajorArray<f64> = ColMajorArray::zeros(vec![3, 2]).unwrap();
562        assert_eq!(arr.len(), 6);
563        assert!(arr.data().iter().all(|&v| v == 0.0));
564    }
565
566    #[test]
567    fn test_filled() {
568        let arr = ColMajorArray::filled(vec![2, 3], 7i32).unwrap();
569        assert_eq!(arr.len(), 6);
570        assert!(arr.data().iter().all(|&v| v == 7));
571    }
572
573    // -- Shape mismatch error -----------------------------------------------
574
575    #[test]
576    fn test_shape_mismatch() {
577        let result = ColMajorArray::new(vec![1, 2, 3], vec![2, 2]);
578        assert_eq!(
579            result.unwrap_err(),
580            ColMajorArrayError::ShapeMismatch {
581                shape: vec![2, 2],
582                expected: 4,
583                actual: 3,
584            }
585        );
586    }
587
588    // -- Out-of-bounds -> None ----------------------------------------------
589
590    #[test]
591    fn test_out_of_bounds() {
592        let arr = ColMajorArray::new(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
593        // Index out of range
594        assert_eq!(arr.get(&[2, 0]), None);
595        assert_eq!(arr.get(&[0, 2]), None);
596        // Wrong number of indices
597        assert_eq!(arr.get(&[0]), None);
598        assert_eq!(arr.get(&[0, 0, 0]), None);
599    }
600
601    // -- Ref and Mut views --------------------------------------------------
602
603    #[test]
604    fn test_as_ref() {
605        let arr = ColMajorArray::new(vec![10, 20, 30, 40], vec![2, 2]).unwrap();
606        let view = arr.as_ref();
607        assert_eq!(view.ndim(), 2);
608        assert_eq!(view.shape(), &[2, 2]);
609        assert_eq!(view.get(&[1, 1]), Some(&40));
610        assert_eq!(view.data(), arr.data());
611    }
612
613    #[test]
614    fn test_as_mut() {
615        let mut arr = ColMajorArray::new(vec![10, 20, 30, 40], vec![2, 2]).unwrap();
616        {
617            let mut view = arr.as_mut();
618            if let Some(v) = view.get_mut(&[0, 1]) {
619                *v = 99;
620            }
621        }
622        assert_eq!(arr.get(&[0, 1]), Some(&99));
623    }
624
625    // -- into_data ----------------------------------------------------------
626
627    #[test]
628    fn test_into_data() {
629        let arr = ColMajorArray::new(vec![1, 2, 3], vec![3]).unwrap();
630        let data = arr.into_data();
631        assert_eq!(data, vec![1, 2, 3]);
632    }
633
634    // -- Empty arrays -------------------------------------------------------
635
636    #[test]
637    fn test_empty_array() {
638        let arr: ColMajorArray<i32> = ColMajorArray::new(vec![], vec![0]).unwrap();
639        assert!(arr.is_empty());
640        assert_eq!(arr.len(), 0);
641        assert_eq!(arr.ndim(), 1);
642    }
643
644    #[test]
645    fn test_empty_2d_array() {
646        let arr: ColMajorArray<i32> = ColMajorArray::new(vec![], vec![3, 0]).unwrap();
647        assert!(arr.is_empty());
648        assert_eq!(arr.len(), 0);
649        assert_eq!(arr.nrows(), 3);
650        assert_eq!(arr.ncols(), 0);
651    }
652
653    // -- ColMajorArrayRef construction --------------------------------------
654
655    #[test]
656    fn test_ref_new() {
657        let data = [1, 2, 3, 4, 5, 6];
658        let shape = [2, 3];
659        let view = ColMajorArrayRef::new(&data, &shape);
660        assert_eq!(view.ndim(), 2);
661        assert_eq!(view.len(), 6);
662        assert_eq!(view.get(&[1, 2]), Some(&6));
663    }
664
665    // -- ColMajorArrayMut construction --------------------------------------
666
667    #[test]
668    fn test_mut_new() {
669        let mut data = [1, 2, 3, 4, 5, 6];
670        let shape = [2, 3];
671        let mut view = ColMajorArrayMut::new(&mut data, &shape);
672        assert_eq!(view.ndim(), 2);
673        assert_eq!(view.len(), 6);
674        *view.get_mut(&[0, 0]).unwrap() = 100;
675        assert_eq!(view.get(&[0, 0]), Some(&100));
676    }
677
678    // -- Overflow detection ---------------------------------------------------
679
680    #[test]
681    fn test_new_rejects_overflow_shape() {
682        let result = ColMajorArray::<u8>::new(vec![], vec![usize::MAX, 2]);
683        assert!(
684            matches!(result, Err(ColMajorArrayError::ShapeOverflow { .. })),
685            "expected ShapeOverflow, got {:?}",
686            result
687        );
688    }
689
690    #[test]
691    fn test_filled_rejects_overflow_shape() {
692        let result = ColMajorArray::filled(vec![usize::MAX, 2], 0u8);
693        assert!(
694            matches!(result, Err(ColMajorArrayError::ShapeOverflow { .. })),
695            "expected ShapeOverflow, got {:?}",
696            result
697        );
698    }
699
700    #[test]
701    fn test_zeros_rejects_overflow_shape() {
702        let result = ColMajorArray::<u8>::zeros(vec![usize::MAX, 2]);
703        assert!(
704            matches!(result, Err(ColMajorArrayError::ShapeOverflow { .. })),
705            "expected ShapeOverflow, got {:?}",
706            result
707        );
708    }
709}