tenferro_internal_frontend_core/
dyn_tensor.rs1use num_complex::{Complex32, Complex64};
2use tenferro_internal_error::{Error, Result};
3use tenferro_tensor::{MemoryOrder, Tensor};
4
5use crate::tensor_ops::{tensor_map_binary_typed, tensor_map_unary_typed, tensor_max_typed};
6use crate::{ScalarType, StructuredTensor};
7
8#[derive(Clone, Debug)]
25pub enum DynTensor {
26 F32(StructuredTensor<f32>),
27 F64(StructuredTensor<f64>),
28 C32(StructuredTensor<Complex32>),
29 C64(StructuredTensor<Complex64>),
30}
31
32#[doc(hidden)]
33pub trait DynTensorTyped: tenferro_algebra::Scalar + 'static {
34 fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>>;
35 fn into_dyn(value: StructuredTensor<Self>) -> DynTensor;
36}
37
38impl DynTensorTyped for f32 {
39 fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>> {
40 value.as_f32()
41 }
42
43 fn into_dyn(value: StructuredTensor<Self>) -> DynTensor {
44 DynTensor::F32(value)
45 }
46}
47
48impl DynTensorTyped for f64 {
49 fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>> {
50 value.as_f64()
51 }
52
53 fn into_dyn(value: StructuredTensor<Self>) -> DynTensor {
54 DynTensor::F64(value)
55 }
56}
57
58impl DynTensorTyped for Complex32 {
59 fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>> {
60 value.as_c32()
61 }
62
63 fn into_dyn(value: StructuredTensor<Self>) -> DynTensor {
64 DynTensor::C32(value)
65 }
66}
67
68impl DynTensorTyped for Complex64 {
69 fn structured_ref(value: &DynTensor) -> Option<&StructuredTensor<Self>> {
70 value.as_c64()
71 }
72
73 fn into_dyn(value: StructuredTensor<Self>) -> DynTensor {
74 DynTensor::C64(value)
75 }
76}
77
78impl DynTensor {
79 pub fn scalar_type(&self) -> ScalarType {
92 match self {
93 Self::F32(_) => ScalarType::F32,
94 Self::F64(_) => ScalarType::F64,
95 Self::C32(_) => ScalarType::C32,
96 Self::C64(_) => ScalarType::C64,
97 }
98 }
99
100 pub fn dims(&self) -> &[usize] {
102 match self {
103 Self::F32(t) => t.logical_dims(),
104 Self::F64(t) => t.logical_dims(),
105 Self::C32(t) => t.logical_dims(),
106 Self::C64(t) => t.logical_dims(),
107 }
108 }
109
110 pub fn axis_classes(&self) -> &[usize] {
112 match self {
113 Self::F32(t) => t.axis_classes(),
114 Self::F64(t) => t.axis_classes(),
115 Self::C32(t) => t.axis_classes(),
116 Self::C64(t) => t.axis_classes(),
117 }
118 }
119
120 pub fn is_dense(&self) -> bool {
122 match self {
123 Self::F32(t) => t.is_dense(),
124 Self::F64(t) => t.is_dense(),
125 Self::C32(t) => t.is_dense(),
126 Self::C64(t) => t.is_dense(),
127 }
128 }
129
130 pub fn is_diag(&self) -> bool {
132 match self {
133 Self::F32(t) => t.is_diag(),
134 Self::F64(t) => t.is_diag(),
135 Self::C32(t) => t.is_diag(),
136 Self::C64(t) => t.is_diag(),
137 }
138 }
139
140 pub fn to_dense(&self) -> Result<Self> {
142 match self {
143 Self::F32(t) => Ok(Self::F32(StructuredTensor(
144 tenferro_tensor::StructuredTensor::from_dense(t.to_dense()?),
145 ))),
146 Self::F64(t) => Ok(Self::F64(StructuredTensor(
147 tenferro_tensor::StructuredTensor::from_dense(t.to_dense()?),
148 ))),
149 Self::C32(t) => Ok(Self::C32(StructuredTensor(
150 tenferro_tensor::StructuredTensor::from_dense(t.to_dense()?),
151 ))),
152 Self::C64(t) => Ok(Self::C64(StructuredTensor(
153 tenferro_tensor::StructuredTensor::from_dense(t.to_dense()?),
154 ))),
155 }
156 }
157
158 pub fn ndim(&self) -> usize {
159 self.dims().len()
160 }
161
162 pub fn len(&self) -> usize {
163 self.dims().iter().product()
164 }
165
166 pub fn is_empty(&self) -> bool {
167 self.len() == 0
168 }
169
170 pub fn as_f32(&self) -> Option<&StructuredTensor<f32>> {
171 if let Self::F32(t) = self {
172 Some(t)
173 } else {
174 None
175 }
176 }
177
178 pub fn as_f64(&self) -> Option<&StructuredTensor<f64>> {
179 if let Self::F64(t) = self {
180 Some(t)
181 } else {
182 None
183 }
184 }
185
186 pub fn as_c32(&self) -> Option<&StructuredTensor<Complex32>> {
187 if let Self::C32(t) = self {
188 Some(t)
189 } else {
190 None
191 }
192 }
193
194 pub fn as_c64(&self) -> Option<&StructuredTensor<Complex64>> {
195 if let Self::C64(t) = self {
196 Some(t)
197 } else {
198 None
199 }
200 }
201
202 pub fn payload_f32(&self) -> Option<&Tensor<f32>> {
203 self.as_f32().map(|tensor| tensor.payload())
204 }
205
206 pub fn payload_f64(&self) -> Option<&Tensor<f64>> {
207 self.as_f64().map(|tensor| tensor.payload())
208 }
209
210 pub fn payload_c32(&self) -> Option<&Tensor<Complex32>> {
211 self.as_c32().map(|tensor| tensor.payload())
212 }
213
214 pub fn payload_c64(&self) -> Option<&Tensor<Complex64>> {
215 self.as_c64().map(|tensor| tensor.payload())
216 }
217
218 #[doc(hidden)]
219 pub fn typed_ref<T>(&self) -> Option<&StructuredTensor<T>>
220 where
221 T: DynTensorTyped,
222 {
223 T::structured_ref(self)
224 }
225
226 pub fn try_sub(&self, rhs: &Self) -> Result<Self> {
227 match (self, rhs) {
228 (Self::F32(a), Self::F32(b)) => {
229 ensure_same_layout("try_sub", a, b)?;
230 Ok(Self::F32(StructuredTensor(a.0.with_payload_like(
231 tensor_map_binary_typed(a.payload(), b.payload(), |x, y| x - y)?,
232 )?)))
233 }
234 (Self::F64(a), Self::F64(b)) => {
235 ensure_same_layout("try_sub", a, b)?;
236 Ok(Self::F64(StructuredTensor(a.0.with_payload_like(
237 tensor_map_binary_typed(a.payload(), b.payload(), |x, y| x - y)?,
238 )?)))
239 }
240 (Self::C32(a), Self::C32(b)) => {
241 ensure_same_layout("try_sub", a, b)?;
242 Ok(Self::C32(StructuredTensor(a.0.with_payload_like(
243 tensor_map_binary_typed(a.payload(), b.payload(), |x, y| x - y)?,
244 )?)))
245 }
246 (Self::C64(a), Self::C64(b)) => {
247 ensure_same_layout("try_sub", a, b)?;
248 Ok(Self::C64(StructuredTensor(a.0.with_payload_like(
249 tensor_map_binary_typed(a.payload(), b.payload(), |x, y| x - y)?,
250 )?)))
251 }
252 _ => Err(Error::InvalidTensorOperands {
253 message: format!(
254 "dtype mismatch in try_sub: lhs={:?}, rhs={:?}",
255 self.scalar_type(),
256 rhs.scalar_type()
257 ),
258 }),
259 }
260 }
261
262 pub fn abs_tensor(&self) -> Result<Self> {
263 match self {
264 Self::F32(a) => Ok(Self::F32(StructuredTensor(a.0.with_payload_like(
265 tensor_map_unary_typed(a.payload(), |x: f32| x.abs())?,
266 )?))),
267 Self::F64(a) => Ok(Self::F64(StructuredTensor(a.0.with_payload_like(
268 tensor_map_unary_typed(a.payload(), |x: f64| x.abs())?,
269 )?))),
270 Self::C32(a) => Ok(Self::F32(StructuredTensor(
271 tenferro_tensor::StructuredTensor::new(
272 a.logical_dims().to_vec(),
273 a.axis_classes().to_vec(),
274 tensor_map_unary_typed(a.payload(), |z: Complex32| z.norm())?,
275 )?,
276 ))),
277 Self::C64(a) => Ok(Self::F64(StructuredTensor(
278 tenferro_tensor::StructuredTensor::new(
279 a.logical_dims().to_vec(),
280 a.axis_classes().to_vec(),
281 tensor_map_unary_typed(a.payload(), |z: Complex64| z.norm())?,
282 )?,
283 ))),
284 }
285 }
286
287 pub fn max(&self) -> Result<Self> {
288 match self {
289 Self::F32(t) => Ok(Self::F32(StructuredTensor(
290 tenferro_tensor::StructuredTensor::from_dense(Tensor::from_slice(
291 &[tensor_max_typed(t.payload())?],
292 &[],
293 MemoryOrder::ColumnMajor,
294 )?),
295 ))),
296 Self::F64(t) => Ok(Self::F64(StructuredTensor(
297 tenferro_tensor::StructuredTensor::from_dense(Tensor::from_slice(
298 &[tensor_max_typed(t.payload())?],
299 &[],
300 MemoryOrder::ColumnMajor,
301 )?),
302 ))),
303 Self::C32(_) | Self::C64(_) => Err(Error::InvalidTensorOperands {
304 message: "max is undefined for complex tensors; call abs_tensor() first"
305 .to_string(),
306 }),
307 }
308 }
309
310 pub fn max_as_f64(&self) -> Result<f64> {
311 match self.max()? {
312 Self::F32(t) => Ok(t.payload().buffer().as_slice().unwrap()[0] as f64),
313 Self::F64(t) => Ok(t.payload().buffer().as_slice().unwrap()[0]),
314 Self::C32(_) | Self::C64(_) => Err(Error::InvalidTensorOperands {
315 message: "max_as_f64 expects a real tensor".to_string(),
316 }),
317 }
318 }
319
320 pub fn max_abs_diff(&self, rhs: &Self) -> Result<f64> {
321 self.try_sub(rhs)?.abs_tensor()?.max_as_f64()
322 }
323}
324
325fn ensure_same_layout<T>(
326 op_name: &'static str,
327 lhs: &StructuredTensor<T>,
328 rhs: &StructuredTensor<T>,
329) -> Result<()>
330where
331 T: tenferro_algebra::Scalar,
332{
333 if lhs.logical_dims() != rhs.logical_dims() {
334 return Err(Error::InvalidTensorOperands {
335 message: format!(
336 "{op_name} requires matching logical_dims, got lhs={:?}, rhs={:?}",
337 lhs.logical_dims(),
338 rhs.logical_dims()
339 ),
340 });
341 }
342 if lhs.axis_classes() != rhs.axis_classes() {
343 return Err(Error::InvalidTensorOperands {
344 message: format!(
345 "{op_name} requires matching axis_classes, got lhs={:?}, rhs={:?}",
346 lhs.axis_classes(),
347 rhs.axis_classes()
348 ),
349 });
350 }
351 Ok(())
352}
353
354macro_rules! impl_dyn_tensor_from {
355 ($variant:ident, $ty:ty) => {
356 impl From<Tensor<$ty>> for DynTensor {
357 fn from(value: Tensor<$ty>) -> Self {
358 Self::$variant(StructuredTensor(
359 tenferro_tensor::StructuredTensor::from_dense(value),
360 ))
361 }
362 }
363
364 impl From<StructuredTensor<$ty>> for DynTensor {
365 fn from(value: StructuredTensor<$ty>) -> Self {
366 Self::$variant(value)
367 }
368 }
369 };
370}
371
372impl_dyn_tensor_from!(F32, f32);
373impl_dyn_tensor_from!(F64, f64);
374impl_dyn_tensor_from!(C32, Complex32);
375impl_dyn_tensor_from!(C64, Complex64);