tenferro_einsum/typed_eager.rs
1use tenferro_tensor::{Error, Result, Tensor, TensorBackend, TensorScalar, TypedTensor};
2
3use crate::eager_einsum;
4
5/// Execute eager einsum over typed tensors and return a typed result.
6///
7/// # Examples
8///
9/// ```
10/// use tenferro_einsum::typed_eager_einsum;
11/// use tenferro_tensor::{cpu::CpuBackend, TypedTensor};
12///
13/// let mut ctx = CpuBackend::new();
14/// let a = TypedTensor::<f64>::from_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
15/// let b = TypedTensor::<f64>::from_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
16///
17/// let c = typed_eager_einsum(&mut ctx, &[&a, &b], "ij,jk->ik").unwrap();
18///
19/// assert_eq!(c.shape, vec![2, 2]);
20/// assert_eq!(c.as_slice(), &[22.0, 28.0, 49.0, 64.0]);
21/// ```
22pub fn typed_eager_einsum<T: TensorScalar>(
23 ctx: &mut impl TensorBackend,
24 inputs: &[&TypedTensor<T>],
25 subscripts: &str,
26) -> Result<TypedTensor<T>> {
27 let tensors: Vec<Tensor> = inputs
28 .iter()
29 .map(|tensor| T::into_tensor(tensor.shape.clone(), tensor.host_data().to_vec()))
30 .collect();
31 let refs: Vec<&Tensor> = tensors.iter().collect();
32 let result = eager_einsum(ctx, &refs, subscripts)?;
33 let actual = result.dtype();
34 T::try_into_typed(result).ok_or_else(|| Error::DTypeMismatch {
35 op: "typed_eager_einsum",
36 lhs: actual,
37 rhs: T::dtype(),
38 })
39}