1use crate::contiguous::{alloc_col_major_uninit, ContiguousOperand, ContiguousOperandMut};
8use crate::util::{try_fuse_group, MultiIndex};
9use faer::linalg::matmul::matmul_with_conj;
10use faer::mat::{MatMut, MatRef};
11use faer::{Accum, Conj, Par};
12use faer_traits::ComplexField;
13use strided_view::{StridedArray, StridedView, StridedViewMut};
14
15pub fn bgemm_strided_into<T>(
21 c: &mut StridedViewMut<T>,
22 a: &StridedView<T>,
23 b: &StridedView<T>,
24 _n_batch: usize,
25 n_lo: usize,
26 n_ro: usize,
27 n_sum: usize,
28 alpha: T,
29 beta: T,
30 conj_a: bool,
31 conj_b: bool,
32) -> strided_view::Result<()>
33where
34 T: ComplexField
35 + Copy
36 + strided_view::ElementOpApply
37 + Send
38 + Sync
39 + std::ops::Mul<Output = T>
40 + std::ops::Add<Output = T>
41 + num_traits::Zero
42 + num_traits::One
43 + PartialEq,
44{
45 let a_dims = a.dims();
46 let b_dims = b.dims();
47 let a_strides = a.strides();
48 let b_strides = b.strides();
49 let c_strides = c.strides();
50
51 let lo_dims = &a_dims[..n_lo];
54 let sum_dims = &a_dims[n_lo..n_lo + n_sum];
55 let batch_dims = &a_dims[n_lo + n_sum..];
56 let ro_dims = &b_dims[n_sum..n_sum + n_ro];
57
58 let m: usize = lo_dims.iter().product::<usize>().max(1);
60 let k: usize = sum_dims.iter().product::<usize>().max(1);
61 let n: usize = ro_dims.iter().product::<usize>().max(1);
62
63 let a_lo_strides = &a_strides[..n_lo];
65 let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
66 let b_sum_strides = &b_strides[..n_sum];
67 let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
68 let c_lo_strides = &c_strides[..n_lo];
69 let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
70
71 let fused_a_lo = try_fuse_group(lo_dims, a_lo_strides);
73 let fused_a_sum = try_fuse_group(sum_dims, a_sum_strides);
74 let fused_b_sum = try_fuse_group(sum_dims, b_sum_strides);
75 let fused_b_ro = try_fuse_group(ro_dims, b_ro_strides);
76 let fused_c_lo = try_fuse_group(lo_dims, c_lo_strides);
77 let fused_c_ro = try_fuse_group(ro_dims, c_ro_strides);
78
79 let a_needs_copy = fused_a_lo.is_none() || fused_a_sum.is_none();
80 let b_needs_copy = fused_b_sum.is_none() || fused_b_ro.is_none();
81 let c_needs_copy = fused_c_lo.is_none() || fused_c_ro.is_none();
82
83 let n_a_inner = n_lo + n_sum;
84 let n_b_inner = n_sum + n_ro;
85 let n_c_inner = n_lo + n_ro;
86
87 let a_contig_buf: Option<StridedArray<T>>;
89 let (a_ptr, a_row_stride, a_col_stride);
90 if a_needs_copy {
91 let mut buf = alloc_col_major_uninit(a.dims());
92 strided_kernel::copy_into(&mut buf.view_mut(), a)?;
93 a_ptr = buf.view().ptr();
94 a_row_stride = if m == 0 { 0 } else { 1isize };
96 a_col_stride = m as isize;
97 a_contig_buf = Some(buf);
98 } else {
99 let (_, rs) = fused_a_lo.unwrap();
100 let (_, cs) = fused_a_sum.unwrap();
101 a_ptr = a.ptr();
102 a_row_stride = rs;
103 a_col_stride = cs;
104 a_contig_buf = None;
105 }
106 let a_batch_strides: &[isize] = match a_contig_buf.as_ref() {
107 Some(buf) => &buf.strides()[n_a_inner..],
108 None => &a_strides[n_a_inner..],
109 };
110
111 let b_contig_buf: Option<StridedArray<T>>;
113 let (b_ptr, b_row_stride, b_col_stride);
114 if b_needs_copy {
115 let mut buf = alloc_col_major_uninit(b.dims());
116 strided_kernel::copy_into(&mut buf.view_mut(), b)?;
117 b_ptr = buf.view().ptr();
118 b_row_stride = if k == 0 { 0 } else { 1isize };
120 b_col_stride = k as isize;
121 b_contig_buf = Some(buf);
122 } else {
123 let (_, rs) = fused_b_sum.unwrap();
124 let (_, cs) = fused_b_ro.unwrap();
125 b_ptr = b.ptr();
126 b_row_stride = rs;
127 b_col_stride = cs;
128 b_contig_buf = None;
129 }
130 let b_batch_strides: &[isize] = match b_contig_buf.as_ref() {
131 Some(buf) => &buf.strides()[n_b_inner..],
132 None => &b_strides[n_b_inner..],
133 };
134
135 let c_contig_buf: Option<StridedArray<T>>;
137 let (c_ptr, c_row_stride, c_col_stride);
138 if c_needs_copy {
139 let mut buf = alloc_col_major_uninit(c.dims());
140 if beta != T::zero() {
141 strided_kernel::copy_into(&mut buf.view_mut(), &c.as_view())?;
142 }
143 c_ptr = buf.view_mut().as_mut_ptr();
144 c_row_stride = if m == 0 { 0 } else { 1isize };
146 c_col_stride = m as isize;
147 c_contig_buf = Some(buf);
148 } else {
149 let (_, rs) = fused_c_lo.unwrap();
150 let (_, cs) = fused_c_ro.unwrap();
151 c_ptr = c.as_mut_ptr();
152 c_row_stride = rs;
153 c_col_stride = cs;
154 c_contig_buf = None;
155 }
156 let c_batch_strides: &[isize] = match c_contig_buf.as_ref() {
157 Some(buf) => &buf.strides()[n_c_inner..],
158 None => &c_strides[n_c_inner..],
159 };
160
161 let is_beta_zero = beta == T::zero();
162 let is_beta_one = beta == T::one();
163
164 let accum = if is_beta_zero {
166 Accum::Replace
167 } else {
168 Accum::Add
169 };
170
171 let cj_a = if conj_a { Conj::Yes } else { Conj::No };
172 let cj_b = if conj_b { Conj::Yes } else { Conj::No };
173
174 let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
176 if !is_beta_zero && !is_beta_one {
178 let c_base = unsafe { c_ptr.offset(c_batch_off) };
179 for i in 0..m {
180 for j in 0..n {
181 let offset = i as isize * c_row_stride + j as isize * c_col_stride;
182 unsafe {
183 let elem = c_base.offset(offset);
184 *elem = beta * *elem;
185 }
186 }
187 }
188 }
189
190 unsafe {
191 let a_mat: MatRef<'_, T> =
192 MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
193 let b_mat: MatRef<'_, T> =
194 MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
195 let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
196 c_ptr.offset(c_batch_off),
197 m,
198 n,
199 c_row_stride,
200 c_col_stride,
201 );
202
203 matmul_with_conj(c_mat, accum, a_mat, cj_a, b_mat, cj_b, alpha, Par::rayon(0));
204 }
205 };
206
207 let fused_a = try_fuse_group(batch_dims, a_batch_strides);
210 let fused_b = try_fuse_group(batch_dims, b_batch_strides);
211 let fused_c = try_fuse_group(batch_dims, c_batch_strides);
212
213 if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
214 (fused_a, fused_b, fused_c)
215 {
216 let mut a_off = 0isize;
217 let mut b_off = 0isize;
218 let mut c_off = 0isize;
219 for _ in 0..total {
220 do_batch(a_off, b_off, c_off);
221 a_off += a_step;
222 b_off += b_step;
223 c_off += c_step;
224 }
225 } else {
226 let mut batch_iter = MultiIndex::new(batch_dims);
227 while batch_iter.next().is_some() {
228 let a_batch_off = batch_iter.offset(a_batch_strides);
229 let b_batch_off = batch_iter.offset(b_batch_strides);
230 let c_batch_off = batch_iter.offset(c_batch_strides);
231 do_batch(a_batch_off, b_batch_off, c_batch_off);
232 }
233 }
234
235 if let Some(ref c_buf) = c_contig_buf {
237 strided_kernel::copy_into(c, &c_buf.view())?;
238 }
239
240 Ok(())
241}
242
243pub fn bgemm_contiguous_into<T>(
253 c: &mut ContiguousOperandMut<T>,
254 a: &ContiguousOperand<T>,
255 b: &ContiguousOperand<T>,
256 batch_dims: &[usize],
257 m: usize,
258 n: usize,
259 k: usize,
260 alpha: T,
261 beta: T,
262) -> strided_view::Result<()>
263where
264 T: ComplexField
265 + Copy
266 + strided_view::ElementOpApply
267 + Send
268 + Sync
269 + std::ops::Mul<Output = T>
270 + std::ops::Add<Output = T>
271 + num_traits::Zero
272 + num_traits::One
273 + PartialEq,
274{
275 let is_beta_zero = beta == T::zero();
276 let is_beta_one = beta == T::one();
277
278 let accum = if is_beta_zero {
279 Accum::Replace
280 } else {
281 Accum::Add
282 };
283
284 let a_batch_strides = a.batch_strides();
285 let b_batch_strides = b.batch_strides();
286 let c_batch_strides = c.batch_strides();
287
288 let a_ptr = a.ptr();
289 let b_ptr = b.ptr();
290 let c_ptr = c.ptr();
291 let a_row_stride = a.row_stride();
292 let a_col_stride = a.col_stride();
293 let b_row_stride = b.row_stride();
294 let b_col_stride = b.col_stride();
295 let c_row_stride = c.row_stride();
296 let c_col_stride = c.col_stride();
297
298 let conj_a = if a.conj() { Conj::Yes } else { Conj::No };
299 let conj_b = if b.conj() { Conj::Yes } else { Conj::No };
300
301 let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
303 if !is_beta_zero && !is_beta_one {
305 let c_base = unsafe { c_ptr.offset(c_batch_off) };
306 for i in 0..m {
307 for j in 0..n {
308 let offset = i as isize * c_row_stride + j as isize * c_col_stride;
309 unsafe {
310 let elem = c_base.offset(offset);
311 *elem = beta * *elem;
312 }
313 }
314 }
315 }
316
317 unsafe {
318 let a_mat: MatRef<'_, T> =
319 MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
320 let b_mat: MatRef<'_, T> =
321 MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
322 let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
323 c_ptr.offset(c_batch_off),
324 m,
325 n,
326 c_row_stride,
327 c_col_stride,
328 );
329
330 matmul_with_conj(
331 c_mat,
332 accum,
333 a_mat,
334 conj_a,
335 b_mat,
336 conj_b,
337 alpha,
338 Par::rayon(0),
339 );
340 }
341 };
342
343 let fused_a = try_fuse_group(batch_dims, a_batch_strides);
346 let fused_b = try_fuse_group(batch_dims, b_batch_strides);
347 let fused_c = try_fuse_group(batch_dims, c_batch_strides);
348
349 if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
350 (fused_a, fused_b, fused_c)
351 {
352 let mut a_off = 0isize;
353 let mut b_off = 0isize;
354 let mut c_off = 0isize;
355 for _ in 0..total {
356 do_batch(a_off, b_off, c_off);
357 a_off += a_step;
358 b_off += b_step;
359 c_off += c_step;
360 }
361 } else {
362 let mut batch_iter = MultiIndex::new(batch_dims);
363 while batch_iter.next().is_some() {
364 let a_batch_off = batch_iter.offset(a_batch_strides);
365 let b_batch_off = batch_iter.offset(b_batch_strides);
366 let c_batch_off = batch_iter.offset(c_batch_strides);
367 do_batch(a_batch_off, b_batch_off, c_batch_off);
368 }
369 }
370
371 Ok(())
372}
373
374use crate::backend::{BgemmBackend, FaerBackend};
375
376impl<T> BgemmBackend<T> for FaerBackend
377where
378 T: crate::Scalar + ComplexField,
379{
380 fn bgemm_contiguous_into(
381 c: &mut ContiguousOperandMut<T>,
382 a: &ContiguousOperand<T>,
383 b: &ContiguousOperand<T>,
384 batch_dims: &[usize],
385 m: usize,
386 n: usize,
387 k: usize,
388 alpha: T,
389 beta: T,
390 ) -> strided_view::Result<()> {
391 bgemm_contiguous_into(c, a, b, batch_dims, m, n, k, alpha, beta)
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use strided_view::StridedArray;
400
401 #[test]
402 fn test_faer_bgemm_2x2() {
403 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
404 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
405 });
406 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
407 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
408 });
409 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
410
411 bgemm_strided_into(
412 &mut c.view_mut(),
413 &a.view(),
414 &b.view(),
415 0,
416 1,
417 1,
418 1,
419 1.0,
420 0.0,
421 false,
422 false,
423 )
424 .unwrap();
425
426 assert_eq!(c.get(&[0, 0]), 19.0);
427 assert_eq!(c.get(&[0, 1]), 22.0);
428 assert_eq!(c.get(&[1, 0]), 43.0);
429 assert_eq!(c.get(&[1, 1]), 50.0);
430 }
431
432 #[test]
433 fn test_faer_bgemm_rect() {
434 let a =
435 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
436 let b =
437 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
438 let mut c = StridedArray::<f64>::row_major(&[2, 4]);
439
440 bgemm_strided_into(
441 &mut c.view_mut(),
442 &a.view(),
443 &b.view(),
444 0,
445 1,
446 1,
447 1,
448 1.0,
449 0.0,
450 false,
451 false,
452 )
453 .unwrap();
454
455 assert_eq!(c.get(&[0, 0]), 38.0);
456 assert_eq!(c.get(&[1, 3]), 128.0);
457 }
458
459 #[test]
460 fn test_faer_bgemm_batched() {
461 let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
463 (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
464 });
465 let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
466 (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
467 });
468 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
469
470 bgemm_strided_into(
471 &mut c.view_mut(),
472 &a.view(),
473 &b.view(),
474 1,
475 1,
476 1,
477 1,
478 1.0,
479 0.0,
480 false,
481 false,
482 )
483 .unwrap();
484
485 assert_eq!(c.get(&[0, 0, 0]), 22.0);
489 }
490
491 #[test]
492 fn test_faer_bgemm_beta_zero() {
493 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
494 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
495 });
496 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
497 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
498 });
499 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
500 [[100.0, 200.0], [300.0, 400.0]][idx[0]][idx[1]]
501 });
502
503 bgemm_strided_into(
504 &mut c.view_mut(),
505 &a.view(),
506 &b.view(),
507 0,
508 1,
509 1,
510 1,
511 1.0,
512 0.0, false,
514 false,
515 )
516 .unwrap();
517
518 assert_eq!(c.get(&[0, 0]), 19.0);
519 assert_eq!(c.get(&[1, 1]), 50.0);
520 }
521
522 #[test]
523 fn test_faer_bgemm_beta_one() {
524 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
525 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
527 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
528 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
529 });
530 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
531 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
532 });
533
534 bgemm_strided_into(
535 &mut c.view_mut(),
536 &a.view(),
537 &b.view(),
538 0,
539 1,
540 1,
541 1,
542 1.0,
543 1.0, false,
545 false,
546 )
547 .unwrap();
548
549 assert_eq!(c.get(&[0, 0]), 11.0);
551 assert_eq!(c.get(&[1, 1]), 44.0);
553 }
554
555 #[test]
556 fn test_faer_bgemm_alpha_beta() {
557 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
558 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
560 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
561 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
562 });
563 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
564 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
565 });
566
567 bgemm_strided_into(
568 &mut c.view_mut(),
569 &a.view(),
570 &b.view(),
571 0,
572 1,
573 1,
574 1,
575 2.0,
576 3.0, false,
578 false,
579 )
580 .unwrap();
581
582 assert_eq!(c.get(&[0, 0]), 32.0);
584 assert_eq!(c.get(&[1, 1]), 128.0);
586 }
587
588 #[test]
589 fn test_faer_bgemm_outer_product() {
590 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
591 let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
592 let mut c = StridedArray::<f64>::row_major(&[3, 4]);
593
594 bgemm_strided_into(
595 &mut c.view_mut(),
596 &a.view(),
597 &b.view(),
598 0,
599 1,
600 1,
601 0, 1.0,
603 0.0,
604 false,
605 false,
606 )
607 .unwrap();
608
609 assert_eq!(c.get(&[0, 0]), 1.0);
610 assert_eq!(c.get(&[2, 3]), 12.0);
611 }
612
613 #[test]
614 fn test_faer_bgemm_f32() {
615 let a = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
616 [[1.0f32, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
617 });
618 let b = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
619 [[5.0f32, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
620 });
621 let mut c = StridedArray::<f32>::row_major(&[2, 2]);
622
623 bgemm_strided_into(
624 &mut c.view_mut(),
625 &a.view(),
626 &b.view(),
627 0,
628 1,
629 1,
630 1,
631 1.0f32,
632 0.0f32,
633 false,
634 false,
635 )
636 .unwrap();
637
638 assert_eq!(c.get(&[0, 0]), 19.0f32);
639 assert_eq!(c.get(&[1, 1]), 50.0f32);
640 }
641
642 #[test]
643 fn test_faer_bgemm_col_major_input() {
644 let a_data = vec![1.0, 3.0, 2.0, 4.0]; let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
647
648 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
649 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
650 });
651 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
652
653 bgemm_strided_into(
654 &mut c.view_mut(),
655 &a.view(),
656 &b.view(),
657 0,
658 1,
659 1,
660 1,
661 1.0,
662 0.0,
663 false,
664 false,
665 )
666 .unwrap();
667
668 assert_eq!(c.get(&[0, 0]), 19.0);
670 assert_eq!(c.get(&[0, 1]), 22.0);
671 assert_eq!(c.get(&[1, 0]), 43.0);
672 assert_eq!(c.get(&[1, 1]), 50.0);
673 }
674
675 #[test]
676 fn test_faer_bgemm_col_major_output() {
677 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
679 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
680 });
681 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
682 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
683 });
684 let mut c = StridedArray::<f64>::col_major(&[2, 2]);
685
686 bgemm_strided_into(
687 &mut c.view_mut(),
688 &a.view(),
689 &b.view(),
690 0,
691 1,
692 1,
693 1,
694 1.0,
695 0.0,
696 false,
697 false,
698 )
699 .unwrap();
700
701 assert_eq!(c.get(&[0, 0]), 19.0);
702 assert_eq!(c.get(&[0, 1]), 22.0);
703 assert_eq!(c.get(&[1, 0]), 43.0);
704 assert_eq!(c.get(&[1, 1]), 50.0);
705 }
706
707 #[test]
708 fn test_faer_bgemm_col_major_with_beta() {
709 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
711 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
713 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
714 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
715 });
716 let c_data = vec![10.0, 30.0, 20.0, 40.0]; let mut c = StridedArray::<f64>::from_parts(c_data, &[2, 2], &[1, 2], 0).unwrap();
719
720 bgemm_strided_into(
721 &mut c.view_mut(),
722 &a.view(),
723 &b.view(),
724 0,
725 1,
726 1,
727 1,
728 2.0,
729 3.0, false,
731 false,
732 )
733 .unwrap();
734
735 assert_eq!(c.get(&[0, 0]), 32.0);
737 assert_eq!(c.get(&[1, 1]), 128.0);
739 }
740
741 use crate::contiguous::{prepare_input_view, prepare_output_view};
744
745 #[test]
746 fn test_bgemm_contiguous_2x2() {
747 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
749 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
750 });
751 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
752 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
753 });
754 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
755
756 let a_op = prepare_input_view(&a.view(), 0, 1, 1, false).unwrap();
760 let b_op = prepare_input_view(&b.view(), 0, 1, 1, false).unwrap();
761 let mut c_view = c.view_mut();
762 let mut c_op = prepare_output_view(&mut c_view, 0, 1, 1, 0.0).unwrap();
763
764 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
765
766 c_op.finalize_into(&mut c_view).unwrap();
767
768 assert_eq!(c.get(&[0, 0]), 19.0);
769 assert_eq!(c.get(&[0, 1]), 22.0);
770 assert_eq!(c.get(&[1, 0]), 43.0);
771 assert_eq!(c.get(&[1, 1]), 50.0);
772 }
773
774 #[test]
775 fn test_bgemm_contiguous_batched() {
776 let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
779 (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
780 });
781 let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
782 (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
783 });
784 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
785
786 let a_op = prepare_input_view(&a.view(), 1, 1, 1, false).unwrap();
788 let b_op = prepare_input_view(&b.view(), 1, 1, 1, false).unwrap();
789 let mut c_view = c.view_mut();
790 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 1, 0.0).unwrap();
791
792 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[2], 2, 2, 3, 1.0, 0.0).unwrap();
793
794 c_op.finalize_into(&mut c_view).unwrap();
795
796 assert_eq!(c.get(&[0, 0, 0]), 22.0);
800 assert_eq!(c.get(&[0, 1, 0]), 28.0);
802 assert_eq!(c.get(&[1, 0, 0]), 49.0);
804 assert_eq!(c.get(&[1, 1, 0]), 64.0);
806
807 assert_eq!(c.get(&[0, 0, 1]), 220.0);
810 }
811
812 #[test]
813 fn test_bgemm_contiguous_with_beta() {
814 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
816 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
818 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
819 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
820 });
821 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
822 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
823 });
824
825 let a_op = prepare_input_view(&a.view(), 0, 1, 1, false).unwrap();
826 let b_op = prepare_input_view(&b.view(), 0, 1, 1, false).unwrap();
827 let mut c_view = c.view_mut();
828 let mut c_op = prepare_output_view(&mut c_view, 0, 1, 1, 3.0).unwrap();
829
830 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 2.0, 3.0).unwrap();
831
832 c_op.finalize_into(&mut c_view).unwrap();
833
834 assert_eq!(c.get(&[0, 0]), 32.0);
836 assert_eq!(c.get(&[0, 1]), 64.0);
838 assert_eq!(c.get(&[1, 0]), 96.0);
840 assert_eq!(c.get(&[1, 1]), 128.0);
842 }
843
844 #[test]
845 fn test_bgemm_contiguous_non_contiguous_input() {
846 let a_data = vec![1.0, 3.0, 2.0, 4.0]; let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
849
850 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
851 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
852 });
853 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
854
855 let a_op = prepare_input_view(&a.view(), 0, 1, 1, false).unwrap();
856 let b_op = prepare_input_view(&b.view(), 0, 1, 1, false).unwrap();
857 let mut c_view = c.view_mut();
858 let mut c_op = prepare_output_view(&mut c_view, 0, 1, 1, 0.0).unwrap();
859
860 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
861
862 c_op.finalize_into(&mut c_view).unwrap();
863
864 assert_eq!(c.get(&[0, 0]), 19.0);
866 assert_eq!(c.get(&[0, 1]), 22.0);
867 assert_eq!(c.get(&[1, 0]), 43.0);
868 assert_eq!(c.get(&[1, 1]), 50.0);
869 }
870}