1use std::sync::Arc;
2
3#[cfg(feature = "cuda")]
4use num_complex::{Complex32, Complex64};
5#[cfg(feature = "cuda")]
6use std::any::TypeId;
7use tenferro_algebra::Scalar;
8#[cfg(feature = "cuda")]
9use tenferro_device::cuda::runtime::{
10 self as device_cuda, ContiguousOrder, CudaBuffer, StridedCopySpec, StridedCopyTransform,
11};
12use tenferro_device::{Error, LogicalMemorySpace, Result};
13
14use super::{Tensor, TensorParts};
15use crate::layout::compute_contiguous_strides;
16#[cfg(feature = "cuda")]
17use crate::DataBuffer;
18use crate::MemoryOrder;
19
20impl<T: Scalar> Tensor<T> {
21 pub fn stack(tensors: &[&Tensor<T>], dim: isize) -> Result<Tensor<T>> {
59 if tensors.is_empty() {
60 return Err(Error::InvalidArgument(
61 "stack requires at least one tensor".to_string(),
62 ));
63 }
64
65 for t in tensors {
66 t.wait();
67 }
68
69 let first = tensors[0];
70 let ndim = first.ndim();
71
72 let dim = if dim < 0 {
73 let wrapped = dim + (ndim as isize) + 1;
74 if wrapped < 0 {
75 return Err(Error::InvalidArgument(format!(
76 "stack dim {dim} out of range for tensors with {ndim} dimensions (valid: [{}, {}])",
77 -(ndim as isize) - 1,
78 ndim
79 )));
80 }
81 wrapped as usize
82 } else if dim as usize > ndim {
83 return Err(Error::InvalidArgument(format!(
84 "stack dim {dim} out of range for tensors with {ndim} dimensions (valid: [{}, {}])",
85 -(ndim as isize) - 1,
86 ndim
87 )));
88 } else {
89 dim as usize
90 };
91
92 let memory_space = first.logical_memory_space();
93 for (i, t) in tensors.iter().enumerate() {
94 if t.dims() != first.dims() {
95 return Err(Error::ShapeMismatch {
96 expected: first.dims.to_vec(),
97 got: t.dims.to_vec(),
98 });
99 }
100 if t.logical_memory_space() != memory_space {
101 return Err(Error::InvalidArgument(format!(
102 "tensor {} has different memory space {:?} (expected {:?})",
103 i, t.logical_memory_space, memory_space
104 )));
105 }
106 }
107
108 let unsqueezed: Vec<Tensor<T>> = tensors
109 .iter()
110 .map(|tensor| tensor.unsqueeze(dim as isize))
111 .collect::<Result<_>>()?;
112 let unsqueezed_refs: Vec<&Tensor<T>> = unsqueezed.iter().collect();
113
114 Tensor::cat(&unsqueezed_refs, dim as isize)
115 }
116
117 pub fn cat(tensors: &[&Tensor<T>], dim: isize) -> Result<Tensor<T>> {
158 if tensors.is_empty() {
159 return Err(Error::InvalidArgument(
160 "cat requires at least one tensor".to_string(),
161 ));
162 }
163
164 for t in tensors {
165 t.wait();
166 }
167
168 let first = tensors[0];
169 let ndim = first.ndim();
170
171 if ndim == 0 {
172 return Err(Error::InvalidArgument(
173 "cat cannot concatenate rank-0 tensors (use stack to pack scalars)".to_string(),
174 ));
175 }
176
177 let dim = if dim < 0 {
178 let wrapped = dim + (ndim as isize);
179 if wrapped < 0 {
180 return Err(Error::InvalidArgument(format!(
181 "cat dim {dim} out of range for tensors with {ndim} dimensions (valid: [{}, {}])",
182 -(ndim as isize),
183 ndim - 1
184 )));
185 }
186 wrapped as usize
187 } else if dim as usize >= ndim {
188 return Err(Error::InvalidArgument(format!(
189 "cat dim {dim} out of range for tensors with {ndim} dimensions (valid: [{}, {}])",
190 -(ndim as isize),
191 ndim - 1
192 )));
193 } else {
194 dim as usize
195 };
196
197 let memory_space = first.logical_memory_space();
198 let mut total_cat_dim = 0usize;
199 for (i, t) in tensors.iter().enumerate() {
200 if t.ndim() != ndim {
201 return Err(Error::InvalidArgument(format!(
202 "tensor {} has rank {} but expected rank {}",
203 i,
204 t.ndim(),
205 ndim
206 )));
207 }
208 if t.logical_memory_space() != memory_space {
209 return Err(Error::InvalidArgument(format!(
210 "tensor {} has different memory space {:?} (expected {:?})",
211 i, t.logical_memory_space, memory_space
212 )));
213 }
214 for (axis, (&d1, &d2)) in first.dims.iter().zip(t.dims.iter()).enumerate() {
215 if axis != dim && d1 != d2 {
216 return Err(Error::ShapeMismatch {
217 expected: first.dims.to_vec(),
218 got: t.dims.to_vec(),
219 });
220 }
221 }
222 total_cat_dim = total_cat_dim.checked_add(t.dims[dim]).ok_or_else(|| {
223 Error::InvalidArgument("cat: dimension size overflow".to_string())
224 })?;
225 }
226
227 let mut result_dims: Vec<usize> = first.dims.to_vec();
228 result_dims[dim] = total_cat_dim;
229
230 let result_strides = compute_contiguous_strides(&result_dims, MemoryOrder::ColumnMajor);
231
232 #[cfg(feature = "cuda")]
233 if matches!(memory_space, LogicalMemorySpace::GpuMemory { .. }) {
234 return cat_gpu(tensors, dim, memory_space, &result_dims, &result_strides);
235 }
236
237 #[cfg(not(feature = "cuda"))]
238 if memory_space != LogicalMemorySpace::MainMemory {
239 return Err(Error::InvalidArgument(
240 "cat only supports main-memory tensors in Phase 1".to_string(),
241 ));
242 }
243 #[cfg(feature = "cuda")]
244 if memory_space != LogicalMemorySpace::MainMemory {
245 return Err(Error::InvalidArgument(format!(
246 "cat only supports main-memory or same-device GPU tensors, got {memory_space:?}"
247 )));
248 }
249
250 let result_len: usize = result_dims.iter().product();
251 let mut result_data = vec![T::zero(); result_len];
252
253 let mut cat_offset: usize = 0;
254 for tensor in tensors {
255 let contiguous_tensor = tensor.materialize_logical_contiguous(MemoryOrder::ColumnMajor);
256 let src = contiguous_tensor.buffer().as_slice().unwrap();
257 let src_strides = compute_contiguous_strides(&tensor.dims, MemoryOrder::ColumnMajor);
258
259 let mut index = vec![0usize; ndim];
260 let n_elements: usize = tensor.dims.iter().product();
261
262 if n_elements > 0 {
263 for _ in 0..n_elements {
264 let src_pos: usize = index
265 .iter()
266 .zip(src_strides.iter())
267 .map(|(&i, &s)| (i as isize) * s)
268 .sum::<isize>() as usize;
269
270 let dst_pos: usize = index
271 .iter()
272 .enumerate()
273 .zip(result_strides.iter())
274 .map(|((axis, &i), &s)| {
275 let adjusted_i = if axis == dim { i + cat_offset } else { i };
276 (adjusted_i as isize) * s
277 })
278 .sum::<isize>() as usize;
279
280 result_data[dst_pos] = src[src_pos];
281
282 for axis in (0..ndim).rev() {
283 index[axis] += 1;
284 if index[axis] < tensor.dims[axis] {
285 break;
286 }
287 index[axis] = 0;
288 }
289 }
290 }
291
292 cat_offset += tensor.dims[dim];
293 }
294
295 Ok(Tensor::from_parts(TensorParts {
296 buffer: crate::DataBuffer::from_vec(result_data),
297 dims: Arc::from(result_dims),
298 strides: Arc::from(result_strides),
299 offset: 0,
300 logical_memory_space: memory_space,
301 preferred_compute_device: None,
302 event: None,
303 conjugated: false,
304 fw_grad: None,
305 }))
306 }
307}
308
309#[cfg(feature = "cuda")]
310fn materialize_cuda_contiguous_buffer<T: Scalar + 'static>(
311 tensor: &Tensor<T>,
312 runtime: &Arc<device_cuda::CudaRuntime>,
313) -> Result<CudaBuffer<T>> {
314 let src_ptr = tensor.buffer().as_device_ptr().ok_or_else(|| {
315 Error::DeviceError("cat: GPU tensor buffer is not resident on device".into())
316 })?;
317 let spec = StridedCopySpec::to_contiguous(
318 tensor.dims(),
319 tensor.strides(),
320 tensor.offset(),
321 ContiguousOrder::ColumnMajor,
322 )?;
323 let dst = runtime.alloc::<T>(tensor.len())?;
324 if tensor.is_empty() {
325 return Ok(dst);
326 }
327
328 unsafe {
329 if tensor.is_conjugated() && supports_conj_strided_copy::<T>() {
330 runtime.copy_strided_raw_with_transform(
331 src_ptr,
332 dst.device_ptr(),
333 &spec,
334 StridedCopyTransform::Conj,
335 )?;
336 } else {
337 runtime.copy_strided_raw(src_ptr, dst.device_ptr(), &spec)?;
338 }
339 }
340 Ok(dst)
341}
342
343#[cfg(feature = "cuda")]
344fn cat_gpu<T: Scalar + 'static>(
345 tensors: &[&Tensor<T>],
346 dim: usize,
347 memory_space: LogicalMemorySpace,
348 result_dims: &[usize],
349 result_strides: &[isize],
350) -> Result<Tensor<T>> {
351 let LogicalMemorySpace::GpuMemory { device_id } = memory_space else {
352 return Err(Error::DeviceError(format!(
353 "cat: unsupported CUDA memory space {memory_space:?}"
354 )));
355 };
356 let runtime = device_cuda::get_or_init(device_id)?;
357
358 let mut current_dims = tensors[0].dims().to_vec();
361 let mut current_buf = materialize_cuda_contiguous_buffer(tensors[0], &runtime)?;
362 for next in tensors.iter().skip(1) {
363 let next_buf = materialize_cuda_contiguous_buffer(next, &runtime)?;
364 let current_strides = compute_contiguous_strides(¤t_dims, MemoryOrder::ColumnMajor);
365 let next_strides = compute_contiguous_strides(next.dims(), MemoryOrder::ColumnMajor);
366 let current_spec = StridedCopySpec::to_contiguous(
367 ¤t_dims,
368 ¤t_strides,
369 0,
370 ContiguousOrder::ColumnMajor,
371 )?;
372 let next_spec = StridedCopySpec::to_contiguous(
373 next.dims(),
374 &next_strides,
375 0,
376 ContiguousOrder::ColumnMajor,
377 )?;
378
379 current_buf = runtime.pack_concat_sources(
380 ¤t_buf,
381 ¤t_spec,
382 &next_buf,
383 &next_spec,
384 dim,
385 ContiguousOrder::ColumnMajor,
386 )?;
387 current_dims[dim] = current_dims[dim]
388 .checked_add(next.dims()[dim])
389 .ok_or_else(|| Error::InvalidArgument("cat: dimension size overflow".to_string()))?;
390 }
391
392 debug_assert_eq!(current_dims.as_slice(), result_dims);
393
394 let current_len = current_buf.len();
395 let current_ptr = current_buf.device_ptr();
396 let buffer = unsafe {
397 DataBuffer::from_gpu_parts(current_ptr, current_len, memory_space, move || {
398 drop(current_buf)
399 })
400 };
401 Ok(Tensor::from_parts(
402 buffer,
403 Arc::from(result_dims.to_vec()),
404 Arc::from(result_strides.to_vec()),
405 0,
406 memory_space,
407 None,
408 None,
409 false,
410 None,
411 ))
412}
413
414#[cfg(feature = "cuda")]
415fn supports_conj_strided_copy<T: 'static>() -> bool {
416 TypeId::of::<T>() == TypeId::of::<Complex32>() || TypeId::of::<T>() == TypeId::of::<Complex64>()
417}