tenferro_tensor/tensor/metadata.rs
1use tenferro_device::{
2 preferred_compute_devices, ComputeDevice, Error, LogicalMemorySpace, OpKind, Result,
3};
4
5use super::Tensor;
6
7impl<T> Tensor<T> {
8 /// Returns the shape (size of each dimension).
9 ///
10 /// # Examples
11 ///
12 /// ```ignore
13 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
14 /// assert_eq!(t.dims(), &[2, 3]);
15 /// ```
16 pub fn dims(&self) -> &[usize] {
17 &self.dims
18 }
19
20 /// Returns the strides (in units of `T`).
21 ///
22 /// # Examples
23 ///
24 /// ```ignore
25 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
26 /// let _strides = t.strides();
27 /// ```
28 pub fn strides(&self) -> &[isize] {
29 &self.strides
30 }
31
32 /// Returns the element offset into the data buffer.
33 ///
34 /// # Examples
35 ///
36 /// ```ignore
37 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
38 /// assert_eq!(t.offset(), 0);
39 /// ```
40 pub fn offset(&self) -> isize {
41 self.offset
42 }
43
44 /// Returns a reference to the underlying data buffer.
45 ///
46 /// # Examples
47 ///
48 /// ```ignore
49 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
50 /// let _buf = t.buffer();
51 /// ```
52 pub fn buffer(&self) -> &crate::DataBuffer<T> {
53 &self.buffer
54 }
55
56 /// Returns a mutable reference to the underlying data buffer.
57 ///
58 /// # Examples
59 ///
60 /// ```ignore
61 /// let mut t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
62 /// let _buf = t.buffer_mut();
63 /// ```
64 pub fn buffer_mut(&mut self) -> &mut crate::DataBuffer<T> {
65 &mut self.buffer
66 }
67
68 /// Returns the number of dimensions (rank).
69 ///
70 /// # Examples
71 ///
72 /// ```ignore
73 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
74 /// assert_eq!(t.ndim(), 2);
75 /// ```
76 pub fn ndim(&self) -> usize {
77 self.dims.len()
78 }
79
80 /// Returns the total number of elements.
81 ///
82 /// # Examples
83 ///
84 /// ```ignore
85 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
86 /// assert_eq!(t.len(), 6);
87 /// ```
88 pub fn len(&self) -> usize {
89 self.dims.iter().product()
90 }
91
92 /// Returns `true` if the tensor has zero elements.
93 ///
94 /// # Examples
95 ///
96 /// ```ignore
97 /// let t = Tensor::<f64>::zeros(&[0, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
98 /// assert!(t.is_empty());
99 /// ```
100 pub fn is_empty(&self) -> bool {
101 self.len() == 0
102 }
103
104 /// Returns the logical memory space where this tensor's data resides.
105 ///
106 /// # Examples
107 ///
108 /// ```ignore
109 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
110 /// assert_eq!(t.logical_memory_space(), LogicalMemorySpace::MainMemory);
111 /// ```
112 pub fn logical_memory_space(&self) -> LogicalMemorySpace {
113 self.logical_memory_space
114 }
115
116 /// Returns the preferred compute device override, if set.
117 ///
118 /// # Examples
119 ///
120 /// ```ignore
121 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
122 /// assert!(t.preferred_compute_device().is_none());
123 /// ```
124 pub fn preferred_compute_device(&self) -> Option<ComputeDevice> {
125 self.preferred_compute_device
126 }
127
128 /// Set the preferred compute device override.
129 ///
130 /// # Examples
131 ///
132 /// ```ignore
133 /// let mut t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
134 /// t.set_preferred_compute_device(Some(ComputeDevice::Cpu { device_id: 0 }));
135 /// ```
136 pub fn set_preferred_compute_device(&mut self, device: Option<ComputeDevice>) {
137 self.preferred_compute_device = device;
138 }
139
140 /// Returns `true` if this tensor is logically conjugated.
141 ///
142 /// # Examples
143 ///
144 /// ```ignore
145 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
146 /// assert!(!t.is_conjugated());
147 /// ```
148 pub fn is_conjugated(&self) -> bool {
149 self.conjugated
150 }
151
152 /// Returns a reference to the forward-mode tangent, if set.
153 ///
154 /// # Examples
155 ///
156 /// ```ignore
157 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
158 /// assert!(t.fw_grad().is_none());
159 /// ```
160 pub fn fw_grad(&self) -> Option<&Tensor<T>> {
161 self.fw_grad.as_deref()
162 }
163
164 /// Returns `true` if this tensor carries a forward-mode tangent.
165 ///
166 /// # Examples
167 ///
168 /// ```ignore
169 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
170 /// assert!(!t.has_fw_grad());
171 /// ```
172 pub fn has_fw_grad(&self) -> bool {
173 self.fw_grad.is_some()
174 }
175
176 /// Attach a forward-mode tangent to this tensor.
177 ///
178 /// # Examples
179 ///
180 /// ```ignore
181 /// let mut t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
182 /// let grad = Tensor::<f64>::ones(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
183 /// t.set_fw_grad(grad);
184 /// ```
185 pub fn set_fw_grad(&mut self, grad: Tensor<T>) {
186 self.fw_grad = Some(Box::new(grad));
187 }
188
189 /// Detach and return the forward-mode tangent, leaving `None`.
190 ///
191 /// # Examples
192 ///
193 /// ```ignore
194 /// let mut t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
195 /// t.set_fw_grad(Tensor::<f64>::ones(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap());
196 /// let _grad = t.detach_fw_grad().unwrap();
197 /// ```
198 pub fn detach_fw_grad(&mut self) -> Option<Tensor<T>> {
199 self.fw_grad.take().map(|boxed| *boxed)
200 }
201
202 /// Return the effective compute devices for a given operation kind.
203 ///
204 /// # Examples
205 ///
206 /// ```ignore
207 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
208 /// let _devices = t.effective_compute_devices(OpKind::BatchedGemm).unwrap();
209 /// ```
210 pub fn effective_compute_devices(&self, op_kind: OpKind) -> Result<Vec<ComputeDevice>> {
211 let compatible = preferred_compute_devices(self.logical_memory_space, op_kind)?;
212 if let Some(device) = self.preferred_compute_device {
213 if compatible.contains(&device) {
214 Ok(vec![device])
215 } else {
216 Err(Error::NoCompatibleComputeDevice {
217 space: self.logical_memory_space,
218 op: op_kind,
219 })
220 }
221 } else {
222 Ok(compatible)
223 }
224 }
225}