tenferro/eager_ops.rs
1use std::sync::Arc;
2
3use tidu::{GradEdge, GradNode};
4
5use tenferro_ops::dim_expr::DimExpr;
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_tensor::{
8 DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig, Tensor,
9 TensorBackend,
10};
11
12use crate::eager::{
13 eager_val_key, exec_single_output, saved_forward_values, saved_forward_values_multi,
14 EagerTensor,
15};
16use crate::eager_exec::exec_op_on_tensors;
17use crate::error::{Error, Result};
18
19impl<B: TensorBackend> EagerTensor<B> {
20 /// Elementwise addition.
21 ///
22 /// # Examples
23 ///
24 /// ```
25 /// use tenferro::{EagerTensor, Tensor};
26 ///
27 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
28 /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
29 /// let z = x.add(&y).unwrap();
30 ///
31 /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[4.0, 6.0]);
32 /// ```
33 pub fn add(&self, other: &Self) -> Result<Self> {
34 self.binary_op(other, StdTensorOp::Add)
35 }
36
37 /// Elementwise multiplication.
38 ///
39 /// # Examples
40 ///
41 /// ```
42 /// use tenferro::{EagerTensor, Tensor};
43 ///
44 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
45 /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
46 /// let z = x.mul(&y).unwrap();
47 ///
48 /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[3.0, 8.0]);
49 /// ```
50 pub fn mul(&self, other: &Self) -> Result<Self> {
51 self.binary_op(other, StdTensorOp::Mul)
52 }
53
54 /// Negate the tensor.
55 ///
56 /// # Examples
57 ///
58 /// ```
59 /// use tenferro::{EagerTensor, Tensor};
60 ///
61 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, -2.0]));
62 /// let y = x.neg().unwrap();
63 ///
64 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[-1.0, 2.0]);
65 /// ```
66 pub fn neg(&self) -> Result<Self> {
67 self.unary_op(StdTensorOp::Neg)
68 }
69
70 /// Elementwise exponential.
71 ///
72 /// # Examples
73 ///
74 /// ```
75 /// use tenferro::{EagerTensor, Tensor};
76 ///
77 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![0.0_f64]));
78 /// let y = x.exp().unwrap();
79 ///
80 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0]);
81 /// ```
82 pub fn exp(&self) -> Result<Self> {
83 self.unary_op(StdTensorOp::Exp)
84 }
85
86 /// Reduce sum over the requested axes.
87 ///
88 /// # Examples
89 ///
90 /// ```
91 /// use tenferro::{EagerTensor, Tensor};
92 ///
93 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
94 /// let y = x.reduce_sum(&[0, 1]).unwrap();
95 ///
96 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[10.0]);
97 /// ```
98 pub fn reduce_sum(&self, axes: &[usize]) -> Result<Self> {
99 self.unary_op(StdTensorOp::ReduceSum {
100 axes: axes.to_vec(),
101 input_shape: DimExpr::from_concrete(self.data.shape()),
102 })
103 }
104
105 /// Execute a dot-general contraction eagerly.
106 ///
107 /// # Examples
108 ///
109 /// ```
110 /// use tenferro::{DotGeneralConfig, EagerTensor, Tensor};
111 ///
112 /// let a = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]));
113 /// let b = EagerTensor::from_tensor(Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]));
114 /// let c = a.dot_general(&b, DotGeneralConfig {
115 /// lhs_contracting_dims: vec![1],
116 /// rhs_contracting_dims: vec![0],
117 /// lhs_batch_dims: vec![],
118 /// rhs_batch_dims: vec![],
119 /// lhs_rank: 2,
120 /// rhs_rank: 2,
121 /// }).unwrap();
122 ///
123 /// assert_eq!(c.data().shape(), &[2, 2]);
124 /// ```
125 pub fn dot_general(&self, other: &Self, config: DotGeneralConfig) -> Result<Self> {
126 self.binary_op(other, StdTensorOp::DotGeneral(config))
127 }
128
129 /// Permute tensor axes.
130 ///
131 /// # Examples
132 ///
133 /// ```
134 /// use tenferro::{EagerTensor, Tensor};
135 ///
136 /// let x = EagerTensor::from_tensor(Tensor::from_vec(
137 /// vec![2, 3],
138 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
139 /// ));
140 /// let y = x.transpose(&[1, 0]).unwrap();
141 ///
142 /// assert_eq!(y.data().shape(), &[3, 2]);
143 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
144 /// ```
145 pub fn transpose(&self, perm: &[usize]) -> Result<Self> {
146 self.unary_op(StdTensorOp::Transpose {
147 perm: perm.to_vec(),
148 })
149 }
150
151 /// Reshape without changing element order.
152 ///
153 /// # Examples
154 ///
155 /// ```
156 /// use tenferro::{EagerTensor, Tensor};
157 ///
158 /// let x = EagerTensor::from_tensor(Tensor::from_vec(
159 /// vec![2, 3],
160 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
161 /// ));
162 /// let y = x.reshape(&[6]).unwrap();
163 ///
164 /// assert_eq!(y.data().shape(), &[6]);
165 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
166 /// ```
167 pub fn reshape(&self, shape: &[usize]) -> Result<Self> {
168 self.unary_op(StdTensorOp::Reshape {
169 from_shape: DimExpr::from_concrete(self.data.shape()),
170 to_shape: DimExpr::from_concrete(shape),
171 })
172 }
173
174 /// Slice with explicit start, limit, and stride per axis.
175 ///
176 /// # Examples
177 ///
178 /// ```
179 /// use tenferro::{EagerTensor, SliceConfig, Tensor};
180 ///
181 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]));
182 /// let y = x
183 /// .slice(SliceConfig {
184 /// starts: vec![1],
185 /// limits: vec![3],
186 /// strides: vec![1],
187 /// })
188 /// .unwrap();
189 ///
190 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[2.0, 3.0]);
191 /// ```
192 pub fn slice(&self, config: SliceConfig) -> Result<Self> {
193 self.unary_op(StdTensorOp::Slice(config))
194 }
195
196 /// Broadcast into a larger shape with explicit dimension placement.
197 ///
198 /// # Examples
199 ///
200 /// ```
201 /// use tenferro::{EagerTensor, Tensor};
202 ///
203 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]));
204 /// let y = x.broadcast_in_dim(&[3, 2], &[0]).unwrap();
205 ///
206 /// assert_eq!(y.data().shape(), &[3, 2]);
207 /// ```
208 pub fn broadcast_in_dim(&self, shape: &[usize], dims: &[usize]) -> Result<Self> {
209 self.unary_op(StdTensorOp::BroadcastInDim {
210 shape: DimExpr::from_concrete(shape),
211 dims: dims.to_vec(),
212 })
213 }
214
215 /// Convert the tensor to a different dtype.
216 ///
217 /// # Examples
218 ///
219 /// ```
220 /// use tenferro::{DType, EagerTensor, Tensor};
221 ///
222 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, -2.0]));
223 /// let y = x.convert(DType::C64).unwrap();
224 ///
225 /// assert_eq!(y.data().dtype(), DType::C64);
226 /// assert_eq!(y.data().shape(), &[2]);
227 /// ```
228 pub fn convert(&self, to: DType) -> Result<Self> {
229 self.unary_op(StdTensorOp::Convert {
230 from: self.data.dtype(),
231 to,
232 })
233 }
234
235 /// Pad with zeros using StableHLO-style edge and interior padding.
236 ///
237 /// # Examples
238 ///
239 /// ```
240 /// use tenferro::{EagerTensor, PadConfig, Tensor};
241 ///
242 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
243 /// let y = x
244 /// .pad(PadConfig {
245 /// edge_padding_low: vec![1],
246 /// edge_padding_high: vec![1],
247 /// interior_padding: vec![1],
248 /// })
249 /// .unwrap();
250 ///
251 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[0.0, 1.0, 0.0, 2.0, 0.0]);
252 /// ```
253 pub fn pad(&self, config: PadConfig) -> Result<Self> {
254 self.unary_op(StdTensorOp::Pad(config))
255 }
256
257 /// Reverse the order of elements along the requested axes.
258 ///
259 /// # Examples
260 ///
261 /// ```
262 /// use tenferro::{EagerTensor, Tensor};
263 ///
264 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]));
265 /// let y = x.reverse(&[0]).unwrap();
266 ///
267 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[4.0, 3.0, 2.0, 1.0]);
268 /// ```
269 pub fn reverse(&self, axes: &[usize]) -> Result<Self> {
270 self.unary_op(StdTensorOp::Reverse {
271 axes: axes.to_vec(),
272 })
273 }
274
275 /// Gather slices from `self` using integer start indices.
276 ///
277 /// # Examples
278 ///
279 /// ```
280 /// use tenferro::{EagerTensor, GatherConfig, Tensor};
281 ///
282 /// let x = EagerTensor::from_tensor(Tensor::from_vec(
283 /// vec![5],
284 /// vec![10.0_f64, 20.0, 30.0, 40.0, 50.0],
285 /// ));
286 /// let indices = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![4.0_f64, 1.0, 0.0]));
287 /// let y = x
288 /// .gather(
289 /// &indices,
290 /// GatherConfig {
291 /// offset_dims: vec![],
292 /// collapsed_slice_dims: vec![0],
293 /// start_index_map: vec![0],
294 /// index_vector_dim: 1,
295 /// slice_sizes: vec![1],
296 /// },
297 /// )
298 /// .unwrap();
299 ///
300 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[50.0, 20.0, 10.0]);
301 /// ```
302 pub fn gather(&self, indices: &Self, config: GatherConfig) -> Result<Self> {
303 self.binary_op(indices, StdTensorOp::Gather(config))
304 }
305
306 /// Scatter updates into `self` using StableHLO scatter semantics.
307 ///
308 /// # Examples
309 ///
310 /// ```
311 /// use tenferro::{EagerTensor, ScatterConfig, Tensor};
312 ///
313 /// let operand = EagerTensor::from_tensor(Tensor::from_vec(vec![4], vec![0.0_f64, 0.0, 0.0, 0.0]));
314 /// let indices = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 1], vec![1.0_f64, 3.0]));
315 /// let updates = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![5.0_f64, 7.0]));
316 /// let result = operand
317 /// .scatter(
318 /// &indices,
319 /// &updates,
320 /// ScatterConfig {
321 /// update_window_dims: vec![],
322 /// inserted_window_dims: vec![0],
323 /// scatter_dims_to_operand_dims: vec![0],
324 /// index_vector_dim: 1,
325 /// },
326 /// )
327 /// .unwrap();
328 ///
329 /// assert_eq!(result.data().as_slice::<f64>().unwrap(), &[0.0, 5.0, 0.0, 7.0]);
330 /// ```
331 pub fn scatter(&self, indices: &Self, updates: &Self, config: ScatterConfig) -> Result<Self> {
332 self.ternary_op(indices, updates, StdTensorOp::Scatter(config))
333 }
334
335 /// Slice using runtime start indices.
336 ///
337 /// # Examples
338 ///
339 /// ```
340 /// use tenferro::{EagerTensor, Tensor};
341 ///
342 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![5], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0]));
343 /// let starts = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![2.0_f64]));
344 /// let y = x.dynamic_slice(&starts, &[2]).unwrap();
345 ///
346 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[3.0, 4.0]);
347 /// ```
348 pub fn dynamic_slice(&self, starts: &Self, sizes: &[usize]) -> Result<Self> {
349 self.binary_op(
350 starts,
351 StdTensorOp::DynamicSlice {
352 slice_sizes: sizes.to_vec(),
353 },
354 )
355 }
356
357 /// Concatenate tensors along one axis.
358 ///
359 /// # Examples
360 ///
361 /// ```
362 /// use tenferro::{EagerTensor, Tensor};
363 ///
364 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
365 /// let y = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]));
366 /// let z = EagerTensor::concatenate(&[&x, &y], 0).unwrap();
367 ///
368 /// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
369 /// ```
370 pub fn concatenate(tensors: &[&Self], axis: usize) -> Result<Self> {
371 Self::nary_op(tensors, StdTensorOp::Concatenate { axis })
372 }
373
374 /// Extract the diagonal along two axes.
375 ///
376 /// # Examples
377 ///
378 /// ```
379 /// use tenferro::{EagerTensor, Tensor};
380 ///
381 /// let x = EagerTensor::from_tensor(Tensor::from_vec(
382 /// vec![3, 3],
383 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
384 /// ));
385 /// let y = x.extract_diag(0, 1).unwrap();
386 ///
387 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 5.0, 9.0]);
388 /// ```
389 pub fn extract_diag(&self, axis_a: usize, axis_b: usize) -> Result<Self> {
390 self.unary_op(StdTensorOp::ExtractDiag { axis_a, axis_b })
391 }
392
393 /// Embed a vector or lower-rank tensor along a diagonal.
394 ///
395 /// # Examples
396 ///
397 /// ```
398 /// use tenferro::{EagerTensor, Tensor};
399 ///
400 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]));
401 /// let y = x.embed_diag(0, 1).unwrap();
402 ///
403 /// assert_eq!(y.data().shape(), &[3, 3]);
404 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
405 /// ```
406 pub fn embed_diag(&self, axis_a: usize, axis_b: usize) -> Result<Self> {
407 self.unary_op(StdTensorOp::EmbedDiag { axis_a, axis_b })
408 }
409
410 /// Keep the lower triangle and zero the rest.
411 ///
412 /// # Examples
413 ///
414 /// ```
415 /// use tenferro::{EagerTensor, Tensor};
416 ///
417 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
418 /// let y = x.tril(0).unwrap();
419 ///
420 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 2.0, 0.0, 4.0]);
421 /// ```
422 pub fn tril(&self, k: i64) -> Result<Self> {
423 self.unary_op(StdTensorOp::Tril { k })
424 }
425
426 /// Keep the upper triangle and zero the rest.
427 ///
428 /// # Examples
429 ///
430 /// ```
431 /// use tenferro::{EagerTensor, Tensor};
432 ///
433 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
434 /// let y = x.triu(0).unwrap();
435 ///
436 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 0.0, 3.0, 4.0]);
437 /// ```
438 pub fn triu(&self, k: i64) -> Result<Self> {
439 self.unary_op(StdTensorOp::Triu { k })
440 }
441
442 /// Reduce product over the requested axes.
443 ///
444 /// # Examples
445 ///
446 /// ```
447 /// use tenferro::{EagerTensor, Tensor};
448 ///
449 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
450 /// let y = x.reduce_prod(&[0, 1]).unwrap();
451 ///
452 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[24.0]);
453 /// ```
454 pub fn reduce_prod(&self, axes: &[usize]) -> Result<Self> {
455 self.unary_op(StdTensorOp::ReduceProd {
456 axes: axes.to_vec(),
457 input_shape: DimExpr::from_concrete(self.data.shape()),
458 })
459 }
460
461 /// Reduce maximum over the requested axes.
462 ///
463 /// # Examples
464 ///
465 /// ```
466 /// use tenferro::{EagerTensor, Tensor};
467 ///
468 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
469 /// let y = x.reduce_max(&[0, 1]).unwrap();
470 ///
471 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[4.0]);
472 /// ```
473 pub fn reduce_max(&self, axes: &[usize]) -> Result<Self> {
474 self.unary_op(StdTensorOp::ReduceMax {
475 axes: axes.to_vec(),
476 input_shape: DimExpr::from_concrete(self.data.shape()),
477 })
478 }
479
480 /// Reduce minimum over the requested axes.
481 ///
482 /// # Examples
483 ///
484 /// ```
485 /// use tenferro::{EagerTensor, Tensor};
486 ///
487 /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]));
488 /// let y = x.reduce_min(&[0, 1]).unwrap();
489 ///
490 /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0]);
491 /// ```
492 pub fn reduce_min(&self, axes: &[usize]) -> Result<Self> {
493 self.unary_op(StdTensorOp::ReduceMin {
494 axes: axes.to_vec(),
495 input_shape: DimExpr::from_concrete(self.data.shape()),
496 })
497 }
498
499 pub(crate) fn unary_op(&self, op: StdTensorOp) -> Result<Self> {
500 let output = exec_single_output(&op, &[self.data.as_ref()], &self.ctx)?;
501 let result_key = eager_val_key();
502 let input_aliases = vec![eager_val_key()];
503 let grad_node = self.requires_grad.then(|| {
504 Arc::new(GradNode {
505 op: op.clone(),
506 primal_in_keys: input_aliases.clone(),
507 primal_out_keys: vec![result_key.clone()],
508 saved_data: saved_forward_values(
509 &op,
510 &input_aliases,
511 &[Arc::clone(&self.data)],
512 Arc::new(output.clone()),
513 ),
514 input_edges: vec![GradEdge {
515 node: self.grad_node.clone(),
516 key: self.key.clone(),
517 requires_grad: self.requires_grad,
518 }],
519 output_idx: 0,
520 })
521 });
522 Ok(Self::new_result(
523 Arc::clone(&self.ctx),
524 result_key,
525 output,
526 self.requires_grad,
527 grad_node,
528 ))
529 }
530
531 pub(crate) fn binary_op(&self, other: &Self, op: StdTensorOp) -> Result<Self> {
532 Self::nary_op(&[self, other], op)
533 }
534
535 pub(crate) fn multi_output_unary_op(
536 &self,
537 op: StdTensorOp,
538 num_outputs: usize,
539 ) -> Result<Vec<Self>> {
540 let outputs = {
541 let mut backend = self.ctx.backend.lock().unwrap();
542 exec_op_on_tensors(&op, &[self.data.as_ref()], &mut *backend)?
543 };
544 if outputs.len() != num_outputs {
545 return Err(Error::Internal(format!(
546 "expected {} eager outputs for {:?}, got {}",
547 num_outputs,
548 op,
549 outputs.len()
550 )));
551 }
552
553 let outputs: Vec<Arc<Tensor>> = outputs.into_iter().map(Arc::new).collect();
554 let output_keys: Vec<_> = (0..num_outputs).map(|_| eager_val_key()).collect();
555 let input_aliases = vec![eager_val_key()];
556 let grad_node = self.requires_grad.then(|| {
557 Arc::new(GradNode {
558 op: op.clone(),
559 primal_in_keys: input_aliases.clone(),
560 primal_out_keys: output_keys.clone(),
561 saved_data: saved_forward_values_multi(
562 &op,
563 &input_aliases,
564 &[Arc::clone(&self.data)],
565 num_outputs,
566 &outputs,
567 ),
568 input_edges: vec![GradEdge {
569 node: self.grad_node.clone(),
570 key: self.key.clone(),
571 requires_grad: self.requires_grad,
572 }],
573 output_idx: 0,
574 })
575 });
576
577 Ok(output_keys
578 .into_iter()
579 .zip(outputs)
580 .map(|(output_key, output)| {
581 Self::new_result(
582 Arc::clone(&self.ctx),
583 output_key,
584 output.as_ref().clone(),
585 self.requires_grad,
586 grad_node.clone(),
587 )
588 })
589 .collect())
590 }
591
592 pub(crate) fn ternary_op(&self, b: &Self, c: &Self, op: StdTensorOp) -> Result<Self> {
593 Self::nary_op(&[self, b, c], op)
594 }
595
596 pub(crate) fn nary_op(tensors: &[&Self], op: StdTensorOp) -> Result<Self> {
597 let Some(first) = tensors.first() else {
598 return Err(Error::Internal(
599 "nary eager op requires at least one input tensor".to_string(),
600 ));
601 };
602
603 let ctx = Arc::clone(&first.ctx);
604 for tensor in tensors.iter().skip(1) {
605 if !Arc::ptr_eq(&ctx, &tensor.ctx) {
606 ctx.absorb_from(&tensor.ctx);
607 }
608 }
609
610 let inputs: Vec<&Tensor> = tensors.iter().map(|tensor| tensor.data.as_ref()).collect();
611 let output = exec_single_output(&op, &inputs, &ctx)?;
612 let requires_grad = tensors.iter().any(|tensor| tensor.requires_grad);
613 let result_key = eager_val_key();
614 let input_aliases: Vec<_> = tensors.iter().map(|_| eager_val_key()).collect();
615 let input_data: Vec<_> = tensors
616 .iter()
617 .map(|tensor| Arc::clone(&tensor.data))
618 .collect();
619 let grad_node = requires_grad.then(|| {
620 Arc::new(GradNode {
621 op: op.clone(),
622 primal_in_keys: input_aliases.clone(),
623 primal_out_keys: vec![result_key.clone()],
624 saved_data: saved_forward_values(
625 &op,
626 &input_aliases,
627 &input_data,
628 Arc::new(output.clone()),
629 ),
630 input_edges: tensors
631 .iter()
632 .map(|tensor| GradEdge {
633 node: tensor.grad_node.clone(),
634 key: tensor.key.clone(),
635 requires_grad: tensor.requires_grad,
636 })
637 .collect(),
638 output_idx: 0,
639 })
640 });
641
642 Ok(Self::new_result(
643 ctx,
644 result_key,
645 output,
646 requires_grad,
647 grad_node,
648 ))
649 }
650}