tenferro_tensor/tensor/
mod.rs

1mod autodiff;
2mod combine;
3mod constructors;
4mod constructors_special;
5mod data_ops;
6mod element_access;
7mod metadata;
8mod structural;
9mod transfer;
10mod views;
11
12use std::sync::Arc;
13
14use tenferro_device::{ComputeDevice, LogicalMemorySpace};
15
16use crate::layout::{compute_contiguous_strides, is_contiguous_in_order};
17use crate::{CompletionEvent, DataBuffer, MemoryOrder};
18
19pub use structural::KeepCountScalar;
20
21/// Multi-dimensional dense tensor.
22///
23/// `Tensor<T>` owns or shares a [`DataBuffer`] together with shape, strides,
24/// and memory-space metadata.
25///
26/// ## Zero-copy views
27///
28/// Operations like [`permute`](Tensor::permute), [`broadcast`](Tensor::broadcast),
29/// and [`diagonal`](Tensor::diagonal) return new tensors that share the
30/// underlying buffer and only adjust metadata.
31///
32/// ## Accessing raw data
33///
34/// Use [`DataBuffer::as_slice`] via [`Tensor::buffer`] together with
35/// [`Tensor::dims`], [`Tensor::strides`], and [`Tensor::offset`] to build
36/// backend-specific views.
37///
38/// ## GPU async support
39///
40/// The optional [`CompletionEvent`] tracks pending GPU computation so future
41/// backends can chain asynchronous work without forcing CPU synchronization.
42///
43/// # Examples
44///
45/// ```ignore
46/// use tenferro_device::LogicalMemorySpace;
47/// use tenferro_tensor::{MemoryOrder, Tensor};
48///
49/// let t = Tensor::<f64>::zeros(
50///     &[2, 3],
51///     LogicalMemorySpace::MainMemory,
52///     MemoryOrder::ColumnMajor,
53/// ).unwrap();
54/// assert_eq!(t.dims(), &[2, 3]);
55/// assert_eq!(t.len(), 6);
56/// ```
57pub struct Tensor<T> {
58    buffer: DataBuffer<T>,
59    dims: Arc<[usize]>,
60    strides: Arc<[isize]>,
61    offset: isize,
62    logical_memory_space: LogicalMemorySpace,
63    preferred_compute_device: Option<ComputeDevice>,
64    event: Option<CompletionEvent>,
65    conjugated: bool,
66    fw_grad: Option<Box<Tensor<T>>>,
67}
68
69pub(crate) struct TensorParts<T> {
70    pub(crate) buffer: DataBuffer<T>,
71    pub(crate) dims: Arc<[usize]>,
72    pub(crate) strides: Arc<[isize]>,
73    pub(crate) offset: isize,
74    pub(crate) logical_memory_space: LogicalMemorySpace,
75    pub(crate) preferred_compute_device: Option<ComputeDevice>,
76    pub(crate) event: Option<CompletionEvent>,
77    pub(crate) conjugated: bool,
78    pub(crate) fw_grad: Option<Box<Tensor<T>>>,
79}
80
81impl<T> Clone for Tensor<T> {
82    /// Shallow clone: shares the underlying data buffer.
83    ///
84    /// For a deep copy, materialize into a new allocation with
85    /// [`Tensor::contiguous`] or another explicit data-producing operation.
86    fn clone(&self) -> Self {
87        Self::from_parts(TensorParts {
88            buffer: self.buffer.clone(),
89            dims: self.dims.clone(),
90            strides: self.strides.clone(),
91            offset: self.offset,
92            logical_memory_space: self.logical_memory_space,
93            preferred_compute_device: self.preferred_compute_device,
94            event: self.event.clone(),
95            conjugated: self.conjugated,
96            fw_grad: self.fw_grad.clone(),
97        })
98    }
99}
100
101impl<T> std::fmt::Debug for Tensor<T> {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        let has_pending_event = self.event.is_some();
104        let has_fw_grad = self.fw_grad.is_some();
105        f.debug_struct("Tensor")
106            .field("dtype", &std::any::type_name::<T>())
107            .field("dims", &self.dims)
108            .field("strides", &self.strides)
109            .field("offset", &self.offset)
110            .field("len", &self.len())
111            .field("logical_memory_space", &self.logical_memory_space)
112            .field("preferred_compute_device", &self.preferred_compute_device)
113            .field("is_contiguous", &self.is_contiguous())
114            .field("conjugated", &self.conjugated)
115            .field("has_pending_event", &has_pending_event)
116            .field("has_fw_grad", &has_fw_grad)
117            .finish()
118    }
119}
120
121/// Methods that require no element-type bounds at all.
122///
123/// These operate only on tensor metadata (dims, strides, offset, buffer
124/// reference) and never read or write element values.
125impl<T> Tensor<T> {
126    pub(crate) fn from_parts(parts: TensorParts<T>) -> Self {
127        let TensorParts {
128            buffer,
129            dims,
130            strides,
131            offset,
132            logical_memory_space,
133            preferred_compute_device,
134            event,
135            conjugated,
136            fw_grad,
137        } = parts;
138        Self {
139            buffer,
140            dims,
141            strides,
142            offset,
143            logical_memory_space,
144            preferred_compute_device,
145            event,
146            conjugated,
147            fw_grad,
148        }
149    }
150
151    pub(crate) fn from_owned_contiguous_data(
152        data: Vec<T>,
153        dims: Arc<[usize]>,
154        order: MemoryOrder,
155        logical_memory_space: LogicalMemorySpace,
156        preferred_compute_device: Option<ComputeDevice>,
157        conjugated: bool,
158    ) -> Self {
159        let strides = Arc::from(compute_contiguous_strides(dims.as_ref(), order));
160        Self::from_parts(TensorParts {
161            buffer: DataBuffer::from_vec(data),
162            dims,
163            strides,
164            offset: 0,
165            logical_memory_space,
166            preferred_compute_device,
167            event: None,
168            conjugated,
169            fw_grad: None,
170        })
171    }
172
173    pub(crate) fn shared_view_with(
174        &self,
175        dims: Arc<[usize]>,
176        strides: Arc<[isize]>,
177        offset: isize,
178    ) -> Self {
179        Self::from_parts(TensorParts {
180            buffer: self.buffer.clone(),
181            dims,
182            strides,
183            offset,
184            logical_memory_space: self.logical_memory_space,
185            preferred_compute_device: self.preferred_compute_device,
186            event: None,
187            conjugated: self.conjugated,
188            fw_grad: None,
189        })
190    }
191
192    pub(crate) fn materialized_from_vec(&self, data: Vec<T>, order: MemoryOrder) -> Self {
193        Self::from_owned_contiguous_data(
194            data,
195            self.dims.clone(),
196            order,
197            self.logical_memory_space,
198            self.preferred_compute_device,
199            self.conjugated,
200        )
201    }
202
203    pub(crate) fn cpu_backed_slice_or_panic(&self, operation: &str) -> &[T] {
204        self.buffer.as_slice().unwrap_or_else(|| {
205            panic!("{operation}: CPU-only operation; GPU tensors are not supported")
206        })
207    }
208
209    /// Returns `true` if the tensor data is contiguous in memory.
210    ///
211    /// # Examples
212    ///
213    /// ```ignore
214    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
215    /// assert!(t.is_contiguous());
216    /// ```
217    pub fn is_contiguous(&self) -> bool {
218        is_contiguous_in_order(&self.dims, &self.strides, MemoryOrder::ColumnMajor)
219            || is_contiguous_in_order(&self.dims, &self.strides, MemoryOrder::RowMajor)
220    }
221
222    /// Check if the tensor has column-major contiguous layout.
223    ///
224    /// # Examples
225    ///
226    /// ```ignore
227    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
228    /// assert!(t.is_col_major_contiguous());
229    /// ```
230    pub fn is_col_major_contiguous(&self) -> bool {
231        is_contiguous_in_order(&self.dims, &self.strides, MemoryOrder::ColumnMajor)
232    }
233
234    /// Check if the tensor has row-major contiguous layout.
235    ///
236    /// # Examples
237    ///
238    /// ```ignore
239    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
240    /// assert!(t.is_row_major_contiguous());
241    /// ```
242    pub fn is_row_major_contiguous(&self) -> bool {
243        is_contiguous_in_order(&self.dims, &self.strides, MemoryOrder::RowMajor)
244    }
245}