1use std::any::TypeId;
2
3use num_complex::{Complex32, Complex64};
4use tenferro_algebra::Scalar;
5#[cfg(feature = "cuda")]
6use tenferro_device::LogicalMemorySpace;
7use tenferro_device::{checked_batch_count, unflatten_col_major_index_into};
8
9use super::{Tensor, TensorParts};
10use crate::layout::{compute_contiguous_strides, copy_strided, is_contiguous_in_order};
11use crate::{DataBuffer, MemoryOrder};
12
13enum TriangularHalf {
14 Lower,
15 Upper,
16}
17
18impl<T> Tensor<T> {
19 pub fn conj(&self) -> Tensor<T>
32 where
33 T: tenferro_algebra::Conjugate,
34 {
35 Tensor::from_parts(TensorParts {
36 buffer: self.buffer.clone(),
37 dims: self.dims.clone(),
38 strides: self.strides.clone(),
39 offset: self.offset,
40 logical_memory_space: self.logical_memory_space,
41 preferred_compute_device: self.preferred_compute_device,
42 event: self.event.clone(),
43 conjugated: !self.conjugated,
44 fw_grad: None,
45 })
46 }
47
48 pub fn into_conj(self) -> Tensor<T>
58 where
59 T: tenferro_algebra::Conjugate,
60 {
61 Tensor::from_parts(TensorParts {
62 buffer: self.buffer,
63 dims: self.dims,
64 strides: self.strides,
65 offset: self.offset,
66 logical_memory_space: self.logical_memory_space,
67 preferred_compute_device: self.preferred_compute_device,
68 event: self.event,
69 conjugated: !self.conjugated,
70 fw_grad: None,
71 })
72 }
73}
74
75impl<T: Scalar> Tensor<T> {
76 pub fn deep_clone(&self) -> Tensor<T> {
100 self.wait();
101 let order = MemoryOrder::ColumnMajor;
102 let mut data = vec![T::zero(); self.len()];
103 if !data.is_empty() {
104 let dst_strides = compute_contiguous_strides(&self.dims, order);
105 copy_strided(
106 self.cpu_backed_slice_or_panic("deep_clone"),
107 &self.dims,
108 &self.strides,
109 self.offset,
110 &mut data,
111 &dst_strides,
112 );
113 }
114 self.materialized_from_vec(data, order)
115 }
116
117 pub fn contiguous(&self, order: MemoryOrder) -> Tensor<T> {
131 self.wait();
132 if is_contiguous_in_order(&self.dims, &self.strides, order) && self.offset == 0 {
133 return Tensor::from_parts(TensorParts {
134 buffer: self.buffer.clone(),
135 dims: self.dims.clone(),
136 strides: self.strides.clone(),
137 offset: self.offset,
138 logical_memory_space: self.logical_memory_space,
139 preferred_compute_device: self.preferred_compute_device,
140 event: self.event.clone(),
141 conjugated: self.conjugated,
142 fw_grad: self.fw_grad.clone(),
143 });
144 }
145
146 #[cfg(feature = "cuda")]
147 if matches!(
148 self.logical_memory_space,
149 LogicalMemorySpace::GpuMemory { .. }
150 ) {
151 return crate::cuda_runtime::contiguous_tensor(self, order)
152 .unwrap_or_else(|err| panic!("contiguous: GPU materialization failed: {err}"));
153 }
154
155 let mut data = vec![T::zero(); self.len()];
156 if !data.is_empty() {
157 let dst_strides = compute_contiguous_strides(&self.dims, order);
158 copy_strided(
159 self.cpu_backed_slice_or_panic("contiguous"),
160 &self.dims,
161 &self.strides,
162 self.offset,
163 &mut data,
164 &dst_strides,
165 );
166 }
167 self.materialized_from_vec(data, order)
168 }
169
170 pub fn into_contiguous(self, order: MemoryOrder) -> Tensor<T> {
180 if is_contiguous_in_order(&self.dims, &self.strides, order) && self.offset == 0 {
181 return Tensor::from_parts(TensorParts {
182 buffer: self.buffer,
183 dims: self.dims,
184 strides: self.strides,
185 offset: self.offset,
186 logical_memory_space: self.logical_memory_space,
187 preferred_compute_device: self.preferred_compute_device,
188 event: self.event,
189 conjugated: self.conjugated,
190 fw_grad: self.fw_grad,
191 });
192 }
193 self.contiguous(order)
194 }
195
196 pub fn into_column_major(self) -> Tensor<T> {
209 self.into_contiguous(MemoryOrder::ColumnMajor)
210 }
211
212 pub fn to_vec(&self) -> Vec<T> {
239 let c = self.contiguous(MemoryOrder::ColumnMajor);
240 let slice = c
241 .buffer()
242 .as_slice()
243 .expect("to_vec: CPU-only operation; GPU tensors are not supported");
244 slice.to_vec()
245 }
246
247 fn triangular_part(&self, diagonal: isize, half: TriangularHalf) -> Tensor<T> {
248 self.wait();
249 if self.ndim() <= 1 {
250 return self.contiguous(MemoryOrder::ColumnMajor);
251 }
252
253 #[cfg(feature = "cuda")]
254 if matches!(
255 self.logical_memory_space,
256 LogicalMemorySpace::GpuMemory { .. }
257 ) {
258 return crate::cuda_runtime::triangular_part_tensor(
259 self,
260 diagonal,
261 matches!(half, TriangularHalf::Lower),
262 )
263 .unwrap_or_else(|err| panic!("triangular_part: GPU materialization failed: {err}"));
264 }
265
266 let m = self.dims[0];
267 let n = self.dims[1];
268 let out_strides = compute_contiguous_strides(&self.dims, MemoryOrder::ColumnMajor);
269 let mut data = vec![T::zero(); self.len()];
270 if data.is_empty() {
271 return self.materialized_from_vec(data, MemoryOrder::ColumnMajor);
272 }
273
274 let src = self.cpu_backed_slice_or_panic(match half {
275 TriangularHalf::Lower => "tril",
276 TriangularHalf::Upper => "triu",
277 });
278 let batch_dims = &self.dims[2..];
279 let n_batch = checked_batch_count(batch_dims).unwrap_or_else(|err| {
280 panic!(
281 "triangular_part: invalid batch dims {:?}: {err}",
282 batch_dims
283 )
284 });
285 let mut batch_index = vec![0usize; batch_dims.len()];
286
287 for batch in 0..n_batch {
288 if !batch_dims.is_empty() {
289 unflatten_col_major_index_into(batch, batch_dims, &mut batch_index)
290 .unwrap_or_else(|err| {
291 panic!(
292 "triangular_part: failed to unflatten batch index {batch} for dims {:?}: {err}",
293 batch_dims
294 )
295 });
296 }
297 let src_batch_off: isize = batch_index
298 .iter()
299 .enumerate()
300 .try_fold(0isize, |acc, (axis, &idx)| {
301 (idx as isize).checked_mul(self.strides[axis + 2]).and_then(|v| acc.checked_add(v))
302 })
303 .unwrap_or_else(|| {
304 panic!(
305 "triangular_part: source batch offset overflow with batch_index {:?}, strides {:?}",
306 batch_index, self.strides
307 )
308 });
309 let dst_batch_off: isize = batch_index
310 .iter()
311 .enumerate()
312 .try_fold(0isize, |acc, (axis, &idx)| {
313 (idx as isize).checked_mul(out_strides[axis + 2]).and_then(|v| acc.checked_add(v))
314 })
315 .unwrap_or_else(|| {
316 panic!(
317 "triangular_part: destination batch offset overflow with batch_index {:?}, strides {:?}",
318 batch_index, out_strides
319 )
320 });
321
322 for j in 0..n {
323 for i in 0..m {
324 let keep = match half {
325 TriangularHalf::Lower => (j as isize - i as isize) <= diagonal,
326 TriangularHalf::Upper => (j as isize - i as isize) >= diagonal,
327 };
328 if !keep {
329 continue;
330 }
331
332 let src_pos = self
333 .offset
334 .checked_add(src_batch_off)
335 .and_then(|off| (i as isize).checked_mul(self.strides[0]).and_then(|v| off.checked_add(v)))
336 .and_then(|off| (j as isize).checked_mul(self.strides[1]).and_then(|v| off.checked_add(v)))
337 .and_then(|pos| usize::try_from(pos).ok())
338 .unwrap_or_else(|| {
339 panic!(
340 "triangular_part: source position overflow at ({}, {}) with offset {}, batch_off {}, strides {:?}",
341 i, j, self.offset, src_batch_off, self.strides
342 )
343 });
344 let dst_pos = (i as isize)
345 .checked_mul(out_strides[0])
346 .and_then(|v| dst_batch_off.checked_add(v))
347 .and_then(|off| (j as isize).checked_mul(out_strides[1]).and_then(|v| off.checked_add(v)))
348 .and_then(|pos| usize::try_from(pos).ok())
349 .unwrap_or_else(|| {
350 panic!(
351 "triangular_part: destination position overflow at ({}, {}) with batch_off {}, strides {:?}",
352 i, j, dst_batch_off, out_strides
353 )
354 });
355 data[dst_pos] = src[src_pos];
356 }
357 }
358 }
359
360 Tensor::from_parts(TensorParts {
361 buffer: DataBuffer::from_vec(data),
362 dims: self.dims.clone(),
363 strides: std::sync::Arc::from(out_strides),
364 offset: 0,
365 logical_memory_space: self.logical_memory_space,
366 preferred_compute_device: self.preferred_compute_device,
367 event: None,
368 conjugated: self.conjugated,
369 fw_grad: None,
370 })
371 }
372
373 pub fn tril(&self, diagonal: isize) -> Tensor<T> {
383 self.triangular_part(diagonal, TriangularHalf::Lower)
384 }
385
386 pub fn triu(&self, diagonal: isize) -> Tensor<T> {
396 self.triangular_part(diagonal, TriangularHalf::Upper)
397 }
398}
399
400impl<T: Scalar> Tensor<T> {
401 pub(crate) fn materialize_logical_contiguous(&self, order: MemoryOrder) -> Tensor<T> {
407 self.wait();
408
409 #[cfg(feature = "cuda")]
410 if matches!(
411 self.logical_memory_space,
412 LogicalMemorySpace::GpuMemory { .. }
413 ) {
414 return crate::cuda_runtime::materialize_logical_contiguous_tensor(self, order)
415 .unwrap_or_else(|err| {
416 panic!("materialize_logical_contiguous: GPU materialization failed: {err}")
417 });
418 }
419
420 let mut data = vec![T::zero(); self.len()];
421 if !data.is_empty() {
422 let dst_strides = compute_contiguous_strides(&self.dims, order);
423 copy_strided(
424 self.cpu_backed_slice_or_panic("materialize_logical_contiguous"),
425 &self.dims,
426 &self.strides,
427 self.offset,
428 &mut data,
429 &dst_strides,
430 );
431 apply_logical_conjugation_if_needed(&mut data, self.conjugated);
432 }
433
434 Tensor::from_owned_contiguous_data(
435 data,
436 self.dims.clone(),
437 order,
438 self.logical_memory_space,
439 None,
440 false,
441 )
442 }
443}
444
445fn apply_logical_conjugation_if_needed<T: Scalar + 'static>(data: &mut [T], conjugated: bool) {
451 if !conjugated || data.is_empty() {
452 return;
453 }
454
455 if TypeId::of::<T>() == TypeId::of::<Complex32>() {
456 let data = unsafe {
458 std::slice::from_raw_parts_mut(data.as_mut_ptr().cast::<Complex32>(), data.len())
459 };
460 for value in data {
461 *value = value.conj();
462 }
463 return;
464 }
465
466 if TypeId::of::<T>() == TypeId::of::<Complex64>() {
467 let data = unsafe {
469 std::slice::from_raw_parts_mut(data.as_mut_ptr().cast::<Complex64>(), data.len())
470 };
471 for value in data {
472 *value = value.conj();
473 }
474 }
475}