tenferro_tensor/tensor/
metadata.rs

1use tenferro_device::{
2    preferred_compute_devices, ComputeDevice, Error, LogicalMemorySpace, OpKind, Result,
3};
4
5use super::Tensor;
6
7impl<T> Tensor<T> {
8    /// Returns the shape (size of each dimension).
9    ///
10    /// # Examples
11    ///
12    /// ```ignore
13    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
14    /// assert_eq!(t.dims(), &[2, 3]);
15    /// ```
16    pub fn dims(&self) -> &[usize] {
17        &self.dims
18    }
19
20    /// Returns the strides (in units of `T`).
21    ///
22    /// # Examples
23    ///
24    /// ```ignore
25    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
26    /// let _strides = t.strides();
27    /// ```
28    pub fn strides(&self) -> &[isize] {
29        &self.strides
30    }
31
32    /// Returns the element offset into the data buffer.
33    ///
34    /// # Examples
35    ///
36    /// ```ignore
37    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
38    /// assert_eq!(t.offset(), 0);
39    /// ```
40    pub fn offset(&self) -> isize {
41        self.offset
42    }
43
44    /// Returns a reference to the underlying data buffer.
45    ///
46    /// # Examples
47    ///
48    /// ```ignore
49    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
50    /// let _buf = t.buffer();
51    /// ```
52    pub fn buffer(&self) -> &crate::DataBuffer<T> {
53        &self.buffer
54    }
55
56    /// Returns a mutable reference to the underlying data buffer.
57    ///
58    /// # Examples
59    ///
60    /// ```ignore
61    /// let mut t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
62    /// let _buf = t.buffer_mut();
63    /// ```
64    pub fn buffer_mut(&mut self) -> &mut crate::DataBuffer<T> {
65        &mut self.buffer
66    }
67
68    /// Returns the number of dimensions (rank).
69    ///
70    /// # Examples
71    ///
72    /// ```ignore
73    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
74    /// assert_eq!(t.ndim(), 2);
75    /// ```
76    pub fn ndim(&self) -> usize {
77        self.dims.len()
78    }
79
80    /// Returns the total number of elements.
81    ///
82    /// # Examples
83    ///
84    /// ```ignore
85    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
86    /// assert_eq!(t.len(), 6);
87    /// ```
88    pub fn len(&self) -> usize {
89        self.dims.iter().product()
90    }
91
92    /// Returns `true` if the tensor has zero elements.
93    ///
94    /// # Examples
95    ///
96    /// ```ignore
97    /// let t = Tensor::<f64>::zeros(&[0, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
98    /// assert!(t.is_empty());
99    /// ```
100    pub fn is_empty(&self) -> bool {
101        self.len() == 0
102    }
103
104    /// Returns the logical memory space where this tensor's data resides.
105    ///
106    /// # Examples
107    ///
108    /// ```ignore
109    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
110    /// assert_eq!(t.logical_memory_space(), LogicalMemorySpace::MainMemory);
111    /// ```
112    pub fn logical_memory_space(&self) -> LogicalMemorySpace {
113        self.logical_memory_space
114    }
115
116    /// Returns the preferred compute device override, if set.
117    ///
118    /// # Examples
119    ///
120    /// ```ignore
121    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
122    /// assert!(t.preferred_compute_device().is_none());
123    /// ```
124    pub fn preferred_compute_device(&self) -> Option<ComputeDevice> {
125        self.preferred_compute_device
126    }
127
128    /// Set the preferred compute device override.
129    ///
130    /// # Examples
131    ///
132    /// ```ignore
133    /// let mut t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
134    /// t.set_preferred_compute_device(Some(ComputeDevice::Cpu { device_id: 0 }));
135    /// ```
136    pub fn set_preferred_compute_device(&mut self, device: Option<ComputeDevice>) {
137        self.preferred_compute_device = device;
138    }
139
140    /// Returns `true` if this tensor is logically conjugated.
141    ///
142    /// # Examples
143    ///
144    /// ```ignore
145    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
146    /// assert!(!t.is_conjugated());
147    /// ```
148    pub fn is_conjugated(&self) -> bool {
149        self.conjugated
150    }
151
152    /// Returns a reference to the forward-mode tangent, if set.
153    ///
154    /// # Examples
155    ///
156    /// ```ignore
157    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
158    /// assert!(t.fw_grad().is_none());
159    /// ```
160    pub fn fw_grad(&self) -> Option<&Tensor<T>> {
161        self.fw_grad.as_deref()
162    }
163
164    /// Returns `true` if this tensor carries a forward-mode tangent.
165    ///
166    /// # Examples
167    ///
168    /// ```ignore
169    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
170    /// assert!(!t.has_fw_grad());
171    /// ```
172    pub fn has_fw_grad(&self) -> bool {
173        self.fw_grad.is_some()
174    }
175
176    /// Attach a forward-mode tangent to this tensor.
177    ///
178    /// # Examples
179    ///
180    /// ```ignore
181    /// let mut t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
182    /// let grad = Tensor::<f64>::ones(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
183    /// t.set_fw_grad(grad);
184    /// ```
185    pub fn set_fw_grad(&mut self, grad: Tensor<T>) {
186        self.fw_grad = Some(Box::new(grad));
187    }
188
189    /// Detach and return the forward-mode tangent, leaving `None`.
190    ///
191    /// # Examples
192    ///
193    /// ```ignore
194    /// let mut t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
195    /// t.set_fw_grad(Tensor::<f64>::ones(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap());
196    /// let _grad = t.detach_fw_grad().unwrap();
197    /// ```
198    pub fn detach_fw_grad(&mut self) -> Option<Tensor<T>> {
199        self.fw_grad.take().map(|boxed| *boxed)
200    }
201
202    /// Return the effective compute devices for a given operation kind.
203    ///
204    /// # Examples
205    ///
206    /// ```ignore
207    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
208    /// let _devices = t.effective_compute_devices(OpKind::BatchedGemm).unwrap();
209    /// ```
210    pub fn effective_compute_devices(&self, op_kind: OpKind) -> Result<Vec<ComputeDevice>> {
211        let compatible = preferred_compute_devices(self.logical_memory_space, op_kind)?;
212        if let Some(device) = self.preferred_compute_device {
213            if compatible.contains(&device) {
214                Ok(vec![device])
215            } else {
216                Err(Error::NoCompatibleComputeDevice {
217                    space: self.logical_memory_space,
218                    op: op_kind,
219                })
220            }
221        } else {
222            Ok(compatible)
223        }
224    }
225}