Skip to main content

tenferro_cpu/
buffer_pool.rs

1//! Typed host buffer pooling for reusable tensor allocations.
2//!
3//! # Examples
4//!
5//! ```rust
6//! use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
7//!
8//! let mut pool = BufferPool::new();
9//! let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 4) };
10//! buf.fill(1.0);
11//! <f64 as PoolScalar>::pool_release(&mut pool, buf);
12//! assert_eq!(pool.len(), 1);
13//! ```
14
15use std::collections::BTreeMap;
16use std::env;
17use std::fmt;
18use std::mem::size_of;
19
20use num_complex::{Complex32, Complex64};
21
22use crate::CacheStats;
23
24/// Environment variable overriding the CPU buffer-pool retention cap in bytes.
25///
26/// The value is parsed as an unsigned integer. Invalid values fall back to
27/// [`DEFAULT_MAX_RETAINED_CAPACITY_BYTES`].
28pub const BUFFER_POOL_MAX_RETAINED_BYTES_ENV: &str = "TENFERRO_BUFFER_POOL_MAX_RETAINED_BYTES";
29
30/// Default retained CPU buffer capacity per backend.
31///
32/// The cap keeps long-running workloads from accumulating obsolete buffer
33/// sizes as tensor shapes grow while still preserving reuse for hot working
34/// sets.
35pub const DEFAULT_MAX_RETAINED_CAPACITY_BYTES: usize = 100 * 1024 * 1024;
36
37/// Snapshot of typed host buffers retained by a [`BufferPool`].
38///
39/// `buffers` counts retained `Vec` allocations, while `capacity_bytes` counts
40/// their total element capacity in bytes. Allocators may keep freed memory in
41/// process-local arenas after a pool is cleared, so this reports memory that is
42/// still live in the pool rather than operating-system RSS.
43#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
44pub struct BufferPoolStats {
45    /// Number of retained vector allocations.
46    pub buffers: usize,
47    /// Total retained vector capacity in bytes.
48    pub capacity_bytes: usize,
49}
50
51/// Typed buffer pool keyed by element capacity and separated by scalar type.
52///
53/// Each supported dtype has an independent best-fit pool. Acquired buffers are
54/// returned without zero-initialization so kernels can avoid redundant writes
55/// when they fully overwrite the output. Use [`PoolScalar::pool_acquire_zeroed`]
56/// when the caller may read the buffer before writing every element.
57///
58/// # Examples
59///
60/// ```rust
61/// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
62///
63/// let mut pool = BufferPool::new();
64/// let buf = unsafe { <f32 as PoolScalar>::pool_acquire(&mut pool, 8) };
65/// <f32 as PoolScalar>::pool_release(&mut pool, buf);
66/// assert_eq!(pool.len(), 1);
67/// ```
68pub struct BufferPool {
69    f64_pool: BTreeMap<usize, Vec<Vec<f64>>>,
70    f32_pool: BTreeMap<usize, Vec<Vec<f32>>>,
71    i32_pool: BTreeMap<usize, Vec<Vec<i32>>>,
72    i64_pool: BTreeMap<usize, Vec<Vec<i64>>>,
73    bool_pool: BTreeMap<usize, Vec<Vec<bool>>>,
74    c64_pool: BTreeMap<usize, Vec<Vec<Complex64>>>,
75    c32_pool: BTreeMap<usize, Vec<Vec<Complex32>>>,
76    retained_capacity_bytes: usize,
77    max_retained_capacity_bytes: usize,
78}
79
80impl fmt::Debug for BufferPool {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.debug_struct("BufferPool")
83            .field("stats", &self.stats())
84            .field(
85                "max_retained_capacity_bytes",
86                &self.max_retained_capacity_bytes,
87            )
88            .finish_non_exhaustive()
89    }
90}
91
92/// Scalar types supported by [`BufferPool`].
93///
94/// The trait is sealed to the scalar dtypes that tenferro currently pools for
95/// CPU execution.
96///
97/// # Examples
98///
99/// ```rust
100/// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
101///
102/// let mut pool = BufferPool::new();
103/// let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 2) };
104/// buf.copy_from_slice(&[3.0, 4.0]);
105/// <f64 as PoolScalar>::pool_release(&mut pool, buf);
106/// ```
107pub trait PoolScalar: Copy + Sized + Send + Sync + private::Sealed {
108    /// Zero value used to initialize acquired buffers.
109    fn pool_zero() -> Self;
110
111    /// Acquire a buffer with length `len`.
112    ///
113    /// The vector length is set without initializing its contents. Callers must
114    /// overwrite every element before any read.
115    ///
116    /// # Safety
117    ///
118    /// The returned vector may contain uninitialized or stale elements. Reading
119    /// any element before writing it is undefined behavior.
120    ///
121    /// # Examples
122    ///
123    /// ```rust
124    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
125    ///
126    /// let mut pool = BufferPool::new();
127    /// let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 2) };
128    /// buf.copy_from_slice(&[1.0, 2.0]);
129    /// assert_eq!(buf, vec![1.0, 2.0]);
130    /// ```
131    unsafe fn pool_acquire(pool: &mut BufferPool, len: usize) -> Vec<Self>;
132
133    /// Acquire a buffer with length `len` and every element set to zero.
134    ///
135    /// This is the safe path for callers that may read the buffer before every
136    /// element is overwritten. Prefer [`PoolScalar::pool_acquire`] for kernels
137    /// that perform a full overwrite.
138    ///
139    /// # Examples
140    ///
141    /// ```rust
142    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
143    ///
144    /// let mut pool = BufferPool::new();
145    /// let buf = <f64 as PoolScalar>::pool_acquire_zeroed(&mut pool, 2);
146    /// assert_eq!(buf, vec![0.0, 0.0]);
147    /// ```
148    fn pool_acquire_zeroed(pool: &mut BufferPool, len: usize) -> Vec<Self>;
149
150    /// Return a buffer to the typed pool for later reuse.
151    ///
152    /// Zero-capacity buffers are ignored.
153    ///
154    /// # Examples
155    ///
156    /// ```rust
157    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
158    ///
159    /// let mut pool = BufferPool::new();
160    /// let buf = vec![1.0_f32; 4];
161    /// <f32 as PoolScalar>::pool_release(&mut pool, buf);
162    /// assert_eq!(pool.len(), 1);
163    /// ```
164    fn pool_release(pool: &mut BufferPool, buf: Vec<Self>);
165}
166
167mod private {
168    pub trait Sealed {}
169
170    impl Sealed for f64 {}
171    impl Sealed for f32 {}
172    impl Sealed for i32 {}
173    impl Sealed for i64 {}
174    impl Sealed for bool {}
175    impl Sealed for num_complex::Complex64 {}
176    impl Sealed for num_complex::Complex32 {}
177}
178
179fn take_best_fit<T>(pool: &mut BTreeMap<usize, Vec<Vec<T>>>, len: usize) -> Option<Vec<T>> {
180    let key = *pool.range(len..).next()?.0;
181    let buf = {
182        let vecs = pool.get_mut(&key)?;
183        vecs.pop()
184    };
185    if pool.get(&key).is_some_and(Vec::is_empty) {
186        pool.remove(&key);
187    }
188    buf
189}
190
191fn pool_len<T>(pool: &BTreeMap<usize, Vec<Vec<T>>>) -> usize {
192    pool.values().map(Vec::len).sum()
193}
194
195fn evict_one_from_pool<T>(pool: &mut BTreeMap<usize, Vec<Vec<T>>>) -> Option<usize> {
196    let key = *pool.keys().next()?;
197    let vecs = pool.get_mut(&key)?;
198    let _ = vecs.pop()?;
199    if vecs.is_empty() {
200        pool.remove(&key);
201    }
202    Some(key.saturating_mul(size_of::<T>()))
203}
204
205#[derive(Clone, Copy)]
206enum TypedPoolKind {
207    F64,
208    F32,
209    I32,
210    I64,
211    Bool,
212    C64,
213    C32,
214}
215
216fn smallest_pool_candidate<T>(
217    pool: &BTreeMap<usize, Vec<Vec<T>>>,
218    kind: TypedPoolKind,
219) -> Option<(usize, TypedPoolKind)> {
220    pool.keys()
221        .next()
222        .map(|&capacity| (capacity.saturating_mul(size_of::<T>()), kind))
223}
224
225macro_rules! impl_pool_scalar {
226    ($ty:ty, $field:ident, $zero:expr) => {
227        impl PoolScalar for $ty {
228            fn pool_zero() -> Self {
229                $zero
230            }
231
232            #[allow(clippy::uninit_vec)]
233            unsafe fn pool_acquire(pool: &mut BufferPool, len: usize) -> Vec<Self> {
234                match take_best_fit(&mut pool.$field, len) {
235                    Some(mut buf) => {
236                        pool.retained_capacity_bytes = pool
237                            .retained_capacity_bytes
238                            .saturating_sub(buf.capacity().saturating_mul(size_of::<Self>()));
239                        // SAFETY: raw acquire requires caller full-overwrite; len <= capacity here.
240                        unsafe { buf.set_len(len) };
241                        buf
242                    }
243                    None => {
244                        let mut buf = Vec::with_capacity(len);
245                        // SAFETY: raw acquire requires caller full-overwrite; len == capacity here.
246                        unsafe { buf.set_len(len) };
247                        buf
248                    }
249                }
250            }
251
252            fn pool_acquire_zeroed(pool: &mut BufferPool, len: usize) -> Vec<Self> {
253                match take_best_fit(&mut pool.$field, len) {
254                    Some(mut buf) => {
255                        pool.retained_capacity_bytes = pool
256                            .retained_capacity_bytes
257                            .saturating_sub(buf.capacity().saturating_mul(size_of::<Self>()));
258                        buf.resize(len, Self::pool_zero());
259                        buf.fill(Self::pool_zero());
260                        buf
261                    }
262                    None => vec![Self::pool_zero(); len],
263                }
264            }
265
266            fn pool_release(pool: &mut BufferPool, buf: Vec<Self>) {
267                let cap = buf.capacity();
268                if cap > 0 {
269                    pool.retained_capacity_bytes = pool
270                        .retained_capacity_bytes
271                        .saturating_add(cap.saturating_mul(size_of::<Self>()));
272                    pool.$field.entry(cap).or_default().push(buf);
273                    pool.enforce_retention_limit();
274                }
275            }
276        }
277    };
278}
279
280impl_pool_scalar!(f64, f64_pool, 0.0);
281impl_pool_scalar!(f32, f32_pool, 0.0);
282impl_pool_scalar!(i32, i32_pool, 0);
283impl_pool_scalar!(i64, i64_pool, 0);
284impl_pool_scalar!(bool, bool_pool, false);
285impl_pool_scalar!(Complex64, c64_pool, Complex64::new(0.0, 0.0));
286impl_pool_scalar!(Complex32, c32_pool, Complex32::new(0.0, 0.0));
287
288impl BufferPool {
289    /// Create an empty typed buffer pool.
290    ///
291    /// # Examples
292    ///
293    /// ```rust
294    /// use tenferro_cpu::linalg_interop::BufferPool;
295    ///
296    /// let pool = BufferPool::new();
297    /// assert!(pool.is_empty());
298    /// ```
299    pub fn new() -> Self {
300        Self::with_max_retained_capacity_bytes(default_max_retained_capacity_bytes())
301    }
302
303    /// Create an empty typed buffer pool with a specific retention cap.
304    ///
305    /// A cap of zero disables retention. Use [`BufferPool::unbounded`] only for
306    /// diagnostics or workloads that are externally memory-limited.
307    ///
308    /// # Examples
309    ///
310    /// ```rust
311    /// use tenferro_cpu::linalg_interop::BufferPool;
312    ///
313    /// let pool = BufferPool::with_max_retained_capacity_bytes(1024);
314    /// assert_eq!(pool.max_retained_capacity_bytes(), 1024);
315    /// ```
316    pub fn with_max_retained_capacity_bytes(max_retained_capacity_bytes: usize) -> Self {
317        Self {
318            f64_pool: BTreeMap::new(),
319            f32_pool: BTreeMap::new(),
320            i32_pool: BTreeMap::new(),
321            i64_pool: BTreeMap::new(),
322            bool_pool: BTreeMap::new(),
323            c64_pool: BTreeMap::new(),
324            c32_pool: BTreeMap::new(),
325            retained_capacity_bytes: 0,
326            max_retained_capacity_bytes,
327        }
328    }
329
330    /// Create an empty typed buffer pool without a retention cap.
331    ///
332    /// This preserves the historical behavior and is mainly useful for
333    /// diagnostics or controlled benchmarks.
334    ///
335    /// # Examples
336    ///
337    /// ```rust
338    /// use tenferro_cpu::linalg_interop::BufferPool;
339    ///
340    /// let pool = BufferPool::unbounded();
341    /// assert_eq!(pool.max_retained_capacity_bytes(), usize::MAX);
342    /// ```
343    pub fn unbounded() -> Self {
344        Self::with_max_retained_capacity_bytes(usize::MAX)
345    }
346
347    /// Maximum retained typed host-buffer capacity in bytes.
348    ///
349    /// # Examples
350    ///
351    /// ```rust
352    /// use tenferro_cpu::linalg_interop::BufferPool;
353    ///
354    /// let pool = BufferPool::with_max_retained_capacity_bytes(4096);
355    /// assert_eq!(pool.max_retained_capacity_bytes(), 4096);
356    /// ```
357    pub fn max_retained_capacity_bytes(&self) -> usize {
358        self.max_retained_capacity_bytes
359    }
360
361    /// Update the maximum retained typed host-buffer capacity in bytes.
362    ///
363    /// Shrinking below the currently retained capacity immediately evicts
364    /// retained buffers until the new cap is satisfied. A cap of zero disables
365    /// retention.
366    ///
367    /// # Examples
368    ///
369    /// ```
370    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
371    ///
372    /// let mut pool = BufferPool::with_max_retained_capacity_bytes(1024);
373    /// <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(128));
374    /// pool.set_max_retained_capacity_bytes(0);
375    /// assert_eq!(pool.max_retained_capacity_bytes(), 0);
376    /// assert!(pool.is_empty());
377    /// ```
378    pub fn set_max_retained_capacity_bytes(&mut self, max_retained_capacity_bytes: usize) {
379        self.max_retained_capacity_bytes = max_retained_capacity_bytes;
380        self.enforce_retention_limit();
381    }
382
383    /// Number of retained buffers across all typed pools.
384    ///
385    /// # Examples
386    ///
387    /// ```rust
388    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
389    ///
390    /// let mut pool = BufferPool::new();
391    /// <f64 as PoolScalar>::pool_release(&mut pool, vec![0.0; 2]);
392    /// assert_eq!(pool.len(), 1);
393    /// ```
394    pub fn len(&self) -> usize {
395        self.stats().buffers
396    }
397
398    /// Total retained typed host-buffer capacity in bytes.
399    ///
400    /// This counts capacity that is still live in the pool. The operating
401    /// system RSS may remain high after clearing the pool because the process
402    /// allocator can keep freed pages for future allocations.
403    ///
404    /// # Examples
405    ///
406    /// ```rust
407    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
408    ///
409    /// let mut pool = BufferPool::new();
410    /// <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(2));
411    /// assert_eq!(pool.retained_capacity_bytes(), 16);
412    /// ```
413    pub fn retained_capacity_bytes(&self) -> usize {
414        self.stats().capacity_bytes
415    }
416
417    /// Snapshot retained-buffer count and capacity.
418    ///
419    /// # Examples
420    ///
421    /// ```rust
422    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
423    ///
424    /// let mut pool = BufferPool::new();
425    /// <f32 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(4));
426    /// let stats = pool.stats();
427    /// assert_eq!(stats.buffers, 1);
428    /// assert_eq!(stats.capacity_bytes, 16);
429    /// ```
430    pub fn stats(&self) -> BufferPoolStats {
431        BufferPoolStats {
432            buffers: pool_len(&self.f64_pool)
433                + pool_len(&self.f32_pool)
434                + pool_len(&self.i32_pool)
435                + pool_len(&self.i64_pool)
436                + pool_len(&self.bool_pool)
437                + pool_len(&self.c64_pool)
438                + pool_len(&self.c32_pool),
439            capacity_bytes: self.retained_capacity_bytes,
440        }
441    }
442
443    /// Return cache-style stats for the buffers retained by this pool.
444    ///
445    /// `entries` is the number of retained buffers, and `retained_bytes` is the
446    /// total retained vector capacity in bytes.
447    ///
448    /// # Examples
449    ///
450    /// ```
451    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
452    ///
453    /// let mut pool = BufferPool::new();
454    /// <f32 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(4));
455    /// let stats = pool.cache_stats();
456    /// assert_eq!(stats.entries, 1);
457    /// assert_eq!(stats.retained_bytes, 16);
458    /// ```
459    pub fn cache_stats(&self) -> CacheStats {
460        let stats = self.stats();
461        CacheStats {
462            entries: stats.buffers,
463            retained_bytes: stats.capacity_bytes,
464        }
465    }
466
467    /// Acquire a typed vector with length 0 and at least `cap` capacity.
468    ///
469    /// Returned buffers come from the typed pool when possible and are ready
470    /// for push-based population.
471    ///
472    /// # Examples
473    ///
474    /// ```rust
475    /// use tenferro_cpu::linalg_interop::BufferPool;
476    ///
477    /// let mut pool = BufferPool::new();
478    /// let mut buf = pool.acquire_with_capacity::<f64>(4);
479    /// buf.extend_from_slice(&[1.0, 2.0]);
480    /// assert_eq!(buf.len(), 2);
481    /// assert!(buf.capacity() >= 4);
482    /// ```
483    pub fn acquire_with_capacity<T: PoolScalar>(&mut self, cap: usize) -> Vec<T> {
484        if cap == 0 {
485            return Vec::new();
486        }
487
488        // SAFETY: this push-only capacity helper clears length before any element can be read.
489        let mut buf = unsafe { T::pool_acquire(self, cap) };
490        // SAFETY: shrinking length to zero does not read pooled `Copy` elements.
491        unsafe { buf.set_len(0) };
492        buf
493    }
494
495    /// Acquire a typed vector with length `len` initialized to zero.
496    ///
497    /// Use this only when the caller may read elements before overwriting the
498    /// entire buffer. Full-overwrite kernels should use
499    /// [`PoolScalar::pool_acquire`] to avoid the initialization cost.
500    ///
501    /// # Examples
502    ///
503    /// ```rust
504    /// use tenferro_cpu::linalg_interop::BufferPool;
505    ///
506    /// let mut pool = BufferPool::new();
507    /// let buf = pool.acquire_zeroed::<f32>(3);
508    /// assert_eq!(buf, vec![0.0, 0.0, 0.0]);
509    /// ```
510    pub fn acquire_zeroed<T: PoolScalar>(&mut self, len: usize) -> Vec<T> {
511        T::pool_acquire_zeroed(self, len)
512    }
513
514    /// Whether all typed pools are empty.
515    ///
516    /// # Examples
517    ///
518    /// ```rust
519    /// use tenferro_cpu::linalg_interop::BufferPool;
520    ///
521    /// let pool = BufferPool::new();
522    /// assert!(pool.is_empty());
523    /// ```
524    pub fn is_empty(&self) -> bool {
525        self.f64_pool.is_empty()
526            && self.f32_pool.is_empty()
527            && self.i32_pool.is_empty()
528            && self.i64_pool.is_empty()
529            && self.bool_pool.is_empty()
530            && self.c64_pool.is_empty()
531            && self.c32_pool.is_empty()
532    }
533
534    /// Drop all retained buffers from the pool.
535    ///
536    /// This releases the vectors owned by the pool. The process allocator may
537    /// still keep freed pages mapped for reuse, so operating-system RSS is not
538    /// guaranteed to fall immediately.
539    ///
540    /// # Examples
541    ///
542    /// ```rust
543    /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
544    ///
545    /// let mut pool = BufferPool::new();
546    /// <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(8));
547    /// pool.clear();
548    /// assert!(pool.is_empty());
549    /// ```
550    pub fn clear(&mut self) {
551        self.f64_pool.clear();
552        self.f32_pool.clear();
553        self.i32_pool.clear();
554        self.i64_pool.clear();
555        self.bool_pool.clear();
556        self.c64_pool.clear();
557        self.c32_pool.clear();
558        self.retained_capacity_bytes = 0;
559    }
560
561    fn enforce_retention_limit(&mut self) {
562        while self.retained_capacity_bytes > self.max_retained_capacity_bytes {
563            let Some(evicted_bytes) = self.evict_smallest_retained_buffer() else {
564                self.retained_capacity_bytes = 0;
565                return;
566            };
567            self.retained_capacity_bytes =
568                self.retained_capacity_bytes.saturating_sub(evicted_bytes);
569        }
570    }
571
572    fn evict_smallest_retained_buffer(&mut self) -> Option<usize> {
573        let candidates = [
574            smallest_pool_candidate(&self.f64_pool, TypedPoolKind::F64),
575            smallest_pool_candidate(&self.f32_pool, TypedPoolKind::F32),
576            smallest_pool_candidate(&self.i32_pool, TypedPoolKind::I32),
577            smallest_pool_candidate(&self.i64_pool, TypedPoolKind::I64),
578            smallest_pool_candidate(&self.bool_pool, TypedPoolKind::Bool),
579            smallest_pool_candidate(&self.c64_pool, TypedPoolKind::C64),
580            smallest_pool_candidate(&self.c32_pool, TypedPoolKind::C32),
581        ];
582        let (_, kind) = candidates
583            .into_iter()
584            .flatten()
585            .min_by_key(|(bytes, _)| *bytes)?;
586        match kind {
587            TypedPoolKind::F64 => evict_one_from_pool(&mut self.f64_pool),
588            TypedPoolKind::F32 => evict_one_from_pool(&mut self.f32_pool),
589            TypedPoolKind::I32 => evict_one_from_pool(&mut self.i32_pool),
590            TypedPoolKind::I64 => evict_one_from_pool(&mut self.i64_pool),
591            TypedPoolKind::Bool => evict_one_from_pool(&mut self.bool_pool),
592            TypedPoolKind::C64 => evict_one_from_pool(&mut self.c64_pool),
593            TypedPoolKind::C32 => evict_one_from_pool(&mut self.c32_pool),
594        }
595    }
596}
597
598fn default_max_retained_capacity_bytes() -> usize {
599    env::var(BUFFER_POOL_MAX_RETAINED_BYTES_ENV)
600        .ok()
601        .and_then(|value| value.parse().ok())
602        .unwrap_or(DEFAULT_MAX_RETAINED_CAPACITY_BYTES)
603}
604
605impl Default for BufferPool {
606    fn default() -> Self {
607        Self::new()
608    }
609}
610
611#[cfg(test)]
612mod tests;