tensor4all_tcicore/matrix.rs
1//! Dense row-major matrix type and utility functions.
2//!
3//! [`Matrix<T>`] is a simple dense 2D matrix in row-major layout, indexed
4//! by `m[[row, col]]`. It is used throughout the TCI infrastructure for
5//! pivot block computations, cross interpolation factors, and dense
6//! submatrix extraction.
7//!
8//! # Examples
9//!
10//! ```
11//! use tensor4all_tcicore::{Matrix, from_vec2d, matrix};
12//!
13//! let m = from_vec2d(vec![
14//! vec![1.0_f64, 2.0],
15//! vec![3.0, 4.0],
16//! ]);
17//! assert_eq!(m.nrows(), 2);
18//! assert_eq!(m.ncols(), 2);
19//! assert_eq!(m[[0, 1]], 2.0);
20//! assert_eq!(m[[1, 0]], 3.0);
21//! ```
22
23use crate::scalar::Scalar;
24use num_traits::{One, Zero};
25use rand::seq::SliceRandom;
26use rand::Rng;
27use std::collections::HashSet;
28use std::ops::{Index, IndexMut};
29
30/// A dense 2D matrix in row-major layout.
31///
32/// Access elements with `m[[row, col]]` syntax. Data is stored contiguously
33/// in row-major order.
34///
35/// # Examples
36///
37/// ```
38/// use tensor4all_tcicore::Matrix;
39///
40/// let mut m = Matrix::zeros(2, 3);
41/// m[[0, 1]] = 5.0_f64;
42/// assert_eq!(m[[0, 1]], 5.0);
43/// assert_eq!(m[[0, 0]], 0.0);
44/// assert_eq!(m.nrows(), 2);
45/// assert_eq!(m.ncols(), 3);
46/// ```
47#[derive(Debug, Clone)]
48pub struct Matrix<T> {
49 data: Vec<T>,
50 nrows: usize,
51 ncols: usize,
52}
53
54impl<T> Matrix<T> {
55 /// Create a matrix from raw row-major data.
56 ///
57 /// # Panics
58 ///
59 /// Panics if `data.len() != nrows * ncols`.
60 ///
61 /// # Examples
62 ///
63 /// ```
64 /// use tensor4all_tcicore::Matrix;
65 ///
66 /// let m = Matrix::from_raw_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
67 /// assert_eq!(m[[0, 0]], 1.0);
68 /// assert_eq!(m[[0, 1]], 2.0);
69 /// assert_eq!(m[[1, 0]], 3.0);
70 /// assert_eq!(m[[1, 1]], 4.0);
71 /// ```
72 pub fn from_raw_vec(nrows: usize, ncols: usize, data: Vec<T>) -> Self {
73 assert_eq!(data.len(), nrows * ncols);
74 Self { data, nrows, ncols }
75 }
76
77 /// View the underlying row-major data as a contiguous slice.
78 ///
79 /// # Examples
80 ///
81 /// ```
82 /// use tensor4all_tcicore::Matrix;
83 ///
84 /// let m = Matrix::from_raw_vec(1, 3, vec![10, 20, 30]);
85 /// assert_eq!(m.as_slice(), &[10, 20, 30]);
86 /// ```
87 pub fn as_slice(&self) -> &[T] {
88 &self.data
89 }
90
91 /// Number of rows
92 pub fn nrows(&self) -> usize {
93 self.nrows
94 }
95
96 /// Number of columns
97 pub fn ncols(&self) -> usize {
98 self.ncols
99 }
100}
101
102impl<T: Clone> Matrix<T> {
103 /// Create a new matrix filled with a constant value.
104 ///
105 /// # Examples
106 ///
107 /// ```
108 /// use tensor4all_tcicore::Matrix;
109 ///
110 /// let m = Matrix::from_elem(2, 3, 7.0);
111 /// assert_eq!(m[[0, 0]], 7.0);
112 /// assert_eq!(m[[1, 2]], 7.0);
113 /// ```
114 pub fn from_elem(nrows: usize, ncols: usize, elem: T) -> Self {
115 Self {
116 data: vec![elem; nrows * ncols],
117 nrows,
118 ncols,
119 }
120 }
121}
122
123impl<T: Clone + Zero> Matrix<T> {
124 /// Create a zeros matrix
125 ///
126 /// # Examples
127 ///
128 /// ```
129 /// use tensor4all_tcicore::Matrix;
130 ///
131 /// let m = Matrix::<f64>::zeros(2, 3);
132 /// assert_eq!(m.nrows(), 2);
133 /// assert_eq!(m.ncols(), 3);
134 /// assert_eq!(m[[0, 0]], 0.0);
135 /// assert_eq!(m[[1, 2]], 0.0);
136 /// ```
137 pub fn zeros(nrows: usize, ncols: usize) -> Self {
138 Self {
139 data: vec![T::zero(); nrows * ncols],
140 nrows,
141 ncols,
142 }
143 }
144}
145
146impl<T> Index<[usize; 2]> for Matrix<T> {
147 type Output = T;
148
149 fn index(&self, idx: [usize; 2]) -> &Self::Output {
150 &self.data[idx[0] * self.ncols + idx[1]]
151 }
152}
153
154impl<T> IndexMut<[usize; 2]> for Matrix<T> {
155 fn index_mut(&mut self, idx: [usize; 2]) -> &mut Self::Output {
156 &mut self.data[idx[0] * self.ncols + idx[1]]
157 }
158}
159
160/// Create a zeros matrix with given dimensions.
161///
162/// # Examples
163///
164/// ```
165/// use tensor4all_tcicore::matrix::zeros;
166///
167/// let m: tensor4all_tcicore::Matrix<f64> = zeros(2, 3);
168/// assert_eq!(m[[0, 0]], 0.0);
169/// assert_eq!(m.nrows(), 2);
170/// assert_eq!(m.ncols(), 3);
171/// ```
172pub fn zeros<T: Clone + Zero>(nrows: usize, ncols: usize) -> Matrix<T> {
173 Matrix::zeros(nrows, ncols)
174}
175
176/// Create an `n x n` identity matrix.
177///
178/// # Examples
179///
180/// ```
181/// use tensor4all_tcicore::matrix::eye;
182///
183/// let m: tensor4all_tcicore::Matrix<f64> = eye(3);
184/// assert_eq!(m[[0, 0]], 1.0);
185/// assert_eq!(m[[1, 1]], 1.0);
186/// assert_eq!(m[[0, 1]], 0.0);
187/// assert_eq!(m[[2, 0]], 0.0);
188/// ```
189pub fn eye<T: Clone + Zero + One>(n: usize) -> Matrix<T> {
190 let mut m = zeros(n, n);
191 for i in 0..n {
192 m[[i, i]] = T::one();
193 }
194 m
195}
196
197/// Create a matrix from a 2D vector (row-major).
198///
199/// Each inner `Vec` is one row.
200///
201/// # Examples
202///
203/// ```
204/// use tensor4all_tcicore::from_vec2d;
205///
206/// let m = from_vec2d(vec![
207/// vec![1.0, 2.0],
208/// vec![3.0, 4.0],
209/// ]);
210/// assert_eq!(m.nrows(), 2);
211/// assert_eq!(m.ncols(), 2);
212/// assert_eq!(m[[0, 1]], 2.0);
213/// assert_eq!(m[[1, 0]], 3.0);
214/// ```
215pub fn from_vec2d<T: Clone + Zero>(data: Vec<Vec<T>>) -> Matrix<T> {
216 let nrows = data.len();
217 let ncols = if nrows > 0 { data[0].len() } else { 0 };
218 let mut m = zeros(nrows, ncols);
219 for i in 0..nrows {
220 for j in 0..ncols {
221 m[[i, j]] = data[i][j].clone();
222 }
223 }
224 m
225}
226
227/// Get number of rows.
228///
229/// # Examples
230///
231/// ```
232/// use tensor4all_tcicore::{from_vec2d, matrix::nrows};
233///
234/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
235/// assert_eq!(nrows(&m), 3);
236/// ```
237pub fn nrows<T>(m: &Matrix<T>) -> usize {
238 m.nrows
239}
240
241/// Get number of columns.
242///
243/// # Examples
244///
245/// ```
246/// use tensor4all_tcicore::{from_vec2d, matrix::ncols};
247///
248/// let m = from_vec2d(vec![vec![1.0, 2.0, 3.0]]);
249/// assert_eq!(ncols(&m), 3);
250/// ```
251pub fn ncols<T>(m: &Matrix<T>) -> usize {
252 m.ncols
253}
254
255/// Get a row as a vector.
256///
257/// # Examples
258///
259/// ```
260/// use tensor4all_tcicore::{from_vec2d, matrix::get_row};
261///
262/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
263/// assert_eq!(get_row(&m, 0), vec![1.0, 2.0]);
264/// assert_eq!(get_row(&m, 1), vec![3.0, 4.0]);
265/// ```
266pub fn get_row<T: Clone>(m: &Matrix<T>, i: usize) -> Vec<T> {
267 (0..m.ncols).map(|j| m[[i, j]].clone()).collect()
268}
269
270/// Get a column as a vector.
271///
272/// # Examples
273///
274/// ```
275/// use tensor4all_tcicore::{from_vec2d, matrix::get_col};
276///
277/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
278/// assert_eq!(get_col(&m, 0), vec![1.0, 3.0]);
279/// assert_eq!(get_col(&m, 1), vec![2.0, 4.0]);
280/// ```
281pub fn get_col<T: Clone>(m: &Matrix<T>, j: usize) -> Vec<T> {
282 (0..m.nrows).map(|i| m[[i, j]].clone()).collect()
283}
284
285/// Get a submatrix by selecting specific rows and columns.
286///
287/// # Examples
288///
289/// ```
290/// use tensor4all_tcicore::{from_vec2d, matrix::submatrix};
291///
292/// let m = from_vec2d(vec![
293/// vec![1.0, 2.0, 3.0],
294/// vec![4.0, 5.0, 6.0],
295/// vec![7.0, 8.0, 9.0],
296/// ]);
297/// let sub = submatrix(&m, &[0, 2], &[1, 2]);
298/// assert_eq!(sub.nrows(), 2);
299/// assert_eq!(sub.ncols(), 2);
300/// assert_eq!(sub[[0, 0]], 2.0); // m[0, 1]
301/// assert_eq!(sub[[1, 1]], 9.0); // m[2, 2]
302/// ```
303pub fn submatrix<T: Clone + Zero>(m: &Matrix<T>, rows: &[usize], cols: &[usize]) -> Matrix<T> {
304 let mut result = zeros(rows.len(), cols.len());
305 for (ri, &r) in rows.iter().enumerate() {
306 for (ci, &c) in cols.iter().enumerate() {
307 result[[ri, ci]] = m[[r, c]].clone();
308 }
309 }
310 result
311}
312
313/// Append a column to the right of a matrix.
314///
315/// # Panics
316///
317/// Panics if `col.len() != m.nrows()`.
318///
319/// # Examples
320///
321/// ```
322/// use tensor4all_tcicore::{from_vec2d, matrix::append_col};
323///
324/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
325/// let m2 = append_col(&m, &[5.0, 6.0]);
326/// assert_eq!(m2.ncols(), 3);
327/// assert_eq!(m2[[0, 2]], 5.0);
328/// assert_eq!(m2[[1, 2]], 6.0);
329/// ```
330pub fn append_col<T: Clone + Zero>(m: &Matrix<T>, col: &[T]) -> Matrix<T> {
331 let nr = m.nrows;
332 let nc = m.ncols;
333 assert_eq!(col.len(), nr);
334
335 let mut result = zeros(nr, nc + 1);
336 for i in 0..nr {
337 for j in 0..nc {
338 result[[i, j]] = m[[i, j]].clone();
339 }
340 result[[i, nc]] = col[i].clone();
341 }
342 result
343}
344
345/// Append a row to the bottom of a matrix.
346///
347/// # Panics
348///
349/// Panics if `row.len() != m.ncols()`.
350///
351/// # Examples
352///
353/// ```
354/// use tensor4all_tcicore::{from_vec2d, matrix::append_row};
355///
356/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
357/// let m2 = append_row(&m, &[5.0, 6.0]);
358/// assert_eq!(m2.nrows(), 3);
359/// assert_eq!(m2[[2, 0]], 5.0);
360/// assert_eq!(m2[[2, 1]], 6.0);
361/// ```
362pub fn append_row<T: Clone + Zero>(m: &Matrix<T>, row: &[T]) -> Matrix<T> {
363 let nr = m.nrows;
364 let nc = m.ncols;
365 assert_eq!(row.len(), nc);
366
367 let mut result = zeros(nr + 1, nc);
368 for i in 0..nr {
369 for j in 0..nc {
370 result[[i, j]] = m[[i, j]].clone();
371 }
372 }
373 for j in 0..nc {
374 result[[nr, j]] = row[j].clone();
375 }
376 result
377}
378
379/// Swap two rows in a matrix in-place.
380///
381/// No-op if `a == b`.
382///
383/// # Examples
384///
385/// ```
386/// use tensor4all_tcicore::{from_vec2d, matrix::swap_rows};
387///
388/// let mut m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
389/// swap_rows(&mut m, 0, 1);
390/// assert_eq!(m[[0, 0]], 3.0);
391/// assert_eq!(m[[1, 0]], 1.0);
392/// ```
393pub fn swap_rows<T>(m: &mut Matrix<T>, a: usize, b: usize) {
394 if a == b {
395 return;
396 }
397 for j in 0..m.ncols {
398 let idx_a = a * m.ncols + j;
399 let idx_b = b * m.ncols + j;
400 m.data.swap(idx_a, idx_b);
401 }
402}
403
404/// Swap two columns in a matrix in-place.
405///
406/// No-op if `a == b`.
407///
408/// # Examples
409///
410/// ```
411/// use tensor4all_tcicore::{from_vec2d, matrix::swap_cols};
412///
413/// let mut m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
414/// swap_cols(&mut m, 0, 1);
415/// assert_eq!(m[[0, 0]], 2.0);
416/// assert_eq!(m[[0, 1]], 1.0);
417/// ```
418pub fn swap_cols<T>(m: &mut Matrix<T>, a: usize, b: usize) {
419 if a == b {
420 return;
421 }
422 for i in 0..m.nrows {
423 let idx_a = i * m.ncols + a;
424 let idx_b = i * m.ncols + b;
425 m.data.swap(idx_a, idx_b);
426 }
427}
428
429/// Transpose the matrix.
430///
431/// # Examples
432///
433/// ```
434/// use tensor4all_tcicore::{from_vec2d, matrix::transpose};
435///
436/// let m = from_vec2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
437/// let mt = transpose(&m);
438/// assert_eq!(mt.nrows(), 3);
439/// assert_eq!(mt.ncols(), 2);
440/// assert_eq!(mt[[0, 0]], 1.0);
441/// assert_eq!(mt[[2, 1]], 6.0);
442/// ```
443pub fn transpose<T: Clone + Zero>(m: &Matrix<T>) -> Matrix<T> {
444 let mut result = zeros(m.ncols, m.nrows);
445 for i in 0..m.nrows {
446 for j in 0..m.ncols {
447 result[[j, i]] = m[[i, j]].clone();
448 }
449 }
450 result
451}
452
453// Scalar trait is now defined in crate::scalar module
454
455/// Calculates A * B^{-1} using Gaussian elimination.
456///
457/// # Panics
458///
459/// Panics if the number of columns of `a` does not match the dimensions of `b`,
460/// or if `b` is not square.
461///
462/// # Examples
463///
464/// ```
465/// use tensor4all_tcicore::{from_vec2d, matrix::a_times_b_inv};
466///
467/// let a = from_vec2d(vec![vec![2.0_f64, 0.0], vec![0.0, 4.0]]);
468/// let b = from_vec2d(vec![vec![1.0, 0.0], vec![0.0, 2.0]]);
469/// let result = a_times_b_inv(&a, &b);
470/// assert!((result[[0, 0]] - 2.0).abs() < 1e-10);
471/// assert!((result[[1, 1]] - 2.0).abs() < 1e-10);
472/// ```
473pub fn a_times_b_inv<T: Scalar>(a: &Matrix<T>, b: &Matrix<T>) -> Matrix<T> {
474 let n = ncols(a);
475 assert_eq!(nrows(b), n);
476 assert_eq!(ncols(b), n);
477
478 // Solve XB = A by solving B'X' = A'
479 let bt = transpose(b);
480 let at = transpose(a);
481 let xt = solve_linear_system(&bt, &at);
482 transpose(&xt)
483}
484
485/// Calculates A^{-1} * B using Gaussian elimination.
486///
487/// # Examples
488///
489/// ```
490/// use tensor4all_tcicore::{from_vec2d, matrix::a_inv_times_b};
491///
492/// let a = from_vec2d(vec![vec![2.0_f64, 0.0], vec![0.0, 4.0]]);
493/// let b = from_vec2d(vec![vec![6.0, 0.0], vec![0.0, 8.0]]);
494/// let result = a_inv_times_b(&a, &b);
495/// assert!((result[[0, 0]] - 3.0).abs() < 1e-10);
496/// assert!((result[[1, 1]] - 2.0).abs() < 1e-10);
497/// ```
498pub fn a_inv_times_b<T: Scalar>(a: &Matrix<T>, b: &Matrix<T>) -> Matrix<T> {
499 let bt = transpose(b);
500 let at = transpose(a);
501 let result = a_times_b_inv(&bt, &at);
502 transpose(&result)
503}
504
505/// Solve linear system AX = B using Gaussian elimination with partial pivoting
506#[allow(clippy::needless_range_loop)]
507fn solve_linear_system<T: Scalar>(a: &Matrix<T>, b: &Matrix<T>) -> Matrix<T> {
508 let n = nrows(a);
509 assert_eq!(ncols(a), n);
510 assert_eq!(nrows(b), n);
511 let m = ncols(b);
512
513 // Create augmented matrix [A | B]
514 let mut aug: Vec<Vec<T>> = (0..n)
515 .map(|i| {
516 let mut row = Vec::with_capacity(n + m);
517 for j in 0..n {
518 row.push(a[[i, j]]);
519 }
520 for j in 0..m {
521 row.push(b[[i, j]]);
522 }
523 row
524 })
525 .collect();
526
527 // Forward elimination with partial pivoting
528 for k in 0..n {
529 // Find pivot
530 let mut max_idx = k;
531 let mut max_val: f64 = aug[k][k].abs_sq();
532 for i in (k + 1)..n {
533 let val: f64 = aug[i][k].abs_sq();
534 if val > max_val {
535 max_val = val;
536 max_idx = i;
537 }
538 }
539
540 // Swap rows
541 if max_idx != k {
542 aug.swap(k, max_idx);
543 }
544
545 let pivot = aug[k][k];
546 if pivot.abs_sq() < T::epsilon() {
547 continue;
548 }
549
550 // Eliminate below
551 for i in (k + 1)..n {
552 let factor = aug[i][k] / pivot;
553 for j in k..(n + m) {
554 aug[i][j] = aug[i][j] - factor * aug[k][j];
555 }
556 }
557 }
558
559 // Back substitution
560 let mut x: Vec<Vec<T>> = vec![vec![T::zero(); m]; n];
561 for i in (0..n).rev() {
562 for j in 0..m {
563 let mut sum = aug[i][n + j];
564 for k in (i + 1)..n {
565 sum = sum - aug[i][k] * x[k][j];
566 }
567 let diag = aug[i][i];
568 if diag.abs_sq() > T::epsilon() {
569 x[i][j] = sum / diag;
570 }
571 }
572 }
573
574 from_vec2d(x)
575}
576
577/// Find the position and value of the maximum absolute value in a submatrix.
578///
579/// Searches within the rectangular region defined by `rows x cols` ranges.
580/// Returns `(row, col, value)` of the element with the largest `|value|^2`.
581///
582/// # Panics
583///
584/// Panics if either range is empty.
585///
586/// # Examples
587///
588/// ```
589/// use tensor4all_tcicore::{from_vec2d, matrix::submatrix_argmax};
590///
591/// let m = from_vec2d(vec![
592/// vec![1.0_f64, 2.0, 3.0],
593/// vec![4.0, 9.0, 6.0],
594/// vec![7.0, 8.0, 5.0],
595/// ]);
596/// let (row, col, val) = submatrix_argmax(&m, 0..3, 0..3);
597/// assert_eq!(row, 1);
598/// assert_eq!(col, 1);
599/// assert_eq!(val, 9.0);
600/// ```
601pub fn submatrix_argmax<T: Scalar>(
602 a: &Matrix<T>,
603 rows: std::ops::Range<usize>,
604 cols: std::ops::Range<usize>,
605) -> (usize, usize, T) {
606 assert!(!rows.is_empty(), "rows must not be empty");
607 assert!(!cols.is_empty(), "cols must not be empty");
608
609 let mut max_val: f64 = a[[rows.start, cols.start]].abs_sq();
610 let mut max_row = rows.start;
611 let mut max_col = cols.start;
612
613 for r in rows {
614 for c in cols.clone() {
615 let val: f64 = a[[r, c]].abs_sq();
616 if val > max_val {
617 max_val = val;
618 max_row = r;
619 max_col = c;
620 }
621 }
622 }
623
624 (max_row, max_col, a[[max_row, max_col]])
625}
626
627/// Select a random subset of up to `n` elements from a slice.
628///
629/// If `n >= set.len()`, returns at most `set.len()` elements (a shuffled
630/// subset). Returns an empty vector when the set is empty or `n` is zero.
631///
632/// # Examples
633///
634/// ```
635/// use tensor4all_tcicore::matrix::random_subset;
636/// use rand::SeedableRng;
637///
638/// let mut rng = rand::rngs::StdRng::seed_from_u64(42);
639/// let items = vec![10, 20, 30, 40, 50];
640/// let sub = random_subset(&items, 3, &mut rng);
641/// assert_eq!(sub.len(), 3);
642/// // All selected elements come from the original set
643/// for &x in &sub {
644/// assert!(items.contains(&x));
645/// }
646/// // Requesting more than available returns at most set.len()
647/// let all = random_subset(&items, 100, &mut rng);
648/// assert_eq!(all.len(), 5);
649/// ```
650pub fn random_subset<T: Clone, R: Rng>(set: &[T], n: usize, rng: &mut R) -> Vec<T> {
651 let n = n.min(set.len());
652 if n == 0 {
653 return Vec::new();
654 }
655
656 let mut indices: Vec<usize> = (0..set.len()).collect();
657 indices.shuffle(rng);
658 indices.truncate(n);
659 indices.into_iter().map(|i| set[i].clone()).collect()
660}
661
662/// Set difference: elements in `set` that are not in `exclude`.
663///
664/// Preserves the order of elements in `set`.
665///
666/// # Examples
667///
668/// ```
669/// use tensor4all_tcicore::matrix::set_diff;
670///
671/// let result = set_diff(&[0, 1, 2, 3, 4], &[1, 3]);
672/// assert_eq!(result, vec![0, 2, 4]);
673/// ```
674pub fn set_diff(set: &[usize], exclude: &[usize]) -> Vec<usize> {
675 let exclude_set: HashSet<usize> = exclude.iter().copied().collect();
676 set.iter()
677 .copied()
678 .filter(|x| !exclude_set.contains(x))
679 .collect()
680}
681
682/// Dot product of two vectors.
683///
684/// # Panics
685///
686/// Panics if `a.len() != b.len()`.
687///
688/// # Examples
689///
690/// ```
691/// use tensor4all_tcicore::matrix::dot;
692///
693/// let a = [1.0_f64, 2.0, 3.0];
694/// let b = [4.0, 5.0, 6.0];
695/// assert!((dot(&a, &b) - 32.0).abs() < 1e-10);
696/// ```
697pub fn dot<T: Scalar>(a: &[T], b: &[T]) -> T {
698 assert_eq!(a.len(), b.len());
699 a.iter()
700 .zip(b.iter())
701 .fold(T::zero(), |acc, (&x, &y)| acc + x * y)
702}
703
704/// BLAS-backed matrix multiplication dispatch.
705///
706/// Implemented for all scalar types supported by tenferro einsum
707/// (f64, f32, Complex64, Complex32). This trait is sealed — external
708/// types cannot implement it.
709pub trait BlasMul: Sized {
710 #[doc(hidden)]
711 fn blas_mat_mul(a: &Matrix<Self>, b: &Matrix<Self>) -> Matrix<Self>;
712}
713
714fn row_major_to_col_major<T: Copy>(data: &[T], nrows: usize, ncols: usize) -> Vec<T> {
715 let mut out = Vec::with_capacity(data.len());
716 for col in 0..ncols {
717 for row in 0..nrows {
718 out.push(data[row * ncols + col]);
719 }
720 }
721 out
722}
723
724fn col_major_to_row_major<T: Copy>(data: &[T], nrows: usize, ncols: usize) -> Vec<T> {
725 let mut out = Vec::with_capacity(data.len());
726 for row in 0..nrows {
727 for col in 0..ncols {
728 out.push(data[col * nrows + row]);
729 }
730 }
731 out
732}
733
734macro_rules! impl_blas_mul {
735 ($($t:ty),*) => {
736 $(
737 impl BlasMul for $t {
738 fn blas_mat_mul(a: &Matrix<Self>, b: &Matrix<Self>) -> Matrix<Self> {
739 use tenferro_einsum::typed_eager_einsum;
740 use tenferro_tensor::TypedTensor;
741 use tensor4all_tensorbackend::with_default_backend;
742
743 let m = a.nrows();
744 let k = a.ncols();
745 let n = b.ncols();
746 assert_eq!(b.nrows(), k);
747
748 let a_tensor = TypedTensor::<$t>::from_vec(
749 vec![m, k],
750 row_major_to_col_major(a.as_slice(), m, k),
751 );
752 let b_tensor = TypedTensor::<$t>::from_vec(
753 vec![k, n],
754 row_major_to_col_major(b.as_slice(), k, n),
755 );
756 let c = with_default_backend(|backend| {
757 typed_eager_einsum(backend, &[&a_tensor, &b_tensor], "ij,jk->ik")
758 })
759 .expect("einsum failed");
760 let c_data = col_major_to_row_major(c.as_slice(), m, n);
761 Matrix::from_raw_vec(m, n, c_data)
762 }
763 }
764 )*
765 };
766}
767
768impl_blas_mul!(f64, f32, num_complex::Complex64, num_complex::Complex32);
769
770/// Matrix multiplication: A * B.
771///
772/// Uses BLAS-backed einsum via tenferro for high performance.
773///
774/// # Panics
775///
776/// Panics if `a.ncols() != b.nrows()`.
777///
778/// # Examples
779///
780/// ```
781/// use tensor4all_tcicore::{from_vec2d, matrix::mat_mul};
782///
783/// let a = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
784/// let b = from_vec2d(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
785/// let c = mat_mul(&a, &b);
786/// assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
787/// assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
788/// assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
789/// assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
790/// ```
791pub fn mat_mul<T: Scalar>(a: &Matrix<T>, b: &Matrix<T>) -> Matrix<T> {
792 T::blas_mat_mul(a, b)
793}
794
795#[cfg(test)]
796mod tests;