tenferro_tensor/tensor/views/
complex.rs1use super::*;
2use crate::tensor::TensorParts;
3
4fn view_as_real_strides(strides: &[isize]) -> Result<Vec<isize>> {
5 let mut output = Vec::with_capacity(strides.len() + 1);
6 for (axis, &stride) in strides.iter().enumerate() {
7 output.push(stride.checked_mul(2).ok_or_else(|| {
8 Error::StrideError(format!(
9 "view_as_real: stride overflow for axis {axis} with stride {stride}"
10 ))
11 })?);
12 }
13 output.push(1);
14 Ok(output)
15}
16
17fn view_as_complex_strides(dims: &[usize], strides: &[isize]) -> Result<Vec<isize>> {
18 if dims.is_empty() || strides.is_empty() {
19 return Err(Error::InvalidArgument(
20 "view_as_complex requires at least one dimension".into(),
21 ));
22 }
23 if dims.len() != strides.len() {
24 return Err(Error::InvalidArgument(format!(
25 "view_as_complex rank mismatch: dims={} strides={}",
26 dims.len(),
27 strides.len()
28 )));
29 }
30
31 let mut output = Vec::with_capacity(strides.len().saturating_sub(1));
32 for (axis, (&dim, &stride)) in dims[..dims.len() - 1]
33 .iter()
34 .zip(&strides[..strides.len() - 1])
35 .enumerate()
36 {
37 if dim != 1 && stride % 2 != 0 {
38 return Err(Error::InvalidArgument(format!(
39 "view_as_complex requires even strides on all leading dimensions, got stride {stride} at axis {axis}"
40 )));
41 }
42 output.push(stride / 2);
43 }
44 Ok(output)
45}
46
47fn select_complex_component<T>(view: Tensor<T>, component: usize) -> Result<Tensor<T>> {
48 view.select(view.ndim() - 1, component)
49}
50
51impl Tensor<Complex32> {
52 pub fn view_as_real(&self) -> Result<Tensor<f32>> {
74 self.wait();
75 if self.is_conjugated() {
76 return Err(Error::InvalidArgument(
77 "view_as_real requires a resolved complex tensor".into(),
78 ));
79 }
80
81 let mut dims = self.dims().to_vec();
82 dims.push(2);
83 let strides = view_as_real_strides(self.strides())?;
84 let offset = self.offset().checked_mul(2).ok_or_else(|| {
85 Error::StrideError("view_as_real: offset overflow when reinterpreting tensor".into())
86 })?;
87 let buffer_len =
88 self.buffer().len().checked_mul(2).ok_or_else(|| {
89 Error::StrideError("view_as_real: storage length overflow".into())
90 })?;
91 let buffer = self.buffer().reinterpret_as::<f32>(buffer_len)?;
92 validate_layout_against_len(&dims, &strides, offset, buffer.len())?;
93 Ok(Tensor::from_parts(TensorParts {
94 buffer,
95 dims: Arc::from(dims),
96 strides: Arc::from(strides),
97 offset,
98 logical_memory_space: self.logical_memory_space(),
99 preferred_compute_device: self.preferred_compute_device(),
100 event: None,
101 conjugated: false,
102 fw_grad: None,
103 }))
104 }
105
106 pub fn real(&self) -> Result<Tensor<f32>> {
127 select_complex_component(self.view_as_real()?, 0)
128 }
129
130 pub fn imag(&self) -> Result<Tensor<f32>> {
151 select_complex_component(self.view_as_real()?, 1)
152 }
153}
154
155impl Tensor<Complex64> {
156 pub fn view_as_real(&self) -> Result<Tensor<f64>> {
176 self.wait();
177 if self.is_conjugated() {
178 return Err(Error::InvalidArgument(
179 "view_as_real requires a resolved complex tensor".into(),
180 ));
181 }
182
183 let mut dims = self.dims().to_vec();
184 dims.push(2);
185 let strides = view_as_real_strides(self.strides())?;
186 let offset = self.offset().checked_mul(2).ok_or_else(|| {
187 Error::StrideError("view_as_real: offset overflow when reinterpreting tensor".into())
188 })?;
189 let buffer_len =
190 self.buffer().len().checked_mul(2).ok_or_else(|| {
191 Error::StrideError("view_as_real: storage length overflow".into())
192 })?;
193 let buffer = self.buffer().reinterpret_as::<f64>(buffer_len)?;
194 validate_layout_against_len(&dims, &strides, offset, buffer.len())?;
195 Ok(Tensor::from_parts(TensorParts {
196 buffer,
197 dims: Arc::from(dims),
198 strides: Arc::from(strides),
199 offset,
200 logical_memory_space: self.logical_memory_space(),
201 preferred_compute_device: self.preferred_compute_device(),
202 event: None,
203 conjugated: false,
204 fw_grad: None,
205 }))
206 }
207
208 pub fn real(&self) -> Result<Tensor<f64>> {
229 select_complex_component(self.view_as_real()?, 0)
230 }
231
232 pub fn imag(&self) -> Result<Tensor<f64>> {
253 select_complex_component(self.view_as_real()?, 1)
254 }
255}
256
257impl Tensor<f32> {
258 pub fn view_as_complex(&self) -> Result<Tensor<Complex32>> {
271 self.wait();
272 if self.ndim() == 0 {
273 return Err(Error::InvalidArgument(
274 "view_as_complex requires at least one dimension".into(),
275 ));
276 }
277 if self.dims().last().copied() != Some(2) {
278 return Err(Error::InvalidArgument(
279 "view_as_complex requires the last dimension to have size 2".into(),
280 ));
281 }
282 if self.strides().last().copied() != Some(1) {
283 return Err(Error::InvalidArgument(
284 "view_as_complex requires the last stride to be 1".into(),
285 ));
286 }
287 if self.offset() % 2 != 0 {
288 return Err(Error::InvalidArgument(
289 "view_as_complex requires an even element offset".into(),
290 ));
291 }
292
293 let dims = self.dims()[..self.ndim() - 1].to_vec();
294 let strides = view_as_complex_strides(self.dims(), self.strides())?;
295 let offset = self.offset() / 2;
296 let source_len = self.buffer().len() / 2;
297 let buffer = self.buffer().reinterpret_as::<Complex32>(source_len)?;
298 validate_layout_against_len(&dims, &strides, offset, buffer.len())?;
299 Ok(Tensor::from_parts(TensorParts {
300 buffer,
301 dims: Arc::from(dims),
302 strides: Arc::from(strides),
303 offset,
304 logical_memory_space: self.logical_memory_space(),
305 preferred_compute_device: self.preferred_compute_device(),
306 event: None,
307 conjugated: false,
308 fw_grad: None,
309 }))
310 }
311}
312
313impl Tensor<f64> {
314 pub fn view_as_complex(&self) -> Result<Tensor<Complex64>> {
327 self.wait();
328 if self.ndim() == 0 {
329 return Err(Error::InvalidArgument(
330 "view_as_complex requires at least one dimension".into(),
331 ));
332 }
333 if self.dims().last().copied() != Some(2) {
334 return Err(Error::InvalidArgument(
335 "view_as_complex requires the last dimension to have size 2".into(),
336 ));
337 }
338 if self.strides().last().copied() != Some(1) {
339 return Err(Error::InvalidArgument(
340 "view_as_complex requires the last stride to be 1".into(),
341 ));
342 }
343 if self.offset() % 2 != 0 {
344 return Err(Error::InvalidArgument(
345 "view_as_complex requires an even element offset".into(),
346 ));
347 }
348
349 let dims = self.dims()[..self.ndim() - 1].to_vec();
350 let strides = view_as_complex_strides(self.dims(), self.strides())?;
351 let offset = self.offset() / 2;
352 let source_len = self.buffer().len() / 2;
353 let buffer = self.buffer().reinterpret_as::<Complex64>(source_len)?;
354 validate_layout_against_len(&dims, &strides, offset, buffer.len())?;
355 Ok(Tensor::from_parts(TensorParts {
356 buffer,
357 dims: Arc::from(dims),
358 strides: Arc::from(strides),
359 offset,
360 logical_memory_space: self.logical_memory_space(),
361 preferred_compute_device: self.preferred_compute_device(),
362 event: None,
363 conjugated: false,
364 fw_grad: None,
365 }))
366 }
367}