Skip to main content

tenferro_tensor/cpu/
backend.rs

1use std::any::Any;
2use std::panic::{catch_unwind, AssertUnwindSafe};
3use std::sync::Arc;
4
5use tenferro_algebra::Semiring;
6
7use crate::backend::{SemiringBackend, TensorBackend, TensorExec};
8use crate::buffer_pool::{BufferPool, PoolScalar};
9use crate::config::{
10    CompareDir, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
11};
12use crate::types::flat_to_multi;
13use crate::validate::validate_nonsingular_u;
14use crate::{Buffer, Tensor, TypedTensor};
15
16use super::exec_session::CpuExecSession;
17use super::{analytic, elementwise, gemm, indexing, linalg, reduction, structural, CpuContext};
18
19/// CPU execution backend.
20///
21/// # Examples
22///
23/// ```ignore
24/// use tenferro_tensor::cpu::CpuBackend;
25///
26/// let backend = CpuBackend::new();
27/// ```
28pub struct CpuBackend {
29    pub(crate) ctx: Arc<CpuContext>,
30    pub(crate) buffers: BufferPool,
31}
32
33impl CpuBackend {
34    /// Create a CPU backend using the environment-driven CPU context.
35    ///
36    /// # Examples
37    ///
38    /// ```ignore
39    /// use tenferro_tensor::cpu::CpuBackend;
40    ///
41    /// let backend = CpuBackend::new();
42    /// ```
43    pub fn new() -> Self {
44        Self::from_context(Arc::new(CpuContext::from_env()))
45    }
46
47    /// Try to create a CPU backend using `RAYON_NUM_THREADS`.
48    ///
49    /// # Examples
50    ///
51    /// ```ignore
52    /// use tenferro_tensor::cpu::CpuBackend;
53    ///
54    /// let backend = CpuBackend::try_new().unwrap();
55    /// let _ = backend.num_threads();
56    /// ```
57    pub fn try_new() -> crate::Result<Self> {
58        CpuContext::try_from_env().map(|ctx| Self::from_context(Arc::new(ctx)))
59    }
60
61    /// Create a CPU backend from an existing context.
62    ///
63    /// # Examples
64    ///
65    /// ```ignore
66    /// use std::sync::Arc;
67    /// use tenferro_tensor::cpu::{CpuBackend, CpuContext};
68    ///
69    /// let ctx = Arc::new(CpuContext::with_threads(2));
70    /// let backend = CpuBackend::from_context(ctx);
71    /// assert_eq!(backend.num_threads(), 2);
72    /// ```
73    pub fn from_context(ctx: Arc<CpuContext>) -> Self {
74        Self {
75            ctx,
76            buffers: BufferPool::new(),
77        }
78    }
79
80    /// Create a CPU backend with a custom thread count.
81    ///
82    /// # Examples
83    ///
84    /// ```ignore
85    /// use tenferro_tensor::cpu::CpuBackend;
86    ///
87    /// let backend = CpuBackend::with_threads(2);
88    /// assert_eq!(backend.num_threads(), 2);
89    /// ```
90    pub fn with_threads(num_threads: usize) -> Self {
91        assert!(num_threads >= 1, "thread count must be >= 1");
92        Self::from_context(Arc::new(CpuContext::with_threads(num_threads)))
93    }
94
95    /// Return the number of threads in this backend's CPU context.
96    ///
97    /// # Examples
98    ///
99    /// ```ignore
100    /// use tenferro_tensor::cpu::CpuBackend;
101    ///
102    /// let backend = CpuBackend::with_threads(2);
103    /// assert_eq!(backend.num_threads(), 2);
104    /// ```
105    pub fn num_threads(&self) -> usize {
106        self.ctx.num_threads()
107    }
108
109    /// Number of retained typed host buffers currently held by this backend.
110    ///
111    /// # Examples
112    ///
113    /// ```ignore
114    /// use tenferro_tensor::cpu::CpuBackend;
115    ///
116    /// let backend = CpuBackend::new();
117    /// assert_eq!(backend.buffer_pool_len(), 0);
118    /// ```
119    pub fn buffer_pool_len(&self) -> usize {
120        self.buffers.len()
121    }
122
123    /// Run a closure inside this backend's shared rayon thread pool.
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// use tenferro_tensor::cpu::CpuBackend;
129    ///
130    /// let backend = CpuBackend::with_threads(1);
131    /// let value = backend.install(|| 1 + 1);
132    /// assert_eq!(value, 2);
133    /// ```
134    pub fn install<R>(&self, op: impl FnOnce() -> R + Send) -> R
135    where
136        R: Send,
137    {
138        self.ctx.install(op)
139    }
140
141    fn install_with_pool<R>(&mut self, op: impl FnOnce(&mut BufferPool) -> R + Send) -> R
142    where
143        R: Send,
144    {
145        let mut buffers = std::mem::take(&mut self.buffers);
146        let (result, buffers) = self.ctx.install(|| {
147            let result = op(&mut buffers);
148            (result, buffers)
149        });
150        self.buffers = buffers;
151        result
152    }
153}
154
155impl TensorBackend for CpuBackend {
156    fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
157        self.install(|| elementwise::add(lhs, rhs))
158    }
159
160    fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
161        self.install(|| elementwise::mul(lhs, rhs))
162    }
163
164    fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor> {
165        self.install(|| elementwise::neg(input))
166    }
167
168    fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor> {
169        self.install(|| elementwise::conj(input))
170    }
171
172    fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
173        self.install(|| elementwise::div(lhs, rhs))
174    }
175
176    fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor> {
177        self.install(|| elementwise::abs(input))
178    }
179
180    fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor> {
181        self.install(|| elementwise::sign(input))
182    }
183
184    fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
185        self.install(|| elementwise::maximum(lhs, rhs))
186    }
187
188    fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
189        self.install(|| elementwise::minimum(lhs, rhs))
190    }
191
192    fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor> {
193        self.install(|| elementwise::compare(lhs, rhs, dir))
194    }
195
196    fn select(
197        &mut self,
198        pred: &Tensor,
199        on_true: &Tensor,
200        on_false: &Tensor,
201    ) -> crate::Result<Tensor> {
202        self.install(|| elementwise::select(pred, on_true, on_false))
203    }
204
205    fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor> {
206        self.install(|| elementwise::clamp(input, lower, upper))
207    }
208
209    fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor> {
210        self.install(|| analytic::exp(input))
211    }
212
213    fn log(&mut self, input: &Tensor) -> crate::Result<Tensor> {
214        self.install(|| analytic::log(input))
215    }
216
217    fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor> {
218        self.install(|| analytic::sin(input))
219    }
220
221    fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor> {
222        self.install(|| analytic::cos(input))
223    }
224
225    fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor> {
226        self.install(|| analytic::tanh(input))
227    }
228
229    fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor> {
230        self.install(|| analytic::sqrt(input))
231    }
232
233    fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor> {
234        self.install(|| analytic::rsqrt(input))
235    }
236
237    fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
238        self.install(|| analytic::pow(lhs, rhs))
239    }
240
241    fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor> {
242        self.install(|| analytic::expm1(input))
243    }
244
245    fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor> {
246        self.install(|| analytic::log1p(input))
247    }
248
249    fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor> {
250        self.install(|| structural::transpose(input, perm))
251    }
252
253    fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor> {
254        self.install(|| structural::reshape(input, shape))
255    }
256
257    fn broadcast_in_dim(
258        &mut self,
259        input: &Tensor,
260        shape: &[usize],
261        dims: &[usize],
262    ) -> crate::Result<Tensor> {
263        self.install(|| structural::broadcast_in_dim(input, shape, dims))
264    }
265
266    fn convert(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor> {
267        Ok(self.install(|| structural::convert(input, to)))
268    }
269
270    fn extract_diagonal(
271        &mut self,
272        input: &Tensor,
273        axis_a: usize,
274        axis_b: usize,
275    ) -> crate::Result<Tensor> {
276        self.install(|| structural::extract_diagonal(input, axis_a, axis_b))
277    }
278
279    fn embed_diagonal(
280        &mut self,
281        input: &Tensor,
282        axis_a: usize,
283        axis_b: usize,
284    ) -> crate::Result<Tensor> {
285        self.install(|| structural::embed_diagonal(input, axis_a, axis_b))
286    }
287
288    fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor> {
289        self.install(|| structural::tril(input, k))
290    }
291
292    fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor> {
293        self.install(|| structural::triu(input, k))
294    }
295
296    fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
297        self.install(|| reduction::reduce_sum(input, axes))
298    }
299
300    fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
301        self.install(|| reduction::reduce_prod(input, axes))
302    }
303
304    fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
305        self.install(|| reduction::reduce_max(input, axes))
306    }
307
308    fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
309        self.install(|| reduction::reduce_min(input, axes))
310    }
311
312    fn dot_general(
313        &mut self,
314        lhs: &Tensor,
315        rhs: &Tensor,
316        config: &DotGeneralConfig,
317    ) -> crate::Result<Tensor> {
318        let ctx = Arc::clone(&self.ctx);
319        self.install_with_pool(|buffers| match (lhs, rhs) {
320            #[cfg(feature = "cpu-faer")]
321            (Tensor::F32(a), Tensor::F32(b)) => {
322                gemm::dot_general(buffers, ctx.as_ref(), a, b, config).map(Tensor::F32)
323            }
324            #[cfg(feature = "cpu-faer")]
325            (Tensor::F64(a), Tensor::F64(b)) => {
326                gemm::dot_general(buffers, ctx.as_ref(), a, b, config).map(Tensor::F64)
327            }
328            #[cfg(feature = "cpu-faer")]
329            (Tensor::C32(a), Tensor::C32(b)) => {
330                gemm::dot_general(buffers, ctx.as_ref(), a, b, config).map(Tensor::C32)
331            }
332            #[cfg(feature = "cpu-faer")]
333            (Tensor::C64(a), Tensor::C64(b)) => {
334                gemm::dot_general(buffers, ctx.as_ref(), a, b, config).map(Tensor::C64)
335            }
336            #[cfg(feature = "cpu-blas")]
337            (Tensor::F32(a), Tensor::F32(b)) => {
338                gemm::dot_general(buffers, a, b, config).map(Tensor::F32)
339            }
340            #[cfg(feature = "cpu-blas")]
341            (Tensor::F64(a), Tensor::F64(b)) => {
342                gemm::dot_general(buffers, a, b, config).map(Tensor::F64)
343            }
344            #[cfg(feature = "cpu-blas")]
345            (Tensor::C32(a), Tensor::C32(b)) => {
346                gemm::dot_general(buffers, a, b, config).map(Tensor::C32)
347            }
348            #[cfg(feature = "cpu-blas")]
349            (Tensor::C64(a), Tensor::C64(b)) => {
350                gemm::dot_general(buffers, a, b, config).map(Tensor::C64)
351            }
352            _ => Err(crate::Error::DTypeMismatch {
353                op: "dot_general",
354                lhs: lhs.dtype(),
355                rhs: rhs.dtype(),
356            }),
357        })
358    }
359
360    fn gather(
361        &mut self,
362        operand: &Tensor,
363        start_indices: &Tensor,
364        config: &GatherConfig,
365    ) -> crate::Result<Tensor> {
366        self.install(|| {
367            catch_backend_panic("gather", || {
368                indexing::gather(operand, start_indices, config)
369            })
370        })
371    }
372
373    fn scatter(
374        &mut self,
375        operand: &Tensor,
376        scatter_indices: &Tensor,
377        updates: &Tensor,
378        config: &ScatterConfig,
379    ) -> crate::Result<Tensor> {
380        self.install(|| {
381            catch_backend_panic("scatter", || {
382                indexing::scatter(operand, scatter_indices, updates, config)
383            })
384        })
385    }
386
387    fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor> {
388        self.install(|| indexing::try_slice(input, config))
389    }
390
391    fn dynamic_slice(
392        &mut self,
393        input: &Tensor,
394        starts: &Tensor,
395        slice_sizes: &[usize],
396    ) -> crate::Result<Tensor> {
397        self.install(|| {
398            catch_backend_panic("dynamic_slice", || {
399                indexing::dynamic_slice(input, starts, slice_sizes)
400            })
401        })
402    }
403
404    fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
405        self.install(|| indexing::try_pad(input, config))
406    }
407
408    fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor> {
409        self.install(|| indexing::try_concatenate(inputs, axis))
410    }
411
412    fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
413        self.install(|| catch_backend_panic("reverse", || indexing::reverse(input, axes)))
414    }
415
416    fn cholesky(&mut self, input: &Tensor) -> crate::Result<Tensor> {
417        let ctx = Arc::clone(&self.ctx);
418        self.install_with_pool(|buffers| match input {
419            #[cfg(feature = "cpu-faer")]
420            Tensor::F64(t) => {
421                catch_backend_panic("cholesky", || linalg::cholesky(ctx.as_ref(), buffers, t))
422                    .and_then(|result| result)
423                    .map(Tensor::F64)
424            }
425            #[cfg(feature = "cpu-blas")]
426            Tensor::F64(t) => {
427                catch_backend_panic("cholesky", || linalg::cholesky(buffers, t)).map(Tensor::F64)
428            }
429            #[cfg(feature = "cpu-faer")]
430            Tensor::C64(t) => {
431                catch_backend_panic("cholesky", || linalg::cholesky(ctx.as_ref(), buffers, t))
432                    .and_then(|result| result)
433                    .map(Tensor::C64)
434            }
435            _ => Err(unsupported_dtype("cholesky", input.dtype())),
436        })
437    }
438
439    fn triangular_solve(
440        &mut self,
441        a: &Tensor,
442        b: &Tensor,
443        left_side: bool,
444        lower: bool,
445        transpose_a: bool,
446        unit_diagonal: bool,
447    ) -> crate::Result<Tensor> {
448        let ctx = Arc::clone(&self.ctx);
449        self.install_with_pool(|buffers| match (a, b) {
450            #[cfg(feature = "cpu-faer")]
451            (Tensor::F64(a), Tensor::F64(b)) => catch_backend_panic("triangular_solve", || {
452                Tensor::F64(linalg::triangular_solve(
453                    ctx.as_ref(),
454                    buffers,
455                    a,
456                    b,
457                    left_side,
458                    lower,
459                    transpose_a,
460                    unit_diagonal,
461                ))
462            }),
463            #[cfg(feature = "cpu-blas")]
464            (Tensor::F64(a), Tensor::F64(b)) => catch_backend_panic("triangular_solve", || {
465                Tensor::F64(linalg::triangular_solve(
466                    buffers,
467                    a,
468                    b,
469                    left_side,
470                    lower,
471                    transpose_a,
472                    unit_diagonal,
473                ))
474            }),
475            #[cfg(feature = "cpu-faer")]
476            (Tensor::C64(a), Tensor::C64(b)) => catch_backend_panic("triangular_solve", || {
477                Tensor::C64(linalg::triangular_solve(
478                    ctx.as_ref(),
479                    buffers,
480                    a,
481                    b,
482                    left_side,
483                    lower,
484                    transpose_a,
485                    unit_diagonal,
486                ))
487            }),
488            _ => {
489                if a.dtype() != b.dtype() {
490                    Err(crate::Error::DTypeMismatch {
491                        op: "triangular_solve",
492                        lhs: a.dtype(),
493                        rhs: b.dtype(),
494                    })
495                } else {
496                    Err(unsupported_dtype("triangular_solve", a.dtype()))
497                }
498            }
499        })
500    }
501
502    fn lu(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
503        let ctx = Arc::clone(&self.ctx);
504        self.install_with_pool(|buffers| match input {
505            #[cfg(feature = "cpu-faer")]
506            Tensor::F64(t) => catch_backend_panic("lu", || {
507                linalg::lu(ctx.as_ref(), buffers, t)
508                    .into_iter()
509                    .map(Tensor::F64)
510                    .collect()
511            }),
512            #[cfg(feature = "cpu-blas")]
513            Tensor::F64(t) => catch_backend_panic("lu", || {
514                linalg::lu(buffers, t)
515                    .into_iter()
516                    .map(Tensor::F64)
517                    .collect()
518            }),
519            #[cfg(feature = "cpu-faer")]
520            Tensor::C64(t) => catch_backend_panic("lu", || {
521                linalg::lu(ctx.as_ref(), buffers, t)
522                    .into_iter()
523                    .map(Tensor::C64)
524                    .collect()
525            }),
526            _ => Err(unsupported_dtype("lu", input.dtype())),
527        })
528    }
529
530    fn svd(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
531        let ctx = Arc::clone(&self.ctx);
532        self.install_with_pool(|buffers| match input {
533            #[cfg(feature = "cpu-faer")]
534            Tensor::F64(t) => catch_backend_panic("svd", || {
535                linalg::svd(ctx.as_ref(), buffers, t)
536                    .into_iter()
537                    .map(Tensor::F64)
538                    .collect()
539            }),
540            #[cfg(feature = "cpu-blas")]
541            Tensor::F64(t) => catch_backend_panic("svd", || {
542                linalg::svd(buffers, t)
543                    .into_iter()
544                    .map(Tensor::F64)
545                    .collect()
546            }),
547            #[cfg(feature = "cpu-faer")]
548            Tensor::C64(t) => catch_backend_panic("svd", || {
549                linalg::svd(ctx.as_ref(), buffers, t)
550                    .into_iter()
551                    .map(Tensor::C64)
552                    .collect()
553            }),
554            _ => Err(unsupported_dtype("svd", input.dtype())),
555        })
556    }
557
558    fn qr(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
559        let ctx = Arc::clone(&self.ctx);
560        self.install_with_pool(|buffers| match input {
561            #[cfg(feature = "cpu-faer")]
562            Tensor::F64(t) => catch_backend_panic("qr", || {
563                linalg::qr(ctx.as_ref(), buffers, t)
564                    .into_iter()
565                    .map(Tensor::F64)
566                    .collect()
567            }),
568            #[cfg(feature = "cpu-blas")]
569            Tensor::F64(t) => catch_backend_panic("qr", || {
570                linalg::qr(buffers, t)
571                    .into_iter()
572                    .map(Tensor::F64)
573                    .collect()
574            }),
575            #[cfg(feature = "cpu-faer")]
576            Tensor::C64(t) => catch_backend_panic("qr", || {
577                linalg::qr(ctx.as_ref(), buffers, t)
578                    .into_iter()
579                    .map(Tensor::C64)
580                    .collect()
581            }),
582            _ => Err(unsupported_dtype("qr", input.dtype())),
583        })
584    }
585
586    fn eigh(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
587        let ctx = Arc::clone(&self.ctx);
588        self.install_with_pool(|buffers| match input {
589            #[cfg(feature = "cpu-faer")]
590            Tensor::F64(t) => catch_backend_panic("eigh", || {
591                linalg::eigh(ctx.as_ref(), buffers, t)
592                    .into_iter()
593                    .map(Tensor::F64)
594                    .collect()
595            }),
596            #[cfg(feature = "cpu-blas")]
597            Tensor::F64(t) => catch_backend_panic("eigh", || {
598                linalg::eigh(buffers, t)
599                    .into_iter()
600                    .map(Tensor::F64)
601                    .collect()
602            }),
603            #[cfg(feature = "cpu-faer")]
604            Tensor::C64(t) => catch_backend_panic("eigh", || {
605                linalg::eigh(ctx.as_ref(), buffers, t)
606                    .into_iter()
607                    .map(Tensor::C64)
608                    .collect()
609            }),
610            _ => Err(unsupported_dtype("eigh", input.dtype())),
611        })
612    }
613
614    fn eig(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>> {
615        let ctx = Arc::clone(&self.ctx);
616        self.install_with_pool(|buffers| {
617            catch_backend_panic("eig", || {
618                #[cfg(feature = "cpu-faer")]
619                {
620                    linalg::eig(ctx.as_ref(), buffers, input)
621                }
622                #[cfg(feature = "cpu-blas")]
623                {
624                    linalg::eig(buffers, input)
625                }
626            })
627        })
628    }
629
630    fn solve(&mut self, a: &Tensor, b: &Tensor) -> crate::Result<Tensor> {
631        if has_zero_dim(a.shape()) || has_zero_dim(b.shape()) {
632            return Ok(zeros_like_tensor(b));
633        }
634
635        let (rhs, restore_shape) = if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
636            (
637                self.reshape(b, &matrix_rhs_shape)?,
638                Some(b.shape().to_vec()),
639            )
640        } else {
641            (b.clone(), None)
642        };
643
644        let outputs = self.lu(a)?;
645        let p = &outputs[0];
646        let l = &outputs[1];
647        let u = &outputs[2];
648        validate_nonsingular_u(u)?;
649
650        let pb = matmul_preserve_trailing_batch(self, p, &rhs)?;
651        let z = self.triangular_solve(l, &pb, true, true, false, true)?;
652        let x = self.triangular_solve(u, &z, true, false, false, false)?;
653        if let Some(shape) = restore_shape {
654            self.reshape(&x, &shape)
655        } else {
656            Ok(x)
657        }
658    }
659
660    fn with_exec_session<R: Send>(&mut self, f: impl FnOnce(&mut dyn TensorExec) -> R + Send) -> R {
661        let mut buffers = std::mem::take(&mut self.buffers);
662        let ctx = Arc::clone(&self.ctx);
663        let result = ctx.install(|| {
664            let mut session = CpuExecSession {
665                ctx: ctx.as_ref(),
666                buffers: &mut buffers,
667            };
668            f(&mut session)
669        });
670        self.buffers = buffers;
671        result
672    }
673
674    fn reclaim_buffer(&mut self, tensor: Tensor) {
675        match tensor {
676            Tensor::F32(t) => reclaim_typed(&mut self.buffers, t),
677            Tensor::F64(t) => reclaim_typed(&mut self.buffers, t),
678            Tensor::C32(t) => reclaim_typed(&mut self.buffers, t),
679            Tensor::C64(t) => reclaim_typed(&mut self.buffers, t),
680        }
681    }
682}
683
684fn has_zero_dim(shape: &[usize]) -> bool {
685    shape.contains(&0)
686}
687
688fn batched_vector_rhs_shape(a: &Tensor, b: &Tensor) -> Option<Vec<usize>> {
689    if b.shape().len() == 1 {
690        return Some(vec![b.shape()[0], 1]);
691    }
692
693    let is_batched_vector_rhs = a.shape().len() == b.shape().len() + 1
694        && !b.shape().is_empty()
695        && b.shape()[0] == a.shape()[0]
696        && b.shape()[1..] == a.shape()[2..];
697    if !is_batched_vector_rhs {
698        return None;
699    }
700
701    let mut rhs_shape = vec![b.shape()[0], 1];
702    rhs_shape.extend_from_slice(&b.shape()[1..]);
703    Some(rhs_shape)
704}
705
706fn matmul_preserve_trailing_batch(
707    backend: &mut CpuBackend,
708    lhs: &Tensor,
709    rhs: &Tensor,
710) -> crate::Result<Tensor> {
711    let rank = lhs.shape().len();
712    let batch_dims: Vec<usize> = (2..rank).collect();
713    backend.dot_general(
714        lhs,
715        rhs,
716        &DotGeneralConfig {
717            lhs_contracting_dims: vec![1],
718            rhs_contracting_dims: vec![0],
719            lhs_batch_dims: batch_dims.clone(),
720            rhs_batch_dims: batch_dims,
721            lhs_rank: rank,
722            rhs_rank: rank,
723        },
724    )
725}
726
727pub(crate) fn reclaim_typed<T: PoolScalar>(pool: &mut BufferPool, typed: TypedTensor<T>) {
728    match typed.buffer {
729        Buffer::Host(data) => T::pool_release(pool, data),
730        Buffer::Backend(_) => {}
731        #[cfg(feature = "cubecl")]
732        Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
733    }
734}
735
736fn zeros_like_tensor(input: &Tensor) -> Tensor {
737    match input {
738        Tensor::F32(t) => Tensor::F32(TypedTensor::zeros(t.shape.clone())),
739        Tensor::F64(t) => Tensor::F64(TypedTensor::zeros(t.shape.clone())),
740        Tensor::C32(t) => Tensor::C32(TypedTensor::zeros(t.shape.clone())),
741        Tensor::C64(t) => Tensor::C64(TypedTensor::zeros(t.shape.clone())),
742    }
743}
744
745fn panic_payload_message(payload: Box<dyn Any + Send>) -> String {
746    if let Some(message) = payload.downcast_ref::<&str>() {
747        (*message).to_string()
748    } else if let Some(message) = payload.downcast_ref::<String>() {
749        message.clone()
750    } else {
751        "backend panic".into()
752    }
753}
754
755pub(crate) fn catch_backend_panic<R>(op: &'static str, f: impl FnOnce() -> R) -> crate::Result<R> {
756    catch_unwind(AssertUnwindSafe(f)).map_err(|payload| crate::Error::BackendFailure {
757        op,
758        message: panic_payload_message(payload),
759    })
760}
761
762pub(crate) fn unsupported_dtype(op: &'static str, dtype: crate::DType) -> crate::Error {
763    crate::Error::BackendFailure {
764        op,
765        message: format!("unsupported dtype {dtype:?}"),
766    }
767}
768
769fn validate_axis_role_conflicts(
770    op: &'static str,
771    first_role: &'static str,
772    first_axes: &[usize],
773    second_role: &'static str,
774    second_axes: &[usize],
775) -> crate::Result<()> {
776    for &axis in first_axes {
777        if second_axes.contains(&axis) {
778            return Err(crate::Error::AxisRoleConflict {
779                op,
780                axis,
781                first_role,
782                second_role,
783            });
784        }
785    }
786    Ok(())
787}
788
789fn validate_axis_list(
790    op: &'static str,
791    role: &'static str,
792    axes: &[usize],
793    rank: usize,
794) -> crate::Result<()> {
795    let mut seen = vec![false; rank];
796    for &axis in axes {
797        if axis >= rank {
798            return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
799        }
800        if seen[axis] {
801            return Err(crate::Error::DuplicateAxis { op, axis, role });
802        }
803        seen[axis] = true;
804    }
805    Ok(())
806}
807
808fn validate_semiring_batched_gemm_config<T>(
809    lhs: &TypedTensor<T>,
810    rhs: &TypedTensor<T>,
811    config: &DotGeneralConfig,
812) -> crate::Result<()> {
813    const OP: &str = "batched_gemm";
814
815    if config.lhs_contracting_dims.len() != config.rhs_contracting_dims.len() {
816        return Err(crate::Error::InvalidConfig {
817            op: OP,
818            message: "contracting dim count mismatch".into(),
819        });
820    }
821    if config.lhs_batch_dims.len() != config.rhs_batch_dims.len() {
822        return Err(crate::Error::InvalidConfig {
823            op: OP,
824            message: "batch dim count mismatch".into(),
825        });
826    }
827
828    let lhs_rank = lhs.shape.len();
829    let rhs_rank = rhs.shape.len();
830    validate_axis_list(
831        OP,
832        "lhs_contracting_dims",
833        &config.lhs_contracting_dims,
834        lhs_rank,
835    )?;
836    validate_axis_list(
837        OP,
838        "rhs_contracting_dims",
839        &config.rhs_contracting_dims,
840        rhs_rank,
841    )?;
842    validate_axis_list(OP, "lhs_batch_dims", &config.lhs_batch_dims, lhs_rank)?;
843    validate_axis_list(OP, "rhs_batch_dims", &config.rhs_batch_dims, rhs_rank)?;
844    validate_axis_role_conflicts(
845        OP,
846        "lhs_contracting_dims",
847        &config.lhs_contracting_dims,
848        "lhs_batch_dims",
849        &config.lhs_batch_dims,
850    )?;
851    validate_axis_role_conflicts(
852        OP,
853        "rhs_contracting_dims",
854        &config.rhs_contracting_dims,
855        "rhs_batch_dims",
856        &config.rhs_batch_dims,
857    )?;
858
859    for (&lhs_axis, &rhs_axis) in config
860        .lhs_contracting_dims
861        .iter()
862        .zip(&config.rhs_contracting_dims)
863    {
864        if lhs.shape[lhs_axis] != rhs.shape[rhs_axis] {
865            return Err(crate::Error::ShapeMismatch {
866                op: OP,
867                lhs: vec![lhs.shape[lhs_axis]],
868                rhs: vec![rhs.shape[rhs_axis]],
869            });
870        }
871    }
872
873    for (&lhs_axis, &rhs_axis) in config.lhs_batch_dims.iter().zip(&config.rhs_batch_dims) {
874        if lhs.shape[lhs_axis] != rhs.shape[rhs_axis] {
875            return Err(crate::Error::ShapeMismatch {
876                op: OP,
877                lhs: vec![lhs.shape[lhs_axis]],
878                rhs: vec![rhs.shape[rhs_axis]],
879            });
880        }
881    }
882
883    Ok(())
884}
885
886impl<Alg: Semiring> SemiringBackend<Alg> for CpuBackend {
887    fn batched_gemm(
888        &mut self,
889        lhs: &TypedTensor<Alg::Scalar>,
890        rhs: &TypedTensor<Alg::Scalar>,
891        config: &DotGeneralConfig,
892    ) -> crate::Result<TypedTensor<Alg::Scalar>> {
893        validate_semiring_batched_gemm_config(lhs, rhs, config)?;
894        Ok(self.install(|| {
895            let lhs_rank = lhs.shape.len();
896            let rhs_rank = rhs.shape.len();
897            let lhs_free: Vec<usize> = (0..lhs_rank)
898                .filter(|d| {
899                    !config.lhs_contracting_dims.contains(d) && !config.lhs_batch_dims.contains(d)
900                })
901                .collect();
902            let rhs_free: Vec<usize> = (0..rhs_rank)
903                .filter(|d| {
904                    !config.rhs_contracting_dims.contains(d) && !config.rhs_batch_dims.contains(d)
905                })
906                .collect();
907
908            let batch_shape: Vec<usize> = config
909                .lhs_batch_dims
910                .iter()
911                .map(|&d| lhs.shape[d])
912                .collect();
913            let lhs_free_shape: Vec<usize> = lhs_free.iter().map(|&d| lhs.shape[d]).collect();
914            let rhs_free_shape: Vec<usize> = rhs_free.iter().map(|&d| rhs.shape[d]).collect();
915            let contract_shape: Vec<usize> = config
916                .lhs_contracting_dims
917                .iter()
918                .map(|&d| lhs.shape[d])
919                .collect();
920
921            let mut out_shape = Vec::new();
922            out_shape.extend_from_slice(&lhs_free_shape);
923            out_shape.extend_from_slice(&rhs_free_shape);
924            out_shape.extend_from_slice(&batch_shape);
925
926            let out_n: usize = out_shape.iter().product();
927            let contract_n: usize = contract_shape.iter().product();
928
929            let mut result = TypedTensor::zeros(out_shape.clone());
930            let n_lhs_free = lhs_free_shape.len();
931            let n_rhs_free = rhs_free_shape.len();
932
933            let mut out_idx = vec![0usize; out_shape.len()];
934            let mut lhs_idx = vec![0usize; lhs_rank];
935            let mut rhs_idx = vec![0usize; rhs_rank];
936            let mut contract_idx = vec![0usize; contract_shape.len()];
937
938            for flat_out in 0..out_n {
939                flat_to_multi(flat_out, &out_shape, &mut out_idx);
940
941                let lhs_free_vals = &out_idx[..n_lhs_free];
942                let rhs_free_vals = &out_idx[n_lhs_free..n_lhs_free + n_rhs_free];
943                let batch_vals = &out_idx[n_lhs_free + n_rhs_free..];
944
945                for (bi, &ld) in config.lhs_batch_dims.iter().enumerate() {
946                    lhs_idx[ld] = batch_vals[bi];
947                }
948                for (bi, &rd) in config.rhs_batch_dims.iter().enumerate() {
949                    rhs_idx[rd] = batch_vals[bi];
950                }
951                for (fi, &ld) in lhs_free.iter().enumerate() {
952                    lhs_idx[ld] = lhs_free_vals[fi];
953                }
954                for (fi, &rd) in rhs_free.iter().enumerate() {
955                    rhs_idx[rd] = rhs_free_vals[fi];
956                }
957
958                let mut acc = Alg::zero();
959                for flat_k in 0..contract_n {
960                    flat_to_multi(flat_k, &contract_shape, &mut contract_idx);
961                    for (ci, &ld) in config.lhs_contracting_dims.iter().enumerate() {
962                        lhs_idx[ld] = contract_idx[ci];
963                    }
964                    for (ci, &rd) in config.rhs_contracting_dims.iter().enumerate() {
965                        rhs_idx[rd] = contract_idx[ci];
966                    }
967                    acc = Alg::add(acc, Alg::mul(*lhs.get(&lhs_idx), *rhs.get(&rhs_idx)));
968                }
969
970                *result.get_mut(&out_idx) = acc;
971            }
972
973            result
974        }))
975    }
976}
977
978impl Default for CpuBackend {
979    fn default() -> Self {
980        Self::new()
981    }
982}