Skip to main content

tenferro_einsum/
typed_eager.rs

1use tenferro_tensor::{Error, Result, Tensor, TensorBackend, TensorScalar, TypedTensor};
2
3use crate::eager_einsum;
4
5/// Execute eager einsum over typed tensors and return a typed result.
6///
7/// # Examples
8///
9/// ```
10/// use tenferro_einsum::typed_eager_einsum;
11/// use tenferro_tensor::{cpu::CpuBackend, TypedTensor};
12///
13/// let mut ctx = CpuBackend::new();
14/// let a = TypedTensor::<f64>::from_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
15/// let b = TypedTensor::<f64>::from_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
16///
17/// let c = typed_eager_einsum(&mut ctx, &[&a, &b], "ij,jk->ik").unwrap();
18///
19/// assert_eq!(c.shape, vec![2, 2]);
20/// assert_eq!(c.as_slice(), &[22.0, 28.0, 49.0, 64.0]);
21/// ```
22pub fn typed_eager_einsum<T: TensorScalar>(
23    ctx: &mut impl TensorBackend,
24    inputs: &[&TypedTensor<T>],
25    subscripts: &str,
26) -> Result<TypedTensor<T>> {
27    let tensors: Vec<Tensor> = inputs
28        .iter()
29        .map(|tensor| T::into_tensor(tensor.shape.clone(), tensor.host_data().to_vec()))
30        .collect();
31    let refs: Vec<&Tensor> = tensors.iter().collect();
32    let result = eager_einsum(ctx, &refs, subscripts)?;
33    let actual = result.dtype();
34    T::try_into_typed(result).ok_or_else(|| Error::DTypeMismatch {
35        op: "typed_eager_einsum",
36        lhs: actual,
37        rhs: T::dtype(),
38    })
39}