1use std::cmp::Reverse;
2use std::collections::HashMap;
3use std::env;
4use std::fmt;
5use std::sync::{Arc, Mutex, OnceLock};
6use std::time::{Duration, Instant};
7
8use crate::buffer_pool::{BufferPool, BufferPoolStats, PoolScalar};
9use crate::{
10 Buffer, CacheStats, Tensor, TensorRank, TensorRead, TensorValue, TypedTensor, TypedTensorView,
11 TypedTensorViewMut,
12};
13use tenferro_tensor::{
14 BackendCachedDot, BackendRuntimeCache, BackendSession, BackendSessionHost, TensorAnalytic,
15 TensorBackend, TensorBuffer, TensorDeviceTransfer, TensorDot, TensorElementwise, TensorFusion,
16 TensorIndexing, TensorReduction, TensorStructural, TensorViewCanonicalization,
17};
18use tenferro_tensor::{
19 CompareDir, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
20};
21
22use super::exec_session::CpuExecSession;
23use super::{
24 analytic, elementwise, gemm, indexing, materialize_tensor_read, reduction, structural,
25 CpuContext,
26};
27
28#[derive(Debug, Default, Clone)]
29struct CpuSessionProfileEntry {
30 calls: usize,
31 total_time: Duration,
32}
33
34fn cpu_session_profile_enabled() -> bool {
35 static ENABLED: OnceLock<bool> = OnceLock::new();
36 *ENABLED.get_or_init(|| env::var("TENFERRO_PROFILE_CPU_SESSION").is_ok())
37}
38
39fn cpu_session_profile_print_every() -> Option<usize> {
40 static PRINT_EVERY: OnceLock<Option<usize>> = OnceLock::new();
41 *PRINT_EVERY.get_or_init(|| {
42 env::var("TENFERRO_PROFILE_CPU_SESSION_PRINT_EVERY")
43 .ok()
44 .and_then(|value| value.parse::<usize>().ok())
45 .filter(|&value| value > 0)
46 })
47}
48
49fn cpu_session_profile_state() -> &'static Mutex<HashMap<&'static str, CpuSessionProfileEntry>> {
50 static STATE: OnceLock<Mutex<HashMap<&'static str, CpuSessionProfileEntry>>> = OnceLock::new();
51 STATE.get_or_init(|| Mutex::new(HashMap::new()))
52}
53
54fn record_cpu_session_profile(section: &'static str, elapsed: Duration) {
55 if !cpu_session_profile_enabled() {
56 return;
57 }
58 let Ok(mut state) = cpu_session_profile_state().lock() else {
59 return;
60 };
61 let entry = state.entry(section).or_default();
62 entry.calls += 1;
63 entry.total_time += elapsed;
64}
65
66fn profile_cpu_session_section<T>(section: &'static str, f: impl FnOnce() -> T) -> T {
67 if !cpu_session_profile_enabled() {
68 return f();
69 }
70 let started = Instant::now();
71 let result = f();
72 record_cpu_session_profile(section, started.elapsed());
73 result
74}
75
76fn maybe_print_cpu_session_profile() {
77 let Some(print_every) = cpu_session_profile_print_every() else {
78 return;
79 };
80 let should_print = {
81 let Ok(state) = cpu_session_profile_state().lock() else {
82 return;
83 };
84 state
85 .get("with_backend_session_cached.total")
86 .is_some_and(|entry| entry.calls % print_every == 0)
87 };
88 if !should_print {
89 return;
90 }
91 let mut entries = {
92 let Ok(mut state) = cpu_session_profile_state().lock() else {
93 return;
94 };
95 let entries = state
96 .iter()
97 .map(|(section, entry)| (*section, entry.clone()))
98 .collect::<Vec<_>>();
99 state.clear();
100 entries
101 };
102 entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
103 eprintln!("=== tenferro CPU session profile ===");
104 for (section, entry) in entries {
105 eprintln!(
106 "{section}: calls={} total={:.6}ms per_call={:.3}us",
107 entry.calls,
108 entry.total_time.as_secs_f64() * 1.0e3,
109 entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64,
110 );
111 }
112}
113
114struct BufferPoolLoan<'a> {
115 target: &'a mut BufferPool,
116 buffers: Option<BufferPool>,
117}
118
119impl<'a> BufferPoolLoan<'a> {
120 fn new(target: &'a mut BufferPool) -> Self {
121 Self {
122 buffers: Some(std::mem::take(target)),
123 target,
124 }
125 }
126
127 fn get_mut(&mut self) -> &mut BufferPool {
128 self.buffers
129 .as_mut()
130 .expect("buffer pool loan already restored")
131 }
132}
133
134impl Drop for BufferPoolLoan<'_> {
135 fn drop(&mut self) {
136 if let Some(buffers) = self.buffers.take() {
137 *self.target = buffers;
138 }
139 }
140}
141
142#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
157pub enum CpuBackendKind {
158 Faer,
160 Blas,
162}
163
164impl CpuBackendKind {
165 pub fn default_compiled() -> Self {
179 #[cfg(feature = "cpu-blas")]
180 {
181 Self::Blas
182 }
183 #[cfg(all(not(feature = "cpu-blas"), feature = "cpu-faer"))]
184 {
185 Self::Faer
186 }
187 }
188
189 #[allow(dead_code)]
192 pub(crate) fn name(self) -> &'static str {
193 match self {
194 Self::Faer => "faer",
195 Self::Blas => "blas",
196 }
197 }
198}
199
200fn ensure_cpu_backend_kind_available(kind: CpuBackendKind, op: &'static str) -> crate::Result<()> {
201 let _ = op;
202 match kind {
203 CpuBackendKind::Faer => {
204 #[cfg(feature = "cpu-faer")]
205 {
206 Ok(())
207 }
208 #[cfg(not(feature = "cpu-faer"))]
209 {
210 Err(crate::Error::InvalidConfig {
211 op,
212 message: "CpuBackendKind::Faer requires the cpu-faer feature".to_string(),
213 })
214 }
215 }
216 CpuBackendKind::Blas => {
217 #[cfg(feature = "cpu-blas")]
218 {
219 Ok(())
220 }
221 #[cfg(not(feature = "cpu-blas"))]
222 {
223 Err(crate::Error::InvalidConfig {
224 op,
225 message: "CpuBackendKind::Blas requires the cpu-blas feature".to_string(),
226 })
227 }
228 }
229 }
230}
231
232#[allow(dead_code)]
235pub(super) fn unavailable_cpu_backend_kind(kind: CpuBackendKind, op: &'static str) -> crate::Error {
236 crate::Error::InvalidConfig {
237 op,
238 message: format!("CPU backend kind {} is not compiled in", kind.name()),
239 }
240}
241
242pub struct CpuBackend {
252 pub(crate) ctx: Arc<CpuContext>,
253 pub(crate) buffers: BufferPool,
254 kind: CpuBackendKind,
255}
256
257impl fmt::Debug for CpuBackend {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 f.debug_struct("CpuBackend")
260 .field("kind", &self.kind)
261 .field("num_threads", &self.num_threads())
262 .field("buffer_pool_cache_stats", &self.buffer_pool_cache_stats())
263 .field("buffer_pool_limit_bytes", &self.buffer_pool_limit_bytes())
264 .finish_non_exhaustive()
265 }
266}
267
268impl CpuBackend {
269 pub fn new() -> Self {
279 Self::from_context(Arc::new(CpuContext::from_env()))
280 }
281
282 pub fn with_kind(kind: CpuBackendKind) -> crate::Result<Self> {
293 Self::try_from_context_with_kind(Arc::new(CpuContext::from_env()), kind)
294 }
295
296 pub fn try_new() -> crate::Result<Self> {
308 CpuContext::try_from_env().map(|ctx| Self::from_context(Arc::new(ctx)))
309 }
310
311 pub fn from_context(ctx: Arc<CpuContext>) -> Self {
324 Self {
325 ctx,
326 buffers: BufferPool::new(),
327 kind: CpuBackendKind::default_compiled(),
328 }
329 }
330
331 fn try_from_context_with_kind(
332 ctx: Arc<CpuContext>,
333 kind: CpuBackendKind,
334 ) -> crate::Result<Self> {
335 ensure_cpu_backend_kind_available(kind, "CpuBackend::with_kind")?;
336 Ok(Self {
337 ctx,
338 buffers: BufferPool::new(),
339 kind,
340 })
341 }
342
343 pub fn from_context_with_buffer_pool_limit(
359 ctx: Arc<CpuContext>,
360 max_retained_capacity_bytes: usize,
361 ) -> Self {
362 Self::from_context_with_buffer_pool_limit_and_kind(
363 ctx,
364 max_retained_capacity_bytes,
365 CpuBackendKind::default_compiled(),
366 )
367 }
368
369 fn from_context_with_buffer_pool_limit_and_kind(
370 ctx: Arc<CpuContext>,
371 max_retained_capacity_bytes: usize,
372 kind: CpuBackendKind,
373 ) -> Self {
374 Self {
375 ctx,
376 buffers: BufferPool::with_max_retained_capacity_bytes(max_retained_capacity_bytes),
377 kind,
378 }
379 }
380
381 pub fn with_threads(num_threads: usize) -> crate::Result<Self> {
396 CpuContext::with_threads(num_threads)
397 .map(|ctx| Self::from_context(Arc::new(ctx)))
398 .map_err(|err| match err {
399 crate::Error::InvalidConfig { message, .. } => crate::Error::InvalidConfig {
400 op: "CpuBackend::with_threads",
401 message,
402 },
403 crate::Error::BackendFailure { message, .. } => {
404 crate::Error::backend_failure("CpuBackend::with_threads", message)
405 }
406 err => err,
407 })
408 }
409
410 pub fn with_threads_and_kind(num_threads: usize, kind: CpuBackendKind) -> crate::Result<Self> {
430 ensure_cpu_backend_kind_available(kind, "CpuBackend::with_threads_and_kind")?;
431 CpuContext::with_threads(num_threads)
432 .map(|ctx| Self {
433 ctx: Arc::new(ctx),
434 buffers: BufferPool::new(),
435 kind,
436 })
437 .map_err(|err| match err {
438 crate::Error::InvalidConfig { message, .. } => crate::Error::InvalidConfig {
439 op: "CpuBackend::with_threads_and_kind",
440 message,
441 },
442 crate::Error::BackendFailure { message, .. } => {
443 crate::Error::backend_failure("CpuBackend::with_threads_and_kind", message)
444 }
445 err => err,
446 })
447 }
448
449 pub fn kind(&self) -> CpuBackendKind {
460 self.kind
461 }
462
463 pub fn num_threads(&self) -> usize {
474 self.ctx.num_threads()
475 }
476
477 pub fn buffer_pool_len(&self) -> usize {
488 self.buffers.len()
489 }
490
491 pub fn buffer_pool_stats(&self) -> BufferPoolStats {
504 self.buffers.stats()
505 }
506
507 pub fn buffer_pool_cache_stats(&self) -> CacheStats {
520 self.buffers.cache_stats()
521 }
522
523 pub fn buffer_pool_limit_bytes(&self) -> usize {
538 self.buffers.max_retained_capacity_bytes()
539 }
540
541 pub fn set_buffer_pool_limit_bytes(&mut self, max_retained_capacity_bytes: usize) {
557 self.buffers
558 .set_max_retained_capacity_bytes(max_retained_capacity_bytes);
559 }
560
561 pub fn reset_buffer_pool(&mut self) {
577 self.buffers.clear();
578 }
579
580 pub fn install<R: Send>(&self, op: impl FnOnce() -> R + Send) -> R {
592 self.ctx.install(op)
593 }
594
595 fn install_with_pool<R: Send>(&mut self, op: impl FnOnce(&mut BufferPool) -> R + Send) -> R {
596 let mut buffers = BufferPoolLoan::new(&mut self.buffers);
597 let ctx = Arc::clone(&self.ctx);
598 ctx.install(|| op(buffers.get_mut()))
599 }
600
601 #[allow(dead_code)]
604 fn run_with_pool<R>(&mut self, op: impl FnOnce(&mut BufferPool) -> R) -> R {
605 let mut buffers = BufferPoolLoan::new(&mut self.buffers);
606 op(buffers.get_mut())
607 }
608
609 fn linalg_with_pool<R: Send>(&mut self, op: impl FnOnce(&mut BufferPool) -> R + Send) -> R {
610 match self.kind {
611 CpuBackendKind::Faer => self.install_with_pool(op),
612 CpuBackendKind::Blas => self.run_with_pool(op),
613 }
614 }
615
616 #[doc(hidden)]
621 pub fn with_linalg_pool<R: Send>(&mut self, op: impl FnOnce(&mut BufferPool) -> R + Send) -> R {
622 self.linalg_with_pool(op)
623 }
624
625 #[cfg(feature = "cpu-faer")]
627 #[doc(hidden)]
628 pub fn linalg_context(&self) -> Arc<CpuContext> {
629 Arc::clone(&self.ctx)
630 }
631
632 #[allow(dead_code)]
635 fn install_with_pool_and_gemm_cache<R: Send>(
636 &mut self,
637 gemm_analysis_cache: &mut gemm::GemmAnalysisCache,
638 op: impl FnOnce(&mut BufferPool, &mut gemm::GemmAnalysisCache) -> R + Send,
639 ) -> R {
640 let mut buffers = BufferPoolLoan::new(&mut self.buffers);
641 let ctx = Arc::clone(&self.ctx);
642 ctx.install(|| op(buffers.get_mut(), gemm_analysis_cache))
643 }
644
645 #[allow(dead_code)]
648 fn run_with_pool_and_gemm_cache<R>(
649 &mut self,
650 gemm_analysis_cache: &mut gemm::GemmAnalysisCache,
651 op: impl FnOnce(&mut BufferPool, &mut gemm::GemmAnalysisCache) -> R,
652 ) -> R {
653 let mut buffers = BufferPoolLoan::new(&mut self.buffers);
654 op(buffers.get_mut(), gemm_analysis_cache)
655 }
656}
657
658impl BackendRuntimeCache for CpuBackend {
659 type RuntimeCache = gemm::GemmAnalysisCache;
660}
661
662impl TensorElementwise for CpuBackend {
663 fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
664 self.install_with_pool(|buffers| elementwise::add_with_pool(buffers, lhs, rhs))
665 }
666
667 fn add_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
668 self.install_with_pool(|buffers| elementwise::add_read_with_pool(buffers, lhs, rhs))
669 }
670
671 fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
672 self.install_with_pool(|buffers| elementwise::mul_with_pool(buffers, lhs, rhs))
673 }
674
675 fn mul_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
676 self.install_with_pool(|buffers| elementwise::mul_read_with_pool(buffers, lhs, rhs))
677 }
678
679 fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor> {
680 self.install_with_pool(|buffers| elementwise::neg_with_pool(buffers, input))
681 }
682
683 fn neg_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
684 self.install_with_pool(|buffers| elementwise::neg_read_with_pool(buffers, input))
685 }
686
687 fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor> {
688 self.install_with_pool(|buffers| elementwise::conj_with_pool(buffers, input))
689 }
690
691 fn conj_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
692 self.install_with_pool(|buffers| elementwise::conj_read_with_pool(buffers, input))
693 }
694
695 fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
696 self.install_with_pool(|buffers| elementwise::div_with_pool(buffers, lhs, rhs))
697 }
698
699 fn div_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
700 self.install_with_pool(|buffers| elementwise::div_read_with_pool(buffers, lhs, rhs))
701 }
702
703 fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor> {
704 self.install_with_pool(|buffers| elementwise::abs_with_pool(buffers, input))
705 }
706
707 fn abs_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
708 self.install_with_pool(|buffers| elementwise::abs_read_with_pool(buffers, input))
709 }
710
711 fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor> {
712 self.install_with_pool(|buffers| elementwise::sign_with_pool(buffers, input))
713 }
714
715 fn sign_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
716 self.install_with_pool(|buffers| elementwise::sign_read_with_pool(buffers, input))
717 }
718
719 fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
720 self.install_with_pool(|buffers| elementwise::maximum_with_pool(buffers, lhs, rhs))
721 }
722
723 fn maximum_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
724 self.install_with_pool(|buffers| elementwise::maximum_read_with_pool(buffers, lhs, rhs))
725 }
726
727 fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
728 self.install_with_pool(|buffers| elementwise::minimum_with_pool(buffers, lhs, rhs))
729 }
730
731 fn minimum_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
732 self.install_with_pool(|buffers| elementwise::minimum_read_with_pool(buffers, lhs, rhs))
733 }
734
735 fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor> {
736 self.install_with_pool(|buffers| elementwise::compare_with_pool(buffers, lhs, rhs, dir))
737 }
738
739 fn compare_read(
740 &mut self,
741 lhs: TensorRead<'_>,
742 rhs: TensorRead<'_>,
743 dir: &CompareDir,
744 ) -> crate::Result<Tensor> {
745 self.install_with_pool(|buffers| {
746 elementwise::compare_read_with_pool(buffers, lhs, rhs, dir)
747 })
748 }
749
750 fn select(
751 &mut self,
752 pred: &Tensor,
753 on_true: &Tensor,
754 on_false: &Tensor,
755 ) -> crate::Result<Tensor> {
756 self.install_with_pool(|buffers| {
757 elementwise::select_with_pool(buffers, pred, on_true, on_false)
758 })
759 }
760
761 fn select_read(
762 &mut self,
763 pred: TensorRead<'_>,
764 on_true: TensorRead<'_>,
765 on_false: TensorRead<'_>,
766 ) -> crate::Result<Tensor> {
767 self.install_with_pool(|buffers| {
768 elementwise::select_read_with_pool(buffers, pred, on_true, on_false)
769 })
770 }
771
772 fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor> {
773 self.install_with_pool(|buffers| elementwise::clamp_with_pool(buffers, input, lower, upper))
774 }
775
776 fn clamp_read(
777 &mut self,
778 input: TensorRead<'_>,
779 lower: TensorRead<'_>,
780 upper: TensorRead<'_>,
781 ) -> crate::Result<Tensor> {
782 self.install_with_pool(|buffers| {
783 elementwise::clamp_read_with_pool(buffers, input, lower, upper)
784 })
785 }
786}
787
788impl TensorAnalytic for CpuBackend {
789 fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor> {
790 self.install_with_pool(|buffers| analytic::exp_with_pool(buffers, input))
791 }
792
793 fn exp_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
794 self.install_with_pool(|buffers| analytic::exp_read_with_pool(buffers, input))
795 }
796
797 fn log(&mut self, input: &Tensor) -> crate::Result<Tensor> {
798 self.install_with_pool(|buffers| analytic::log_with_pool(buffers, input))
799 }
800
801 fn log_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
802 self.install_with_pool(|buffers| analytic::log_read_with_pool(buffers, input))
803 }
804
805 fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor> {
806 self.install_with_pool(|buffers| analytic::sin_with_pool(buffers, input))
807 }
808
809 fn sin_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
810 self.install_with_pool(|buffers| analytic::sin_read_with_pool(buffers, input))
811 }
812
813 fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor> {
814 self.install_with_pool(|buffers| analytic::cos_with_pool(buffers, input))
815 }
816
817 fn cos_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
818 self.install_with_pool(|buffers| analytic::cos_read_with_pool(buffers, input))
819 }
820
821 fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor> {
822 self.install_with_pool(|buffers| analytic::tanh_with_pool(buffers, input))
823 }
824
825 fn tanh_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
826 self.install_with_pool(|buffers| analytic::tanh_read_with_pool(buffers, input))
827 }
828
829 fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor> {
830 self.install_with_pool(|buffers| analytic::sqrt_with_pool(buffers, input))
831 }
832
833 fn sqrt_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
834 self.install_with_pool(|buffers| analytic::sqrt_read_with_pool(buffers, input))
835 }
836
837 fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor> {
838 self.install_with_pool(|buffers| analytic::rsqrt_with_pool(buffers, input))
839 }
840
841 fn rsqrt_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
842 self.install_with_pool(|buffers| analytic::rsqrt_read_with_pool(buffers, input))
843 }
844
845 fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
846 self.install_with_pool(|buffers| analytic::pow_with_pool(buffers, lhs, rhs))
847 }
848
849 fn pow_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
850 self.install_with_pool(|buffers| analytic::pow_read_with_pool(buffers, lhs, rhs))
851 }
852
853 fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor> {
854 self.install_with_pool(|buffers| analytic::expm1_with_pool(buffers, input))
855 }
856
857 fn expm1_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
858 self.install_with_pool(|buffers| analytic::expm1_read_with_pool(buffers, input))
859 }
860
861 fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor> {
862 self.install_with_pool(|buffers| analytic::log1p_with_pool(buffers, input))
863 }
864
865 fn log1p_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
866 self.install_with_pool(|buffers| analytic::log1p_read_with_pool(buffers, input))
867 }
868}
869
870impl TensorStructural for CpuBackend {
871 fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor> {
872 self.install_with_pool(|buffers| structural::transpose_with_pool(buffers, input, perm))
873 }
874
875 fn transpose_read(&mut self, input: TensorRead<'_>, perm: &[usize]) -> crate::Result<Tensor> {
876 if let Some(input) = input.as_tensor() {
877 return self.transpose(input, perm);
878 }
879
880 let input = materialize_tensor_read("transpose", input)?;
881 self.transpose(&input, perm)
882 }
883
884 fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor> {
885 self.install(|| structural::reshape(input, shape))
886 }
887
888 fn reshape_read(&mut self, input: TensorRead<'_>, shape: &[usize]) -> crate::Result<Tensor> {
889 if let Some(input) = input.as_tensor() {
890 return self.reshape(input, shape);
891 }
892
893 let input = materialize_tensor_read("reshape", input)?;
894 self.reshape(&input, shape)
895 }
896
897 fn broadcast_in_dim(
898 &mut self,
899 input: &Tensor,
900 shape: &[usize],
901 dims: &[usize],
902 ) -> crate::Result<Tensor> {
903 self.install_with_pool(|buffers| {
904 structural::broadcast_in_dim_with_pool(buffers, input, shape, dims)
905 })
906 }
907
908 fn broadcast_in_dim_read(
909 &mut self,
910 input: TensorRead<'_>,
911 shape: &[usize],
912 dims: &[usize],
913 ) -> crate::Result<Tensor> {
914 if let Some(input) = input.as_tensor() {
915 return self.broadcast_in_dim(input, shape, dims);
916 }
917
918 let input = materialize_tensor_read("broadcast_in_dim", input)?;
919 self.broadcast_in_dim(&input, shape, dims)
920 }
921
922 fn cast(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor> {
923 self.install_with_pool(|buffers| structural::cast_with_pool(buffers, input, to))
924 }
925
926 fn extract_diagonal(
927 &mut self,
928 input: &Tensor,
929 axis_a: usize,
930 axis_b: usize,
931 ) -> crate::Result<Tensor> {
932 self.install_with_pool(|buffers| {
933 structural::extract_diagonal_with_pool(buffers, input, axis_a, axis_b)
934 })
935 }
936
937 fn embed_diagonal(
938 &mut self,
939 input: &Tensor,
940 axis_a: usize,
941 axis_b: usize,
942 ) -> crate::Result<Tensor> {
943 self.install_with_pool(|buffers| {
944 structural::embed_diagonal_with_pool(buffers, input, axis_a, axis_b)
945 })
946 }
947
948 fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor> {
949 self.install_with_pool(|buffers| structural::tril_with_pool(buffers, input, k))
950 }
951
952 fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor> {
953 self.install_with_pool(|buffers| structural::triu_with_pool(buffers, input, k))
954 }
955}
956
957impl TensorReduction for CpuBackend {
958 fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
959 self.install(|| reduction::reduce_sum(input, axes))
960 }
961
962 fn reduce_sum_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
963 self.install(|| reduction::reduce_sum_read(input, axes))
964 }
965
966 fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
967 self.install(|| reduction::reduce_prod(input, axes))
968 }
969
970 fn reduce_prod_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
971 self.install(|| reduction::reduce_prod_read(input, axes))
972 }
973
974 fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
975 self.install(|| reduction::reduce_max(input, axes))
976 }
977
978 fn reduce_max_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
979 self.install(|| reduction::reduce_max_read(input, axes))
980 }
981
982 fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
983 self.install(|| reduction::reduce_min(input, axes))
984 }
985
986 fn reduce_min_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
987 self.install(|| reduction::reduce_min_read(input, axes))
988 }
989}
990
991impl TensorDot for CpuBackend {
992 fn dot_general(
993 &mut self,
994 lhs: &Tensor,
995 rhs: &Tensor,
996 config: &DotGeneralConfig,
997 ) -> crate::Result<Tensor> {
998 let mut cache = gemm::GemmAnalysisCache::default();
999 BackendCachedDot::dot_general_cached(self, &mut cache, None, lhs, rhs, config)
1000 }
1001
1002 fn dot_general_read(
1003 &mut self,
1004 lhs: TensorRead<'_>,
1005 rhs: TensorRead<'_>,
1006 config: &DotGeneralConfig,
1007 ) -> crate::Result<Tensor> {
1008 let mut cache = gemm::GemmAnalysisCache::default();
1009 let direct = match self.kind {
1010 CpuBackendKind::Faer => {
1011 #[cfg(feature = "cpu-faer")]
1012 {
1013 let ctx = Arc::clone(&self.ctx);
1014 self.install_with_pool_and_gemm_cache(&mut cache, |buffers, cache| {
1015 gemm::dot_general_faer_read_cached(
1016 buffers,
1017 cache,
1018 None,
1019 ctx.as_ref(),
1020 lhs.clone(),
1021 rhs.clone(),
1022 config,
1023 )
1024 })?
1025 }
1026 #[cfg(not(feature = "cpu-faer"))]
1027 {
1028 return Err(unavailable_cpu_backend_kind(self.kind, "dot_general"));
1029 }
1030 }
1031 CpuBackendKind::Blas => {
1032 #[cfg(feature = "cpu-blas")]
1033 {
1034 self.run_with_pool_and_gemm_cache(&mut cache, |buffers, cache| {
1035 gemm::dot_general_blas_read_cached(
1036 buffers,
1037 cache,
1038 None,
1039 lhs.clone(),
1040 rhs.clone(),
1041 config,
1042 )
1043 })?
1044 }
1045 #[cfg(not(feature = "cpu-blas"))]
1046 {
1047 return Err(unavailable_cpu_backend_kind(self.kind, "dot_general"));
1048 }
1049 }
1050 };
1051 if let Some(result) = direct {
1052 return Ok(result);
1053 }
1054
1055 let lhs = materialize_tensor_read("dot_general", lhs)?;
1056 let rhs = materialize_tensor_read("dot_general", rhs)?;
1057 BackendCachedDot::dot_general_cached(self, &mut cache, None, &lhs, &rhs, config)
1058 }
1059
1060 fn dot_general_with_conj(
1061 &mut self,
1062 lhs: &Tensor,
1063 rhs: &Tensor,
1064 config: &DotGeneralConfig,
1065 lhs_conj: bool,
1066 rhs_conj: bool,
1067 ) -> crate::Result<Tensor> {
1068 let mut cache = gemm::GemmAnalysisCache::default();
1069 BackendCachedDot::dot_general_with_conj_cached(
1070 self, &mut cache, None, lhs, rhs, config, lhs_conj, rhs_conj,
1071 )
1072 }
1073}
1074
1075impl BackendCachedDot for CpuBackend {
1076 fn dot_general_cached(
1077 &mut self,
1078 cache: &mut Self::RuntimeCache,
1079 cache_slot: Option<usize>,
1080 lhs: &Tensor,
1081 rhs: &Tensor,
1082 config: &DotGeneralConfig,
1083 ) -> crate::Result<Tensor> {
1084 match self.kind {
1085 CpuBackendKind::Faer => {
1086 #[cfg(feature = "cpu-faer")]
1087 {
1088 let ctx = Arc::clone(&self.ctx);
1089 self.install_with_pool_and_gemm_cache(cache, |buffers, cache| {
1090 match (lhs, rhs) {
1091 (Tensor::F32(a), Tensor::F32(b)) => gemm::dot_general_faer_cached(
1092 buffers,
1093 cache,
1094 cache_slot,
1095 ctx.as_ref(),
1096 a,
1097 b,
1098 config,
1099 )
1100 .map(Tensor::F32),
1101 (Tensor::F64(a), Tensor::F64(b)) => gemm::dot_general_faer_cached(
1102 buffers,
1103 cache,
1104 cache_slot,
1105 ctx.as_ref(),
1106 a,
1107 b,
1108 config,
1109 )
1110 .map(Tensor::F64),
1111 (Tensor::C32(a), Tensor::C32(b)) => gemm::dot_general_faer_cached(
1112 buffers,
1113 cache,
1114 cache_slot,
1115 ctx.as_ref(),
1116 a,
1117 b,
1118 config,
1119 )
1120 .map(Tensor::C32),
1121 (Tensor::C64(a), Tensor::C64(b)) => gemm::dot_general_faer_cached(
1122 buffers,
1123 cache,
1124 cache_slot,
1125 ctx.as_ref(),
1126 a,
1127 b,
1128 config,
1129 )
1130 .map(Tensor::C64),
1131 _ => Err(crate::Error::DTypeMismatch {
1132 op: "dot_general",
1133 lhs: lhs.dtype(),
1134 rhs: rhs.dtype(),
1135 }),
1136 }
1137 })
1138 }
1139 #[cfg(not(feature = "cpu-faer"))]
1140 {
1141 Err(unavailable_cpu_backend_kind(self.kind, "dot_general"))
1142 }
1143 }
1144 CpuBackendKind::Blas => {
1145 #[cfg(feature = "cpu-blas")]
1146 {
1147 self.run_with_pool_and_gemm_cache(cache, |buffers, cache| match (lhs, rhs) {
1148 (Tensor::F32(a), Tensor::F32(b)) => {
1149 gemm::dot_general_blas_cached(buffers, cache, cache_slot, a, b, config)
1150 .map(Tensor::F32)
1151 }
1152 (Tensor::F64(a), Tensor::F64(b)) => {
1153 gemm::dot_general_blas_cached(buffers, cache, cache_slot, a, b, config)
1154 .map(Tensor::F64)
1155 }
1156 (Tensor::C32(a), Tensor::C32(b)) => {
1157 gemm::dot_general_blas_cached(buffers, cache, cache_slot, a, b, config)
1158 .map(Tensor::C32)
1159 }
1160 (Tensor::C64(a), Tensor::C64(b)) => {
1161 gemm::dot_general_blas_cached(buffers, cache, cache_slot, a, b, config)
1162 .map(Tensor::C64)
1163 }
1164 _ => Err(crate::Error::DTypeMismatch {
1165 op: "dot_general",
1166 lhs: lhs.dtype(),
1167 rhs: rhs.dtype(),
1168 }),
1169 })
1170 }
1171 #[cfg(not(feature = "cpu-blas"))]
1172 {
1173 Err(unavailable_cpu_backend_kind(self.kind, "dot_general"))
1174 }
1175 }
1176 }
1177 }
1178
1179 fn dot_general_with_conj_cached(
1180 &mut self,
1181 cache: &mut Self::RuntimeCache,
1182 cache_slot: Option<usize>,
1183 lhs: &Tensor,
1184 rhs: &Tensor,
1185 config: &DotGeneralConfig,
1186 lhs_conj: bool,
1187 rhs_conj: bool,
1188 ) -> crate::Result<Tensor> {
1189 match self.kind {
1190 CpuBackendKind::Faer => {
1191 #[cfg(feature = "cpu-faer")]
1192 {
1193 let ctx = Arc::clone(&self.ctx);
1194 self.install_with_pool_and_gemm_cache(cache, |buffers, cache| {
1195 match (lhs, rhs) {
1196 (Tensor::F32(a), Tensor::F32(b)) => {
1197 gemm::dot_general_faer_with_conj_cached(
1198 buffers,
1199 cache,
1200 cache_slot,
1201 ctx.as_ref(),
1202 a,
1203 b,
1204 config,
1205 lhs_conj,
1206 rhs_conj,
1207 )
1208 .map(Tensor::F32)
1209 }
1210 (Tensor::F64(a), Tensor::F64(b)) => {
1211 gemm::dot_general_faer_with_conj_cached(
1212 buffers,
1213 cache,
1214 cache_slot,
1215 ctx.as_ref(),
1216 a,
1217 b,
1218 config,
1219 lhs_conj,
1220 rhs_conj,
1221 )
1222 .map(Tensor::F64)
1223 }
1224 (Tensor::C32(a), Tensor::C32(b)) => {
1225 gemm::dot_general_faer_with_conj_cached(
1226 buffers,
1227 cache,
1228 cache_slot,
1229 ctx.as_ref(),
1230 a,
1231 b,
1232 config,
1233 lhs_conj,
1234 rhs_conj,
1235 )
1236 .map(Tensor::C32)
1237 }
1238 (Tensor::C64(a), Tensor::C64(b)) => {
1239 gemm::dot_general_faer_with_conj_cached(
1240 buffers,
1241 cache,
1242 cache_slot,
1243 ctx.as_ref(),
1244 a,
1245 b,
1246 config,
1247 lhs_conj,
1248 rhs_conj,
1249 )
1250 .map(Tensor::C64)
1251 }
1252 _ => Err(crate::Error::DTypeMismatch {
1253 op: "dot_general",
1254 lhs: lhs.dtype(),
1255 rhs: rhs.dtype(),
1256 }),
1257 }
1258 })
1259 }
1260 #[cfg(not(feature = "cpu-faer"))]
1261 {
1262 Err(unavailable_cpu_backend_kind(self.kind, "dot_general"))
1263 }
1264 }
1265 CpuBackendKind::Blas => {
1266 #[cfg(feature = "cpu-blas")]
1267 {
1268 self.run_with_pool_and_gemm_cache(cache, |buffers, cache| match (lhs, rhs) {
1269 (Tensor::F32(a), Tensor::F32(b)) => {
1270 gemm::dot_general_blas_with_conj_cached(
1271 buffers, cache, cache_slot, a, b, config, lhs_conj, rhs_conj,
1272 )
1273 .map(Tensor::F32)
1274 }
1275 (Tensor::F64(a), Tensor::F64(b)) => {
1276 gemm::dot_general_blas_with_conj_cached(
1277 buffers, cache, cache_slot, a, b, config, lhs_conj, rhs_conj,
1278 )
1279 .map(Tensor::F64)
1280 }
1281 (Tensor::C32(a), Tensor::C32(b)) => {
1282 gemm::dot_general_blas_with_conj_cached(
1283 buffers, cache, cache_slot, a, b, config, lhs_conj, rhs_conj,
1284 )
1285 .map(Tensor::C32)
1286 }
1287 (Tensor::C64(a), Tensor::C64(b)) => {
1288 gemm::dot_general_blas_with_conj_cached(
1289 buffers, cache, cache_slot, a, b, config, lhs_conj, rhs_conj,
1290 )
1291 .map(Tensor::C64)
1292 }
1293 _ => Err(crate::Error::DTypeMismatch {
1294 op: "dot_general",
1295 lhs: lhs.dtype(),
1296 rhs: rhs.dtype(),
1297 }),
1298 })
1299 }
1300 #[cfg(not(feature = "cpu-blas"))]
1301 {
1302 Err(unavailable_cpu_backend_kind(self.kind, "dot_general"))
1303 }
1304 }
1305 }
1306 }
1307}
1308
1309impl TensorIndexing for CpuBackend {
1310 fn gather(
1311 &mut self,
1312 operand: &Tensor,
1313 start_indices: &Tensor,
1314 config: &GatherConfig,
1315 ) -> crate::Result<Tensor> {
1316 self.install_with_pool(|buffers| {
1317 indexing::gather_with_pool(buffers, operand, start_indices, config)
1318 })
1319 }
1320
1321 fn scatter(
1322 &mut self,
1323 operand: &Tensor,
1324 scatter_indices: &Tensor,
1325 updates: &Tensor,
1326 config: &ScatterConfig,
1327 ) -> crate::Result<Tensor> {
1328 self.install_with_pool(|buffers| {
1329 indexing::scatter_with_pool(buffers, operand, scatter_indices, updates, config)
1330 })
1331 }
1332
1333 fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor> {
1334 self.install_with_pool(|buffers| indexing::try_slice_with_pool(buffers, input, config))
1335 }
1336
1337 fn dynamic_slice(
1338 &mut self,
1339 input: &Tensor,
1340 starts: &Tensor,
1341 slice_sizes: &[usize],
1342 ) -> crate::Result<Tensor> {
1343 self.install_with_pool(|buffers| {
1344 indexing::dynamic_slice_with_pool(buffers, input, starts, slice_sizes)
1345 })
1346 }
1347
1348 fn dynamic_update_slice(
1349 &mut self,
1350 operand: &Tensor,
1351 update: &Tensor,
1352 starts: &Tensor,
1353 ) -> crate::Result<Tensor> {
1354 self.install_with_pool(|buffers| {
1355 indexing::dynamic_update_slice_with_pool(buffers, operand, update, starts)
1356 })
1357 }
1358
1359 fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
1360 self.install_with_pool(|buffers| indexing::try_pad_with_pool(buffers, input, config))
1361 }
1362
1363 fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor> {
1364 self.install_with_pool(|buffers| indexing::try_concatenate_with_pool(buffers, inputs, axis))
1365 }
1366
1367 fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
1368 self.install_with_pool(|buffers| indexing::reverse_with_pool(buffers, input, axes))
1369 }
1370}
1371
1372impl BackendSessionHost for CpuBackend {
1373 fn with_backend_session<R: Send>(
1374 &mut self,
1375 f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1376 ) -> R {
1377 let mut cache = profile_cpu_session_section("with_backend_session.cache_default", || {
1378 gemm::GemmAnalysisCache::default()
1379 });
1380 self.with_backend_session_cached(&mut cache, f)
1381 }
1382
1383 fn with_backend_session_cached<R: Send>(
1384 &mut self,
1385 cache: &mut Self::RuntimeCache,
1386 f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1387 ) -> R {
1388 if !cpu_session_profile_enabled() {
1389 let mut buffers = BufferPoolLoan::new(&mut self.buffers);
1390 let ctx = Arc::clone(&self.ctx);
1391 let kind = self.kind;
1392 return ctx.install(|| {
1393 let mut session = CpuExecSession {
1394 ctx: ctx.as_ref(),
1395 buffers: buffers.get_mut(),
1396 gemm_analysis_cache: cache,
1397 kind,
1398 };
1399 f(&mut session)
1400 });
1401 }
1402
1403 let total_started = Instant::now();
1404 let mut buffers =
1405 profile_cpu_session_section("with_backend_session_cached.take_buffers", || {
1406 BufferPoolLoan::new(&mut self.buffers)
1407 });
1408 let ctx = Arc::clone(&self.ctx);
1409 let kind = self.kind;
1410 let result =
1411 profile_cpu_session_section("with_backend_session_cached.exec_session", || {
1412 ctx.install(|| {
1413 let session_started = Instant::now();
1414 let mut session = CpuExecSession {
1415 ctx: ctx.as_ref(),
1416 buffers: buffers.get_mut(),
1417 gemm_analysis_cache: cache,
1418 kind,
1419 };
1420 record_cpu_session_profile(
1421 "with_backend_session_cached.session_construct",
1422 session_started.elapsed(),
1423 );
1424
1425 let exec_started = Instant::now();
1426 let result = f(&mut session);
1427 record_cpu_session_profile(
1428 "with_backend_session_cached.exec_body",
1429 exec_started.elapsed(),
1430 );
1431 result
1432 })
1433 });
1434 profile_cpu_session_section("with_backend_session_cached.restore_buffers", || {
1435 drop(buffers);
1436 });
1437 record_cpu_session_profile("with_backend_session_cached.total", total_started.elapsed());
1438 maybe_print_cpu_session_profile();
1439 result
1440 }
1441}
1442
1443impl TensorBuffer for CpuBackend {
1444 fn reclaim_buffer(&mut self, tensor: Tensor) {
1445 match tensor {
1446 Tensor::F32(t) => reclaim_typed(&mut self.buffers, t),
1447 Tensor::F64(t) => reclaim_typed(&mut self.buffers, t),
1448 Tensor::I32(t) => reclaim_typed(&mut self.buffers, t),
1449 Tensor::I64(t) => reclaim_typed(&mut self.buffers, t),
1450 Tensor::Bool(t) => reclaim_typed(&mut self.buffers, t),
1451 Tensor::C32(t) => reclaim_typed(&mut self.buffers, t),
1452 Tensor::C64(t) => reclaim_typed(&mut self.buffers, t),
1453 }
1454 }
1455}
1456
1457impl<T, R> TensorViewCanonicalization<T, R> for CpuBackend
1458where
1459 T: Clone + 'static,
1460 R: TensorRank,
1461{
1462 fn to_contiguous(
1463 &mut self,
1464 view: &TypedTensorView<'_, T, R>,
1465 ) -> crate::Result<TypedTensor<T, R>> {
1466 if view.backend_buffer().is_some() {
1467 return Err(crate::Error::backend_failure(
1468 "CpuBackend::to_contiguous",
1469 "CPU backend received a backend tensor view; download the tensor to host before CPU view canonicalization",
1470 ));
1471 }
1472 view.to_contiguous()
1473 }
1474
1475 fn copy_from_contiguous(
1476 &mut self,
1477 src: &TypedTensor<T, R>,
1478 dst: &mut TypedTensorViewMut<'_, T, R>,
1479 ) -> crate::Result<()> {
1480 if matches!(src.buffer(), Buffer::Backend(_)) {
1481 return Err(crate::Error::backend_failure(
1482 "CpuBackend::copy_from_contiguous",
1483 "CPU backend received a backend source tensor; download the tensor to host before CPU view copy-back",
1484 ));
1485 }
1486 if dst.backend_buffer().is_some() {
1487 return Err(crate::Error::backend_failure(
1488 "CpuBackend::copy_from_contiguous",
1489 "CPU backend received a backend destination view; download the tensor to host before CPU view copy-back",
1490 ));
1491 }
1492 dst.copy_from_contiguous(src)
1493 }
1494}
1495
1496impl TensorFusion for CpuBackend {
1497 fn execute_broadcast_multiply(
1498 &mut self,
1499 lhs: TensorRead<'_>,
1500 lhs_shape: &[usize],
1501 lhs_dims: &[usize],
1502 rhs: TensorRead<'_>,
1503 rhs_shape: &[usize],
1504 rhs_dims: &[usize],
1505 ) -> crate::Result<Option<Tensor>> {
1506 self.install_with_pool(|buffers| {
1507 elementwise::broadcast_multiply_read_with_pool(
1508 buffers, lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims,
1509 )
1510 })
1511 }
1512
1513 fn execute_broadcast_multiply_value(
1514 &mut self,
1515 lhs: TensorRead<'_>,
1516 lhs_shape: &[usize],
1517 lhs_dims: &[usize],
1518 rhs: TensorRead<'_>,
1519 rhs_shape: &[usize],
1520 rhs_dims: &[usize],
1521 ) -> crate::Result<Option<TensorValue>> {
1522 self.install_with_pool(|buffers| {
1523 elementwise::broadcast_multiply_value_with_pool(
1524 buffers, lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims,
1525 )
1526 })
1527 }
1528}
1529
1530impl TensorDeviceTransfer for CpuBackend {
1531 fn download_to_host(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
1532 if tensor.is_backend_buffer() {
1533 return Err(crate::Error::backend_failure(
1534 "CpuBackend::download_to_host",
1535 "CPU backend received a backend buffer; download the tensor to host with its owning backend before CPU execution",
1536 ));
1537 }
1538 Ok(tensor.clone())
1539 }
1540
1541 fn upload_host_tensor(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
1542 if tensor.is_backend_buffer() {
1543 return Err(crate::Error::backend_failure(
1544 "CpuBackend::upload_host_tensor",
1545 "CPU backend upload_host_tensor expects a host tensor; download backend buffers to host before CPU execution",
1546 ));
1547 }
1548 Ok(tensor.clone())
1549 }
1550}
1551
1552impl TensorBackend for CpuBackend {}
1553
1554pub(crate) fn reclaim_typed<T: PoolScalar>(pool: &mut BufferPool, typed: TypedTensor<T>) {
1555 let (buffer, _, _) = typed.into_parts();
1556 match buffer {
1557 Buffer::Host(data) => T::pool_release(pool, data),
1558 Buffer::Backend(_) => {}
1559 }
1560}
1561
1562impl Default for CpuBackend {
1563 fn default() -> Self {
1564 Self::new()
1565 }
1566}
1567
1568#[cfg(test)]
1569mod tests;