Skip to main content

tenferro_ad/
eager_ops.rs

1use std::sync::Arc;
2
3use tenferro_ops::broadcast::{
4    broadcast_input_plan, broadcast_shape, broadcast_shapes, BroadcastError,
5};
6use tenferro_ops::dim_expr::DimExpr;
7use tenferro_ops::std_tensor_op::StdTensorOp;
8use tenferro_tensor::{
9    DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig, Tensor,
10    TensorValue,
11};
12
13use crate::eager::{
14    exec_single_output, exec_single_output_read, maybe_print_eager_op_profile,
15    profile_eager_op_section, record_eager_op_profile, record_eager_outputs, EagerTensor,
16};
17use crate::eager_exec::exec_dot_general_with_conj_on_tensor_reads;
18use crate::error::{Error, Result};
19use crate::metadata::push_metadata_scope;
20
21pub(crate) fn broadcast_binary(
22    op: &'static str,
23    lhs: &EagerTensor,
24    rhs: &EagerTensor,
25) -> Result<(EagerTensor, EagerTensor)> {
26    ensure_same_context(lhs, rhs)?;
27    let shape =
28        broadcast_shape(lhs.shape(), rhs.shape()).map_err(|err| broadcast_error(op, err))?;
29    Ok((
30        broadcast_to(op, lhs, &shape)?,
31        broadcast_to(op, rhs, &shape)?,
32    ))
33}
34
35pub(crate) fn broadcast_ternary(
36    op: &'static str,
37    first: &EagerTensor,
38    second: &EagerTensor,
39    third: &EagerTensor,
40) -> Result<(EagerTensor, EagerTensor, EagerTensor)> {
41    ensure_same_context(first, second)?;
42    ensure_same_context(first, third)?;
43    let shape = broadcast_shapes([first.shape(), second.shape(), third.shape()])
44        .map_err(|err| broadcast_error(op, err))?;
45    Ok((
46        broadcast_to(op, first, &shape)?,
47        broadcast_to(op, second, &shape)?,
48        broadcast_to(op, third, &shape)?,
49    ))
50}
51
52fn broadcast_to(
53    op: &'static str,
54    input: &EagerTensor,
55    target_shape: &[usize],
56) -> Result<EagerTensor> {
57    let input_shape = input.shape();
58    if input_shape == target_shape {
59        return Ok(input.clone());
60    }
61
62    let plan =
63        broadcast_input_plan(input_shape, target_shape).map_err(|err| broadcast_error(op, err))?;
64    let source = if plan.source_shape == input_shape {
65        input.clone()
66    } else {
67        input.reshape(&plan.source_shape)?
68    };
69    source.broadcast_in_dim(target_shape, &plan.dims)
70}
71
72fn broadcast_error(op: &'static str, err: BroadcastError) -> Error {
73    match err {
74        BroadcastError::IncompatibleBinary { lhs, rhs } => {
75            tenferro_tensor::Error::ShapeMismatch { op, lhs, rhs }.into()
76        }
77        BroadcastError::IncompatibleInput { input, output }
78        | BroadcastError::RankTooLarge { input, output } => tenferro_tensor::Error::InvalidConfig {
79            op,
80            message: format!("cannot broadcast shape {input:?} to {output:?}"),
81        }
82        .into(),
83    }
84}
85
86fn ensure_same_context(lhs: &EagerTensor, rhs: &EagerTensor) -> Result<()> {
87    if !lhs.same_context(rhs) {
88        return Err(Error::ContextMismatch {
89            lhs: lhs.ctx_id(),
90            rhs: rhs.ctx_id(),
91        });
92    }
93    Ok(())
94}
95
96impl std::ops::Add for &EagerTensor {
97    type Output = Result<EagerTensor>;
98
99    fn add(self, rhs: &EagerTensor) -> Result<EagerTensor> {
100        EagerTensor::add(self, rhs)
101    }
102}
103
104impl std::ops::Sub for &EagerTensor {
105    type Output = Result<EagerTensor>;
106
107    fn sub(self, rhs: &EagerTensor) -> Result<EagerTensor> {
108        EagerTensor::sub(self, rhs)
109    }
110}
111
112impl std::ops::Mul for &EagerTensor {
113    type Output = Result<EagerTensor>;
114
115    fn mul(self, rhs: &EagerTensor) -> Result<EagerTensor> {
116        EagerTensor::mul(self, rhs)
117    }
118}
119
120impl std::ops::Div for &EagerTensor {
121    type Output = Result<EagerTensor>;
122
123    fn div(self, rhs: &EagerTensor) -> Result<EagerTensor> {
124        EagerTensor::div(self, rhs)
125    }
126}
127
128impl std::ops::Neg for &EagerTensor {
129    type Output = Result<EagerTensor>;
130
131    fn neg(self) -> Result<EagerTensor> {
132        EagerTensor::neg(self)
133    }
134}
135
136impl EagerTensor {
137    /// Elementwise addition.
138    ///
139    /// # Examples
140    ///
141    /// ```
142    /// use tenferro_cpu::CpuBackend;
143    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
144    ///
145    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
146    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
147    /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
148    /// let z = x.add(&y).unwrap();
149    ///
150    /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[4.0, 6.0]);
151    /// ```
152    pub fn add(&self, other: &Self) -> Result<Self> {
153        let (lhs, rhs) = broadcast_binary("add", self, other)?;
154        lhs.binary_op(&rhs, StdTensorOp::Add)
155    }
156
157    /// Elementwise subtraction.
158    pub fn sub(&self, other: &Self) -> Result<Self> {
159        let (lhs, rhs) = broadcast_binary("sub", self, other)?;
160        let rhs = rhs.neg()?;
161        lhs.binary_op(&rhs, StdTensorOp::Add)
162    }
163
164    /// Elementwise multiplication.
165    ///
166    /// # Examples
167    ///
168    /// ```
169    /// use tenferro_cpu::CpuBackend;
170    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
171    ///
172    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
173    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
174    /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
175    /// let z = x.mul(&y).unwrap();
176    ///
177    /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0, 8.0]);
178    /// ```
179    pub fn mul(&self, other: &Self) -> Result<Self> {
180        let (lhs, rhs) = broadcast_binary("mul", self, other)?;
181        lhs.binary_op(&rhs, StdTensorOp::Mul)
182    }
183
184    /// Negate the tensor.
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// use tenferro_cpu::CpuBackend;
190    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
191    ///
192    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
193    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, -2.0]).unwrap(), ctx.clone()).unwrap();
194    /// let y = x.neg().unwrap();
195    ///
196    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[-1.0, 2.0]);
197    /// ```
198    pub fn neg(&self) -> Result<Self> {
199        self.unary_op(StdTensorOp::Neg)
200    }
201
202    /// Elementwise exponential.
203    ///
204    /// # Examples
205    ///
206    /// ```
207    /// use tenferro_cpu::CpuBackend;
208    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
209    ///
210    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
211    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
212    /// let y = x.exp().unwrap();
213    ///
214    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0]);
215    /// ```
216    pub fn exp(&self) -> Result<Self> {
217        self.unary_op(StdTensorOp::Exp)
218    }
219
220    /// Reduce sum over the requested axes.
221    ///
222    /// # Examples
223    ///
224    /// ```
225    /// use tenferro_cpu::CpuBackend;
226    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
227    ///
228    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
229    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
230    /// let y = x.reduce_sum(&[0, 1]).unwrap();
231    ///
232    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[10.0]);
233    /// ```
234    pub fn reduce_sum(&self, axes: &[usize]) -> Result<Self> {
235        self.unary_op(StdTensorOp::ReduceSum {
236            axes: axes.to_vec(),
237        })
238    }
239
240    /// Execute a dot-general contraction eagerly.
241    ///
242    /// # Examples
243    ///
244    /// ```
245    /// use tenferro_cpu::CpuBackend;
246    /// use tenferro_ad::{DotGeneralConfig, EagerRuntime, EagerTensor, Tensor};
247    ///
248    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
249    /// let a = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), ctx.clone()).unwrap();
250    /// let b = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), ctx.clone()).unwrap();
251    /// let c = a.dot_general(&b, DotGeneralConfig {
252    ///     lhs_contracting_dims: vec![1],
253    ///     rhs_contracting_dims: vec![0],
254    ///     lhs_batch_dims: vec![],
255    ///     rhs_batch_dims: vec![],
256    /// }).unwrap();
257    ///
258    /// assert_eq!(c.shape(), &[2, 2]);
259    /// ```
260    pub fn dot_general(&self, other: &Self, config: DotGeneralConfig) -> Result<Self> {
261        self.binary_op(other, StdTensorOp::DotGeneral { config })
262    }
263
264    /// Execute a dot-general contraction, optionally conjugating either operand.
265    ///
266    /// Untracked tensors route the conjugation flags directly to the backend so
267    /// the conjugated operand does not need to be materialized. Tracked tensors
268    /// fall back to explicit `Conj` plus `DotGeneral` so reverse-mode AD keeps
269    /// the same graph semantics as the standard eager ops.
270    pub fn dot_general_with_conj(
271        &self,
272        other: &Self,
273        config: &DotGeneralConfig,
274        lhs_conj: bool,
275        rhs_conj: bool,
276    ) -> Result<Self> {
277        if !self.same_context(other) {
278            return Err(Error::ContextMismatch {
279                lhs: self.ctx_id(),
280                rhs: other.ctx_id(),
281            });
282        }
283
284        if !self.requires_grad && !other.requires_grad {
285            let ctx = Arc::clone(&self.ctx);
286            let output = ctx.with_backend_mut(|backend| {
287                exec_dot_general_with_conj_on_tensor_reads(
288                    self.tensor_read(),
289                    other.tensor_read(),
290                    config,
291                    lhs_conj,
292                    rhs_conj,
293                    backend,
294                )
295            })??;
296            return Self::new_untracked_result(ctx, output);
297        }
298
299        match (lhs_conj, rhs_conj) {
300            (false, false) => self.dot_general(other, config.clone()),
301            (true, false) => self.conj()?.dot_general(other, config.clone()),
302            (false, true) => {
303                let rhs = other.conj()?;
304                self.dot_general(&rhs, config.clone())
305            }
306            (true, true) => {
307                let lhs = self.conj()?;
308                let rhs = other.conj()?;
309                lhs.dot_general(&rhs, config.clone())
310            }
311        }
312    }
313
314    /// Matrix multiplication for rank-2 tensors.
315    ///
316    /// This is a convenience wrapper over [`Self::dot_general`] that
317    /// contracts the left matrix's column axis with the right matrix's row
318    /// axis.
319    ///
320    /// # Examples
321    ///
322    /// ```
323    /// use tenferro_cpu::CpuBackend;
324    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
325    ///
326    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
327    /// let a = EagerTensor::from_tensor_in(
328    ///     Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(),
329    ///     ctx.clone(),
330    /// ).unwrap();
331    /// let b = EagerTensor::from_tensor_in(
332    ///     Tensor::from_vec_col_major(vec![2, 1], vec![5.0_f64, 6.0]).unwrap(),
333    ///     ctx,
334    /// ).unwrap();
335    /// let c = a.matmul(&b).unwrap();
336    ///
337    /// assert_eq!(c.shape(), &[2, 1]);
338    /// assert_eq!(c.materialized().unwrap().as_slice::<f64>().unwrap(), &[23.0, 34.0]);
339    /// ```
340    pub fn matmul(&self, other: &Self) -> Result<Self> {
341        let lhs_shape = self.shape();
342        let rhs_shape = other.shape();
343        if lhs_shape.len() != 2 {
344            return Err(tenferro_tensor::Error::RankMismatch {
345                op: "matmul",
346                expected: 2,
347                actual: lhs_shape.len(),
348            }
349            .into());
350        }
351        if rhs_shape.len() != 2 {
352            return Err(tenferro_tensor::Error::RankMismatch {
353                op: "matmul",
354                expected: 2,
355                actual: rhs_shape.len(),
356            }
357            .into());
358        }
359        if lhs_shape[1] != rhs_shape[0] {
360            return Err(tenferro_tensor::Error::ShapeMismatch {
361                op: "matmul",
362                lhs: lhs_shape.to_vec(),
363                rhs: rhs_shape.to_vec(),
364            }
365            .into());
366        }
367        self.dot_general(
368            other,
369            DotGeneralConfig {
370                lhs_contracting_dims: vec![1],
371                rhs_contracting_dims: vec![0],
372                lhs_batch_dims: vec![],
373                rhs_batch_dims: vec![],
374            },
375        )
376    }
377
378    /// Permute tensor axes.
379    ///
380    /// # Examples
381    ///
382    /// ```
383    /// use tenferro_cpu::CpuBackend;
384    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
385    ///
386    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
387    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(
388    ///     vec![2, 3],
389    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
390    /// ).unwrap(), ctx.clone()).unwrap();
391    /// let y = x.transpose(&[1, 0]).unwrap();
392    ///
393    /// assert_eq!(y.shape(), &[3, 2]);
394    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
395    /// ```
396    pub fn transpose(&self, perm: &[usize]) -> Result<Self> {
397        let op = StdTensorOp::Transpose {
398            perm: perm.to_vec(),
399        };
400        let value = self
401            .value
402            .transpose_view(perm)
403            .map_err(Error::TensorRuntime)?;
404        Self::nary_value_op(&[self], op, value)
405    }
406
407    /// Reshape without changing element order.
408    ///
409    /// # Examples
410    ///
411    /// ```
412    /// use tenferro_cpu::CpuBackend;
413    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
414    ///
415    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
416    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(
417    ///     vec![2, 3],
418    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
419    /// ).unwrap(), ctx.clone()).unwrap();
420    /// let y = x.reshape(&[6]).unwrap();
421    ///
422    /// assert_eq!(y.shape(), &[6]);
423    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
424    /// ```
425    pub fn reshape(&self, shape: &[usize]) -> Result<Self> {
426        let op = StdTensorOp::Reshape {
427            to_shape: DimExpr::from_concrete(shape),
428        };
429        if let Ok(value) = self.value.reshape_view(shape) {
430            return Self::nary_value_op(&[self], op, value);
431        }
432        self.unary_op(op)
433    }
434
435    /// Slice with explicit start, limit, and stride per axis.
436    ///
437    /// # Examples
438    ///
439    /// ```
440    /// use tenferro_cpu::CpuBackend;
441    /// use tenferro_ad::{EagerRuntime, EagerTensor, SliceConfig, Tensor};
442    ///
443    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
444    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
445    /// let y = x
446    ///     .slice(SliceConfig {
447    ///         starts: vec![1],
448    ///         limits: vec![3],
449    ///         strides: vec![1],
450    ///     })
451    ///     .unwrap();
452    ///
453    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[2.0, 3.0]);
454    /// ```
455    pub fn slice(&self, config: SliceConfig) -> Result<Self> {
456        let value = self
457            .value
458            .slice_view(&config)
459            .map_err(Error::TensorRuntime)?;
460        Self::nary_value_op(&[self], StdTensorOp::Slice(config), value)
461    }
462
463    /// Broadcast into a larger shape with explicit dimension placement.
464    ///
465    /// # Examples
466    ///
467    /// ```
468    /// use tenferro_cpu::CpuBackend;
469    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
470    ///
471    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
472    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap(), ctx.clone()).unwrap();
473    /// let y = x.broadcast_in_dim(&[3, 2], &[0]).unwrap();
474    ///
475    /// assert_eq!(y.shape(), &[3, 2]);
476    /// ```
477    pub fn broadcast_in_dim(&self, shape: &[usize], dims: &[usize]) -> Result<Self> {
478        let op = StdTensorOp::BroadcastInDim {
479            shape: DimExpr::from_concrete(shape),
480            dims: dims.to_vec(),
481        };
482        let value = self
483            .value
484            .broadcast_in_dim_view(shape, dims)
485            .map_err(Error::TensorRuntime)?;
486        Self::nary_value_op(&[self], op, value)
487    }
488
489    /// Convert the tensor to a different dtype using checked conversion.
490    ///
491    /// Use [`cast`](Self::cast) when a lossy dtype projection is intended.
492    ///
493    /// # Examples
494    ///
495    /// ```
496    /// use tenferro_cpu::CpuBackend;
497    /// use tenferro_ad::{DType, EagerRuntime, EagerTensor, Tensor};
498    ///
499    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
500    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, -2.0]).unwrap(), ctx.clone()).unwrap();
501    /// let y = x.convert(DType::C64).unwrap();
502    ///
503    /// assert_eq!(y.dtype(), DType::C64);
504    /// assert_eq!(y.shape(), &[2]);
505    /// ```
506    ///
507    /// # Errors
508    ///
509    /// Returns an error when the requested conversion is outside tenferro's
510    /// checked dtype-promotion lattice. Use [`cast`](Self::cast) for explicit
511    /// lossy dtype projection.
512    pub fn convert(&self, to: DType) -> Result<Self> {
513        tenferro_tensor::validate::validate_convert_dtype("EagerTensor::convert", self.dtype(), to)
514            .map_err(Error::TensorRuntime)?;
515        self.cast(to)
516    }
517
518    /// Cast the tensor to a different dtype using explicit dtype projection.
519    ///
520    /// `cast` may truncate, narrow precision, project complex values to their
521    /// real component, or use boolean truthiness where the backend supports the
522    /// requested projection.
523    ///
524    /// # Examples
525    ///
526    /// ```
527    /// use tenferro_cpu::CpuBackend;
528    /// use tenferro_ad::{DType, EagerRuntime, EagerTensor, Tensor};
529    ///
530    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
531    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.2_f64, -2.8]).unwrap(), ctx.clone()).unwrap();
532    /// let y = x.cast(DType::I32).unwrap();
533    ///
534    /// assert_eq!(y.materialized().unwrap().as_slice::<i32>().unwrap(), &[1, -2]);
535    /// ```
536    pub fn cast(&self, to: DType) -> Result<Self> {
537        self.unary_op(StdTensorOp::Convert {
538            from: self.dtype(),
539            to,
540        })
541    }
542
543    /// Pad with zeros using StableHLO-style edge and interior padding.
544    ///
545    /// # Examples
546    ///
547    /// ```
548    /// use tenferro_cpu::CpuBackend;
549    /// use tenferro_ad::{EagerRuntime, EagerTensor, PadConfig, Tensor};
550    ///
551    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
552    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
553    /// let y = x
554    ///     .pad(PadConfig {
555    ///         edge_padding_low: vec![1],
556    ///         edge_padding_high: vec![1],
557    ///         interior_padding: vec![1],
558    ///     })
559    ///     .unwrap();
560    ///
561    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0, 1.0, 0.0, 2.0, 0.0]);
562    /// ```
563    pub fn pad(&self, config: PadConfig) -> Result<Self> {
564        self.unary_op(StdTensorOp::Pad(config))
565    }
566
567    /// Reverse the order of elements along the requested axes.
568    ///
569    /// # Examples
570    ///
571    /// ```
572    /// use tenferro_cpu::CpuBackend;
573    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
574    ///
575    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
576    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
577    /// let y = x.reverse(&[0]).unwrap();
578    ///
579    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[4.0, 3.0, 2.0, 1.0]);
580    /// ```
581    pub fn reverse(&self, axes: &[usize]) -> Result<Self> {
582        self.unary_op(StdTensorOp::Reverse {
583            axes: axes.to_vec(),
584        })
585    }
586
587    /// Gather slices from `self` using integer start indices.
588    ///
589    /// # Examples
590    ///
591    /// ```
592    /// use tenferro_cpu::CpuBackend;
593    /// use tenferro_ad::{EagerRuntime, EagerTensor, GatherConfig, Tensor};
594    ///
595    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
596    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(
597    ///     vec![5],
598    ///     vec![10.0_f64, 20.0, 30.0, 40.0, 50.0],
599    /// ).unwrap(), ctx.clone()).unwrap();
600    /// let indices = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![4_i64, 1, 0]).unwrap(), ctx.clone()).unwrap();
601    /// let y = x
602    ///     .gather(
603    ///         &indices,
604    ///         GatherConfig {
605    ///             offset_dims: vec![],
606    ///             collapsed_slice_dims: vec![0],
607    ///             start_index_map: vec![0],
608    ///             index_vector_dim: 1,
609    ///             slice_sizes: vec![1],
610    ///         },
611    ///     )
612    ///     .unwrap();
613    ///
614    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[50.0, 20.0, 10.0]);
615    /// ```
616    pub fn gather(&self, indices: &Self, config: GatherConfig) -> Result<Self> {
617        self.binary_op(indices, StdTensorOp::Gather(config))
618    }
619
620    /// Scatter updates into `self` using StableHLO scatter semantics.
621    ///
622    /// # Examples
623    ///
624    /// ```
625    /// use tenferro_cpu::CpuBackend;
626    /// use tenferro_ad::{EagerRuntime, EagerTensor, ScatterConfig, Tensor};
627    ///
628    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
629    /// let operand = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![4], vec![0.0_f64, 0.0, 0.0, 0.0]).unwrap(), ctx.clone()).unwrap();
630    /// let indices = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 1], vec![1_i64, 3]).unwrap(), ctx.clone()).unwrap();
631    /// let updates = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![5.0_f64, 7.0]).unwrap(), ctx.clone()).unwrap();
632    /// let result = operand
633    ///     .scatter(
634    ///         &indices,
635    ///         &updates,
636    ///         ScatterConfig {
637    ///             update_window_dims: vec![],
638    ///             inserted_window_dims: vec![0],
639    ///             scatter_dims_to_operand_dims: vec![0],
640    ///             index_vector_dim: 1,
641    ///         },
642    ///     )
643    ///     .unwrap();
644    ///
645    /// assert_eq!(result.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0, 5.0, 0.0, 7.0]);
646    /// ```
647    pub fn scatter(&self, indices: &Self, updates: &Self, config: ScatterConfig) -> Result<Self> {
648        self.ternary_op(indices, updates, StdTensorOp::Scatter(config))
649    }
650
651    /// Slice using runtime start indices.
652    ///
653    /// # Examples
654    ///
655    /// ```
656    /// use tenferro_cpu::CpuBackend;
657    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
658    ///
659    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
660    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![5], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0]).unwrap(), ctx.clone()).unwrap();
661    /// let starts = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![2_i64]).unwrap(), ctx.clone()).unwrap();
662    /// let y = x.dynamic_slice(&starts, &[2]).unwrap();
663    ///
664    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0, 4.0]);
665    /// ```
666    pub fn dynamic_slice(&self, starts: &Self, sizes: &[usize]) -> Result<Self> {
667        self.binary_op(
668            starts,
669            StdTensorOp::DynamicSlice {
670                slice_sizes: sizes.to_vec(),
671            },
672        )
673    }
674
675    /// Concatenate tensors along one axis.
676    ///
677    /// # Examples
678    ///
679    /// ```
680    /// use tenferro_cpu::CpuBackend;
681    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
682    ///
683    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
684    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
685    /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
686    /// let z = EagerTensor::concatenate(&[&x, &y], 0).unwrap();
687    ///
688    /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
689    /// ```
690    pub fn concatenate(tensors: &[&Self], axis: usize) -> Result<Self> {
691        Self::nary_op(
692            tensors,
693            StdTensorOp::Concatenate {
694                axis,
695                input_count: tensors.len(),
696            },
697        )
698    }
699
700    /// Extract the diagonal along two axes.
701    ///
702    /// # Examples
703    ///
704    /// ```
705    /// use tenferro_cpu::CpuBackend;
706    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
707    ///
708    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
709    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(
710    ///     vec![3, 3],
711    ///     vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
712    /// ).unwrap(), ctx.clone()).unwrap();
713    /// let y = x.extract_diag(0, 1).unwrap();
714    ///
715    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 5.0, 9.0]);
716    /// ```
717    pub fn extract_diag(&self, axis_a: usize, axis_b: usize) -> Result<Self> {
718        self.unary_op(StdTensorOp::ExtractDiag { axis_a, axis_b })
719    }
720
721    /// Embed a vector or lower-rank tensor along a diagonal.
722    ///
723    /// # Examples
724    ///
725    /// ```
726    /// use tenferro_cpu::CpuBackend;
727    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
728    ///
729    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
730    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap(), ctx.clone()).unwrap();
731    /// let y = x.embed_diag(0, 1).unwrap();
732    ///
733    /// assert_eq!(y.shape(), &[3, 3]);
734    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
735    /// ```
736    pub fn embed_diag(&self, axis_a: usize, axis_b: usize) -> Result<Self> {
737        self.unary_op(StdTensorOp::EmbedDiag { axis_a, axis_b })
738    }
739
740    /// Keep the lower triangle and zero the rest.
741    ///
742    /// # Examples
743    ///
744    /// ```
745    /// use tenferro_cpu::CpuBackend;
746    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
747    ///
748    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
749    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
750    /// let y = x.tril(0).unwrap();
751    ///
752    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0, 0.0, 4.0]);
753    /// ```
754    pub fn tril(&self, k: i64) -> Result<Self> {
755        self.unary_op(StdTensorOp::Tril { k })
756    }
757
758    /// Keep the upper triangle and zero the rest.
759    ///
760    /// # Examples
761    ///
762    /// ```
763    /// use tenferro_cpu::CpuBackend;
764    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
765    ///
766    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
767    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
768    /// let y = x.triu(0).unwrap();
769    ///
770    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 0.0, 3.0, 4.0]);
771    /// ```
772    pub fn triu(&self, k: i64) -> Result<Self> {
773        self.unary_op(StdTensorOp::Triu { k })
774    }
775
776    /// Reduce product over the requested axes.
777    ///
778    /// # Examples
779    ///
780    /// ```
781    /// use tenferro_cpu::CpuBackend;
782    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
783    ///
784    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
785    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
786    /// let y = x.reduce_prod(&[0, 1]).unwrap();
787    ///
788    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[24.0]);
789    /// ```
790    pub fn reduce_prod(&self, axes: &[usize]) -> Result<Self> {
791        self.unary_op(StdTensorOp::ReduceProd {
792            axes: axes.to_vec(),
793        })
794    }
795
796    /// Reduce maximum over the requested axes.
797    ///
798    /// # Examples
799    ///
800    /// ```
801    /// use tenferro_cpu::CpuBackend;
802    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
803    ///
804    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
805    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
806    /// let y = x.reduce_max(&[0, 1]).unwrap();
807    ///
808    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[4.0]);
809    /// ```
810    pub fn reduce_max(&self, axes: &[usize]) -> Result<Self> {
811        self.unary_op(StdTensorOp::ReduceMax {
812            axes: axes.to_vec(),
813        })
814    }
815
816    /// Reduce minimum over the requested axes.
817    ///
818    /// # Examples
819    ///
820    /// ```
821    /// use tenferro_cpu::CpuBackend;
822    /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
823    ///
824    /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
825    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
826    /// let y = x.reduce_min(&[0, 1]).unwrap();
827    ///
828    /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0]);
829    /// ```
830    pub fn reduce_min(&self, axes: &[usize]) -> Result<Self> {
831        self.unary_op(StdTensorOp::ReduceMin {
832            axes: axes.to_vec(),
833        })
834    }
835
836    pub(crate) fn unary_op(&self, op: StdTensorOp) -> Result<Self> {
837        Self::nary_op(&[self], op)
838    }
839
840    pub(crate) fn binary_op(&self, other: &Self, op: StdTensorOp) -> Result<Self> {
841        Self::nary_op(&[self, other], op)
842    }
843
844    pub(crate) fn ternary_op(&self, b: &Self, c: &Self, op: StdTensorOp) -> Result<Self> {
845        Self::nary_op(&[self, b, c], op)
846    }
847
848    pub(crate) fn nary_value_op(
849        tensors: &[&Self],
850        op: StdTensorOp,
851        value: TensorValue,
852    ) -> Result<Self> {
853        let Some(first) = tensors.first() else {
854            return Err(empty_nary_input_error(&op));
855        };
856
857        let ctx = Arc::clone(&first.ctx);
858        for tensor in tensors.iter().skip(1) {
859            if !first.same_context(tensor) {
860                return Err(Error::ContextMismatch {
861                    lhs: first.ctx_id(),
862                    rhs: tensor.ctx_id(),
863                });
864            }
865        }
866
867        if !tensors.iter().any(|tensor| tensor.requires_grad) {
868            return Ok(Self::new_untracked_value_result(ctx, value));
869        }
870
871        let output = Arc::new(value.to_tensor().map_err(Error::from)?);
872        let outputs = vec![Arc::clone(&output)];
873        let mut recorded = record_eager_outputs(&op, &outputs, tensors)?;
874        let trace = recorded.traces.pop().ok_or_else(|| {
875            Error::Internal(format!("expected one eager trace for {:?}, got 0", op))
876        })?;
877        let mut metadata_scopes = vec![Arc::clone(&recorded.metadata_scope)];
878        for tensor in tensors {
879            for scope in &tensor.metadata_scopes {
880                push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
881            }
882        }
883
884        Self::new_result_value(
885            ctx,
886            trace.key,
887            value,
888            trace.requires_grad,
889            trace.trace,
890            metadata_scopes,
891        )
892    }
893
894    pub(crate) fn nary_op(tensors: &[&Self], op: StdTensorOp) -> Result<Self> {
895        let total_started = std::time::Instant::now();
896        let Some(first) = tensors.first() else {
897            return Err(empty_nary_input_error(&op));
898        };
899
900        let ctx = Arc::clone(&first.ctx);
901        profile_eager_op_section("nary_op.context_check", || -> Result<()> {
902            for tensor in tensors.iter().skip(1) {
903                if !first.same_context(tensor) {
904                    return Err(Error::ContextMismatch {
905                        lhs: first.ctx_id(),
906                        rhs: tensor.ctx_id(),
907                    });
908                }
909            }
910            Ok(())
911        })?;
912
913        let any_requires_grad = profile_eager_op_section("nary_op.requires_grad_scan", || {
914            tensors.iter().any(|tensor| tensor.requires_grad)
915        });
916        if !any_requires_grad {
917            let input_reads = profile_eager_op_section("nary_op.collect_input_reads", || {
918                tensors
919                    .iter()
920                    .map(|tensor| tensor.tensor_read())
921                    .collect::<Vec<_>>()
922            });
923            let output = profile_eager_op_section("nary_op.exec_single_output_read", || {
924                exec_single_output_read(&op, &input_reads, &ctx)
925            })?;
926            let result = profile_eager_op_section("nary_op.new_untracked_result", || {
927                Self::new_untracked_result(ctx, output)
928            });
929            record_eager_op_profile("nary_op.total", total_started.elapsed());
930            maybe_print_eager_op_profile();
931            return result;
932        }
933
934        let input_arcs = profile_eager_op_section("nary_op.materialize_inputs", || {
935            tensors
936                .iter()
937                .map(|tensor| tensor.materialized_arc())
938                .collect::<Result<Vec<_>>>()
939        })?;
940        let inputs: Vec<&Tensor> = profile_eager_op_section("nary_op.collect_inputs", || {
941            input_arcs.iter().map(|tensor| tensor.as_ref()).collect()
942        });
943        let output = profile_eager_op_section("nary_op.exec_single_output", || {
944            exec_single_output(&op, &inputs, &ctx)
945        })?;
946
947        let output = Arc::new(output);
948        let outputs = vec![Arc::clone(&output)];
949        let mut recorded = profile_eager_op_section("nary_op.record_outputs", || {
950            record_eager_outputs(&op, &outputs, tensors)
951        })?;
952        let trace = recorded.traces.pop().ok_or_else(|| {
953            Error::Internal(format!("expected one eager trace for {:?}, got 0", op))
954        })?;
955        let mut metadata_scopes = vec![Arc::clone(&recorded.metadata_scope)];
956        for tensor in tensors {
957            for scope in &tensor.metadata_scopes {
958                push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
959            }
960        }
961
962        let result = profile_eager_op_section("nary_op.new_tracked_result", || {
963            Self::new_result_arc(
964                ctx,
965                trace.key,
966                output,
967                trace.requires_grad,
968                trace.trace,
969                metadata_scopes,
970            )
971        });
972        record_eager_op_profile("nary_op.total", total_started.elapsed());
973        maybe_print_eager_op_profile();
974        result
975    }
976}
977
978fn empty_nary_input_error(op: &StdTensorOp) -> Error {
979    Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
980        op: eager_validation_op_name(op),
981        message: "operation requires at least one input tensor".to_string(),
982    })
983}
984
985fn eager_validation_op_name(op: &StdTensorOp) -> &'static str {
986    match op {
987        StdTensorOp::Concatenate { .. } => "concatenate",
988        _ => "eager_nary_op",
989    }
990}