Skip to main content

tenferro_tensor/
buffer_pool.rs

1//! Typed host buffer pooling for reusable tensor allocations.
2//!
3//! # Examples
4//!
5//! ```ignore
6//! use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
7//!
8//! let mut pool = BufferPool::new();
9//! let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 4) };
10//! buf.fill(1.0);
11//! <f64 as PoolScalar>::pool_release(&mut pool, buf);
12//! assert_eq!(pool.len(), 1);
13//! ```
14
15use std::collections::BTreeMap;
16
17use num_complex::{Complex32, Complex64};
18
19/// Typed buffer pool keyed by element capacity and separated by scalar type.
20///
21/// Each supported dtype has an independent best-fit pool. Acquired buffers are
22/// returned without zero-initialization so GEMM callers can avoid redundant
23/// writes when they fully overwrite the output.
24///
25/// # Examples
26///
27/// ```ignore
28/// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
29///
30/// let mut pool = BufferPool::new();
31/// let buf = unsafe { <f32 as PoolScalar>::pool_acquire(&mut pool, 8) };
32/// <f32 as PoolScalar>::pool_release(&mut pool, buf);
33/// assert_eq!(pool.len(), 1);
34/// ```
35pub struct BufferPool {
36    f64_pool: BTreeMap<usize, Vec<Vec<f64>>>,
37    f32_pool: BTreeMap<usize, Vec<Vec<f32>>>,
38    c64_pool: BTreeMap<usize, Vec<Vec<Complex64>>>,
39    c32_pool: BTreeMap<usize, Vec<Vec<Complex32>>>,
40}
41
42/// Scalar types supported by [`BufferPool`].
43///
44/// The trait is sealed to the scalar dtypes that tenferro currently pools for
45/// CPU execution.
46///
47/// # Examples
48///
49/// ```ignore
50/// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
51///
52/// let mut pool = BufferPool::new();
53/// let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 2) };
54/// buf.copy_from_slice(&[3.0, 4.0]);
55/// <f64 as PoolScalar>::pool_release(&mut pool, buf);
56/// ```
57pub trait PoolScalar: Copy + Sized + Send + private::Sealed {
58    /// Acquire a buffer with length `len`.
59    ///
60    /// The vector length is set without initializing its contents. Callers must
61    /// overwrite every element before any read.
62    ///
63    /// # Safety
64    ///
65    /// The returned vector may contain uninitialized elements. Reading any
66    /// element before writing it is undefined behavior.
67    ///
68    /// # Examples
69    ///
70    /// ```ignore
71    /// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
72    ///
73    /// let mut pool = BufferPool::new();
74    /// let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 2) };
75    /// buf.copy_from_slice(&[1.0, 2.0]);
76    /// assert_eq!(buf, vec![1.0, 2.0]);
77    /// ```
78    unsafe fn pool_acquire(pool: &mut BufferPool, len: usize) -> Vec<Self>;
79
80    /// Return a buffer to the typed pool for later reuse.
81    ///
82    /// Zero-capacity buffers are ignored.
83    ///
84    /// # Examples
85    ///
86    /// ```ignore
87    /// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
88    ///
89    /// let mut pool = BufferPool::new();
90    /// let buf = vec![1.0_f32; 4];
91    /// <f32 as PoolScalar>::pool_release(&mut pool, buf);
92    /// assert_eq!(pool.len(), 1);
93    /// ```
94    fn pool_release(pool: &mut BufferPool, buf: Vec<Self>);
95}
96
97mod private {
98    pub trait Sealed {}
99
100    impl Sealed for f64 {}
101    impl Sealed for f32 {}
102    impl Sealed for num_complex::Complex64 {}
103    impl Sealed for num_complex::Complex32 {}
104}
105
106fn take_best_fit<T>(pool: &mut BTreeMap<usize, Vec<Vec<T>>>, len: usize) -> Option<Vec<T>> {
107    let key = *pool.range(len..).next()?.0;
108    let buf = {
109        let vecs = pool.get_mut(&key)?;
110        vecs.pop()
111    };
112    if pool.get(&key).is_some_and(Vec::is_empty) {
113        pool.remove(&key);
114    }
115    buf
116}
117
118macro_rules! impl_pool_scalar {
119    ($ty:ty, $field:ident) => {
120        impl PoolScalar for $ty {
121            #[allow(clippy::uninit_vec)]
122            unsafe fn pool_acquire(pool: &mut BufferPool, len: usize) -> Vec<Self> {
123                match take_best_fit(&mut pool.$field, len) {
124                    Some(mut buf) => {
125                        // SAFETY: caller upholds that elements will be written
126                        // before any read. len <= capacity by construction.
127                        unsafe { buf.set_len(len) };
128                        buf
129                    }
130                    None => {
131                        let mut buf = Vec::with_capacity(len);
132                        // SAFETY: caller upholds that elements will be written
133                        // before any read. len == capacity here.
134                        unsafe { buf.set_len(len) };
135                        buf
136                    }
137                }
138            }
139
140            fn pool_release(pool: &mut BufferPool, buf: Vec<Self>) {
141                let cap = buf.capacity();
142                if cap > 0 {
143                    pool.$field.entry(cap).or_default().push(buf);
144                }
145            }
146        }
147    };
148}
149
150impl_pool_scalar!(f64, f64_pool);
151impl_pool_scalar!(f32, f32_pool);
152impl_pool_scalar!(Complex64, c64_pool);
153impl_pool_scalar!(Complex32, c32_pool);
154
155impl BufferPool {
156    /// Create an empty typed buffer pool.
157    ///
158    /// # Examples
159    ///
160    /// ```ignore
161    /// use tenferro_tensor::buffer_pool::BufferPool;
162    ///
163    /// let pool = BufferPool::new();
164    /// assert!(pool.is_empty());
165    /// ```
166    pub fn new() -> Self {
167        Self {
168            f64_pool: BTreeMap::new(),
169            f32_pool: BTreeMap::new(),
170            c64_pool: BTreeMap::new(),
171            c32_pool: BTreeMap::new(),
172        }
173    }
174
175    /// Number of retained buffers across all typed pools.
176    ///
177    /// # Examples
178    ///
179    /// ```ignore
180    /// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
181    ///
182    /// let mut pool = BufferPool::new();
183    /// <f64 as PoolScalar>::pool_release(&mut pool, vec![0.0; 2]);
184    /// assert_eq!(pool.len(), 1);
185    /// ```
186    pub fn len(&self) -> usize {
187        self.f64_pool.values().map(Vec::len).sum::<usize>()
188            + self.f32_pool.values().map(Vec::len).sum::<usize>()
189            + self.c64_pool.values().map(Vec::len).sum::<usize>()
190            + self.c32_pool.values().map(Vec::len).sum::<usize>()
191    }
192
193    /// Acquire a typed vector with length 0 and at least `cap` capacity.
194    ///
195    /// Returned buffers come from the typed pool when possible and are ready
196    /// for push-based population.
197    ///
198    /// # Examples
199    ///
200    /// ```ignore
201    /// use tenferro_tensor::buffer_pool::BufferPool;
202    ///
203    /// let mut pool = BufferPool::new();
204    /// let mut buf = pool.acquire_with_capacity::<f64>(4);
205    /// buf.extend_from_slice(&[1.0, 2.0]);
206    /// assert_eq!(buf.len(), 2);
207    /// assert!(buf.capacity() >= 4);
208    /// ```
209    pub fn acquire_with_capacity<T: PoolScalar>(&mut self, cap: usize) -> Vec<T> {
210        if cap == 0 {
211            return Vec::new();
212        }
213
214        let mut buf = unsafe { T::pool_acquire(self, cap) };
215        // SAFETY: shrinking the length to zero does not read the buffer. The
216        // pool only stores `PoolScalar` values, which are `Copy`.
217        unsafe { buf.set_len(0) };
218        buf
219    }
220
221    /// Whether all typed pools are empty.
222    ///
223    /// # Examples
224    ///
225    /// ```ignore
226    /// use tenferro_tensor::buffer_pool::BufferPool;
227    ///
228    /// let pool = BufferPool::new();
229    /// assert!(pool.is_empty());
230    /// ```
231    pub fn is_empty(&self) -> bool {
232        self.f64_pool.is_empty()
233            && self.f32_pool.is_empty()
234            && self.c64_pool.is_empty()
235            && self.c32_pool.is_empty()
236    }
237}
238
239impl Default for BufferPool {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::{BufferPool, PoolScalar};
248
249    #[test]
250    fn acquire_release_reuse() {
251        let mut pool = BufferPool::new();
252
253        let buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 64) };
254        let ptr = buf.as_ptr();
255        let cap = buf.capacity();
256        <f64 as PoolScalar>::pool_release(&mut pool, buf);
257
258        let reused = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 64) };
259        assert_eq!(reused.as_ptr(), ptr);
260        assert_eq!(reused.capacity(), cap);
261        assert!(pool.is_empty());
262    }
263
264    #[test]
265    fn best_fit() {
266        let mut pool = BufferPool::new();
267        <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(100));
268        <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(200));
269        <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(300));
270
271        let reused = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 150) };
272        assert_eq!(reused.capacity(), 200);
273        assert_eq!(pool.len(), 2);
274    }
275
276    #[test]
277    fn type_separation() {
278        let mut pool = BufferPool::new();
279        <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(16));
280        assert_eq!(pool.len(), 1);
281
282        let f32_buf = unsafe { <f32 as PoolScalar>::pool_acquire(&mut pool, 16) };
283        assert_eq!(f32_buf.capacity(), 16);
284        assert_eq!(pool.len(), 1);
285
286        let f64_buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 16) };
287        assert_eq!(f64_buf.capacity(), 16);
288        assert!(pool.is_empty());
289    }
290
291    #[test]
292    fn fresh_alloc_fallback() {
293        let mut pool = BufferPool::new();
294        let buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 32) };
295        assert_eq!(buf.len(), 32);
296        assert!(buf.capacity() >= 32);
297        assert!(pool.is_empty());
298    }
299
300    #[test]
301    fn zero_len_not_pooled() {
302        let mut pool = BufferPool::new();
303        <f64 as PoolScalar>::pool_release(&mut pool, Vec::new());
304        assert!(pool.is_empty());
305    }
306
307    #[test]
308    fn acquire_with_capacity_reuses_buffer_as_empty_vec() {
309        let mut pool = BufferPool::new();
310
311        let buf = vec![1.0_f64; 8];
312        let ptr = buf.as_ptr();
313        let cap = buf.capacity();
314        <f64 as PoolScalar>::pool_release(&mut pool, buf);
315
316        let reused = pool.acquire_with_capacity::<f64>(8);
317        assert_eq!(reused.as_ptr(), ptr);
318        assert_eq!(reused.len(), 0);
319        assert_eq!(reused.capacity(), cap);
320        assert!(pool.is_empty());
321    }
322}