Skip to main content

tensor4all_tensorbackend/
tensor_element.rs

1use anyhow::{anyhow, ensure, Result};
2use num_complex::{Complex32, Complex64};
3use tenferro::{DType, Tensor as NativeTensor, TensorScalar};
4
5/// Public scalar element types supported by tensor4all dense/diag constructors.
6///
7/// Implemented for `f32`, `f64`, `Complex32`, and `Complex64`.
8///
9/// # Examples
10///
11/// ```
12/// use tensor4all_tensorbackend::TensorElement;
13///
14/// let t = f64::dense_native_tensor_from_col_major(&[1.0, 2.0], &[2]).unwrap();
15/// assert_eq!(t.shape(), &[2]);
16///
17/// let vals = f64::dense_values_from_native_col_major(&t).unwrap();
18/// assert_eq!(vals, vec![1.0, 2.0]);
19/// ```
20pub trait TensorElement: TensorScalar + Copy + Send + Sync + 'static {
21    /// Build a dense native tensor from column-major data.
22    fn dense_native_tensor_from_col_major(data: &[Self], dims: &[usize]) -> Result<NativeTensor>;
23
24    /// Build a diagonal native tensor from column-major diagonal payload data.
25    fn diag_native_tensor_from_col_major(
26        data: &[Self],
27        logical_rank: usize,
28    ) -> Result<NativeTensor>;
29
30    /// Build a rank-0 native tensor.
31    fn scalar_native_tensor(value: Self) -> Result<NativeTensor>;
32
33    /// Materialize dense column-major values from a native tensor.
34    fn dense_values_from_native_col_major(tensor: &NativeTensor) -> Result<Vec<Self>>;
35
36    /// Materialize diagonal values from a dense native tensor.
37    fn diag_values_from_native_temp(tensor: &NativeTensor) -> Result<Vec<Self>>;
38}
39
40fn tensor_dtype_name(dtype: DType) -> &'static str {
41    match dtype {
42        DType::F32 => "f32",
43        DType::F64 => "f64",
44        DType::I64 => "i64",
45        DType::C32 => "c32",
46        DType::C64 => "c64",
47    }
48}
49
50fn dense_diagonal_values<T: Copy + Default>(diag: &[T], logical_rank: usize) -> Result<Vec<T>> {
51    ensure!(
52        logical_rank >= 1,
53        "diagonal tensor construction requires at least one logical axis"
54    );
55    let diag_len = diag.len();
56    let dims = vec![diag_len; logical_rank];
57    let total_len = dims.iter().product::<usize>();
58    let mut dense = vec![T::default(); total_len];
59    let diagonal_stride = (0..logical_rank)
60        .scan(1usize, |stride, _| {
61            let current = *stride;
62            *stride = stride.saturating_mul(diag_len);
63            Some(current)
64        })
65        .sum::<usize>();
66    for (i, value) in diag.iter().copied().enumerate() {
67        dense[i * diagonal_stride] = value;
68    }
69    Ok(dense)
70}
71
72macro_rules! impl_tensor_element {
73    ($ty:ty, $dtype:expr) => {
74        impl TensorElement for $ty {
75            fn dense_native_tensor_from_col_major(
76                data: &[Self],
77                dims: &[usize],
78            ) -> Result<NativeTensor> {
79                let expected_len: usize = dims.iter().product();
80                ensure!(
81                    data.len() == expected_len,
82                    "dense tensor len {} does not match dims {:?} (expected {})",
83                    data.len(),
84                    dims,
85                    expected_len
86                );
87                Ok(NativeTensor::from_vec(dims.to_vec(), data.to_vec()))
88            }
89
90            fn diag_native_tensor_from_col_major(
91                data: &[Self],
92                logical_rank: usize,
93            ) -> Result<NativeTensor> {
94                let dims = vec![data.len(); logical_rank];
95                let dense = dense_diagonal_values(data, logical_rank)?;
96                Self::dense_native_tensor_from_col_major(&dense, &dims)
97            }
98
99            fn scalar_native_tensor(value: Self) -> Result<NativeTensor> {
100                Ok(NativeTensor::from_vec(vec![], vec![value]))
101            }
102
103            fn dense_values_from_native_col_major(tensor: &NativeTensor) -> Result<Vec<Self>> {
104                tensor
105                    .as_slice::<Self>()
106                    .map(|values| values.to_vec())
107                    .ok_or_else(|| {
108                        anyhow!(
109                            "tensor dtype mismatch: expected {}, got {}",
110                            tensor_dtype_name($dtype),
111                            tensor_dtype_name(tensor.dtype())
112                        )
113                    })
114            }
115
116            fn diag_values_from_native_temp(tensor: &NativeTensor) -> Result<Vec<Self>> {
117                let shape = tensor.shape();
118                ensure!(
119                    !shape.is_empty(),
120                    "diagonal extraction requires rank >= 1, got scalar tensor"
121                );
122                let diag_len = shape[0];
123                ensure!(
124                    shape.iter().all(|&dim| dim == diag_len),
125                    "expected square/equal dims for diagonal extraction, got {:?}",
126                    shape
127                );
128                let dense = Self::dense_values_from_native_col_major(tensor)?;
129                let diagonal_stride = (0..shape.len())
130                    .scan(1usize, |stride, _| {
131                        let current = *stride;
132                        *stride = stride.saturating_mul(diag_len);
133                        Some(current)
134                    })
135                    .sum::<usize>();
136                Ok((0..diag_len).map(|i| dense[i * diagonal_stride]).collect())
137            }
138        }
139    };
140}
141
142impl_tensor_element!(f32, DType::F32);
143impl_tensor_element!(f64, DType::F64);
144impl_tensor_element!(Complex32, DType::C32);
145impl_tensor_element!(Complex64, DType::C64);