tenferro_tensor/tensor/views/
complex.rs

1use super::*;
2use crate::tensor::TensorParts;
3
4fn view_as_real_strides(strides: &[isize]) -> Result<Vec<isize>> {
5    let mut output = Vec::with_capacity(strides.len() + 1);
6    for (axis, &stride) in strides.iter().enumerate() {
7        output.push(stride.checked_mul(2).ok_or_else(|| {
8            Error::StrideError(format!(
9                "view_as_real: stride overflow for axis {axis} with stride {stride}"
10            ))
11        })?);
12    }
13    output.push(1);
14    Ok(output)
15}
16
17fn view_as_complex_strides(dims: &[usize], strides: &[isize]) -> Result<Vec<isize>> {
18    if dims.is_empty() || strides.is_empty() {
19        return Err(Error::InvalidArgument(
20            "view_as_complex requires at least one dimension".into(),
21        ));
22    }
23    if dims.len() != strides.len() {
24        return Err(Error::InvalidArgument(format!(
25            "view_as_complex rank mismatch: dims={} strides={}",
26            dims.len(),
27            strides.len()
28        )));
29    }
30
31    let mut output = Vec::with_capacity(strides.len().saturating_sub(1));
32    for (axis, (&dim, &stride)) in dims[..dims.len() - 1]
33        .iter()
34        .zip(&strides[..strides.len() - 1])
35        .enumerate()
36    {
37        if dim != 1 && stride % 2 != 0 {
38            return Err(Error::InvalidArgument(format!(
39                "view_as_complex requires even strides on all leading dimensions, got stride {stride} at axis {axis}"
40            )));
41        }
42        output.push(stride / 2);
43    }
44    Ok(output)
45}
46
47fn select_complex_component<T>(view: Tensor<T>, component: usize) -> Result<Tensor<T>> {
48    view.select(view.ndim() - 1, component)
49}
50
51impl Tensor<Complex32> {
52    /// Return a zero-copy real view of a complex tensor.
53    ///
54    /// The last logical axis is expanded to length `2`, exposing the real and
55    /// imaginary parts as adjacent real-valued elements.
56    ///
57    /// # Examples
58    ///
59    /// ```ignore
60    /// use num_complex::Complex32;
61    /// use tenferro_device::LogicalMemorySpace;
62    /// use tenferro_tensor::{MemoryOrder, Tensor};
63    ///
64    /// let z = Tensor::<Complex32>::from_slice(
65    ///     &[Complex32::new(1.0, 2.0)],
66    ///     &[1],
67    ///     MemoryOrder::ColumnMajor,
68    /// ).unwrap();
69    /// let r = z.view_as_real().unwrap();
70    /// assert_eq!(r.dims(), &[1, 2]);
71    /// assert_eq!(r.logical_memory_space(), LogicalMemorySpace::MainMemory);
72    /// ```
73    pub fn view_as_real(&self) -> Result<Tensor<f32>> {
74        self.wait();
75        if self.is_conjugated() {
76            return Err(Error::InvalidArgument(
77                "view_as_real requires a resolved complex tensor".into(),
78            ));
79        }
80
81        let mut dims = self.dims().to_vec();
82        dims.push(2);
83        let strides = view_as_real_strides(self.strides())?;
84        let offset = self.offset().checked_mul(2).ok_or_else(|| {
85            Error::StrideError("view_as_real: offset overflow when reinterpreting tensor".into())
86        })?;
87        let buffer_len =
88            self.buffer().len().checked_mul(2).ok_or_else(|| {
89                Error::StrideError("view_as_real: storage length overflow".into())
90            })?;
91        let buffer = self.buffer().reinterpret_as::<f32>(buffer_len)?;
92        validate_layout_against_len(&dims, &strides, offset, buffer.len())?;
93        Ok(Tensor::from_parts(TensorParts {
94            buffer,
95            dims: Arc::from(dims),
96            strides: Arc::from(strides),
97            offset,
98            logical_memory_space: self.logical_memory_space(),
99            preferred_compute_device: self.preferred_compute_device(),
100            event: None,
101            conjugated: false,
102            fw_grad: None,
103        }))
104    }
105
106    /// Return a zero-copy view of the real part of a resolved complex tensor.
107    ///
108    /// This is implemented as `view_as_real()` followed by selecting the real
109    /// lane of the trailing size-2 axis.
110    ///
111    /// # Examples
112    ///
113    /// ```ignore
114    /// use num_complex::Complex32;
115    /// use tenferro_tensor::{MemoryOrder, Tensor};
116    ///
117    /// let z = Tensor::<Complex32>::from_slice(
118    ///     &[Complex32::new(1.0, 2.0)],
119    ///     &[1],
120    ///     MemoryOrder::ColumnMajor,
121    /// )
122    /// .unwrap();
123    /// let real = z.real().unwrap();
124    /// assert_eq!(real.dims(), &[1]);
125    /// ```
126    pub fn real(&self) -> Result<Tensor<f32>> {
127        select_complex_component(self.view_as_real()?, 0)
128    }
129
130    /// Return a zero-copy view of the imaginary part of a resolved complex tensor.
131    ///
132    /// This is implemented as `view_as_real()` followed by selecting the
133    /// imaginary lane of the trailing size-2 axis.
134    ///
135    /// # Examples
136    ///
137    /// ```ignore
138    /// use num_complex::Complex32;
139    /// use tenferro_tensor::{MemoryOrder, Tensor};
140    ///
141    /// let z = Tensor::<Complex32>::from_slice(
142    ///     &[Complex32::new(1.0, 2.0)],
143    ///     &[1],
144    ///     MemoryOrder::ColumnMajor,
145    /// )
146    /// .unwrap();
147    /// let imag = z.imag().unwrap();
148    /// assert_eq!(imag.dims(), &[1]);
149    /// ```
150    pub fn imag(&self) -> Result<Tensor<f32>> {
151        select_complex_component(self.view_as_real()?, 1)
152    }
153}
154
155impl Tensor<Complex64> {
156    /// Return a zero-copy real view of a complex tensor.
157    ///
158    /// The last logical axis is expanded to length `2`, exposing the real and
159    /// imaginary parts as adjacent real-valued elements.
160    ///
161    /// # Examples
162    ///
163    /// ```ignore
164    /// use num_complex::Complex64;
165    /// use tenferro_tensor::{MemoryOrder, Tensor};
166    ///
167    /// let z = Tensor::<Complex64>::from_slice(
168    ///     &[Complex64::new(1.0, 2.0)],
169    ///     &[1],
170    ///     MemoryOrder::ColumnMajor,
171    /// ).unwrap();
172    /// let r = z.view_as_real().unwrap();
173    /// assert_eq!(r.dims(), &[1, 2]);
174    /// ```
175    pub fn view_as_real(&self) -> Result<Tensor<f64>> {
176        self.wait();
177        if self.is_conjugated() {
178            return Err(Error::InvalidArgument(
179                "view_as_real requires a resolved complex tensor".into(),
180            ));
181        }
182
183        let mut dims = self.dims().to_vec();
184        dims.push(2);
185        let strides = view_as_real_strides(self.strides())?;
186        let offset = self.offset().checked_mul(2).ok_or_else(|| {
187            Error::StrideError("view_as_real: offset overflow when reinterpreting tensor".into())
188        })?;
189        let buffer_len =
190            self.buffer().len().checked_mul(2).ok_or_else(|| {
191                Error::StrideError("view_as_real: storage length overflow".into())
192            })?;
193        let buffer = self.buffer().reinterpret_as::<f64>(buffer_len)?;
194        validate_layout_against_len(&dims, &strides, offset, buffer.len())?;
195        Ok(Tensor::from_parts(TensorParts {
196            buffer,
197            dims: Arc::from(dims),
198            strides: Arc::from(strides),
199            offset,
200            logical_memory_space: self.logical_memory_space(),
201            preferred_compute_device: self.preferred_compute_device(),
202            event: None,
203            conjugated: false,
204            fw_grad: None,
205        }))
206    }
207
208    /// Return a zero-copy view of the real part of a resolved complex tensor.
209    ///
210    /// This is implemented as `view_as_real()` followed by selecting the real
211    /// lane of the trailing size-2 axis.
212    ///
213    /// # Examples
214    ///
215    /// ```ignore
216    /// use num_complex::Complex64;
217    /// use tenferro_tensor::{MemoryOrder, Tensor};
218    ///
219    /// let z = Tensor::<Complex64>::from_slice(
220    ///     &[Complex64::new(1.0, 2.0)],
221    ///     &[1],
222    ///     MemoryOrder::ColumnMajor,
223    /// )
224    /// .unwrap();
225    /// let real = z.real().unwrap();
226    /// assert_eq!(real.dims(), &[1]);
227    /// ```
228    pub fn real(&self) -> Result<Tensor<f64>> {
229        select_complex_component(self.view_as_real()?, 0)
230    }
231
232    /// Return a zero-copy view of the imaginary part of a resolved complex tensor.
233    ///
234    /// This is implemented as `view_as_real()` followed by selecting the
235    /// imaginary lane of the trailing size-2 axis.
236    ///
237    /// # Examples
238    ///
239    /// ```ignore
240    /// use num_complex::Complex64;
241    /// use tenferro_tensor::{MemoryOrder, Tensor};
242    ///
243    /// let z = Tensor::<Complex64>::from_slice(
244    ///     &[Complex64::new(1.0, 2.0)],
245    ///     &[1],
246    ///     MemoryOrder::ColumnMajor,
247    /// )
248    /// .unwrap();
249    /// let imag = z.imag().unwrap();
250    /// assert_eq!(imag.dims(), &[1]);
251    /// ```
252    pub fn imag(&self) -> Result<Tensor<f64>> {
253        select_complex_component(self.view_as_real()?, 1)
254    }
255}
256
257impl Tensor<f32> {
258    /// Return a zero-copy complex view of a real tensor whose last dimension
259    /// stores paired real and imaginary components.
260    ///
261    /// # Examples
262    ///
263    /// ```ignore
264    /// use tenferro_tensor::{MemoryOrder, Tensor};
265    ///
266    /// let r = Tensor::<f32>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
267    /// let z = r.view_as_complex().unwrap();
268    /// assert_eq!(z.dims(), &[] as &[usize]);
269    /// ```
270    pub fn view_as_complex(&self) -> Result<Tensor<Complex32>> {
271        self.wait();
272        if self.ndim() == 0 {
273            return Err(Error::InvalidArgument(
274                "view_as_complex requires at least one dimension".into(),
275            ));
276        }
277        if self.dims().last().copied() != Some(2) {
278            return Err(Error::InvalidArgument(
279                "view_as_complex requires the last dimension to have size 2".into(),
280            ));
281        }
282        if self.strides().last().copied() != Some(1) {
283            return Err(Error::InvalidArgument(
284                "view_as_complex requires the last stride to be 1".into(),
285            ));
286        }
287        if self.offset() % 2 != 0 {
288            return Err(Error::InvalidArgument(
289                "view_as_complex requires an even element offset".into(),
290            ));
291        }
292
293        let dims = self.dims()[..self.ndim() - 1].to_vec();
294        let strides = view_as_complex_strides(self.dims(), self.strides())?;
295        let offset = self.offset() / 2;
296        let source_len = self.buffer().len() / 2;
297        let buffer = self.buffer().reinterpret_as::<Complex32>(source_len)?;
298        validate_layout_against_len(&dims, &strides, offset, buffer.len())?;
299        Ok(Tensor::from_parts(TensorParts {
300            buffer,
301            dims: Arc::from(dims),
302            strides: Arc::from(strides),
303            offset,
304            logical_memory_space: self.logical_memory_space(),
305            preferred_compute_device: self.preferred_compute_device(),
306            event: None,
307            conjugated: false,
308            fw_grad: None,
309        }))
310    }
311}
312
313impl Tensor<f64> {
314    /// Return a zero-copy complex view of a real tensor whose last dimension
315    /// stores paired real and imaginary components.
316    ///
317    /// # Examples
318    ///
319    /// ```ignore
320    /// use tenferro_tensor::{MemoryOrder, Tensor};
321    ///
322    /// let r = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
323    /// let z = r.view_as_complex().unwrap();
324    /// assert_eq!(z.dims(), &[] as &[usize]);
325    /// ```
326    pub fn view_as_complex(&self) -> Result<Tensor<Complex64>> {
327        self.wait();
328        if self.ndim() == 0 {
329            return Err(Error::InvalidArgument(
330                "view_as_complex requires at least one dimension".into(),
331            ));
332        }
333        if self.dims().last().copied() != Some(2) {
334            return Err(Error::InvalidArgument(
335                "view_as_complex requires the last dimension to have size 2".into(),
336            ));
337        }
338        if self.strides().last().copied() != Some(1) {
339            return Err(Error::InvalidArgument(
340                "view_as_complex requires the last stride to be 1".into(),
341            ));
342        }
343        if self.offset() % 2 != 0 {
344            return Err(Error::InvalidArgument(
345                "view_as_complex requires an even element offset".into(),
346            ));
347        }
348
349        let dims = self.dims()[..self.ndim() - 1].to_vec();
350        let strides = view_as_complex_strides(self.dims(), self.strides())?;
351        let offset = self.offset() / 2;
352        let source_len = self.buffer().len() / 2;
353        let buffer = self.buffer().reinterpret_as::<Complex64>(source_len)?;
354        validate_layout_against_len(&dims, &strides, offset, buffer.len())?;
355        Ok(Tensor::from_parts(TensorParts {
356            buffer,
357            dims: Arc::from(dims),
358            strides: Arc::from(strides),
359            offset,
360            logical_memory_space: self.logical_memory_space(),
361            preferred_compute_device: self.preferred_compute_device(),
362            event: None,
363            conjugated: false,
364            fw_grad: None,
365        }))
366    }
367}