tenferro_ad/eager_ops.rs
1use std::sync::Arc;
2
3use tenferro_ops::broadcast::{
4 broadcast_input_plan, broadcast_shape, broadcast_shapes, BroadcastError,
5};
6use tenferro_ops::dim_expr::DimExpr;
7use tenferro_ops::std_tensor_op::StdTensorOp;
8use tenferro_tensor::{
9 DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig, Tensor,
10 TensorValue,
11};
12
13use crate::eager::{
14 exec_single_output, exec_single_output_read, maybe_print_eager_op_profile,
15 profile_eager_op_section, record_eager_op_profile, record_eager_outputs, EagerTensor,
16};
17use crate::eager_exec::exec_dot_general_with_conj_on_tensor_reads;
18use crate::error::{Error, Result};
19use crate::metadata::push_metadata_scope;
20
21pub(crate) fn broadcast_binary(
22 op: &'static str,
23 lhs: &EagerTensor,
24 rhs: &EagerTensor,
25) -> Result<(EagerTensor, EagerTensor)> {
26 ensure_same_context(lhs, rhs)?;
27 let shape =
28 broadcast_shape(lhs.shape(), rhs.shape()).map_err(|err| broadcast_error(op, err))?;
29 Ok((
30 broadcast_to(op, lhs, &shape)?,
31 broadcast_to(op, rhs, &shape)?,
32 ))
33}
34
35pub(crate) fn broadcast_ternary(
36 op: &'static str,
37 first: &EagerTensor,
38 second: &EagerTensor,
39 third: &EagerTensor,
40) -> Result<(EagerTensor, EagerTensor, EagerTensor)> {
41 ensure_same_context(first, second)?;
42 ensure_same_context(first, third)?;
43 let shape = broadcast_shapes([first.shape(), second.shape(), third.shape()])
44 .map_err(|err| broadcast_error(op, err))?;
45 Ok((
46 broadcast_to(op, first, &shape)?,
47 broadcast_to(op, second, &shape)?,
48 broadcast_to(op, third, &shape)?,
49 ))
50}
51
52fn broadcast_to(
53 op: &'static str,
54 input: &EagerTensor,
55 target_shape: &[usize],
56) -> Result<EagerTensor> {
57 let input_shape = input.shape();
58 if input_shape == target_shape {
59 return Ok(input.clone());
60 }
61
62 let plan =
63 broadcast_input_plan(input_shape, target_shape).map_err(|err| broadcast_error(op, err))?;
64 let source = if plan.source_shape == input_shape {
65 input.clone()
66 } else {
67 input.reshape(&plan.source_shape)?
68 };
69 source.broadcast_in_dim(target_shape, &plan.dims)
70}
71
72fn broadcast_error(op: &'static str, err: BroadcastError) -> Error {
73 match err {
74 BroadcastError::IncompatibleBinary { lhs, rhs } => {
75 tenferro_tensor::Error::ShapeMismatch { op, lhs, rhs }.into()
76 }
77 BroadcastError::IncompatibleInput { input, output }
78 | BroadcastError::RankTooLarge { input, output } => tenferro_tensor::Error::InvalidConfig {
79 op,
80 message: format!("cannot broadcast shape {input:?} to {output:?}"),
81 }
82 .into(),
83 }
84}
85
86fn ensure_same_context(lhs: &EagerTensor, rhs: &EagerTensor) -> Result<()> {
87 if !lhs.same_context(rhs) {
88 return Err(Error::ContextMismatch {
89 lhs: lhs.ctx_id(),
90 rhs: rhs.ctx_id(),
91 });
92 }
93 Ok(())
94}
95
96impl std::ops::Add for &EagerTensor {
97 type Output = Result<EagerTensor>;
98
99 fn add(self, rhs: &EagerTensor) -> Result<EagerTensor> {
100 EagerTensor::add(self, rhs)
101 }
102}
103
104impl std::ops::Sub for &EagerTensor {
105 type Output = Result<EagerTensor>;
106
107 fn sub(self, rhs: &EagerTensor) -> Result<EagerTensor> {
108 EagerTensor::sub(self, rhs)
109 }
110}
111
112impl std::ops::Mul for &EagerTensor {
113 type Output = Result<EagerTensor>;
114
115 fn mul(self, rhs: &EagerTensor) -> Result<EagerTensor> {
116 EagerTensor::mul(self, rhs)
117 }
118}
119
120impl std::ops::Div for &EagerTensor {
121 type Output = Result<EagerTensor>;
122
123 fn div(self, rhs: &EagerTensor) -> Result<EagerTensor> {
124 EagerTensor::div(self, rhs)
125 }
126}
127
128impl std::ops::Neg for &EagerTensor {
129 type Output = Result<EagerTensor>;
130
131 fn neg(self) -> Result<EagerTensor> {
132 EagerTensor::neg(self)
133 }
134}
135
136impl EagerTensor {
137 /// Elementwise addition.
138 ///
139 /// # Examples
140 ///
141 /// ```
142 /// use tenferro_cpu::CpuBackend;
143 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
144 ///
145 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
146 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
147 /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
148 /// let z = x.add(&y).unwrap();
149 ///
150 /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[4.0, 6.0]);
151 /// ```
152 pub fn add(&self, other: &Self) -> Result<Self> {
153 let (lhs, rhs) = broadcast_binary("add", self, other)?;
154 lhs.binary_op(&rhs, StdTensorOp::Add)
155 }
156
157 /// Elementwise subtraction.
158 pub fn sub(&self, other: &Self) -> Result<Self> {
159 let (lhs, rhs) = broadcast_binary("sub", self, other)?;
160 let rhs = rhs.neg()?;
161 lhs.binary_op(&rhs, StdTensorOp::Add)
162 }
163
164 /// Elementwise multiplication.
165 ///
166 /// # Examples
167 ///
168 /// ```
169 /// use tenferro_cpu::CpuBackend;
170 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
171 ///
172 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
173 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
174 /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
175 /// let z = x.mul(&y).unwrap();
176 ///
177 /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0, 8.0]);
178 /// ```
179 pub fn mul(&self, other: &Self) -> Result<Self> {
180 let (lhs, rhs) = broadcast_binary("mul", self, other)?;
181 lhs.binary_op(&rhs, StdTensorOp::Mul)
182 }
183
184 /// Negate the tensor.
185 ///
186 /// # Examples
187 ///
188 /// ```
189 /// use tenferro_cpu::CpuBackend;
190 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
191 ///
192 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
193 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, -2.0]).unwrap(), ctx.clone()).unwrap();
194 /// let y = x.neg().unwrap();
195 ///
196 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[-1.0, 2.0]);
197 /// ```
198 pub fn neg(&self) -> Result<Self> {
199 self.unary_op(StdTensorOp::Neg)
200 }
201
202 /// Elementwise exponential.
203 ///
204 /// # Examples
205 ///
206 /// ```
207 /// use tenferro_cpu::CpuBackend;
208 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
209 ///
210 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
211 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![0.0_f64]).unwrap(), ctx.clone()).unwrap();
212 /// let y = x.exp().unwrap();
213 ///
214 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0]);
215 /// ```
216 pub fn exp(&self) -> Result<Self> {
217 self.unary_op(StdTensorOp::Exp)
218 }
219
220 /// Reduce sum over the requested axes.
221 ///
222 /// # Examples
223 ///
224 /// ```
225 /// use tenferro_cpu::CpuBackend;
226 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
227 ///
228 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
229 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
230 /// let y = x.reduce_sum(&[0, 1]).unwrap();
231 ///
232 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[10.0]);
233 /// ```
234 pub fn reduce_sum(&self, axes: &[usize]) -> Result<Self> {
235 self.unary_op(StdTensorOp::ReduceSum {
236 axes: axes.to_vec(),
237 })
238 }
239
240 /// Execute a dot-general contraction eagerly.
241 ///
242 /// # Examples
243 ///
244 /// ```
245 /// use tenferro_cpu::CpuBackend;
246 /// use tenferro_ad::{DotGeneralConfig, EagerRuntime, EagerTensor, Tensor};
247 ///
248 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
249 /// let a = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), ctx.clone()).unwrap();
250 /// let b = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), ctx.clone()).unwrap();
251 /// let c = a.dot_general(&b, DotGeneralConfig {
252 /// lhs_contracting_dims: vec![1],
253 /// rhs_contracting_dims: vec![0],
254 /// lhs_batch_dims: vec![],
255 /// rhs_batch_dims: vec![],
256 /// }).unwrap();
257 ///
258 /// assert_eq!(c.shape(), &[2, 2]);
259 /// ```
260 pub fn dot_general(&self, other: &Self, config: DotGeneralConfig) -> Result<Self> {
261 self.binary_op(other, StdTensorOp::DotGeneral { config })
262 }
263
264 /// Execute a dot-general contraction, optionally conjugating either operand.
265 ///
266 /// Untracked tensors route the conjugation flags directly to the backend so
267 /// the conjugated operand does not need to be materialized. Tracked tensors
268 /// fall back to explicit `Conj` plus `DotGeneral` so reverse-mode AD keeps
269 /// the same graph semantics as the standard eager ops.
270 pub fn dot_general_with_conj(
271 &self,
272 other: &Self,
273 config: &DotGeneralConfig,
274 lhs_conj: bool,
275 rhs_conj: bool,
276 ) -> Result<Self> {
277 if !self.same_context(other) {
278 return Err(Error::ContextMismatch {
279 lhs: self.ctx_id(),
280 rhs: other.ctx_id(),
281 });
282 }
283
284 if !self.requires_grad && !other.requires_grad {
285 let ctx = Arc::clone(&self.ctx);
286 let output = ctx.with_backend_mut(|backend| {
287 exec_dot_general_with_conj_on_tensor_reads(
288 self.tensor_read(),
289 other.tensor_read(),
290 config,
291 lhs_conj,
292 rhs_conj,
293 backend,
294 )
295 })??;
296 return Self::new_untracked_result(ctx, output);
297 }
298
299 match (lhs_conj, rhs_conj) {
300 (false, false) => self.dot_general(other, config.clone()),
301 (true, false) => self.conj()?.dot_general(other, config.clone()),
302 (false, true) => {
303 let rhs = other.conj()?;
304 self.dot_general(&rhs, config.clone())
305 }
306 (true, true) => {
307 let lhs = self.conj()?;
308 let rhs = other.conj()?;
309 lhs.dot_general(&rhs, config.clone())
310 }
311 }
312 }
313
314 /// Matrix multiplication for rank-2 tensors.
315 ///
316 /// This is a convenience wrapper over [`Self::dot_general`] that
317 /// contracts the left matrix's column axis with the right matrix's row
318 /// axis.
319 ///
320 /// # Examples
321 ///
322 /// ```
323 /// use tenferro_cpu::CpuBackend;
324 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
325 ///
326 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
327 /// let a = EagerTensor::from_tensor_in(
328 /// Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(),
329 /// ctx.clone(),
330 /// ).unwrap();
331 /// let b = EagerTensor::from_tensor_in(
332 /// Tensor::from_vec_col_major(vec![2, 1], vec![5.0_f64, 6.0]).unwrap(),
333 /// ctx,
334 /// ).unwrap();
335 /// let c = a.matmul(&b).unwrap();
336 ///
337 /// assert_eq!(c.shape(), &[2, 1]);
338 /// assert_eq!(c.materialized().unwrap().as_slice::<f64>().unwrap(), &[23.0, 34.0]);
339 /// ```
340 pub fn matmul(&self, other: &Self) -> Result<Self> {
341 let lhs_shape = self.shape();
342 let rhs_shape = other.shape();
343 if lhs_shape.len() != 2 {
344 return Err(tenferro_tensor::Error::RankMismatch {
345 op: "matmul",
346 expected: 2,
347 actual: lhs_shape.len(),
348 }
349 .into());
350 }
351 if rhs_shape.len() != 2 {
352 return Err(tenferro_tensor::Error::RankMismatch {
353 op: "matmul",
354 expected: 2,
355 actual: rhs_shape.len(),
356 }
357 .into());
358 }
359 if lhs_shape[1] != rhs_shape[0] {
360 return Err(tenferro_tensor::Error::ShapeMismatch {
361 op: "matmul",
362 lhs: lhs_shape.to_vec(),
363 rhs: rhs_shape.to_vec(),
364 }
365 .into());
366 }
367 self.dot_general(
368 other,
369 DotGeneralConfig {
370 lhs_contracting_dims: vec![1],
371 rhs_contracting_dims: vec![0],
372 lhs_batch_dims: vec![],
373 rhs_batch_dims: vec![],
374 },
375 )
376 }
377
378 /// Permute tensor axes.
379 ///
380 /// # Examples
381 ///
382 /// ```
383 /// use tenferro_cpu::CpuBackend;
384 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
385 ///
386 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
387 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(
388 /// vec![2, 3],
389 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
390 /// ).unwrap(), ctx.clone()).unwrap();
391 /// let y = x.transpose(&[1, 0]).unwrap();
392 ///
393 /// assert_eq!(y.shape(), &[3, 2]);
394 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
395 /// ```
396 pub fn transpose(&self, perm: &[usize]) -> Result<Self> {
397 let op = StdTensorOp::Transpose {
398 perm: perm.to_vec(),
399 };
400 let value = self
401 .value
402 .transpose_view(perm)
403 .map_err(Error::TensorRuntime)?;
404 Self::nary_value_op(&[self], op, value)
405 }
406
407 /// Reshape without changing element order.
408 ///
409 /// # Examples
410 ///
411 /// ```
412 /// use tenferro_cpu::CpuBackend;
413 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
414 ///
415 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
416 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(
417 /// vec![2, 3],
418 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
419 /// ).unwrap(), ctx.clone()).unwrap();
420 /// let y = x.reshape(&[6]).unwrap();
421 ///
422 /// assert_eq!(y.shape(), &[6]);
423 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
424 /// ```
425 pub fn reshape(&self, shape: &[usize]) -> Result<Self> {
426 let op = StdTensorOp::Reshape {
427 to_shape: DimExpr::from_concrete(shape),
428 };
429 if let Ok(value) = self.value.reshape_view(shape) {
430 return Self::nary_value_op(&[self], op, value);
431 }
432 self.unary_op(op)
433 }
434
435 /// Slice with explicit start, limit, and stride per axis.
436 ///
437 /// # Examples
438 ///
439 /// ```
440 /// use tenferro_cpu::CpuBackend;
441 /// use tenferro_ad::{EagerRuntime, EagerTensor, SliceConfig, Tensor};
442 ///
443 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
444 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
445 /// let y = x
446 /// .slice(SliceConfig {
447 /// starts: vec![1],
448 /// limits: vec![3],
449 /// strides: vec![1],
450 /// })
451 /// .unwrap();
452 ///
453 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[2.0, 3.0]);
454 /// ```
455 pub fn slice(&self, config: SliceConfig) -> Result<Self> {
456 let value = self
457 .value
458 .slice_view(&config)
459 .map_err(Error::TensorRuntime)?;
460 Self::nary_value_op(&[self], StdTensorOp::Slice(config), value)
461 }
462
463 /// Broadcast into a larger shape with explicit dimension placement.
464 ///
465 /// # Examples
466 ///
467 /// ```
468 /// use tenferro_cpu::CpuBackend;
469 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
470 ///
471 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
472 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap(), ctx.clone()).unwrap();
473 /// let y = x.broadcast_in_dim(&[3, 2], &[0]).unwrap();
474 ///
475 /// assert_eq!(y.shape(), &[3, 2]);
476 /// ```
477 pub fn broadcast_in_dim(&self, shape: &[usize], dims: &[usize]) -> Result<Self> {
478 let op = StdTensorOp::BroadcastInDim {
479 shape: DimExpr::from_concrete(shape),
480 dims: dims.to_vec(),
481 };
482 let value = self
483 .value
484 .broadcast_in_dim_view(shape, dims)
485 .map_err(Error::TensorRuntime)?;
486 Self::nary_value_op(&[self], op, value)
487 }
488
489 /// Convert the tensor to a different dtype using checked conversion.
490 ///
491 /// Use [`cast`](Self::cast) when a lossy dtype projection is intended.
492 ///
493 /// # Examples
494 ///
495 /// ```
496 /// use tenferro_cpu::CpuBackend;
497 /// use tenferro_ad::{DType, EagerRuntime, EagerTensor, Tensor};
498 ///
499 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
500 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, -2.0]).unwrap(), ctx.clone()).unwrap();
501 /// let y = x.convert(DType::C64).unwrap();
502 ///
503 /// assert_eq!(y.dtype(), DType::C64);
504 /// assert_eq!(y.shape(), &[2]);
505 /// ```
506 ///
507 /// # Errors
508 ///
509 /// Returns an error when the requested conversion is outside tenferro's
510 /// checked dtype-promotion lattice. Use [`cast`](Self::cast) for explicit
511 /// lossy dtype projection.
512 pub fn convert(&self, to: DType) -> Result<Self> {
513 tenferro_tensor::validate::validate_convert_dtype("EagerTensor::convert", self.dtype(), to)
514 .map_err(Error::TensorRuntime)?;
515 self.cast(to)
516 }
517
518 /// Cast the tensor to a different dtype using explicit dtype projection.
519 ///
520 /// `cast` may truncate, narrow precision, project complex values to their
521 /// real component, or use boolean truthiness where the backend supports the
522 /// requested projection.
523 ///
524 /// # Examples
525 ///
526 /// ```
527 /// use tenferro_cpu::CpuBackend;
528 /// use tenferro_ad::{DType, EagerRuntime, EagerTensor, Tensor};
529 ///
530 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
531 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.2_f64, -2.8]).unwrap(), ctx.clone()).unwrap();
532 /// let y = x.cast(DType::I32).unwrap();
533 ///
534 /// assert_eq!(y.materialized().unwrap().as_slice::<i32>().unwrap(), &[1, -2]);
535 /// ```
536 pub fn cast(&self, to: DType) -> Result<Self> {
537 self.unary_op(StdTensorOp::Convert {
538 from: self.dtype(),
539 to,
540 })
541 }
542
543 /// Pad with zeros using StableHLO-style edge and interior padding.
544 ///
545 /// # Examples
546 ///
547 /// ```
548 /// use tenferro_cpu::CpuBackend;
549 /// use tenferro_ad::{EagerRuntime, EagerTensor, PadConfig, Tensor};
550 ///
551 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
552 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
553 /// let y = x
554 /// .pad(PadConfig {
555 /// edge_padding_low: vec![1],
556 /// edge_padding_high: vec![1],
557 /// interior_padding: vec![1],
558 /// })
559 /// .unwrap();
560 ///
561 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0, 1.0, 0.0, 2.0, 0.0]);
562 /// ```
563 pub fn pad(&self, config: PadConfig) -> Result<Self> {
564 self.unary_op(StdTensorOp::Pad(config))
565 }
566
567 /// Reverse the order of elements along the requested axes.
568 ///
569 /// # Examples
570 ///
571 /// ```
572 /// use tenferro_cpu::CpuBackend;
573 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
574 ///
575 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
576 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![4], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
577 /// let y = x.reverse(&[0]).unwrap();
578 ///
579 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[4.0, 3.0, 2.0, 1.0]);
580 /// ```
581 pub fn reverse(&self, axes: &[usize]) -> Result<Self> {
582 self.unary_op(StdTensorOp::Reverse {
583 axes: axes.to_vec(),
584 })
585 }
586
587 /// Gather slices from `self` using integer start indices.
588 ///
589 /// # Examples
590 ///
591 /// ```
592 /// use tenferro_cpu::CpuBackend;
593 /// use tenferro_ad::{EagerRuntime, EagerTensor, GatherConfig, Tensor};
594 ///
595 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
596 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(
597 /// vec![5],
598 /// vec![10.0_f64, 20.0, 30.0, 40.0, 50.0],
599 /// ).unwrap(), ctx.clone()).unwrap();
600 /// let indices = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![4_i64, 1, 0]).unwrap(), ctx.clone()).unwrap();
601 /// let y = x
602 /// .gather(
603 /// &indices,
604 /// GatherConfig {
605 /// offset_dims: vec![],
606 /// collapsed_slice_dims: vec![0],
607 /// start_index_map: vec![0],
608 /// index_vector_dim: 1,
609 /// slice_sizes: vec![1],
610 /// },
611 /// )
612 /// .unwrap();
613 ///
614 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[50.0, 20.0, 10.0]);
615 /// ```
616 pub fn gather(&self, indices: &Self, config: GatherConfig) -> Result<Self> {
617 self.binary_op(indices, StdTensorOp::Gather(config))
618 }
619
620 /// Scatter updates into `self` using StableHLO scatter semantics.
621 ///
622 /// # Examples
623 ///
624 /// ```
625 /// use tenferro_cpu::CpuBackend;
626 /// use tenferro_ad::{EagerRuntime, EagerTensor, ScatterConfig, Tensor};
627 ///
628 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
629 /// let operand = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![4], vec![0.0_f64, 0.0, 0.0, 0.0]).unwrap(), ctx.clone()).unwrap();
630 /// let indices = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 1], vec![1_i64, 3]).unwrap(), ctx.clone()).unwrap();
631 /// let updates = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![5.0_f64, 7.0]).unwrap(), ctx.clone()).unwrap();
632 /// let result = operand
633 /// .scatter(
634 /// &indices,
635 /// &updates,
636 /// ScatterConfig {
637 /// update_window_dims: vec![],
638 /// inserted_window_dims: vec![0],
639 /// scatter_dims_to_operand_dims: vec![0],
640 /// index_vector_dim: 1,
641 /// },
642 /// )
643 /// .unwrap();
644 ///
645 /// assert_eq!(result.materialized().unwrap().as_slice::<f64>().unwrap(), &[0.0, 5.0, 0.0, 7.0]);
646 /// ```
647 pub fn scatter(&self, indices: &Self, updates: &Self, config: ScatterConfig) -> Result<Self> {
648 self.ternary_op(indices, updates, StdTensorOp::Scatter(config))
649 }
650
651 /// Slice using runtime start indices.
652 ///
653 /// # Examples
654 ///
655 /// ```
656 /// use tenferro_cpu::CpuBackend;
657 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
658 ///
659 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
660 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![5], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0]).unwrap(), ctx.clone()).unwrap();
661 /// let starts = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![1], vec![2_i64]).unwrap(), ctx.clone()).unwrap();
662 /// let y = x.dynamic_slice(&starts, &[2]).unwrap();
663 ///
664 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[3.0, 4.0]);
665 /// ```
666 pub fn dynamic_slice(&self, starts: &Self, sizes: &[usize]) -> Result<Self> {
667 self.binary_op(
668 starts,
669 StdTensorOp::DynamicSlice {
670 slice_sizes: sizes.to_vec(),
671 },
672 )
673 }
674
675 /// Concatenate tensors along one axis.
676 ///
677 /// # Examples
678 ///
679 /// ```
680 /// use tenferro_cpu::CpuBackend;
681 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
682 ///
683 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
684 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap(), ctx.clone()).unwrap();
685 /// let y = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2], vec![3.0_f64, 4.0]).unwrap(), ctx.clone()).unwrap();
686 /// let z = EagerTensor::concatenate(&[&x, &y], 0).unwrap();
687 ///
688 /// assert_eq!(z.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
689 /// ```
690 pub fn concatenate(tensors: &[&Self], axis: usize) -> Result<Self> {
691 Self::nary_op(
692 tensors,
693 StdTensorOp::Concatenate {
694 axis,
695 input_count: tensors.len(),
696 },
697 )
698 }
699
700 /// Extract the diagonal along two axes.
701 ///
702 /// # Examples
703 ///
704 /// ```
705 /// use tenferro_cpu::CpuBackend;
706 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
707 ///
708 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
709 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(
710 /// vec![3, 3],
711 /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
712 /// ).unwrap(), ctx.clone()).unwrap();
713 /// let y = x.extract_diag(0, 1).unwrap();
714 ///
715 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 5.0, 9.0]);
716 /// ```
717 pub fn extract_diag(&self, axis_a: usize, axis_b: usize) -> Result<Self> {
718 self.unary_op(StdTensorOp::ExtractDiag { axis_a, axis_b })
719 }
720
721 /// Embed a vector or lower-rank tensor along a diagonal.
722 ///
723 /// # Examples
724 ///
725 /// ```
726 /// use tenferro_cpu::CpuBackend;
727 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
728 ///
729 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
730 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![3], vec![1.0_f64, 2.0, 3.0]).unwrap(), ctx.clone()).unwrap();
731 /// let y = x.embed_diag(0, 1).unwrap();
732 ///
733 /// assert_eq!(y.shape(), &[3, 3]);
734 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
735 /// ```
736 pub fn embed_diag(&self, axis_a: usize, axis_b: usize) -> Result<Self> {
737 self.unary_op(StdTensorOp::EmbedDiag { axis_a, axis_b })
738 }
739
740 /// Keep the lower triangle and zero the rest.
741 ///
742 /// # Examples
743 ///
744 /// ```
745 /// use tenferro_cpu::CpuBackend;
746 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
747 ///
748 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
749 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
750 /// let y = x.tril(0).unwrap();
751 ///
752 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0, 0.0, 4.0]);
753 /// ```
754 pub fn tril(&self, k: i64) -> Result<Self> {
755 self.unary_op(StdTensorOp::Tril { k })
756 }
757
758 /// Keep the upper triangle and zero the rest.
759 ///
760 /// # Examples
761 ///
762 /// ```
763 /// use tenferro_cpu::CpuBackend;
764 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
765 ///
766 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
767 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
768 /// let y = x.triu(0).unwrap();
769 ///
770 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0, 0.0, 3.0, 4.0]);
771 /// ```
772 pub fn triu(&self, k: i64) -> Result<Self> {
773 self.unary_op(StdTensorOp::Triu { k })
774 }
775
776 /// Reduce product over the requested axes.
777 ///
778 /// # Examples
779 ///
780 /// ```
781 /// use tenferro_cpu::CpuBackend;
782 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
783 ///
784 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
785 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
786 /// let y = x.reduce_prod(&[0, 1]).unwrap();
787 ///
788 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[24.0]);
789 /// ```
790 pub fn reduce_prod(&self, axes: &[usize]) -> Result<Self> {
791 self.unary_op(StdTensorOp::ReduceProd {
792 axes: axes.to_vec(),
793 })
794 }
795
796 /// Reduce maximum over the requested axes.
797 ///
798 /// # Examples
799 ///
800 /// ```
801 /// use tenferro_cpu::CpuBackend;
802 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
803 ///
804 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
805 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
806 /// let y = x.reduce_max(&[0, 1]).unwrap();
807 ///
808 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[4.0]);
809 /// ```
810 pub fn reduce_max(&self, axes: &[usize]) -> Result<Self> {
811 self.unary_op(StdTensorOp::ReduceMax {
812 axes: axes.to_vec(),
813 })
814 }
815
816 /// Reduce minimum over the requested axes.
817 ///
818 /// # Examples
819 ///
820 /// ```
821 /// use tenferro_cpu::CpuBackend;
822 /// use tenferro_ad::{EagerRuntime, EagerTensor, Tensor};
823 ///
824 /// let ctx = EagerRuntime::with_cpu_backend(CpuBackend::new());
825 /// let x = EagerTensor::from_tensor_in(Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]).unwrap(), ctx.clone()).unwrap();
826 /// let y = x.reduce_min(&[0, 1]).unwrap();
827 ///
828 /// assert_eq!(y.materialized().unwrap().as_slice::<f64>().unwrap(), &[1.0]);
829 /// ```
830 pub fn reduce_min(&self, axes: &[usize]) -> Result<Self> {
831 self.unary_op(StdTensorOp::ReduceMin {
832 axes: axes.to_vec(),
833 })
834 }
835
836 pub(crate) fn unary_op(&self, op: StdTensorOp) -> Result<Self> {
837 Self::nary_op(&[self], op)
838 }
839
840 pub(crate) fn binary_op(&self, other: &Self, op: StdTensorOp) -> Result<Self> {
841 Self::nary_op(&[self, other], op)
842 }
843
844 pub(crate) fn ternary_op(&self, b: &Self, c: &Self, op: StdTensorOp) -> Result<Self> {
845 Self::nary_op(&[self, b, c], op)
846 }
847
848 pub(crate) fn nary_value_op(
849 tensors: &[&Self],
850 op: StdTensorOp,
851 value: TensorValue,
852 ) -> Result<Self> {
853 let Some(first) = tensors.first() else {
854 return Err(empty_nary_input_error(&op));
855 };
856
857 let ctx = Arc::clone(&first.ctx);
858 for tensor in tensors.iter().skip(1) {
859 if !first.same_context(tensor) {
860 return Err(Error::ContextMismatch {
861 lhs: first.ctx_id(),
862 rhs: tensor.ctx_id(),
863 });
864 }
865 }
866
867 if !tensors.iter().any(|tensor| tensor.requires_grad) {
868 return Ok(Self::new_untracked_value_result(ctx, value));
869 }
870
871 let output = Arc::new(value.to_tensor().map_err(Error::from)?);
872 let outputs = vec![Arc::clone(&output)];
873 let mut recorded = record_eager_outputs(&op, &outputs, tensors)?;
874 let trace = recorded.traces.pop().ok_or_else(|| {
875 Error::Internal(format!("expected one eager trace for {:?}, got 0", op))
876 })?;
877 let mut metadata_scopes = vec![Arc::clone(&recorded.metadata_scope)];
878 for tensor in tensors {
879 for scope in &tensor.metadata_scopes {
880 push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
881 }
882 }
883
884 Self::new_result_value(
885 ctx,
886 trace.key,
887 value,
888 trace.requires_grad,
889 trace.trace,
890 metadata_scopes,
891 )
892 }
893
894 pub(crate) fn nary_op(tensors: &[&Self], op: StdTensorOp) -> Result<Self> {
895 let total_started = std::time::Instant::now();
896 let Some(first) = tensors.first() else {
897 return Err(empty_nary_input_error(&op));
898 };
899
900 let ctx = Arc::clone(&first.ctx);
901 profile_eager_op_section("nary_op.context_check", || -> Result<()> {
902 for tensor in tensors.iter().skip(1) {
903 if !first.same_context(tensor) {
904 return Err(Error::ContextMismatch {
905 lhs: first.ctx_id(),
906 rhs: tensor.ctx_id(),
907 });
908 }
909 }
910 Ok(())
911 })?;
912
913 let any_requires_grad = profile_eager_op_section("nary_op.requires_grad_scan", || {
914 tensors.iter().any(|tensor| tensor.requires_grad)
915 });
916 if !any_requires_grad {
917 let input_reads = profile_eager_op_section("nary_op.collect_input_reads", || {
918 tensors
919 .iter()
920 .map(|tensor| tensor.tensor_read())
921 .collect::<Vec<_>>()
922 });
923 let output = profile_eager_op_section("nary_op.exec_single_output_read", || {
924 exec_single_output_read(&op, &input_reads, &ctx)
925 })?;
926 let result = profile_eager_op_section("nary_op.new_untracked_result", || {
927 Self::new_untracked_result(ctx, output)
928 });
929 record_eager_op_profile("nary_op.total", total_started.elapsed());
930 maybe_print_eager_op_profile();
931 return result;
932 }
933
934 let input_arcs = profile_eager_op_section("nary_op.materialize_inputs", || {
935 tensors
936 .iter()
937 .map(|tensor| tensor.materialized_arc())
938 .collect::<Result<Vec<_>>>()
939 })?;
940 let inputs: Vec<&Tensor> = profile_eager_op_section("nary_op.collect_inputs", || {
941 input_arcs.iter().map(|tensor| tensor.as_ref()).collect()
942 });
943 let output = profile_eager_op_section("nary_op.exec_single_output", || {
944 exec_single_output(&op, &inputs, &ctx)
945 })?;
946
947 let output = Arc::new(output);
948 let outputs = vec![Arc::clone(&output)];
949 let mut recorded = profile_eager_op_section("nary_op.record_outputs", || {
950 record_eager_outputs(&op, &outputs, tensors)
951 })?;
952 let trace = recorded.traces.pop().ok_or_else(|| {
953 Error::Internal(format!("expected one eager trace for {:?}, got 0", op))
954 })?;
955 let mut metadata_scopes = vec![Arc::clone(&recorded.metadata_scope)];
956 for tensor in tensors {
957 for scope in &tensor.metadata_scopes {
958 push_metadata_scope(&mut metadata_scopes, Arc::clone(scope));
959 }
960 }
961
962 let result = profile_eager_op_section("nary_op.new_tracked_result", || {
963 Self::new_result_arc(
964 ctx,
965 trace.key,
966 output,
967 trace.requires_grad,
968 trace.trace,
969 metadata_scopes,
970 )
971 });
972 record_eager_op_profile("nary_op.total", total_started.elapsed());
973 maybe_print_eager_op_profile();
974 result
975 }
976}
977
978fn empty_nary_input_error(op: &StdTensorOp) -> Error {
979 Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
980 op: eager_validation_op_name(op),
981 message: "operation requires at least one input tensor".to_string(),
982 })
983}
984
985fn eager_validation_op_name(op: &StdTensorOp) -> &'static str {
986 match op {
987 StdTensorOp::Concatenate { .. } => "concatenate",
988 _ => "eager_nary_op",
989 }
990}