tensor4all_tensorbackend/
tensor_element.rs1use anyhow::{anyhow, ensure, Result};
2use num_complex::{Complex32, Complex64};
3use tenferro::{DType, Tensor as NativeTensor, TensorScalar};
4
5pub trait TensorElement: TensorScalar + Copy + Send + Sync + 'static {
21 fn dense_native_tensor_from_col_major(data: &[Self], dims: &[usize]) -> Result<NativeTensor>;
23
24 fn diag_native_tensor_from_col_major(
26 data: &[Self],
27 logical_rank: usize,
28 ) -> Result<NativeTensor>;
29
30 fn scalar_native_tensor(value: Self) -> Result<NativeTensor>;
32
33 fn dense_values_from_native_col_major(tensor: &NativeTensor) -> Result<Vec<Self>>;
35
36 fn diag_values_from_native_temp(tensor: &NativeTensor) -> Result<Vec<Self>>;
38}
39
40fn tensor_dtype_name(dtype: DType) -> &'static str {
41 match dtype {
42 DType::F32 => "f32",
43 DType::F64 => "f64",
44 DType::I64 => "i64",
45 DType::C32 => "c32",
46 DType::C64 => "c64",
47 }
48}
49
50fn dense_diagonal_values<T: Copy + Default>(diag: &[T], logical_rank: usize) -> Result<Vec<T>> {
51 ensure!(
52 logical_rank >= 1,
53 "diagonal tensor construction requires at least one logical axis"
54 );
55 let diag_len = diag.len();
56 let dims = vec![diag_len; logical_rank];
57 let total_len = dims.iter().product::<usize>();
58 let mut dense = vec![T::default(); total_len];
59 let diagonal_stride = (0..logical_rank)
60 .scan(1usize, |stride, _| {
61 let current = *stride;
62 *stride = stride.saturating_mul(diag_len);
63 Some(current)
64 })
65 .sum::<usize>();
66 for (i, value) in diag.iter().copied().enumerate() {
67 dense[i * diagonal_stride] = value;
68 }
69 Ok(dense)
70}
71
72macro_rules! impl_tensor_element {
73 ($ty:ty, $dtype:expr) => {
74 impl TensorElement for $ty {
75 fn dense_native_tensor_from_col_major(
76 data: &[Self],
77 dims: &[usize],
78 ) -> Result<NativeTensor> {
79 let expected_len: usize = dims.iter().product();
80 ensure!(
81 data.len() == expected_len,
82 "dense tensor len {} does not match dims {:?} (expected {})",
83 data.len(),
84 dims,
85 expected_len
86 );
87 Ok(NativeTensor::from_vec(dims.to_vec(), data.to_vec()))
88 }
89
90 fn diag_native_tensor_from_col_major(
91 data: &[Self],
92 logical_rank: usize,
93 ) -> Result<NativeTensor> {
94 let dims = vec![data.len(); logical_rank];
95 let dense = dense_diagonal_values(data, logical_rank)?;
96 Self::dense_native_tensor_from_col_major(&dense, &dims)
97 }
98
99 fn scalar_native_tensor(value: Self) -> Result<NativeTensor> {
100 Ok(NativeTensor::from_vec(vec![], vec![value]))
101 }
102
103 fn dense_values_from_native_col_major(tensor: &NativeTensor) -> Result<Vec<Self>> {
104 tensor
105 .as_slice::<Self>()
106 .map(|values| values.to_vec())
107 .ok_or_else(|| {
108 anyhow!(
109 "tensor dtype mismatch: expected {}, got {}",
110 tensor_dtype_name($dtype),
111 tensor_dtype_name(tensor.dtype())
112 )
113 })
114 }
115
116 fn diag_values_from_native_temp(tensor: &NativeTensor) -> Result<Vec<Self>> {
117 let shape = tensor.shape();
118 ensure!(
119 !shape.is_empty(),
120 "diagonal extraction requires rank >= 1, got scalar tensor"
121 );
122 let diag_len = shape[0];
123 ensure!(
124 shape.iter().all(|&dim| dim == diag_len),
125 "expected square/equal dims for diagonal extraction, got {:?}",
126 shape
127 );
128 let dense = Self::dense_values_from_native_col_major(tensor)?;
129 let diagonal_stride = (0..shape.len())
130 .scan(1usize, |stride, _| {
131 let current = *stride;
132 *stride = stride.saturating_mul(diag_len);
133 Some(current)
134 })
135 .sum::<usize>();
136 Ok((0..diag_len).map(|i| dense[i * diagonal_stride]).collect())
137 }
138 }
139 };
140}
141
142impl_tensor_element!(f32, DType::F32);
143impl_tensor_element!(f64, DType::F64);
144impl_tensor_element!(Complex32, DType::C32);
145impl_tensor_element!(Complex64, DType::C64);