Skip to main content

tensor4all_tcicore/
matrix.rs

1//! Dense row-major matrix type and utility functions.
2//!
3//! [`Matrix<T>`] is a simple dense 2D matrix in row-major layout, indexed
4//! by `m[[row, col]]`. It is used throughout the TCI infrastructure for
5//! pivot block computations, cross interpolation factors, and dense
6//! submatrix extraction.
7//!
8//! # Examples
9//!
10//! ```
11//! use tensor4all_tcicore::{Matrix, from_vec2d, matrix};
12//!
13//! let m = from_vec2d(vec![
14//!     vec![1.0_f64, 2.0],
15//!     vec![3.0, 4.0],
16//! ]);
17//! assert_eq!(m.nrows(), 2);
18//! assert_eq!(m.ncols(), 2);
19//! assert_eq!(m[[0, 1]], 2.0);
20//! assert_eq!(m[[1, 0]], 3.0);
21//! ```
22
23use crate::scalar::Scalar;
24use num_traits::{One, Zero};
25use rand::seq::SliceRandom;
26use rand::Rng;
27use std::collections::HashSet;
28use std::ops::{Index, IndexMut};
29
30/// A dense 2D matrix in row-major layout.
31///
32/// Access elements with `m[[row, col]]` syntax. Data is stored contiguously
33/// in row-major order.
34///
35/// # Examples
36///
37/// ```
38/// use tensor4all_tcicore::Matrix;
39///
40/// let mut m = Matrix::zeros(2, 3);
41/// m[[0, 1]] = 5.0_f64;
42/// assert_eq!(m[[0, 1]], 5.0);
43/// assert_eq!(m[[0, 0]], 0.0);
44/// assert_eq!(m.nrows(), 2);
45/// assert_eq!(m.ncols(), 3);
46/// ```
47#[derive(Debug, Clone)]
48pub struct Matrix<T> {
49    data: Vec<T>,
50    nrows: usize,
51    ncols: usize,
52}
53
54impl<T> Matrix<T> {
55    /// Create a matrix from raw row-major data.
56    ///
57    /// # Panics
58    ///
59    /// Panics if `data.len() != nrows * ncols`.
60    ///
61    /// # Examples
62    ///
63    /// ```
64    /// use tensor4all_tcicore::Matrix;
65    ///
66    /// let m = Matrix::from_raw_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
67    /// assert_eq!(m[[0, 0]], 1.0);
68    /// assert_eq!(m[[0, 1]], 2.0);
69    /// assert_eq!(m[[1, 0]], 3.0);
70    /// assert_eq!(m[[1, 1]], 4.0);
71    /// ```
72    pub fn from_raw_vec(nrows: usize, ncols: usize, data: Vec<T>) -> Self {
73        assert_eq!(data.len(), nrows * ncols);
74        Self { data, nrows, ncols }
75    }
76
77    /// View the underlying row-major data as a contiguous slice.
78    ///
79    /// # Examples
80    ///
81    /// ```
82    /// use tensor4all_tcicore::Matrix;
83    ///
84    /// let m = Matrix::from_raw_vec(1, 3, vec![10, 20, 30]);
85    /// assert_eq!(m.as_slice(), &[10, 20, 30]);
86    /// ```
87    pub fn as_slice(&self) -> &[T] {
88        &self.data
89    }
90
91    /// Number of rows
92    pub fn nrows(&self) -> usize {
93        self.nrows
94    }
95
96    /// Number of columns
97    pub fn ncols(&self) -> usize {
98        self.ncols
99    }
100}
101
102impl<T: Clone> Matrix<T> {
103    /// Create a new matrix filled with a constant value.
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// use tensor4all_tcicore::Matrix;
109    ///
110    /// let m = Matrix::from_elem(2, 3, 7.0);
111    /// assert_eq!(m[[0, 0]], 7.0);
112    /// assert_eq!(m[[1, 2]], 7.0);
113    /// ```
114    pub fn from_elem(nrows: usize, ncols: usize, elem: T) -> Self {
115        Self {
116            data: vec![elem; nrows * ncols],
117            nrows,
118            ncols,
119        }
120    }
121}
122
123impl<T: Clone + Zero> Matrix<T> {
124    /// Create a zeros matrix
125    ///
126    /// # Examples
127    ///
128    /// ```
129    /// use tensor4all_tcicore::Matrix;
130    ///
131    /// let m = Matrix::<f64>::zeros(2, 3);
132    /// assert_eq!(m.nrows(), 2);
133    /// assert_eq!(m.ncols(), 3);
134    /// assert_eq!(m[[0, 0]], 0.0);
135    /// assert_eq!(m[[1, 2]], 0.0);
136    /// ```
137    pub fn zeros(nrows: usize, ncols: usize) -> Self {
138        Self {
139            data: vec![T::zero(); nrows * ncols],
140            nrows,
141            ncols,
142        }
143    }
144}
145
146impl<T> Index<[usize; 2]> for Matrix<T> {
147    type Output = T;
148
149    fn index(&self, idx: [usize; 2]) -> &Self::Output {
150        &self.data[idx[0] * self.ncols + idx[1]]
151    }
152}
153
154impl<T> IndexMut<[usize; 2]> for Matrix<T> {
155    fn index_mut(&mut self, idx: [usize; 2]) -> &mut Self::Output {
156        &mut self.data[idx[0] * self.ncols + idx[1]]
157    }
158}
159
160/// Create a zeros matrix with given dimensions.
161///
162/// # Examples
163///
164/// ```
165/// use tensor4all_tcicore::matrix::zeros;
166///
167/// let m: tensor4all_tcicore::Matrix<f64> = zeros(2, 3);
168/// assert_eq!(m[[0, 0]], 0.0);
169/// assert_eq!(m.nrows(), 2);
170/// assert_eq!(m.ncols(), 3);
171/// ```
172pub fn zeros<T: Clone + Zero>(nrows: usize, ncols: usize) -> Matrix<T> {
173    Matrix::zeros(nrows, ncols)
174}
175
176/// Create an `n x n` identity matrix.
177///
178/// # Examples
179///
180/// ```
181/// use tensor4all_tcicore::matrix::eye;
182///
183/// let m: tensor4all_tcicore::Matrix<f64> = eye(3);
184/// assert_eq!(m[[0, 0]], 1.0);
185/// assert_eq!(m[[1, 1]], 1.0);
186/// assert_eq!(m[[0, 1]], 0.0);
187/// assert_eq!(m[[2, 0]], 0.0);
188/// ```
189pub fn eye<T: Clone + Zero + One>(n: usize) -> Matrix<T> {
190    let mut m = zeros(n, n);
191    for i in 0..n {
192        m[[i, i]] = T::one();
193    }
194    m
195}
196
197/// Create a matrix from a 2D vector (row-major).
198///
199/// Each inner `Vec` is one row.
200///
201/// # Examples
202///
203/// ```
204/// use tensor4all_tcicore::from_vec2d;
205///
206/// let m = from_vec2d(vec![
207///     vec![1.0, 2.0],
208///     vec![3.0, 4.0],
209/// ]);
210/// assert_eq!(m.nrows(), 2);
211/// assert_eq!(m.ncols(), 2);
212/// assert_eq!(m[[0, 1]], 2.0);
213/// assert_eq!(m[[1, 0]], 3.0);
214/// ```
215pub fn from_vec2d<T: Clone + Zero>(data: Vec<Vec<T>>) -> Matrix<T> {
216    let nrows = data.len();
217    let ncols = if nrows > 0 { data[0].len() } else { 0 };
218    let mut m = zeros(nrows, ncols);
219    for i in 0..nrows {
220        for j in 0..ncols {
221            m[[i, j]] = data[i][j].clone();
222        }
223    }
224    m
225}
226
227/// Get number of rows.
228///
229/// # Examples
230///
231/// ```
232/// use tensor4all_tcicore::{from_vec2d, matrix::nrows};
233///
234/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
235/// assert_eq!(nrows(&m), 3);
236/// ```
237pub fn nrows<T>(m: &Matrix<T>) -> usize {
238    m.nrows
239}
240
241/// Get number of columns.
242///
243/// # Examples
244///
245/// ```
246/// use tensor4all_tcicore::{from_vec2d, matrix::ncols};
247///
248/// let m = from_vec2d(vec![vec![1.0, 2.0, 3.0]]);
249/// assert_eq!(ncols(&m), 3);
250/// ```
251pub fn ncols<T>(m: &Matrix<T>) -> usize {
252    m.ncols
253}
254
255/// Get a row as a vector.
256///
257/// # Examples
258///
259/// ```
260/// use tensor4all_tcicore::{from_vec2d, matrix::get_row};
261///
262/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
263/// assert_eq!(get_row(&m, 0), vec![1.0, 2.0]);
264/// assert_eq!(get_row(&m, 1), vec![3.0, 4.0]);
265/// ```
266pub fn get_row<T: Clone>(m: &Matrix<T>, i: usize) -> Vec<T> {
267    (0..m.ncols).map(|j| m[[i, j]].clone()).collect()
268}
269
270/// Get a column as a vector.
271///
272/// # Examples
273///
274/// ```
275/// use tensor4all_tcicore::{from_vec2d, matrix::get_col};
276///
277/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
278/// assert_eq!(get_col(&m, 0), vec![1.0, 3.0]);
279/// assert_eq!(get_col(&m, 1), vec![2.0, 4.0]);
280/// ```
281pub fn get_col<T: Clone>(m: &Matrix<T>, j: usize) -> Vec<T> {
282    (0..m.nrows).map(|i| m[[i, j]].clone()).collect()
283}
284
285/// Get a submatrix by selecting specific rows and columns.
286///
287/// # Examples
288///
289/// ```
290/// use tensor4all_tcicore::{from_vec2d, matrix::submatrix};
291///
292/// let m = from_vec2d(vec![
293///     vec![1.0, 2.0, 3.0],
294///     vec![4.0, 5.0, 6.0],
295///     vec![7.0, 8.0, 9.0],
296/// ]);
297/// let sub = submatrix(&m, &[0, 2], &[1, 2]);
298/// assert_eq!(sub.nrows(), 2);
299/// assert_eq!(sub.ncols(), 2);
300/// assert_eq!(sub[[0, 0]], 2.0); // m[0, 1]
301/// assert_eq!(sub[[1, 1]], 9.0); // m[2, 2]
302/// ```
303pub fn submatrix<T: Clone + Zero>(m: &Matrix<T>, rows: &[usize], cols: &[usize]) -> Matrix<T> {
304    let mut result = zeros(rows.len(), cols.len());
305    for (ri, &r) in rows.iter().enumerate() {
306        for (ci, &c) in cols.iter().enumerate() {
307            result[[ri, ci]] = m[[r, c]].clone();
308        }
309    }
310    result
311}
312
313/// Append a column to the right of a matrix.
314///
315/// # Panics
316///
317/// Panics if `col.len() != m.nrows()`.
318///
319/// # Examples
320///
321/// ```
322/// use tensor4all_tcicore::{from_vec2d, matrix::append_col};
323///
324/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
325/// let m2 = append_col(&m, &[5.0, 6.0]);
326/// assert_eq!(m2.ncols(), 3);
327/// assert_eq!(m2[[0, 2]], 5.0);
328/// assert_eq!(m2[[1, 2]], 6.0);
329/// ```
330pub fn append_col<T: Clone + Zero>(m: &Matrix<T>, col: &[T]) -> Matrix<T> {
331    let nr = m.nrows;
332    let nc = m.ncols;
333    assert_eq!(col.len(), nr);
334
335    let mut result = zeros(nr, nc + 1);
336    for i in 0..nr {
337        for j in 0..nc {
338            result[[i, j]] = m[[i, j]].clone();
339        }
340        result[[i, nc]] = col[i].clone();
341    }
342    result
343}
344
345/// Append a row to the bottom of a matrix.
346///
347/// # Panics
348///
349/// Panics if `row.len() != m.ncols()`.
350///
351/// # Examples
352///
353/// ```
354/// use tensor4all_tcicore::{from_vec2d, matrix::append_row};
355///
356/// let m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
357/// let m2 = append_row(&m, &[5.0, 6.0]);
358/// assert_eq!(m2.nrows(), 3);
359/// assert_eq!(m2[[2, 0]], 5.0);
360/// assert_eq!(m2[[2, 1]], 6.0);
361/// ```
362pub fn append_row<T: Clone + Zero>(m: &Matrix<T>, row: &[T]) -> Matrix<T> {
363    let nr = m.nrows;
364    let nc = m.ncols;
365    assert_eq!(row.len(), nc);
366
367    let mut result = zeros(nr + 1, nc);
368    for i in 0..nr {
369        for j in 0..nc {
370            result[[i, j]] = m[[i, j]].clone();
371        }
372    }
373    for j in 0..nc {
374        result[[nr, j]] = row[j].clone();
375    }
376    result
377}
378
379/// Swap two rows in a matrix in-place.
380///
381/// No-op if `a == b`.
382///
383/// # Examples
384///
385/// ```
386/// use tensor4all_tcicore::{from_vec2d, matrix::swap_rows};
387///
388/// let mut m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
389/// swap_rows(&mut m, 0, 1);
390/// assert_eq!(m[[0, 0]], 3.0);
391/// assert_eq!(m[[1, 0]], 1.0);
392/// ```
393pub fn swap_rows<T>(m: &mut Matrix<T>, a: usize, b: usize) {
394    if a == b {
395        return;
396    }
397    for j in 0..m.ncols {
398        let idx_a = a * m.ncols + j;
399        let idx_b = b * m.ncols + j;
400        m.data.swap(idx_a, idx_b);
401    }
402}
403
404/// Swap two columns in a matrix in-place.
405///
406/// No-op if `a == b`.
407///
408/// # Examples
409///
410/// ```
411/// use tensor4all_tcicore::{from_vec2d, matrix::swap_cols};
412///
413/// let mut m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
414/// swap_cols(&mut m, 0, 1);
415/// assert_eq!(m[[0, 0]], 2.0);
416/// assert_eq!(m[[0, 1]], 1.0);
417/// ```
418pub fn swap_cols<T>(m: &mut Matrix<T>, a: usize, b: usize) {
419    if a == b {
420        return;
421    }
422    for i in 0..m.nrows {
423        let idx_a = i * m.ncols + a;
424        let idx_b = i * m.ncols + b;
425        m.data.swap(idx_a, idx_b);
426    }
427}
428
429/// Transpose the matrix.
430///
431/// # Examples
432///
433/// ```
434/// use tensor4all_tcicore::{from_vec2d, matrix::transpose};
435///
436/// let m = from_vec2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
437/// let mt = transpose(&m);
438/// assert_eq!(mt.nrows(), 3);
439/// assert_eq!(mt.ncols(), 2);
440/// assert_eq!(mt[[0, 0]], 1.0);
441/// assert_eq!(mt[[2, 1]], 6.0);
442/// ```
443pub fn transpose<T: Clone + Zero>(m: &Matrix<T>) -> Matrix<T> {
444    let mut result = zeros(m.ncols, m.nrows);
445    for i in 0..m.nrows {
446        for j in 0..m.ncols {
447            result[[j, i]] = m[[i, j]].clone();
448        }
449    }
450    result
451}
452
453// Scalar trait is now defined in crate::scalar module
454
455/// Calculates A * B^{-1} using Gaussian elimination.
456///
457/// # Panics
458///
459/// Panics if the number of columns of `a` does not match the dimensions of `b`,
460/// or if `b` is not square.
461///
462/// # Examples
463///
464/// ```
465/// use tensor4all_tcicore::{from_vec2d, matrix::a_times_b_inv};
466///
467/// let a = from_vec2d(vec![vec![2.0_f64, 0.0], vec![0.0, 4.0]]);
468/// let b = from_vec2d(vec![vec![1.0, 0.0], vec![0.0, 2.0]]);
469/// let result = a_times_b_inv(&a, &b);
470/// assert!((result[[0, 0]] - 2.0).abs() < 1e-10);
471/// assert!((result[[1, 1]] - 2.0).abs() < 1e-10);
472/// ```
473pub fn a_times_b_inv<T: Scalar>(a: &Matrix<T>, b: &Matrix<T>) -> Matrix<T> {
474    let n = ncols(a);
475    assert_eq!(nrows(b), n);
476    assert_eq!(ncols(b), n);
477
478    // Solve XB = A by solving B'X' = A'
479    let bt = transpose(b);
480    let at = transpose(a);
481    let xt = solve_linear_system(&bt, &at);
482    transpose(&xt)
483}
484
485/// Calculates A^{-1} * B using Gaussian elimination.
486///
487/// # Examples
488///
489/// ```
490/// use tensor4all_tcicore::{from_vec2d, matrix::a_inv_times_b};
491///
492/// let a = from_vec2d(vec![vec![2.0_f64, 0.0], vec![0.0, 4.0]]);
493/// let b = from_vec2d(vec![vec![6.0, 0.0], vec![0.0, 8.0]]);
494/// let result = a_inv_times_b(&a, &b);
495/// assert!((result[[0, 0]] - 3.0).abs() < 1e-10);
496/// assert!((result[[1, 1]] - 2.0).abs() < 1e-10);
497/// ```
498pub fn a_inv_times_b<T: Scalar>(a: &Matrix<T>, b: &Matrix<T>) -> Matrix<T> {
499    let bt = transpose(b);
500    let at = transpose(a);
501    let result = a_times_b_inv(&bt, &at);
502    transpose(&result)
503}
504
505/// Solve linear system AX = B using Gaussian elimination with partial pivoting
506#[allow(clippy::needless_range_loop)]
507fn solve_linear_system<T: Scalar>(a: &Matrix<T>, b: &Matrix<T>) -> Matrix<T> {
508    let n = nrows(a);
509    assert_eq!(ncols(a), n);
510    assert_eq!(nrows(b), n);
511    let m = ncols(b);
512
513    // Create augmented matrix [A | B]
514    let mut aug: Vec<Vec<T>> = (0..n)
515        .map(|i| {
516            let mut row = Vec::with_capacity(n + m);
517            for j in 0..n {
518                row.push(a[[i, j]]);
519            }
520            for j in 0..m {
521                row.push(b[[i, j]]);
522            }
523            row
524        })
525        .collect();
526
527    // Forward elimination with partial pivoting
528    for k in 0..n {
529        // Find pivot
530        let mut max_idx = k;
531        let mut max_val: f64 = aug[k][k].abs_sq();
532        for i in (k + 1)..n {
533            let val: f64 = aug[i][k].abs_sq();
534            if val > max_val {
535                max_val = val;
536                max_idx = i;
537            }
538        }
539
540        // Swap rows
541        if max_idx != k {
542            aug.swap(k, max_idx);
543        }
544
545        let pivot = aug[k][k];
546        if pivot.abs_sq() < T::epsilon() {
547            continue;
548        }
549
550        // Eliminate below
551        for i in (k + 1)..n {
552            let factor = aug[i][k] / pivot;
553            for j in k..(n + m) {
554                aug[i][j] = aug[i][j] - factor * aug[k][j];
555            }
556        }
557    }
558
559    // Back substitution
560    let mut x: Vec<Vec<T>> = vec![vec![T::zero(); m]; n];
561    for i in (0..n).rev() {
562        for j in 0..m {
563            let mut sum = aug[i][n + j];
564            for k in (i + 1)..n {
565                sum = sum - aug[i][k] * x[k][j];
566            }
567            let diag = aug[i][i];
568            if diag.abs_sq() > T::epsilon() {
569                x[i][j] = sum / diag;
570            }
571        }
572    }
573
574    from_vec2d(x)
575}
576
577/// Find the position and value of the maximum absolute value in a submatrix.
578///
579/// Searches within the rectangular region defined by `rows x cols` ranges.
580/// Returns `(row, col, value)` of the element with the largest `|value|^2`.
581///
582/// # Panics
583///
584/// Panics if either range is empty.
585///
586/// # Examples
587///
588/// ```
589/// use tensor4all_tcicore::{from_vec2d, matrix::submatrix_argmax};
590///
591/// let m = from_vec2d(vec![
592///     vec![1.0_f64, 2.0, 3.0],
593///     vec![4.0, 9.0, 6.0],
594///     vec![7.0, 8.0, 5.0],
595/// ]);
596/// let (row, col, val) = submatrix_argmax(&m, 0..3, 0..3);
597/// assert_eq!(row, 1);
598/// assert_eq!(col, 1);
599/// assert_eq!(val, 9.0);
600/// ```
601pub fn submatrix_argmax<T: Scalar>(
602    a: &Matrix<T>,
603    rows: std::ops::Range<usize>,
604    cols: std::ops::Range<usize>,
605) -> (usize, usize, T) {
606    assert!(!rows.is_empty(), "rows must not be empty");
607    assert!(!cols.is_empty(), "cols must not be empty");
608
609    let mut max_val: f64 = a[[rows.start, cols.start]].abs_sq();
610    let mut max_row = rows.start;
611    let mut max_col = cols.start;
612
613    for r in rows {
614        for c in cols.clone() {
615            let val: f64 = a[[r, c]].abs_sq();
616            if val > max_val {
617                max_val = val;
618                max_row = r;
619                max_col = c;
620            }
621        }
622    }
623
624    (max_row, max_col, a[[max_row, max_col]])
625}
626
627/// Select a random subset of up to `n` elements from a slice.
628///
629/// If `n >= set.len()`, returns at most `set.len()` elements (a shuffled
630/// subset). Returns an empty vector when the set is empty or `n` is zero.
631///
632/// # Examples
633///
634/// ```
635/// use tensor4all_tcicore::matrix::random_subset;
636/// use rand::SeedableRng;
637///
638/// let mut rng = rand::rngs::StdRng::seed_from_u64(42);
639/// let items = vec![10, 20, 30, 40, 50];
640/// let sub = random_subset(&items, 3, &mut rng);
641/// assert_eq!(sub.len(), 3);
642/// // All selected elements come from the original set
643/// for &x in &sub {
644///     assert!(items.contains(&x));
645/// }
646/// // Requesting more than available returns at most set.len()
647/// let all = random_subset(&items, 100, &mut rng);
648/// assert_eq!(all.len(), 5);
649/// ```
650pub fn random_subset<T: Clone, R: Rng>(set: &[T], n: usize, rng: &mut R) -> Vec<T> {
651    let n = n.min(set.len());
652    if n == 0 {
653        return Vec::new();
654    }
655
656    let mut indices: Vec<usize> = (0..set.len()).collect();
657    indices.shuffle(rng);
658    indices.truncate(n);
659    indices.into_iter().map(|i| set[i].clone()).collect()
660}
661
662/// Set difference: elements in `set` that are not in `exclude`.
663///
664/// Preserves the order of elements in `set`.
665///
666/// # Examples
667///
668/// ```
669/// use tensor4all_tcicore::matrix::set_diff;
670///
671/// let result = set_diff(&[0, 1, 2, 3, 4], &[1, 3]);
672/// assert_eq!(result, vec![0, 2, 4]);
673/// ```
674pub fn set_diff(set: &[usize], exclude: &[usize]) -> Vec<usize> {
675    let exclude_set: HashSet<usize> = exclude.iter().copied().collect();
676    set.iter()
677        .copied()
678        .filter(|x| !exclude_set.contains(x))
679        .collect()
680}
681
682/// Dot product of two vectors.
683///
684/// # Panics
685///
686/// Panics if `a.len() != b.len()`.
687///
688/// # Examples
689///
690/// ```
691/// use tensor4all_tcicore::matrix::dot;
692///
693/// let a = [1.0_f64, 2.0, 3.0];
694/// let b = [4.0, 5.0, 6.0];
695/// assert!((dot(&a, &b) - 32.0).abs() < 1e-10);
696/// ```
697pub fn dot<T: Scalar>(a: &[T], b: &[T]) -> T {
698    assert_eq!(a.len(), b.len());
699    a.iter()
700        .zip(b.iter())
701        .fold(T::zero(), |acc, (&x, &y)| acc + x * y)
702}
703
704/// BLAS-backed matrix multiplication dispatch.
705///
706/// Implemented for all scalar types supported by tenferro einsum
707/// (f64, f32, Complex64, Complex32). This trait is sealed — external
708/// types cannot implement it.
709pub trait BlasMul: Sized {
710    #[doc(hidden)]
711    fn blas_mat_mul(a: &Matrix<Self>, b: &Matrix<Self>) -> Matrix<Self>;
712}
713
714fn row_major_to_col_major<T: Copy>(data: &[T], nrows: usize, ncols: usize) -> Vec<T> {
715    let mut out = Vec::with_capacity(data.len());
716    for col in 0..ncols {
717        for row in 0..nrows {
718            out.push(data[row * ncols + col]);
719        }
720    }
721    out
722}
723
724fn col_major_to_row_major<T: Copy>(data: &[T], nrows: usize, ncols: usize) -> Vec<T> {
725    let mut out = Vec::with_capacity(data.len());
726    for row in 0..nrows {
727        for col in 0..ncols {
728            out.push(data[col * nrows + row]);
729        }
730    }
731    out
732}
733
734macro_rules! impl_blas_mul {
735    ($($t:ty),*) => {
736        $(
737        impl BlasMul for $t {
738            fn blas_mat_mul(a: &Matrix<Self>, b: &Matrix<Self>) -> Matrix<Self> {
739                use tenferro_einsum::typed_eager_einsum;
740                use tenferro_tensor::TypedTensor;
741                use tensor4all_tensorbackend::with_default_backend;
742
743                let m = a.nrows();
744                let k = a.ncols();
745                let n = b.ncols();
746                assert_eq!(b.nrows(), k);
747
748                let a_tensor = TypedTensor::<$t>::from_vec(
749                    vec![m, k],
750                    row_major_to_col_major(a.as_slice(), m, k),
751                );
752                let b_tensor = TypedTensor::<$t>::from_vec(
753                    vec![k, n],
754                    row_major_to_col_major(b.as_slice(), k, n),
755                );
756                let c = with_default_backend(|backend| {
757                    typed_eager_einsum(backend, &[&a_tensor, &b_tensor], "ij,jk->ik")
758                })
759                .expect("einsum failed");
760                let c_data = col_major_to_row_major(c.as_slice(), m, n);
761                Matrix::from_raw_vec(m, n, c_data)
762            }
763        }
764        )*
765    };
766}
767
768impl_blas_mul!(f64, f32, num_complex::Complex64, num_complex::Complex32);
769
770/// Matrix multiplication: A * B.
771///
772/// Uses BLAS-backed einsum via tenferro for high performance.
773///
774/// # Panics
775///
776/// Panics if `a.ncols() != b.nrows()`.
777///
778/// # Examples
779///
780/// ```
781/// use tensor4all_tcicore::{from_vec2d, matrix::mat_mul};
782///
783/// let a = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
784/// let b = from_vec2d(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
785/// let c = mat_mul(&a, &b);
786/// assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
787/// assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
788/// assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
789/// assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
790/// ```
791pub fn mat_mul<T: Scalar>(a: &Matrix<T>, b: &Matrix<T>) -> Matrix<T> {
792    T::blas_mat_mul(a, b)
793}
794
795#[cfg(test)]
796mod tests;