Skip to main content

tenferro/
eager_ops.rs

1use std::sync::Arc;
2
3use tidu::{GradEdge, GradNode};
4
5use tenferro_ops::dim_expr::DimExpr;
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_tensor::{
8    DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig, Tensor,
9    TensorBackend,
10};
11
12use crate::eager::{
13    eager_val_key, exec_single_output, saved_forward_values, saved_forward_values_multi,
14    EagerTensor,
15};
16use crate::eager_exec::exec_op_on_tensors;
17use crate::error::{Error, Result};
18
19impl<B: TensorBackend> EagerTensor<B> {
20    /// Elementwise addition.
21    ///
22    /// # Examples
23    ///
24    /// ```
25    /// use tenferro::{EagerTensor, Tensor};
26    ///
27    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
28    /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
29    /// let z = x.add(&y).unwrap();
30    ///
31    /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[4.0, 6.0]);
32    /// ```
33    pub fn add(&self, other: &Self) -> Result<Self> {
34        self.binary_op(other, StdTensorOp::Add)
35    }
36
37    /// Elementwise multiplication.
38    ///
39    /// # Examples
40    ///
41    /// ```
42    /// use tenferro::{EagerTensor, Tensor};
43    ///
44    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
45    /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
46    /// let z = x.mul(&y).unwrap();
47    ///
48    /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[3.0, 8.0]);
49    /// ```
50    pub fn mul(&self, other: &Self) -> Result<Self> {
51        self.binary_op(other, StdTensorOp::Mul)
52    }
53
54    /// Negate the tensor.
55    ///
56    /// # Examples
57    ///
58    /// ```
59    /// use tenferro::{EagerTensor, Tensor};
60    ///
61    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, -2.0]));
62    /// let y = x.neg().unwrap();
63    ///
64    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[-1.0, 2.0]);
65    /// ```
66    pub fn neg(&self) -> Result<Self> {
67        self.unary_op(StdTensorOp::Neg)
68    }
69
70    /// Elementwise exponential.
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use tenferro::{EagerTensor, Tensor};
76    ///
77    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
78    /// let y = x.exp().unwrap();
79    ///
80    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0]);
81    /// ```
82    pub fn exp(&self) -> Result<Self> {
83        self.unary_op(StdTensorOp::Exp)
84    }
85
86    /// Reduce sum over the requested axes.
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// use tenferro::{EagerTensor, Tensor};
92    ///
93    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
94    /// let y = x.reduce_sum(&[0, 1]).unwrap();
95    ///
96    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[10.0]);
97    /// ```
98    pub fn reduce_sum(&self, axes: &[usize]) -> Result<Self> {
99        self.unary_op(StdTensorOp::ReduceSum {
100            axes: axes.to_vec(),
101            input_shape: DimExpr::from_concrete(self.data.shape()),
102        })
103    }
104
105    /// Execute a dot-general contraction eagerly.
106    ///
107    /// # Examples
108    ///
109    /// ```
110    /// use tenferro::{DotGeneralConfig, EagerTensor, Tensor};
111    ///
112    /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]));
113    /// let b = EagerTensor::from_tensor(Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]));
114    /// let c = a.dot_general(&b, DotGeneralConfig {
115    ///     lhs_contracting_dims: vec![1],
116    ///     rhs_contracting_dims: vec![0],
117    ///     lhs_batch_dims: vec![],
118    ///     rhs_batch_dims: vec![],
119    ///     lhs_rank: 2,
120    ///     rhs_rank: 2,
121    /// }).unwrap();
122    ///
123    /// assert_eq!(c.data().shape(), &[2, 2]);
124    /// ```
125    pub fn dot_general(&self, other: &Self, config: DotGeneralConfig) -> Result<Self> {
126        self.binary_op(other, StdTensorOp::DotGeneral(config))
127    }
128
129    /// Permute tensor axes.
130    ///
131    /// # Examples
132    ///
133    /// ```
134    /// use tenferro::{EagerTensor, Tensor};
135    ///
136    /// let x = EagerTensor::from_tensor(Tensor::from_vec(
137    ///     vec![2, 3],
138    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
139    /// ));
140    /// let y = x.transpose(&[1, 0]).unwrap();
141    ///
142    /// assert_eq!(y.data().shape(), &[3, 2]);
143    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
144    /// ```
145    pub fn transpose(&self, perm: &[usize]) -> Result<Self> {
146        self.unary_op(StdTensorOp::Transpose {
147            perm: perm.to_vec(),
148        })
149    }
150
151    /// Reshape without changing element order.
152    ///
153    /// # Examples
154    ///
155    /// ```
156    /// use tenferro::{EagerTensor, Tensor};
157    ///
158    /// let x = EagerTensor::from_tensor(Tensor::from_vec(
159    ///     vec![2, 3],
160    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
161    /// ));
162    /// let y = x.reshape(&[6]).unwrap();
163    ///
164    /// assert_eq!(y.data().shape(), &[6]);
165    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
166    /// ```
167    pub fn reshape(&self, shape: &[usize]) -> Result<Self> {
168        self.unary_op(StdTensorOp::Reshape {
169            from_shape: DimExpr::from_concrete(self.data.shape()),
170            to_shape: DimExpr::from_concrete(shape),
171        })
172    }
173
174    /// Slice with explicit start, limit, and stride per axis.
175    ///
176    /// # Examples
177    ///
178    /// ```
179    /// use tenferro::{EagerTensor, SliceConfig, Tensor};
180    ///
181    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]));
182    /// let y = x
183    ///     .slice(SliceConfig {
184    ///         starts: vec![1],
185    ///         limits: vec![3],
186    ///         strides: vec![1],
187    ///     })
188    ///     .unwrap();
189    ///
190    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[2.0, 3.0]);
191    /// ```
192    pub fn slice(&self, config: SliceConfig) -> Result<Self> {
193        self.unary_op(StdTensorOp::Slice(config))
194    }
195
196    /// Broadcast into a larger shape with explicit dimension placement.
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// use tenferro::{EagerTensor, Tensor};
202    ///
203    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]));
204    /// let y = x.broadcast_in_dim(&[3, 2], &[0]).unwrap();
205    ///
206    /// assert_eq!(y.data().shape(), &[3, 2]);
207    /// ```
208    pub fn broadcast_in_dim(&self, shape: &[usize], dims: &[usize]) -> Result<Self> {
209        self.unary_op(StdTensorOp::BroadcastInDim {
210            shape: DimExpr::from_concrete(shape),
211            dims: dims.to_vec(),
212        })
213    }
214
215    /// Convert the tensor to a different dtype.
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// use tenferro::{DType, EagerTensor, Tensor};
221    ///
222    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, -2.0]));
223    /// let y = x.convert(DType::C64).unwrap();
224    ///
225    /// assert_eq!(y.data().dtype(), DType::C64);
226    /// assert_eq!(y.data().shape(), &[2]);
227    /// ```
228    pub fn convert(&self, to: DType) -> Result<Self> {
229        self.unary_op(StdTensorOp::Convert {
230            from: self.data.dtype(),
231            to,
232        })
233    }
234
235    /// Pad with zeros using StableHLO-style edge and interior padding.
236    ///
237    /// # Examples
238    ///
239    /// ```
240    /// use tenferro::{EagerTensor, PadConfig, Tensor};
241    ///
242    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
243    /// let y = x
244    ///     .pad(PadConfig {
245    ///         edge_padding_low: vec![1],
246    ///         edge_padding_high: vec![1],
247    ///         interior_padding: vec![1],
248    ///     })
249    ///     .unwrap();
250    ///
251    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0, 1.0, 0.0, 2.0, 0.0]);
252    /// ```
253    pub fn pad(&self, config: PadConfig) -> Result<Self> {
254        self.unary_op(StdTensorOp::Pad(config))
255    }
256
257    /// Reverse the order of elements along the requested axes.
258    ///
259    /// # Examples
260    ///
261    /// ```
262    /// use tenferro::{EagerTensor, Tensor};
263    ///
264    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]));
265    /// let y = x.reverse(&[0]).unwrap();
266    ///
267    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[4.0, 3.0, 2.0, 1.0]);
268    /// ```
269    pub fn reverse(&self, axes: &[usize]) -> Result<Self> {
270        self.unary_op(StdTensorOp::Reverse {
271            axes: axes.to_vec(),
272        })
273    }
274
275    /// Gather slices from `self` using integer start indices.
276    ///
277    /// # Examples
278    ///
279    /// ```
280    /// use tenferro::{EagerTensor, GatherConfig, Tensor};
281    ///
282    /// let x = EagerTensor::from_tensor(Tensor::from_vec(
283    ///     vec![5],
284    ///     vec![10.0_f64, 20.0, 30.0, 40.0, 50.0],
285    /// ));
286    /// let indices = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![4.0_f64, 1.0, 0.0]));
287    /// let y = x
288    ///     .gather(
289    ///         &indices,
290    ///         GatherConfig {
291    ///             offset_dims: vec![],
292    ///             collapsed_slice_dims: vec![0],
293    ///             start_index_map: vec![0],
294    ///             index_vector_dim: 1,
295    ///             slice_sizes: vec![1],
296    ///         },
297    ///     )
298    ///     .unwrap();
299    ///
300    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[50.0, 20.0, 10.0]);
301    /// ```
302    pub fn gather(&self, indices: &Self, config: GatherConfig) -> Result<Self> {
303        self.binary_op(indices, StdTensorOp::Gather(config))
304    }
305
306    /// Scatter updates into `self` using StableHLO scatter semantics.
307    ///
308    /// # Examples
309    ///
310    /// ```
311    /// use tenferro::{EagerTensor, ScatterConfig, Tensor};
312    ///
313    /// let operand = EagerTensor::from_tensor(Tensor::from_vec(vec![4], vec![0.0_f64, 0.0, 0.0, 0.0]));
314    /// let indices = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 1], vec![1.0_f64, 3.0]));
315    /// let updates = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![5.0_f64, 7.0]));
316    /// let result = operand
317    ///     .scatter(
318    ///         &indices,
319    ///         &updates,
320    ///         ScatterConfig {
321    ///             update_window_dims: vec![],
322    ///             inserted_window_dims: vec![0],
323    ///             scatter_dims_to_operand_dims: vec![0],
324    ///             index_vector_dim: 1,
325    ///         },
326    ///     )
327    ///     .unwrap();
328    ///
329    /// assert_eq!(result.data().as_slice::<f64>().unwrap(), &[0.0, 5.0, 0.0, 7.0]);
330    /// ```
331    pub fn scatter(&self, indices: &Self, updates: &Self, config: ScatterConfig) -> Result<Self> {
332        self.ternary_op(indices, updates, StdTensorOp::Scatter(config))
333    }
334
335    /// Slice using runtime start indices.
336    ///
337    /// # Examples
338    ///
339    /// ```
340    /// use tenferro::{EagerTensor, Tensor};
341    ///
342    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![5], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0]));
343    /// let starts = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![2.0_f64]));
344    /// let y = x.dynamic_slice(&starts, &[2]).unwrap();
345    ///
346    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[3.0, 4.0]);
347    /// ```
348    pub fn dynamic_slice(&self, starts: &Self, sizes: &[usize]) -> Result<Self> {
349        self.binary_op(
350            starts,
351            StdTensorOp::DynamicSlice {
352                slice_sizes: sizes.to_vec(),
353            },
354        )
355    }
356
357    /// Concatenate tensors along one axis.
358    ///
359    /// # Examples
360    ///
361    /// ```
362    /// use tenferro::{EagerTensor, Tensor};
363    ///
364    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
365    /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
366    /// let z = EagerTensor::concatenate(&[&x, &y], 0).unwrap();
367    ///
368    /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
369    /// ```
370    pub fn concatenate(tensors: &[&Self], axis: usize) -> Result<Self> {
371        Self::nary_op(tensors, StdTensorOp::Concatenate { axis })
372    }
373
374    /// Extract the diagonal along two axes.
375    ///
376    /// # Examples
377    ///
378    /// ```
379    /// use tenferro::{EagerTensor, Tensor};
380    ///
381    /// let x = EagerTensor::from_tensor(Tensor::from_vec(
382    ///     vec![3, 3],
383    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
384    /// ));
385    /// let y = x.extract_diag(0, 1).unwrap();
386    ///
387    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 5.0, 9.0]);
388    /// ```
389    pub fn extract_diag(&self, axis_a: usize, axis_b: usize) -> Result<Self> {
390        self.unary_op(StdTensorOp::ExtractDiag { axis_a, axis_b })
391    }
392
393    /// Embed a vector or lower-rank tensor along a diagonal.
394    ///
395    /// # Examples
396    ///
397    /// ```
398    /// use tenferro::{EagerTensor, Tensor};
399    ///
400    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]));
401    /// let y = x.embed_diag(0, 1).unwrap();
402    ///
403    /// assert_eq!(y.data().shape(), &[3, 3]);
404    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
405    /// ```
406    pub fn embed_diag(&self, axis_a: usize, axis_b: usize) -> Result<Self> {
407        self.unary_op(StdTensorOp::EmbedDiag { axis_a, axis_b })
408    }
409
410    /// Keep the lower triangle and zero the rest.
411    ///
412    /// # Examples
413    ///
414    /// ```
415    /// use tenferro::{EagerTensor, Tensor};
416    ///
417    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
418    /// let y = x.tril(0).unwrap();
419    ///
420    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 2.0, 0.0, 4.0]);
421    /// ```
422    pub fn tril(&self, k: i64) -> Result<Self> {
423        self.unary_op(StdTensorOp::Tril { k })
424    }
425
426    /// Keep the upper triangle and zero the rest.
427    ///
428    /// # Examples
429    ///
430    /// ```
431    /// use tenferro::{EagerTensor, Tensor};
432    ///
433    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
434    /// let y = x.triu(0).unwrap();
435    ///
436    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 0.0, 3.0, 4.0]);
437    /// ```
438    pub fn triu(&self, k: i64) -> Result<Self> {
439        self.unary_op(StdTensorOp::Triu { k })
440    }
441
442    /// Reduce product over the requested axes.
443    ///
444    /// # Examples
445    ///
446    /// ```
447    /// use tenferro::{EagerTensor, Tensor};
448    ///
449    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
450    /// let y = x.reduce_prod(&[0, 1]).unwrap();
451    ///
452    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[24.0]);
453    /// ```
454    pub fn reduce_prod(&self, axes: &[usize]) -> Result<Self> {
455        self.unary_op(StdTensorOp::ReduceProd {
456            axes: axes.to_vec(),
457            input_shape: DimExpr::from_concrete(self.data.shape()),
458        })
459    }
460
461    /// Reduce maximum over the requested axes.
462    ///
463    /// # Examples
464    ///
465    /// ```
466    /// use tenferro::{EagerTensor, Tensor};
467    ///
468    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
469    /// let y = x.reduce_max(&[0, 1]).unwrap();
470    ///
471    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[4.0]);
472    /// ```
473    pub fn reduce_max(&self, axes: &[usize]) -> Result<Self> {
474        self.unary_op(StdTensorOp::ReduceMax {
475            axes: axes.to_vec(),
476            input_shape: DimExpr::from_concrete(self.data.shape()),
477        })
478    }
479
480    /// Reduce minimum over the requested axes.
481    ///
482    /// # Examples
483    ///
484    /// ```
485    /// use tenferro::{EagerTensor, Tensor};
486    ///
487    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
488    /// let y = x.reduce_min(&[0, 1]).unwrap();
489    ///
490    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0]);
491    /// ```
492    pub fn reduce_min(&self, axes: &[usize]) -> Result<Self> {
493        self.unary_op(StdTensorOp::ReduceMin {
494            axes: axes.to_vec(),
495            input_shape: DimExpr::from_concrete(self.data.shape()),
496        })
497    }
498
499    pub(crate) fn unary_op(&self, op: StdTensorOp) -> Result<Self> {
500        let output = exec_single_output(&op, &[self.data.as_ref()], &self.ctx)?;
501        let result_key = eager_val_key();
502        let input_aliases = vec![eager_val_key()];
503        let grad_node = self.requires_grad.then(|| {
504            Arc::new(GradNode {
505                op: op.clone(),
506                primal_in_keys: input_aliases.clone(),
507                primal_out_keys: vec![result_key.clone()],
508                saved_data: saved_forward_values(
509                    &op,
510                    &input_aliases,
511                    &[Arc::clone(&self.data)],
512                    Arc::new(output.clone()),
513                ),
514                input_edges: vec![GradEdge {
515                    node: self.grad_node.clone(),
516                    key: self.key.clone(),
517                    requires_grad: self.requires_grad,
518                }],
519                output_idx: 0,
520            })
521        });
522        Ok(Self::new_result(
523            Arc::clone(&self.ctx),
524            result_key,
525            output,
526            self.requires_grad,
527            grad_node,
528        ))
529    }
530
531    pub(crate) fn binary_op(&self, other: &Self, op: StdTensorOp) -> Result<Self> {
532        Self::nary_op(&[self, other], op)
533    }
534
535    pub(crate) fn multi_output_unary_op(
536        &self,
537        op: StdTensorOp,
538        num_outputs: usize,
539    ) -> Result<Vec<Self>> {
540        let outputs = {
541            let mut backend = self.ctx.backend.lock().unwrap();
542            exec_op_on_tensors(&op, &[self.data.as_ref()], &mut *backend)?
543        };
544        if outputs.len() != num_outputs {
545            return Err(Error::Internal(format!(
546                "expected {} eager outputs for {:?}, got {}",
547                num_outputs,
548                op,
549                outputs.len()
550            )));
551        }
552
553        let outputs: Vec<Arc<Tensor>> = outputs.into_iter().map(Arc::new).collect();
554        let output_keys: Vec<_> = (0..num_outputs).map(|_| eager_val_key()).collect();
555        let input_aliases = vec![eager_val_key()];
556        let grad_node = self.requires_grad.then(|| {
557            Arc::new(GradNode {
558                op: op.clone(),
559                primal_in_keys: input_aliases.clone(),
560                primal_out_keys: output_keys.clone(),
561                saved_data: saved_forward_values_multi(
562                    &op,
563                    &input_aliases,
564                    &[Arc::clone(&self.data)],
565                    num_outputs,
566                    &outputs,
567                ),
568                input_edges: vec![GradEdge {
569                    node: self.grad_node.clone(),
570                    key: self.key.clone(),
571                    requires_grad: self.requires_grad,
572                }],
573                output_idx: 0,
574            })
575        });
576
577        Ok(output_keys
578            .into_iter()
579            .zip(outputs)
580            .map(|(output_key, output)| {
581                Self::new_result(
582                    Arc::clone(&self.ctx),
583                    output_key,
584                    output.as_ref().clone(),
585                    self.requires_grad,
586                    grad_node.clone(),
587                )
588            })
589            .collect())
590    }
591
592    pub(crate) fn ternary_op(&self, b: &Self, c: &Self, op: StdTensorOp) -> Result<Self> {
593        Self::nary_op(&[self, b, c], op)
594    }
595
596    pub(crate) fn nary_op(tensors: &[&Self], op: StdTensorOp) -> Result<Self> {
597        let Some(first) = tensors.first() else {
598            return Err(Error::Internal(
599                "nary eager op requires at least one input tensor".to_string(),
600            ));
601        };
602
603        let ctx = Arc::clone(&first.ctx);
604        for tensor in tensors.iter().skip(1) {
605            if !Arc::ptr_eq(&ctx, &tensor.ctx) {
606                ctx.absorb_from(&tensor.ctx);
607            }
608        }
609
610        let inputs: Vec<&Tensor> = tensors.iter().map(|tensor| tensor.data.as_ref()).collect();
611        let output = exec_single_output(&op, &inputs, &ctx)?;
612        let requires_grad = tensors.iter().any(|tensor| tensor.requires_grad);
613        let result_key = eager_val_key();
614        let input_aliases: Vec<_> = tensors.iter().map(|_| eager_val_key()).collect();
615        let input_data: Vec<_> = tensors
616            .iter()
617            .map(|tensor| Arc::clone(&tensor.data))
618            .collect();
619        let grad_node = requires_grad.then(|| {
620            Arc::new(GradNode {
621                op: op.clone(),
622                primal_in_keys: input_aliases.clone(),
623                primal_out_keys: vec![result_key.clone()],
624                saved_data: saved_forward_values(
625                    &op,
626                    &input_aliases,
627                    &input_data,
628                    Arc::new(output.clone()),
629                ),
630                input_edges: tensors
631                    .iter()
632                    .map(|tensor| GradEdge {
633                        node: tensor.grad_node.clone(),
634                        key: tensor.key.clone(),
635                        requires_grad: tensor.requires_grad,
636                    })
637                    .collect(),
638                output_idx: 0,
639            })
640        });
641
642        Ok(Self::new_result(
643            ctx,
644            result_key,
645            output,
646            requires_grad,
647            grad_node,
648        ))
649    }
650}