tensor4all_tensorbackend/matrix.rs
1//! Dense column-major matrix type and utility functions.
2//!
3//! [`Matrix<T>`] is a simple dense 2D matrix in column-major layout, indexed
4//! by `m[[row, col]]`. It is the shared dense matrix boundary for tensor4all
5//! crates that need flat buffers and backend-backed matrix multiplication.
6//!
7//! # Examples
8//!
9//! ```
10//! use tensor4all_tensorbackend::{from_vec2d, Matrix};
11//!
12//! let m = from_vec2d(vec![
13//! vec![1.0_f64, 2.0],
14//! vec![3.0, 4.0],
15//! ]);
16//! assert_eq!(m.nrows(), 2);
17//! assert_eq!(m.ncols(), 2);
18//! assert_eq!(m[[0, 1]], 2.0);
19//! assert_eq!(m[[1, 0]], 3.0);
20//! ```
21
22use anyhow::{ensure, Context, Result};
23use num_complex::{Complex32, Complex64};
24use num_traits::{One, Zero};
25use std::ops::{Index, IndexMut};
26
27/// A dense 2D matrix in column-major layout.
28///
29/// Access elements with `m[[row, col]]` syntax. Data is stored contiguously
30/// in column-major order, so flat buffers use `row + nrows * col`.
31///
32/// # Examples
33///
34/// ```
35/// use tensor4all_tensorbackend::Matrix;
36///
37/// let mut m = Matrix::zeros(2, 3);
38/// m[[0, 1]] = 5.0_f64;
39/// assert_eq!(m[[0, 1]], 5.0);
40/// assert_eq!(m[[0, 0]], 0.0);
41/// assert_eq!(m.nrows(), 2);
42/// assert_eq!(m.ncols(), 3);
43/// ```
44#[derive(Debug, Clone)]
45pub struct Matrix<T> {
46 data: Vec<T>,
47 nrows: usize,
48 ncols: usize,
49}
50
51impl<T> Matrix<T> {
52 /// Create a matrix from raw column-major data.
53 ///
54 /// # Panics
55 ///
56 /// Panics if `data.len() != nrows * ncols`.
57 ///
58 /// # Examples
59 ///
60 /// ```
61 /// use tensor4all_tensorbackend::Matrix;
62 ///
63 /// let m = Matrix::from_col_major_vec(2, 2, vec![1.0, 3.0, 2.0, 4.0]);
64 /// assert_eq!(m[[0, 0]], 1.0);
65 /// assert_eq!(m[[0, 1]], 2.0);
66 /// assert_eq!(m[[1, 0]], 3.0);
67 /// assert_eq!(m[[1, 1]], 4.0);
68 /// ```
69 pub fn from_col_major_vec(nrows: usize, ncols: usize, data: Vec<T>) -> Self {
70 assert_eq!(data.len(), nrows * ncols);
71 Self { data, nrows, ncols }
72 }
73
74 /// View the underlying column-major data as a contiguous slice.
75 ///
76 /// # Examples
77 ///
78 /// ```
79 /// use tensor4all_tensorbackend::Matrix;
80 ///
81 /// let m = Matrix::from_col_major_vec(2, 2, vec![1, 3, 2, 4]);
82 /// assert_eq!(m.as_col_major_slice(), &[1, 3, 2, 4]);
83 /// ```
84 pub fn as_col_major_slice(&self) -> &[T] {
85 &self.data
86 }
87
88 fn offset(&self, row: usize, col: usize) -> usize {
89 row + self.nrows * col
90 }
91
92 /// Number of rows
93 pub fn nrows(&self) -> usize {
94 self.nrows
95 }
96
97 /// Number of columns
98 pub fn ncols(&self) -> usize {
99 self.ncols
100 }
101}
102
103impl<T: Clone> Matrix<T> {
104 /// Create a new matrix filled with a constant value.
105 ///
106 /// # Examples
107 ///
108 /// ```
109 /// use tensor4all_tensorbackend::Matrix;
110 ///
111 /// let m = Matrix::from_elem(2, 3, 7.0);
112 /// assert_eq!(m[[0, 0]], 7.0);
113 /// assert_eq!(m[[1, 2]], 7.0);
114 /// ```
115 pub fn from_elem(nrows: usize, ncols: usize, elem: T) -> Self {
116 Self {
117 data: vec![elem; nrows * ncols],
118 nrows,
119 ncols,
120 }
121 }
122}
123
124impl<T: Clone + Zero> Matrix<T> {
125 /// Create a zeros matrix
126 ///
127 /// # Examples
128 ///
129 /// ```
130 /// use tensor4all_tensorbackend::Matrix;
131 ///
132 /// let m = Matrix::<f64>::zeros(2, 3);
133 /// assert_eq!(m.nrows(), 2);
134 /// assert_eq!(m.ncols(), 3);
135 /// assert_eq!(m[[0, 0]], 0.0);
136 /// assert_eq!(m[[1, 2]], 0.0);
137 /// ```
138 pub fn zeros(nrows: usize, ncols: usize) -> Self {
139 Self {
140 data: vec![T::zero(); nrows * ncols],
141 nrows,
142 ncols,
143 }
144 }
145}
146
147impl<T> Index<[usize; 2]> for Matrix<T> {
148 type Output = T;
149
150 fn index(&self, idx: [usize; 2]) -> &Self::Output {
151 &self.data[self.offset(idx[0], idx[1])]
152 }
153}
154
155impl<T> IndexMut<[usize; 2]> for Matrix<T> {
156 fn index_mut(&mut self, idx: [usize; 2]) -> &mut Self::Output {
157 let offset = self.offset(idx[0], idx[1]);
158 &mut self.data[offset]
159 }
160}
161
162/// Create a matrix from a 2D vector.
163///
164/// Each inner `Vec` is one row. The resulting matrix is stored internally in
165/// column-major order.
166///
167/// # Examples
168///
169/// ```
170/// use tensor4all_tensorbackend::from_vec2d;
171///
172/// let m = from_vec2d(vec![
173/// vec![1.0, 2.0],
174/// vec![3.0, 4.0],
175/// ]);
176/// assert_eq!(m.nrows(), 2);
177/// assert_eq!(m.ncols(), 2);
178/// assert_eq!(m[[0, 1]], 2.0);
179/// assert_eq!(m[[1, 0]], 3.0);
180/// ```
181pub fn from_vec2d<T: Clone + Zero>(data: Vec<Vec<T>>) -> Matrix<T> {
182 let nrows = data.len();
183 let ncols = if nrows > 0 { data[0].len() } else { 0 };
184 let mut m = Matrix::zeros(nrows, ncols);
185 for i in 0..nrows {
186 for j in 0..ncols {
187 m[[i, j]] = data[i][j].clone();
188 }
189 }
190 m
191}
192
193/// Get a submatrix by selecting specific rows and columns.
194///
195/// # Examples
196///
197/// ```
198/// use tensor4all_tensorbackend::{from_vec2d, submatrix};
199///
200/// let m = from_vec2d(vec![
201/// vec![1.0, 2.0, 3.0],
202/// vec![4.0, 5.0, 6.0],
203/// vec![7.0, 8.0, 9.0],
204/// ]);
205/// let sub = submatrix(&m, &[0, 2], &[1, 2]);
206/// assert_eq!(sub.nrows(), 2);
207/// assert_eq!(sub.ncols(), 2);
208/// assert_eq!(sub[[0, 0]], 2.0); // m[0, 1]
209/// assert_eq!(sub[[1, 1]], 9.0); // m[2, 2]
210/// ```
211pub fn submatrix<T: Clone + Zero>(m: &Matrix<T>, rows: &[usize], cols: &[usize]) -> Matrix<T> {
212 let mut result = Matrix::zeros(rows.len(), cols.len());
213 for (ri, &r) in rows.iter().enumerate() {
214 for (ci, &c) in cols.iter().enumerate() {
215 result[[ri, ci]] = m[[r, c]].clone();
216 }
217 }
218 result
219}
220
221/// Swap two rows in a matrix in-place.
222///
223/// No-op if `a == b`.
224///
225/// # Examples
226///
227/// ```
228/// use tensor4all_tensorbackend::{from_vec2d, swap_rows};
229///
230/// let mut m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
231/// swap_rows(&mut m, 0, 1);
232/// assert_eq!(m[[0, 0]], 3.0);
233/// assert_eq!(m[[1, 0]], 1.0);
234/// ```
235pub fn swap_rows<T>(m: &mut Matrix<T>, a: usize, b: usize) {
236 if a == b {
237 return;
238 }
239 for j in 0..m.ncols {
240 let idx_a = m.offset(a, j);
241 let idx_b = m.offset(b, j);
242 m.data.swap(idx_a, idx_b);
243 }
244}
245
246/// Swap two columns in a matrix in-place.
247///
248/// No-op if `a == b`.
249///
250/// # Examples
251///
252/// ```
253/// use tensor4all_tensorbackend::{from_vec2d, swap_cols};
254///
255/// let mut m = from_vec2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
256/// swap_cols(&mut m, 0, 1);
257/// assert_eq!(m[[0, 0]], 2.0);
258/// assert_eq!(m[[0, 1]], 1.0);
259/// ```
260pub fn swap_cols<T>(m: &mut Matrix<T>, a: usize, b: usize) {
261 if a == b {
262 return;
263 }
264 for i in 0..m.nrows {
265 let idx_a = m.offset(i, a);
266 let idx_b = m.offset(i, b);
267 m.data.swap(idx_a, idx_b);
268 }
269}
270
271/// Transpose the matrix.
272///
273/// # Examples
274///
275/// ```
276/// use tensor4all_tensorbackend::{from_vec2d, transpose};
277///
278/// let m = from_vec2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
279/// let mt = transpose(&m);
280/// assert_eq!(mt.nrows(), 3);
281/// assert_eq!(mt.ncols(), 2);
282/// assert_eq!(mt[[0, 0]], 1.0);
283/// assert_eq!(mt[[2, 1]], 6.0);
284/// ```
285pub fn transpose<T: Clone + Zero>(m: &Matrix<T>) -> Matrix<T> {
286 let mut result = Matrix::zeros(m.ncols, m.nrows);
287 for i in 0..m.nrows {
288 for j in 0..m.ncols {
289 result[[j, i]] = m[[i, j]].clone();
290 }
291 }
292 result
293}
294
295/// Find the position and value of the maximum absolute value in a submatrix.
296///
297/// Searches within the rectangular region defined by `rows x cols` ranges.
298/// Returns `(row, col, value)` of the element with the largest `|value|^2`.
299///
300/// # Panics
301///
302/// Panics if either range is empty.
303///
304/// # Examples
305///
306/// ```
307/// use tensor4all_tensorbackend::{from_vec2d, submatrix_argmax};
308///
309/// let m = from_vec2d(vec![
310/// vec![1.0_f64, 2.0, 3.0],
311/// vec![4.0, 9.0, 6.0],
312/// vec![7.0, 8.0, 5.0],
313/// ]);
314/// let (row, col, val) = submatrix_argmax(&m, 0..3, 0..3);
315/// assert_eq!(row, 1);
316/// assert_eq!(col, 1);
317/// assert_eq!(val, 9.0);
318/// ```
319pub fn submatrix_argmax<T: MatrixScalar>(
320 a: &Matrix<T>,
321 rows: std::ops::Range<usize>,
322 cols: std::ops::Range<usize>,
323) -> (usize, usize, T) {
324 assert!(!rows.is_empty(), "rows must not be empty");
325 assert!(!cols.is_empty(), "cols must not be empty");
326
327 let mut max_val: f64 = a[[rows.start, cols.start]].matrix_abs_sq();
328 let mut max_row = rows.start;
329 let mut max_col = cols.start;
330
331 for r in rows {
332 for c in cols.clone() {
333 let val: f64 = a[[r, c]].matrix_abs_sq();
334 if val > max_val {
335 max_val = val;
336 max_row = r;
337 max_col = c;
338 }
339 }
340 }
341
342 (max_row, max_col, a[[max_row, max_col]])
343}
344
345/// BLAS-backed matrix multiplication dispatch.
346///
347/// Implemented for all scalar types supported by tenferro einsum
348/// (f64, f32, Complex64, Complex32). This trait is sealed — external
349/// types cannot implement it.
350pub trait BlasMul: Sized {
351 #[doc(hidden)]
352 fn blas_mat_mul(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>>;
353}
354
355macro_rules! impl_blas_mul {
356 ($($t:ty),*) => {
357 $(
358 impl BlasMul for $t {
359 fn blas_mat_mul(a: &Matrix<Self>, b: &Matrix<Self>) -> Result<Matrix<Self>> {
360 use tenferro_einsum::typed_eager_einsum;
361 use tenferro_tensor::TypedTensor;
362 use crate::with_default_backend;
363
364 let m = a.nrows();
365 let k = a.ncols();
366 let n = b.ncols();
367 ensure!(
368 b.nrows() == k,
369 "matrix dimensions must agree for multiplication: left is {}x{}, right is {}x{}",
370 m,
371 k,
372 b.nrows(),
373 n
374 );
375
376 let a_tensor = TypedTensor::<$t>::from_vec(
377 vec![m, k],
378 a.as_col_major_slice().to_vec(),
379 );
380 let b_tensor = TypedTensor::<$t>::from_vec(
381 vec![k, n],
382 b.as_col_major_slice().to_vec(),
383 );
384 let c = with_default_backend(|backend| {
385 typed_eager_einsum(backend, &[&a_tensor, &b_tensor], "ij,jk->ik")
386 })
387 .context("matrix multiplication einsum failed")?;
388 let data = c.as_slice().to_vec();
389 ensure!(
390 data.len() == m * n,
391 "matrix multiplication returned {} values for expected shape {}x{}",
392 data.len(),
393 m,
394 n
395 );
396 Ok(Matrix {
397 data,
398 nrows: m,
399 ncols: n,
400 })
401 }
402 }
403 )*
404 };
405}
406
407impl_blas_mul!(f64, f32, num_complex::Complex64, num_complex::Complex32);
408
409/// Scalar bound for dense backend matrix utilities.
410///
411/// This is the storage/linalg-layer scalar trait. Higher-level crates may
412/// extend it with domain-specific methods, but matrix utilities only rely on
413/// these algebraic operations and absolute-value comparisons.
414pub trait MatrixScalar:
415 Clone
416 + Copy
417 + Zero
418 + One
419 + std::ops::Add<Output = Self>
420 + std::ops::Sub<Output = Self>
421 + std::ops::Mul<Output = Self>
422 + std::ops::Div<Output = Self>
423 + std::ops::Neg<Output = Self>
424 + Default
425 + Send
426 + Sync
427 + BlasMul
428 + 'static
429{
430 /// Squared absolute value as `f64`.
431 fn matrix_abs_sq(self) -> f64;
432}
433
434impl MatrixScalar for f64 {
435 fn matrix_abs_sq(self) -> f64 {
436 self * self
437 }
438}
439
440impl MatrixScalar for f32 {
441 fn matrix_abs_sq(self) -> f64 {
442 (self * self) as f64
443 }
444}
445
446impl MatrixScalar for Complex64 {
447 fn matrix_abs_sq(self) -> f64 {
448 self.norm_sqr()
449 }
450}
451
452impl MatrixScalar for Complex32 {
453 fn matrix_abs_sq(self) -> f64 {
454 self.norm_sqr() as f64
455 }
456}
457
458/// Matrix multiplication: A * B.
459///
460/// Uses BLAS-backed einsum via tenferro for high performance.
461///
462/// # Errors
463///
464/// Returns an error if `a.ncols() != b.nrows()` or the backend einsum fails.
465///
466/// # Examples
467///
468/// ```
469/// use tensor4all_tensorbackend::{from_vec2d, mat_mul};
470///
471/// let a = from_vec2d(vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]]);
472/// let b = from_vec2d(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
473/// let c = mat_mul(&a, &b).unwrap();
474/// assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
475/// assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
476/// assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
477/// assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
478/// ```
479pub fn mat_mul<T: BlasMul>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>> {
480 T::blas_mat_mul(a, b)
481}
482
483#[cfg(test)]
484mod tests;