1#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
13pub enum ColMajorArrayError {
14 #[error("Shape mismatch: shape {shape:?} requires {expected} elements, but got {actual}")]
16 ShapeMismatch {
17 shape: Vec<usize>,
19 expected: usize,
21 actual: usize,
23 },
24
25 #[error("Column length mismatch: expected {expected} elements, but got {actual}")]
27 ColumnLengthMismatch {
28 expected: usize,
30 actual: usize,
32 },
33
34 #[error("Expected a 2D array, but ndim = {ndim}")]
36 Not2D {
37 ndim: usize,
39 },
40
41 #[error("Shape product overflow: shape {shape:?} overflows usize")]
43 ShapeOverflow {
44 shape: Vec<usize>,
46 },
47
48 #[error("Column count overflow")]
50 ColumnCountOverflow,
51}
52
53fn checked_shape_numel(shape: &[usize]) -> Option<usize> {
58 shape
59 .iter()
60 .copied()
61 .try_fold(1usize, |acc, d| acc.checked_mul(d))
62}
63
64fn shape_numel(shape: &[usize]) -> usize {
65 checked_shape_numel(shape).expect("shape product overflows usize")
66}
67
68fn flat_offset(shape: &[usize], index: &[usize]) -> Option<usize> {
71 if index.len() != shape.len() {
72 return None;
73 }
74 let mut offset: usize = 0;
80 for (idx, dim) in index.iter().zip(shape.iter()).rev() {
81 if *idx >= *dim {
82 return None;
83 }
84 offset = offset.checked_mul(*dim)?.checked_add(*idx)?;
85 }
86 Some(offset)
87}
88
89#[derive(Debug, Clone, Copy)]
95pub struct ColMajorArrayRef<'a, T> {
96 data: &'a [T],
97 shape: &'a [usize],
98}
99
100impl<'a, T> ColMajorArrayRef<'a, T> {
101 pub fn new(data: &'a [T], shape: &'a [usize]) -> Self {
107 let expected = shape_numel(shape);
108 assert_eq!(
109 data.len(),
110 expected,
111 "ColMajorArrayRef::new: data length {} != shape product {}",
112 data.len(),
113 expected,
114 );
115 Self { data, shape }
116 }
117
118 pub fn ndim(&self) -> usize {
120 self.shape.len()
121 }
122
123 pub fn shape(&self) -> &[usize] {
125 self.shape
126 }
127
128 pub fn len(&self) -> usize {
130 self.data.len()
131 }
132
133 pub fn is_empty(&self) -> bool {
135 self.data.is_empty()
136 }
137
138 pub fn data(&self) -> &[T] {
140 self.data
141 }
142
143 pub fn get(&self, index: &[usize]) -> Option<&T> {
146 let off = flat_offset(self.shape, index)?;
147 self.data.get(off)
148 }
149}
150
151#[derive(Debug)]
157pub struct ColMajorArrayMut<'a, T> {
158 data: &'a mut [T],
159 shape: &'a [usize],
160}
161
162impl<'a, T> ColMajorArrayMut<'a, T> {
163 pub fn new(data: &'a mut [T], shape: &'a [usize]) -> Self {
169 let expected = shape_numel(shape);
170 assert_eq!(
171 data.len(),
172 expected,
173 "ColMajorArrayMut::new: data length {} != shape product {}",
174 data.len(),
175 expected,
176 );
177 Self { data, shape }
178 }
179
180 pub fn ndim(&self) -> usize {
182 self.shape.len()
183 }
184
185 pub fn shape(&self) -> &[usize] {
187 self.shape
188 }
189
190 pub fn len(&self) -> usize {
192 self.data.len()
193 }
194
195 pub fn is_empty(&self) -> bool {
197 self.data.is_empty()
198 }
199
200 pub fn data(&self) -> &[T] {
202 self.data
203 }
204
205 pub fn data_mut(&mut self) -> &mut [T] {
207 self.data
208 }
209
210 pub fn get(&self, index: &[usize]) -> Option<&T> {
213 let off = flat_offset(self.shape, index)?;
214 self.data.get(off)
215 }
216
217 pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
220 let off = flat_offset(self.shape, index)?;
221 self.data.get_mut(off)
222 }
223}
224
225#[derive(Debug, Clone, PartialEq, Eq)]
231pub struct ColMajorArray<T> {
232 data: Vec<T>,
233 shape: Vec<usize>,
234}
235
236impl<T> ColMajorArray<T> {
237 pub fn new(data: Vec<T>, shape: Vec<usize>) -> Result<Self, ColMajorArrayError> {
242 let expected =
243 checked_shape_numel(&shape).ok_or_else(|| ColMajorArrayError::ShapeOverflow {
244 shape: shape.clone(),
245 })?;
246 if data.len() != expected {
247 return Err(ColMajorArrayError::ShapeMismatch {
248 shape,
249 expected,
250 actual: data.len(),
251 });
252 }
253 Ok(Self { data, shape })
254 }
255
256 pub fn ndim(&self) -> usize {
258 self.shape.len()
259 }
260
261 pub fn shape(&self) -> &[usize] {
263 &self.shape
264 }
265
266 pub fn len(&self) -> usize {
268 self.data.len()
269 }
270
271 pub fn is_empty(&self) -> bool {
273 self.data.is_empty()
274 }
275
276 pub fn data(&self) -> &[T] {
278 &self.data
279 }
280
281 pub fn data_mut(&mut self) -> &mut [T] {
283 &mut self.data
284 }
285
286 pub fn get(&self, index: &[usize]) -> Option<&T> {
289 let off = flat_offset(&self.shape, index)?;
290 self.data.get(off)
291 }
292
293 pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
296 let off = flat_offset(&self.shape, index)?;
297 self.data.get_mut(off)
298 }
299
300 pub fn into_data(self) -> Vec<T> {
302 self.data
303 }
304
305 pub fn as_ref(&self) -> ColMajorArrayRef<'_, T> {
307 ColMajorArrayRef {
308 data: &self.data,
309 shape: &self.shape,
310 }
311 }
312
313 pub fn as_mut(&mut self) -> ColMajorArrayMut<'_, T> {
315 ColMajorArrayMut {
316 data: &mut self.data,
317 shape: &self.shape,
318 }
319 }
320
321 pub fn nrows(&self) -> usize {
325 assert_eq!(self.ndim(), 2, "nrows() requires a 2D array");
326 self.shape[0]
327 }
328
329 pub fn ncols(&self) -> usize {
331 assert_eq!(self.ndim(), 2, "ncols() requires a 2D array");
332 self.shape[1]
333 }
334
335 pub fn column(&self, j: usize) -> Option<&[T]> {
338 assert_eq!(self.ndim(), 2, "column() requires a 2D array");
339 let nrows = self.shape[0];
340 if j >= self.shape[1] {
341 return None;
342 }
343 let start = nrows.checked_mul(j)?;
344 let end = start.checked_add(nrows)?;
345 Some(&self.data[start..end])
346 }
347
348 pub fn push_column(&mut self, col: &[T]) -> Result<(), ColMajorArrayError>
356 where
357 T: Clone,
358 {
359 if self.ndim() != 2 {
360 return Err(ColMajorArrayError::Not2D { ndim: self.ndim() });
361 }
362 let nrows = self.shape[0];
363 if col.len() != nrows {
364 return Err(ColMajorArrayError::ColumnLengthMismatch {
365 expected: nrows,
366 actual: col.len(),
367 });
368 }
369 self.data.extend_from_slice(col);
370 self.shape[1] = self.shape[1]
371 .checked_add(1)
372 .ok_or(ColMajorArrayError::ColumnCountOverflow)?;
373 Ok(())
374 }
375}
376
377impl<T: Clone> ColMajorArray<T> {
380 pub fn filled(shape: Vec<usize>, value: T) -> Result<Self, ColMajorArrayError> {
384 let n = checked_shape_numel(&shape).ok_or_else(|| ColMajorArrayError::ShapeOverflow {
385 shape: shape.clone(),
386 })?;
387 Ok(Self {
388 data: vec![value; n],
389 shape,
390 })
391 }
392}
393
394impl<T: Default + Clone> ColMajorArray<T> {
395 pub fn zeros(shape: Vec<usize>) -> Result<Self, ColMajorArrayError> {
400 let n = checked_shape_numel(&shape).ok_or_else(|| ColMajorArrayError::ShapeOverflow {
401 shape: shape.clone(),
402 })?;
403 Ok(Self {
404 data: vec![T::default(); n],
405 shape,
406 })
407 }
408}
409
410#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
421 fn test_1d_creation_and_get() {
422 let arr = ColMajorArray::new(vec![10, 20, 30], vec![3]).unwrap();
423 assert_eq!(arr.ndim(), 1);
424 assert_eq!(arr.shape(), &[3]);
425 assert_eq!(arr.len(), 3);
426 assert!(!arr.is_empty());
427
428 assert_eq!(arr.get(&[0]), Some(&10));
429 assert_eq!(arr.get(&[1]), Some(&20));
430 assert_eq!(arr.get(&[2]), Some(&30));
431 }
432
433 #[test]
436 fn test_2d_creation_and_get() {
437 let arr = ColMajorArray::new(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
441 assert_eq!(arr.ndim(), 2);
442 assert_eq!(arr.shape(), &[2, 3]);
443 assert_eq!(arr.len(), 6);
444
445 assert_eq!(arr.get(&[0, 0]), Some(&1));
447 assert_eq!(arr.get(&[1, 0]), Some(&2));
448 assert_eq!(arr.get(&[0, 1]), Some(&3));
449 assert_eq!(arr.get(&[1, 1]), Some(&4));
450 assert_eq!(arr.get(&[0, 2]), Some(&5));
451 assert_eq!(arr.get(&[1, 2]), Some(&6));
452 }
453
454 #[test]
457 fn test_3d_creation_and_get() {
458 let data: Vec<i32> = (0..12).collect();
460 let arr = ColMajorArray::new(data.clone(), vec![2, 3, 2]).unwrap();
461 assert_eq!(arr.ndim(), 3);
462 assert_eq!(arr.len(), 12);
463
464 for i2 in 0..2 {
466 for i1 in 0..3 {
467 for i0 in 0..2 {
468 let expected_offset = i0 + 2 * (i1 + 3 * i2);
469 assert_eq!(
470 arr.get(&[i0, i1, i2]),
471 Some(&(expected_offset as i32)),
472 "Mismatch at [{i0}, {i1}, {i2}]"
473 );
474 }
475 }
476 }
477 }
478
479 #[test]
482 fn test_column_major_order_2d() {
483 let nrows = 3;
484 let ncols = 4;
485 let data: Vec<i32> = (0..(nrows * ncols) as i32).collect();
486 let arr = ColMajorArray::new(data.clone(), vec![nrows, ncols]).unwrap();
487
488 for j in 0..ncols {
490 for i in 0..nrows {
491 assert_eq!(arr.get(&[i, j]), Some(&data[i + nrows * j]));
492 }
493 }
494 }
495
496 #[test]
499 fn test_get_mut() {
500 let mut arr = ColMajorArray::new(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
501 if let Some(v) = arr.get_mut(&[1, 0]) {
502 *v = 42;
503 }
504 assert_eq!(arr.get(&[1, 0]), Some(&42));
505 assert_eq!(arr.get(&[0, 0]), Some(&1));
507 assert_eq!(arr.get(&[0, 1]), Some(&3));
508 assert_eq!(arr.get(&[1, 1]), Some(&4));
509 }
510
511 #[test]
514 fn test_push_column() {
515 let mut arr = ColMajorArray::new(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
516 assert_eq!(arr.ncols(), 2);
517
518 arr.push_column(&[5, 6]).unwrap();
519 assert_eq!(arr.ncols(), 3);
520 assert_eq!(arr.shape(), &[2, 3]);
521 assert_eq!(arr.len(), 6);
522 assert_eq!(arr.get(&[0, 2]), Some(&5));
523 assert_eq!(arr.get(&[1, 2]), Some(&6));
524 }
525
526 #[test]
527 fn test_push_column_wrong_length() {
528 let mut arr = ColMajorArray::new(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
529 let err = arr.push_column(&[5, 6, 7]).unwrap_err();
530 assert_eq!(
531 err,
532 ColMajorArrayError::ColumnLengthMismatch {
533 expected: 2,
534 actual: 3,
535 }
536 );
537 }
538
539 #[test]
540 fn test_push_column_not_2d() {
541 let mut arr = ColMajorArray::new(vec![1, 2, 3], vec![3]).unwrap();
542 let err = arr.push_column(&[4]).unwrap_err();
543 assert_eq!(err, ColMajorArrayError::Not2D { ndim: 1 });
544 }
545
546 #[test]
549 fn test_column_access() {
550 let arr = ColMajorArray::new(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
551 assert_eq!(arr.column(0), Some([1, 2].as_slice()));
552 assert_eq!(arr.column(1), Some([3, 4].as_slice()));
553 assert_eq!(arr.column(2), Some([5, 6].as_slice()));
554 assert_eq!(arr.column(3), None); }
556
557 #[test]
560 fn test_zeros() {
561 let arr: ColMajorArray<f64> = ColMajorArray::zeros(vec![3, 2]).unwrap();
562 assert_eq!(arr.len(), 6);
563 assert!(arr.data().iter().all(|&v| v == 0.0));
564 }
565
566 #[test]
567 fn test_filled() {
568 let arr = ColMajorArray::filled(vec![2, 3], 7i32).unwrap();
569 assert_eq!(arr.len(), 6);
570 assert!(arr.data().iter().all(|&v| v == 7));
571 }
572
573 #[test]
576 fn test_shape_mismatch() {
577 let result = ColMajorArray::new(vec![1, 2, 3], vec![2, 2]);
578 assert_eq!(
579 result.unwrap_err(),
580 ColMajorArrayError::ShapeMismatch {
581 shape: vec![2, 2],
582 expected: 4,
583 actual: 3,
584 }
585 );
586 }
587
588 #[test]
591 fn test_out_of_bounds() {
592 let arr = ColMajorArray::new(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
593 assert_eq!(arr.get(&[2, 0]), None);
595 assert_eq!(arr.get(&[0, 2]), None);
596 assert_eq!(arr.get(&[0]), None);
598 assert_eq!(arr.get(&[0, 0, 0]), None);
599 }
600
601 #[test]
604 fn test_as_ref() {
605 let arr = ColMajorArray::new(vec![10, 20, 30, 40], vec![2, 2]).unwrap();
606 let view = arr.as_ref();
607 assert_eq!(view.ndim(), 2);
608 assert_eq!(view.shape(), &[2, 2]);
609 assert_eq!(view.get(&[1, 1]), Some(&40));
610 assert_eq!(view.data(), arr.data());
611 }
612
613 #[test]
614 fn test_as_mut() {
615 let mut arr = ColMajorArray::new(vec![10, 20, 30, 40], vec![2, 2]).unwrap();
616 {
617 let mut view = arr.as_mut();
618 if let Some(v) = view.get_mut(&[0, 1]) {
619 *v = 99;
620 }
621 }
622 assert_eq!(arr.get(&[0, 1]), Some(&99));
623 }
624
625 #[test]
628 fn test_into_data() {
629 let arr = ColMajorArray::new(vec![1, 2, 3], vec![3]).unwrap();
630 let data = arr.into_data();
631 assert_eq!(data, vec![1, 2, 3]);
632 }
633
634 #[test]
637 fn test_empty_array() {
638 let arr: ColMajorArray<i32> = ColMajorArray::new(vec![], vec![0]).unwrap();
639 assert!(arr.is_empty());
640 assert_eq!(arr.len(), 0);
641 assert_eq!(arr.ndim(), 1);
642 }
643
644 #[test]
645 fn test_empty_2d_array() {
646 let arr: ColMajorArray<i32> = ColMajorArray::new(vec![], vec![3, 0]).unwrap();
647 assert!(arr.is_empty());
648 assert_eq!(arr.len(), 0);
649 assert_eq!(arr.nrows(), 3);
650 assert_eq!(arr.ncols(), 0);
651 }
652
653 #[test]
656 fn test_ref_new() {
657 let data = [1, 2, 3, 4, 5, 6];
658 let shape = [2, 3];
659 let view = ColMajorArrayRef::new(&data, &shape);
660 assert_eq!(view.ndim(), 2);
661 assert_eq!(view.len(), 6);
662 assert_eq!(view.get(&[1, 2]), Some(&6));
663 }
664
665 #[test]
668 fn test_mut_new() {
669 let mut data = [1, 2, 3, 4, 5, 6];
670 let shape = [2, 3];
671 let mut view = ColMajorArrayMut::new(&mut data, &shape);
672 assert_eq!(view.ndim(), 2);
673 assert_eq!(view.len(), 6);
674 *view.get_mut(&[0, 0]).unwrap() = 100;
675 assert_eq!(view.get(&[0, 0]), Some(&100));
676 }
677
678 #[test]
681 fn test_new_rejects_overflow_shape() {
682 let result = ColMajorArray::<u8>::new(vec![], vec![usize::MAX, 2]);
683 assert!(
684 matches!(result, Err(ColMajorArrayError::ShapeOverflow { .. })),
685 "expected ShapeOverflow, got {:?}",
686 result
687 );
688 }
689
690 #[test]
691 fn test_filled_rejects_overflow_shape() {
692 let result = ColMajorArray::filled(vec![usize::MAX, 2], 0u8);
693 assert!(
694 matches!(result, Err(ColMajorArrayError::ShapeOverflow { .. })),
695 "expected ShapeOverflow, got {:?}",
696 result
697 );
698 }
699
700 #[test]
701 fn test_zeros_rejects_overflow_shape() {
702 let result = ColMajorArray::<u8>::zeros(vec![usize::MAX, 2]);
703 assert!(
704 matches!(result, Err(ColMajorArrayError::ShapeOverflow { .. })),
705 "expected ShapeOverflow, got {:?}",
706 result
707 );
708 }
709}