tenferro_tensor/tensor/
data_ops.rs

1use std::any::TypeId;
2
3use num_complex::{Complex32, Complex64};
4use tenferro_algebra::Scalar;
5#[cfg(feature = "cuda")]
6use tenferro_device::LogicalMemorySpace;
7use tenferro_device::{checked_batch_count, unflatten_col_major_index_into};
8
9use super::{Tensor, TensorParts};
10use crate::layout::{compute_contiguous_strides, copy_strided, is_contiguous_in_order};
11use crate::{DataBuffer, MemoryOrder};
12
13enum TriangularHalf {
14    Lower,
15    Upper,
16}
17
18impl<T> Tensor<T> {
19    /// Return a lazily-conjugated tensor (shared buffer, flag flip).
20    ///
21    /// # Examples
22    ///
23    /// ```ignore
24    /// use num_complex::Complex64;
25    ///
26    /// let data = vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -4.0)];
27    /// let a = Tensor::from_slice(&data, &[2], MemoryOrder::ColumnMajor).unwrap();
28    /// let a_conj = a.conj();
29    /// assert!(a_conj.is_conjugated());
30    /// ```
31    pub fn conj(&self) -> Tensor<T>
32    where
33        T: tenferro_algebra::Conjugate,
34    {
35        Tensor::from_parts(TensorParts {
36            buffer: self.buffer.clone(),
37            dims: self.dims.clone(),
38            strides: self.strides.clone(),
39            offset: self.offset,
40            logical_memory_space: self.logical_memory_space,
41            preferred_compute_device: self.preferred_compute_device,
42            event: self.event.clone(),
43            conjugated: !self.conjugated,
44            fw_grad: None,
45        })
46    }
47
48    /// Consume this tensor and return a lazily-conjugated version.
49    ///
50    /// # Examples
51    ///
52    /// ```ignore
53    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
54    /// let tc = t.into_conj();
55    /// assert!(tc.is_conjugated());
56    /// ```
57    pub fn into_conj(self) -> Tensor<T>
58    where
59        T: tenferro_algebra::Conjugate,
60    {
61        Tensor::from_parts(TensorParts {
62            buffer: self.buffer,
63            dims: self.dims,
64            strides: self.strides,
65            offset: self.offset,
66            logical_memory_space: self.logical_memory_space,
67            preferred_compute_device: self.preferred_compute_device,
68            event: self.event,
69            conjugated: !self.conjugated,
70            fw_grad: None,
71        })
72    }
73}
74
75impl<T: Scalar> Tensor<T> {
76    /// Create a deep copy with an exclusively-owned contiguous buffer.
77    ///
78    /// Unlike [`clone`](Clone::clone) (which is a shallow `Arc` refcount
79    /// bump), this always allocates a fresh buffer and copies element data.
80    /// The returned tensor is contiguous in column-major order and has
81    /// `buffer.is_unique() == true`, so [`set`](Tensor::set) and
82    /// [`get_mut`](Tensor::get_mut) are guaranteed to succeed.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use tenferro_tensor::{MemoryOrder, Tensor};
88    ///
89    /// let a = Tensor::<f64>::from_slice(
90    ///     &[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor,
91    /// ).unwrap();
92    /// let b = a.clone(); // shallow — shares buffer
93    ///
94    /// let mut c = a.deep_clone(); // deep — independent buffer
95    /// c.set(&[0, 0], 99.0).unwrap();
96    /// assert_eq!(c.get(&[0, 0]), Some(&99.0));
97    /// assert_eq!(a.get(&[0, 0]), Some(&1.0)); // original unchanged
98    /// ```
99    pub fn deep_clone(&self) -> Tensor<T> {
100        self.wait();
101        let order = MemoryOrder::ColumnMajor;
102        let mut data = vec![T::zero(); self.len()];
103        if !data.is_empty() {
104            let dst_strides = compute_contiguous_strides(&self.dims, order);
105            copy_strided(
106                self.cpu_backed_slice_or_panic("deep_clone"),
107                &self.dims,
108                &self.strides,
109                self.offset,
110                &mut data,
111                &dst_strides,
112            );
113        }
114        self.materialized_from_vec(data, order)
115    }
116
117    /// Return a contiguous copy of this tensor in the given memory order.
118    ///
119    /// `order` controls the materialized output buffer only. It does not change
120    /// the internal column-major semantics used by view operations such as
121    /// [`reshape`](Tensor::reshape).
122    ///
123    /// # Examples
124    ///
125    /// ```ignore
126    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
127    /// let c = t.contiguous(MemoryOrder::RowMajor);
128    /// assert!(c.is_contiguous());
129    /// ```
130    pub fn contiguous(&self, order: MemoryOrder) -> Tensor<T> {
131        self.wait();
132        if is_contiguous_in_order(&self.dims, &self.strides, order) && self.offset == 0 {
133            return Tensor::from_parts(TensorParts {
134                buffer: self.buffer.clone(),
135                dims: self.dims.clone(),
136                strides: self.strides.clone(),
137                offset: self.offset,
138                logical_memory_space: self.logical_memory_space,
139                preferred_compute_device: self.preferred_compute_device,
140                event: self.event.clone(),
141                conjugated: self.conjugated,
142                fw_grad: self.fw_grad.clone(),
143            });
144        }
145
146        #[cfg(feature = "cuda")]
147        if matches!(
148            self.logical_memory_space,
149            LogicalMemorySpace::GpuMemory { .. }
150        ) {
151            return crate::cuda_runtime::contiguous_tensor(self, order)
152                .unwrap_or_else(|err| panic!("contiguous: GPU materialization failed: {err}"));
153        }
154
155        let mut data = vec![T::zero(); self.len()];
156        if !data.is_empty() {
157            let dst_strides = compute_contiguous_strides(&self.dims, order);
158            copy_strided(
159                self.cpu_backed_slice_or_panic("contiguous"),
160                &self.dims,
161                &self.strides,
162                self.offset,
163                &mut data,
164                &dst_strides,
165            );
166        }
167        self.materialized_from_vec(data, order)
168    }
169
170    /// Consume this tensor and return a contiguous version.
171    ///
172    /// # Examples
173    ///
174    /// ```ignore
175    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
176    /// let c = t.into_contiguous(MemoryOrder::ColumnMajor);
177    /// assert!(c.is_contiguous());
178    /// ```
179    pub fn into_contiguous(self, order: MemoryOrder) -> Tensor<T> {
180        if is_contiguous_in_order(&self.dims, &self.strides, order) && self.offset == 0 {
181            return Tensor::from_parts(TensorParts {
182                buffer: self.buffer,
183                dims: self.dims,
184                strides: self.strides,
185                offset: self.offset,
186                logical_memory_space: self.logical_memory_space,
187                preferred_compute_device: self.preferred_compute_device,
188                event: self.event,
189                conjugated: self.conjugated,
190                fw_grad: self.fw_grad,
191            });
192        }
193        self.contiguous(order)
194    }
195
196    /// Consume this tensor and return a contiguous column-major version.
197    ///
198    /// This is a convenience wrapper around `into_contiguous(MemoryOrder::ColumnMajor)`
199    /// since column-major is tenferro's canonical internal layout.
200    ///
201    /// # Examples
202    ///
203    /// ```ignore
204    /// let t = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::RowMajor).unwrap();
205    /// let col_major = t.into_column_major();
206    /// assert!(col_major.is_col_major_contiguous());
207    /// ```
208    pub fn into_column_major(self) -> Tensor<T> {
209        self.into_contiguous(MemoryOrder::ColumnMajor)
210    }
211
212    /// Copy tensor data into a flat `Vec<T>` in column-major order.
213    ///
214    /// The returned vector has length `self.len()` with elements laid out
215    /// in column-major (Fortran) order. For a 2-D tensor with shape
216    /// `[m, n]`, the first `m` elements are column 0, the next `m` are
217    /// column 1, and so on.
218    ///
219    /// This method internally materializes a contiguous copy when the
220    /// tensor is not already column-major contiguous, so it always
221    /// returns owned data regardless of the original layout.
222    ///
223    /// # Examples
224    ///
225    /// ```
226    /// use tenferro_tensor::{MemoryOrder, Tensor};
227    ///
228    /// let t = Tensor::<f64>::from_row_major_slice(
229    ///     &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
230    ///     &[2, 3],
231    /// ).unwrap();
232    /// // Matrix (row-major input):
233    /// //   [[1, 2, 3],
234    /// //    [4, 5, 6]]
235    /// // Column-major output: col0=[1,4], col1=[2,5], col2=[3,6]
236    /// assert_eq!(t.to_vec(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
237    /// ```
238    pub fn to_vec(&self) -> Vec<T> {
239        let c = self.contiguous(MemoryOrder::ColumnMajor);
240        let slice = c
241            .buffer()
242            .as_slice()
243            .expect("to_vec: CPU-only operation; GPU tensors are not supported");
244        slice.to_vec()
245    }
246
247    fn triangular_part(&self, diagonal: isize, half: TriangularHalf) -> Tensor<T> {
248        self.wait();
249        if self.ndim() <= 1 {
250            return self.contiguous(MemoryOrder::ColumnMajor);
251        }
252
253        #[cfg(feature = "cuda")]
254        if matches!(
255            self.logical_memory_space,
256            LogicalMemorySpace::GpuMemory { .. }
257        ) {
258            return crate::cuda_runtime::triangular_part_tensor(
259                self,
260                diagonal,
261                matches!(half, TriangularHalf::Lower),
262            )
263            .unwrap_or_else(|err| panic!("triangular_part: GPU materialization failed: {err}"));
264        }
265
266        let m = self.dims[0];
267        let n = self.dims[1];
268        let out_strides = compute_contiguous_strides(&self.dims, MemoryOrder::ColumnMajor);
269        let mut data = vec![T::zero(); self.len()];
270        if data.is_empty() {
271            return self.materialized_from_vec(data, MemoryOrder::ColumnMajor);
272        }
273
274        let src = self.cpu_backed_slice_or_panic(match half {
275            TriangularHalf::Lower => "tril",
276            TriangularHalf::Upper => "triu",
277        });
278        let batch_dims = &self.dims[2..];
279        let n_batch = checked_batch_count(batch_dims).unwrap_or_else(|err| {
280            panic!(
281                "triangular_part: invalid batch dims {:?}: {err}",
282                batch_dims
283            )
284        });
285        let mut batch_index = vec![0usize; batch_dims.len()];
286
287        for batch in 0..n_batch {
288            if !batch_dims.is_empty() {
289                unflatten_col_major_index_into(batch, batch_dims, &mut batch_index)
290                    .unwrap_or_else(|err| {
291                        panic!(
292                            "triangular_part: failed to unflatten batch index {batch} for dims {:?}: {err}",
293                            batch_dims
294                        )
295                    });
296            }
297            let src_batch_off: isize = batch_index
298                .iter()
299                .enumerate()
300                .try_fold(0isize, |acc, (axis, &idx)| {
301                    (idx as isize).checked_mul(self.strides[axis + 2]).and_then(|v| acc.checked_add(v))
302                })
303                .unwrap_or_else(|| {
304                    panic!(
305                        "triangular_part: source batch offset overflow with batch_index {:?}, strides {:?}",
306                        batch_index, self.strides
307                    )
308                });
309            let dst_batch_off: isize = batch_index
310                .iter()
311                .enumerate()
312                .try_fold(0isize, |acc, (axis, &idx)| {
313                    (idx as isize).checked_mul(out_strides[axis + 2]).and_then(|v| acc.checked_add(v))
314                })
315                .unwrap_or_else(|| {
316                    panic!(
317                        "triangular_part: destination batch offset overflow with batch_index {:?}, strides {:?}",
318                        batch_index, out_strides
319                    )
320                });
321
322            for j in 0..n {
323                for i in 0..m {
324                    let keep = match half {
325                        TriangularHalf::Lower => (j as isize - i as isize) <= diagonal,
326                        TriangularHalf::Upper => (j as isize - i as isize) >= diagonal,
327                    };
328                    if !keep {
329                        continue;
330                    }
331
332                    let src_pos = self
333                        .offset
334                        .checked_add(src_batch_off)
335                        .and_then(|off| (i as isize).checked_mul(self.strides[0]).and_then(|v| off.checked_add(v)))
336                        .and_then(|off| (j as isize).checked_mul(self.strides[1]).and_then(|v| off.checked_add(v)))
337                        .and_then(|pos| usize::try_from(pos).ok())
338                        .unwrap_or_else(|| {
339                            panic!(
340                        "triangular_part: source position overflow at ({}, {}) with offset {}, batch_off {}, strides {:?}",
341                        i, j, self.offset, src_batch_off, self.strides
342                    )
343                        });
344                    let dst_pos = (i as isize)
345                        .checked_mul(out_strides[0])
346                        .and_then(|v| dst_batch_off.checked_add(v))
347                        .and_then(|off| (j as isize).checked_mul(out_strides[1]).and_then(|v| off.checked_add(v)))
348                        .and_then(|pos| usize::try_from(pos).ok())
349                        .unwrap_or_else(|| {
350                            panic!(
351                        "triangular_part: destination position overflow at ({}, {}) with batch_off {}, strides {:?}",
352                        i, j, dst_batch_off, out_strides
353                    )
354                        });
355                    data[dst_pos] = src[src_pos];
356                }
357            }
358        }
359
360        Tensor::from_parts(TensorParts {
361            buffer: DataBuffer::from_vec(data),
362            dims: self.dims.clone(),
363            strides: std::sync::Arc::from(out_strides),
364            offset: 0,
365            logical_memory_space: self.logical_memory_space,
366            preferred_compute_device: self.preferred_compute_device,
367            event: None,
368            conjugated: self.conjugated,
369            fw_grad: None,
370        })
371    }
372
373    /// Extract the lower triangular part of a matrix.
374    ///
375    /// # Examples
376    ///
377    /// ```ignore
378    /// let a = Tensor::<f64>::ones(&[3, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
379    /// let lower = a.tril(0);
380    /// assert_eq!(lower.dims(), &[3, 3]);
381    /// ```
382    pub fn tril(&self, diagonal: isize) -> Tensor<T> {
383        self.triangular_part(diagonal, TriangularHalf::Lower)
384    }
385
386    /// Extract the upper triangular part of a matrix.
387    ///
388    /// # Examples
389    ///
390    /// ```ignore
391    /// let a = Tensor::<f64>::ones(&[3, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
392    /// let upper = a.triu(0);
393    /// assert_eq!(upper.dims(), &[3, 3]);
394    /// ```
395    pub fn triu(&self, diagonal: isize) -> Tensor<T> {
396        self.triangular_part(diagonal, TriangularHalf::Upper)
397    }
398}
399
400impl<T: Scalar> Tensor<T> {
401    /// Materialize the logical tensor values into a fresh contiguous tensor.
402    ///
403    /// The returned tensor is resolved (`conjugated = false`) and clears any
404    /// preferred compute-device hint. GPU tensors use the Layer 0 logical-copy
405    /// substrate when available.
406    pub(crate) fn materialize_logical_contiguous(&self, order: MemoryOrder) -> Tensor<T> {
407        self.wait();
408
409        #[cfg(feature = "cuda")]
410        if matches!(
411            self.logical_memory_space,
412            LogicalMemorySpace::GpuMemory { .. }
413        ) {
414            return crate::cuda_runtime::materialize_logical_contiguous_tensor(self, order)
415                .unwrap_or_else(|err| {
416                    panic!("materialize_logical_contiguous: GPU materialization failed: {err}")
417                });
418        }
419
420        let mut data = vec![T::zero(); self.len()];
421        if !data.is_empty() {
422            let dst_strides = compute_contiguous_strides(&self.dims, order);
423            copy_strided(
424                self.cpu_backed_slice_or_panic("materialize_logical_contiguous"),
425                &self.dims,
426                &self.strides,
427                self.offset,
428                &mut data,
429                &dst_strides,
430            );
431            apply_logical_conjugation_if_needed(&mut data, self.conjugated);
432        }
433
434        Tensor::from_owned_contiguous_data(
435            data,
436            self.dims.clone(),
437            order,
438            self.logical_memory_space,
439            None,
440            false,
441        )
442    }
443}
444
445/// Apply logical conjugation in place when the element type supports it.
446///
447/// This uses a narrow private runtime dispatch so `Tensor::cat` / `Tensor::stack`
448/// can stay generic over `Scalar` while built-in complex tensors still materialize
449/// their logical values correctly.
450fn apply_logical_conjugation_if_needed<T: Scalar + 'static>(data: &mut [T], conjugated: bool) {
451    if !conjugated || data.is_empty() {
452        return;
453    }
454
455    if TypeId::of::<T>() == TypeId::of::<Complex32>() {
456        // SAFETY: the type check above guarantees `T` really is `Complex32`.
457        let data = unsafe {
458            std::slice::from_raw_parts_mut(data.as_mut_ptr().cast::<Complex32>(), data.len())
459        };
460        for value in data {
461            *value = value.conj();
462        }
463        return;
464    }
465
466    if TypeId::of::<T>() == TypeId::of::<Complex64>() {
467        // SAFETY: the type check above guarantees `T` really is `Complex64`.
468        let data = unsafe {
469            std::slice::from_raw_parts_mut(data.as_mut_ptr().cast::<Complex64>(), data.len())
470        };
471        for value in data {
472            *value = value.conj();
473        }
474    }
475}