tenferro_tensor/tensor/views/
basic.rs

1use super::*;
2use tenferro_algebra::Conjugate;
3
4fn matrix_transpose_permutation(ndim: usize) -> Result<Vec<usize>> {
5    if ndim < 2 {
6        return Err(Error::InvalidArgument(
7            "mT requires at least 2 dimensions".into(),
8        ));
9    }
10
11    let mut perm: Vec<usize> = (0..ndim).collect();
12    perm.swap(ndim - 2, ndim - 1);
13    Ok(perm)
14}
15
16impl<T> Tensor<T> {
17    /// Permute (reorder) the dimensions of the tensor.
18    ///
19    /// # Examples
20    ///
21    /// ```ignore
22    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
23    /// let transposed = t.permute(&[1, 0]).unwrap();
24    /// assert_eq!(transposed.dims(), &[3, 2]);
25    /// ```
26    pub fn permute(&self, perm: &[usize]) -> Result<Tensor<T>> {
27        self.wait();
28        if perm.len() != self.ndim() {
29            return Err(Error::InvalidArgument(format!(
30                "permutation length {} doesn't match ndim {}",
31                perm.len(),
32                self.ndim()
33            )));
34        }
35
36        let mut seen = vec![false; self.ndim()];
37        for &axis in perm {
38            if axis >= self.ndim() {
39                return Err(Error::InvalidArgument(format!(
40                    "permutation index {axis} out of range for ndim {}",
41                    self.ndim()
42                )));
43            }
44            if seen[axis] {
45                return Err(Error::InvalidArgument(format!(
46                    "duplicate index {axis} in permutation"
47                )));
48            }
49            seen[axis] = true;
50        }
51
52        let new_dims: Arc<[usize]> = perm.iter().map(|&axis| self.dims[axis]).collect();
53        let new_strides: Arc<[isize]> = perm.iter().map(|&axis| self.strides[axis]).collect();
54        Ok(self.shared_view_with(new_dims, new_strides, self.offset))
55    }
56
57    /// Broadcast the tensor to a larger shape.
58    ///
59    /// # Examples
60    ///
61    /// ```ignore
62    /// let t = Tensor::<f64>::zeros(&[1, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
63    /// let b = t.broadcast(&[4, 3]).unwrap();
64    /// assert_eq!(b.dims(), &[4, 3]);
65    /// ```
66    pub fn broadcast(&self, target_dims: &[usize]) -> Result<Tensor<T>> {
67        self.wait();
68        if target_dims.len() != self.ndim() {
69            return Err(Error::InvalidArgument(format!(
70                "target dims length {} doesn't match ndim {}",
71                target_dims.len(),
72                self.ndim()
73            )));
74        }
75
76        let mut new_strides = self.strides.to_vec();
77        for (axis, (&current, &target)) in self.dims.iter().zip(target_dims).enumerate() {
78            if current == target {
79                continue;
80            }
81            if current == 1 {
82                new_strides[axis] = 0;
83            } else {
84                return Err(Error::ShapeMismatch {
85                    expected: self.dims.to_vec(),
86                    got: target_dims.to_vec(),
87                });
88            }
89        }
90
91        Ok(self.shared_view_with(Arc::from(target_dims), Arc::from(new_strides), self.offset))
92    }
93
94    /// Extract a diagonal view by merging pairs of axes.
95    ///
96    /// # Examples
97    ///
98    /// ```ignore
99    /// let t = Tensor::<f64>::zeros(&[3, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
100    /// let d = t.diagonal(&[(0, 1)]).unwrap();
101    /// assert_eq!(d.dims(), &[3]);
102    /// ```
103    pub fn diagonal(&self, axes: &[(usize, usize)]) -> Result<Tensor<T>> {
104        self.wait();
105        let mut used = vec![false; self.ndim()];
106        let mut diag_dims = Vec::new();
107        let mut diag_strides = Vec::new();
108
109        for &(i, j) in axes {
110            if i >= self.ndim() || j >= self.ndim() {
111                return Err(Error::InvalidArgument(format!(
112                    "axis out of range: ({i}, {j}) for tensor with {} dimensions",
113                    self.ndim()
114                )));
115            }
116            if i == j {
117                return Err(Error::InvalidArgument(format!(
118                    "diagonal axes must be distinct, got ({i}, {j})"
119                )));
120            }
121            if used[i] || used[j] {
122                return Err(Error::InvalidArgument(format!(
123                    "axis {i} or {j} used in multiple diagonal pairs"
124                )));
125            }
126            if self.dims[i] != self.dims[j] {
127                return Err(Error::ShapeMismatch {
128                    expected: vec![self.dims[i]],
129                    got: vec![self.dims[j]],
130                });
131            }
132            used[i] = true;
133            used[j] = true;
134            diag_dims.push(self.dims[i]);
135            let stride = self.strides[i]
136                .checked_add(self.strides[j])
137                .ok_or_else(|| {
138                    Error::InvalidArgument(format!(
139                        "diagonal stride overflow for axes ({i}, {j}) with strides {} and {}",
140                        self.strides[i], self.strides[j]
141                    ))
142                })?;
143            diag_strides.push(stride);
144        }
145
146        let mut new_dims = Vec::new();
147        let mut new_strides = Vec::new();
148        for (axis, was_used) in used.iter().enumerate() {
149            if !was_used {
150                new_dims.push(self.dims[axis]);
151                new_strides.push(self.strides[axis]);
152            }
153        }
154        new_dims.extend_from_slice(&diag_dims);
155        new_strides.extend_from_slice(&diag_strides);
156
157        Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
158    }
159
160    /// Return a zero-copy view with a different shape.
161    ///
162    /// This is the strict metadata-only variant of reshape. The returned tensor
163    /// shares storage with `self` and therefore requires the input layout to be
164    /// contiguous (column-major). For PyTorch-style view-or-copy semantics that
165    /// handle non-contiguous inputs, use [`reshape`](Self::reshape) instead.
166    ///
167    /// # Errors
168    ///
169    /// Returns `StrideError` if the tensor is not contiguous.
170    ///
171    /// # Examples
172    ///
173    /// ```ignore
174    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
175    /// let r = t.view(&[6]).unwrap();
176    /// assert_eq!(r.dims(), &[6]);
177    /// ```
178    pub fn view(&self, new_dims: &[usize]) -> Result<Tensor<T>> {
179        self.wait();
180        if self.len() != new_dims.iter().product::<usize>() {
181            return Err(Error::ShapeMismatch {
182                expected: self.dims.to_vec(),
183                got: new_dims.to_vec(),
184            });
185        }
186        if !self.is_contiguous() {
187            return Err(Error::StrideError(format!(
188                "view requires contiguous data (use reshape for view-or-copy semantics): \
189                 current strides={:?}, expected contiguous for shape {:?}",
190                self.strides.as_ref(),
191                self.dims.as_ref()
192            )));
193        }
194
195        let new_strides = Arc::from(compute_contiguous_strides(
196            new_dims,
197            crate::MemoryOrder::ColumnMajor,
198        ));
199        Ok(self.shared_view_with(Arc::from(new_dims), new_strides, self.offset))
200    }
201
202    /// Reshape the tensor to a new shape.
203    ///
204    /// Reshape follows tenferro's internal column-major semantics and PyTorch-style
205    /// view-or-copy behavior: it returns a zero-copy view when the current layout
206    /// is compatible with column-major ordering, and otherwise materializes a
207    /// contiguous column-major copy first before returning the view.
208    ///
209    /// For strict zero-copy semantics that reject non-contiguous inputs, use
210    /// [`view`](Self::view) instead.
211    ///
212    /// # Examples
213    ///
214    /// ```ignore
215    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
216    /// let r = t.reshape(&[6]).unwrap();
217    /// assert_eq!(r.dims(), &[6]);
218    /// ```
219    pub fn reshape(&self, new_dims: &[usize]) -> Result<Tensor<T>>
220    where
221        T: tenferro_algebra::Scalar,
222    {
223        if self.len() != new_dims.iter().product::<usize>() {
224            return Err(Error::ShapeMismatch {
225                expected: self.dims.to_vec(),
226                got: new_dims.to_vec(),
227            });
228        }
229
230        match self.view(new_dims) {
231            Ok(view) => Ok(view),
232            Err(Error::StrideError(_)) => self
233                .contiguous(crate::MemoryOrder::ColumnMajor)
234                .view(new_dims),
235            Err(err) => Err(err),
236        }
237    }
238
239    /// Create a zero-copy view with explicit dims and strides.
240    ///
241    /// # Examples
242    ///
243    /// ```ignore
244    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
245    /// let view = t.view_as_strided(vec![3, 2], vec![2, 1]).unwrap();
246    /// assert_eq!(view.dims(), &[3, 2]);
247    /// ```
248    pub fn view_as_strided(
249        &self,
250        new_dims: Vec<usize>,
251        new_strides: Vec<isize>,
252    ) -> Result<Tensor<T>> {
253        self.wait();
254        validate_layout_against_len(&new_dims, &new_strides, self.offset, self.buffer.len())?;
255        Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
256    }
257
258    /// Select a single index along a dimension, removing that dimension.
259    ///
260    /// # Examples
261    ///
262    /// ```ignore
263    /// let t = Tensor::<f64>::zeros(&[2, 3, 4], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
264    /// let slice = t.select(2, 1).unwrap();
265    /// assert_eq!(slice.dims(), &[2, 3]);
266    /// ```
267    pub fn select(&self, dim: usize, index: usize) -> Result<Tensor<T>> {
268        self.wait();
269        if dim >= self.ndim() {
270            return Err(Error::InvalidArgument(format!(
271                "dim {dim} out of range for tensor with {} dimensions",
272                self.ndim()
273            )));
274        }
275        if index >= self.dims[dim] {
276            return Err(Error::InvalidArgument(format!(
277                "index {index} out of range for dimension {dim} with size {}",
278                self.dims[dim]
279            )));
280        }
281
282        let offset = (index as isize)
283            .checked_mul(self.strides[dim])
284            .and_then(|delta| self.offset.checked_add(delta))
285            .ok_or_else(|| {
286                Error::InvalidArgument(format!(
287                    "select offset overflow for index {index} in dimension {dim}"
288                ))
289            })?;
290        let mut new_dims = self.dims.to_vec();
291        let mut new_strides = self.strides.to_vec();
292        new_dims.remove(dim);
293        new_strides.remove(dim);
294        Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), offset))
295    }
296
297    /// Narrow (slice) a dimension to a sub-range.
298    ///
299    /// # Examples
300    ///
301    /// ```ignore
302    /// let t = Tensor::<f64>::zeros(&[2, 10], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
303    /// let sub = t.narrow(1, 2, 3).unwrap();
304    /// assert_eq!(sub.dims(), &[2, 3]);
305    /// ```
306    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Tensor<T>> {
307        self.wait();
308        if dim >= self.ndim() {
309            return Err(Error::InvalidArgument(format!(
310                "dim {dim} out of range for tensor with {} dimensions",
311                self.ndim()
312            )));
313        }
314        if start
315            .checked_add(length)
316            .is_none_or(|end| end > self.dims[dim])
317        {
318            return Err(Error::InvalidArgument(format!(
319                "narrow range out of bounds for dimension {dim} with size {}",
320                self.dims[dim]
321            )));
322        }
323
324        let offset = (start as isize)
325            .checked_mul(self.strides[dim])
326            .and_then(|delta| self.offset.checked_add(delta))
327            .ok_or_else(|| {
328                Error::InvalidArgument(format!(
329                    "narrow offset overflow for start {start} in dimension {dim}"
330                ))
331            })?;
332        let mut new_dims = self.dims.to_vec();
333        new_dims[dim] = length;
334        Ok(self.shared_view_with(Arc::from(new_dims), self.strides.clone(), offset))
335    }
336
337    /// Insert a size-1 dimension at the specified position.
338    ///
339    /// This is a zero-copy view operation. Negative dimensions are supported
340    /// and count from the end.
341    ///
342    /// # Arguments
343    ///
344    /// * `dim` - Position to insert the new dimension. Must be in range `[-ndim-1, ndim]`.
345    ///
346    /// # Errors
347    ///
348    /// Returns an error if the dimension is out of range.
349    ///
350    /// # Examples
351    ///
352    /// ```ignore
353    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
354    /// let u = t.unsqueeze(0).unwrap();
355    /// assert_eq!(u.dims(), &[1, 2, 3]);
356    ///
357    /// let u2 = t.unsqueeze(-1).unwrap();
358    /// assert_eq!(u2.dims(), &[2, 3, 1]);
359    /// ```
360    pub fn unsqueeze(&self, dim: isize) -> Result<Tensor<T>> {
361        self.wait();
362        let ndim = self.ndim();
363
364        let dim = if dim < 0 {
365            let wrapped = dim + (ndim as isize) + 1;
366            if wrapped < 0 {
367                return Err(Error::InvalidArgument(format!(
368                    "unsqueeze dim {dim} out of range for tensor with {ndim} dimensions (valid: [{}, {}])",
369                    -(ndim as isize) - 1,
370                    ndim
371                )));
372            }
373            wrapped as usize
374        } else if dim as usize > ndim {
375            return Err(Error::InvalidArgument(format!(
376                "unsqueeze dim {dim} out of range for tensor with {ndim} dimensions (valid: [{}, {}])",
377                -(ndim as isize) - 1,
378                ndim
379            )));
380        } else {
381            dim as usize
382        };
383
384        let mut new_dims: Vec<usize> = self.dims.to_vec();
385        new_dims.insert(dim, 1);
386
387        let mut new_strides: Vec<isize> = self.strides.to_vec();
388        let new_stride = if dim < ndim {
389            self.strides[dim]
390        } else {
391            if ndim > 0 {
392                self.strides[ndim - 1]
393            } else {
394                1
395            }
396        };
397        new_strides.insert(dim, new_stride);
398
399        Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
400    }
401
402    /// Remove all size-1 dimensions from the tensor.
403    ///
404    /// This is a zero-copy view operation.
405    ///
406    /// # Examples
407    ///
408    /// ```ignore
409    /// let t = Tensor::<f64>::zeros(&[1, 2, 1, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
410    /// let s = t.squeeze().unwrap();
411    /// assert_eq!(s.dims(), &[2, 3]);
412    /// ```
413    pub fn squeeze(&self) -> Result<Tensor<T>> {
414        self.wait();
415        let new_dims: Vec<usize> = self.dims.iter().filter(|&&d| d != 1).copied().collect();
416        let new_strides: Vec<isize> = self
417            .dims
418            .iter()
419            .zip(self.strides.iter())
420            .filter(|(&d, _)| d != 1)
421            .map(|(_, &s)| s)
422            .collect();
423
424        Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
425    }
426
427    /// Remove a specific size-1 dimension from the tensor.
428    ///
429    /// This is a zero-copy view operation. Negative dimensions are supported
430    /// and count from the end.
431    ///
432    /// # Arguments
433    ///
434    /// * `dim` - Dimension to remove. Must be in range `[-ndim, ndim-1]` and have size 1.
435    ///
436    /// # Errors
437    ///
438    /// Returns an error if:
439    /// - The dimension is out of range
440    /// - The dimension does not have size 1
441    ///
442    /// # Examples
443    ///
444    /// ```ignore
445    /// let t = Tensor::<f64>::zeros(&[2, 1, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
446    /// let s = t.squeeze_dim(1).unwrap();
447    /// assert_eq!(s.dims(), &[2, 3]);
448    ///
449    /// let s2 = t.squeeze_dim(-2).unwrap();
450    /// assert_eq!(s2.dims(), &[2, 3]);
451    /// ```
452    pub fn squeeze_dim(&self, dim: isize) -> Result<Tensor<T>> {
453        self.wait();
454        let ndim = self.ndim();
455
456        if ndim == 0 {
457            return Err(Error::InvalidArgument(
458                "squeeze_dim: cannot squeeze a rank-0 tensor".to_string(),
459            ));
460        }
461
462        let dim = if dim < 0 {
463            let wrapped = dim + (ndim as isize);
464            if wrapped < 0 {
465                return Err(Error::InvalidArgument(format!(
466                    "squeeze_dim dim {dim} out of range for tensor with {ndim} dimensions (valid: [{}, {}])",
467                    -(ndim as isize),
468                    ndim - 1
469                )));
470            }
471            wrapped as usize
472        } else if dim as usize >= ndim {
473            return Err(Error::InvalidArgument(format!(
474                "squeeze_dim dim {dim} out of range for tensor with {ndim} dimensions (valid: [{}, {}])",
475                -(ndim as isize),
476                ndim - 1
477            )));
478        } else {
479            dim as usize
480        };
481
482        if self.dims[dim] != 1 {
483            return Err(Error::InvalidArgument(format!(
484                "squeeze_dim: dimension {dim} has size {} (expected 1)",
485                self.dims[dim]
486            )));
487        }
488
489        let mut new_dims = self.dims.to_vec();
490        new_dims.remove(dim);
491
492        let mut new_strides = self.strides.to_vec();
493        new_strides.remove(dim);
494
495        Ok(self.shared_view_with(Arc::from(new_dims), Arc::from(new_strides), self.offset))
496    }
497
498    /// Return a zero-copy view with the last two axes transposed.
499    ///
500    /// This is a metadata-only operation. For batched matrices, leading batch
501    /// axes are preserved and only the final two matrix axes are swapped.
502    ///
503    /// # Examples
504    ///
505    /// ```ignore
506    /// use tenferro_tensor::{MemoryOrder, Tensor};
507    ///
508    /// let t = Tensor::<f64>::from_slice(
509    ///     &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
510    ///     &[2, 3],
511    ///     MemoryOrder::ColumnMajor,
512    /// )
513    /// .unwrap();
514    /// let mt = t.mT().unwrap();
515    /// assert_eq!(mt.dims(), &[3, 2]);
516    /// ```
517    #[allow(non_snake_case)]
518    pub fn mT(&self) -> Result<Tensor<T>> {
519        self.permute(&matrix_transpose_permutation(self.ndim())?)
520    }
521}
522
523impl<T> Tensor<T>
524where
525    T: Conjugate,
526{
527    /// Return a zero-copy conjugate-transpose view over the last two axes.
528    ///
529    /// This is equivalent to `self.mT()?.conj()`: swap the trailing matrix axes
530    /// and toggle the lazy conjugation flag.
531    ///
532    /// # Examples
533    ///
534    /// ```ignore
535    /// use num_complex::Complex64;
536    /// use tenferro_tensor::{MemoryOrder, Tensor};
537    ///
538    /// let z = Tensor::<Complex64>::from_slice(
539    ///     &[Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)],
540    ///     &[2, 1],
541    ///     MemoryOrder::ColumnMajor,
542    /// )
543    /// .unwrap();
544    /// let mh = z.mH().unwrap();
545    /// assert_eq!(mh.dims(), &[1, 2]);
546    /// assert!(mh.is_conjugated());
547    /// ```
548    #[allow(non_snake_case)]
549    pub fn mH(&self) -> Result<Tensor<T>> {
550        Ok(self.mT()?.conj())
551    }
552}