1use crate::util::MultiIndex;
9use strided_view::{ElementOp, ElementOpApply, StridedView, StridedViewMut};
10
11pub fn bgemm_strided_into<T>(
21 c: &mut StridedViewMut<T>,
22 a: &StridedView<T>,
23 b: &StridedView<T>,
24 _n_batch: usize,
25 n_lo: usize,
26 n_ro: usize,
27 n_sum: usize,
28 alpha: T,
29 beta: T,
30 conj_a: bool,
31 conj_b: bool,
32) -> strided_view::Result<()>
33where
34 T: Copy
35 + ElementOpApply
36 + std::ops::Mul<Output = T>
37 + std::ops::Add<Output = T>
38 + num_traits::Zero
39 + num_traits::One
40 + PartialEq,
41{
42 let a_dims = a.dims();
43 let b_dims = b.dims();
44 let c_dims = c.dims();
45 let a_strides = a.strides();
46 let b_strides = b.strides();
47 let c_strides = c.strides();
48
49 let lo_dims = &a_dims[..n_lo];
51 let sum_dims = &a_dims[n_lo..n_lo + n_sum];
52 let batch_dims = &a_dims[n_lo + n_sum..];
53 let ro_dims = &b_dims[n_sum..n_sum + n_ro];
54
55 let a_lo_strides = &a_strides[..n_lo];
57 let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
58 let a_batch_strides = &a_strides[n_lo + n_sum..];
59
60 let b_sum_strides = &b_strides[..n_sum];
61 let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
62 let b_batch_strides = &b_strides[n_sum + n_ro..];
63
64 let c_lo_strides = &c_strides[..n_lo];
65 let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
66 let c_batch_strides = &c_strides[n_lo + n_ro..];
67
68 debug_assert_eq!(&c_dims[..n_lo], lo_dims);
70 debug_assert_eq!(&c_dims[n_lo..n_lo + n_ro], ro_dims);
71 debug_assert_eq!(&c_dims[n_lo + n_ro..], batch_dims);
72 debug_assert_eq!(&b_dims[..n_sum], sum_dims);
73 debug_assert_eq!(&b_dims[n_sum + n_ro..], batch_dims);
74
75 let a_ptr = a.ptr();
76 let b_ptr = b.ptr();
77 let c_ptr = c.as_mut_ptr();
78
79 let is_beta_zero = beta == T::zero();
80 let is_alpha_one = alpha == T::one();
81
82 let mut batch_iter = MultiIndex::new(batch_dims);
83 let mut lo_iter = MultiIndex::new(lo_dims);
84 let mut ro_iter = MultiIndex::new(ro_dims);
85 let mut sum_iter = MultiIndex::new(sum_dims);
86 while batch_iter.next().is_some() {
87 let a_batch_off = batch_iter.offset(a_batch_strides);
88 let b_batch_off = batch_iter.offset(b_batch_strides);
89 let c_batch_off = batch_iter.offset(c_batch_strides);
90
91 lo_iter.reset();
92 while lo_iter.next().is_some() {
93 let a_lo_off = lo_iter.offset(a_lo_strides);
94 let c_lo_off = lo_iter.offset(c_lo_strides);
95
96 ro_iter.reset();
97 while ro_iter.next().is_some() {
98 let b_ro_off = ro_iter.offset(b_ro_strides);
99 let c_ro_off = ro_iter.offset(c_ro_strides);
100
101 let mut acc = T::zero();
103 sum_iter.reset();
104 while sum_iter.next().is_some() {
105 let a_sum_off = sum_iter.offset(a_sum_strides);
106 let b_sum_off = sum_iter.offset(b_sum_strides);
107
108 let a_raw = unsafe { *a_ptr.offset(a_batch_off + a_lo_off + a_sum_off) };
109 let b_raw = unsafe { *b_ptr.offset(b_batch_off + b_sum_off + b_ro_off) };
110 let a_val = if conj_a {
111 strided_view::Conj::apply(a_raw)
112 } else {
113 a_raw
114 };
115 let b_val = if conj_b {
116 strided_view::Conj::apply(b_raw)
117 } else {
118 b_raw
119 };
120 acc = acc + a_val * b_val;
121 }
122
123 let c_off = c_batch_off + c_lo_off + c_ro_off;
125 unsafe {
126 let c_elem = c_ptr.offset(c_off);
127 if is_beta_zero {
128 if is_alpha_one {
129 *c_elem = acc;
130 } else {
131 *c_elem = alpha * acc;
132 }
133 } else {
134 let old = *c_elem;
135 if is_alpha_one {
136 *c_elem = acc + beta * old;
137 } else {
138 *c_elem = alpha * acc + beta * old;
139 }
140 }
141 }
142 }
143 }
144 }
145
146 Ok(())
147}
148
149pub fn bgemm_strided_into_with_map<T, MapA, MapB>(
154 c: &mut StridedViewMut<T>,
155 a: &StridedView<T>,
156 b: &StridedView<T>,
157 _n_batch: usize,
158 n_lo: usize,
159 n_ro: usize,
160 n_sum: usize,
161 alpha: T,
162 beta: T,
163 map_a: MapA,
164 map_b: MapB,
165) -> strided_view::Result<()>
166where
167 T: Copy
168 + std::ops::Mul<Output = T>
169 + std::ops::Add<Output = T>
170 + num_traits::Zero
171 + num_traits::One
172 + PartialEq,
173 MapA: Fn(T) -> T,
174 MapB: Fn(T) -> T,
175{
176 let a_dims = a.dims();
177 let b_dims = b.dims();
178 let c_dims = c.dims();
179 let a_strides = a.strides();
180 let b_strides = b.strides();
181 let c_strides = c.strides();
182
183 let lo_dims = &a_dims[..n_lo];
184 let sum_dims = &a_dims[n_lo..n_lo + n_sum];
185 let batch_dims = &a_dims[n_lo + n_sum..];
186 let ro_dims = &b_dims[n_sum..n_sum + n_ro];
187
188 let a_lo_strides = &a_strides[..n_lo];
189 let a_sum_strides = &a_strides[n_lo..n_lo + n_sum];
190 let a_batch_strides = &a_strides[n_lo + n_sum..];
191
192 let b_sum_strides = &b_strides[..n_sum];
193 let b_ro_strides = &b_strides[n_sum..n_sum + n_ro];
194 let b_batch_strides = &b_strides[n_sum + n_ro..];
195
196 let c_lo_strides = &c_strides[..n_lo];
197 let c_ro_strides = &c_strides[n_lo..n_lo + n_ro];
198 let c_batch_strides = &c_strides[n_lo + n_ro..];
199
200 debug_assert_eq!(&c_dims[..n_lo], lo_dims);
201 debug_assert_eq!(&c_dims[n_lo..n_lo + n_ro], ro_dims);
202 debug_assert_eq!(&c_dims[n_lo + n_ro..], batch_dims);
203 debug_assert_eq!(&b_dims[..n_sum], sum_dims);
204 debug_assert_eq!(&b_dims[n_sum + n_ro..], batch_dims);
205
206 let a_ptr = a.ptr();
207 let b_ptr = b.ptr();
208 let c_ptr = c.as_mut_ptr();
209
210 let is_beta_zero = beta == T::zero();
211 let is_alpha_one = alpha == T::one();
212
213 let mut batch_iter = MultiIndex::new(batch_dims);
214 let mut lo_iter = MultiIndex::new(lo_dims);
215 let mut ro_iter = MultiIndex::new(ro_dims);
216 let mut sum_iter = MultiIndex::new(sum_dims);
217 while batch_iter.next().is_some() {
218 let a_batch_off = batch_iter.offset(a_batch_strides);
219 let b_batch_off = batch_iter.offset(b_batch_strides);
220 let c_batch_off = batch_iter.offset(c_batch_strides);
221
222 lo_iter.reset();
223 while lo_iter.next().is_some() {
224 let a_lo_off = lo_iter.offset(a_lo_strides);
225 let c_lo_off = lo_iter.offset(c_lo_strides);
226
227 ro_iter.reset();
228 while ro_iter.next().is_some() {
229 let b_ro_off = ro_iter.offset(b_ro_strides);
230 let c_ro_off = ro_iter.offset(c_ro_strides);
231
232 let mut acc = T::zero();
233 sum_iter.reset();
234 while sum_iter.next().is_some() {
235 let a_sum_off = sum_iter.offset(a_sum_strides);
236 let b_sum_off = sum_iter.offset(b_sum_strides);
237
238 let a_raw = unsafe { *a_ptr.offset(a_batch_off + a_lo_off + a_sum_off) };
239 let b_raw = unsafe { *b_ptr.offset(b_batch_off + b_sum_off + b_ro_off) };
240 acc = acc + map_a(a_raw) * map_b(b_raw);
241 }
242
243 let c_off = c_batch_off + c_lo_off + c_ro_off;
244 unsafe {
245 let c_elem = c_ptr.offset(c_off);
246 if is_beta_zero {
247 if is_alpha_one {
248 *c_elem = acc;
249 } else {
250 *c_elem = alpha * acc;
251 }
252 } else {
253 let old = *c_elem;
254 if is_alpha_one {
255 *c_elem = acc + beta * old;
256 } else {
257 *c_elem = alpha * acc + beta * old;
258 }
259 }
260 }
261 }
262 }
263 }
264
265 Ok(())
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use strided_view::StridedArray;
272
273 #[test]
274 fn test_bgemm_2x2() {
275 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
279 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
280 });
281 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
282 [[5.0, 6.0], [7.0, 8.0]][idx[0]][idx[1]]
283 });
284 let mut c = StridedArray::<f64>::row_major(&[2, 2]);
285
286 bgemm_strided_into(
287 &mut c.view_mut(),
288 &a.view(),
289 &b.view(),
290 0,
291 1,
292 1,
293 1, 1.0,
295 0.0,
296 false,
297 false,
298 )
299 .unwrap();
300
301 assert_eq!(c.get(&[0, 0]), 19.0);
302 assert_eq!(c.get(&[0, 1]), 22.0);
303 assert_eq!(c.get(&[1, 0]), 43.0);
304 assert_eq!(c.get(&[1, 1]), 50.0);
305 }
306
307 #[test]
308 fn test_bgemm_rect() {
309 let a =
311 StridedArray::<f64>::from_fn_row_major(&[2, 3], |idx| (idx[0] * 3 + idx[1] + 1) as f64);
312 let b =
313 StridedArray::<f64>::from_fn_row_major(&[3, 4], |idx| (idx[0] * 4 + idx[1] + 1) as f64);
314 let mut c = StridedArray::<f64>::row_major(&[2, 4]);
315
316 bgemm_strided_into(
317 &mut c.view_mut(),
318 &a.view(),
319 &b.view(),
320 0,
321 1,
322 1,
323 1,
324 1.0,
325 0.0,
326 false,
327 false,
328 )
329 .unwrap();
330
331 assert_eq!(c.get(&[0, 0]), 38.0);
335 assert_eq!(c.get(&[1, 3]), 128.0);
337 }
338
339 #[test]
340 fn test_bgemm_batched() {
341 let a = StridedArray::<f64>::from_fn_row_major(&[2, 3, 2], |idx| {
344 (idx[2] * 6 + idx[0] * 3 + idx[1] + 1) as f64
346 });
347 let b = StridedArray::<f64>::from_fn_row_major(&[3, 2, 2], |idx| {
348 (idx[2] * 6 + idx[0] * 2 + idx[1] + 1) as f64
350 });
351 let mut c = StridedArray::<f64>::row_major(&[2, 2, 2]);
352
353 bgemm_strided_into(
354 &mut c.view_mut(),
355 &a.view(),
356 &b.view(),
357 1,
358 1,
359 1,
360 1, 1.0,
362 0.0,
363 false,
364 false,
365 )
366 .unwrap();
367
368 assert_eq!(c.get(&[0, 0, 0]), 22.0);
372 }
373
374 #[test]
375 fn test_bgemm_alpha_beta() {
376 let a = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
378 [[1.0, 0.0], [0.0, 1.0]][idx[0]][idx[1]] });
380 let b = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
381 [[1.0, 2.0], [3.0, 4.0]][idx[0]][idx[1]]
382 });
383 let mut c = StridedArray::<f64>::from_fn_row_major(&[2, 2], |idx| {
384 [[10.0, 20.0], [30.0, 40.0]][idx[0]][idx[1]]
385 });
386
387 bgemm_strided_into(
388 &mut c.view_mut(),
389 &a.view(),
390 &b.view(),
391 0,
392 1,
393 1,
394 1,
395 2.0,
396 3.0, false,
398 false,
399 )
400 .unwrap();
401
402 assert_eq!(c.get(&[0, 0]), 32.0);
405 assert_eq!(c.get(&[1, 1]), 128.0);
407 }
408
409 #[test]
410 fn test_bgemm_outer_product() {
411 let a = StridedArray::<f64>::from_fn_row_major(&[3], |idx| (idx[0] + 1) as f64);
414 let b = StridedArray::<f64>::from_fn_row_major(&[4], |idx| (idx[0] + 1) as f64);
415 let mut c = StridedArray::<f64>::row_major(&[3, 4]);
416
417 bgemm_strided_into(
418 &mut c.view_mut(),
419 &a.view(),
420 &b.view(),
421 0,
422 1,
423 1,
424 0, 1.0,
426 0.0,
427 false,
428 false,
429 )
430 .unwrap();
431
432 assert_eq!(c.get(&[0, 0]), 1.0);
433 assert_eq!(c.get(&[2, 3]), 12.0);
434 }
435}