Skip to main content

tenferro_cpu/
backend.rs

1use std::cmp::Reverse;
2use std::collections::HashMap;
3use std::env;
4use std::fmt;
5use std::sync::{Arc, Mutex, OnceLock};
6use std::time::{Duration, Instant};
7
8use crate::buffer_pool::{BufferPool, BufferPoolStats, PoolScalar};
9use crate::{
10    Buffer, CacheStats, Tensor, TensorRank, TensorRead, TensorValue, TypedTensor, TypedTensorView,
11    TypedTensorViewMut,
12};
13use tenferro_tensor::{
14    BackendCachedDot, BackendRuntimeCache, BackendSession, BackendSessionHost, TensorAnalytic,
15    TensorBackend, TensorBuffer, TensorDeviceTransfer, TensorDot, TensorElementwise, TensorFusion,
16    TensorIndexing, TensorReduction, TensorStructural, TensorViewCanonicalization,
17};
18use tenferro_tensor::{
19    CompareDir, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
20};
21
22use super::exec_session::CpuExecSession;
23use super::{
24    analytic, elementwise, gemm, indexing, materialize_tensor_read, reduction, structural,
25    CpuContext,
26};
27
28#[derive(Debug, Default, Clone)]
29struct CpuSessionProfileEntry {
30    calls: usize,
31    total_time: Duration,
32}
33
34fn cpu_session_profile_enabled() -> bool {
35    static ENABLED: OnceLock<bool> = OnceLock::new();
36    *ENABLED.get_or_init(|| env::var("TENFERRO_PROFILE_CPU_SESSION").is_ok())
37}
38
39fn cpu_session_profile_print_every() -> Option<usize> {
40    static PRINT_EVERY: OnceLock<Option<usize>> = OnceLock::new();
41    *PRINT_EVERY.get_or_init(|| {
42        env::var("TENFERRO_PROFILE_CPU_SESSION_PRINT_EVERY")
43            .ok()
44            .and_then(|value| value.parse::<usize>().ok())
45            .filter(|&value| value > 0)
46    })
47}
48
49fn cpu_session_profile_state() -> &'static Mutex<HashMap<&'static str, CpuSessionProfileEntry>> {
50    static STATE: OnceLock<Mutex<HashMap<&'static str, CpuSessionProfileEntry>>> = OnceLock::new();
51    STATE.get_or_init(|| Mutex::new(HashMap::new()))
52}
53
54fn record_cpu_session_profile(section: &'static str, elapsed: Duration) {
55    if !cpu_session_profile_enabled() {
56        return;
57    }
58    let Ok(mut state) = cpu_session_profile_state().lock() else {
59        return;
60    };
61    let entry = state.entry(section).or_default();
62    entry.calls += 1;
63    entry.total_time += elapsed;
64}
65
66fn profile_cpu_session_section<T>(section: &'static str, f: impl FnOnce() -> T) -> T {
67    if !cpu_session_profile_enabled() {
68        return f();
69    }
70    let started = Instant::now();
71    let result = f();
72    record_cpu_session_profile(section, started.elapsed());
73    result
74}
75
76fn maybe_print_cpu_session_profile() {
77    let Some(print_every) = cpu_session_profile_print_every() else {
78        return;
79    };
80    let should_print = {
81        let Ok(state) = cpu_session_profile_state().lock() else {
82            return;
83        };
84        state
85            .get("with_backend_session_cached.total")
86            .is_some_and(|entry| entry.calls % print_every == 0)
87    };
88    if !should_print {
89        return;
90    }
91    let mut entries = {
92        let Ok(mut state) = cpu_session_profile_state().lock() else {
93            return;
94        };
95        let entries = state
96            .iter()
97            .map(|(section, entry)| (*section, entry.clone()))
98            .collect::<Vec<_>>();
99        state.clear();
100        entries
101    };
102    entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
103    eprintln!("=== tenferro CPU session profile ===");
104    for (section, entry) in entries {
105        eprintln!(
106            "{section}: calls={} total={:.6}ms per_call={:.3}us",
107            entry.calls,
108            entry.total_time.as_secs_f64() * 1.0e3,
109            entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64,
110        );
111    }
112}
113
114struct BufferPoolLoan<'a> {
115    target: &'a mut BufferPool,
116    buffers: Option<BufferPool>,
117}
118
119impl<'a> BufferPoolLoan<'a> {
120    fn new(target: &'a mut BufferPool) -> Self {
121        Self {
122            buffers: Some(std::mem::take(target)),
123            target,
124        }
125    }
126
127    fn get_mut(&mut self) -> &mut BufferPool {
128        self.buffers
129            .as_mut()
130            .expect("buffer pool loan already restored")
131    }
132}
133
134impl Drop for BufferPoolLoan<'_> {
135    fn drop(&mut self) {
136        if let Some(buffers) = self.buffers.take() {
137            *self.target = buffers;
138        }
139    }
140}
141
142/// CPU provider selected by a [`CpuBackend`] instance.
143///
144/// CPU provider features are additive at compile time; this runtime selector
145/// chooses which compiled provider an individual backend uses for provider-owned
146/// kernels such as GEMM.
147///
148/// # Examples
149///
150/// ```
151/// use tenferro_cpu::CpuBackendKind;
152///
153/// let kind = CpuBackendKind::default_compiled();
154/// assert!(matches!(kind, CpuBackendKind::Faer | CpuBackendKind::Blas));
155/// ```
156#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
157pub enum CpuBackendKind {
158    /// faer-backed CPU kernels.
159    Faer,
160    /// BLAS/LAPACK-backed CPU kernels.
161    Blas,
162}
163
164impl CpuBackendKind {
165    /// Return the default compiled CPU provider.
166    ///
167    /// BLAS is preferred when both BLAS and faer are compiled in because an
168    /// application that links a BLAS/LAPACK provider normally expects
169    /// provider-backed kernels to use it by default.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use tenferro_cpu::CpuBackendKind;
175    ///
176    /// let _kind = CpuBackendKind::default_compiled();
177    /// ```
178    pub fn default_compiled() -> Self {
179        #[cfg(feature = "cpu-blas")]
180        {
181            Self::Blas
182        }
183        #[cfg(all(not(feature = "cpu-blas"), feature = "cpu-faer"))]
184        {
185            Self::Faer
186        }
187    }
188
189    // Used by feature-specific diagnostics; some feature combinations leave
190    // the formatter path inactive.
191    #[allow(dead_code)]
192    pub(crate) fn name(self) -> &'static str {
193        match self {
194            Self::Faer => "faer",
195            Self::Blas => "blas",
196        }
197    }
198}
199
200fn ensure_cpu_backend_kind_available(kind: CpuBackendKind, op: &'static str) -> crate::Result<()> {
201    let _ = op;
202    match kind {
203        CpuBackendKind::Faer => {
204            #[cfg(feature = "cpu-faer")]
205            {
206                Ok(())
207            }
208            #[cfg(not(feature = "cpu-faer"))]
209            {
210                Err(crate::Error::InvalidConfig {
211                    op,
212                    message: "CpuBackendKind::Faer requires the cpu-faer feature".to_string(),
213                })
214            }
215        }
216        CpuBackendKind::Blas => {
217            #[cfg(feature = "cpu-blas")]
218            {
219                Ok(())
220            }
221            #[cfg(not(feature = "cpu-blas"))]
222            {
223                Err(crate::Error::InvalidConfig {
224                    op,
225                    message: "CpuBackendKind::Blas requires the cpu-blas feature".to_string(),
226                })
227            }
228        }
229    }
230}
231
232// Used by feature-disabled backend paths; a given feature build may compile no
233// direct call site for one provider.
234#[allow(dead_code)]
235pub(super) fn unavailable_cpu_backend_kind(kind: CpuBackendKind, op: &'static str) -> crate::Error {
236    crate::Error::InvalidConfig {
237        op,
238        message: format!("CPU backend kind {} is not compiled in", kind.name()),
239    }
240}
241
242/// CPU execution backend.
243///
244/// # Examples
245///
246/// ```
247/// use tenferro_cpu::CpuBackend;
248///
249/// let backend = CpuBackend::new();
250/// ```
251pub struct CpuBackend {
252    pub(crate) ctx: Arc<CpuContext>,
253    pub(crate) buffers: BufferPool,
254    kind: CpuBackendKind,
255}
256
257impl fmt::Debug for CpuBackend {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        f.debug_struct("CpuBackend")
260            .field("kind", &self.kind)
261            .field("num_threads", &self.num_threads())
262            .field("buffer_pool_cache_stats", &self.buffer_pool_cache_stats())
263            .field("buffer_pool_limit_bytes", &self.buffer_pool_limit_bytes())
264            .finish_non_exhaustive()
265    }
266}
267
268impl CpuBackend {
269    /// Create a CPU backend using the environment-driven CPU context.
270    ///
271    /// # Examples
272    ///
273    /// ```
274    /// use tenferro_cpu::CpuBackend;
275    ///
276    /// let backend = CpuBackend::new();
277    /// ```
278    pub fn new() -> Self {
279        Self::from_context(Arc::new(CpuContext::from_env()))
280    }
281
282    /// Create a CPU backend using the selected compiled provider.
283    ///
284    /// # Examples
285    ///
286    /// ```
287    /// use tenferro_cpu::{CpuBackend, CpuBackendKind};
288    ///
289    /// let backend = CpuBackend::with_kind(CpuBackendKind::default_compiled()).unwrap();
290    /// assert_eq!(backend.kind(), CpuBackendKind::default_compiled());
291    /// ```
292    pub fn with_kind(kind: CpuBackendKind) -> crate::Result<Self> {
293        Self::try_from_context_with_kind(Arc::new(CpuContext::from_env()), kind)
294    }
295
296    /// Try to create a CPU backend using `RAYON_NUM_THREADS`.
297    ///
298    /// # Examples
299    ///
300    /// ```
301    /// use tenferro_cpu::CpuBackend;
302    ///
303    /// let backend = CpuBackend::try_new()
304    ///     .unwrap_or_else(|_| CpuBackend::with_threads(1).unwrap());
305    /// let _ = backend.num_threads();
306    /// ```
307    pub fn try_new() -> crate::Result<Self> {
308        CpuContext::try_from_env().map(|ctx| Self::from_context(Arc::new(ctx)))
309    }
310
311    /// Create a CPU backend from an existing context.
312    ///
313    /// # Examples
314    ///
315    /// ```
316    /// use std::sync::Arc;
317    /// use tenferro_cpu::{CpuBackend, CpuContext};
318    ///
319    /// let ctx = Arc::new(CpuContext::with_threads(2).unwrap());
320    /// let backend = CpuBackend::from_context(ctx);
321    /// assert_eq!(backend.num_threads(), 2);
322    /// ```
323    pub fn from_context(ctx: Arc<CpuContext>) -> Self {
324        Self {
325            ctx,
326            buffers: BufferPool::new(),
327            kind: CpuBackendKind::default_compiled(),
328        }
329    }
330
331    fn try_from_context_with_kind(
332        ctx: Arc<CpuContext>,
333        kind: CpuBackendKind,
334    ) -> crate::Result<Self> {
335        ensure_cpu_backend_kind_available(kind, "CpuBackend::with_kind")?;
336        Ok(Self {
337            ctx,
338            buffers: BufferPool::new(),
339            kind,
340        })
341    }
342
343    /// Create a CPU backend from an existing context and buffer-pool retention cap.
344    ///
345    /// The cap is measured in retained vector capacity bytes. A cap of zero
346    /// disables buffer retention.
347    ///
348    /// # Examples
349    ///
350    /// ```
351    /// use std::sync::Arc;
352    /// use tenferro_cpu::{CpuBackend, CpuContext};
353    ///
354    /// let ctx = Arc::new(CpuContext::with_threads(1).unwrap());
355    /// let backend = CpuBackend::from_context_with_buffer_pool_limit(ctx, 0);
356    /// assert_eq!(backend.buffer_pool_limit_bytes(), 0);
357    /// ```
358    pub fn from_context_with_buffer_pool_limit(
359        ctx: Arc<CpuContext>,
360        max_retained_capacity_bytes: usize,
361    ) -> Self {
362        Self::from_context_with_buffer_pool_limit_and_kind(
363            ctx,
364            max_retained_capacity_bytes,
365            CpuBackendKind::default_compiled(),
366        )
367    }
368
369    fn from_context_with_buffer_pool_limit_and_kind(
370        ctx: Arc<CpuContext>,
371        max_retained_capacity_bytes: usize,
372        kind: CpuBackendKind,
373    ) -> Self {
374        Self {
375            ctx,
376            buffers: BufferPool::with_max_retained_capacity_bytes(max_retained_capacity_bytes),
377            kind,
378        }
379    }
380
381    /// Create a CPU backend with a custom thread count.
382    ///
383    /// # Examples
384    ///
385    /// ```
386    /// use tenferro_cpu::CpuBackend;
387    ///
388    /// let backend = CpuBackend::with_threads(2).unwrap();
389    /// assert_eq!(backend.num_threads(), 2);
390    /// ```
391    ///
392    /// # Errors
393    ///
394    /// Returns an error when `num_threads` is zero or Rayon rejects the pool.
395    pub fn with_threads(num_threads: usize) -> crate::Result<Self> {
396        CpuContext::with_threads(num_threads)
397            .map(|ctx| Self::from_context(Arc::new(ctx)))
398            .map_err(|err| match err {
399                crate::Error::InvalidConfig { message, .. } => crate::Error::InvalidConfig {
400                    op: "CpuBackend::with_threads",
401                    message,
402                },
403                crate::Error::BackendFailure { message, .. } => {
404                    crate::Error::backend_failure("CpuBackend::with_threads", message)
405                }
406                err => err,
407            })
408    }
409
410    /// Create a CPU backend with a custom thread count and provider.
411    ///
412    /// # Examples
413    ///
414    /// ```
415    /// use tenferro_cpu::{CpuBackend, CpuBackendKind};
416    ///
417    /// let backend = CpuBackend::with_threads_and_kind(
418    ///     1,
419    ///     CpuBackendKind::default_compiled(),
420    /// )?;
421    /// assert_eq!(backend.num_threads(), 1);
422    /// # Ok::<(), tenferro_tensor::Error>(())
423    /// ```
424    ///
425    /// # Errors
426    ///
427    /// Returns an error when `num_threads` is zero, Rayon rejects the pool, or
428    /// the selected provider is unavailable.
429    pub fn with_threads_and_kind(num_threads: usize, kind: CpuBackendKind) -> crate::Result<Self> {
430        ensure_cpu_backend_kind_available(kind, "CpuBackend::with_threads_and_kind")?;
431        CpuContext::with_threads(num_threads)
432            .map(|ctx| Self {
433                ctx: Arc::new(ctx),
434                buffers: BufferPool::new(),
435                kind,
436            })
437            .map_err(|err| match err {
438                crate::Error::InvalidConfig { message, .. } => crate::Error::InvalidConfig {
439                    op: "CpuBackend::with_threads_and_kind",
440                    message,
441                },
442                crate::Error::BackendFailure { message, .. } => {
443                    crate::Error::backend_failure("CpuBackend::with_threads_and_kind", message)
444                }
445                err => err,
446            })
447    }
448
449    /// Return the runtime CPU provider selected by this backend.
450    ///
451    /// # Examples
452    ///
453    /// ```
454    /// use tenferro_cpu::{CpuBackend, CpuBackendKind};
455    ///
456    /// let backend = CpuBackend::new();
457    /// assert_eq!(backend.kind(), CpuBackendKind::default_compiled());
458    /// ```
459    pub fn kind(&self) -> CpuBackendKind {
460        self.kind
461    }
462
463    /// Return the number of threads in this backend's CPU context.
464    ///
465    /// # Examples
466    ///
467    /// ```
468    /// use tenferro_cpu::CpuBackend;
469    ///
470    /// let backend = CpuBackend::with_threads(2).unwrap();
471    /// assert_eq!(backend.num_threads(), 2);
472    /// ```
473    pub fn num_threads(&self) -> usize {
474        self.ctx.num_threads()
475    }
476
477    /// Number of retained typed host buffers currently held by this backend.
478    ///
479    /// # Examples
480    ///
481    /// ```
482    /// use tenferro_cpu::CpuBackend;
483    ///
484    /// let backend = CpuBackend::new();
485    /// assert_eq!(backend.buffer_pool_len(), 0);
486    /// ```
487    pub fn buffer_pool_len(&self) -> usize {
488        self.buffers.len()
489    }
490
491    /// Snapshot reusable typed host buffers currently retained by this backend.
492    ///
493    /// # Examples
494    ///
495    /// ```
496    /// use tenferro_cpu::CpuBackend;
497    ///
498    /// let backend = CpuBackend::new();
499    /// let stats = backend.buffer_pool_stats();
500    /// assert_eq!(stats.buffers, 0);
501    /// assert_eq!(stats.capacity_bytes, 0);
502    /// ```
503    pub fn buffer_pool_stats(&self) -> BufferPoolStats {
504        self.buffers.stats()
505    }
506
507    /// Return cache-style stats for the CPU buffer pool.
508    ///
509    /// # Examples
510    ///
511    /// ```
512    /// use tenferro_cpu::CpuBackend;
513    ///
514    /// let backend = CpuBackend::new();
515    /// let stats = backend.buffer_pool_cache_stats();
516    /// assert_eq!(stats.entries, 0);
517    /// assert_eq!(stats.retained_bytes, 0);
518    /// ```
519    pub fn buffer_pool_cache_stats(&self) -> CacheStats {
520        self.buffers.cache_stats()
521    }
522
523    /// Current CPU buffer-pool retention limit in bytes.
524    ///
525    /// # Examples
526    ///
527    /// ```
528    /// use std::sync::Arc;
529    /// use tenferro_cpu::{CpuBackend, CpuContext};
530    ///
531    /// let backend = CpuBackend::from_context_with_buffer_pool_limit(
532    ///     Arc::new(CpuContext::with_threads(1).unwrap()),
533    ///     4096,
534    /// );
535    /// assert_eq!(backend.buffer_pool_limit_bytes(), 4096);
536    /// ```
537    pub fn buffer_pool_limit_bytes(&self) -> usize {
538        self.buffers.max_retained_capacity_bytes()
539    }
540
541    /// Update the CPU buffer-pool retention limit in bytes.
542    ///
543    /// Shrinking the limit evicts retained buffers immediately. A limit of zero
544    /// disables buffer retention.
545    ///
546    /// # Examples
547    ///
548    /// ```
549    /// use tenferro_cpu::CpuBackend;
550    ///
551    /// let mut backend = CpuBackend::new();
552    /// backend.set_buffer_pool_limit_bytes(0);
553    /// assert_eq!(backend.buffer_pool_limit_bytes(), 0);
554    /// assert_eq!(backend.buffer_pool_len(), 0);
555    /// ```
556    pub fn set_buffer_pool_limit_bytes(&mut self, max_retained_capacity_bytes: usize) {
557        self.buffers
558            .set_max_retained_capacity_bytes(max_retained_capacity_bytes);
559    }
560
561    /// Reset reusable typed host buffers currently retained by this backend.
562    ///
563    /// This releases pool-owned vectors to the process allocator. Operating
564    /// system RSS may not fall immediately because allocators can retain freed
565    /// pages for future allocations.
566    ///
567    /// # Examples
568    ///
569    /// ```
570    /// use tenferro_cpu::CpuBackend;
571    ///
572    /// let mut backend = CpuBackend::new();
573    /// backend.reset_buffer_pool();
574    /// assert_eq!(backend.buffer_pool_len(), 0);
575    /// ```
576    pub fn reset_buffer_pool(&mut self) {
577        self.buffers.clear();
578    }
579
580    /// Run a closure in this backend's CPU execution scope.
581    ///
582    /// # Examples
583    ///
584    /// ```
585    /// use tenferro_cpu::CpuBackend;
586    ///
587    /// let backend = CpuBackend::with_threads(1).unwrap();
588    /// let value = backend.install(|| 1 + 1);
589    /// assert_eq!(value, 2);
590    /// ```
591    pub fn install<R: Send>(&self, op: impl FnOnce() -> R + Send) -> R {
592        self.ctx.install(op)
593    }
594
595    fn install_with_pool<R: Send>(&mut self, op: impl FnOnce(&mut BufferPool) -> R + Send) -> R {
596        let mut buffers = BufferPoolLoan::new(&mut self.buffers);
597        let ctx = Arc::clone(&self.ctx);
598        ctx.install(|| op(buffers.get_mut()))
599    }
600
601    // Selected when the BLAS provider is active; default Faer-only builds keep
602    // it dormant.
603    #[allow(dead_code)]
604    fn run_with_pool<R>(&mut self, op: impl FnOnce(&mut BufferPool) -> R) -> R {
605        let mut buffers = BufferPoolLoan::new(&mut self.buffers);
606        op(buffers.get_mut())
607    }
608
609    fn linalg_with_pool<R: Send>(&mut self, op: impl FnOnce(&mut BufferPool) -> R + Send) -> R {
610        match self.kind {
611            CpuBackendKind::Faer => self.install_with_pool(op),
612            CpuBackendKind::Blas => self.run_with_pool(op),
613        }
614    }
615
616    /// Run an external linalg implementation with this backend's buffer pool.
617    ///
618    /// This is exposed for operation-family crates that own their backend
619    /// implementation while still sharing the CPU backend's allocation pool.
620    #[doc(hidden)]
621    pub fn with_linalg_pool<R: Send>(&mut self, op: impl FnOnce(&mut BufferPool) -> R + Send) -> R {
622        self.linalg_with_pool(op)
623    }
624
625    /// Clone the CPU context used by external linalg implementations.
626    #[cfg(feature = "cpu-faer")]
627    #[doc(hidden)]
628    pub fn linalg_context(&self) -> Arc<CpuContext> {
629        Arc::clone(&self.ctx)
630    }
631
632    // Selected when the Faer provider handles cached GEMM execution; some
633    // feature combinations compile only the uncached or BLAS path.
634    #[allow(dead_code)]
635    fn install_with_pool_and_gemm_cache<R: Send>(
636        &mut self,
637        gemm_analysis_cache: &mut gemm::GemmAnalysisCache,
638        op: impl FnOnce(&mut BufferPool, &mut gemm::GemmAnalysisCache) -> R + Send,
639    ) -> R {
640        let mut buffers = BufferPoolLoan::new(&mut self.buffers);
641        let ctx = Arc::clone(&self.ctx);
642        ctx.install(|| op(buffers.get_mut(), gemm_analysis_cache))
643    }
644
645    // Selected when the BLAS provider handles cached GEMM execution; default
646    // Faer-only builds keep it dormant.
647    #[allow(dead_code)]
648    fn run_with_pool_and_gemm_cache<R>(
649        &mut self,
650        gemm_analysis_cache: &mut gemm::GemmAnalysisCache,
651        op: impl FnOnce(&mut BufferPool, &mut gemm::GemmAnalysisCache) -> R,
652    ) -> R {
653        let mut buffers = BufferPoolLoan::new(&mut self.buffers);
654        op(buffers.get_mut(), gemm_analysis_cache)
655    }
656}
657
658impl BackendRuntimeCache for CpuBackend {
659    type RuntimeCache = gemm::GemmAnalysisCache;
660}
661
662impl TensorElementwise for CpuBackend {
663    fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
664        self.install_with_pool(|buffers| elementwise::add_with_pool(buffers, lhs, rhs))
665    }
666
667    fn add_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
668        self.install_with_pool(|buffers| elementwise::add_read_with_pool(buffers, lhs, rhs))
669    }
670
671    fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
672        self.install_with_pool(|buffers| elementwise::mul_with_pool(buffers, lhs, rhs))
673    }
674
675    fn mul_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
676        self.install_with_pool(|buffers| elementwise::mul_read_with_pool(buffers, lhs, rhs))
677    }
678
679    fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor> {
680        self.install_with_pool(|buffers| elementwise::neg_with_pool(buffers, input))
681    }
682
683    fn neg_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
684        self.install_with_pool(|buffers| elementwise::neg_read_with_pool(buffers, input))
685    }
686
687    fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor> {
688        self.install_with_pool(|buffers| elementwise::conj_with_pool(buffers, input))
689    }
690
691    fn conj_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
692        self.install_with_pool(|buffers| elementwise::conj_read_with_pool(buffers, input))
693    }
694
695    fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
696        self.install_with_pool(|buffers| elementwise::div_with_pool(buffers, lhs, rhs))
697    }
698
699    fn div_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
700        self.install_with_pool(|buffers| elementwise::div_read_with_pool(buffers, lhs, rhs))
701    }
702
703    fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor> {
704        self.install_with_pool(|buffers| elementwise::abs_with_pool(buffers, input))
705    }
706
707    fn abs_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
708        self.install_with_pool(|buffers| elementwise::abs_read_with_pool(buffers, input))
709    }
710
711    fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor> {
712        self.install_with_pool(|buffers| elementwise::sign_with_pool(buffers, input))
713    }
714
715    fn sign_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
716        self.install_with_pool(|buffers| elementwise::sign_read_with_pool(buffers, input))
717    }
718
719    fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
720        self.install_with_pool(|buffers| elementwise::maximum_with_pool(buffers, lhs, rhs))
721    }
722
723    fn maximum_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
724        self.install_with_pool(|buffers| elementwise::maximum_read_with_pool(buffers, lhs, rhs))
725    }
726
727    fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
728        self.install_with_pool(|buffers| elementwise::minimum_with_pool(buffers, lhs, rhs))
729    }
730
731    fn minimum_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
732        self.install_with_pool(|buffers| elementwise::minimum_read_with_pool(buffers, lhs, rhs))
733    }
734
735    fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor> {
736        self.install_with_pool(|buffers| elementwise::compare_with_pool(buffers, lhs, rhs, dir))
737    }
738
739    fn compare_read(
740        &mut self,
741        lhs: TensorRead<'_>,
742        rhs: TensorRead<'_>,
743        dir: &CompareDir,
744    ) -> crate::Result<Tensor> {
745        self.install_with_pool(|buffers| {
746            elementwise::compare_read_with_pool(buffers, lhs, rhs, dir)
747        })
748    }
749
750    fn select(
751        &mut self,
752        pred: &Tensor,
753        on_true: &Tensor,
754        on_false: &Tensor,
755    ) -> crate::Result<Tensor> {
756        self.install_with_pool(|buffers| {
757            elementwise::select_with_pool(buffers, pred, on_true, on_false)
758        })
759    }
760
761    fn select_read(
762        &mut self,
763        pred: TensorRead<'_>,
764        on_true: TensorRead<'_>,
765        on_false: TensorRead<'_>,
766    ) -> crate::Result<Tensor> {
767        self.install_with_pool(|buffers| {
768            elementwise::select_read_with_pool(buffers, pred, on_true, on_false)
769        })
770    }
771
772    fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor> {
773        self.install_with_pool(|buffers| elementwise::clamp_with_pool(buffers, input, lower, upper))
774    }
775
776    fn clamp_read(
777        &mut self,
778        input: TensorRead<'_>,
779        lower: TensorRead<'_>,
780        upper: TensorRead<'_>,
781    ) -> crate::Result<Tensor> {
782        self.install_with_pool(|buffers| {
783            elementwise::clamp_read_with_pool(buffers, input, lower, upper)
784        })
785    }
786}
787
788impl TensorAnalytic for CpuBackend {
789    fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor> {
790        self.install_with_pool(|buffers| analytic::exp_with_pool(buffers, input))
791    }
792
793    fn exp_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
794        self.install_with_pool(|buffers| analytic::exp_read_with_pool(buffers, input))
795    }
796
797    fn log(&mut self, input: &Tensor) -> crate::Result<Tensor> {
798        self.install_with_pool(|buffers| analytic::log_with_pool(buffers, input))
799    }
800
801    fn log_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
802        self.install_with_pool(|buffers| analytic::log_read_with_pool(buffers, input))
803    }
804
805    fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor> {
806        self.install_with_pool(|buffers| analytic::sin_with_pool(buffers, input))
807    }
808
809    fn sin_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
810        self.install_with_pool(|buffers| analytic::sin_read_with_pool(buffers, input))
811    }
812
813    fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor> {
814        self.install_with_pool(|buffers| analytic::cos_with_pool(buffers, input))
815    }
816
817    fn cos_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
818        self.install_with_pool(|buffers| analytic::cos_read_with_pool(buffers, input))
819    }
820
821    fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor> {
822        self.install_with_pool(|buffers| analytic::tanh_with_pool(buffers, input))
823    }
824
825    fn tanh_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
826        self.install_with_pool(|buffers| analytic::tanh_read_with_pool(buffers, input))
827    }
828
829    fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor> {
830        self.install_with_pool(|buffers| analytic::sqrt_with_pool(buffers, input))
831    }
832
833    fn sqrt_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
834        self.install_with_pool(|buffers| analytic::sqrt_read_with_pool(buffers, input))
835    }
836
837    fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor> {
838        self.install_with_pool(|buffers| analytic::rsqrt_with_pool(buffers, input))
839    }
840
841    fn rsqrt_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
842        self.install_with_pool(|buffers| analytic::rsqrt_read_with_pool(buffers, input))
843    }
844
845    fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor> {
846        self.install_with_pool(|buffers| analytic::pow_with_pool(buffers, lhs, rhs))
847    }
848
849    fn pow_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
850        self.install_with_pool(|buffers| analytic::pow_read_with_pool(buffers, lhs, rhs))
851    }
852
853    fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor> {
854        self.install_with_pool(|buffers| analytic::expm1_with_pool(buffers, input))
855    }
856
857    fn expm1_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
858        self.install_with_pool(|buffers| analytic::expm1_read_with_pool(buffers, input))
859    }
860
861    fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor> {
862        self.install_with_pool(|buffers| analytic::log1p_with_pool(buffers, input))
863    }
864
865    fn log1p_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
866        self.install_with_pool(|buffers| analytic::log1p_read_with_pool(buffers, input))
867    }
868}
869
870impl TensorStructural for CpuBackend {
871    fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor> {
872        self.install_with_pool(|buffers| structural::transpose_with_pool(buffers, input, perm))
873    }
874
875    fn transpose_read(&mut self, input: TensorRead<'_>, perm: &[usize]) -> crate::Result<Tensor> {
876        if let Some(input) = input.as_tensor() {
877            return self.transpose(input, perm);
878        }
879
880        let input = materialize_tensor_read("transpose", input)?;
881        self.transpose(&input, perm)
882    }
883
884    fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor> {
885        self.install(|| structural::reshape(input, shape))
886    }
887
888    fn reshape_read(&mut self, input: TensorRead<'_>, shape: &[usize]) -> crate::Result<Tensor> {
889        if let Some(input) = input.as_tensor() {
890            return self.reshape(input, shape);
891        }
892
893        let input = materialize_tensor_read("reshape", input)?;
894        self.reshape(&input, shape)
895    }
896
897    fn broadcast_in_dim(
898        &mut self,
899        input: &Tensor,
900        shape: &[usize],
901        dims: &[usize],
902    ) -> crate::Result<Tensor> {
903        self.install_with_pool(|buffers| {
904            structural::broadcast_in_dim_with_pool(buffers, input, shape, dims)
905        })
906    }
907
908    fn broadcast_in_dim_read(
909        &mut self,
910        input: TensorRead<'_>,
911        shape: &[usize],
912        dims: &[usize],
913    ) -> crate::Result<Tensor> {
914        if let Some(input) = input.as_tensor() {
915            return self.broadcast_in_dim(input, shape, dims);
916        }
917
918        let input = materialize_tensor_read("broadcast_in_dim", input)?;
919        self.broadcast_in_dim(&input, shape, dims)
920    }
921
922    fn cast(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor> {
923        self.install_with_pool(|buffers| structural::cast_with_pool(buffers, input, to))
924    }
925
926    fn extract_diagonal(
927        &mut self,
928        input: &Tensor,
929        axis_a: usize,
930        axis_b: usize,
931    ) -> crate::Result<Tensor> {
932        self.install_with_pool(|buffers| {
933            structural::extract_diagonal_with_pool(buffers, input, axis_a, axis_b)
934        })
935    }
936
937    fn embed_diagonal(
938        &mut self,
939        input: &Tensor,
940        axis_a: usize,
941        axis_b: usize,
942    ) -> crate::Result<Tensor> {
943        self.install_with_pool(|buffers| {
944            structural::embed_diagonal_with_pool(buffers, input, axis_a, axis_b)
945        })
946    }
947
948    fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor> {
949        self.install_with_pool(|buffers| structural::tril_with_pool(buffers, input, k))
950    }
951
952    fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor> {
953        self.install_with_pool(|buffers| structural::triu_with_pool(buffers, input, k))
954    }
955}
956
957impl TensorReduction for CpuBackend {
958    fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
959        self.install(|| reduction::reduce_sum(input, axes))
960    }
961
962    fn reduce_sum_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
963        self.install(|| reduction::reduce_sum_read(input, axes))
964    }
965
966    fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
967        self.install(|| reduction::reduce_prod(input, axes))
968    }
969
970    fn reduce_prod_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
971        self.install(|| reduction::reduce_prod_read(input, axes))
972    }
973
974    fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
975        self.install(|| reduction::reduce_max(input, axes))
976    }
977
978    fn reduce_max_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
979        self.install(|| reduction::reduce_max_read(input, axes))
980    }
981
982    fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
983        self.install(|| reduction::reduce_min(input, axes))
984    }
985
986    fn reduce_min_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
987        self.install(|| reduction::reduce_min_read(input, axes))
988    }
989}
990
991impl TensorDot for CpuBackend {
992    fn dot_general(
993        &mut self,
994        lhs: &Tensor,
995        rhs: &Tensor,
996        config: &DotGeneralConfig,
997    ) -> crate::Result<Tensor> {
998        let mut cache = gemm::GemmAnalysisCache::default();
999        BackendCachedDot::dot_general_cached(self, &mut cache, None, lhs, rhs, config)
1000    }
1001
1002    fn dot_general_read(
1003        &mut self,
1004        lhs: TensorRead<'_>,
1005        rhs: TensorRead<'_>,
1006        config: &DotGeneralConfig,
1007    ) -> crate::Result<Tensor> {
1008        let mut cache = gemm::GemmAnalysisCache::default();
1009        let direct = match self.kind {
1010            CpuBackendKind::Faer => {
1011                #[cfg(feature = "cpu-faer")]
1012                {
1013                    let ctx = Arc::clone(&self.ctx);
1014                    self.install_with_pool_and_gemm_cache(&mut cache, |buffers, cache| {
1015                        gemm::dot_general_faer_read_cached(
1016                            buffers,
1017                            cache,
1018                            None,
1019                            ctx.as_ref(),
1020                            lhs.clone(),
1021                            rhs.clone(),
1022                            config,
1023                        )
1024                    })?
1025                }
1026                #[cfg(not(feature = "cpu-faer"))]
1027                {
1028                    return Err(unavailable_cpu_backend_kind(self.kind, "dot_general"));
1029                }
1030            }
1031            CpuBackendKind::Blas => {
1032                #[cfg(feature = "cpu-blas")]
1033                {
1034                    self.run_with_pool_and_gemm_cache(&mut cache, |buffers, cache| {
1035                        gemm::dot_general_blas_read_cached(
1036                            buffers,
1037                            cache,
1038                            None,
1039                            lhs.clone(),
1040                            rhs.clone(),
1041                            config,
1042                        )
1043                    })?
1044                }
1045                #[cfg(not(feature = "cpu-blas"))]
1046                {
1047                    return Err(unavailable_cpu_backend_kind(self.kind, "dot_general"));
1048                }
1049            }
1050        };
1051        if let Some(result) = direct {
1052            return Ok(result);
1053        }
1054
1055        let lhs = materialize_tensor_read("dot_general", lhs)?;
1056        let rhs = materialize_tensor_read("dot_general", rhs)?;
1057        BackendCachedDot::dot_general_cached(self, &mut cache, None, &lhs, &rhs, config)
1058    }
1059
1060    fn dot_general_with_conj(
1061        &mut self,
1062        lhs: &Tensor,
1063        rhs: &Tensor,
1064        config: &DotGeneralConfig,
1065        lhs_conj: bool,
1066        rhs_conj: bool,
1067    ) -> crate::Result<Tensor> {
1068        let mut cache = gemm::GemmAnalysisCache::default();
1069        BackendCachedDot::dot_general_with_conj_cached(
1070            self, &mut cache, None, lhs, rhs, config, lhs_conj, rhs_conj,
1071        )
1072    }
1073}
1074
1075impl BackendCachedDot for CpuBackend {
1076    fn dot_general_cached(
1077        &mut self,
1078        cache: &mut Self::RuntimeCache,
1079        cache_slot: Option<usize>,
1080        lhs: &Tensor,
1081        rhs: &Tensor,
1082        config: &DotGeneralConfig,
1083    ) -> crate::Result<Tensor> {
1084        match self.kind {
1085            CpuBackendKind::Faer => {
1086                #[cfg(feature = "cpu-faer")]
1087                {
1088                    let ctx = Arc::clone(&self.ctx);
1089                    self.install_with_pool_and_gemm_cache(cache, |buffers, cache| {
1090                        match (lhs, rhs) {
1091                            (Tensor::F32(a), Tensor::F32(b)) => gemm::dot_general_faer_cached(
1092                                buffers,
1093                                cache,
1094                                cache_slot,
1095                                ctx.as_ref(),
1096                                a,
1097                                b,
1098                                config,
1099                            )
1100                            .map(Tensor::F32),
1101                            (Tensor::F64(a), Tensor::F64(b)) => gemm::dot_general_faer_cached(
1102                                buffers,
1103                                cache,
1104                                cache_slot,
1105                                ctx.as_ref(),
1106                                a,
1107                                b,
1108                                config,
1109                            )
1110                            .map(Tensor::F64),
1111                            (Tensor::C32(a), Tensor::C32(b)) => gemm::dot_general_faer_cached(
1112                                buffers,
1113                                cache,
1114                                cache_slot,
1115                                ctx.as_ref(),
1116                                a,
1117                                b,
1118                                config,
1119                            )
1120                            .map(Tensor::C32),
1121                            (Tensor::C64(a), Tensor::C64(b)) => gemm::dot_general_faer_cached(
1122                                buffers,
1123                                cache,
1124                                cache_slot,
1125                                ctx.as_ref(),
1126                                a,
1127                                b,
1128                                config,
1129                            )
1130                            .map(Tensor::C64),
1131                            _ => Err(crate::Error::DTypeMismatch {
1132                                op: "dot_general",
1133                                lhs: lhs.dtype(),
1134                                rhs: rhs.dtype(),
1135                            }),
1136                        }
1137                    })
1138                }
1139                #[cfg(not(feature = "cpu-faer"))]
1140                {
1141                    Err(unavailable_cpu_backend_kind(self.kind, "dot_general"))
1142                }
1143            }
1144            CpuBackendKind::Blas => {
1145                #[cfg(feature = "cpu-blas")]
1146                {
1147                    self.run_with_pool_and_gemm_cache(cache, |buffers, cache| match (lhs, rhs) {
1148                        (Tensor::F32(a), Tensor::F32(b)) => {
1149                            gemm::dot_general_blas_cached(buffers, cache, cache_slot, a, b, config)
1150                                .map(Tensor::F32)
1151                        }
1152                        (Tensor::F64(a), Tensor::F64(b)) => {
1153                            gemm::dot_general_blas_cached(buffers, cache, cache_slot, a, b, config)
1154                                .map(Tensor::F64)
1155                        }
1156                        (Tensor::C32(a), Tensor::C32(b)) => {
1157                            gemm::dot_general_blas_cached(buffers, cache, cache_slot, a, b, config)
1158                                .map(Tensor::C32)
1159                        }
1160                        (Tensor::C64(a), Tensor::C64(b)) => {
1161                            gemm::dot_general_blas_cached(buffers, cache, cache_slot, a, b, config)
1162                                .map(Tensor::C64)
1163                        }
1164                        _ => Err(crate::Error::DTypeMismatch {
1165                            op: "dot_general",
1166                            lhs: lhs.dtype(),
1167                            rhs: rhs.dtype(),
1168                        }),
1169                    })
1170                }
1171                #[cfg(not(feature = "cpu-blas"))]
1172                {
1173                    Err(unavailable_cpu_backend_kind(self.kind, "dot_general"))
1174                }
1175            }
1176        }
1177    }
1178
1179    fn dot_general_with_conj_cached(
1180        &mut self,
1181        cache: &mut Self::RuntimeCache,
1182        cache_slot: Option<usize>,
1183        lhs: &Tensor,
1184        rhs: &Tensor,
1185        config: &DotGeneralConfig,
1186        lhs_conj: bool,
1187        rhs_conj: bool,
1188    ) -> crate::Result<Tensor> {
1189        match self.kind {
1190            CpuBackendKind::Faer => {
1191                #[cfg(feature = "cpu-faer")]
1192                {
1193                    let ctx = Arc::clone(&self.ctx);
1194                    self.install_with_pool_and_gemm_cache(cache, |buffers, cache| {
1195                        match (lhs, rhs) {
1196                            (Tensor::F32(a), Tensor::F32(b)) => {
1197                                gemm::dot_general_faer_with_conj_cached(
1198                                    buffers,
1199                                    cache,
1200                                    cache_slot,
1201                                    ctx.as_ref(),
1202                                    a,
1203                                    b,
1204                                    config,
1205                                    lhs_conj,
1206                                    rhs_conj,
1207                                )
1208                                .map(Tensor::F32)
1209                            }
1210                            (Tensor::F64(a), Tensor::F64(b)) => {
1211                                gemm::dot_general_faer_with_conj_cached(
1212                                    buffers,
1213                                    cache,
1214                                    cache_slot,
1215                                    ctx.as_ref(),
1216                                    a,
1217                                    b,
1218                                    config,
1219                                    lhs_conj,
1220                                    rhs_conj,
1221                                )
1222                                .map(Tensor::F64)
1223                            }
1224                            (Tensor::C32(a), Tensor::C32(b)) => {
1225                                gemm::dot_general_faer_with_conj_cached(
1226                                    buffers,
1227                                    cache,
1228                                    cache_slot,
1229                                    ctx.as_ref(),
1230                                    a,
1231                                    b,
1232                                    config,
1233                                    lhs_conj,
1234                                    rhs_conj,
1235                                )
1236                                .map(Tensor::C32)
1237                            }
1238                            (Tensor::C64(a), Tensor::C64(b)) => {
1239                                gemm::dot_general_faer_with_conj_cached(
1240                                    buffers,
1241                                    cache,
1242                                    cache_slot,
1243                                    ctx.as_ref(),
1244                                    a,
1245                                    b,
1246                                    config,
1247                                    lhs_conj,
1248                                    rhs_conj,
1249                                )
1250                                .map(Tensor::C64)
1251                            }
1252                            _ => Err(crate::Error::DTypeMismatch {
1253                                op: "dot_general",
1254                                lhs: lhs.dtype(),
1255                                rhs: rhs.dtype(),
1256                            }),
1257                        }
1258                    })
1259                }
1260                #[cfg(not(feature = "cpu-faer"))]
1261                {
1262                    Err(unavailable_cpu_backend_kind(self.kind, "dot_general"))
1263                }
1264            }
1265            CpuBackendKind::Blas => {
1266                #[cfg(feature = "cpu-blas")]
1267                {
1268                    self.run_with_pool_and_gemm_cache(cache, |buffers, cache| match (lhs, rhs) {
1269                        (Tensor::F32(a), Tensor::F32(b)) => {
1270                            gemm::dot_general_blas_with_conj_cached(
1271                                buffers, cache, cache_slot, a, b, config, lhs_conj, rhs_conj,
1272                            )
1273                            .map(Tensor::F32)
1274                        }
1275                        (Tensor::F64(a), Tensor::F64(b)) => {
1276                            gemm::dot_general_blas_with_conj_cached(
1277                                buffers, cache, cache_slot, a, b, config, lhs_conj, rhs_conj,
1278                            )
1279                            .map(Tensor::F64)
1280                        }
1281                        (Tensor::C32(a), Tensor::C32(b)) => {
1282                            gemm::dot_general_blas_with_conj_cached(
1283                                buffers, cache, cache_slot, a, b, config, lhs_conj, rhs_conj,
1284                            )
1285                            .map(Tensor::C32)
1286                        }
1287                        (Tensor::C64(a), Tensor::C64(b)) => {
1288                            gemm::dot_general_blas_with_conj_cached(
1289                                buffers, cache, cache_slot, a, b, config, lhs_conj, rhs_conj,
1290                            )
1291                            .map(Tensor::C64)
1292                        }
1293                        _ => Err(crate::Error::DTypeMismatch {
1294                            op: "dot_general",
1295                            lhs: lhs.dtype(),
1296                            rhs: rhs.dtype(),
1297                        }),
1298                    })
1299                }
1300                #[cfg(not(feature = "cpu-blas"))]
1301                {
1302                    Err(unavailable_cpu_backend_kind(self.kind, "dot_general"))
1303                }
1304            }
1305        }
1306    }
1307}
1308
1309impl TensorIndexing for CpuBackend {
1310    fn gather(
1311        &mut self,
1312        operand: &Tensor,
1313        start_indices: &Tensor,
1314        config: &GatherConfig,
1315    ) -> crate::Result<Tensor> {
1316        self.install_with_pool(|buffers| {
1317            indexing::gather_with_pool(buffers, operand, start_indices, config)
1318        })
1319    }
1320
1321    fn scatter(
1322        &mut self,
1323        operand: &Tensor,
1324        scatter_indices: &Tensor,
1325        updates: &Tensor,
1326        config: &ScatterConfig,
1327    ) -> crate::Result<Tensor> {
1328        self.install_with_pool(|buffers| {
1329            indexing::scatter_with_pool(buffers, operand, scatter_indices, updates, config)
1330        })
1331    }
1332
1333    fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor> {
1334        self.install_with_pool(|buffers| indexing::try_slice_with_pool(buffers, input, config))
1335    }
1336
1337    fn dynamic_slice(
1338        &mut self,
1339        input: &Tensor,
1340        starts: &Tensor,
1341        slice_sizes: &[usize],
1342    ) -> crate::Result<Tensor> {
1343        self.install_with_pool(|buffers| {
1344            indexing::dynamic_slice_with_pool(buffers, input, starts, slice_sizes)
1345        })
1346    }
1347
1348    fn dynamic_update_slice(
1349        &mut self,
1350        operand: &Tensor,
1351        update: &Tensor,
1352        starts: &Tensor,
1353    ) -> crate::Result<Tensor> {
1354        self.install_with_pool(|buffers| {
1355            indexing::dynamic_update_slice_with_pool(buffers, operand, update, starts)
1356        })
1357    }
1358
1359    fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor> {
1360        self.install_with_pool(|buffers| indexing::try_pad_with_pool(buffers, input, config))
1361    }
1362
1363    fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor> {
1364        self.install_with_pool(|buffers| indexing::try_concatenate_with_pool(buffers, inputs, axis))
1365    }
1366
1367    fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor> {
1368        self.install_with_pool(|buffers| indexing::reverse_with_pool(buffers, input, axes))
1369    }
1370}
1371
1372impl BackendSessionHost for CpuBackend {
1373    fn with_backend_session<R: Send>(
1374        &mut self,
1375        f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1376    ) -> R {
1377        let mut cache = profile_cpu_session_section("with_backend_session.cache_default", || {
1378            gemm::GemmAnalysisCache::default()
1379        });
1380        self.with_backend_session_cached(&mut cache, f)
1381    }
1382
1383    fn with_backend_session_cached<R: Send>(
1384        &mut self,
1385        cache: &mut Self::RuntimeCache,
1386        f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1387    ) -> R {
1388        if !cpu_session_profile_enabled() {
1389            let mut buffers = BufferPoolLoan::new(&mut self.buffers);
1390            let ctx = Arc::clone(&self.ctx);
1391            let kind = self.kind;
1392            return ctx.install(|| {
1393                let mut session = CpuExecSession {
1394                    ctx: ctx.as_ref(),
1395                    buffers: buffers.get_mut(),
1396                    gemm_analysis_cache: cache,
1397                    kind,
1398                };
1399                f(&mut session)
1400            });
1401        }
1402
1403        let total_started = Instant::now();
1404        let mut buffers =
1405            profile_cpu_session_section("with_backend_session_cached.take_buffers", || {
1406                BufferPoolLoan::new(&mut self.buffers)
1407            });
1408        let ctx = Arc::clone(&self.ctx);
1409        let kind = self.kind;
1410        let result =
1411            profile_cpu_session_section("with_backend_session_cached.exec_session", || {
1412                ctx.install(|| {
1413                    let session_started = Instant::now();
1414                    let mut session = CpuExecSession {
1415                        ctx: ctx.as_ref(),
1416                        buffers: buffers.get_mut(),
1417                        gemm_analysis_cache: cache,
1418                        kind,
1419                    };
1420                    record_cpu_session_profile(
1421                        "with_backend_session_cached.session_construct",
1422                        session_started.elapsed(),
1423                    );
1424
1425                    let exec_started = Instant::now();
1426                    let result = f(&mut session);
1427                    record_cpu_session_profile(
1428                        "with_backend_session_cached.exec_body",
1429                        exec_started.elapsed(),
1430                    );
1431                    result
1432                })
1433            });
1434        profile_cpu_session_section("with_backend_session_cached.restore_buffers", || {
1435            drop(buffers);
1436        });
1437        record_cpu_session_profile("with_backend_session_cached.total", total_started.elapsed());
1438        maybe_print_cpu_session_profile();
1439        result
1440    }
1441}
1442
1443impl TensorBuffer for CpuBackend {
1444    fn reclaim_buffer(&mut self, tensor: Tensor) {
1445        match tensor {
1446            Tensor::F32(t) => reclaim_typed(&mut self.buffers, t),
1447            Tensor::F64(t) => reclaim_typed(&mut self.buffers, t),
1448            Tensor::I32(t) => reclaim_typed(&mut self.buffers, t),
1449            Tensor::I64(t) => reclaim_typed(&mut self.buffers, t),
1450            Tensor::Bool(t) => reclaim_typed(&mut self.buffers, t),
1451            Tensor::C32(t) => reclaim_typed(&mut self.buffers, t),
1452            Tensor::C64(t) => reclaim_typed(&mut self.buffers, t),
1453        }
1454    }
1455}
1456
1457impl<T, R> TensorViewCanonicalization<T, R> for CpuBackend
1458where
1459    T: Clone + 'static,
1460    R: TensorRank,
1461{
1462    fn to_contiguous(
1463        &mut self,
1464        view: &TypedTensorView<'_, T, R>,
1465    ) -> crate::Result<TypedTensor<T, R>> {
1466        if view.backend_buffer().is_some() {
1467            return Err(crate::Error::backend_failure(
1468                "CpuBackend::to_contiguous",
1469                "CPU backend received a backend tensor view; download the tensor to host before CPU view canonicalization",
1470            ));
1471        }
1472        view.to_contiguous()
1473    }
1474
1475    fn copy_from_contiguous(
1476        &mut self,
1477        src: &TypedTensor<T, R>,
1478        dst: &mut TypedTensorViewMut<'_, T, R>,
1479    ) -> crate::Result<()> {
1480        if matches!(src.buffer(), Buffer::Backend(_)) {
1481            return Err(crate::Error::backend_failure(
1482                "CpuBackend::copy_from_contiguous",
1483                "CPU backend received a backend source tensor; download the tensor to host before CPU view copy-back",
1484            ));
1485        }
1486        if dst.backend_buffer().is_some() {
1487            return Err(crate::Error::backend_failure(
1488                "CpuBackend::copy_from_contiguous",
1489                "CPU backend received a backend destination view; download the tensor to host before CPU view copy-back",
1490            ));
1491        }
1492        dst.copy_from_contiguous(src)
1493    }
1494}
1495
1496impl TensorFusion for CpuBackend {
1497    fn execute_broadcast_multiply(
1498        &mut self,
1499        lhs: TensorRead<'_>,
1500        lhs_shape: &[usize],
1501        lhs_dims: &[usize],
1502        rhs: TensorRead<'_>,
1503        rhs_shape: &[usize],
1504        rhs_dims: &[usize],
1505    ) -> crate::Result<Option<Tensor>> {
1506        self.install_with_pool(|buffers| {
1507            elementwise::broadcast_multiply_read_with_pool(
1508                buffers, lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims,
1509            )
1510        })
1511    }
1512
1513    fn execute_broadcast_multiply_value(
1514        &mut self,
1515        lhs: TensorRead<'_>,
1516        lhs_shape: &[usize],
1517        lhs_dims: &[usize],
1518        rhs: TensorRead<'_>,
1519        rhs_shape: &[usize],
1520        rhs_dims: &[usize],
1521    ) -> crate::Result<Option<TensorValue>> {
1522        self.install_with_pool(|buffers| {
1523            elementwise::broadcast_multiply_value_with_pool(
1524                buffers, lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims,
1525            )
1526        })
1527    }
1528}
1529
1530impl TensorDeviceTransfer for CpuBackend {
1531    fn download_to_host(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
1532        if tensor.is_backend_buffer() {
1533            return Err(crate::Error::backend_failure(
1534                "CpuBackend::download_to_host",
1535                "CPU backend received a backend buffer; download the tensor to host with its owning backend before CPU execution",
1536            ));
1537        }
1538        Ok(tensor.clone())
1539    }
1540
1541    fn upload_host_tensor(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
1542        if tensor.is_backend_buffer() {
1543            return Err(crate::Error::backend_failure(
1544                "CpuBackend::upload_host_tensor",
1545                "CPU backend upload_host_tensor expects a host tensor; download backend buffers to host before CPU execution",
1546            ));
1547        }
1548        Ok(tensor.clone())
1549    }
1550}
1551
1552impl TensorBackend for CpuBackend {}
1553
1554pub(crate) fn reclaim_typed<T: PoolScalar>(pool: &mut BufferPool, typed: TypedTensor<T>) {
1555    let (buffer, _, _) = typed.into_parts();
1556    match buffer {
1557        Buffer::Host(data) => T::pool_release(pool, data),
1558        Buffer::Backend(_) => {}
1559    }
1560}
1561
1562impl Default for CpuBackend {
1563    fn default() -> Self {
1564        Self::new()
1565    }
1566}
1567
1568#[cfg(test)]
1569mod tests;