1use crate::contiguous::{alloc_col_major_uninit, ContiguousOperand, ContiguousOperandMut};
8use crate::util::{try_fuse_group, MultiIndex};
9use faer::linalg::matmul::matmul_with_conj;
10use faer::mat::{MatMut, MatRef};
11use faer::{Accum, Conj, Par};
12use faer_traits::ComplexField;
13use strided_view::{StridedArray, StridedView, StridedViewMut};
14
15pub fn bgemm_strided_into<T>(
21 c: &mut StridedViewMut<T>,
22 a: &StridedView<T>,
23 b: &StridedView<T>,
24 _n_batch: usize,
25 n_lo: usize,
26 n_ro: usize,
27 n_sum: usize,
28 alpha: T,
29 beta: T,
30 conj_a: bool,
31 conj_b: bool,
32) -> strided_view::Result<()>
33where
34 T: ComplexField
35 + Copy
36 + strided_view::ElementOpApply
37 + Send
38 + Sync
39 + std::ops::Mul<Output = T>
40 + std::ops::Add<Output = T>
41 + num_traits::Zero
42 + num_traits::One
43 + PartialEq,
44{
45 let a_dims = a.dims();
46 let b_dims = b.dims();
47 let a_strides = a.strides();
48 let b_strides = b.strides();
49 let c_strides = c.strides();
50
51 let lo_dims = &a_dims[..n_lo];
54 let sum_dims = &a_dims[n_lo..n_lo + n_sum];
55 let batch_dims = &a_dims[n_lo + n_sum..];
56 let ro_dims = &b_dims[n_sum..n_sum + n_ro];
57
58 let m: usize = lo_dims.iter().product::<usize>().max(1);
60 let k: usize = sum_dims.iter().product::<usize>().max(1);
61 let n: usize = ro_dims.iter().product::<usize>().max(1);
62
63 let a_lo_strides = &a_strides[..n_lo];
65 let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
66 let b_sum_strides = &b_strides[..n_sum];
67 let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
68 let c_lo_strides = &c_strides[..n_lo];
69 let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
70
71 let fused_a_lo = try_fuse_group(lo_dims, a_lo_strides);
73 let fused_a_sum = try_fuse_group(sum_dims, a_sum_strides);
74 let fused_b_sum = try_fuse_group(sum_dims, b_sum_strides);
75 let fused_b_ro = try_fuse_group(ro_dims, b_ro_strides);
76 let fused_c_lo = try_fuse_group(lo_dims, c_lo_strides);
77 let fused_c_ro = try_fuse_group(ro_dims, c_ro_strides);
78
79 let a_needs_copy = fused_a_lo.is_none() || fused_a_sum.is_none();
80 let b_needs_copy = fused_b_sum.is_none() || fused_b_ro.is_none();
81 let c_needs_copy = fused_c_lo.is_none() || fused_c_ro.is_none();
82
83 let n_a_inner = n_lo + n_sum;
84 let n_b_inner = n_sum + n_ro;
85 let n_c_inner = n_lo + n_ro;
86
87 let a_contig_buf: Option<StridedArray<T>>;
89 let (a_ptr, a_row_stride, a_col_stride);
90 if a_needs_copy {
91 let mut buf = alloc_col_major_uninit(a.dims());
92 strided_kernel::copy_into(&mut buf.view_mut(), a)?;
93 a_ptr = buf.view().ptr();
94 a_row_stride = if m == 0 { 0 } else { 1isize };
96 a_col_stride = m as isize;
97 a_contig_buf = Some(buf);
98 } else {
99 let (_, rs) = fused_a_lo.unwrap();
100 let (_, cs) = fused_a_sum.unwrap();
101 a_ptr = a.ptr();
102 a_row_stride = rs;
103 a_col_stride = cs;
104 a_contig_buf = None;
105 }
106 let a_batch_strides: &[isize] = match a_contig_buf.as_ref() {
107 Some(buf) => &buf.strides()[n_a_inner..],
108 None => &a_strides[n_a_inner..],
109 };
110
111 let b_contig_buf: Option<StridedArray<T>>;
113 let (b_ptr, b_row_stride, b_col_stride);
114 if b_needs_copy {
115 let mut buf = alloc_col_major_uninit(b.dims());
116 strided_kernel::copy_into(&mut buf.view_mut(), b)?;
117 b_ptr = buf.view().ptr();
118 b_row_stride = if k == 0 { 0 } else { 1isize };
120 b_col_stride = k as isize;
121 b_contig_buf = Some(buf);
122 } else {
123 let (_, rs) = fused_b_sum.unwrap();
124 let (_, cs) = fused_b_ro.unwrap();
125 b_ptr = b.ptr();
126 b_row_stride = rs;
127 b_col_stride = cs;
128 b_contig_buf = None;
129 }
130 let b_batch_strides: &[isize] = match b_contig_buf.as_ref() {
131 Some(buf) => &buf.strides()[n_b_inner..],
132 None => &b_strides[n_b_inner..],
133 };
134
135 let c_contig_buf: Option<StridedArray<T>>;
137 let (c_ptr, c_row_stride, c_col_stride);
138 if c_needs_copy {
139 let mut buf = alloc_col_major_uninit(c.dims());
140 if beta != T::zero() {
141 strided_kernel::copy_into(&mut buf.view_mut(), &c.as_view())?;
142 }
143 c_ptr = buf.view_mut().as_mut_ptr();
144 c_row_stride = if m == 0 { 0 } else { 1isize };
146 c_col_stride = m as isize;
147 c_contig_buf = Some(buf);
148 } else {
149 let (_, rs) = fused_c_lo.unwrap();
150 let (_, cs) = fused_c_ro.unwrap();
151 c_ptr = c.as_mut_ptr();
152 c_row_stride = rs;
153 c_col_stride = cs;
154 c_contig_buf = None;
155 }
156 let c_batch_strides: &[isize] = match c_contig_buf.as_ref() {
157 Some(buf) => &buf.strides()[n_c_inner..],
158 None => &c_strides[n_c_inner..],
159 };
160
161 let is_beta_zero = beta == T::zero();
162 let is_beta_one = beta == T::one();
163
164 let accum = if is_beta_zero {
166 Accum::Replace
167 } else {
168 Accum::Add
169 };
170
171 let cj_a = if conj_a { Conj::Yes } else { Conj::No };
172 let cj_b = if conj_b { Conj::Yes } else { Conj::No };
173
174 let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
176 if !is_beta_zero && !is_beta_one {
178 let c_base = unsafe { c_ptr.offset(c_batch_off) };
179 for i in 0..m {
180 for j in 0..n {
181 let offset = i as isize * c_row_stride + j as isize * c_col_stride;
182 unsafe {
183 let elem = c_base.offset(offset);
184 *elem = beta * *elem;
185 }
186 }
187 }
188 }
189
190 unsafe {
191 let a_mat: MatRef<'_, T> =
192 MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
193 let b_mat: MatRef<'_, T> =
194 MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
195 let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
196 c_ptr.offset(c_batch_off),
197 m,
198 n,
199 c_row_stride,
200 c_col_stride,
201 );
202
203 matmul_with_conj(c_mat, accum, a_mat, cj_a, b_mat, cj_b, alpha, Par::rayon(0));
204 }
205 };
206
207 let fused_a = try_fuse_group(batch_dims, a_batch_strides);
210 let fused_b = try_fuse_group(batch_dims, b_batch_strides);
211 let fused_c = try_fuse_group(batch_dims, c_batch_strides);
212
213 if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
214 (fused_a, fused_b, fused_c)
215 {
216 let mut a_off = 0isize;
217 let mut b_off = 0isize;
218 let mut c_off = 0isize;
219 for _ in 0..total {
220 do_batch(a_off, b_off, c_off);
221 a_off += a_step;
222 b_off += b_step;
223 c_off += c_step;
224 }
225 } else {
226 let mut batch_iter = MultiIndex::new(batch_dims);
227 while batch_iter.next().is_some() {
228 let a_batch_off = batch_iter.offset(a_batch_strides);
229 let b_batch_off = batch_iter.offset(b_batch_strides);
230 let c_batch_off = batch_iter.offset(c_batch_strides);
231 do_batch(a_batch_off, b_batch_off, c_batch_off);
232 }
233 }
234
235 if let Some(ref c_buf) = c_contig_buf {
237 strided_kernel::copy_into(c, &c_buf.view())?;
238 }
239
240 Ok(())
241}
242
243pub fn bgemm_contiguous_into<T>(
253 c: &mut ContiguousOperandMut<T>,
254 a: &ContiguousOperand<T>,
255 b: &ContiguousOperand<T>,
256 batch_dims: &[usize],
257 m: usize,
258 n: usize,
259 k: usize,
260 alpha: T,
261 beta: T,
262) -> strided_view::Result<()>
263where
264 T: ComplexField
265 + Copy
266 + strided_view::ElementOpApply
267 + Send
268 + Sync
269 + std::ops::Mul<Output = T>
270 + std::ops::Add<Output = T>
271 + num_traits::Zero
272 + num_traits::One
273 + PartialEq,
274{
275 let is_beta_zero = beta == T::zero();
276 let is_beta_one = beta == T::one();
277
278 let accum = if is_beta_zero {
279 Accum::Replace
280 } else {
281 Accum::Add
282 };
283
284 let a_batch_strides = a.batch_strides();
285 let b_batch_strides = b.batch_strides();
286 let c_batch_strides = c.batch_strides();
287
288 let a_ptr = a.ptr();
289 let b_ptr = b.ptr();
290 let c_ptr = c.ptr();
291 let a_row_stride = a.row_stride();
292 let a_col_stride = a.col_stride();
293 let b_row_stride = b.row_stride();
294 let b_col_stride = b.col_stride();
295 let c_row_stride = c.row_stride();
296 let c_col_stride = c.col_stride();
297
298 let conj_a = if a.conj() { Conj::Yes } else { Conj::No };
299 let conj_b = if b.conj() { Conj::Yes } else { Conj::No };
300
301 let do_batch = |a_batch_off: isize, b_batch_off: isize, c_batch_off: isize| {
303 if !is_beta_zero && !is_beta_one {
305 let c_base = unsafe { c_ptr.offset(c_batch_off) };
306 for i in 0..m {
307 for j in 0..n {
308 let offset = i as isize * c_row_stride + j as isize * c_col_stride;
309 unsafe {
310 let elem = c_base.offset(offset);
311 *elem = beta * *elem;
312 }
313 }
314 }
315 }
316
317 unsafe {
318 let a_mat: MatRef<'_, T> =
319 MatRef::from_raw_parts(a_ptr.offset(a_batch_off), m, k, a_row_stride, a_col_stride);
320 let b_mat: MatRef<'_, T> =
321 MatRef::from_raw_parts(b_ptr.offset(b_batch_off), k, n, b_row_stride, b_col_stride);
322 let c_mat: MatMut<'_, T> = MatMut::from_raw_parts_mut(
323 c_ptr.offset(c_batch_off),
324 m,
325 n,
326 c_row_stride,
327 c_col_stride,
328 );
329
330 matmul_with_conj(
331 c_mat,
332 accum,
333 a_mat,
334 conj_a,
335 b_mat,
336 conj_b,
337 alpha,
338 Par::rayon(0),
339 );
340 }
341 };
342
343 let fused_a = try_fuse_group(batch_dims, a_batch_strides);
346 let fused_b = try_fuse_group(batch_dims, b_batch_strides);
347 let fused_c = try_fuse_group(batch_dims, c_batch_strides);
348
349 if let (Some((total, a_step)), Some((_, b_step)), Some((_, c_step))) =
350 (fused_a, fused_b, fused_c)
351 {
352 let mut a_off = 0isize;
353 let mut b_off = 0isize;
354 let mut c_off = 0isize;
355 for _ in 0..total {
356 do_batch(a_off, b_off, c_off);
357 a_off += a_step;
358 b_off += b_step;
359 c_off += c_step;
360 }
361 } else {
362 let mut batch_iter = MultiIndex::new(batch_dims);
363 while batch_iter.next().is_some() {
364 let a_batch_off = batch_iter.offset(a_batch_strides);
365 let b_batch_off = batch_iter.offset(b_batch_strides);
366 let c_batch_off = batch_iter.offset(c_batch_strides);
367 do_batch(a_batch_off, b_batch_off, c_batch_off);
368 }
369 }
370
371 Ok(())
372}
373
374use crate::backend::{Backend, FaerBackend};
375
376impl<T> Backend<T> for FaerBackend
377where
378 T: crate::Scalar + ComplexField,
379{
380 const MATERIALIZES_CONJ: bool = false;
381 const REQUIRES_UNIT_STRIDE: bool = false;
382
383 fn bgemm_contiguous_into(
384 c: &mut ContiguousOperandMut<T>,
385 a: &ContiguousOperand<T>,
386 b: &ContiguousOperand<T>,
387 batch_dims: &[usize],
388 m: usize,
389 n: usize,
390 k: usize,
391 alpha: T,
392 beta: T,
393 ) -> strided_view::Result<()> {
394 bgemm_contiguous_into(c, a, b, batch_dims, m, n, k, alpha, beta)
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use strided_view::StridedArray;
403
404 #[test]
405 fn test_faer_bgemm_2x2() {
406 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
407 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
408 });
409 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
410 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
411 });
412 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
413
414 bgemm_strided_into(
415 &mut c.view_mut(),
416 &a.view(),
417 &b.view(),
418 0,
419 1,
420 1,
421 1,
422 1.0,
423 0.0,
424 false,
425 false,
426 )
427 .unwrap();
428
429 assert_eq!(c.get(&[0, 0]), 19.0);
430 assert_eq!(c.get(&[0, 1]), 22.0);
431 assert_eq!(c.get(&[1, 0]), 43.0);
432 assert_eq!(c.get(&[1, 1]), 50.0);
433 }
434
435 #[test]
436 fn test_faer_bgemm_rect() {
437 let a =
438 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
439 let b =
440 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
441 let mut c = StridedArray::<f64>::row_major(&[2, 4]);
442
443 bgemm_strided_into(
444 &mut c.view_mut(),
445 &a.view(),
446 &b.view(),
447 0,
448 1,
449 1,
450 1,
451 1.0,
452 0.0,
453 false,
454 false,
455 )
456 .unwrap();
457
458 assert_eq!(c.get(&[0, 0]), 38.0);
459 assert_eq!(c.get(&[1, 3]), 128.0);
460 }
461
462 #[test]
463 fn test_faer_bgemm_batched() {
464 let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
466 (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
467 });
468 let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
469 (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
470 });
471 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
472
473 bgemm_strided_into(
474 &mut c.view_mut(),
475 &a.view(),
476 &b.view(),
477 1,
478 1,
479 1,
480 1,
481 1.0,
482 0.0,
483 false,
484 false,
485 )
486 .unwrap();
487
488 assert_eq!(c.get(&[0, 0, 0]), 22.0);
492 }
493
494 #[test]
495 fn test_faer_bgemm_beta_zero() {
496 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
497 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
498 });
499 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
500 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
501 });
502 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
503 [[100.0, 200.0], [300.0, 400.0]][idx[0]][idx[1]]
504 });
505
506 bgemm_strided_into(
507 &mut c.view_mut(),
508 &a.view(),
509 &b.view(),
510 0,
511 1,
512 1,
513 1,
514 1.0,
515 0.0, false,
517 false,
518 )
519 .unwrap();
520
521 assert_eq!(c.get(&[0, 0]), 19.0);
522 assert_eq!(c.get(&[1, 1]), 50.0);
523 }
524
525 #[test]
526 fn test_faer_bgemm_beta_one() {
527 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
528 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
530 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
531 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
532 });
533 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
534 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
535 });
536
537 bgemm_strided_into(
538 &mut c.view_mut(),
539 &a.view(),
540 &b.view(),
541 0,
542 1,
543 1,
544 1,
545 1.0,
546 1.0, false,
548 false,
549 )
550 .unwrap();
551
552 assert_eq!(c.get(&[0, 0]), 11.0);
554 assert_eq!(c.get(&[1, 1]), 44.0);
556 }
557
558 #[test]
559 fn test_faer_bgemm_alpha_beta() {
560 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
561 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
563 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
564 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
565 });
566 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
567 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
568 });
569
570 bgemm_strided_into(
571 &mut c.view_mut(),
572 &a.view(),
573 &b.view(),
574 0,
575 1,
576 1,
577 1,
578 2.0,
579 3.0, false,
581 false,
582 )
583 .unwrap();
584
585 assert_eq!(c.get(&[0, 0]), 32.0);
587 assert_eq!(c.get(&[1, 1]), 128.0);
589 }
590
591 #[test]
592 fn test_faer_bgemm_outer_product() {
593 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
594 let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
595 let mut c = StridedArray::<f64>::row_major(&[3, 4]);
596
597 bgemm_strided_into(
598 &mut c.view_mut(),
599 &a.view(),
600 &b.view(),
601 0,
602 1,
603 1,
604 0, 1.0,
606 0.0,
607 false,
608 false,
609 )
610 .unwrap();
611
612 assert_eq!(c.get(&[0, 0]), 1.0);
613 assert_eq!(c.get(&[2, 3]), 12.0);
614 }
615
616 #[test]
617 fn test_faer_bgemm_f32() {
618 let a = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
619 [[1.0f32, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
620 });
621 let b = StridedArray::<f32>::from_fn_row_major(&[2, 2], |idx| {
622 [[5.0f32, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
623 });
624 let mut c = StridedArray::<f32>::row_major(&[2, 2]);
625
626 bgemm_strided_into(
627 &mut c.view_mut(),
628 &a.view(),
629 &b.view(),
630 0,
631 1,
632 1,
633 1,
634 1.0f32,
635 0.0f32,
636 false,
637 false,
638 )
639 .unwrap();
640
641 assert_eq!(c.get(&[0, 0]), 19.0f32);
642 assert_eq!(c.get(&[1, 1]), 50.0f32);
643 }
644
645 #[test]
646 fn test_faer_bgemm_col_major_input() {
647 let a_data = vec![1.0, 3.0, 2.0, 4.0]; let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
650
651 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
652 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
653 });
654 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
655
656 bgemm_strided_into(
657 &mut c.view_mut(),
658 &a.view(),
659 &b.view(),
660 0,
661 1,
662 1,
663 1,
664 1.0,
665 0.0,
666 false,
667 false,
668 )
669 .unwrap();
670
671 assert_eq!(c.get(&[0, 0]), 19.0);
673 assert_eq!(c.get(&[0, 1]), 22.0);
674 assert_eq!(c.get(&[1, 0]), 43.0);
675 assert_eq!(c.get(&[1, 1]), 50.0);
676 }
677
678 #[test]
679 fn test_faer_bgemm_col_major_output() {
680 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
682 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
683 });
684 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
685 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
686 });
687 let mut c = StridedArray::<f64>::col_major(&[2, 2]);
688
689 bgemm_strided_into(
690 &mut c.view_mut(),
691 &a.view(),
692 &b.view(),
693 0,
694 1,
695 1,
696 1,
697 1.0,
698 0.0,
699 false,
700 false,
701 )
702 .unwrap();
703
704 assert_eq!(c.get(&[0, 0]), 19.0);
705 assert_eq!(c.get(&[0, 1]), 22.0);
706 assert_eq!(c.get(&[1, 0]), 43.0);
707 assert_eq!(c.get(&[1, 1]), 50.0);
708 }
709
710 #[test]
711 fn test_faer_bgemm_col_major_with_beta() {
712 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
714 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
716 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
717 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
718 });
719 let c_data = vec![10.0, 30.0, 20.0, 40.0]; let mut c = StridedArray::<f64>::from_parts(c_data, &[2, 2], &[1, 2], 0).unwrap();
722
723 bgemm_strided_into(
724 &mut c.view_mut(),
725 &a.view(),
726 &b.view(),
727 0,
728 1,
729 1,
730 1,
731 2.0,
732 3.0, false,
734 false,
735 )
736 .unwrap();
737
738 assert_eq!(c.get(&[0, 0]), 32.0);
740 assert_eq!(c.get(&[1, 1]), 128.0);
742 }
743
744 use crate::backend::{ActiveBackend, Backend};
747 use crate::contiguous::{prepare_input_view, prepare_output_view};
748
749 const US: bool = <ActiveBackend as Backend<f64>>::REQUIRES_UNIT_STRIDE;
750
751 #[test]
752 fn test_bgemm_contiguous_2x2() {
753 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
754 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
755 });
756 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
757 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
758 });
759 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
760
761 let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
762 let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
763 let mut c_view = c.view_mut();
764 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
765
766 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
767
768 c_op.finalize_into(&mut c_view).unwrap();
769
770 assert_eq!(c.get(&[0, 0]), 19.0);
771 assert_eq!(c.get(&[0, 1]), 22.0);
772 assert_eq!(c.get(&[1, 0]), 43.0);
773 assert_eq!(c.get(&[1, 1]), 50.0);
774 }
775
776 #[test]
777 fn test_bgemm_contiguous_batched() {
778 let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
781 (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
782 });
783 let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
784 (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
785 });
786 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
787
788 let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
790 let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
791 let mut c_view = c.view_mut();
792 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
793
794 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[2], 2, 2, 3, 1.0, 0.0).unwrap();
795
796 c_op.finalize_into(&mut c_view).unwrap();
797
798 assert_eq!(c.get(&[0, 0, 0]), 22.0);
802 assert_eq!(c.get(&[0, 1, 0]), 28.0);
804 assert_eq!(c.get(&[1, 0, 0]), 49.0);
806 assert_eq!(c.get(&[1, 1, 0]), 64.0);
808
809 assert_eq!(c.get(&[0, 0, 1]), 220.0);
812 }
813
814 #[test]
815 fn test_bgemm_contiguous_with_beta() {
816 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
818 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
820 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
821 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
822 });
823 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
824 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
825 });
826
827 let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
828 let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
829 let mut c_view = c.view_mut();
830 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 3.0, US, true).unwrap();
831
832 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 2.0, 3.0).unwrap();
833
834 c_op.finalize_into(&mut c_view).unwrap();
835
836 assert_eq!(c.get(&[0, 0]), 32.0);
838 assert_eq!(c.get(&[0, 1]), 64.0);
840 assert_eq!(c.get(&[1, 0]), 96.0);
842 assert_eq!(c.get(&[1, 1]), 128.0);
844 }
845
846 #[test]
847 fn test_bgemm_contiguous_non_contiguous_input() {
848 let a_data = vec![1.0, 3.0, 2.0, 4.0]; let a = StridedArray::<f64>::from_parts(a_data, &[2, 2], &[1, 2], 0).unwrap();
851
852 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
853 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
854 });
855 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
856
857 let a_op = prepare_input_view(&a.view(), 1, 1, false, US, true, None).unwrap();
858 let b_op = prepare_input_view(&b.view(), 1, 1, false, US, true, None).unwrap();
859 let mut c_view = c.view_mut();
860 let mut c_op = prepare_output_view(&mut c_view, 1, 1, 0.0, US, true).unwrap();
861
862 bgemm_contiguous_into(&mut c_op, &a_op, &b_op, &[], 2, 2, 2, 1.0, 0.0).unwrap();
863
864 c_op.finalize_into(&mut c_view).unwrap();
865
866 assert_eq!(c.get(&[0, 0]), 19.0);
868 assert_eq!(c.get(&[0, 1]), 22.0);
869 assert_eq!(c.get(&[1, 0]), 43.0);
870 assert_eq!(c.get(&[1, 1]), 50.0);
871 }
872}