tenferro_tensor/tensor/views/basic.rs
1use super::*;
2use tenferro_algebra::Conjugate;
3
4fn matrix_transpose_permutation(ndim: usize) -> Result<Vec<usize>> {
5 if ndim < 2 {
6 return Err(Error::InvalidArgument(
7 "mT requires at least 2 dimensions".into(),
8 ));
9 }
10
11 let mut perm: Vec<usize> = (0..ndim).collect();
12 perm.swap(ndim - 2, ndim - 1);
13 Ok(perm)
14}
15
16impl<T> Tensor<T> {
17 /// Permute (reorder) the dimensions of the tensor.
18 ///
19 /// # Examples
20 ///
21 /// ```ignore
22 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
23 /// let transposed = t.permute(&[1, 0]).unwrap();
24 /// assert_eq!(transposed.dims(), &[3, 2]);
25 /// ```
26 pub fn permute(&self, perm: &[usize]) -> Result<Tensor<T>> {
27 self.wait();
28 if perm.len() != self.ndim() {
29 return Err(Error::InvalidArgument(format!(
30 "permutation length {} doesn't match ndim {}",
31 perm.len(),
32 self.ndim()
33 )));
34 }
35
36 let mut seen = vec![false; self.ndim()];
37 for &axis in perm {
38 if axis >= self.ndim() {
39 return Err(Error::InvalidArgument(format!(
40 "permutation index {axis} out of range for ndim {}",
41 self.ndim()
42 )));
43 }
44 if seen[axis] {
45 return Err(Error::InvalidArgument(format!(
46 "duplicate index {axis} in permutation"
47 )));
48 }
49 seen[axis] = true;
50 }
51
52 let new_dims: Arc<[usize]> = perm.iter().map(|&axis| self.dims[axis]).collect();
53 let new_strides: Arc<[isize]> = perm.iter().map(|&axis| self.strides[axis]).collect();
54 Ok(self.shared_view_with(new_dims, new_strides, self.offset))
55 }
56
57 /// Broadcast the tensor to a larger shape.
58 ///
59 /// # Examples
60 ///
61 /// ```ignore
62 /// let t = Tensor::<f64>::zeros(&[1, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
63 /// let b = t.broadcast(&[4, 3]).unwrap();
64 /// assert_eq!(b.dims(), &[4, 3]);
65 /// ```
66 pub fn broadcast(&self, target_dims: &[usize]) -> Result<Tensor<T>> {
67 self.wait();
68 if target_dims.len() != self.ndim() {
69 return Err(Error::InvalidArgument(format!(
70 "target dims length {} doesn't match ndim {}",
71 target_dims.len(),
72 self.ndim()
73 )));
74 }
75
76 let mut new_strides = self.strides.to_vec();
77 for (axis, (¤t, &target)) in self.dims.iter().zip(target_dims).enumerate() {
78 if current == target {
79 continue;
80 }
81 if current == 1 {
82 new_strides[axis] = 0;
83 } else {
84 return Err(Error::ShapeMismatch {
85 expected: self.dims.to_vec(),
86 got: target_dims.to_vec(),
87 });
88 }
89 }
90
91 Ok(self.shared_view_with(Arc::from(target_dims), Arc::from(new_strides), self.offset))
92 }
93
94 /// Extract a diagonal view by merging pairs of axes.
95 ///
96 /// # Examples
97 ///
98 /// ```ignore
99 /// let t = Tensor::<f64>::zeros(&[3, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
100 /// let d = t.diagonal(&[(0, 1)]).unwrap();
101 /// assert_eq!(d.dims(), &[3]);
102 /// ```
103 pub fn diagonal(&self, axes: &[(usize, usize)]) -> Result<Tensor<T>> {
104 self.wait();
105 let mut used = vec![false; self.ndim()];
106 let mut diag_dims = Vec::new();
107 let mut diag_strides = Vec::new();
108
109 for &(i, j) in axes {
110 if i >= self.ndim() || j >= self.ndim() {
111 return Err(Error::InvalidArgument(format!(
112 "axis out of range: ({i}, {j}) for tensor with {} dimensions",
113 self.ndim()
114 )));
115 }
116 if i == j {
117 return Err(Error::InvalidArgument(format!(
118 "diagonal axes must be distinct, got ({i}, {j})"
119 )));
120 }
121 if used[i] || used[j] {
122 return Err(Error::InvalidArgument(format!(
123 "axis {i} or {j} used in multiple diagonal pairs"
124 )));
125 }
126 if self.dims[i] != self.dims[j] {
127 return Err(Error::ShapeMismatch {
128 expected: vec![self.dims[i]],
129 got: vec![self.dims[j]],
130 });
131 }
132 used[i] = true;
133 used[j] = true;
134 diag_dims.push(self.dims[i]);
135 let stride = self.strides[i]
136 .checked_add(self.strides[j])
137 .ok_or_else(|| {
138 Error::InvalidArgument(format!(
139 "diagonal stride overflow for axes ({i}, {j}) with strides {} and {}",
140 self.strides[i], self.strides[j]
141 ))
142 })?;
143 diag_strides.push(stride);
144 }
145
146 let mut new_dims = Vec::new();
147 let mut new_strides = Vec::new();
148 for (axis, was_used) in used.iter().enumerate() {
149 if !was_used {
150 new_dims.push(self.dims[axis]);
151 new_strides.push(self.strides[axis]);
152 }
153 }
154 new_dims.extend_from_slice(&diag_dims);
155 new_strides.extend_from_slice(&diag_strides);
156
157 Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
158 }
159
160 /// Return a zero-copy view with a different shape.
161 ///
162 /// This is the strict metadata-only variant of reshape. The returned tensor
163 /// shares storage with `self` and therefore requires the input layout to be
164 /// contiguous (column-major). For PyTorch-style view-or-copy semantics that
165 /// handle non-contiguous inputs, use [`reshape`](Self::reshape) instead.
166 ///
167 /// # Errors
168 ///
169 /// Returns `StrideError` if the tensor is not contiguous.
170 ///
171 /// # Examples
172 ///
173 /// ```ignore
174 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
175 /// let r = t.view(&[6]).unwrap();
176 /// assert_eq!(r.dims(), &[6]);
177 /// ```
178 pub fn view(&self, new_dims: &[usize]) -> Result<Tensor<T>> {
179 self.wait();
180 if self.len() != new_dims.iter().product::<usize>() {
181 return Err(Error::ShapeMismatch {
182 expected: self.dims.to_vec(),
183 got: new_dims.to_vec(),
184 });
185 }
186 if !self.is_contiguous() {
187 return Err(Error::StrideError(format!(
188 "view requires contiguous data (use reshape for view-or-copy semantics): \
189 current strides={:?}, expected contiguous for shape {:?}",
190 self.strides.as_ref(),
191 self.dims.as_ref()
192 )));
193 }
194
195 let new_strides = Arc::from(compute_contiguous_strides(
196 new_dims,
197 crate::MemoryOrder::ColumnMajor,
198 ));
199 Ok(self.shared_view_with(Arc::from(new_dims), new_strides, self.offset))
200 }
201
202 /// Reshape the tensor to a new shape.
203 ///
204 /// Reshape follows tenferro's internal column-major semantics and PyTorch-style
205 /// view-or-copy behavior: it returns a zero-copy view when the current layout
206 /// is compatible with column-major ordering, and otherwise materializes a
207 /// contiguous column-major copy first before returning the view.
208 ///
209 /// For strict zero-copy semantics that reject non-contiguous inputs, use
210 /// [`view`](Self::view) instead.
211 ///
212 /// # Examples
213 ///
214 /// ```ignore
215 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
216 /// let r = t.reshape(&[6]).unwrap();
217 /// assert_eq!(r.dims(), &[6]);
218 /// ```
219 pub fn reshape(&self, new_dims: &[usize]) -> Result<Tensor<T>>
220 where
221 T: tenferro_algebra::Scalar,
222 {
223 if self.len() != new_dims.iter().product::<usize>() {
224 return Err(Error::ShapeMismatch {
225 expected: self.dims.to_vec(),
226 got: new_dims.to_vec(),
227 });
228 }
229
230 match self.view(new_dims) {
231 Ok(view) => Ok(view),
232 Err(Error::StrideError(_)) => self
233 .contiguous(crate::MemoryOrder::ColumnMajor)
234 .view(new_dims),
235 Err(err) => Err(err),
236 }
237 }
238
239 /// Create a zero-copy view with explicit dims and strides.
240 ///
241 /// # Examples
242 ///
243 /// ```ignore
244 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
245 /// let view = t.view_as_strided(vec![3, 2], vec![2, 1]).unwrap();
246 /// assert_eq!(view.dims(), &[3, 2]);
247 /// ```
248 pub fn view_as_strided(
249 &self,
250 new_dims: Vec<usize>,
251 new_strides: Vec<isize>,
252 ) -> Result<Tensor<T>> {
253 self.wait();
254 validate_layout_against_len(&new_dims, &new_strides, self.offset, self.buffer.len())?;
255 Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
256 }
257
258 /// Select a single index along a dimension, removing that dimension.
259 ///
260 /// # Examples
261 ///
262 /// ```ignore
263 /// let t = Tensor::<f64>::zeros(&[2, 3, 4], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
264 /// let slice = t.select(2, 1).unwrap();
265 /// assert_eq!(slice.dims(), &[2, 3]);
266 /// ```
267 pub fn select(&self, dim: usize, index: usize) -> Result<Tensor<T>> {
268 self.wait();
269 if dim >= self.ndim() {
270 return Err(Error::InvalidArgument(format!(
271 "dim {dim} out of range for tensor with {} dimensions",
272 self.ndim()
273 )));
274 }
275 if index >= self.dims[dim] {
276 return Err(Error::InvalidArgument(format!(
277 "index {index} out of range for dimension {dim} with size {}",
278 self.dims[dim]
279 )));
280 }
281
282 let offset = (index as isize)
283 .checked_mul(self.strides[dim])
284 .and_then(|delta| self.offset.checked_add(delta))
285 .ok_or_else(|| {
286 Error::InvalidArgument(format!(
287 "select offset overflow for index {index} in dimension {dim}"
288 ))
289 })?;
290 let mut new_dims = self.dims.to_vec();
291 let mut new_strides = self.strides.to_vec();
292 new_dims.remove(dim);
293 new_strides.remove(dim);
294 Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), offset))
295 }
296
297 /// Narrow (slice) a dimension to a sub-range.
298 ///
299 /// # Examples
300 ///
301 /// ```ignore
302 /// let t = Tensor::<f64>::zeros(&[2, 10], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
303 /// let sub = t.narrow(1, 2, 3).unwrap();
304 /// assert_eq!(sub.dims(), &[2, 3]);
305 /// ```
306 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Tensor<T>> {
307 self.wait();
308 if dim >= self.ndim() {
309 return Err(Error::InvalidArgument(format!(
310 "dim {dim} out of range for tensor with {} dimensions",
311 self.ndim()
312 )));
313 }
314 if start
315 .checked_add(length)
316 .is_none_or(|end| end > self.dims[dim])
317 {
318 return Err(Error::InvalidArgument(format!(
319 "narrow range out of bounds for dimension {dim} with size {}",
320 self.dims[dim]
321 )));
322 }
323
324 let offset = (start as isize)
325 .checked_mul(self.strides[dim])
326 .and_then(|delta| self.offset.checked_add(delta))
327 .ok_or_else(|| {
328 Error::InvalidArgument(format!(
329 "narrow offset overflow for start {start} in dimension {dim}"
330 ))
331 })?;
332 let mut new_dims = self.dims.to_vec();
333 new_dims[dim] = length;
334 Ok(self.shared_view_with(Arc::from(new_dims), self.strides.clone(), offset))
335 }
336
337 /// Insert a size-1 dimension at the specified position.
338 ///
339 /// This is a zero-copy view operation. Negative dimensions are supported
340 /// and count from the end.
341 ///
342 /// # Arguments
343 ///
344 /// * `dim` - Position to insert the new dimension. Must be in range `[-ndim-1, ndim]`.
345 ///
346 /// # Errors
347 ///
348 /// Returns an error if the dimension is out of range.
349 ///
350 /// # Examples
351 ///
352 /// ```ignore
353 /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
354 /// let u = t.unsqueeze(0).unwrap();
355 /// assert_eq!(u.dims(), &[1, 2, 3]);
356 ///
357 /// let u2 = t.unsqueeze(-1).unwrap();
358 /// assert_eq!(u2.dims(), &[2, 3, 1]);
359 /// ```
360 pub fn unsqueeze(&self, dim: isize) -> Result<Tensor<T>> {
361 self.wait();
362 let ndim = self.ndim();
363
364 let dim = if dim < 0 {
365 let wrapped = dim + (ndim as isize) + 1;
366 if wrapped < 0 {
367 return Err(Error::InvalidArgument(format!(
368 "unsqueeze dim {dim} out of range for tensor with {ndim} dimensions (valid: [{}, {}])",
369 -(ndim as isize) - 1,
370 ndim
371 )));
372 }
373 wrapped as usize
374 } else if dim as usize > ndim {
375 return Err(Error::InvalidArgument(format!(
376 "unsqueeze dim {dim} out of range for tensor with {ndim} dimensions (valid: [{}, {}])",
377 -(ndim as isize) - 1,
378 ndim
379 )));
380 } else {
381 dim as usize
382 };
383
384 let mut new_dims: Vec<usize> = self.dims.to_vec();
385 new_dims.insert(dim, 1);
386
387 let mut new_strides: Vec<isize> = self.strides.to_vec();
388 let new_stride = if dim < ndim {
389 self.strides[dim]
390 } else {
391 if ndim > 0 {
392 self.strides[ndim - 1]
393 } else {
394 1
395 }
396 };
397 new_strides.insert(dim, new_stride);
398
399 Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
400 }
401
402 /// Remove all size-1 dimensions from the tensor.
403 ///
404 /// This is a zero-copy view operation.
405 ///
406 /// # Examples
407 ///
408 /// ```ignore
409 /// let t = Tensor::<f64>::zeros(&[1, 2, 1, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
410 /// let s = t.squeeze().unwrap();
411 /// assert_eq!(s.dims(), &[2, 3]);
412 /// ```
413 pub fn squeeze(&self) -> Result<Tensor<T>> {
414 self.wait();
415 let new_dims: Vec<usize> = self.dims.iter().filter(|&&d| d != 1).copied().collect();
416 let new_strides: Vec<isize> = self
417 .dims
418 .iter()
419 .zip(self.strides.iter())
420 .filter(|(&d, _)| d != 1)
421 .map(|(_, &s)| s)
422 .collect();
423
424 Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
425 }
426
427 /// Remove a specific size-1 dimension from the tensor.
428 ///
429 /// This is a zero-copy view operation. Negative dimensions are supported
430 /// and count from the end.
431 ///
432 /// # Arguments
433 ///
434 /// * `dim` - Dimension to remove. Must be in range `[-ndim, ndim-1]` and have size 1.
435 ///
436 /// # Errors
437 ///
438 /// Returns an error if:
439 /// - The dimension is out of range
440 /// - The dimension does not have size 1
441 ///
442 /// # Examples
443 ///
444 /// ```ignore
445 /// let t = Tensor::<f64>::zeros(&[2, 1, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
446 /// let s = t.squeeze_dim(1).unwrap();
447 /// assert_eq!(s.dims(), &[2, 3]);
448 ///
449 /// let s2 = t.squeeze_dim(-2).unwrap();
450 /// assert_eq!(s2.dims(), &[2, 3]);
451 /// ```
452 pub fn squeeze_dim(&self, dim: isize) -> Result<Tensor<T>> {
453 self.wait();
454 let ndim = self.ndim();
455
456 if ndim == 0 {
457 return Err(Error::InvalidArgument(
458 "squeeze_dim: cannot squeeze a rank-0 tensor".to_string(),
459 ));
460 }
461
462 let dim = if dim < 0 {
463 let wrapped = dim + (ndim as isize);
464 if wrapped < 0 {
465 return Err(Error::InvalidArgument(format!(
466 "squeeze_dim dim {dim} out of range for tensor with {ndim} dimensions (valid: [{}, {}])",
467 -(ndim as isize),
468 ndim - 1
469 )));
470 }
471 wrapped as usize
472 } else if dim as usize >= ndim {
473 return Err(Error::InvalidArgument(format!(
474 "squeeze_dim dim {dim} out of range for tensor with {ndim} dimensions (valid: [{}, {}])",
475 -(ndim as isize),
476 ndim - 1
477 )));
478 } else {
479 dim as usize
480 };
481
482 if self.dims[dim] != 1 {
483 return Err(Error::InvalidArgument(format!(
484 "squeeze_dim: dimension {dim} has size {} (expected 1)",
485 self.dims[dim]
486 )));
487 }
488
489 let mut new_dims = self.dims.to_vec();
490 new_dims.remove(dim);
491
492 let mut new_strides = self.strides.to_vec();
493 new_strides.remove(dim);
494
495 Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
496 }
497
498 /// Return a zero-copy view with the last two axes transposed.
499 ///
500 /// This is a metadata-only operation. For batched matrices, leading batch
501 /// axes are preserved and only the final two matrix axes are swapped.
502 ///
503 /// # Examples
504 ///
505 /// ```ignore
506 /// use tenferro_tensor::{MemoryOrder, Tensor};
507 ///
508 /// let t = Tensor::<f64>::from_slice(
509 /// &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
510 /// &[2, 3],
511 /// MemoryOrder::ColumnMajor,
512 /// )
513 /// .unwrap();
514 /// let mt = t.mT().unwrap();
515 /// assert_eq!(mt.dims(), &[3, 2]);
516 /// ```
517 #[allow(non_snake_case)]
518 pub fn mT(&self) -> Result<Tensor<T>> {
519 self.permute(&matrix_transpose_permutation(self.ndim())?)
520 }
521}
522
523impl<T> Tensor<T>
524where
525 T: Conjugate,
526{
527 /// Return a zero-copy conjugate-transpose view over the last two axes.
528 ///
529 /// This is equivalent to `self.mT()?.conj()`: swap the trailing matrix axes
530 /// and toggle the lazy conjugation flag.
531 ///
532 /// # Examples
533 ///
534 /// ```ignore
535 /// use num_complex::Complex64;
536 /// use tenferro_tensor::{MemoryOrder, Tensor};
537 ///
538 /// let z = Tensor::<Complex64>::from_slice(
539 /// &[Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
540 /// &[2, 1],
541 /// MemoryOrder::ColumnMajor,
542 /// )
543 /// .unwrap();
544 /// let mh = z.mH().unwrap();
545 /// assert_eq!(mh.dims(), &[1, 2]);
546 /// assert!(mh.is_conjugated());
547 /// ```
548 #[allow(non_snake_case)]
549 pub fn mH(&self) -> Result<Tensor<T>> {
550 Ok(self.mT()?.conj())
551 }
552}