tenferro_tensor/tensor/
combine.rs

1use std::sync::Arc;
2
3#[cfg(feature = "cuda")]
4use num_complex::{Complex32, Complex64};
5#[cfg(feature = "cuda")]
6use std::any::TypeId;
7use tenferro_algebra::Scalar;
8#[cfg(feature = "cuda")]
9use tenferro_device::cuda::runtime::{
10    self as device_cuda, ContiguousOrder, CudaBuffer, StridedCopySpec, StridedCopyTransform,
11};
12use tenferro_device::{Error, LogicalMemorySpace, Result};
13
14use super::{Tensor, TensorParts};
15use crate::layout::compute_contiguous_strides;
16#[cfg(feature = "cuda")]
17use crate::DataBuffer;
18use crate::MemoryOrder;
19
20impl<T: Scalar> Tensor<T> {
21    /// Stack tensors along a new dimension.
22    ///
23    /// Creates a new dimension and concatenates the input tensors along it.
24    /// All input tensors must have the same shape. Negative dimensions are
25    /// supported and count from the end.
26    ///
27    /// This is a dense materialization operation that allocates a new buffer.
28    /// It is implemented by inserting a size-1 axis with
29    /// [`Tensor::unsqueeze`] and then delegating to [`Tensor::cat`], so it
30    /// materializes logical values, resolves conjugation, and supports the same
31    /// CPU and same-device CUDA paths as concatenation.
32    ///
33    /// # Arguments
34    ///
35    /// * `tensors` - Slice of input tensors to stack. Must not be empty.
36    /// * `dim` - Position to insert the new dimension. Must be in range `[-ndim-1, ndim]`.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if:
41    /// - The input list is empty
42    /// - Tensors have different shapes
43    /// - Tensors have different memory spaces or devices
44    /// - The dimension is out of range
45    ///
46    /// # Examples
47    ///
48    /// ```ignore
49    /// use tenferro_device::LogicalMemorySpace;
50    /// use tenferro_tensor::{MemoryOrder, Tensor};
51    ///
52    /// let a = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
53    /// let b = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
54    ///
55    /// let stacked = Tensor::stack(&[&a, &b], 0).unwrap();
56    /// assert_eq!(stacked.dims(), &[2, 2, 3]);
57    /// ```
58    pub fn stack(tensors: &[&Tensor<T>], dim: isize) -> Result<Tensor<T>> {
59        if tensors.is_empty() {
60            return Err(Error::InvalidArgument(
61                "stack requires at least one tensor".to_string(),
62            ));
63        }
64
65        for t in tensors {
66            t.wait();
67        }
68
69        let first = tensors[0];
70        let ndim = first.ndim();
71
72        let dim = if dim < 0 {
73            let wrapped = dim + (ndim as isize) + 1;
74            if wrapped < 0 {
75                return Err(Error::InvalidArgument(format!(
76                    "stack dim {dim} out of range for tensors with {ndim} dimensions (valid: [{}, {}])",
77                    -(ndim as isize) - 1,
78                    ndim
79                )));
80            }
81            wrapped as usize
82        } else if dim as usize > ndim {
83            return Err(Error::InvalidArgument(format!(
84                "stack dim {dim} out of range for tensors with {ndim} dimensions (valid: [{}, {}])",
85                -(ndim as isize) - 1,
86                ndim
87            )));
88        } else {
89            dim as usize
90        };
91
92        let memory_space = first.logical_memory_space();
93        for (i, t) in tensors.iter().enumerate() {
94            if t.dims() != first.dims() {
95                return Err(Error::ShapeMismatch {
96                    expected: first.dims.to_vec(),
97                    got: t.dims.to_vec(),
98                });
99            }
100            if t.logical_memory_space() != memory_space {
101                return Err(Error::InvalidArgument(format!(
102                    "tensor {} has different memory space {:?} (expected {:?})",
103                    i, t.logical_memory_space, memory_space
104                )));
105            }
106        }
107
108        let unsqueezed: Vec<Tensor<T>> = tensors
109            .iter()
110            .map(|tensor| tensor.unsqueeze(dim as isize))
111            .collect::<Result<_>>()?;
112        let unsqueezed_refs: Vec<&Tensor<T>> = unsqueezed.iter().collect();
113
114        Tensor::cat(&unsqueezed_refs, dim as isize)
115    }
116
117    /// Concatenate tensors along an existing dimension.
118    ///
119    /// Joins tensors along the specified dimension. All tensors must have
120    /// the same rank and matching sizes on non-concatenated dimensions.
121    /// Negative dimensions are supported and count from the end.
122    ///
123    /// This is a dense materialization operation that allocates a new buffer.
124    /// Logical conjugation is materialized per input, the output is resolved
125    /// (`conjugated = false`), and any preferred compute-device hint is cleared.
126    /// Main-memory tensors are always supported; with `cuda` enabled, same-device
127    /// GPU tensors are also supported.
128    ///
129    /// # Arguments
130    ///
131    /// * `tensors` - Slice of input tensors to concatenate. Must not be empty.
132    /// * `dim` - Dimension along which to concatenate. Must be in range `[-ndim, ndim-1]`.
133    ///
134    /// # Errors
135    ///
136    /// Returns an error if:
137    /// - The input list is empty
138    /// - Any tensor is rank-0 (scalars cannot be concatenated)
139    /// - Tensors have different ranks
140    /// - Tensors have mismatched sizes on non-concatenated dimensions
141    /// - Tensors have different memory spaces
142    /// - The dimension is out of range
143    /// - Non-main-memory tensors are provided without `cuda` support
144    ///
145    /// # Examples
146    ///
147    /// ```ignore
148    /// use tenferro_device::LogicalMemorySpace;
149    /// use tenferro_tensor::{MemoryOrder, Tensor};
150    ///
151    /// let a = Tensor::<f64>::zeros(&[2, 3], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
152    /// let b = Tensor::<f64>::zeros(&[2, 4], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor).unwrap();
153    ///
154    /// let concatenated = Tensor::cat(&[&a, &b], 1).unwrap();
155    /// assert_eq!(concatenated.dims(), &[2, 7]);
156    /// ```
157    pub fn cat(tensors: &[&Tensor<T>], dim: isize) -> Result<Tensor<T>> {
158        if tensors.is_empty() {
159            return Err(Error::InvalidArgument(
160                "cat requires at least one tensor".to_string(),
161            ));
162        }
163
164        for t in tensors {
165            t.wait();
166        }
167
168        let first = tensors[0];
169        let ndim = first.ndim();
170
171        if ndim == 0 {
172            return Err(Error::InvalidArgument(
173                "cat cannot concatenate rank-0 tensors (use stack to pack scalars)".to_string(),
174            ));
175        }
176
177        let dim = if dim < 0 {
178            let wrapped = dim + (ndim as isize);
179            if wrapped < 0 {
180                return Err(Error::InvalidArgument(format!(
181                    "cat dim {dim} out of range for tensors with {ndim} dimensions (valid: [{}, {}])",
182                    -(ndim as isize),
183                    ndim - 1
184                )));
185            }
186            wrapped as usize
187        } else if dim as usize >= ndim {
188            return Err(Error::InvalidArgument(format!(
189                "cat dim {dim} out of range for tensors with {ndim} dimensions (valid: [{}, {}])",
190                -(ndim as isize),
191                ndim - 1
192            )));
193        } else {
194            dim as usize
195        };
196
197        let memory_space = first.logical_memory_space();
198        let mut total_cat_dim = 0usize;
199        for (i, t) in tensors.iter().enumerate() {
200            if t.ndim() != ndim {
201                return Err(Error::InvalidArgument(format!(
202                    "tensor {} has rank {} but expected rank {}",
203                    i,
204                    t.ndim(),
205                    ndim
206                )));
207            }
208            if t.logical_memory_space() != memory_space {
209                return Err(Error::InvalidArgument(format!(
210                    "tensor {} has different memory space {:?} (expected {:?})",
211                    i, t.logical_memory_space, memory_space
212                )));
213            }
214            for (axis, (&d1, &d2)) in first.dims.iter().zip(t.dims.iter()).enumerate() {
215                if axis != dim && d1 != d2 {
216                    return Err(Error::ShapeMismatch {
217                        expected: first.dims.to_vec(),
218                        got: t.dims.to_vec(),
219                    });
220                }
221            }
222            total_cat_dim = total_cat_dim.checked_add(t.dims[dim]).ok_or_else(|| {
223                Error::InvalidArgument("cat: dimension size overflow".to_string())
224            })?;
225        }
226
227        let mut result_dims: Vec<usize> = first.dims.to_vec();
228        result_dims[dim] = total_cat_dim;
229
230        let result_strides = compute_contiguous_strides(&result_dims, MemoryOrder::ColumnMajor);
231
232        #[cfg(feature = "cuda")]
233        if matches!(memory_space, LogicalMemorySpace::GpuMemory { .. }) {
234            return cat_gpu(tensors, dim, memory_space, &result_dims, &result_strides);
235        }
236
237        #[cfg(not(feature = "cuda"))]
238        if memory_space != LogicalMemorySpace::MainMemory {
239            return Err(Error::InvalidArgument(
240                "cat only supports main-memory tensors in Phase 1".to_string(),
241            ));
242        }
243        #[cfg(feature = "cuda")]
244        if memory_space != LogicalMemorySpace::MainMemory {
245            return Err(Error::InvalidArgument(format!(
246                "cat only supports main-memory or same-device GPU tensors, got {memory_space:?}"
247            )));
248        }
249
250        let result_len: usize = result_dims.iter().product();
251        let mut result_data = vec![T::zero(); result_len];
252
253        let mut cat_offset: usize = 0;
254        for tensor in tensors {
255            let contiguous_tensor = tensor.materialize_logical_contiguous(MemoryOrder::ColumnMajor);
256            let src = contiguous_tensor.buffer().as_slice().unwrap();
257            let src_strides = compute_contiguous_strides(&tensor.dims, MemoryOrder::ColumnMajor);
258
259            let mut index = vec![0usize; ndim];
260            let n_elements: usize = tensor.dims.iter().product();
261
262            if n_elements > 0 {
263                for _ in 0..n_elements {
264                    let src_pos: usize = index
265                        .iter()
266                        .zip(src_strides.iter())
267                        .map(|(&i, &s)| (i as isize) * s)
268                        .sum::<isize>() as usize;
269
270                    let dst_pos: usize = index
271                        .iter()
272                        .enumerate()
273                        .zip(result_strides.iter())
274                        .map(|((axis, &i), &s)| {
275                            let adjusted_i = if axis == dim { i + cat_offset } else { i };
276                            (adjusted_i as isize) * s
277                        })
278                        .sum::<isize>() as usize;
279
280                    result_data[dst_pos] = src[src_pos];
281
282                    for axis in (0..ndim).rev() {
283                        index[axis] += 1;
284                        if index[axis] < tensor.dims[axis] {
285                            break;
286                        }
287                        index[axis] = 0;
288                    }
289                }
290            }
291
292            cat_offset += tensor.dims[dim];
293        }
294
295        Ok(Tensor::from_parts(TensorParts {
296            buffer: crate::DataBuffer::from_vec(result_data),
297            dims: Arc::from(result_dims),
298            strides: Arc::from(result_strides),
299            offset: 0,
300            logical_memory_space: memory_space,
301            preferred_compute_device: None,
302            event: None,
303            conjugated: false,
304            fw_grad: None,
305        }))
306    }
307}
308
309#[cfg(feature = "cuda")]
310fn materialize_cuda_contiguous_buffer<T: Scalar + 'static>(
311    tensor: &Tensor<T>,
312    runtime: &Arc<device_cuda::CudaRuntime>,
313) -> Result<CudaBuffer<T>> {
314    let src_ptr = tensor.buffer().as_device_ptr().ok_or_else(|| {
315        Error::DeviceError("cat: GPU tensor buffer is not resident on device".into())
316    })?;
317    let spec = StridedCopySpec::to_contiguous(
318        tensor.dims(),
319        tensor.strides(),
320        tensor.offset(),
321        ContiguousOrder::ColumnMajor,
322    )?;
323    let dst = runtime.alloc::<T>(tensor.len())?;
324    if tensor.is_empty() {
325        return Ok(dst);
326    }
327
328    unsafe {
329        if tensor.is_conjugated() && supports_conj_strided_copy::<T>() {
330            runtime.copy_strided_raw_with_transform(
331                src_ptr,
332                dst.device_ptr(),
333                &spec,
334                StridedCopyTransform::Conj,
335            )?;
336        } else {
337            runtime.copy_strided_raw(src_ptr, dst.device_ptr(), &spec)?;
338        }
339    }
340    Ok(dst)
341}
342
343#[cfg(feature = "cuda")]
344fn cat_gpu<T: Scalar + 'static>(
345    tensors: &[&Tensor<T>],
346    dim: usize,
347    memory_space: LogicalMemorySpace,
348    result_dims: &[usize],
349    result_strides: &[isize],
350) -> Result<Tensor<T>> {
351    let LogicalMemorySpace::GpuMemory { device_id } = memory_space else {
352        return Err(Error::DeviceError(format!(
353            "cat: unsupported CUDA memory space {memory_space:?}"
354        )));
355    };
356    let runtime = device_cuda::get_or_init(device_id)?;
357
358    // Stage each source into a contiguous GPU buffer so we can reuse the
359    // concat-pack substrate without falling back to host materialization.
360    let mut current_dims = tensors[0].dims().to_vec();
361    let mut current_buf = materialize_cuda_contiguous_buffer(tensors[0], &runtime)?;
362    for next in tensors.iter().skip(1) {
363        let next_buf = materialize_cuda_contiguous_buffer(next, &runtime)?;
364        let current_strides = compute_contiguous_strides(&current_dims, MemoryOrder::ColumnMajor);
365        let next_strides = compute_contiguous_strides(next.dims(), MemoryOrder::ColumnMajor);
366        let current_spec = StridedCopySpec::to_contiguous(
367            &current_dims,
368            &current_strides,
369            0,
370            ContiguousOrder::ColumnMajor,
371        )?;
372        let next_spec = StridedCopySpec::to_contiguous(
373            next.dims(),
374            &next_strides,
375            0,
376            ContiguousOrder::ColumnMajor,
377        )?;
378
379        current_buf = runtime.pack_concat_sources(
380            &current_buf,
381            &current_spec,
382            &next_buf,
383            &next_spec,
384            dim,
385            ContiguousOrder::ColumnMajor,
386        )?;
387        current_dims[dim] = current_dims[dim]
388            .checked_add(next.dims()[dim])
389            .ok_or_else(|| Error::InvalidArgument("cat: dimension size overflow".to_string()))?;
390    }
391
392    debug_assert_eq!(current_dims.as_slice(), result_dims);
393
394    let current_len = current_buf.len();
395    let current_ptr = current_buf.device_ptr();
396    let buffer = unsafe {
397        DataBuffer::from_gpu_parts(current_ptr, current_len, memory_space, move || {
398            drop(current_buf)
399        })
400    };
401    Ok(Tensor::from_parts(
402        buffer,
403        Arc::from(result_dims.to_vec()),
404        Arc::from(result_strides.to_vec()),
405        0,
406        memory_space,
407        None,
408        None,
409        false,
410        None,
411    ))
412}
413
414#[cfg(feature = "cuda")]
415fn supports_conj_strided_copy<T: 'static>() -> bool {
416    TypeId::of::<T>() == TypeId::of::<Complex32>() || TypeId::of::<T>() == TypeId::of::<Complex64>()
417}