1use num_traits::{One, Zero};
2use smallvec::SmallVec;
3
4use crate::buffer_pool::{BufferPool, PoolScalar};
5use crate::config::DotGeneralConfig;
6use crate::cpu::structural::typed_transpose;
7use crate::types::{col_major_strides, Buffer, TypedTensor};
8use crate::Error;
9
10#[cfg(feature = "cpu-blas")]
11mod blas_gemm;
12#[cfg(feature = "cpu-faer")]
13mod faer_gemm;
14
15#[cfg(feature = "cpu-blas")]
16use blas_gemm::BlasGemm;
17#[cfg(feature = "cpu-faer")]
18use faer_gemm::FaerGemm;
19
20struct GemmDims {
21 m: usize,
22 n: usize,
23 k: usize,
24 batch_total: usize,
25 a_rs: isize,
26 a_cs: isize,
27 a_bs: isize,
28 b_rs: isize,
29 b_cs: isize,
30 b_bs: isize,
31 c_rs: isize,
32 c_cs: isize,
33 c_bs: isize,
34 out_shape: SmallVec<[usize; 8]>,
35}
36
37fn validate_axis_list(
38 op: &'static str,
39 role: &'static str,
40 axes: &[usize],
41 rank: usize,
42) -> crate::Result<()> {
43 let mut seen: SmallVec<[bool; 8]> = smallvec::smallvec![false; rank];
44 for &axis in axes {
45 if axis >= rank {
46 return Err(Error::AxisOutOfBounds { op, axis, rank });
47 }
48 if seen[axis] {
49 return Err(Error::DuplicateAxis { op, axis, role });
50 }
51 seen[axis] = true;
52 }
53 Ok(())
54}
55
56fn validate_role_disjoint(
57 op: &'static str,
58 first_role: &'static str,
59 first_axes: &[usize],
60 second_role: &'static str,
61 second_axes: &[usize],
62) -> crate::Result<()> {
63 for &axis in first_axes {
64 if second_axes.contains(&axis) {
65 return Err(Error::AxisRoleConflict {
66 op,
67 axis,
68 first_role,
69 second_role,
70 });
71 }
72 }
73 Ok(())
74}
75
76fn validate_dot_general<T>(
77 lhs: &TypedTensor<T>,
78 rhs: &TypedTensor<T>,
79 config: &DotGeneralConfig,
80) -> crate::Result<()> {
81 const OP: &str = "dot_general";
82
83 if config.lhs_contracting_dims.len() != config.rhs_contracting_dims.len() {
84 return Err(Error::InvalidConfig {
85 op: OP,
86 message: "lhs/rhs contracting dim counts differ".into(),
87 });
88 }
89 if config.lhs_batch_dims.len() != config.rhs_batch_dims.len() {
90 return Err(Error::InvalidConfig {
91 op: OP,
92 message: "lhs/rhs batch dim counts differ".into(),
93 });
94 }
95
96 let lhs_rank = lhs.shape.len();
97 let rhs_rank = rhs.shape.len();
98 validate_axis_list(
99 OP,
100 "lhs_contracting",
101 &config.lhs_contracting_dims,
102 lhs_rank,
103 )?;
104 validate_axis_list(
105 OP,
106 "rhs_contracting",
107 &config.rhs_contracting_dims,
108 rhs_rank,
109 )?;
110 validate_axis_list(OP, "lhs_batch", &config.lhs_batch_dims, lhs_rank)?;
111 validate_axis_list(OP, "rhs_batch", &config.rhs_batch_dims, rhs_rank)?;
112 validate_role_disjoint(
113 OP,
114 "lhs_contracting",
115 &config.lhs_contracting_dims,
116 "lhs_batch",
117 &config.lhs_batch_dims,
118 )?;
119 validate_role_disjoint(
120 OP,
121 "rhs_contracting",
122 &config.rhs_contracting_dims,
123 "rhs_batch",
124 &config.rhs_batch_dims,
125 )?;
126
127 for (&lhs_axis, &rhs_axis) in config
128 .lhs_contracting_dims
129 .iter()
130 .zip(&config.rhs_contracting_dims)
131 {
132 if lhs.shape[lhs_axis] != rhs.shape[rhs_axis] {
133 return Err(Error::InvalidConfig {
134 op: OP,
135 message: format!(
136 "contracting dim size mismatch: lhs axis {lhs_axis}={} rhs axis {rhs_axis}={}",
137 lhs.shape[lhs_axis], rhs.shape[rhs_axis]
138 ),
139 });
140 }
141 }
142 for (&lhs_axis, &rhs_axis) in config.lhs_batch_dims.iter().zip(&config.rhs_batch_dims) {
143 if lhs.shape[lhs_axis] != rhs.shape[rhs_axis] {
144 return Err(Error::InvalidConfig {
145 op: OP,
146 message: format!(
147 "batch dim size mismatch: lhs axis {lhs_axis}={} rhs axis {rhs_axis}={}",
148 lhs.shape[lhs_axis], rhs.shape[rhs_axis]
149 ),
150 });
151 }
152 }
153
154 Ok(())
155}
156
157fn try_fuse_dims(shapes: &[usize], strides: &[isize]) -> Option<(usize, isize)> {
158 if shapes.is_empty() {
159 return Some((1, 0));
160 }
161 if shapes.len() == 1 {
162 return Some((shapes[0], strides[0]));
163 }
164 let mut dims: SmallVec<[(usize, isize); 8]> = shapes
165 .iter()
166 .copied()
167 .zip(strides.iter().copied())
168 .collect();
169 dims.sort_by_key(|&(_, stride)| stride.unsigned_abs());
170 let base_stride = dims[0].1;
171 let mut expected = base_stride;
172 for (shape, stride) in dims {
173 if stride != expected {
174 return None;
175 }
176 expected = stride.checked_mul(shape as isize)?;
177 }
178 Some((shapes.iter().product(), base_stride))
179}
180
181fn stride_sort_order(strides: &[isize]) -> SmallVec<[usize; 8]> {
182 let mut order: SmallVec<[usize; 8]> = (0..strides.len()).collect();
183 order.sort_by_key(|&idx| strides[idx].unsigned_abs());
184 order
185}
186
187fn is_identity_order(order: &[usize]) -> bool {
188 order.iter().enumerate().all(|(idx, &value)| idx == value)
189}
190
191fn canonical_gemm_layout(
197 config: &DotGeneralConfig,
198 lhs_rank: usize,
199 rhs_rank: usize,
200) -> (SmallVec<[usize; 8]>, SmallVec<[usize; 8]>, DotGeneralConfig) {
201 let lhs_free: SmallVec<[usize; 8]> = (0..lhs_rank)
202 .filter(|d| !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d))
203 .collect();
204 let rhs_free: SmallVec<[usize; 8]> = (0..rhs_rank)
205 .filter(|d| !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d))
206 .collect();
207
208 let mut lhs_perm = SmallVec::<[usize; 8]>::with_capacity(lhs_rank);
209 lhs_perm.extend_from_slice(&lhs_free);
210 lhs_perm.extend_from_slice(&config.lhs_contracting_dims);
211 lhs_perm.extend_from_slice(&config.lhs_batch_dims);
212
213 let mut rhs_perm = SmallVec::<[usize; 8]>::with_capacity(rhs_rank);
214 rhs_perm.extend_from_slice(&config.rhs_contracting_dims);
215 rhs_perm.extend_from_slice(&rhs_free);
216 rhs_perm.extend_from_slice(&config.rhs_batch_dims);
217
218 let nf_lhs = lhs_free.len();
219 let nc = config.lhs_contracting_dims.len();
220 let nb = config.lhs_batch_dims.len();
221 let nf_rhs = rhs_free.len();
222
223 let new_config = DotGeneralConfig {
224 lhs_contracting_dims: (nf_lhs..nf_lhs + nc).collect(),
225 rhs_contracting_dims: (0..nc).collect(),
226 lhs_batch_dims: (nf_lhs + nc..nf_lhs + nc + nb).collect(),
227 rhs_batch_dims: (nc + nf_rhs..nc + nf_rhs + nb).collect(),
228 };
229
230 (lhs_perm, rhs_perm, new_config)
231}
232
233fn is_identity_perm(perm: &[usize]) -> bool {
234 perm.iter().enumerate().all(|(i, &p)| i == p)
235}
236
237fn analyse_gemm<T>(
238 lhs: &TypedTensor<T>,
239 rhs: &TypedTensor<T>,
240 config: &DotGeneralConfig,
241) -> Option<GemmDims> {
242 let lhs_rank = lhs.shape.len();
243 let rhs_rank = rhs.shape.len();
244
245 let lhs_free: SmallVec<[usize; 8]> = (0..lhs_rank)
246 .filter(|d| !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d))
247 .collect();
248 let rhs_free: SmallVec<[usize; 8]> = (0..rhs_rank)
249 .filter(|d| !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d))
250 .collect();
251
252 let lhs_strides: SmallVec<[isize; 8]> = col_major_strides(&lhs.shape).into_iter().collect();
253 let rhs_strides: SmallVec<[isize; 8]> = col_major_strides(&rhs.shape).into_iter().collect();
254
255 let batch_shapes: SmallVec<[usize; 8]> = config
256 .lhs_batch_dims
257 .iter()
258 .map(|&d| lhs.shape[d])
259 .collect();
260 let batch_total: usize = batch_shapes.iter().product();
261
262 let lhs_free_shapes: SmallVec<[usize; 8]> = lhs_free.iter().map(|&d| lhs.shape[d]).collect();
263 let rhs_free_shapes: SmallVec<[usize; 8]> = rhs_free.iter().map(|&d| rhs.shape[d]).collect();
264 let contract_shapes: SmallVec<[usize; 8]> = config
265 .lhs_contracting_dims
266 .iter()
267 .map(|&d| lhs.shape[d])
268 .collect();
269
270 let m: usize = lhs_free_shapes.iter().product();
271 let n: usize = rhs_free_shapes.iter().product();
272 let k: usize = contract_shapes.iter().product();
273
274 let lhs_free_strides: SmallVec<[isize; 8]> = lhs_free.iter().map(|&d| lhs_strides[d]).collect();
275 let rhs_free_strides: SmallVec<[isize; 8]> = rhs_free.iter().map(|&d| rhs_strides[d]).collect();
276 let lhs_contract_strides: SmallVec<[isize; 8]> = config
277 .lhs_contracting_dims
278 .iter()
279 .map(|&d| lhs_strides[d])
280 .collect();
281 let rhs_contract_strides: SmallVec<[isize; 8]> = config
282 .rhs_contracting_dims
283 .iter()
284 .map(|&d| rhs_strides[d])
285 .collect();
286 let lhs_batch_strides: SmallVec<[isize; 8]> = config
287 .lhs_batch_dims
288 .iter()
289 .map(|&d| lhs_strides[d])
290 .collect();
291 let rhs_batch_strides: SmallVec<[isize; 8]> = config
292 .rhs_batch_dims
293 .iter()
294 .map(|&d| rhs_strides[d])
295 .collect();
296
297 if !is_identity_order(&stride_sort_order(&lhs_free_strides))
298 || !is_identity_order(&stride_sort_order(&rhs_free_strides))
299 || !is_identity_order(&stride_sort_order(&lhs_batch_strides))
300 || !is_identity_order(&stride_sort_order(&rhs_batch_strides))
301 || stride_sort_order(&lhs_contract_strides) != stride_sort_order(&rhs_contract_strides)
302 {
303 return None;
304 }
305
306 let (_, a_rs) = try_fuse_dims(&lhs_free_shapes, &lhs_free_strides)?;
307 let (_, a_cs) = try_fuse_dims(&contract_shapes, &lhs_contract_strides)?;
308 let (_, b_rs) = try_fuse_dims(&contract_shapes, &rhs_contract_strides)?;
309 let (_, b_cs) = try_fuse_dims(&rhs_free_shapes, &rhs_free_strides)?;
310 let (_, a_bs) = try_fuse_dims(&batch_shapes, &lhs_batch_strides)?;
311 let (_, b_bs) = try_fuse_dims(&batch_shapes, &rhs_batch_strides)?;
312
313 let mut out_shape = SmallVec::<[usize; 8]>::new();
314 out_shape.extend_from_slice(&lhs_free_shapes);
315 out_shape.extend_from_slice(&rhs_free_shapes);
316 out_shape.extend_from_slice(&batch_shapes);
317
318 let out_strides: SmallVec<[isize; 8]> = col_major_strides(&out_shape).into_iter().collect();
319 let nm = lhs_free_shapes.len();
320 let nn = rhs_free_shapes.len();
321 let out_m_shapes = &out_shape[..nm];
322 let out_m_strides = &out_strides[..nm];
323 let out_n_shapes = &out_shape[nm..nm + nn];
324 let out_n_strides = &out_strides[nm..nm + nn];
325 let out_b_shapes = &out_shape[nm + nn..];
326 let out_b_strides = &out_strides[nm + nn..];
327
328 let (_, c_rs) = try_fuse_dims(out_m_shapes, out_m_strides)?;
329 let (_, c_cs) = try_fuse_dims(out_n_shapes, out_n_strides)?;
330 let (_, c_bs) = try_fuse_dims(out_b_shapes, out_b_strides)?;
331
332 Some(GemmDims {
333 m,
334 n,
335 k,
336 batch_total,
337 a_rs,
338 a_cs,
339 a_bs,
340 b_rs,
341 b_cs,
342 b_bs,
343 c_rs,
344 c_cs,
345 c_bs,
346 out_shape,
347 })
348}
349
350#[cfg(feature = "cpu-faer")]
351pub(crate) fn dot_general<T>(
352 buffers: &mut BufferPool,
353 ctx: &crate::cpu::CpuContext,
354 lhs: &TypedTensor<T>,
355 rhs: &TypedTensor<T>,
356 config: &DotGeneralConfig,
357) -> crate::Result<TypedTensor<T>>
358where
359 T: FaerGemm + PoolScalar + Copy + Clone + Zero + One + PartialEq,
360{
361 validate_dot_general(lhs, rhs, config)?;
362 if let Some(result) = typed_faer_gemm(buffers, ctx, lhs, rhs, config) {
363 return Ok(result);
364 }
365 let (lhs_perm, rhs_perm, new_config) =
366 canonical_gemm_layout(config, lhs.shape.len(), rhs.shape.len());
367 let lhs_canon = if is_identity_perm(&lhs_perm) {
368 std::borrow::Cow::Borrowed(lhs)
369 } else {
370 std::borrow::Cow::Owned(typed_transpose(lhs, &lhs_perm)?)
371 };
372 let rhs_canon = if is_identity_perm(&rhs_perm) {
373 std::borrow::Cow::Borrowed(rhs)
374 } else {
375 std::borrow::Cow::Owned(typed_transpose(rhs, &rhs_perm)?)
376 };
377 typed_faer_gemm(buffers, ctx, &lhs_canon, &rhs_canon, &new_config).ok_or_else(|| {
378 Error::BackendFailure {
379 op: "dot_general",
380 message: "CPU GEMM requires host-backed canonical inputs".into(),
381 }
382 })
383}
384
385#[cfg(feature = "cpu-faer")]
386fn typed_faer_gemm<T>(
387 buffers: &mut BufferPool,
388 ctx: &crate::cpu::CpuContext,
389 lhs: &TypedTensor<T>,
390 rhs: &TypedTensor<T>,
391 config: &DotGeneralConfig,
392) -> Option<TypedTensor<T>>
393where
394 T: FaerGemm + PoolScalar + Copy + Clone + Zero + One + PartialEq,
395{
396 let dims = analyse_gemm(lhs, rhs, config)?;
397 let out_n: usize = dims.out_shape.iter().product();
398 if dims.m == 0 || dims.n == 0 || dims.k == 0 || dims.batch_total == 0 {
399 return Some(TypedTensor {
400 buffer: Buffer::Host(vec![T::zero(); out_n]),
401 shape: dims.out_shape.into_vec(),
402 placement: lhs.placement.clone(),
403 });
404 }
405
406 let a_data = match &lhs.buffer {
407 Buffer::Host(v) => v.as_ptr(),
408 Buffer::Backend(_) => return None,
409 #[cfg(feature = "cubecl")]
410 Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
411 };
412 let b_data = match &rhs.buffer {
413 Buffer::Host(v) => v.as_ptr(),
414 Buffer::Backend(_) => return None,
415 #[cfg(feature = "cubecl")]
416 Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
417 };
418
419 let mut out_data: Vec<T> = unsafe { T::pool_acquire(buffers, out_n) };
421 let c_ptr = out_data.as_mut_ptr();
422
423 for batch in 0..dims.batch_total {
424 let a_off = batch as isize * dims.a_bs;
425 let b_off = batch as isize * dims.b_bs;
426 let c_off = batch as isize * dims.c_bs;
427 unsafe {
428 T::strided_gemm(
429 ctx,
430 T::one(),
431 a_data.offset(a_off),
432 dims.m,
433 dims.k,
434 dims.a_rs,
435 dims.a_cs,
436 b_data.offset(b_off),
437 dims.n,
438 dims.b_rs,
439 dims.b_cs,
440 T::zero(),
441 c_ptr.offset(c_off),
442 dims.c_rs,
443 dims.c_cs,
444 );
445 }
446 }
447
448 Some(TypedTensor {
449 buffer: Buffer::Host(out_data),
450 shape: dims.out_shape.into_vec(),
451 placement: lhs.placement.clone(),
452 })
453}
454
455#[cfg(feature = "cpu-blas")]
456pub(crate) fn dot_general<T>(
457 buffers: &mut BufferPool,
458 lhs: &TypedTensor<T>,
459 rhs: &TypedTensor<T>,
460 config: &DotGeneralConfig,
461) -> crate::Result<TypedTensor<T>>
462where
463 T: BlasGemm + PoolScalar + Copy + Clone + Zero + One,
464{
465 validate_dot_general(lhs, rhs, config)?;
466 if let Some(result) = typed_blas_gemm(buffers, lhs, rhs, config) {
467 return Ok(result);
468 }
469 let (lhs_perm, rhs_perm, new_config) =
470 canonical_gemm_layout(config, lhs.shape.len(), rhs.shape.len());
471 let lhs_canon = if is_identity_perm(&lhs_perm) {
472 std::borrow::Cow::Borrowed(lhs)
473 } else {
474 std::borrow::Cow::Owned(typed_transpose(lhs, &lhs_perm)?)
475 };
476 let rhs_canon = if is_identity_perm(&rhs_perm) {
477 std::borrow::Cow::Borrowed(rhs)
478 } else {
479 std::borrow::Cow::Owned(typed_transpose(rhs, &rhs_perm)?)
480 };
481 typed_blas_gemm(buffers, &lhs_canon, &rhs_canon, &new_config).ok_or_else(|| {
482 Error::BackendFailure {
483 op: "dot_general",
484 message: "CPU GEMM requires host-backed canonical inputs".into(),
485 }
486 })
487}
488
489#[cfg(feature = "cpu-blas")]
490fn typed_blas_gemm<T>(
491 buffers: &mut BufferPool,
492 lhs: &TypedTensor<T>,
493 rhs: &TypedTensor<T>,
494 config: &DotGeneralConfig,
495) -> Option<TypedTensor<T>>
496where
497 T: BlasGemm + PoolScalar + Copy + Clone + Zero + One,
498{
499 let dims = analyse_gemm(lhs, rhs, config)?;
500 let out_n: usize = dims.out_shape.iter().product();
501 if dims.m == 0 || dims.n == 0 || dims.k == 0 || dims.batch_total == 0 {
502 return Some(TypedTensor {
503 buffer: Buffer::Host(vec![T::zero(); out_n]),
504 shape: dims.out_shape.into_vec(),
505 placement: lhs.placement.clone(),
506 });
507 }
508
509 let a_ok = dims.a_rs == 1 || dims.a_cs == 1;
510 let b_ok = dims.b_rs == 1 || dims.b_cs == 1;
511 let c_ok = dims.c_rs == 1;
512 if !a_ok || !b_ok || !c_ok {
513 return None;
514 }
515
516 let a_data = match &lhs.buffer {
517 Buffer::Host(v) => v.as_ptr(),
518 Buffer::Backend(_) => return None,
519 #[cfg(feature = "cubecl")]
520 Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
521 };
522 let b_data = match &rhs.buffer {
523 Buffer::Host(v) => v.as_ptr(),
524 Buffer::Backend(_) => return None,
525 #[cfg(feature = "cubecl")]
526 Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
527 };
528
529 let mut out: Vec<T> = unsafe { T::pool_acquire(buffers, out_n) };
531 let c_ptr = out.as_mut_ptr();
532
533 for batch in 0..dims.batch_total {
534 let a_off = batch as isize * dims.a_bs;
535 let b_off = batch as isize * dims.b_bs;
536 let c_off = batch as isize * dims.c_bs;
537 unsafe {
538 T::strided_gemm(
539 T::one(),
540 a_data.offset(a_off),
541 dims.m,
542 dims.k,
543 dims.a_rs,
544 dims.a_cs,
545 b_data.offset(b_off),
546 dims.n,
547 dims.b_rs,
548 dims.b_cs,
549 T::zero(),
550 c_ptr.offset(c_off),
551 dims.c_rs,
552 dims.c_cs,
553 );
554 }
555 }
556
557 Some(TypedTensor {
558 buffer: Buffer::Host(out),
559 shape: dims.out_shape.into_vec(),
560 placement: lhs.placement.clone(),
561 })
562}
563
564#[cfg(test)]
565mod tests;