1use num_complex::{Complex32, Complex64};
2use tenferro_ops::dim_expr::DimExpr;
3use tenferro_ops::std_tensor_op::StdTensorOp;
4use tenferro_tensor::validate::validate_nonsingular_u;
5use tenferro_tensor::{DType, PadConfig, SliceConfig, Tensor, TensorBackend, TypedTensor};
6
7use crate::error::{Error, Result};
8
9pub fn exec_op_on_tensors<B: TensorBackend>(
14 op: &StdTensorOp,
15 inputs: &[&Tensor],
16 backend: &mut B,
17) -> Result<Vec<Tensor>> {
18 if let StdTensorOp::NaryEinsum { subscripts, .. } = op {
19 return Ok(vec![tenferro_einsum::eager_einsum(
20 backend, inputs, subscripts,
21 )?]);
22 }
23
24 backend.with_exec_session(|exec| {
25 let result = match op {
26 StdTensorOp::Add => vec![exec.add(inputs[0], inputs[1])?],
27 StdTensorOp::Mul => vec![exec.mul(inputs[0], inputs[1])?],
28 StdTensorOp::Neg => vec![exec.neg(inputs[0])?],
29 StdTensorOp::Div => vec![exec.div(inputs[0], inputs[1])?],
30 StdTensorOp::Exp => vec![exec.exp(inputs[0])?],
31 StdTensorOp::Log => vec![exec.log(inputs[0])?],
32 StdTensorOp::Sin => vec![exec.sin(inputs[0])?],
33 StdTensorOp::Cos => vec![exec.cos(inputs[0])?],
34 StdTensorOp::Tanh => vec![exec.tanh(inputs[0])?],
35 StdTensorOp::Sqrt => vec![exec.sqrt(inputs[0])?],
36 StdTensorOp::Rsqrt => vec![exec.rsqrt(inputs[0])?],
37 StdTensorOp::Pow => vec![exec.pow(inputs[0], inputs[1])?],
38 StdTensorOp::Abs => vec![exec.abs(inputs[0])?],
39 StdTensorOp::Sign => vec![exec.sign(inputs[0])?],
40 StdTensorOp::Conj => vec![exec.conj(inputs[0])?],
41 StdTensorOp::Maximum => vec![exec.maximum(inputs[0], inputs[1])?],
42 StdTensorOp::Minimum => vec![exec.minimum(inputs[0], inputs[1])?],
43 StdTensorOp::Compare(dir) => vec![exec.compare(inputs[0], inputs[1], dir)?],
44 StdTensorOp::Transpose { perm } => vec![exec.transpose(inputs[0], perm)?],
45 StdTensorOp::ReduceSum { axes, .. } => vec![exec.reduce_sum(inputs[0], axes)?],
46 StdTensorOp::DotGeneral(config) => {
47 vec![exec.dot_general(inputs[0], inputs[1], config)?]
48 }
49 StdTensorOp::Reshape { to_shape, .. } => {
50 let shape = resolve_tensor_shape_exprs(inputs, to_shape);
51 vec![exec.reshape(inputs[0], &shape)?]
52 }
53 StdTensorOp::BroadcastInDim { shape, dims } => {
54 let shape = resolve_tensor_shape_exprs(inputs, shape);
55 vec![exec.broadcast_in_dim(inputs[0], &shape, dims)?]
56 }
57 StdTensorOp::ExtractDiag { axis_a, axis_b } => {
58 vec![exec.extract_diagonal(inputs[0], *axis_a, *axis_b)?]
59 }
60 StdTensorOp::EmbedDiag { axis_a, axis_b } => {
61 vec![exec.embed_diagonal(inputs[0], *axis_a, *axis_b)?]
62 }
63 StdTensorOp::Tril { k } => vec![exec.tril(inputs[0], *k)?],
64 StdTensorOp::Triu { k } => vec![exec.triu(inputs[0], *k)?],
65 StdTensorOp::Slice(config) => vec![exec.slice(inputs[0], config)?],
66 StdTensorOp::Pad(config) => vec![exec.pad(inputs[0], config)?],
67 StdTensorOp::Reverse { axes } => vec![exec.reverse(inputs[0], axes)?],
68 StdTensorOp::ReduceProd { axes, .. } => vec![exec.reduce_prod(inputs[0], axes)?],
69 StdTensorOp::ReduceMax { axes, .. } => vec![exec.reduce_max(inputs[0], axes)?],
70 StdTensorOp::ReduceMin { axes, .. } => vec![exec.reduce_min(inputs[0], axes)?],
71 StdTensorOp::Expm1 => vec![exec.expm1(inputs[0])?],
72 StdTensorOp::Log1p => vec![exec.log1p(inputs[0])?],
73 StdTensorOp::Convert { to, .. } => vec![exec.convert(inputs[0], *to)?],
74 StdTensorOp::Constant { dtype, bytes } => vec![constant_tensor(*dtype, bytes)],
75 StdTensorOp::Select => vec![exec.select(inputs[0], inputs[1], inputs[2])?],
76 StdTensorOp::Clamp => vec![exec.clamp(inputs[0], inputs[1], inputs[2])?],
77 StdTensorOp::Concatenate { axis } => {
78 vec![exec.concatenate(inputs, *axis)?]
79 }
80 StdTensorOp::NaryEinsum { .. } => {
81 unreachable!("NaryEinsum is handled before opening an exec session")
82 }
83 StdTensorOp::Gather(config) => vec![exec.gather(inputs[0], inputs[1], config)?],
84 StdTensorOp::Scatter(config) => {
85 vec![exec.scatter(inputs[0], inputs[1], inputs[2], config)?]
86 }
87 StdTensorOp::DynamicSlice { slice_sizes } => {
88 vec![exec.dynamic_slice(inputs[0], inputs[1], slice_sizes)?]
89 }
90 StdTensorOp::Cholesky { .. } => vec![exec.cholesky(inputs[0])?],
91 StdTensorOp::TriangularSolve {
92 left_side,
93 lower,
94 transpose_a,
95 unit_diagonal,
96 ..
97 } => {
98 vec![exec.triangular_solve(
99 inputs[0],
100 inputs[1],
101 *left_side,
102 *lower,
103 *transpose_a,
104 *unit_diagonal,
105 )?]
106 }
107 StdTensorOp::Svd { .. } => exec.svd(inputs[0])?,
108 StdTensorOp::Qr { .. } => exec.qr(inputs[0])?,
109 StdTensorOp::Lu { .. } => exec.lu(inputs[0])?,
110 StdTensorOp::Eigh { .. } => exec.eigh(inputs[0])?,
111 StdTensorOp::Eig { .. } => exec.eig(inputs[0])?,
112 StdTensorOp::ShapeOf { axis } => {
113 let input = inputs[0];
114 if *axis >= input.shape().len() {
115 return Err(Error::Internal(format!(
116 "ShapeOf: axis {} out of bounds for rank {}",
117 axis,
118 input.shape().len()
119 )));
120 }
121 let size = input.shape()[*axis] as f64;
122 vec![Tensor::F64(TypedTensor::from_vec(vec![], vec![size]))]
123 }
124 StdTensorOp::DynamicTruncate { axis } => {
125 let input = inputs[0];
126 if *axis >= input.shape().len() {
127 return Err(Error::Internal(format!(
128 "DynamicTruncate: axis {} out of bounds for rank {}",
129 axis,
130 input.shape().len()
131 )));
132 }
133 let size_tensor = inputs[1];
134 let axis_extent = input.shape()[*axis];
135 let size_f64 = match size_tensor {
136 Tensor::F64(inner) => inner.host_data()[0],
137 Tensor::F32(inner) => inner.host_data()[0] as f64,
138 _ => {
139 return Err(Error::Internal(
140 "DynamicTruncate size must be an f32 or f64 scalar".into(),
141 ))
142 }
143 };
144 let rounded_size = if size_f64.is_finite() {
145 size_f64.round()
146 } else {
147 0.0
148 };
149 let size = rounded_size.max(0.0).min(axis_extent as f64) as usize;
150 let rank = input.shape().len();
151 let mut limits = input.shape().to_vec();
152 limits[*axis] = size;
153 let config = SliceConfig {
154 starts: vec![0; rank],
155 limits,
156 strides: vec![1; rank],
157 };
158 vec![exec.slice(input, &config)?]
159 }
160 StdTensorOp::PadToMatch { axis } => {
161 let input = inputs[0];
162 let reference = inputs[1];
163 if *axis >= input.shape().len() {
164 return Err(Error::Internal(format!(
165 "PadToMatch: axis {} out of bounds for rank {}",
166 axis,
167 input.shape().len()
168 )));
169 }
170 let target_size = reference.shape()[*axis];
171 let current_size = input.shape()[*axis];
172 if current_size >= target_size {
173 vec![input.clone()]
174 } else {
175 let rank = input.shape().len();
176 let mut high = vec![0i64; rank];
177 high[*axis] = (target_size - current_size) as i64;
178 let config = PadConfig {
179 edge_padding_low: vec![0i64; rank],
180 edge_padding_high: high,
181 interior_padding: vec![0i64; rank],
182 };
183 vec![exec.pad(input, &config)?]
184 }
185 }
186 StdTensorOp::ValidateNonsingular { .. } => {
187 validate_nonsingular_u(inputs[0])?;
188 vec![inputs[0].clone()]
189 }
190 };
191 Ok(result)
192 })
193}
194
195fn resolve_tensor_shape_exprs(inputs: &[&Tensor], exprs: &[DimExpr]) -> Vec<usize> {
196 let input_shapes: Vec<&[usize]> = inputs.iter().map(|tensor| tensor.shape()).collect();
197 DimExpr::eval_all(exprs, &input_shapes)
198}
199
200fn constant_tensor(dtype: DType, bytes: &[u8]) -> Tensor {
201 match dtype {
202 DType::F64 => Tensor::F64(TypedTensor::from_vec(
203 vec![],
204 vec![f64::from_le_bytes(exact_bytes::<8>(dtype, bytes))],
205 )),
206 DType::F32 => Tensor::F32(TypedTensor::from_vec(
207 vec![],
208 vec![f32::from_le_bytes(exact_bytes::<4>(dtype, bytes))],
209 )),
210 DType::C64 => {
211 let data = exact_bytes::<16>(dtype, bytes);
212 let mut re_bytes = [0u8; 8];
213 let mut im_bytes = [0u8; 8];
214 re_bytes.copy_from_slice(&data[..8]);
215 im_bytes.copy_from_slice(&data[8..]);
216 let re = f64::from_le_bytes(re_bytes);
217 let im = f64::from_le_bytes(im_bytes);
218 Tensor::C64(TypedTensor::from_vec(vec![], vec![Complex64::new(re, im)]))
219 }
220 DType::C32 => {
221 let data = exact_bytes::<8>(dtype, bytes);
222 let mut re_bytes = [0u8; 4];
223 let mut im_bytes = [0u8; 4];
224 re_bytes.copy_from_slice(&data[..4]);
225 im_bytes.copy_from_slice(&data[4..]);
226 let re = f32::from_le_bytes(re_bytes);
227 let im = f32::from_le_bytes(im_bytes);
228 Tensor::C32(TypedTensor::from_vec(vec![], vec![Complex32::new(re, im)]))
229 }
230 }
231}
232
233fn exact_bytes<const N: usize>(dtype: DType, bytes: &[u8]) -> [u8; N] {
234 if bytes.len() != N {
235 panic!(
236 "constant {:?} expected {} bytes, got {}",
237 dtype,
238 N,
239 bytes.len()
240 );
241 }
242 let mut out = [0u8; N];
243 out.copy_from_slice(bytes);
244 out
245}