tenferro_cpu/buffer_pool.rs
1//! Typed host buffer pooling for reusable tensor allocations.
2//!
3//! # Examples
4//!
5//! ```rust
6//! use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
7//!
8//! let mut pool = BufferPool::new();
9//! let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 4) };
10//! buf.fill(1.0);
11//! <f64 as PoolScalar>::pool_release(&mut pool, buf);
12//! assert_eq!(pool.len(), 1);
13//! ```
14
15use std::collections::BTreeMap;
16use std::env;
17use std::fmt;
18use std::mem::size_of;
19
20use num_complex::{Complex32, Complex64};
21
22use crate::CacheStats;
23
24/// Environment variable overriding the CPU buffer-pool retention cap in bytes.
25///
26/// The value is parsed as an unsigned integer. Invalid values fall back to
27/// [`DEFAULT_MAX_RETAINED_CAPACITY_BYTES`].
28pub const BUFFER_POOL_MAX_RETAINED_BYTES_ENV: &str = "TENFERRO_BUFFER_POOL_MAX_RETAINED_BYTES";
29
30/// Default retained CPU buffer capacity per backend.
31///
32/// The cap keeps long-running workloads from accumulating obsolete buffer
33/// sizes as tensor shapes grow while still preserving reuse for hot working
34/// sets.
35pub const DEFAULT_MAX_RETAINED_CAPACITY_BYTES: usize = 100 * 1024 * 1024;
36
37/// Snapshot of typed host buffers retained by a [`BufferPool`].
38///
39/// `buffers` counts retained `Vec` allocations, while `capacity_bytes` counts
40/// their total element capacity in bytes. Allocators may keep freed memory in
41/// process-local arenas after a pool is cleared, so this reports memory that is
42/// still live in the pool rather than operating-system RSS.
43#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
44pub struct BufferPoolStats {
45 /// Number of retained vector allocations.
46 pub buffers: usize,
47 /// Total retained vector capacity in bytes.
48 pub capacity_bytes: usize,
49}
50
51/// Typed buffer pool keyed by element capacity and separated by scalar type.
52///
53/// Each supported dtype has an independent best-fit pool. Acquired buffers are
54/// returned without zero-initialization so kernels can avoid redundant writes
55/// when they fully overwrite the output. Use [`PoolScalar::pool_acquire_zeroed`]
56/// when the caller may read the buffer before writing every element.
57///
58/// # Examples
59///
60/// ```rust
61/// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
62///
63/// let mut pool = BufferPool::new();
64/// let buf = unsafe { <f32 as PoolScalar>::pool_acquire(&mut pool, 8) };
65/// <f32 as PoolScalar>::pool_release(&mut pool, buf);
66/// assert_eq!(pool.len(), 1);
67/// ```
68pub struct BufferPool {
69 f64_pool: BTreeMap<usize, Vec<Vec<f64>>>,
70 f32_pool: BTreeMap<usize, Vec<Vec<f32>>>,
71 i32_pool: BTreeMap<usize, Vec<Vec<i32>>>,
72 i64_pool: BTreeMap<usize, Vec<Vec<i64>>>,
73 bool_pool: BTreeMap<usize, Vec<Vec<bool>>>,
74 c64_pool: BTreeMap<usize, Vec<Vec<Complex64>>>,
75 c32_pool: BTreeMap<usize, Vec<Vec<Complex32>>>,
76 retained_capacity_bytes: usize,
77 max_retained_capacity_bytes: usize,
78}
79
80impl fmt::Debug for BufferPool {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.debug_struct("BufferPool")
83 .field("stats", &self.stats())
84 .field(
85 "max_retained_capacity_bytes",
86 &self.max_retained_capacity_bytes,
87 )
88 .finish_non_exhaustive()
89 }
90}
91
92/// Scalar types supported by [`BufferPool`].
93///
94/// The trait is sealed to the scalar dtypes that tenferro currently pools for
95/// CPU execution.
96///
97/// # Examples
98///
99/// ```rust
100/// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
101///
102/// let mut pool = BufferPool::new();
103/// let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 2) };
104/// buf.copy_from_slice(&[3.0, 4.0]);
105/// <f64 as PoolScalar>::pool_release(&mut pool, buf);
106/// ```
107pub trait PoolScalar: Copy + Sized + Send + Sync + private::Sealed {
108 /// Zero value used to initialize acquired buffers.
109 fn pool_zero() -> Self;
110
111 /// Acquire a buffer with length `len`.
112 ///
113 /// The vector length is set without initializing its contents. Callers must
114 /// overwrite every element before any read.
115 ///
116 /// # Safety
117 ///
118 /// The returned vector may contain uninitialized or stale elements. Reading
119 /// any element before writing it is undefined behavior.
120 ///
121 /// # Examples
122 ///
123 /// ```rust
124 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
125 ///
126 /// let mut pool = BufferPool::new();
127 /// let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 2) };
128 /// buf.copy_from_slice(&[1.0, 2.0]);
129 /// assert_eq!(buf, vec![1.0, 2.0]);
130 /// ```
131 unsafe fn pool_acquire(pool: &mut BufferPool, len: usize) -> Vec<Self>;
132
133 /// Acquire a buffer with length `len` and every element set to zero.
134 ///
135 /// This is the safe path for callers that may read the buffer before every
136 /// element is overwritten. Prefer [`PoolScalar::pool_acquire`] for kernels
137 /// that perform a full overwrite.
138 ///
139 /// # Examples
140 ///
141 /// ```rust
142 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
143 ///
144 /// let mut pool = BufferPool::new();
145 /// let buf = <f64 as PoolScalar>::pool_acquire_zeroed(&mut pool, 2);
146 /// assert_eq!(buf, vec![0.0, 0.0]);
147 /// ```
148 fn pool_acquire_zeroed(pool: &mut BufferPool, len: usize) -> Vec<Self>;
149
150 /// Return a buffer to the typed pool for later reuse.
151 ///
152 /// Zero-capacity buffers are ignored.
153 ///
154 /// # Examples
155 ///
156 /// ```rust
157 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
158 ///
159 /// let mut pool = BufferPool::new();
160 /// let buf = vec![1.0_f32; 4];
161 /// <f32 as PoolScalar>::pool_release(&mut pool, buf);
162 /// assert_eq!(pool.len(), 1);
163 /// ```
164 fn pool_release(pool: &mut BufferPool, buf: Vec<Self>);
165}
166
167mod private {
168 pub trait Sealed {}
169
170 impl Sealed for f64 {}
171 impl Sealed for f32 {}
172 impl Sealed for i32 {}
173 impl Sealed for i64 {}
174 impl Sealed for bool {}
175 impl Sealed for num_complex::Complex64 {}
176 impl Sealed for num_complex::Complex32 {}
177}
178
179fn take_best_fit<T>(pool: &mut BTreeMap<usize, Vec<Vec<T>>>, len: usize) -> Option<Vec<T>> {
180 let key = *pool.range(len..).next()?.0;
181 let buf = {
182 let vecs = pool.get_mut(&key)?;
183 vecs.pop()
184 };
185 if pool.get(&key).is_some_and(Vec::is_empty) {
186 pool.remove(&key);
187 }
188 buf
189}
190
191fn pool_len<T>(pool: &BTreeMap<usize, Vec<Vec<T>>>) -> usize {
192 pool.values().map(Vec::len).sum()
193}
194
195fn evict_one_from_pool<T>(pool: &mut BTreeMap<usize, Vec<Vec<T>>>) -> Option<usize> {
196 let key = *pool.keys().next()?;
197 let vecs = pool.get_mut(&key)?;
198 let _ = vecs.pop()?;
199 if vecs.is_empty() {
200 pool.remove(&key);
201 }
202 Some(key.saturating_mul(size_of::<T>()))
203}
204
205#[derive(Clone, Copy)]
206enum TypedPoolKind {
207 F64,
208 F32,
209 I32,
210 I64,
211 Bool,
212 C64,
213 C32,
214}
215
216fn smallest_pool_candidate<T>(
217 pool: &BTreeMap<usize, Vec<Vec<T>>>,
218 kind: TypedPoolKind,
219) -> Option<(usize, TypedPoolKind)> {
220 pool.keys()
221 .next()
222 .map(|&capacity| (capacity.saturating_mul(size_of::<T>()), kind))
223}
224
225macro_rules! impl_pool_scalar {
226 ($ty:ty, $field:ident, $zero:expr) => {
227 impl PoolScalar for $ty {
228 fn pool_zero() -> Self {
229 $zero
230 }
231
232 #[allow(clippy::uninit_vec)]
233 unsafe fn pool_acquire(pool: &mut BufferPool, len: usize) -> Vec<Self> {
234 match take_best_fit(&mut pool.$field, len) {
235 Some(mut buf) => {
236 pool.retained_capacity_bytes = pool
237 .retained_capacity_bytes
238 .saturating_sub(buf.capacity().saturating_mul(size_of::<Self>()));
239 // SAFETY: raw acquire requires caller full-overwrite; len <= capacity here.
240 unsafe { buf.set_len(len) };
241 buf
242 }
243 None => {
244 let mut buf = Vec::with_capacity(len);
245 // SAFETY: raw acquire requires caller full-overwrite; len == capacity here.
246 unsafe { buf.set_len(len) };
247 buf
248 }
249 }
250 }
251
252 fn pool_acquire_zeroed(pool: &mut BufferPool, len: usize) -> Vec<Self> {
253 match take_best_fit(&mut pool.$field, len) {
254 Some(mut buf) => {
255 pool.retained_capacity_bytes = pool
256 .retained_capacity_bytes
257 .saturating_sub(buf.capacity().saturating_mul(size_of::<Self>()));
258 buf.resize(len, Self::pool_zero());
259 buf.fill(Self::pool_zero());
260 buf
261 }
262 None => vec![Self::pool_zero(); len],
263 }
264 }
265
266 fn pool_release(pool: &mut BufferPool, buf: Vec<Self>) {
267 let cap = buf.capacity();
268 if cap > 0 {
269 pool.retained_capacity_bytes = pool
270 .retained_capacity_bytes
271 .saturating_add(cap.saturating_mul(size_of::<Self>()));
272 pool.$field.entry(cap).or_default().push(buf);
273 pool.enforce_retention_limit();
274 }
275 }
276 }
277 };
278}
279
280impl_pool_scalar!(f64, f64_pool, 0.0);
281impl_pool_scalar!(f32, f32_pool, 0.0);
282impl_pool_scalar!(i32, i32_pool, 0);
283impl_pool_scalar!(i64, i64_pool, 0);
284impl_pool_scalar!(bool, bool_pool, false);
285impl_pool_scalar!(Complex64, c64_pool, Complex64::new(0.0, 0.0));
286impl_pool_scalar!(Complex32, c32_pool, Complex32::new(0.0, 0.0));
287
288impl BufferPool {
289 /// Create an empty typed buffer pool.
290 ///
291 /// # Examples
292 ///
293 /// ```rust
294 /// use tenferro_cpu::linalg_interop::BufferPool;
295 ///
296 /// let pool = BufferPool::new();
297 /// assert!(pool.is_empty());
298 /// ```
299 pub fn new() -> Self {
300 Self::with_max_retained_capacity_bytes(default_max_retained_capacity_bytes())
301 }
302
303 /// Create an empty typed buffer pool with a specific retention cap.
304 ///
305 /// A cap of zero disables retention. Use [`BufferPool::unbounded`] only for
306 /// diagnostics or workloads that are externally memory-limited.
307 ///
308 /// # Examples
309 ///
310 /// ```rust
311 /// use tenferro_cpu::linalg_interop::BufferPool;
312 ///
313 /// let pool = BufferPool::with_max_retained_capacity_bytes(1024);
314 /// assert_eq!(pool.max_retained_capacity_bytes(), 1024);
315 /// ```
316 pub fn with_max_retained_capacity_bytes(max_retained_capacity_bytes: usize) -> Self {
317 Self {
318 f64_pool: BTreeMap::new(),
319 f32_pool: BTreeMap::new(),
320 i32_pool: BTreeMap::new(),
321 i64_pool: BTreeMap::new(),
322 bool_pool: BTreeMap::new(),
323 c64_pool: BTreeMap::new(),
324 c32_pool: BTreeMap::new(),
325 retained_capacity_bytes: 0,
326 max_retained_capacity_bytes,
327 }
328 }
329
330 /// Create an empty typed buffer pool without a retention cap.
331 ///
332 /// This preserves the historical behavior and is mainly useful for
333 /// diagnostics or controlled benchmarks.
334 ///
335 /// # Examples
336 ///
337 /// ```rust
338 /// use tenferro_cpu::linalg_interop::BufferPool;
339 ///
340 /// let pool = BufferPool::unbounded();
341 /// assert_eq!(pool.max_retained_capacity_bytes(), usize::MAX);
342 /// ```
343 pub fn unbounded() -> Self {
344 Self::with_max_retained_capacity_bytes(usize::MAX)
345 }
346
347 /// Maximum retained typed host-buffer capacity in bytes.
348 ///
349 /// # Examples
350 ///
351 /// ```rust
352 /// use tenferro_cpu::linalg_interop::BufferPool;
353 ///
354 /// let pool = BufferPool::with_max_retained_capacity_bytes(4096);
355 /// assert_eq!(pool.max_retained_capacity_bytes(), 4096);
356 /// ```
357 pub fn max_retained_capacity_bytes(&self) -> usize {
358 self.max_retained_capacity_bytes
359 }
360
361 /// Update the maximum retained typed host-buffer capacity in bytes.
362 ///
363 /// Shrinking below the currently retained capacity immediately evicts
364 /// retained buffers until the new cap is satisfied. A cap of zero disables
365 /// retention.
366 ///
367 /// # Examples
368 ///
369 /// ```
370 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
371 ///
372 /// let mut pool = BufferPool::with_max_retained_capacity_bytes(1024);
373 /// <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(128));
374 /// pool.set_max_retained_capacity_bytes(0);
375 /// assert_eq!(pool.max_retained_capacity_bytes(), 0);
376 /// assert!(pool.is_empty());
377 /// ```
378 pub fn set_max_retained_capacity_bytes(&mut self, max_retained_capacity_bytes: usize) {
379 self.max_retained_capacity_bytes = max_retained_capacity_bytes;
380 self.enforce_retention_limit();
381 }
382
383 /// Number of retained buffers across all typed pools.
384 ///
385 /// # Examples
386 ///
387 /// ```rust
388 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
389 ///
390 /// let mut pool = BufferPool::new();
391 /// <f64 as PoolScalar>::pool_release(&mut pool, vec![0.0; 2]);
392 /// assert_eq!(pool.len(), 1);
393 /// ```
394 pub fn len(&self) -> usize {
395 self.stats().buffers
396 }
397
398 /// Total retained typed host-buffer capacity in bytes.
399 ///
400 /// This counts capacity that is still live in the pool. The operating
401 /// system RSS may remain high after clearing the pool because the process
402 /// allocator can keep freed pages for future allocations.
403 ///
404 /// # Examples
405 ///
406 /// ```rust
407 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
408 ///
409 /// let mut pool = BufferPool::new();
410 /// <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(2));
411 /// assert_eq!(pool.retained_capacity_bytes(), 16);
412 /// ```
413 pub fn retained_capacity_bytes(&self) -> usize {
414 self.stats().capacity_bytes
415 }
416
417 /// Snapshot retained-buffer count and capacity.
418 ///
419 /// # Examples
420 ///
421 /// ```rust
422 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
423 ///
424 /// let mut pool = BufferPool::new();
425 /// <f32 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(4));
426 /// let stats = pool.stats();
427 /// assert_eq!(stats.buffers, 1);
428 /// assert_eq!(stats.capacity_bytes, 16);
429 /// ```
430 pub fn stats(&self) -> BufferPoolStats {
431 BufferPoolStats {
432 buffers: pool_len(&self.f64_pool)
433 + pool_len(&self.f32_pool)
434 + pool_len(&self.i32_pool)
435 + pool_len(&self.i64_pool)
436 + pool_len(&self.bool_pool)
437 + pool_len(&self.c64_pool)
438 + pool_len(&self.c32_pool),
439 capacity_bytes: self.retained_capacity_bytes,
440 }
441 }
442
443 /// Return cache-style stats for the buffers retained by this pool.
444 ///
445 /// `entries` is the number of retained buffers, and `retained_bytes` is the
446 /// total retained vector capacity in bytes.
447 ///
448 /// # Examples
449 ///
450 /// ```
451 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
452 ///
453 /// let mut pool = BufferPool::new();
454 /// <f32 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(4));
455 /// let stats = pool.cache_stats();
456 /// assert_eq!(stats.entries, 1);
457 /// assert_eq!(stats.retained_bytes, 16);
458 /// ```
459 pub fn cache_stats(&self) -> CacheStats {
460 let stats = self.stats();
461 CacheStats {
462 entries: stats.buffers,
463 retained_bytes: stats.capacity_bytes,
464 }
465 }
466
467 /// Acquire a typed vector with length 0 and at least `cap` capacity.
468 ///
469 /// Returned buffers come from the typed pool when possible and are ready
470 /// for push-based population.
471 ///
472 /// # Examples
473 ///
474 /// ```rust
475 /// use tenferro_cpu::linalg_interop::BufferPool;
476 ///
477 /// let mut pool = BufferPool::new();
478 /// let mut buf = pool.acquire_with_capacity::<f64>(4);
479 /// buf.extend_from_slice(&[1.0, 2.0]);
480 /// assert_eq!(buf.len(), 2);
481 /// assert!(buf.capacity() >= 4);
482 /// ```
483 pub fn acquire_with_capacity<T: PoolScalar>(&mut self, cap: usize) -> Vec<T> {
484 if cap == 0 {
485 return Vec::new();
486 }
487
488 // SAFETY: this push-only capacity helper clears length before any element can be read.
489 let mut buf = unsafe { T::pool_acquire(self, cap) };
490 // SAFETY: shrinking length to zero does not read pooled `Copy` elements.
491 unsafe { buf.set_len(0) };
492 buf
493 }
494
495 /// Acquire a typed vector with length `len` initialized to zero.
496 ///
497 /// Use this only when the caller may read elements before overwriting the
498 /// entire buffer. Full-overwrite kernels should use
499 /// [`PoolScalar::pool_acquire`] to avoid the initialization cost.
500 ///
501 /// # Examples
502 ///
503 /// ```rust
504 /// use tenferro_cpu::linalg_interop::BufferPool;
505 ///
506 /// let mut pool = BufferPool::new();
507 /// let buf = pool.acquire_zeroed::<f32>(3);
508 /// assert_eq!(buf, vec![0.0, 0.0, 0.0]);
509 /// ```
510 pub fn acquire_zeroed<T: PoolScalar>(&mut self, len: usize) -> Vec<T> {
511 T::pool_acquire_zeroed(self, len)
512 }
513
514 /// Whether all typed pools are empty.
515 ///
516 /// # Examples
517 ///
518 /// ```rust
519 /// use tenferro_cpu::linalg_interop::BufferPool;
520 ///
521 /// let pool = BufferPool::new();
522 /// assert!(pool.is_empty());
523 /// ```
524 pub fn is_empty(&self) -> bool {
525 self.f64_pool.is_empty()
526 && self.f32_pool.is_empty()
527 && self.i32_pool.is_empty()
528 && self.i64_pool.is_empty()
529 && self.bool_pool.is_empty()
530 && self.c64_pool.is_empty()
531 && self.c32_pool.is_empty()
532 }
533
534 /// Drop all retained buffers from the pool.
535 ///
536 /// This releases the vectors owned by the pool. The process allocator may
537 /// still keep freed pages mapped for reuse, so operating-system RSS is not
538 /// guaranteed to fall immediately.
539 ///
540 /// # Examples
541 ///
542 /// ```rust
543 /// use tenferro_cpu::linalg_interop::{BufferPool, PoolScalar};
544 ///
545 /// let mut pool = BufferPool::new();
546 /// <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(8));
547 /// pool.clear();
548 /// assert!(pool.is_empty());
549 /// ```
550 pub fn clear(&mut self) {
551 self.f64_pool.clear();
552 self.f32_pool.clear();
553 self.i32_pool.clear();
554 self.i64_pool.clear();
555 self.bool_pool.clear();
556 self.c64_pool.clear();
557 self.c32_pool.clear();
558 self.retained_capacity_bytes = 0;
559 }
560
561 fn enforce_retention_limit(&mut self) {
562 while self.retained_capacity_bytes > self.max_retained_capacity_bytes {
563 let Some(evicted_bytes) = self.evict_smallest_retained_buffer() else {
564 self.retained_capacity_bytes = 0;
565 return;
566 };
567 self.retained_capacity_bytes =
568 self.retained_capacity_bytes.saturating_sub(evicted_bytes);
569 }
570 }
571
572 fn evict_smallest_retained_buffer(&mut self) -> Option<usize> {
573 let candidates = [
574 smallest_pool_candidate(&self.f64_pool, TypedPoolKind::F64),
575 smallest_pool_candidate(&self.f32_pool, TypedPoolKind::F32),
576 smallest_pool_candidate(&self.i32_pool, TypedPoolKind::I32),
577 smallest_pool_candidate(&self.i64_pool, TypedPoolKind::I64),
578 smallest_pool_candidate(&self.bool_pool, TypedPoolKind::Bool),
579 smallest_pool_candidate(&self.c64_pool, TypedPoolKind::C64),
580 smallest_pool_candidate(&self.c32_pool, TypedPoolKind::C32),
581 ];
582 let (_, kind) = candidates
583 .into_iter()
584 .flatten()
585 .min_by_key(|(bytes, _)| *bytes)?;
586 match kind {
587 TypedPoolKind::F64 => evict_one_from_pool(&mut self.f64_pool),
588 TypedPoolKind::F32 => evict_one_from_pool(&mut self.f32_pool),
589 TypedPoolKind::I32 => evict_one_from_pool(&mut self.i32_pool),
590 TypedPoolKind::I64 => evict_one_from_pool(&mut self.i64_pool),
591 TypedPoolKind::Bool => evict_one_from_pool(&mut self.bool_pool),
592 TypedPoolKind::C64 => evict_one_from_pool(&mut self.c64_pool),
593 TypedPoolKind::C32 => evict_one_from_pool(&mut self.c32_pool),
594 }
595 }
596}
597
598fn default_max_retained_capacity_bytes() -> usize {
599 env::var(BUFFER_POOL_MAX_RETAINED_BYTES_ENV)
600 .ok()
601 .and_then(|value| value.parse().ok())
602 .unwrap_or(DEFAULT_MAX_RETAINED_CAPACITY_BYTES)
603}
604
605impl Default for BufferPool {
606 fn default() -> Self {
607 Self::new()
608 }
609}
610
611#[cfg(test)]
612mod tests;