tenferro_tensor/buffer_pool.rs
1//! Typed host buffer pooling for reusable tensor allocations.
2//!
3//! # Examples
4//!
5//! ```ignore
6//! use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
7//!
8//! let mut pool = BufferPool::new();
9//! let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 4) };
10//! buf.fill(1.0);
11//! <f64 as PoolScalar>::pool_release(&mut pool, buf);
12//! assert_eq!(pool.len(), 1);
13//! ```
14
15use std::collections::BTreeMap;
16
17use num_complex::{Complex32, Complex64};
18
19/// Typed buffer pool keyed by element capacity and separated by scalar type.
20///
21/// Each supported dtype has an independent best-fit pool. Acquired buffers are
22/// returned without zero-initialization so GEMM callers can avoid redundant
23/// writes when they fully overwrite the output.
24///
25/// # Examples
26///
27/// ```ignore
28/// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
29///
30/// let mut pool = BufferPool::new();
31/// let buf = unsafe { <f32 as PoolScalar>::pool_acquire(&mut pool, 8) };
32/// <f32 as PoolScalar>::pool_release(&mut pool, buf);
33/// assert_eq!(pool.len(), 1);
34/// ```
35pub struct BufferPool {
36 f64_pool: BTreeMap<usize, Vec<Vec<f64>>>,
37 f32_pool: BTreeMap<usize, Vec<Vec<f32>>>,
38 c64_pool: BTreeMap<usize, Vec<Vec<Complex64>>>,
39 c32_pool: BTreeMap<usize, Vec<Vec<Complex32>>>,
40}
41
42/// Scalar types supported by [`BufferPool`].
43///
44/// The trait is sealed to the scalar dtypes that tenferro currently pools for
45/// CPU execution.
46///
47/// # Examples
48///
49/// ```ignore
50/// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
51///
52/// let mut pool = BufferPool::new();
53/// let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 2) };
54/// buf.copy_from_slice(&[3.0, 4.0]);
55/// <f64 as PoolScalar>::pool_release(&mut pool, buf);
56/// ```
57pub trait PoolScalar: Copy + Sized + Send + private::Sealed {
58 /// Acquire a buffer with length `len`.
59 ///
60 /// The vector length is set without initializing its contents. Callers must
61 /// overwrite every element before any read.
62 ///
63 /// # Safety
64 ///
65 /// The returned vector may contain uninitialized elements. Reading any
66 /// element before writing it is undefined behavior.
67 ///
68 /// # Examples
69 ///
70 /// ```ignore
71 /// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
72 ///
73 /// let mut pool = BufferPool::new();
74 /// let mut buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 2) };
75 /// buf.copy_from_slice(&[1.0, 2.0]);
76 /// assert_eq!(buf, vec![1.0, 2.0]);
77 /// ```
78 unsafe fn pool_acquire(pool: &mut BufferPool, len: usize) -> Vec<Self>;
79
80 /// Return a buffer to the typed pool for later reuse.
81 ///
82 /// Zero-capacity buffers are ignored.
83 ///
84 /// # Examples
85 ///
86 /// ```ignore
87 /// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
88 ///
89 /// let mut pool = BufferPool::new();
90 /// let buf = vec![1.0_f32; 4];
91 /// <f32 as PoolScalar>::pool_release(&mut pool, buf);
92 /// assert_eq!(pool.len(), 1);
93 /// ```
94 fn pool_release(pool: &mut BufferPool, buf: Vec<Self>);
95}
96
97mod private {
98 pub trait Sealed {}
99
100 impl Sealed for f64 {}
101 impl Sealed for f32 {}
102 impl Sealed for num_complex::Complex64 {}
103 impl Sealed for num_complex::Complex32 {}
104}
105
106fn take_best_fit<T>(pool: &mut BTreeMap<usize, Vec<Vec<T>>>, len: usize) -> Option<Vec<T>> {
107 let key = *pool.range(len..).next()?.0;
108 let buf = {
109 let vecs = pool.get_mut(&key)?;
110 vecs.pop()
111 };
112 if pool.get(&key).is_some_and(Vec::is_empty) {
113 pool.remove(&key);
114 }
115 buf
116}
117
118macro_rules! impl_pool_scalar {
119 ($ty:ty, $field:ident) => {
120 impl PoolScalar for $ty {
121 #[allow(clippy::uninit_vec)]
122 unsafe fn pool_acquire(pool: &mut BufferPool, len: usize) -> Vec<Self> {
123 match take_best_fit(&mut pool.$field, len) {
124 Some(mut buf) => {
125 // SAFETY: caller upholds that elements will be written
126 // before any read. len <= capacity by construction.
127 unsafe { buf.set_len(len) };
128 buf
129 }
130 None => {
131 let mut buf = Vec::with_capacity(len);
132 // SAFETY: caller upholds that elements will be written
133 // before any read. len == capacity here.
134 unsafe { buf.set_len(len) };
135 buf
136 }
137 }
138 }
139
140 fn pool_release(pool: &mut BufferPool, buf: Vec<Self>) {
141 let cap = buf.capacity();
142 if cap > 0 {
143 pool.$field.entry(cap).or_default().push(buf);
144 }
145 }
146 }
147 };
148}
149
150impl_pool_scalar!(f64, f64_pool);
151impl_pool_scalar!(f32, f32_pool);
152impl_pool_scalar!(Complex64, c64_pool);
153impl_pool_scalar!(Complex32, c32_pool);
154
155impl BufferPool {
156 /// Create an empty typed buffer pool.
157 ///
158 /// # Examples
159 ///
160 /// ```ignore
161 /// use tenferro_tensor::buffer_pool::BufferPool;
162 ///
163 /// let pool = BufferPool::new();
164 /// assert!(pool.is_empty());
165 /// ```
166 pub fn new() -> Self {
167 Self {
168 f64_pool: BTreeMap::new(),
169 f32_pool: BTreeMap::new(),
170 c64_pool: BTreeMap::new(),
171 c32_pool: BTreeMap::new(),
172 }
173 }
174
175 /// Number of retained buffers across all typed pools.
176 ///
177 /// # Examples
178 ///
179 /// ```ignore
180 /// use tenferro_tensor::buffer_pool::{BufferPool, PoolScalar};
181 ///
182 /// let mut pool = BufferPool::new();
183 /// <f64 as PoolScalar>::pool_release(&mut pool, vec![0.0; 2]);
184 /// assert_eq!(pool.len(), 1);
185 /// ```
186 pub fn len(&self) -> usize {
187 self.f64_pool.values().map(Vec::len).sum::<usize>()
188 + self.f32_pool.values().map(Vec::len).sum::<usize>()
189 + self.c64_pool.values().map(Vec::len).sum::<usize>()
190 + self.c32_pool.values().map(Vec::len).sum::<usize>()
191 }
192
193 /// Acquire a typed vector with length 0 and at least `cap` capacity.
194 ///
195 /// Returned buffers come from the typed pool when possible and are ready
196 /// for push-based population.
197 ///
198 /// # Examples
199 ///
200 /// ```ignore
201 /// use tenferro_tensor::buffer_pool::BufferPool;
202 ///
203 /// let mut pool = BufferPool::new();
204 /// let mut buf = pool.acquire_with_capacity::<f64>(4);
205 /// buf.extend_from_slice(&[1.0, 2.0]);
206 /// assert_eq!(buf.len(), 2);
207 /// assert!(buf.capacity() >= 4);
208 /// ```
209 pub fn acquire_with_capacity<T: PoolScalar>(&mut self, cap: usize) -> Vec<T> {
210 if cap == 0 {
211 return Vec::new();
212 }
213
214 let mut buf = unsafe { T::pool_acquire(self, cap) };
215 // SAFETY: shrinking the length to zero does not read the buffer. The
216 // pool only stores `PoolScalar` values, which are `Copy`.
217 unsafe { buf.set_len(0) };
218 buf
219 }
220
221 /// Whether all typed pools are empty.
222 ///
223 /// # Examples
224 ///
225 /// ```ignore
226 /// use tenferro_tensor::buffer_pool::BufferPool;
227 ///
228 /// let pool = BufferPool::new();
229 /// assert!(pool.is_empty());
230 /// ```
231 pub fn is_empty(&self) -> bool {
232 self.f64_pool.is_empty()
233 && self.f32_pool.is_empty()
234 && self.c64_pool.is_empty()
235 && self.c32_pool.is_empty()
236 }
237}
238
239impl Default for BufferPool {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::{BufferPool, PoolScalar};
248
249 #[test]
250 fn acquire_release_reuse() {
251 let mut pool = BufferPool::new();
252
253 let buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 64) };
254 let ptr = buf.as_ptr();
255 let cap = buf.capacity();
256 <f64 as PoolScalar>::pool_release(&mut pool, buf);
257
258 let reused = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 64) };
259 assert_eq!(reused.as_ptr(), ptr);
260 assert_eq!(reused.capacity(), cap);
261 assert!(pool.is_empty());
262 }
263
264 #[test]
265 fn best_fit() {
266 let mut pool = BufferPool::new();
267 <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(100));
268 <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(200));
269 <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(300));
270
271 let reused = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 150) };
272 assert_eq!(reused.capacity(), 200);
273 assert_eq!(pool.len(), 2);
274 }
275
276 #[test]
277 fn type_separation() {
278 let mut pool = BufferPool::new();
279 <f64 as PoolScalar>::pool_release(&mut pool, Vec::with_capacity(16));
280 assert_eq!(pool.len(), 1);
281
282 let f32_buf = unsafe { <f32 as PoolScalar>::pool_acquire(&mut pool, 16) };
283 assert_eq!(f32_buf.capacity(), 16);
284 assert_eq!(pool.len(), 1);
285
286 let f64_buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 16) };
287 assert_eq!(f64_buf.capacity(), 16);
288 assert!(pool.is_empty());
289 }
290
291 #[test]
292 fn fresh_alloc_fallback() {
293 let mut pool = BufferPool::new();
294 let buf = unsafe { <f64 as PoolScalar>::pool_acquire(&mut pool, 32) };
295 assert_eq!(buf.len(), 32);
296 assert!(buf.capacity() >= 32);
297 assert!(pool.is_empty());
298 }
299
300 #[test]
301 fn zero_len_not_pooled() {
302 let mut pool = BufferPool::new();
303 <f64 as PoolScalar>::pool_release(&mut pool, Vec::new());
304 assert!(pool.is_empty());
305 }
306
307 #[test]
308 fn acquire_with_capacity_reuses_buffer_as_empty_vec() {
309 let mut pool = BufferPool::new();
310
311 let buf = vec![1.0_f64; 8];
312 let ptr = buf.as_ptr();
313 let cap = buf.capacity();
314 <f64 as PoolScalar>::pool_release(&mut pool, buf);
315
316 let reused = pool.acquire_with_capacity::<f64>(8);
317 assert_eq!(reused.as_ptr(), ptr);
318 assert_eq!(reused.len(), 0);
319 assert_eq!(reused.capacity(), cap);
320 assert!(pool.is_empty());
321 }
322}