tenferro_internal_frontend_core/
dyn_tensor.rs

1use num_complex::{Complex32, Complex64};
2use tenferro_internal_error::{Error, Result};
3use tenferro_tensor::{MemoryOrder, Tensor};
4
5use crate::tensor_ops::{tensor_map_binary_typed, tensor_map_unary_typed, tensor_max_typed};
6use crate::{ScalarType, StructuredTensor};
7
8/// Runtime tensor wrapper for a fixed supported dtype set.
9///
10/// `DynTensor` is the canonical dynamic primal tensor type shared by tenferro
11/// frontends. Each variant carries a `StructuredTensor<T>`, so dense tensors
12/// and structured special cases such as `Diag` share the same container.
13///
14/// # Examples
15///
16/// ```rust
17/// use tenferro_internal_frontend_core::{DynTensor, ScalarType, StructuredTensor};
18/// use tenferro_tensor::{MemoryOrder, Tensor};
19///
20/// let t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
21/// let x: DynTensor = StructuredTensor::from(t).into();
22/// assert_eq!(x.scalar_type(), ScalarType::F64);
23/// ```
24#[derive(Clone, Debug)]
25pub enum DynTensor {
26    F32(StructuredTensor<f32>),
27    F64(StructuredTensor<f64>),
28    C32(StructuredTensor<Complex32>),
29    C64(StructuredTensor<Complex64>),
30}
31
32#[doc(hidden)]
33pub trait DynTensorTyped: tenferro_algebra::Scalar + 'static {
34    fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>>;
35    fn into_dyn(value: StructuredTensor<Self>) -> DynTensor;
36}
37
38impl DynTensorTyped for f32 {
39    fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>> {
40        value.as_f32()
41    }
42
43    fn into_dyn(value: StructuredTensor<Self>) -> DynTensor {
44        DynTensor::F32(value)
45    }
46}
47
48impl DynTensorTyped for f64 {
49    fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>> {
50        value.as_f64()
51    }
52
53    fn into_dyn(value: StructuredTensor<Self>) -> DynTensor {
54        DynTensor::F64(value)
55    }
56}
57
58impl DynTensorTyped for Complex32 {
59    fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>> {
60        value.as_c32()
61    }
62
63    fn into_dyn(value: StructuredTensor<Self>) -> DynTensor {
64        DynTensor::C32(value)
65    }
66}
67
68impl DynTensorTyped for Complex64 {
69    fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>> {
70        value.as_c64()
71    }
72
73    fn into_dyn(value: StructuredTensor<Self>) -> DynTensor {
74        DynTensor::C64(value)
75    }
76}
77
78impl DynTensor {
79    /// Returns runtime scalar type.
80    ///
81    /// # Examples
82    ///
83    /// ```rust
84    /// use tenferro_internal_frontend_core::{DynTensor, ScalarType, StructuredTensor};
85    /// use tenferro_tensor::{MemoryOrder, Tensor};
86    ///
87    /// let t = Tensor::<f32>::from_slice(&[1.0], &[1], MemoryOrder::ColumnMajor).unwrap();
88    /// let x: DynTensor = StructuredTensor::from(t).into();
89    /// assert_eq!(x.scalar_type(), ScalarType::F32);
90    /// ```
91    pub fn scalar_type(&self) -> ScalarType {
92        match self {
93            Self::F32(_) => ScalarType::F32,
94            Self::F64(_) => ScalarType::F64,
95            Self::C32(_) => ScalarType::C32,
96            Self::C64(_) => ScalarType::C64,
97        }
98    }
99
100    /// Returns logical dimensions of the underlying tensor.
101    pub fn dims(&self) -> &[usize] {
102        match self {
103            Self::F32(t) => t.logical_dims(),
104            Self::F64(t) => t.logical_dims(),
105            Self::C32(t) => t.logical_dims(),
106            Self::C64(t) => t.logical_dims(),
107        }
108    }
109
110    /// Returns axis equivalence classes of the structured layout.
111    pub fn axis_classes(&self) -> &[usize] {
112        match self {
113            Self::F32(t) => t.axis_classes(),
114            Self::F64(t) => t.axis_classes(),
115            Self::C32(t) => t.axis_classes(),
116            Self::C64(t) => t.axis_classes(),
117        }
118    }
119
120    /// Returns `true` when the structured payload is dense.
121    pub fn is_dense(&self) -> bool {
122        match self {
123            Self::F32(t) => t.is_dense(),
124            Self::F64(t) => t.is_dense(),
125            Self::C32(t) => t.is_dense(),
126            Self::C64(t) => t.is_dense(),
127        }
128    }
129
130    /// Returns `true` when the structured payload is diagonal.
131    pub fn is_diag(&self) -> bool {
132        match self {
133            Self::F32(t) => t.is_diag(),
134            Self::F64(t) => t.is_diag(),
135            Self::C32(t) => t.is_diag(),
136            Self::C64(t) => t.is_diag(),
137        }
138    }
139
140    /// Materializes a dense snapshot with the same logical tensor values.
141    pub fn to_dense(&self) -> Result<Self> {
142        match self {
143            Self::F32(t) => Ok(Self::F32(StructuredTensor(
144                tenferro_tensor::StructuredTensor::from_dense(t.to_dense()?),
145            ))),
146            Self::F64(t) => Ok(Self::F64(StructuredTensor(
147                tenferro_tensor::StructuredTensor::from_dense(t.to_dense()?),
148            ))),
149            Self::C32(t) => Ok(Self::C32(StructuredTensor(
150                tenferro_tensor::StructuredTensor::from_dense(t.to_dense()?),
151            ))),
152            Self::C64(t) => Ok(Self::C64(StructuredTensor(
153                tenferro_tensor::StructuredTensor::from_dense(t.to_dense()?),
154            ))),
155        }
156    }
157
158    pub fn ndim(&self) -> usize {
159        self.dims().len()
160    }
161
162    pub fn len(&self) -> usize {
163        self.dims().iter().product()
164    }
165
166    pub fn is_empty(&self) -> bool {
167        self.len() == 0
168    }
169
170    pub fn as_f32(&self) -> Option<&StructuredTensor<f32>> {
171        if let Self::F32(t) = self {
172            Some(t)
173        } else {
174            None
175        }
176    }
177
178    pub fn as_f64(&self) -> Option<&StructuredTensor<f64>> {
179        if let Self::F64(t) = self {
180            Some(t)
181        } else {
182            None
183        }
184    }
185
186    pub fn as_c32(&self) -> Option<&StructuredTensor<Complex32>> {
187        if let Self::C32(t) = self {
188            Some(t)
189        } else {
190            None
191        }
192    }
193
194    pub fn as_c64(&self) -> Option<&StructuredTensor<Complex64>> {
195        if let Self::C64(t) = self {
196            Some(t)
197        } else {
198            None
199        }
200    }
201
202    pub fn payload_f32(&self) -> Option<&Tensor<f32>> {
203        self.as_f32().map(|tensor| tensor.payload())
204    }
205
206    pub fn payload_f64(&self) -> Option<&Tensor<f64>> {
207        self.as_f64().map(|tensor| tensor.payload())
208    }
209
210    pub fn payload_c32(&self) -> Option<&Tensor<Complex32>> {
211        self.as_c32().map(|tensor| tensor.payload())
212    }
213
214    pub fn payload_c64(&self) -> Option<&Tensor<Complex64>> {
215        self.as_c64().map(|tensor| tensor.payload())
216    }
217
218    #[doc(hidden)]
219    pub fn typed_ref<T>(&self) -> Option<&StructuredTensor<T>>
220    where
221        T: DynTensorTyped,
222    {
223        T::structured_ref(self)
224    }
225
226    pub fn try_sub(&self, rhs: &Self) -> Result<Self> {
227        match (self, rhs) {
228            (Self::F32(a), Self::F32(b)) => {
229                ensure_same_layout("try_sub", a, b)?;
230                Ok(Self::F32(StructuredTensor(a.0.with_payload_like(
231                    tensor_map_binary_typed(a.payload(), b.payload(), |x, y| x - y)?,
232                )?)))
233            }
234            (Self::F64(a), Self::F64(b)) => {
235                ensure_same_layout("try_sub", a, b)?;
236                Ok(Self::F64(StructuredTensor(a.0.with_payload_like(
237                    tensor_map_binary_typed(a.payload(), b.payload(), |x, y| x - y)?,
238                )?)))
239            }
240            (Self::C32(a), Self::C32(b)) => {
241                ensure_same_layout("try_sub", a, b)?;
242                Ok(Self::C32(StructuredTensor(a.0.with_payload_like(
243                    tensor_map_binary_typed(a.payload(), b.payload(), |x, y| x - y)?,
244                )?)))
245            }
246            (Self::C64(a), Self::C64(b)) => {
247                ensure_same_layout("try_sub", a, b)?;
248                Ok(Self::C64(StructuredTensor(a.0.with_payload_like(
249                    tensor_map_binary_typed(a.payload(), b.payload(), |x, y| x - y)?,
250                )?)))
251            }
252            _ => Err(Error::InvalidTensorOperands {
253                message: format!(
254                    "dtype mismatch in try_sub: lhs={:?}, rhs={:?}",
255                    self.scalar_type(),
256                    rhs.scalar_type()
257                ),
258            }),
259        }
260    }
261
262    pub fn abs_tensor(&self) -> Result<Self> {
263        match self {
264            Self::F32(a) => Ok(Self::F32(StructuredTensor(a.0.with_payload_like(
265                tensor_map_unary_typed(a.payload(), |x: f32| x.abs())?,
266            )?))),
267            Self::F64(a) => Ok(Self::F64(StructuredTensor(a.0.with_payload_like(
268                tensor_map_unary_typed(a.payload(), |x: f64| x.abs())?,
269            )?))),
270            Self::C32(a) => Ok(Self::F32(StructuredTensor(
271                tenferro_tensor::StructuredTensor::new(
272                    a.logical_dims().to_vec(),
273                    a.axis_classes().to_vec(),
274                    tensor_map_unary_typed(a.payload(), |z: Complex32| z.norm())?,
275                )?,
276            ))),
277            Self::C64(a) => Ok(Self::F64(StructuredTensor(
278                tenferro_tensor::StructuredTensor::new(
279                    a.logical_dims().to_vec(),
280                    a.axis_classes().to_vec(),
281                    tensor_map_unary_typed(a.payload(), |z: Complex64| z.norm())?,
282                )?,
283            ))),
284        }
285    }
286
287    pub fn max(&self) -> Result<Self> {
288        match self {
289            Self::F32(t) => Ok(Self::F32(StructuredTensor(
290                tenferro_tensor::StructuredTensor::from_dense(Tensor::from_slice(
291                    &[tensor_max_typed(t.payload())?],
292                    &[],
293                    MemoryOrder::ColumnMajor,
294                )?),
295            ))),
296            Self::F64(t) => Ok(Self::F64(StructuredTensor(
297                tenferro_tensor::StructuredTensor::from_dense(Tensor::from_slice(
298                    &[tensor_max_typed(t.payload())?],
299                    &[],
300                    MemoryOrder::ColumnMajor,
301                )?),
302            ))),
303            Self::C32(_) | Self::C64(_) => Err(Error::InvalidTensorOperands {
304                message: "max is undefined for complex tensors; call abs_tensor() first"
305                    .to_string(),
306            }),
307        }
308    }
309
310    pub fn max_as_f64(&self) -> Result<f64> {
311        match self.max()? {
312            Self::F32(t) => Ok(t.payload().buffer().as_slice().unwrap()[0] as f64),
313            Self::F64(t) => Ok(t.payload().buffer().as_slice().unwrap()[0]),
314            Self::C32(_) | Self::C64(_) => Err(Error::InvalidTensorOperands {
315                message: "max_as_f64 expects a real tensor".to_string(),
316            }),
317        }
318    }
319
320    pub fn max_abs_diff(&self, rhs: &Self) -> Result<f64> {
321        self.try_sub(rhs)?.abs_tensor()?.max_as_f64()
322    }
323}
324
325fn ensure_same_layout<T>(
326    op_name: &'static str,
327    lhs: &StructuredTensor<T>,
328    rhs: &StructuredTensor<T>,
329) -> Result<()>
330where
331    T: tenferro_algebra::Scalar,
332{
333    if lhs.logical_dims() != rhs.logical_dims() {
334        return Err(Error::InvalidTensorOperands {
335            message: format!(
336                "{op_name} requires matching logical_dims, got lhs={:?}, rhs={:?}",
337                lhs.logical_dims(),
338                rhs.logical_dims()
339            ),
340        });
341    }
342    if lhs.axis_classes() != rhs.axis_classes() {
343        return Err(Error::InvalidTensorOperands {
344            message: format!(
345                "{op_name} requires matching axis_classes, got lhs={:?}, rhs={:?}",
346                lhs.axis_classes(),
347                rhs.axis_classes()
348            ),
349        });
350    }
351    Ok(())
352}
353
354macro_rules! impl_dyn_tensor_from {
355    ($variant:ident, $ty:ty) => {
356        impl From<Tensor<$ty>> for DynTensor {
357            fn from(value: Tensor<$ty>) -> Self {
358                Self::$variant(StructuredTensor(
359                    tenferro_tensor::StructuredTensor::from_dense(value),
360                ))
361            }
362        }
363
364        impl From<StructuredTensor<$ty>> for DynTensor {
365            fn from(value: StructuredTensor<$ty>) -> Self {
366                Self::$variant(value)
367            }
368        }
369    };
370}
371
372impl_dyn_tensor_from!(F32, f32);
373impl_dyn_tensor_from!(F64, f64);
374impl_dyn_tensor_from!(C32, Complex32);
375impl_dyn_tensor_from!(C64, Complex64);