1use crate::contiguous::{alloc_col_major_uninit, ContiguousOperand, ContiguousOperandMut};
8use crate::util::{try_fuse_group, MultiIndex};
9use faer::linalg::matmul::matmul_with_conj;
10use faer::mat::{MatMut, MatRef};
11use faer::{Accum, Conj, Par};
12use faer_traits::ComplexField;
13use strided_view::{RawStridedMut, RawStridedRef, StridedArray, StridedView, StridedViewMut};
14
15pub fn bgemm_strided_into<T>(
21 c: &mut StridedViewMut<T>,
22 a: &StridedView<T>,
23 b: &StridedView<T>,
24 _n_batch: usize,
25 n_lo: usize,
26 n_ro: usize,
27 n_sum: usize,
28 alpha: T,
29 beta: T,
30 conj_a: bool,
31 conj_b: bool,
32) -> strided_view::Result<()>
33where
34 T: ComplexField
35 + Copy
36 + strided_view::ElementOpApply
37 + Send
38 + Sync
39 + std::ops::Mul<Output = T>
40 + std::ops::Add<Output = T>
41 + num_traits::Zero
42 + num_traits::One
43 + PartialEq,
44{
45 let a = unsafe { RawStridedRef::new_unchecked(a.data(), a.dims(), a.strides(), a.offset()) };
46 let b = unsafe { RawStridedRef::new_unchecked(b.data(), b.dims(), b.strides(), b.offset()) };
47 let c_dims = c.dims().to_vec();
48 let c_strides = c.strides().to_vec();
49 let c_offset = c.offset();
50 let c = unsafe { RawStridedMut::new_unchecked(c.data_mut(), &c_dims, &c_strides, c_offset) };
51 bgemm_raw_strided_into(
52 c, a, b, _n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
53 )
54}
55
56#[allow(clippy::too_many_arguments)]
62pub(crate) fn bgemm_raw_strided_into<T>(
63 c: RawStridedMut<'_, T>,
64 a: RawStridedRef<'_, T>,
65 b: RawStridedRef<'_, T>,
66 n_batch: usize,
67 n_lo: usize,
68 n_ro: usize,
69 n_sum: usize,
70 alpha: T,
71 beta: T,
72 conj_a: bool,
73 conj_b: bool,
74) -> strided_view::Result<()>
75where
76 T: ComplexField
77 + Copy
78 + strided_view::ElementOpApply
79 + Send
80 + Sync
81 + std::ops::Mul<Output = T>
82 + std::ops::Add<Output = T>
83 + num_traits::Zero
84 + num_traits::One
85 + PartialEq,
86{
87 validate_bgemm_shapes(&c, &a, &b, n_batch, n_lo, n_ro, n_sum)?;
88 unsafe {
89 bgemm_raw_strided_into_unchecked(
90 c, a, b, n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
91 )
92 }
93}
94
95#[allow(clippy::too_many_arguments)]
105pub(crate) unsafe fn bgemm_raw_strided_into_unchecked<T>(
106 mut c: RawStridedMut<'_, T>,
107 a: RawStridedRef<'_, T>,
108 b: RawStridedRef<'_, T>,
109 _n_batch: usize,
110 n_lo: usize,
111 n_ro: usize,
112 n_sum: usize,
113 alpha: T,
114 beta: T,
115 conj_a: bool,
116 conj_b: bool,
117) -> strided_view::Result<()>
118where
119 T: ComplexField
120 + Copy
121 + strided_view::ElementOpApply
122 + Send
123 + Sync
124 + std::ops::Mul<Output = T>
125 + std::ops::Add<Output = T>
126 + num_traits::Zero
127 + num_traits::One
128 + PartialEq,
129{
130 let a_dims = a.dims();
131 let b_dims = b.dims();
132 let a_strides = a.strides();
133 let b_strides = b.strides();
134 let c_strides = c.strides();
135
136 let lo_dims = &a_dims[..n_lo];
139 let sum_dims = &a_dims[n_lo..n_lo + n_sum];
140 let batch_dims = &a_dims[n_lo + n_sum..];
141 let ro_dims = &b_dims[n_sum..n_sum + n_ro];
142
143 if c.dims().iter().any(|&dim| dim == 0) {
144 return Ok(());
145 }
146
147 let m: usize = lo_dims.iter().product::<usize>().max(1);
149 let k: usize = sum_dims.iter().product::<usize>().max(1);
150 let n: usize = ro_dims.iter().product::<usize>().max(1);
151
152 let a_lo_strides = &a_strides[..n_lo];
154 let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
155 let b_sum_strides = &b_strides[..n_sum];
156 let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
157 let c_lo_strides = &c_strides[..n_lo];
158 let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
159
160 let fused_a_lo = try_fuse_group(lo_dims, a_lo_strides);
162 let fused_a_sum = try_fuse_group(sum_dims, a_sum_strides);
163 let fused_b_sum = try_fuse_group(sum_dims, b_sum_strides);
164 let fused_b_ro = try_fuse_group(ro_dims, b_ro_strides);
165 let fused_c_lo = try_fuse_group(lo_dims, c_lo_strides);
166 let fused_c_ro = try_fuse_group(ro_dims, c_ro_strides);
167
168 let a_needs_copy = fused_a_lo.is_none() || fused_a_sum.is_none();
169 let b_needs_copy = fused_b_sum.is_none() || fused_b_ro.is_none();
170 let c_needs_copy = fused_c_lo.is_none() || fused_c_ro.is_none();
171
172 let n_a_inner = n_lo + n_sum;
173 let n_b_inner = n_sum + n_ro;
174 let n_c_inner = n_lo + n_ro;
175
176 let a_contig_buf: Option<StridedArray<T>>;
178 let (a_ptr, a_row_stride, a_col_stride);
179 if a_needs_copy {
180 let mut buf = alloc_col_major_uninit(a.dims());
181 strided_kernel::copy_into(&mut buf.view_mut(), &a.as_view())?;
182 a_ptr = buf.view().ptr();
183 a_row_stride = if m == 0 { 0 } else { 1isize };
185 a_col_stride = m as isize;
186 a_contig_buf = Some(buf);
187 } else {
188 let (_, rs) = fused_a_lo.unwrap();
189 let (_, cs) = fused_a_sum.unwrap();
190 a_ptr = a.ptr();
191 a_row_stride = rs;
192 a_col_stride = cs;
193 a_contig_buf = None;
194 }
195 let a_batch_strides: &[isize] = match a_contig_buf.as_ref() {
196 Some(buf) => &buf.strides()[n_a_inner..],
197 None => &a_strides[n_a_inner..],
198 };
199
200 let b_contig_buf: Option<StridedArray<T>>;
202 let (b_ptr, b_row_stride, b_col_stride);
203 if b_needs_copy {
204 let mut buf = alloc_col_major_uninit(b.dims());
205 strided_kernel::copy_into(&mut buf.view_mut(), &b.as_view())?;
206 b_ptr = buf.view().ptr();
207 b_row_stride = if k == 0 { 0 } else { 1isize };
209 b_col_stride = k as isize;
210 b_contig_buf = Some(buf);
211 } else {
212 let (_, rs) = fused_b_sum.unwrap();
213 let (_, cs) = fused_b_ro.unwrap();
214 b_ptr = b.ptr();
215 b_row_stride = rs;
216 b_col_stride = cs;
217 b_contig_buf = None;
218 }
219 let b_batch_strides: &[isize] = match b_contig_buf.as_ref() {
220 Some(buf) => &buf.strides()[n_b_inner..],
221 None => &b_strides[n_b_inner..],
222 };
223
224 let c_contig_buf: Option<StridedArray<T>>;
226 let (c_ptr, c_row_stride, c_col_stride);
227 if c_needs_copy {
228 let mut buf = alloc_col_major_uninit(c.dims());
229 if beta != T::zero() {
230 let c_view: StridedView<'_, T> = c.as_view();
231 strided_kernel::copy_into(&mut buf.view_mut(), &c_view)?;
232 }
233 c_ptr = buf.view_mut().as_mut_ptr();
234 c_row_stride = if m == 0 { 0 } else { 1isize };
236 c_col_stride = m as isize;
237 c_contig_buf = Some(buf);
238 } else {
239 let (_, rs) = fused_c_lo.unwrap();
240 let (_, cs) = fused_c_ro.unwrap();
241 c_ptr = c.as_mut_ptr();
242 c_row_stride = rs;
243 c_col_stride = cs;
244 c_contig_buf = None;
245 }
246 let c_batch_strides: &[isize] = match c_contig_buf.as_ref() {
247 Some(buf) => &buf.strides()[n_c_inner..],
248 None => &c_strides[n_c_inner..],
249 };
250
251 let is_beta_zero = beta == T::zero();
252 let is_beta_one = beta == T::one();
253
254 let accum = if is_beta_zero {
256 Accum::Replace
257 } else {
258 Accum::Add
259 };
260
261 let cj_a = if conj_a { Conj::Yes } else { Conj::No };
262 let cj_b = if conj_b { Conj::Yes } else { Conj::No };
263
264 let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
266 if !is_beta_zero && !is_beta_one {
268 let c_base = unsafe { c_ptr.offset(c_batch_off) };
269 for i in 0..m {
270 for j in 0..n {
271 let offset = i as isize * c_row_stride + j as isize * c_col_stride;
272 unsafe {
273 let elem = c_base.offset(offset);
274 *elem = beta * *elem;
275 }
276 }
277 }
278 }
279
280 unsafe {
281 let a_mat: MatRef<'_, T> =
282 MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
283 let b_mat: MatRef<'_, T> =
284 MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
285 let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
286 c_ptr.offset(c_batch_off),
287 m,
288 n,
289 c_row_stride,
290 c_col_stride,
291 );
292
293 matmul_with_conj(c_mat, accum, a_mat, cj_a, b_mat, cj_b, alpha, Par::rayon(0));
294 }
295 };
296
297 let fused_a = try_fuse_group(batch_dims, a_batch_strides);
300 let fused_b = try_fuse_group(batch_dims, b_batch_strides);
301 let fused_c = try_fuse_group(batch_dims, c_batch_strides);
302
303 if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
304 (fused_a, fused_b, fused_c)
305 {
306 let mut a_off = 0isize;
307 let mut b_off = 0isize;
308 let mut c_off = 0isize;
309 for _ in 0..total {
310 do_batch(a_off, b_off, c_off);
311 a_off += a_step;
312 b_off += b_step;
313 c_off += c_step;
314 }
315 } else {
316 let mut batch_iter = MultiIndex::new(batch_dims);
317 while batch_iter.next().is_some() {
318 let a_batch_off = batch_iter.offset(a_batch_strides);
319 let b_batch_off = batch_iter.offset(b_batch_strides);
320 let c_batch_off = batch_iter.offset(c_batch_strides);
321 do_batch(a_batch_off, b_batch_off, c_batch_off);
322 }
323 }
324
325 if let Some(ref c_buf) = c_contig_buf {
327 let mut c_view = c.as_view_mut();
328 strided_kernel::copy_into(&mut c_view, &c_buf.view())?;
329 }
330
331 Ok(())
332}
333
334fn validate_bgemm_shapes<T>(
335 c: &RawStridedMut<'_, T>,
336 a: &RawStridedRef<'_, T>,
337 b: &RawStridedRef<'_, T>,
338 n_batch: usize,
339 n_lo: usize,
340 n_ro: usize,
341 n_sum: usize,
342) -> strided_view::Result<()> {
343 let a_rank = n_lo + n_sum + n_batch;
344 let b_rank = n_sum + n_ro + n_batch;
345 let c_rank = n_lo + n_ro + n_batch;
346 if a.dims().len() != a_rank {
347 return Err(strided_view::StridedError::RankMismatch(
348 a_rank,
349 a.dims().len(),
350 ));
351 }
352 if b.dims().len() != b_rank {
353 return Err(strided_view::StridedError::RankMismatch(
354 b_rank,
355 b.dims().len(),
356 ));
357 }
358 if c.dims().len() != c_rank {
359 return Err(strided_view::StridedError::RankMismatch(
360 c_rank,
361 c.dims().len(),
362 ));
363 }
364
365 let lo_dims = &a.dims()[..n_lo];
366 let sum_dims = &a.dims()[n_lo..n_lo + n_sum];
367 let batch_dims = &a.dims()[n_lo + n_sum..];
368 let ro_dims = &b.dims()[n_sum..n_sum + n_ro];
369
370 if &b.dims()[..n_sum] != sum_dims {
371 return Err(strided_view::StridedError::ShapeMismatch(
372 sum_dims.to_vec(),
373 b.dims()[..n_sum].to_vec(),
374 ));
375 }
376 if &b.dims()[n_sum + n_ro..] != batch_dims {
377 return Err(strided_view::StridedError::ShapeMismatch(
378 batch_dims.to_vec(),
379 b.dims()[n_sum + n_ro..].to_vec(),
380 ));
381 }
382 if &c.dims()[..n_lo] != lo_dims {
383 return Err(strided_view::StridedError::ShapeMismatch(
384 lo_dims.to_vec(),
385 c.dims()[..n_lo].to_vec(),
386 ));
387 }
388 if &c.dims()[n_lo..n_lo + n_ro] != ro_dims {
389 return Err(strided_view::StridedError::ShapeMismatch(
390 ro_dims.to_vec(),
391 c.dims()[n_lo..n_lo + n_ro].to_vec(),
392 ));
393 }
394 if &c.dims()[n_lo + n_ro..] != batch_dims {
395 return Err(strided_view::StridedError::ShapeMismatch(
396 batch_dims.to_vec(),
397 c.dims()[n_lo + n_ro..].to_vec(),
398 ));
399 }
400 Ok(())
401}
402
403pub fn bgemm_contiguous_into<T>(
413 c: &mut ContiguousOperandMut<T>,
414 a: &ContiguousOperand<T>,
415 b: &ContiguousOperand<T>,
416 batch_dims: &[usize],
417 m: usize,
418 n: usize,
419 k: usize,
420 alpha: T,
421 beta: T,
422) -> strided_view::Result<()>
423where
424 T: ComplexField
425 + Copy
426 + strided_view::ElementOpApply
427 + Send
428 + Sync
429 + std::ops::Mul<Output = T>
430 + std::ops::Add<Output = T>
431 + num_traits::Zero
432 + num_traits::One
433 + PartialEq,
434{
435 let is_beta_zero = beta == T::zero();
436 let is_beta_one = beta == T::one();
437
438 let accum = if is_beta_zero {
439 Accum::Replace
440 } else {
441 Accum::Add
442 };
443
444 let a_batch_strides = a.batch_strides();
445 let b_batch_strides = b.batch_strides();
446 let c_batch_strides = c.batch_strides();
447
448 let a_ptr = a.ptr();
449 let b_ptr = b.ptr();
450 let c_ptr = c.ptr();
451 let a_row_stride = a.row_stride();
452 let a_col_stride = a.col_stride();
453 let b_row_stride = b.row_stride();
454 let b_col_stride = b.col_stride();
455 let c_row_stride = c.row_stride();
456 let c_col_stride = c.col_stride();
457
458 let conj_a = if a.conj() { Conj::Yes } else { Conj::No };
459 let conj_b = if b.conj() { Conj::Yes } else { Conj::No };
460
461 let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
463 if !is_beta_zero && !is_beta_one {
465 let c_base = unsafe { c_ptr.offset(c_batch_off) };
466 for i in 0..m {
467 for j in 0..n {
468 let offset = i as isize * c_row_stride + j as isize * c_col_stride;
469 unsafe {
470 let elem = c_base.offset(offset);
471 *elem = beta * *elem;
472 }
473 }
474 }
475 }
476
477 unsafe {
478 let a_mat: MatRef<'_, T> =
479 MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
480 let b_mat: MatRef<'_, T> =
481 MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
482 let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
483 c_ptr.offset(c_batch_off),
484 m,
485 n,
486 c_row_stride,
487 c_col_stride,
488 );
489
490 matmul_with_conj(
491 c_mat,
492 accum,
493 a_mat,
494 conj_a,
495 b_mat,
496 conj_b,
497 alpha,
498 Par::rayon(0),
499 );
500 }
501 };
502
503 let fused_a = try_fuse_group(batch_dims, a_batch_strides);
506 let fused_b = try_fuse_group(batch_dims, b_batch_strides);
507 let fused_c = try_fuse_group(batch_dims, c_batch_strides);
508
509 if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
510 (fused_a, fused_b, fused_c)
511 {
512 let mut a_off = 0isize;
513 let mut b_off = 0isize;
514 let mut c_off = 0isize;
515 for _ in 0..total {
516 do_batch(a_off, b_off, c_off);
517 a_off += a_step;
518 b_off += b_step;
519 c_off += c_step;
520 }
521 } else {
522 let mut batch_iter = MultiIndex::new(batch_dims);
523 while batch_iter.next().is_some() {
524 let a_batch_off = batch_iter.offset(a_batch_strides);
525 let b_batch_off = batch_iter.offset(b_batch_strides);
526 let c_batch_off = batch_iter.offset(c_batch_strides);
527 do_batch(a_batch_off, b_batch_off, c_batch_off);
528 }
529 }
530
531 Ok(())
532}
533
534use crate::backend::{Backend, FaerBackend};
535
536impl<T> Backend<T> for FaerBackend
537where
538 T: crate::ScalarBase + strided_view::ElementOpApply + ComplexField,
539{
540 const MATERIALIZES_CONJ: bool = false;
541 const REQUIRES_UNIT_STRIDE: bool = false;
542
543 fn bgemm_contiguous_into(
544 c: &mut ContiguousOperandMut<T>,
545 a: &ContiguousOperand<T>,
546 b: &ContiguousOperand<T>,
547 batch_dims: &[usize],
548 m: usize,
549 n: usize,
550 k: usize,
551 alpha: T,
552 beta: T,
553 ) -> strided_view::Result<()> {
554 bgemm_contiguous_into(c, a, b, batch_dims, m, n, k, alpha, beta)
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use strided_view::StridedArray;
563
564 #[test]
565 fn test_faer_bgemm_2x2() {
566 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
567 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
568 });
569 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
570 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
571 });
572 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
573
574 bgemm_strided_into(
575 &mut c.view_mut(),
576 &a.view(),
577 &b.view(),
578 0,
579 1,
580 1,
581 1,
582 1.0,
583 0.0,
584 false,
585 false,
586 )
587 .unwrap();
588
589 assert_eq!(c.get(&[0, 0]), 19.0);
590 assert_eq!(c.get(&[0, 1]), 22.0);
591 assert_eq!(c.get(&[1, 0]), 43.0);
592 assert_eq!(c.get(&[1, 1]), 50.0);
593 }
594
595 fn raw_bgemm_2x2<T>(one: T, zero: T) -> Vec<T>
596 where
597 T: ComplexField
598 + Copy
599 + strided_view::ElementOpApply
600 + Send
601 + Sync
602 + std::ops::Mul<Output = T>
603 + std::ops::Add<Output = T>
604 + num_traits::Zero
605 + num_traits::One
606 + PartialEq
607 + From<f32>,
608 {
609 let dims = [2, 2];
610 let strides = [2, 1];
611 let a_data = [T::from(1.0), T::from(2.0), T::from(3.0), T::from(4.0)];
612 let b_data = [T::from(5.0), T::from(6.0), T::from(7.0), T::from(8.0)];
613 let mut c_data = vec![zero; 4];
614 let a = RawStridedRef::new(&a_data, &dims, &strides, 0).unwrap();
615 let b = RawStridedRef::new(&b_data, &dims, &strides, 0).unwrap();
616 let c = RawStridedMut::new(&mut c_data, &dims, &strides, 0).unwrap();
617 bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, one, zero, false, false).unwrap();
618 c_data
619 }
620
621 #[test]
622 fn test_faer_raw_bgemm_f64() {
623 assert_eq!(raw_bgemm_2x2(1.0f64, 0.0), vec![19.0, 22.0, 43.0, 50.0]);
624 }
625
626 #[test]
627 fn test_faer_raw_bgemm_f32() {
628 assert_eq!(raw_bgemm_2x2(1.0f32, 0.0), vec![19.0f32, 22.0, 43.0, 50.0]);
629 }
630
631 #[test]
632 fn test_faer_raw_bgemm_complex_conj() {
633 use num_complex::Complex64;
634
635 let i = Complex64::i();
636 let dims = [2, 2];
637 let strides = [2, 1];
638 let a_data = [
639 Complex64::new(1.0, 0.0) + i,
640 Complex64::new(2.0, 0.0),
641 Complex64::new(3.0, 0.0),
642 Complex64::new(4.0, 0.0) - i,
643 ];
644 let b_data = [
645 Complex64::new(1.0, 0.0),
646 Complex64::new(0.0, 0.0),
647 Complex64::new(0.0, 0.0),
648 Complex64::new(1.0, 0.0),
649 ];
650 let mut c_data = vec![Complex64::new(0.0, 0.0); 4];
651 let a = RawStridedRef::new(&a_data, &dims, &strides, 0).unwrap();
652 let b = RawStridedRef::new(&b_data, &dims, &strides, 0).unwrap();
653 let c = RawStridedMut::new(&mut c_data, &dims, &strides, 0).unwrap();
654 bgemm_raw_strided_into(
655 c,
656 a,
657 b,
658 0,
659 1,
660 1,
661 1,
662 Complex64::new(1.0, 0.0),
663 Complex64::new(0.0, 0.0),
664 true,
665 false,
666 )
667 .unwrap();
668 assert_eq!(
669 c_data,
670 vec![
671 Complex64::new(1.0, -1.0),
672 Complex64::new(2.0, 0.0),
673 Complex64::new(3.0, 0.0),
674 Complex64::new(4.0, 1.0),
675 ]
676 );
677 }
678
679 #[test]
680 fn test_faer_raw_bgemm_checked_shape_mismatch() {
681 let a_dims = [2, 2];
682 let b_dims = [3, 2];
683 let c_dims = [2, 2];
684 let a_strides = [2, 1];
685 let b_strides = [2, 1];
686 let c_strides = [2, 1];
687 let a_data = [1.0, 2.0, 3.0, 4.0];
688 let b_data = [0.0; 6];
689 let mut c_data = [0.0; 4];
690 let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
691 let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
692 let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
693 let err = bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false).unwrap_err();
694 assert!(matches!(
695 err,
696 strided_view::StridedError::ShapeMismatch(_, _)
697 ));
698 }
699
700 #[test]
701 fn test_faer_bgemm_rect() {
702 let a =
703 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
704 let b =
705 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
706 let mut c = StridedArray::<f64>::row_major(&[2, 4]);
707
708 bgemm_strided_into(
709 &mut c.view_mut(),
710 &a.view(),
711 &b.view(),
712 0,
713 1,
714 1,
715 1,
716 1.0,
717 0.0,
718 false,
719 false,
720 )
721 .unwrap();
722
723 assert_eq!(c.get(&[0, 0]), 38.0);
724 assert_eq!(c.get(&[1, 3]), 128.0);
725 }
726
727 #[test]
728 fn test_faer_bgemm_batched() {
729 let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
731 (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
732 });
733 let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
734 (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
735 });
736 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
737
738 bgemm_strided_into(
739 &mut c.view_mut(),
740 &a.view(),
741 &b.view(),
742 1,
743 1,
744 1,
745 1,
746 1.0,
747 0.0,
748 false,
749 false,
750 )
751 .unwrap();
752
753 assert_eq!(c.get(&[0, 0, 0]), 22.0);
757 }
758
759 #[test]
760 fn test_faer_bgemm_beta_zero() {
761 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
762 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
763 });
764 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
765 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
766 });
767 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
768 [[100.0, 200.0], [300.0, 400.0]][idx[0]][idx[1]]
769 });
770
771 bgemm_strided_into(
772 &mut c.view_mut(),
773 &a.view(),
774 &b.view(),
775 0,
776 1,
777 1,
778 1,
779 1.0,
780 0.0, false,
782 false,
783 )
784 .unwrap();
785
786 assert_eq!(c.get(&[0, 0]), 19.0);
787 assert_eq!(c.get(&[1, 1]), 50.0);
788 }
789
790 #[test]
791 fn test_faer_bgemm_beta_one() {
792 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
793 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
795 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
796 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
797 });
798 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
799 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
800 });
801
802 bgemm_strided_into(
803 &mut c.view_mut(),
804 &a.view(),
805 &b.view(),
806 0,
807 1,
808 1,
809 1,
810 1.0,
811 1.0, false,
813 false,
814 )
815 .unwrap();
816
817 assert_eq!(c.get(&[0, 0]), 11.0);
819 assert_eq!(c.get(&[1, 1]), 44.0);
821 }
822
823 #[test]
824 fn test_faer_bgemm_alpha_beta() {
825 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
826 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
828 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
829 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
830 });
831 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
832 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
833 });
834
835 bgemm_strided_into(
836 &mut c.view_mut(),
837 &a.view(),
838 &b.view(),
839 0,
840 1,
841 1,
842 1,
843 2.0,
844 3.0, false,
846 false,
847 )
848 .unwrap();
849
850 assert_eq!(c.get(&[0, 0]), 32.0);
852 assert_eq!(c.get(&[1, 1]), 128.0);
854 }
855
856 #[test]
857 fn test_faer_bgemm_outer_product() {
858 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
859 let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
860 let mut c = StridedArray::<f64>::row_major(&[3, 4]);
861
862 bgemm_strided_into(
863 &mut c.view_mut(),
864 &a.view(),
865 &b.view(),
866 0,
867 1,
868 1,
869 0, 1.0,
871 0.0,
872 false,
873 false,
874 )
875 .unwrap();
876
877 assert_eq!(c.get(&[0, 0]), 1.0);
878 assert_eq!(c.get(&[2, 3]), 12.0);
879 }
880
881 #[test]
882 fn test_faer_bgemm_f32() {
883 let a = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
884 [[1.0f32, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
885 });
886 let b = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
887 [[5.0f32, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
888 });
889 let mut c = StridedArray::<f32>::row_major(&[2, 2]);
890
891 bgemm_strided_into(
892 &mut c.view_mut(),
893 &a.view(),
894 &b.view(),
895 0,
896 1,
897 1,
898 1,
899 1.0f32,
900 0.0f32,
901 false,
902 false,
903 )
904 .unwrap();
905
906 assert_eq!(c.get(&[0, 0]), 19.0f32);
907 assert_eq!(c.get(&[1, 1]), 50.0f32);
908 }
909
910 #[test]
911 fn test_faer_bgemm_col_major_input() {
912 let a_data = vec![1.0, 3.0, 2.0, 4.0]; let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
915
916 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
917 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
918 });
919 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
920
921 bgemm_strided_into(
922 &mut c.view_mut(),
923 &a.view(),
924 &b.view(),
925 0,
926 1,
927 1,
928 1,
929 1.0,
930 0.0,
931 false,
932 false,
933 )
934 .unwrap();
935
936 assert_eq!(c.get(&[0, 0]), 19.0);
938 assert_eq!(c.get(&[0, 1]), 22.0);
939 assert_eq!(c.get(&[1, 0]), 43.0);
940 assert_eq!(c.get(&[1, 1]), 50.0);
941 }
942
943 #[test]
944 fn test_faer_bgemm_col_major_output() {
945 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
947 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
948 });
949 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
950 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
951 });
952 let mut c = StridedArray::<f64>::col_major(&[2, 2]);
953
954 bgemm_strided_into(
955 &mut c.view_mut(),
956 &a.view(),
957 &b.view(),
958 0,
959 1,
960 1,
961 1,
962 1.0,
963 0.0,
964 false,
965 false,
966 )
967 .unwrap();
968
969 assert_eq!(c.get(&[0, 0]), 19.0);
970 assert_eq!(c.get(&[0, 1]), 22.0);
971 assert_eq!(c.get(&[1, 0]), 43.0);
972 assert_eq!(c.get(&[1, 1]), 50.0);
973 }
974
975 #[test]
976 fn test_faer_bgemm_col_major_with_beta() {
977 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
979 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
981 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
982 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
983 });
984 let c_data = vec![10.0, 30.0, 20.0, 40.0]; let mut c = StridedArray::<f64>::from_parts(c_data, &[2, 2], &[1, 2], 0).unwrap();
987
988 bgemm_strided_into(
989 &mut c.view_mut(),
990 &a.view(),
991 &b.view(),
992 0,
993 1,
994 1,
995 1,
996 2.0,
997 3.0, false,
999 false,
1000 )
1001 .unwrap();
1002
1003 assert_eq!(c.get(&[0, 0]), 32.0);
1005 assert_eq!(c.get(&[1, 1]), 128.0);
1007 }
1008
1009 use crate::backend::{ActiveBackend, Backend};
1012 use crate::contiguous::{prepare_input_view, prepare_output_view};
1013
1014 const US: bool = <ActiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE;
1015
1016 #[test]
1017 fn test_bgemm_contiguous_2x2() {
1018 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1019 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1020 });
1021 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1022 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1023 });
1024 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1025
1026 let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
1027 let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
1028 let mut c_view = c.view_mut();
1029 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
1030
1031 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
1032
1033 c_op.finalize_into(&mut c_view).unwrap();
1034
1035 assert_eq!(c.get(&[0, 0]), 19.0);
1036 assert_eq!(c.get(&[0, 1]), 22.0);
1037 assert_eq!(c.get(&[1, 0]), 43.0);
1038 assert_eq!(c.get(&[1, 1]), 50.0);
1039 }
1040
1041 #[test]
1042 fn test_bgemm_contiguous_batched() {
1043 let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
1046 (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
1047 });
1048 let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
1049 (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
1050 });
1051 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
1052
1053 let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
1055 let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
1056 let mut c_view = c.view_mut();
1057 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
1058
1059 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[2], 2, 2, 3, 1.0, 0.0).unwrap();
1060
1061 c_op.finalize_into(&mut c_view).unwrap();
1062
1063 assert_eq!(c.get(&[0, 0, 0]), 22.0);
1067 assert_eq!(c.get(&[0, 1, 0]), 28.0);
1069 assert_eq!(c.get(&[1, 0, 0]), 49.0);
1071 assert_eq!(c.get(&[1, 1, 0]), 64.0);
1073
1074 assert_eq!(c.get(&[0, 0, 1]), 220.0);
1077 }
1078
1079 #[test]
1080 fn test_bgemm_contiguous_with_beta() {
1081 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1083 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
1085 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1086 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
1087 });
1088 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1089 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
1090 });
1091
1092 let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
1093 let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
1094 let mut c_view = c.view_mut();
1095 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 3.0, US, true).unwrap();
1096
1097 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 2.0, 3.0).unwrap();
1098
1099 c_op.finalize_into(&mut c_view).unwrap();
1100
1101 assert_eq!(c.get(&[0, 0]), 32.0);
1103 assert_eq!(c.get(&[0, 1]), 64.0);
1105 assert_eq!(c.get(&[1, 0]), 96.0);
1107 assert_eq!(c.get(&[1, 1]), 128.0);
1109 }
1110
1111 #[test]
1112 fn test_bgemm_contiguous_non_contiguous_input() {
1113 let a_data = vec![1.0, 3.0, 2.0, 4.0]; let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
1116
1117 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
1118 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
1119 });
1120 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
1121
1122 let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
1123 let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
1124 let mut c_view = c.view_mut();
1125 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
1126
1127 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
1128
1129 c_op.finalize_into(&mut c_view).unwrap();
1130
1131 assert_eq!(c.get(&[0, 0]), 19.0);
1133 assert_eq!(c.get(&[0, 1]), 22.0);
1134 assert_eq!(c.get(&[1, 0]), 43.0);
1135 assert_eq!(c.get(&[1, 1]), 50.0);
1136 }
1137}