tenferro_tensor/
buffer.rs

1use std::sync::Arc;
2
3use tenferro_algebra::Scalar;
4use tenferro_device::{Error, LogicalMemorySpace, Result};
5
6/// Data storage for tensor elements.
7///
8/// Abstracts over ownership: data may be Rust-owned (`Vec<T>`),
9/// externally-owned (e.g., imported via DLPack with a release callback),
10/// or GPU device memory. Shape and stride metadata are not stored here.
11///
12/// # Shared ownership
13///
14/// `DataBuffer` uses `Arc`, so cloning a tensor or buffer is shallow and
15/// shares storage. Deep copies happen only when a tensor is materialized into
16/// a new contiguous allocation.
17///
18/// # Examples
19///
20/// ```ignore
21/// use tenferro_tensor::DataBuffer;
22///
23/// let buf = DataBuffer::<f64>::from_vec(vec![1.0, 2.0, 3.0]);
24/// assert_eq!(buf.len(), 3);
25/// ```
26pub struct DataBuffer<T> {
27    inner: Arc<BufferInner<T>>,
28}
29
30enum BufferInner<T> {
31    Owned(Vec<T>),
32    External {
33        ptr: *const T,
34        len: usize,
35        release: Option<Box<dyn FnOnce() + Send>>,
36    },
37    #[allow(dead_code)]
38    Gpu {
39        device_ptr: *mut T,
40        len: usize,
41        space: LogicalMemorySpace,
42        release: Option<Box<dyn FnOnce() + Send>>,
43    },
44}
45
46unsafe impl<T: Send> Send for DataBuffer<T> {}
47unsafe impl<T: Sync> Sync for DataBuffer<T> {}
48
49impl<T> Clone for DataBuffer<T> {
50    fn clone(&self) -> Self {
51        Self {
52            inner: Arc::clone(&self.inner),
53        }
54    }
55}
56
57impl<T> Drop for BufferInner<T> {
58    fn drop(&mut self) {
59        match self {
60            BufferInner::External { release, .. } | BufferInner::Gpu { release, .. } => {
61                if let Some(callback) = release.take() {
62                    callback();
63                }
64            }
65            BufferInner::Owned(_) => {}
66        }
67    }
68}
69
70impl<T> DataBuffer<T> {
71    /// Create a buffer from an owned `Vec<T>`.
72    ///
73    /// # Examples
74    ///
75    /// ```ignore
76    /// use tenferro_tensor::DataBuffer;
77    ///
78    /// let buf = DataBuffer::from_vec(vec![1.0, 2.0, 3.0]);
79    /// assert!(buf.is_owned());
80    /// ```
81    pub fn from_vec(v: Vec<T>) -> Self {
82        Self {
83            inner: Arc::new(BufferInner::Owned(v)),
84        }
85    }
86
87    /// Create a buffer from externally-owned data with a release callback.
88    ///
89    /// # Safety
90    ///
91    /// - `ptr` must point to a valid, properly aligned allocation of at least
92    ///   `len` elements of `T`.
93    /// - The allocation must remain valid until the release callback fires.
94    ///
95    /// # Examples
96    ///
97    /// ```ignore
98    /// use tenferro_tensor::DataBuffer;
99    ///
100    /// let data = vec![1.0, 2.0, 3.0];
101    /// let ptr = data.as_ptr();
102    /// let len = data.len();
103    /// let buf = unsafe { DataBuffer::from_external(ptr, len, move || drop(data)) };
104    /// assert!(!buf.is_owned());
105    /// ```
106    pub unsafe fn from_external(
107        ptr: *const T,
108        len: usize,
109        release: impl FnOnce() + Send + 'static,
110    ) -> Self {
111        Self {
112            inner: Arc::new(BufferInner::External {
113                ptr,
114                len,
115                release: Some(Box::new(release)),
116            }),
117        }
118    }
119
120    #[allow(dead_code)]
121    pub(crate) unsafe fn from_gpu_parts(
122        device_ptr: *mut T,
123        len: usize,
124        space: LogicalMemorySpace,
125        release: impl FnOnce() + Send + 'static,
126    ) -> Self {
127        Self {
128            inner: Arc::new(BufferInner::Gpu {
129                device_ptr,
130                len,
131                space,
132                release: Some(Box::new(release)),
133            }),
134        }
135    }
136
137    /// Returns the raw data as a slice for CPU-accessible buffers.
138    ///
139    /// # Examples
140    ///
141    /// ```
142    /// use tenferro_tensor::DataBuffer;
143    ///
144    /// let buf = DataBuffer::from_vec(vec![1.0, 2.0, 3.0]);
145    /// assert_eq!(buf.as_slice(), Some(&[1.0, 2.0, 3.0][..]));
146    /// ```
147    pub fn as_slice(&self) -> Option<&[T]> {
148        match &*self.inner {
149            BufferInner::Owned(v) => Some(v.as_slice()),
150            BufferInner::External { ptr, len, .. } => {
151                Some(unsafe { std::slice::from_raw_parts(*ptr, *len) })
152            }
153            BufferInner::Gpu { .. } => None,
154        }
155    }
156
157    /// Returns the raw data as a mutable slice, if uniquely owned.
158    ///
159    /// # Examples
160    ///
161    /// ```
162    /// use tenferro_tensor::DataBuffer;
163    ///
164    /// let mut buf = DataBuffer::from_vec(vec![1.0, 2.0]);
165    /// if let Some(slice) = buf.as_mut_slice() {
166    ///     slice[0] = 42.0;
167    /// }
168    /// assert_eq!(buf.as_slice().unwrap()[0], 42.0);
169    /// ```
170    pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
171        match Arc::get_mut(&mut self.inner)? {
172            BufferInner::Owned(v) => Some(v.as_mut_slice()),
173            BufferInner::External { .. } | BufferInner::Gpu { .. } => None,
174        }
175    }
176
177    /// Returns the number of elements in the buffer.
178    ///
179    /// # Examples
180    ///
181    /// ```
182    /// use tenferro_tensor::DataBuffer;
183    ///
184    /// let buf = DataBuffer::from_vec(vec![1.0, 2.0, 3.0]);
185    /// assert_eq!(buf.len(), 3);
186    /// ```
187    pub fn len(&self) -> usize {
188        match &*self.inner {
189            BufferInner::Owned(v) => v.len(),
190            BufferInner::External { len, .. } | BufferInner::Gpu { len, .. } => *len,
191        }
192    }
193
194    /// Returns `true` if the buffer has no elements.
195    ///
196    /// # Examples
197    ///
198    /// ```
199    /// use tenferro_tensor::DataBuffer;
200    ///
201    /// let buf = DataBuffer::<f64>::from_vec(vec![]);
202    /// assert!(buf.is_empty());
203    /// ```
204    pub fn is_empty(&self) -> bool {
205        self.len() == 0
206    }
207
208    /// Returns `true` if the buffer is Rust-owned.
209    ///
210    /// # Examples
211    ///
212    /// ```
213    /// use tenferro_tensor::DataBuffer;
214    ///
215    /// let buf = DataBuffer::from_vec(vec![1.0f64]);
216    /// assert!(buf.is_owned());
217    /// ```
218    pub fn is_owned(&self) -> bool {
219        matches!(&*self.inner, BufferInner::Owned(_))
220    }
221
222    /// Returns `true` if the buffer resides on GPU device memory.
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// use tenferro_tensor::DataBuffer;
228    ///
229    /// let buf = DataBuffer::from_vec(vec![1.0f64]);
230    /// assert!(!buf.is_gpu());
231    /// ```
232    pub fn is_gpu(&self) -> bool {
233        matches!(&*self.inner, BufferInner::Gpu { .. })
234    }
235
236    /// Returns `true` if this is the only reference to the underlying buffer.
237    ///
238    /// # Examples
239    ///
240    /// ```
241    /// use tenferro_tensor::DataBuffer;
242    ///
243    /// let buf = DataBuffer::from_vec(vec![1.0f64]);
244    /// assert!(buf.is_unique());
245    /// let buf2 = buf.clone();
246    /// assert!(!buf.is_unique());
247    /// ```
248    pub fn is_unique(&self) -> bool {
249        Arc::strong_count(&self.inner) == 1
250    }
251
252    /// Extract the inner `Vec<T>` if this is the sole owner of a CPU-owned buffer.
253    pub fn try_into_vec(self) -> Option<Vec<T>> {
254        let inner = Arc::try_unwrap(self.inner).ok()?;
255        let mut inner = std::mem::ManuallyDrop::new(inner);
256        match &mut *inner {
257            BufferInner::Owned(v) => Some(unsafe { std::ptr::read(v as *const Vec<T>) }),
258            _ => {
259                unsafe { std::mem::ManuallyDrop::drop(&mut inner) };
260                None
261            }
262        }
263    }
264
265    /// Returns a raw CPU pointer to the data, or `None` for GPU buffers.
266    ///
267    /// # Examples
268    ///
269    /// ```
270    /// use tenferro_tensor::DataBuffer;
271    ///
272    /// let buf = DataBuffer::from_vec(vec![1.0f64]);
273    /// assert!(buf.as_ptr().is_some());
274    /// ```
275    pub fn as_ptr(&self) -> Option<*const T> {
276        match &*self.inner {
277            BufferInner::Owned(v) => Some(v.as_ptr()),
278            BufferInner::External { ptr, .. } => Some(*ptr),
279            BufferInner::Gpu { .. } => None,
280        }
281    }
282
283    /// Returns the GPU device pointer, or `None` for CPU buffers.
284    ///
285    /// # Examples
286    ///
287    /// ```
288    /// use tenferro_tensor::DataBuffer;
289    ///
290    /// let buf = DataBuffer::from_vec(vec![1.0f64]);
291    /// assert!(buf.as_device_ptr().is_none());
292    /// ```
293    pub fn as_device_ptr(&self) -> Option<*const T> {
294        match &*self.inner {
295            BufferInner::Gpu { device_ptr, .. } => Some(*device_ptr as *const T),
296            _ => None,
297        }
298    }
299
300    /// Returns the logical memory space of a GPU buffer, or `None` for CPU buffers.
301    ///
302    /// # Examples
303    ///
304    /// ```
305    /// use tenferro_tensor::DataBuffer;
306    ///
307    /// let buf = DataBuffer::from_vec(vec![1.0f64]);
308    /// assert!(buf.gpu_memory_space().is_none());
309    /// ```
310    pub fn gpu_memory_space(&self) -> Option<LogicalMemorySpace> {
311        match &*self.inner {
312            BufferInner::Gpu { space, .. } => Some(*space),
313            _ => None,
314        }
315    }
316
317    pub(crate) fn reinterpret_as<U>(&self, new_len: usize) -> Result<DataBuffer<U>>
318    where
319        T: Send + Sync + 'static,
320        U: Send + Sync + 'static,
321    {
322        let src_bytes = self
323            .len()
324            .checked_mul(std::mem::size_of::<T>())
325            .ok_or_else(|| Error::InvalidArgument("buffer byte-size overflow".into()))?;
326        let dst_bytes = new_len
327            .checked_mul(std::mem::size_of::<U>())
328            .ok_or_else(|| {
329                Error::InvalidArgument("reinterpreted buffer byte-size overflow".into())
330            })?;
331        if dst_bytes > src_bytes {
332            return Err(Error::InvalidArgument(format!(
333                "buffer reinterpretation exceeds source byte size: src_bytes={src_bytes} dst_bytes={dst_bytes}"
334            )));
335        }
336
337        let align = std::mem::align_of::<U>();
338        match &*self.inner {
339            BufferInner::Owned(v) => {
340                let ptr = v.as_ptr() as *const U;
341                if new_len != 0 && !(ptr as usize).is_multiple_of(align) {
342                    return Err(Error::InvalidArgument(format!(
343                        "buffer reinterpretation would violate alignment {}",
344                        align
345                    )));
346                }
347                let owner = self.clone();
348                Ok(unsafe { DataBuffer::from_external(ptr, new_len, move || drop(owner)) })
349            }
350            BufferInner::External { ptr, .. } => {
351                let ptr = *ptr as *const U;
352                if new_len != 0 && !(ptr as usize).is_multiple_of(align) {
353                    return Err(Error::InvalidArgument(format!(
354                        "external buffer reinterpretation would violate alignment {}",
355                        align
356                    )));
357                }
358                let owner = self.clone();
359                Ok(unsafe { DataBuffer::from_external(ptr, new_len, move || drop(owner)) })
360            }
361            BufferInner::Gpu {
362                device_ptr, space, ..
363            } => {
364                let ptr = *device_ptr as *mut U;
365                if new_len != 0 && !(ptr as usize).is_multiple_of(align) {
366                    return Err(Error::InvalidArgument(format!(
367                        "gpu buffer reinterpretation would violate alignment {}",
368                        align
369                    )));
370                }
371                let owner = self.clone();
372                Ok(
373                    unsafe {
374                        DataBuffer::from_gpu_parts(ptr, new_len, *space, move || drop(owner))
375                    },
376                )
377            }
378        }
379    }
380}
381
382impl<T: Scalar> DataBuffer<T> {
383    /// Allocate a zero-filled buffer on the specified device.
384    ///
385    /// For CPU (`MainMemory`) this creates an owned `Vec<T>` filled with
386    /// `T::zero()`. For CUDA GPU memory (when the `cuda` feature is enabled)
387    /// this allocates directly on the device and zero-fills via host-to-device
388    /// copy.
389    ///
390    /// # Errors
391    ///
392    /// Returns an error for unsupported memory spaces or if device allocation
393    /// fails.
394    ///
395    /// # Examples
396    ///
397    /// ```ignore
398    /// use tenferro_tensor::DataBuffer;
399    /// use tenferro_device::LogicalMemorySpace;
400    ///
401    /// let buf = DataBuffer::<f64>::zeros_on_device(4, LogicalMemorySpace::MainMemory).unwrap();
402    /// assert_eq!(buf.as_slice().unwrap(), &[0.0; 4]);
403    /// ```
404    pub fn zeros_on_device(n_elements: usize, memory_space: LogicalMemorySpace) -> Result<Self> {
405        match memory_space {
406            LogicalMemorySpace::MainMemory => Ok(Self::from_vec(vec![T::zero(); n_elements])),
407            #[cfg(feature = "cuda")]
408            LogicalMemorySpace::GpuMemory { .. } => {
409                crate::cuda_runtime::alloc_zeros_gpu(n_elements, memory_space)
410            }
411            _ => Err(Error::DeviceError(format!(
412                "zeros_on_device: unsupported memory space {memory_space:?}"
413            ))),
414        }
415    }
416
417    /// Allocate an uninitialized buffer on the specified device.
418    ///
419    /// For CPU (`MainMemory`) this creates an owned `Vec<T>` with the
420    /// requested capacity. The contents are **unspecified** (currently
421    /// zero-filled for safety, but callers must not rely on that).
422    ///
423    /// For CUDA GPU memory (when the `cuda` feature is enabled) this
424    /// allocates device memory without initialization.
425    ///
426    /// # Errors
427    ///
428    /// Returns an error for unsupported memory spaces or if device allocation
429    /// fails.
430    ///
431    /// # Examples
432    ///
433    /// ```ignore
434    /// use tenferro_tensor::DataBuffer;
435    /// use tenferro_device::LogicalMemorySpace;
436    ///
437    /// let buf = DataBuffer::<f64>::allocate_uninit_on_device(4, LogicalMemorySpace::MainMemory).unwrap();
438    /// assert_eq!(buf.len(), 4);
439    /// ```
440    pub fn allocate_uninit_on_device(
441        n_elements: usize,
442        memory_space: LogicalMemorySpace,
443    ) -> Result<Self> {
444        match memory_space {
445            LogicalMemorySpace::MainMemory => {
446                // Zero-fill for safety on CPU; the "uninit" contract means
447                // callers will overwrite, but we avoid UB by initializing.
448                Ok(Self::from_vec(vec![T::zero(); n_elements]))
449            }
450            #[cfg(feature = "cuda")]
451            LogicalMemorySpace::GpuMemory { .. } => {
452                // True uninit allocation on GPU — caller is responsible for
453                // writing before reading.
454                crate::cuda_runtime::alloc_gpu_uninit(n_elements, memory_space)
455            }
456            _ => Err(Error::DeviceError(format!(
457                "allocate_uninit_on_device: unsupported memory space {memory_space:?}"
458            ))),
459        }
460    }
461}