1mod autodiff;
2mod combine;
3mod constructors;
4mod constructors_special;
5mod data_ops;
6mod element_access;
7mod metadata;
8mod structural;
9mod transfer;
10mod views;
11
12use std::sync::Arc;
13
14use tenferro_device::{ComputeDevice, LogicalMemorySpace};
15
16use crate::layout::{compute_contiguous_strides, is_contiguous_in_order};
17use crate::{CompletionEvent, DataBuffer, MemoryOrder};
18
19pub use structural::KeepCountScalar;
20
21pub struct Tensor<T> {
58 buffer: DataBuffer<T>,
59 dims: Arc<[usize]>,
60 strides: Arc<[isize]>,
61 offset: isize,
62 logical_memory_space: LogicalMemorySpace,
63 preferred_compute_device: Option<ComputeDevice>,
64 event: Option<CompletionEvent>,
65 conjugated: bool,
66 fw_grad: Option<Box<Tensor<T>>>,
67}
68
69pub(crate) struct TensorParts<T> {
70 pub(crate) buffer: DataBuffer<T>,
71 pub(crate) dims: Arc<[usize]>,
72 pub(crate) strides: Arc<[isize]>,
73 pub(crate) offset: isize,
74 pub(crate) logical_memory_space: LogicalMemorySpace,
75 pub(crate) preferred_compute_device: Option<ComputeDevice>,
76 pub(crate) event: Option<CompletionEvent>,
77 pub(crate) conjugated: bool,
78 pub(crate) fw_grad: Option<Box<Tensor<T>>>,
79}
80
81impl<T> Clone for Tensor<T> {
82 fn clone(&self) -> Self {
87 Self::from_parts(TensorParts {
88 buffer: self.buffer.clone(),
89 dims: self.dims.clone(),
90 strides: self.strides.clone(),
91 offset: self.offset,
92 logical_memory_space: self.logical_memory_space,
93 preferred_compute_device: self.preferred_compute_device,
94 event: self.event.clone(),
95 conjugated: self.conjugated,
96 fw_grad: self.fw_grad.clone(),
97 })
98 }
99}
100
101impl<T> std::fmt::Debug for Tensor<T> {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 let has_pending_event = self.event.is_some();
104 let has_fw_grad = self.fw_grad.is_some();
105 f.debug_struct("Tensor")
106 .field("dtype", &std::any::type_name::<T>())
107 .field("dims", &self.dims)
108 .field("strides", &self.strides)
109 .field("offset", &self.offset)
110 .field("len", &self.len())
111 .field("logical_memory_space", &self.logical_memory_space)
112 .field("preferred_compute_device", &self.preferred_compute_device)
113 .field("is_contiguous", &self.is_contiguous())
114 .field("conjugated", &self.conjugated)
115 .field("has_pending_event", &has_pending_event)
116 .field("has_fw_grad", &has_fw_grad)
117 .finish()
118 }
119}
120
121impl<T> Tensor<T> {
126 pub(crate) fn from_parts(parts: TensorParts<T>) -> Self {
127 let TensorParts {
128 buffer,
129 dims,
130 strides,
131 offset,
132 logical_memory_space,
133 preferred_compute_device,
134 event,
135 conjugated,
136 fw_grad,
137 } = parts;
138 Self {
139 buffer,
140 dims,
141 strides,
142 offset,
143 logical_memory_space,
144 preferred_compute_device,
145 event,
146 conjugated,
147 fw_grad,
148 }
149 }
150
151 pub(crate) fn from_owned_contiguous_data(
152 data: Vec<T>,
153 dims: Arc<[usize]>,
154 order: MemoryOrder,
155 logical_memory_space: LogicalMemorySpace,
156 preferred_compute_device: Option<ComputeDevice>,
157 conjugated: bool,
158 ) -> Self {
159 let strides = Arc::from(compute_contiguous_strides(dims.as_ref(), order));
160 Self::from_parts(TensorParts {
161 buffer: DataBuffer::from_vec(data),
162 dims,
163 strides,
164 offset: 0,
165 logical_memory_space,
166 preferred_compute_device,
167 event: None,
168 conjugated,
169 fw_grad: None,
170 })
171 }
172
173 pub(crate) fn shared_view_with(
174 &self,
175 dims: Arc<[usize]>,
176 strides: Arc<[isize]>,
177 offset: isize,
178 ) -> Self {
179 Self::from_parts(TensorParts {
180 buffer: self.buffer.clone(),
181 dims,
182 strides,
183 offset,
184 logical_memory_space: self.logical_memory_space,
185 preferred_compute_device: self.preferred_compute_device,
186 event: None,
187 conjugated: self.conjugated,
188 fw_grad: None,
189 })
190 }
191
192 pub(crate) fn materialized_from_vec(&self, data: Vec<T>, order: MemoryOrder) -> Self {
193 Self::from_owned_contiguous_data(
194 data,
195 self.dims.clone(),
196 order,
197 self.logical_memory_space,
198 self.preferred_compute_device,
199 self.conjugated,
200 )
201 }
202
203 pub(crate) fn cpu_backed_slice_or_panic(&self, operation: &str) -> &[T] {
204 self.buffer.as_slice().unwrap_or_else(|| {
205 panic!("{operation}: CPU-only operation; GPU tensors are not supported")
206 })
207 }
208
209 pub fn is_contiguous(&self) -> bool {
218 is_contiguous_in_order(&self.dims, &self.strides, MemoryOrder::ColumnMajor)
219 || is_contiguous_in_order(&self.dims, &self.strides, MemoryOrder::RowMajor)
220 }
221
222 pub fn is_col_major_contiguous(&self) -> bool {
231 is_contiguous_in_order(&self.dims, &self.strides, MemoryOrder::ColumnMajor)
232 }
233
234 pub fn is_row_major_contiguous(&self) -> bool {
243 is_contiguous_in_order(&self.dims, &self.strides, MemoryOrder::RowMajor)
244 }
245}