Skip to main content

tenferro/
eager_exec.rs

1use num_complex::{Complex32, Complex64};
2use tenferro_ops::dim_expr::DimExpr;
3use tenferro_ops::std_tensor_op::StdTensorOp;
4use tenferro_tensor::validate::validate_nonsingular_u;
5use tenferro_tensor::{DType, PadConfig, SliceConfig, Tensor, TensorBackend, TypedTensor};
6
7use crate::error::{Error, Result};
8
9/// Execute a single [`StdTensorOp`] on concrete tensors.
10///
11/// Most ops produce one output tensor. Multi-output linalg ops return one
12/// tensor per output slot.
13pub fn exec_op_on_tensors<B: TensorBackend>(
14    op: &StdTensorOp,
15    inputs: &[&Tensor],
16    backend: &mut B,
17) -> Result<Vec<Tensor>> {
18    if let StdTensorOp::NaryEinsum { subscripts, .. } = op {
19        return Ok(vec![tenferro_einsum::eager_einsum(
20            backend, inputs, subscripts,
21        )?]);
22    }
23
24    backend.with_exec_session(|exec| {
25        let result = match op {
26            StdTensorOp::Add => vec![exec.add(inputs[0], inputs[1])?],
27            StdTensorOp::Mul => vec![exec.mul(inputs[0], inputs[1])?],
28            StdTensorOp::Neg => vec![exec.neg(inputs[0])?],
29            StdTensorOp::Div => vec![exec.div(inputs[0], inputs[1])?],
30            StdTensorOp::Exp => vec![exec.exp(inputs[0])?],
31            StdTensorOp::Log => vec![exec.log(inputs[0])?],
32            StdTensorOp::Sin => vec![exec.sin(inputs[0])?],
33            StdTensorOp::Cos => vec![exec.cos(inputs[0])?],
34            StdTensorOp::Tanh => vec![exec.tanh(inputs[0])?],
35            StdTensorOp::Sqrt => vec![exec.sqrt(inputs[0])?],
36            StdTensorOp::Rsqrt => vec![exec.rsqrt(inputs[0])?],
37            StdTensorOp::Pow => vec![exec.pow(inputs[0], inputs[1])?],
38            StdTensorOp::Abs => vec![exec.abs(inputs[0])?],
39            StdTensorOp::Sign => vec![exec.sign(inputs[0])?],
40            StdTensorOp::Conj => vec![exec.conj(inputs[0])?],
41            StdTensorOp::Maximum => vec![exec.maximum(inputs[0], inputs[1])?],
42            StdTensorOp::Minimum => vec![exec.minimum(inputs[0], inputs[1])?],
43            StdTensorOp::Compare(dir) => vec![exec.compare(inputs[0], inputs[1], dir)?],
44            StdTensorOp::Transpose { perm } => vec![exec.transpose(inputs[0], perm)?],
45            StdTensorOp::ReduceSum { axes, .. } => vec![exec.reduce_sum(inputs[0], axes)?],
46            StdTensorOp::DotGeneral(config) => {
47                vec![exec.dot_general(inputs[0], inputs[1], config)?]
48            }
49            StdTensorOp::Reshape { to_shape, .. } => {
50                let shape = resolve_tensor_shape_exprs(inputs, to_shape);
51                vec![exec.reshape(inputs[0], &shape)?]
52            }
53            StdTensorOp::BroadcastInDim { shape, dims } => {
54                let shape = resolve_tensor_shape_exprs(inputs, shape);
55                vec![exec.broadcast_in_dim(inputs[0], &shape, dims)?]
56            }
57            StdTensorOp::ExtractDiag { axis_a, axis_b } => {
58                vec![exec.extract_diagonal(inputs[0], *axis_a, *axis_b)?]
59            }
60            StdTensorOp::EmbedDiag { axis_a, axis_b } => {
61                vec![exec.embed_diagonal(inputs[0], *axis_a, *axis_b)?]
62            }
63            StdTensorOp::Tril { k } => vec![exec.tril(inputs[0], *k)?],
64            StdTensorOp::Triu { k } => vec![exec.triu(inputs[0], *k)?],
65            StdTensorOp::Slice(config) => vec![exec.slice(inputs[0], config)?],
66            StdTensorOp::Pad(config) => vec![exec.pad(inputs[0], config)?],
67            StdTensorOp::Reverse { axes } => vec![exec.reverse(inputs[0], axes)?],
68            StdTensorOp::ReduceProd { axes, .. } => vec![exec.reduce_prod(inputs[0], axes)?],
69            StdTensorOp::ReduceMax { axes, .. } => vec![exec.reduce_max(inputs[0], axes)?],
70            StdTensorOp::ReduceMin { axes, .. } => vec![exec.reduce_min(inputs[0], axes)?],
71            StdTensorOp::Expm1 => vec![exec.expm1(inputs[0])?],
72            StdTensorOp::Log1p => vec![exec.log1p(inputs[0])?],
73            StdTensorOp::Convert { to, .. } => vec![exec.convert(inputs[0], *to)?],
74            StdTensorOp::Constant { dtype, bytes } => vec![constant_tensor(*dtype, bytes)],
75            StdTensorOp::Select => vec![exec.select(inputs[0], inputs[1], inputs[2])?],
76            StdTensorOp::Clamp => vec![exec.clamp(inputs[0], inputs[1], inputs[2])?],
77            StdTensorOp::Concatenate { axis } => {
78                vec![exec.concatenate(inputs, *axis)?]
79            }
80            StdTensorOp::NaryEinsum { .. } => {
81                unreachable!("NaryEinsum is handled before opening an exec session")
82            }
83            StdTensorOp::Gather(config) => vec![exec.gather(inputs[0], inputs[1], config)?],
84            StdTensorOp::Scatter(config) => {
85                vec![exec.scatter(inputs[0], inputs[1], inputs[2], config)?]
86            }
87            StdTensorOp::DynamicSlice { slice_sizes } => {
88                vec![exec.dynamic_slice(inputs[0], inputs[1], slice_sizes)?]
89            }
90            StdTensorOp::Cholesky { .. } => vec![exec.cholesky(inputs[0])?],
91            StdTensorOp::TriangularSolve {
92                left_side,
93                lower,
94                transpose_a,
95                unit_diagonal,
96                ..
97            } => {
98                vec![exec.triangular_solve(
99                    inputs[0],
100                    inputs[1],
101                    *left_side,
102                    *lower,
103                    *transpose_a,
104                    *unit_diagonal,
105                )?]
106            }
107            StdTensorOp::Svd { .. } => exec.svd(inputs[0])?,
108            StdTensorOp::Qr { .. } => exec.qr(inputs[0])?,
109            StdTensorOp::Lu { .. } => exec.lu(inputs[0])?,
110            StdTensorOp::Eigh { .. } => exec.eigh(inputs[0])?,
111            StdTensorOp::Eig { .. } => exec.eig(inputs[0])?,
112            StdTensorOp::ShapeOf { axis } => {
113                let input = inputs[0];
114                if *axis >= input.shape().len() {
115                    return Err(Error::Internal(format!(
116                        "ShapeOf: axis {} out of bounds for rank {}",
117                        axis,
118                        input.shape().len()
119                    )));
120                }
121                let size = input.shape()[*axis] as f64;
122                vec![Tensor::F64(TypedTensor::from_vec(vec![], vec![size]))]
123            }
124            StdTensorOp::DynamicTruncate { axis } => {
125                let input = inputs[0];
126                if *axis >= input.shape().len() {
127                    return Err(Error::Internal(format!(
128                        "DynamicTruncate: axis {} out of bounds for rank {}",
129                        axis,
130                        input.shape().len()
131                    )));
132                }
133                let size_tensor = inputs[1];
134                let axis_extent = input.shape()[*axis];
135                let size_f64 = match size_tensor {
136                    Tensor::F64(inner) => inner.host_data()[0],
137                    Tensor::F32(inner) => inner.host_data()[0] as f64,
138                    _ => {
139                        return Err(Error::Internal(
140                            "DynamicTruncate size must be an f32 or f64 scalar".into(),
141                        ))
142                    }
143                };
144                let rounded_size = if size_f64.is_finite() {
145                    size_f64.round()
146                } else {
147                    0.0
148                };
149                let size = rounded_size.max(0.0).min(axis_extent as f64) as usize;
150                let rank = input.shape().len();
151                let mut limits = input.shape().to_vec();
152                limits[*axis] = size;
153                let config = SliceConfig {
154                    starts: vec![0; rank],
155                    limits,
156                    strides: vec![1; rank],
157                };
158                vec![exec.slice(input, &config)?]
159            }
160            StdTensorOp::PadToMatch { axis } => {
161                let input = inputs[0];
162                let reference = inputs[1];
163                if *axis >= input.shape().len() {
164                    return Err(Error::Internal(format!(
165                        "PadToMatch: axis {} out of bounds for rank {}",
166                        axis,
167                        input.shape().len()
168                    )));
169                }
170                let target_size = reference.shape()[*axis];
171                let current_size = input.shape()[*axis];
172                if current_size >= target_size {
173                    vec![input.clone()]
174                } else {
175                    let rank = input.shape().len();
176                    let mut high = vec![0i64; rank];
177                    high[*axis] = (target_size - current_size) as i64;
178                    let config = PadConfig {
179                        edge_padding_low: vec![0i64; rank],
180                        edge_padding_high: high,
181                        interior_padding: vec![0i64; rank],
182                    };
183                    vec![exec.pad(input, &config)?]
184                }
185            }
186            StdTensorOp::ValidateNonsingular { .. } => {
187                validate_nonsingular_u(inputs[0])?;
188                vec![inputs[0].clone()]
189            }
190        };
191        Ok(result)
192    })
193}
194
195fn resolve_tensor_shape_exprs(inputs: &[&Tensor], exprs: &[DimExpr]) -> Vec<usize> {
196    let input_shapes: Vec<&[usize]> = inputs.iter().map(|tensor| tensor.shape()).collect();
197    DimExpr::eval_all(exprs, &input_shapes)
198}
199
200fn constant_tensor(dtype: DType, bytes: &[u8]) -> Tensor {
201    match dtype {
202        DType::F64 => Tensor::F64(TypedTensor::from_vec(
203            vec![],
204            vec![f64::from_le_bytes(exact_bytes::<8>(dtype, bytes))],
205        )),
206        DType::F32 => Tensor::F32(TypedTensor::from_vec(
207            vec![],
208            vec![f32::from_le_bytes(exact_bytes::<4>(dtype, bytes))],
209        )),
210        DType::C64 => {
211            let data = exact_bytes::<16>(dtype, bytes);
212            let mut re_bytes = [0u8; 8];
213            let mut im_bytes = [0u8; 8];
214            re_bytes.copy_from_slice(&data[..8]);
215            im_bytes.copy_from_slice(&data[8..]);
216            let re = f64::from_le_bytes(re_bytes);
217            let im = f64::from_le_bytes(im_bytes);
218            Tensor::C64(TypedTensor::from_vec(vec![], vec![Complex64::new(re, im)]))
219        }
220        DType::C32 => {
221            let data = exact_bytes::<8>(dtype, bytes);
222            let mut re_bytes = [0u8; 4];
223            let mut im_bytes = [0u8; 4];
224            re_bytes.copy_from_slice(&data[..4]);
225            im_bytes.copy_from_slice(&data[4..]);
226            let re = f32::from_le_bytes(re_bytes);
227            let im = f32::from_le_bytes(im_bytes);
228            Tensor::C32(TypedTensor::from_vec(vec![], vec![Complex32::new(re, im)]))
229        }
230    }
231}
232
233fn exact_bytes<const N: usize>(dtype: DType, bytes: &[u8]) -> [u8; N] {
234    if bytes.len() != N {
235        panic!(
236            "constant {:?} expected {} bytes, got {}",
237            dtype,
238            N,
239            bytes.len()
240        );
241    }
242    let mut out = [0u8; N];
243    out.copy_from_slice(bytes);
244    out
245}