1use std::any::Any;
2use std::panic::{catch_unwind, AssertUnwindSafe};
3use std::sync::Arc;
4
5use tenferro_algebra::Semiring;
6
7use crate::backend::{SemiringBackend, TensorBackend, TensorExec};
8use crate::buffer_pool::{BufferPool, PoolScalar};
9use crate::config::{
10 CompareDir, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
11};
12use crate::types::flat_to_multi;
13use crate::validate::validate_nonsingular_u;
14use crate::{Buffer, Tensor, TypedTensor};
15
16use super::exec_session::CpuExecSession;
17use super::{analytic, elementwise, gemm, indexing, linalg, reduction, structural, CpuContext};
18
19pub struct CpuBackend {
29 pub(crate) ctx: Arc<CpuContext>,
30 pub(crate) buffers: BufferPool,
31}
32
33impl CpuBackend {
34 pub fn new() -> Self {
44 Self::from_context(Arc::new(CpuContext::from_env()))
45 }
46
47 pub fn try_new() -> crate::Result<Self> {
58 CpuContext::try_from_env().map(|ctx| Self::from_context(Arc::new(ctx)))
59 }
60
61 pub fn from_context(ctx: Arc<CpuContext>) -> Self {
74 Self {
75 ctx,
76 buffers: BufferPool::new(),
77 }
78 }
79
80 pub fn with_threads(num_threads: usize) -> Self {
91 assert!(num_threads >= 1, "thread count must be >= 1");
92 Self::from_context(Arc::new(CpuContext::with_threads(num_threads)))
93 }
94
95 pub fn num_threads(&self) -> usize {
106 self.ctx.num_threads()
107 }
108
109 pub fn buffer_pool_len(&self) -> usize {
120 self.buffers.len()
121 }
122
123 pub fn install<R>(&self, op: impl FnOnce() -> R + Send) -> R
135 where
136 R: Send,
137 {
138 self.ctx.install(op)
139 }
140
141 fn install_with_pool<R>(&mut self, op: impl FnOnce(&mut BufferPool) -> R + Send) -> R
142 where
143 R: Send,
144 {
145 let mut buffers = std::mem::take(&mut self.buffers);
146 let (result, buffers) = self.ctx.install(|| {
147 let result = op(&mut buffers);
148 (result, buffers)
149 });
150 self.buffers = buffers;
151 result
152 }
153}
154
155impl TensorBackend for CpuBackend {
156 fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
157 self.install(|| elementwise::add(lhs, rhs))
158 }
159
160 fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
161 self.install(|| elementwise::mul(lhs, rhs))
162 }
163
164 fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor> {
165 self.install(|| elementwise::neg(input))
166 }
167
168 fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor> {
169 self.install(|| elementwise::conj(input))
170 }
171
172 fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
173 self.install(|| elementwise::div(lhs, rhs))
174 }
175
176 fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor> {
177 self.install(|| elementwise::abs(input))
178 }
179
180 fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor> {
181 self.install(|| elementwise::sign(input))
182 }
183
184 fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
185 self.install(|| elementwise::maximum(lhs, rhs))
186 }
187
188 fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
189 self.install(|| elementwise::minimum(lhs, rhs))
190 }
191
192 fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor> {
193 self.install(|| elementwise::compare(lhs, rhs, dir))
194 }
195
196 fn select(
197 &mut self,
198 pred: &Tensor,
199 on_true: &Tensor,
200 on_false: &Tensor,
201 ) -> crate::Result<Tensor> {
202 self.install(|| elementwise::select(pred, on_true, on_false))
203 }
204
205 fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor> {
206 self.install(|| elementwise::clamp(input, lower, upper))
207 }
208
209 fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor> {
210 self.install(|| analytic::exp(input))
211 }
212
213 fn log(&mut self, input: &Tensor) -> crate::Result<Tensor> {
214 self.install(|| analytic::log(input))
215 }
216
217 fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor> {
218 self.install(|| analytic::sin(input))
219 }
220
221 fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor> {
222 self.install(|| analytic::cos(input))
223 }
224
225 fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor> {
226 self.install(|| analytic::tanh(input))
227 }
228
229 fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor> {
230 self.install(|| analytic::sqrt(input))
231 }
232
233 fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor> {
234 self.install(|| analytic::rsqrt(input))
235 }
236
237 fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
238 self.install(|| analytic::pow(lhs, rhs))
239 }
240
241 fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor> {
242 self.install(|| analytic::expm1(input))
243 }
244
245 fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor> {
246 self.install(|| analytic::log1p(input))
247 }
248
249 fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor> {
250 self.install(|| structural::transpose(input, perm))
251 }
252
253 fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor> {
254 self.install(|| structural::reshape(input, shape))
255 }
256
257 fn broadcast_in_dim(
258 &mut self,
259 input: &Tensor,
260 shape: &[usize],
261 dims: &[usize],
262 ) -> crate::Result<Tensor> {
263 self.install(|| structural::broadcast_in_dim(input, shape, dims))
264 }
265
266 fn convert(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor> {
267 Ok(self.install(|| structural::convert(input, to)))
268 }
269
270 fn extract_diagonal(
271 &mut self,
272 input: &Tensor,
273 axis_a: usize,
274 axis_b: usize,
275 ) -> crate::Result<Tensor> {
276 self.install(|| structural::extract_diagonal(input, axis_a, axis_b))
277 }
278
279 fn embed_diagonal(
280 &mut self,
281 input: &Tensor,
282 axis_a: usize,
283 axis_b: usize,
284 ) -> crate::Result<Tensor> {
285 self.install(|| structural::embed_diagonal(input, axis_a, axis_b))
286 }
287
288 fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor> {
289 self.install(|| structural::tril(input, k))
290 }
291
292 fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor> {
293 self.install(|| structural::triu(input, k))
294 }
295
296 fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
297 self.install(|| reduction::reduce_sum(input, axes))
298 }
299
300 fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
301 self.install(|| reduction::reduce_prod(input, axes))
302 }
303
304 fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
305 self.install(|| reduction::reduce_max(input, axes))
306 }
307
308 fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
309 self.install(|| reduction::reduce_min(input, axes))
310 }
311
312 fn dot_general(
313 &mut self,
314 lhs: &Tensor,
315 rhs: &Tensor,
316 config: &DotGeneralConfig,
317 ) -> crate::Result<Tensor> {
318 let ctx = Arc::clone(&self.ctx);
319 self.install_with_pool(|buffers| match (lhs, rhs) {
320 #[cfg(feature = "cpu-faer")]
321 (Tensor::F32(a), Tensor::F32(b)) => {
322 gemm::dot_general(buffers, ctx.as_ref(), a, b, config).map(Tensor::F32)
323 }
324 #[cfg(feature = "cpu-faer")]
325 (Tensor::F64(a), Tensor::F64(b)) => {
326 gemm::dot_general(buffers, ctx.as_ref(), a, b, config).map(Tensor::F64)
327 }
328 #[cfg(feature = "cpu-faer")]
329 (Tensor::C32(a), Tensor::C32(b)) => {
330 gemm::dot_general(buffers, ctx.as_ref(), a, b, config).map(Tensor::C32)
331 }
332 #[cfg(feature = "cpu-faer")]
333 (Tensor::C64(a), Tensor::C64(b)) => {
334 gemm::dot_general(buffers, ctx.as_ref(), a, b, config).map(Tensor::C64)
335 }
336 #[cfg(feature = "cpu-blas")]
337 (Tensor::F32(a), Tensor::F32(b)) => {
338 gemm::dot_general(buffers, a, b, config).map(Tensor::F32)
339 }
340 #[cfg(feature = "cpu-blas")]
341 (Tensor::F64(a), Tensor::F64(b)) => {
342 gemm::dot_general(buffers, a, b, config).map(Tensor::F64)
343 }
344 #[cfg(feature = "cpu-blas")]
345 (Tensor::C32(a), Tensor::C32(b)) => {
346 gemm::dot_general(buffers, a, b, config).map(Tensor::C32)
347 }
348 #[cfg(feature = "cpu-blas")]
349 (Tensor::C64(a), Tensor::C64(b)) => {
350 gemm::dot_general(buffers, a, b, config).map(Tensor::C64)
351 }
352 _ => Err(crate::Error::DTypeMismatch {
353 op: "dot_general",
354 lhs: lhs.dtype(),
355 rhs: rhs.dtype(),
356 }),
357 })
358 }
359
360 fn gather(
361 &mut self,
362 operand: &Tensor,
363 start_indices: &Tensor,
364 config: &GatherConfig,
365 ) -> crate::Result<Tensor> {
366 self.install(|| {
367 catch_backend_panic("gather", || {
368 indexing::gather(operand, start_indices, config)
369 })
370 })
371 }
372
373 fn scatter(
374 &mut self,
375 operand: &Tensor,
376 scatter_indices: &Tensor,
377 updates: &Tensor,
378 config: &ScatterConfig,
379 ) -> crate::Result<Tensor> {
380 self.install(|| {
381 catch_backend_panic("scatter", || {
382 indexing::scatter(operand, scatter_indices, updates, config)
383 })
384 })
385 }
386
387 fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor> {
388 self.install(|| indexing::try_slice(input, config))
389 }
390
391 fn dynamic_slice(
392 &mut self,
393 input: &Tensor,
394 starts: &Tensor,
395 slice_sizes: &[usize],
396 ) -> crate::Result<Tensor> {
397 self.install(|| {
398 catch_backend_panic("dynamic_slice", || {
399 indexing::dynamic_slice(input, starts, slice_sizes)
400 })
401 })
402 }
403
404 fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
405 self.install(|| indexing::try_pad(input, config))
406 }
407
408 fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor> {
409 self.install(|| indexing::try_concatenate(inputs, axis))
410 }
411
412 fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
413 self.install(|| catch_backend_panic("reverse", || indexing::reverse(input, axes)))
414 }
415
416 fn cholesky(&mut self, input: &Tensor) -> crate::Result<Tensor> {
417 let ctx = Arc::clone(&self.ctx);
418 self.install_with_pool(|buffers| match input {
419 #[cfg(feature = "cpu-faer")]
420 Tensor::F64(t) => {
421 catch_backend_panic("cholesky", || linalg::cholesky(ctx.as_ref(), buffers, t))
422 .and_then(|result| result)
423 .map(Tensor::F64)
424 }
425 #[cfg(feature = "cpu-blas")]
426 Tensor::F64(t) => {
427 catch_backend_panic("cholesky", || linalg::cholesky(buffers, t)).map(Tensor::F64)
428 }
429 #[cfg(feature = "cpu-faer")]
430 Tensor::C64(t) => {
431 catch_backend_panic("cholesky", || linalg::cholesky(ctx.as_ref(), buffers, t))
432 .and_then(|result| result)
433 .map(Tensor::C64)
434 }
435 _ => Err(unsupported_dtype("cholesky", input.dtype())),
436 })
437 }
438
439 fn triangular_solve(
440 &mut self,
441 a: &Tensor,
442 b: &Tensor,
443 left_side: bool,
444 lower: bool,
445 transpose_a: bool,
446 unit_diagonal: bool,
447 ) -> crate::Result<Tensor> {
448 let ctx = Arc::clone(&self.ctx);
449 self.install_with_pool(|buffers| match (a, b) {
450 #[cfg(feature = "cpu-faer")]
451 (Tensor::F64(a), Tensor::F64(b)) => catch_backend_panic("triangular_solve", || {
452 Tensor::F64(linalg::triangular_solve(
453 ctx.as_ref(),
454 buffers,
455 a,
456 b,
457 left_side,
458 lower,
459 transpose_a,
460 unit_diagonal,
461 ))
462 }),
463 #[cfg(feature = "cpu-blas")]
464 (Tensor::F64(a), Tensor::F64(b)) => catch_backend_panic("triangular_solve", || {
465 Tensor::F64(linalg::triangular_solve(
466 buffers,
467 a,
468 b,
469 left_side,
470 lower,
471 transpose_a,
472 unit_diagonal,
473 ))
474 }),
475 #[cfg(feature = "cpu-faer")]
476 (Tensor::C64(a), Tensor::C64(b)) => catch_backend_panic("triangular_solve", || {
477 Tensor::C64(linalg::triangular_solve(
478 ctx.as_ref(),
479 buffers,
480 a,
481 b,
482 left_side,
483 lower,
484 transpose_a,
485 unit_diagonal,
486 ))
487 }),
488 _ => {
489 if a.dtype() != b.dtype() {
490 Err(crate::Error::DTypeMismatch {
491 op: "triangular_solve",
492 lhs: a.dtype(),
493 rhs: b.dtype(),
494 })
495 } else {
496 Err(unsupported_dtype("triangular_solve", a.dtype()))
497 }
498 }
499 })
500 }
501
502 fn lu(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
503 let ctx = Arc::clone(&self.ctx);
504 self.install_with_pool(|buffers| match input {
505 #[cfg(feature = "cpu-faer")]
506 Tensor::F64(t) => catch_backend_panic("lu", || {
507 linalg::lu(ctx.as_ref(), buffers, t)
508 .into_iter()
509 .map(Tensor::F64)
510 .collect()
511 }),
512 #[cfg(feature = "cpu-blas")]
513 Tensor::F64(t) => catch_backend_panic("lu", || {
514 linalg::lu(buffers, t)
515 .into_iter()
516 .map(Tensor::F64)
517 .collect()
518 }),
519 #[cfg(feature = "cpu-faer")]
520 Tensor::C64(t) => catch_backend_panic("lu", || {
521 linalg::lu(ctx.as_ref(), buffers, t)
522 .into_iter()
523 .map(Tensor::C64)
524 .collect()
525 }),
526 _ => Err(unsupported_dtype("lu", input.dtype())),
527 })
528 }
529
530 fn svd(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
531 let ctx = Arc::clone(&self.ctx);
532 self.install_with_pool(|buffers| match input {
533 #[cfg(feature = "cpu-faer")]
534 Tensor::F64(t) => catch_backend_panic("svd", || {
535 linalg::svd(ctx.as_ref(), buffers, t)
536 .into_iter()
537 .map(Tensor::F64)
538 .collect()
539 }),
540 #[cfg(feature = "cpu-blas")]
541 Tensor::F64(t) => catch_backend_panic("svd", || {
542 linalg::svd(buffers, t)
543 .into_iter()
544 .map(Tensor::F64)
545 .collect()
546 }),
547 #[cfg(feature = "cpu-faer")]
548 Tensor::C64(t) => catch_backend_panic("svd", || {
549 linalg::svd(ctx.as_ref(), buffers, t)
550 .into_iter()
551 .map(Tensor::C64)
552 .collect()
553 }),
554 _ => Err(unsupported_dtype("svd", input.dtype())),
555 })
556 }
557
558 fn qr(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
559 let ctx = Arc::clone(&self.ctx);
560 self.install_with_pool(|buffers| match input {
561 #[cfg(feature = "cpu-faer")]
562 Tensor::F64(t) => catch_backend_panic("qr", || {
563 linalg::qr(ctx.as_ref(), buffers, t)
564 .into_iter()
565 .map(Tensor::F64)
566 .collect()
567 }),
568 #[cfg(feature = "cpu-blas")]
569 Tensor::F64(t) => catch_backend_panic("qr", || {
570 linalg::qr(buffers, t)
571 .into_iter()
572 .map(Tensor::F64)
573 .collect()
574 }),
575 #[cfg(feature = "cpu-faer")]
576 Tensor::C64(t) => catch_backend_panic("qr", || {
577 linalg::qr(ctx.as_ref(), buffers, t)
578 .into_iter()
579 .map(Tensor::C64)
580 .collect()
581 }),
582 _ => Err(unsupported_dtype("qr", input.dtype())),
583 })
584 }
585
586 fn eigh(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
587 let ctx = Arc::clone(&self.ctx);
588 self.install_with_pool(|buffers| match input {
589 #[cfg(feature = "cpu-faer")]
590 Tensor::F64(t) => catch_backend_panic("eigh", || {
591 linalg::eigh(ctx.as_ref(), buffers, t)
592 .into_iter()
593 .map(Tensor::F64)
594 .collect()
595 }),
596 #[cfg(feature = "cpu-blas")]
597 Tensor::F64(t) => catch_backend_panic("eigh", || {
598 linalg::eigh(buffers, t)
599 .into_iter()
600 .map(Tensor::F64)
601 .collect()
602 }),
603 #[cfg(feature = "cpu-faer")]
604 Tensor::C64(t) => catch_backend_panic("eigh", || {
605 linalg::eigh(ctx.as_ref(), buffers, t)
606 .into_iter()
607 .map(Tensor::C64)
608 .collect()
609 }),
610 _ => Err(unsupported_dtype("eigh", input.dtype())),
611 })
612 }
613
614 fn eig(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
615 let ctx = Arc::clone(&self.ctx);
616 self.install_with_pool(|buffers| {
617 catch_backend_panic("eig", || {
618 #[cfg(feature = "cpu-faer")]
619 {
620 linalg::eig(ctx.as_ref(), buffers, input)
621 }
622 #[cfg(feature = "cpu-blas")]
623 {
624 linalg::eig(buffers, input)
625 }
626 })
627 })
628 }
629
630 fn solve(&mut self, a: &Tensor, b: &Tensor) -> crate::Result<Tensor> {
631 if has_zero_dim(a.shape()) || has_zero_dim(b.shape()) {
632 return Ok(zeros_like_tensor(b));
633 }
634
635 let (rhs, restore_shape) = if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
636 (
637 self.reshape(b, &matrix_rhs_shape)?,
638 Some(b.shape().to_vec()),
639 )
640 } else {
641 (b.clone(), None)
642 };
643
644 let outputs = self.lu(a)?;
645 let p = &outputs[0];
646 let l = &outputs[1];
647 let u = &outputs[2];
648 validate_nonsingular_u(u)?;
649
650 let pb = matmul_preserve_trailing_batch(self, p, &rhs)?;
651 let z = self.triangular_solve(l, &pb, true, true, false, true)?;
652 let x = self.triangular_solve(u, &z, true, false, false, false)?;
653 if let Some(shape) = restore_shape {
654 self.reshape(&x, &shape)
655 } else {
656 Ok(x)
657 }
658 }
659
660 fn with_exec_session<R: Send>(&mut self, f: impl FnOnce(&mut dyn TensorExec) -> R + Send) -> R {
661 let mut buffers = std::mem::take(&mut self.buffers);
662 let ctx = Arc::clone(&self.ctx);
663 let result = ctx.install(|| {
664 let mut session = CpuExecSession {
665 ctx: ctx.as_ref(),
666 buffers: &mut buffers,
667 };
668 f(&mut session)
669 });
670 self.buffers = buffers;
671 result
672 }
673
674 fn reclaim_buffer(&mut self, tensor: Tensor) {
675 match tensor {
676 Tensor::F32(t) => reclaim_typed(&mut self.buffers, t),
677 Tensor::F64(t) => reclaim_typed(&mut self.buffers, t),
678 Tensor::C32(t) => reclaim_typed(&mut self.buffers, t),
679 Tensor::C64(t) => reclaim_typed(&mut self.buffers, t),
680 }
681 }
682}
683
684fn has_zero_dim(shape: &[usize]) -> bool {
685 shape.contains(&0)
686}
687
688fn batched_vector_rhs_shape(a: &Tensor, b: &Tensor) -> Option<Vec<usize>> {
689 if b.shape().len() == 1 {
690 return Some(vec![b.shape()[0], 1]);
691 }
692
693 let is_batched_vector_rhs = a.shape().len() == b.shape().len() + 1
694 && !b.shape().is_empty()
695 && b.shape()[0] == a.shape()[0]
696 && b.shape()[1..] == a.shape()[2..];
697 if !is_batched_vector_rhs {
698 return None;
699 }
700
701 let mut rhs_shape = vec![b.shape()[0], 1];
702 rhs_shape.extend_from_slice(&b.shape()[1..]);
703 Some(rhs_shape)
704}
705
706fn matmul_preserve_trailing_batch(
707 backend: &mut CpuBackend,
708 lhs: &Tensor,
709 rhs: &Tensor,
710) -> crate::Result<Tensor> {
711 let rank = lhs.shape().len();
712 let batch_dims: Vec<usize> = (2..rank).collect();
713 backend.dot_general(
714 lhs,
715 rhs,
716 &DotGeneralConfig {
717 lhs_contracting_dims: vec![1],
718 rhs_contracting_dims: vec![0],
719 lhs_batch_dims: batch_dims.clone(),
720 rhs_batch_dims: batch_dims,
721 lhs_rank: rank,
722 rhs_rank: rank,
723 },
724 )
725}
726
727pub(crate) fn reclaim_typed<T: PoolScalar>(pool: &mut BufferPool, typed: TypedTensor<T>) {
728 match typed.buffer {
729 Buffer::Host(data) => T::pool_release(pool, data),
730 Buffer::Backend(_) => {}
731 #[cfg(feature = "cubecl")]
732 Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
733 }
734}
735
736fn zeros_like_tensor(input: &Tensor) -> Tensor {
737 match input {
738 Tensor::F32(t) => Tensor::F32(TypedTensor::zeros(t.shape.clone())),
739 Tensor::F64(t) => Tensor::F64(TypedTensor::zeros(t.shape.clone())),
740 Tensor::C32(t) => Tensor::C32(TypedTensor::zeros(t.shape.clone())),
741 Tensor::C64(t) => Tensor::C64(TypedTensor::zeros(t.shape.clone())),
742 }
743}
744
745fn panic_payload_message(payload: Box<dyn Any + Send>) -> String {
746 if let Some(message) = payload.downcast_ref::<&str>() {
747 (*message).to_string()
748 } else if let Some(message) = payload.downcast_ref::<String>() {
749 message.clone()
750 } else {
751 "backend panic".into()
752 }
753}
754
755pub(crate) fn catch_backend_panic<R>(op: &'static str, f: impl FnOnce() -> R) -> crate::Result<R> {
756 catch_unwind(AssertUnwindSafe(f)).map_err(|payload| crate::Error::BackendFailure {
757 op,
758 message: panic_payload_message(payload),
759 })
760}
761
762pub(crate) fn unsupported_dtype(op: &'static str, dtype: crate::DType) -> crate::Error {
763 crate::Error::BackendFailure {
764 op,
765 message: format!("unsupported dtype {dtype:?}"),
766 }
767}
768
769fn validate_axis_role_conflicts(
770 op: &'static str,
771 first_role: &'static str,
772 first_axes: &[usize],
773 second_role: &'static str,
774 second_axes: &[usize],
775) -> crate::Result<()> {
776 for &axis in first_axes {
777 if second_axes.contains(&axis) {
778 return Err(crate::Error::AxisRoleConflict {
779 op,
780 axis,
781 first_role,
782 second_role,
783 });
784 }
785 }
786 Ok(())
787}
788
789fn validate_axis_list(
790 op: &'static str,
791 role: &'static str,
792 axes: &[usize],
793 rank: usize,
794) -> crate::Result<()> {
795 let mut seen = vec![false; rank];
796 for &axis in axes {
797 if axis >= rank {
798 return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
799 }
800 if seen[axis] {
801 return Err(crate::Error::DuplicateAxis { op, axis, role });
802 }
803 seen[axis] = true;
804 }
805 Ok(())
806}
807
808fn validate_semiring_batched_gemm_config<T>(
809 lhs: &TypedTensor<T>,
810 rhs: &TypedTensor<T>,
811 config: &DotGeneralConfig,
812) -> crate::Result<()> {
813 const OP: &str = "batched_gemm";
814
815 if config.lhs_contracting_dims.len() != config.rhs_contracting_dims.len() {
816 return Err(crate::Error::InvalidConfig {
817 op: OP,
818 message: "contracting dim count mismatch".into(),
819 });
820 }
821 if config.lhs_batch_dims.len() != config.rhs_batch_dims.len() {
822 return Err(crate::Error::InvalidConfig {
823 op: OP,
824 message: "batch dim count mismatch".into(),
825 });
826 }
827
828 let lhs_rank = lhs.shape.len();
829 let rhs_rank = rhs.shape.len();
830 validate_axis_list(
831 OP,
832 "lhs_contracting_dims",
833 &config.lhs_contracting_dims,
834 lhs_rank,
835 )?;
836 validate_axis_list(
837 OP,
838 "rhs_contracting_dims",
839 &config.rhs_contracting_dims,
840 rhs_rank,
841 )?;
842 validate_axis_list(OP, "lhs_batch_dims", &config.lhs_batch_dims, lhs_rank)?;
843 validate_axis_list(OP, "rhs_batch_dims", &config.rhs_batch_dims, rhs_rank)?;
844 validate_axis_role_conflicts(
845 OP,
846 "lhs_contracting_dims",
847 &config.lhs_contracting_dims,
848 "lhs_batch_dims",
849 &config.lhs_batch_dims,
850 )?;
851 validate_axis_role_conflicts(
852 OP,
853 "rhs_contracting_dims",
854 &config.rhs_contracting_dims,
855 "rhs_batch_dims",
856 &config.rhs_batch_dims,
857 )?;
858
859 for (&lhs_axis, &rhs_axis) in config
860 .lhs_contracting_dims
861 .iter()
862 .zip(&config.rhs_contracting_dims)
863 {
864 if lhs.shape[lhs_axis] != rhs.shape[rhs_axis] {
865 return Err(crate::Error::ShapeMismatch {
866 op: OP,
867 lhs: vec![lhs.shape[lhs_axis]],
868 rhs: vec![rhs.shape[rhs_axis]],
869 });
870 }
871 }
872
873 for (&lhs_axis, &rhs_axis) in config.lhs_batch_dims.iter().zip(&config.rhs_batch_dims) {
874 if lhs.shape[lhs_axis] != rhs.shape[rhs_axis] {
875 return Err(crate::Error::ShapeMismatch {
876 op: OP,
877 lhs: vec![lhs.shape[lhs_axis]],
878 rhs: vec![rhs.shape[rhs_axis]],
879 });
880 }
881 }
882
883 Ok(())
884}
885
886impl<Alg: Semiring> SemiringBackend<Alg> for CpuBackend {
887 fn batched_gemm(
888 &mut self,
889 lhs: &TypedTensor<Alg::Scalar>,
890 rhs: &TypedTensor<Alg::Scalar>,
891 config: &DotGeneralConfig,
892 ) -> crate::Result<TypedTensor<Alg::Scalar>> {
893 validate_semiring_batched_gemm_config(lhs, rhs, config)?;
894 Ok(self.install(|| {
895 let lhs_rank = lhs.shape.len();
896 let rhs_rank = rhs.shape.len();
897 let lhs_free: Vec<usize> = (0..lhs_rank)
898 .filter(|d| {
899 !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d)
900 })
901 .collect();
902 let rhs_free: Vec<usize> = (0..rhs_rank)
903 .filter(|d| {
904 !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d)
905 })
906 .collect();
907
908 let batch_shape: Vec<usize> = config
909 .lhs_batch_dims
910 .iter()
911 .map(|&d| lhs.shape[d])
912 .collect();
913 let lhs_free_shape: Vec<usize> = lhs_free.iter().map(|&d| lhs.shape[d]).collect();
914 let rhs_free_shape: Vec<usize> = rhs_free.iter().map(|&d| rhs.shape[d]).collect();
915 let contract_shape: Vec<usize> = config
916 .lhs_contracting_dims
917 .iter()
918 .map(|&d| lhs.shape[d])
919 .collect();
920
921 let mut out_shape = Vec::new();
922 out_shape.extend_from_slice(&lhs_free_shape);
923 out_shape.extend_from_slice(&rhs_free_shape);
924 out_shape.extend_from_slice(&batch_shape);
925
926 let out_n: usize = out_shape.iter().product();
927 let contract_n: usize = contract_shape.iter().product();
928
929 let mut result = TypedTensor::zeros(out_shape.clone());
930 let n_lhs_free = lhs_free_shape.len();
931 let n_rhs_free = rhs_free_shape.len();
932
933 let mut out_idx = vec![0usize; out_shape.len()];
934 let mut lhs_idx = vec![0usize; lhs_rank];
935 let mut rhs_idx = vec![0usize; rhs_rank];
936 let mut contract_idx = vec![0usize; contract_shape.len()];
937
938 for flat_out in 0..out_n {
939 flat_to_multi(flat_out, &out_shape, &mut out_idx);
940
941 let lhs_free_vals = &out_idx[..n_lhs_free];
942 let rhs_free_vals = &out_idx[n_lhs_free..n_lhs_free + n_rhs_free];
943 let batch_vals = &out_idx[n_lhs_free + n_rhs_free..];
944
945 for (bi, &ld) in config.lhs_batch_dims.iter().enumerate() {
946 lhs_idx[ld] = batch_vals[bi];
947 }
948 for (bi, &rd) in config.rhs_batch_dims.iter().enumerate() {
949 rhs_idx[rd] = batch_vals[bi];
950 }
951 for (fi, &ld) in lhs_free.iter().enumerate() {
952 lhs_idx[ld] = lhs_free_vals[fi];
953 }
954 for (fi, &rd) in rhs_free.iter().enumerate() {
955 rhs_idx[rd] = rhs_free_vals[fi];
956 }
957
958 let mut acc = Alg::zero();
959 for flat_k in 0..contract_n {
960 flat_to_multi(flat_k, &contract_shape, &mut contract_idx);
961 for (ci, &ld) in config.lhs_contracting_dims.iter().enumerate() {
962 lhs_idx[ld] = contract_idx[ci];
963 }
964 for (ci, &rd) in config.rhs_contracting_dims.iter().enumerate() {
965 rhs_idx[rd] = contract_idx[ci];
966 }
967 acc = Alg::add(acc, Alg::mul(*lhs.get(&lhs_idx), *rhs.get(&rhs_idx)));
968 }
969
970 *result.get_mut(&out_idx) = acc;
971 }
972
973 result
974 }))
975 }
976}
977
978impl Default for CpuBackend {
979 fn default() -> Self {
980 Self::new()
981 }
982}