tenferro/eager_einsum.rs
1//! Eager einsum helpers exposed through the `tenferro` facade.
2//!
3//! This module provides both immediate eager execution on concrete tensors and
4//! eager reverse-mode autodiff on tracked tensors.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use tenferro::eager_einsum::{eager_einsum, eager_einsum_ad};
10//! use tenferro::{CpuBackend, EagerTensor, Tensor};
11//!
12//! let mut backend = CpuBackend::new();
13//! let a = Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]);
14//! let b = Tensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]);
15//! let dot = eager_einsum(&mut backend, &[&a, &b], "i,i->").unwrap();
16//!
17//! assert_eq!(dot.as_slice::<f64>().unwrap(), &[32.0]);
18//!
19//! let x = EagerTensor::requires_grad(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]));
20//! let y = EagerTensor::requires_grad(Tensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]));
21//! let loss = eager_einsum_ad(&[&x, &y], "i,i->").unwrap();
22//! let _ = loss.backward().unwrap();
23//!
24//! assert_eq!(x.grad().unwrap().as_slice::<f64>().unwrap(), &[4.0, 5.0, 6.0]);
25//! assert_eq!(y.grad().unwrap().as_slice::<f64>().unwrap(), &[1.0, 2.0, 3.0]);
26//! ```
27
28use tenferro_ops::std_tensor_op::StdTensorOp;
29use tenferro_tensor::TensorBackend;
30
31use crate::eager::EagerTensor;
32use crate::error::Result;
33
34pub use tenferro_einsum::eager_einsum;
35
36/// Execute an einsum eagerly and record it for reverse-mode autodiff.
37///
38/// # Examples
39///
40/// ```
41/// use tenferro::eager_einsum::eager_einsum_ad;
42/// use tenferro::{EagerTensor, Tensor};
43///
44/// let a = EagerTensor::from_tensor(Tensor::from_vec(
45/// vec![2, 3],
46/// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
47/// ));
48/// let b = EagerTensor::from_tensor(Tensor::from_vec(
49/// vec![3, 2],
50/// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0],
51/// ));
52/// let c = eager_einsum_ad(&[&a, &b], "ij,jk->ik").unwrap();
53///
54/// assert_eq!(c.data().shape(), &[2, 2]);
55/// assert_eq!(c.data().as_slice::<f64>().unwrap(), &[22.0, 28.0, 49.0, 64.0]);
56/// ```
57pub fn eager_einsum_ad<B: TensorBackend>(
58 inputs: &[&EagerTensor<B>],
59 subscripts: &str,
60) -> Result<EagerTensor<B>> {
61 EagerTensor::nary_op(
62 inputs,
63 StdTensorOp::NaryEinsum {
64 subscripts: subscripts.to_string(),
65 n_inputs: inputs.len(),
66 },
67 )
68}