1use crate::backend::Backend;
9use crate::{contiguous, Scalar, ScalarBase};
10use strided_view::{Conj, ElementOp, ElementOpApply, RawStridedMut, RawStridedRef};
11
12#[allow(clippy::too_many_arguments)]
18pub fn bgemm_raw_strided_into<T>(
19 c: RawStridedMut<'_, T>,
20 a: RawStridedRef<'_, T>,
21 b: RawStridedRef<'_, T>,
22 n_batch: usize,
23 n_lo: usize,
24 n_ro: usize,
25 n_sum: usize,
26 alpha: T,
27 beta: T,
28 conj_a: bool,
29 conj_b: bool,
30) -> crate::Result<()>
31where
32 T: Scalar,
33 crate::backend::ActiveBackend: Backend<T>,
34{
35 validate_bgemm_shapes(&c, &a, &b, n_batch, n_lo, n_ro, n_sum)?;
36 unsafe {
37 bgemm_raw_strided_into_unchecked(
38 c, a, b, n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
39 )
40 }
41}
42
43#[allow(clippy::too_many_arguments)]
53pub unsafe fn bgemm_raw_strided_into_unchecked<T>(
54 c: RawStridedMut<'_, T>,
55 a: RawStridedRef<'_, T>,
56 b: RawStridedRef<'_, T>,
57 n_batch: usize,
58 n_lo: usize,
59 n_ro: usize,
60 n_sum: usize,
61 alpha: T,
62 beta: T,
63 conj_a: bool,
64 conj_b: bool,
65) -> crate::Result<()>
66where
67 T: Scalar,
68 crate::backend::ActiveBackend: Backend<T>,
69{
70 bgemm_raw_with_backend_into_unchecked::<T, crate::backend::ActiveBackend>(
71 c, a, b, n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
72 )
73}
74
75#[allow(clippy::too_many_arguments)]
82pub fn bgemm_raw_with_backend_into<T, B>(
83 c: RawStridedMut<'_, T>,
84 a: RawStridedRef<'_, T>,
85 b: RawStridedRef<'_, T>,
86 n_batch: usize,
87 n_lo: usize,
88 n_ro: usize,
89 n_sum: usize,
90 alpha: T,
91 beta: T,
92 conj_a: bool,
93 conj_b: bool,
94) -> crate::Result<()>
95where
96 T: ScalarBase + ElementOpApply,
97 B: Backend<T>,
98{
99 validate_bgemm_shapes(&c, &a, &b, n_batch, n_lo, n_ro, n_sum)?;
100 unsafe {
101 bgemm_raw_with_backend_into_unchecked::<T, B>(
102 c, a, b, n_batch, n_lo, n_ro, n_sum, alpha, beta, conj_a, conj_b,
103 )
104 }
105}
106
107#[allow(clippy::too_many_arguments)]
113pub unsafe fn bgemm_raw_with_backend_into_unchecked<T, B>(
114 mut c: RawStridedMut<'_, T>,
115 a: RawStridedRef<'_, T>,
116 b: RawStridedRef<'_, T>,
117 _n_batch: usize,
118 n_lo: usize,
119 n_ro: usize,
120 n_sum: usize,
121 alpha: T,
122 beta: T,
123 conj_a: bool,
124 conj_b: bool,
125) -> crate::Result<()>
126where
127 T: ScalarBase + ElementOpApply,
128 B: Backend<T>,
129{
130 let a_dims = a.dims();
131 let b_dims = b.dims();
132 let lo_dims = &a_dims[..n_lo];
133 let sum_dims = &a_dims[n_lo..n_lo + n_sum];
134 let batch_dims = &a_dims[n_lo + n_sum..];
135 let ro_dims = &b_dims[n_sum..n_sum + n_ro];
136
137 if c.dims().iter().any(|&dim| dim == 0) {
138 return Ok(());
139 }
140 if sum_dims.iter().any(|&dim| dim == 0) {
141 scale_or_zero_raw_mut(&mut c, beta);
142 return Ok(());
143 }
144
145 let use_pool = true;
146 let materialize = if B::MATERIALIZES_CONJ {
147 Some(Conj::apply as fn(T) -> T)
148 } else {
149 None
150 };
151
152 let a_op = contiguous::prepare_input_raw(
153 &a,
154 n_lo,
155 n_sum,
156 conj_a,
157 B::REQUIRES_UNIT_STRIDE,
158 use_pool,
159 materialize,
160 )?;
161 let b_op = contiguous::prepare_input_raw(
162 &b,
163 n_sum,
164 n_ro,
165 conj_b,
166 B::REQUIRES_UNIT_STRIDE,
167 use_pool,
168 materialize,
169 )?;
170 let mut c_op = contiguous::prepare_output_raw(
171 &mut c,
172 n_lo,
173 n_ro,
174 beta,
175 B::REQUIRES_UNIT_STRIDE,
176 use_pool,
177 )?;
178
179 let m: usize = lo_dims.iter().product::<usize>().max(1);
180 let k: usize = sum_dims.iter().product::<usize>().max(1);
181 let n: usize = ro_dims.iter().product::<usize>().max(1);
182
183 B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, batch_dims, m, n, k, alpha, beta)?;
184 c_op.finalize_raw_into(&mut c)?;
185
186 Ok(())
187}
188
189pub(crate) fn validate_bgemm_shapes<T>(
190 c: &RawStridedMut<'_, T>,
191 a: &RawStridedRef<'_, T>,
192 b: &RawStridedRef<'_, T>,
193 n_batch: usize,
194 n_lo: usize,
195 n_ro: usize,
196 n_sum: usize,
197) -> crate::Result<()> {
198 let a_rank = n_lo + n_sum + n_batch;
199 let b_rank = n_sum + n_ro + n_batch;
200 let c_rank = n_lo + n_ro + n_batch;
201 if a.dims().len() != a_rank {
202 return Err(strided_view::StridedError::RankMismatch(a_rank, a.dims().len()).into());
203 }
204 if b.dims().len() != b_rank {
205 return Err(strided_view::StridedError::RankMismatch(b_rank, b.dims().len()).into());
206 }
207 if c.dims().len() != c_rank {
208 return Err(strided_view::StridedError::RankMismatch(c_rank, c.dims().len()).into());
209 }
210
211 let lo_dims = &a.dims()[..n_lo];
212 let sum_dims = &a.dims()[n_lo..n_lo + n_sum];
213 let batch_dims = &a.dims()[n_lo + n_sum..];
214 let ro_dims = &b.dims()[n_sum..n_sum + n_ro];
215
216 if &b.dims()[..n_sum] != sum_dims {
217 return Err(strided_view::StridedError::ShapeMismatch(
218 sum_dims.to_vec(),
219 b.dims()[..n_sum].to_vec(),
220 )
221 .into());
222 }
223 if &b.dims()[n_sum + n_ro..] != batch_dims {
224 return Err(strided_view::StridedError::ShapeMismatch(
225 batch_dims.to_vec(),
226 b.dims()[n_sum + n_ro..].to_vec(),
227 )
228 .into());
229 }
230 if &c.dims()[..n_lo] != lo_dims {
231 return Err(strided_view::StridedError::ShapeMismatch(
232 lo_dims.to_vec(),
233 c.dims()[..n_lo].to_vec(),
234 )
235 .into());
236 }
237 if &c.dims()[n_lo..n_lo + n_ro] != ro_dims {
238 return Err(strided_view::StridedError::ShapeMismatch(
239 ro_dims.to_vec(),
240 c.dims()[n_lo..n_lo + n_ro].to_vec(),
241 )
242 .into());
243 }
244 if &c.dims()[n_lo + n_ro..] != batch_dims {
245 return Err(strided_view::StridedError::ShapeMismatch(
246 batch_dims.to_vec(),
247 c.dims()[n_lo + n_ro..].to_vec(),
248 )
249 .into());
250 }
251 Ok(())
252}
253
254pub(crate) fn scale_or_zero_raw_mut<T: ScalarBase>(c: &mut RawStridedMut<'_, T>, beta: T) {
255 if c.dims().iter().any(|&dim| dim == 0) {
256 return;
257 }
258
259 fn visit<T: ScalarBase>(
260 ptr: *mut T,
261 dims: &[usize],
262 strides: &[isize],
263 axis: usize,
264 offset: isize,
265 beta: T,
266 zero: T,
267 ) {
268 if axis == dims.len() {
269 unsafe {
270 let dst = ptr.offset(offset);
271 if beta == zero {
272 *dst = zero;
273 } else {
274 *dst = beta * *dst;
275 }
276 }
277 return;
278 }
279
280 for i in 0..dims[axis] {
281 visit(
282 ptr,
283 dims,
284 strides,
285 axis + 1,
286 offset + i as isize * strides[axis],
287 beta,
288 zero,
289 );
290 }
291 }
292
293 visit(c.as_mut_ptr(), c.dims(), c.strides(), 0, 0, beta, T::zero());
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 fn raw_bgemm_2x2<T>(one: T, zero: T) -> Vec<T>
301 where
302 T: Scalar,
303 crate::backend::ActiveBackend: Backend<T>,
304 T: From<f32>,
305 {
306 let dims = [2, 2];
307 let strides = [2, 1];
308 let a_data = [T::from(1.0), T::from(2.0), T::from(3.0), T::from(4.0)];
309 let b_data = [T::from(5.0), T::from(6.0), T::from(7.0), T::from(8.0)];
310 let mut c_data = vec![zero; 4];
311 let a = RawStridedRef::new(&a_data, &dims, &strides, 0).unwrap();
312 let b = RawStridedRef::new(&b_data, &dims, &strides, 0).unwrap();
313 let c = RawStridedMut::new(&mut c_data, &dims, &strides, 0).unwrap();
314 bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, one, zero, false, false).unwrap();
315 c_data
316 }
317
318 #[test]
319 fn raw_bgemm_active_backend_f64() {
320 assert_eq!(raw_bgemm_2x2(1.0f64, 0.0), vec![19.0, 22.0, 43.0, 50.0]);
321 }
322
323 #[test]
324 fn raw_bgemm_active_backend_f32() {
325 assert_eq!(raw_bgemm_2x2(1.0f32, 0.0), vec![19.0f32, 22.0, 43.0, 50.0]);
326 }
327
328 #[test]
329 fn raw_bgemm_active_backend_complex_conj() {
330 use num_complex::Complex64;
331
332 let i = Complex64::i();
333 let dims = [2, 2];
334 let strides = [2, 1];
335 let a_data = [
336 Complex64::new(1.0, 0.0) + i,
337 Complex64::new(2.0, 0.0),
338 Complex64::new(3.0, 0.0),
339 Complex64::new(4.0, 0.0) - i,
340 ];
341 let b_data = [
342 Complex64::new(1.0, 0.0),
343 Complex64::new(0.0, 0.0),
344 Complex64::new(0.0, 0.0),
345 Complex64::new(1.0, 0.0),
346 ];
347 let mut c_data = vec![Complex64::new(0.0, 0.0); 4];
348 let a = RawStridedRef::new(&a_data, &dims, &strides, 0).unwrap();
349 let b = RawStridedRef::new(&b_data, &dims, &strides, 0).unwrap();
350 let c = RawStridedMut::new(&mut c_data, &dims, &strides, 0).unwrap();
351 bgemm_raw_strided_into(
352 c,
353 a,
354 b,
355 0,
356 1,
357 1,
358 1,
359 Complex64::new(1.0, 0.0),
360 Complex64::new(0.0, 0.0),
361 true,
362 false,
363 )
364 .unwrap();
365 assert_eq!(
366 c_data,
367 vec![
368 Complex64::new(1.0, -1.0),
369 Complex64::new(2.0, 0.0),
370 Complex64::new(3.0, 0.0),
371 Complex64::new(4.0, 1.0),
372 ]
373 );
374 }
375
376 #[test]
377 fn raw_bgemm_active_backend_checked_shape_mismatch() {
378 let a_dims = [2, 2];
379 let b_dims = [3, 2];
380 let c_dims = [2, 2];
381 let a_strides = [2, 1];
382 let b_strides = [2, 1];
383 let c_strides = [2, 1];
384 let a_data = [1.0, 2.0, 3.0, 4.0];
385 let b_data = [0.0; 6];
386 let mut c_data = [0.0; 4];
387 let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
388 let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
389 let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
390 let err = bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false).unwrap_err();
391 assert!(matches!(
392 err,
393 crate::EinsumError::Strided(strided_view::StridedError::ShapeMismatch(_, _))
394 ));
395 }
396
397 #[test]
398 fn raw_bgemm_explicit_backend_checked_rank_mismatch() {
399 let a_dims = [2, 2];
400 let b_dims = [2, 2];
401 let c_dims = [2];
402 let strides = [2, 1];
403 let c_strides = [1];
404 let a_data = [1.0, 2.0, 3.0, 4.0];
405 let b_data = [5.0, 6.0, 7.0, 8.0];
406 let mut c_data = [0.0; 2];
407 let a = RawStridedRef::new(&a_data, &a_dims, &strides, 0).unwrap();
408 let b = RawStridedRef::new(&b_data, &b_dims, &strides, 0).unwrap();
409 let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
410 let err = bgemm_raw_with_backend_into::<f64, crate::backend::ActiveBackend>(
411 c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false,
412 )
413 .unwrap_err();
414 assert!(matches!(
415 err,
416 crate::EinsumError::Strided(strided_view::StridedError::RankMismatch(2, 1))
417 ));
418 }
419
420 #[test]
421 fn raw_bgemm_zero_sum_scales_destination() {
422 let a_dims = [2, 0];
423 let b_dims = [0, 2];
424 let c_dims = [2, 2];
425 let a_strides = [0, 0];
426 let b_strides = [0, 0];
427 let c_strides = [2, 1];
428 let a_data = [0.0; 1];
429 let b_data = [0.0; 1];
430 let mut c_data = [1.0, 2.0, 3.0, 4.0];
431 let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
432 let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
433 let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
434
435 bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 2.0, false, false).unwrap();
436
437 assert_eq!(c_data, [2.0, 4.0, 6.0, 8.0]);
438 }
439
440 #[test]
441 fn raw_bgemm_zero_sum_beta_zero_clears_destination() {
442 let a_dims = [2, 0];
443 let b_dims = [0, 2];
444 let c_dims = [2, 2];
445 let a_strides = [0, 0];
446 let b_strides = [0, 0];
447 let c_strides = [2, 1];
448 let a_data = [0.0; 1];
449 let b_data = [0.0; 1];
450 let mut c_data = [1.0, 2.0, 3.0, 4.0];
451 let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
452 let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
453 let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
454
455 bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false).unwrap();
456
457 assert_eq!(c_data, [0.0, 0.0, 0.0, 0.0]);
458 }
459
460 #[test]
461 fn raw_bgemm_empty_output_is_noop() {
462 let a_dims = [0, 2];
463 let b_dims = [2, 2];
464 let c_dims = [0, 2];
465 let a_strides = [2, 1];
466 let b_strides = [2, 1];
467 let c_strides = [2, 1];
468 let a_data = [1.0, 2.0];
469 let b_data = [3.0, 4.0, 5.0, 6.0];
470 let mut c_data = [7.0, 8.0, 9.0, 10.0];
471 let expected = c_data;
472 let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
473 let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
474 let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 0).unwrap();
475
476 bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 1.0, false, false).unwrap();
477
478 assert_eq!(c_data, expected);
479 }
480
481 #[test]
482 fn raw_bgemm_noncontiguous_output_writes_back() {
483 let a_dims = [2, 2];
484 let b_dims = [2, 2];
485 let c_dims = [2, 2];
486 let a_strides = [2, 1];
487 let b_strides = [2, 1];
488 let c_strides = [1, 3];
489 let a_data = [1.0, 2.0, 3.0, 4.0];
490 let b_data = [5.0, 6.0, 7.0, 8.0];
491 let mut c_data = [0.0; 8];
492 let a = RawStridedRef::new(&a_data, &a_dims, &a_strides, 0).unwrap();
493 let b = RawStridedRef::new(&b_data, &b_dims, &b_strides, 0).unwrap();
494 let c = RawStridedMut::new(&mut c_data, &c_dims, &c_strides, 1).unwrap();
495
496 bgemm_raw_strided_into(c, a, b, 0, 1, 1, 1, 1.0, 0.0, false, false).unwrap();
497
498 assert_eq!(c_data[1], 19.0);
499 assert_eq!(c_data[4], 22.0);
500 assert_eq!(c_data[2], 43.0);
501 assert_eq!(c_data[5], 50.0);
502 assert_eq!(c_data[0], 0.0);
503 assert_eq!(c_data[3], 0.0);
504 }
505}