pub trait TensorNetworkOps: Backend {
// Required method
fn tn_einsum(
subscripts: &str,
inputs: Vec<FloatTensor<Self>>,
) -> FloatTensor<Self>;
}Expand description
Trait for backends that support tenferro tensor network operations.
Implement this trait for a Burn backend to enable einsum and other
tensor network primitives on that backend’s tensors.
§Examples
ⓘ
use burn::backend::NdArray;
use tenferro_burn::TensorNetworkOps;
// NdArray<f64> implements TensorNetworkOps
let result = <NdArray<f64> as TensorNetworkOps>::tn_einsum(
"ij,jk->ik",
vec![a_primitive, b_primitive],
);Required Methods§
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.
Implementations on Foreign Types§
Source§impl TensorNetworkOps for NdArray<f64>
impl TensorNetworkOps for NdArray<f64>
Source§impl<B, C> TensorNetworkOps for Autodiff<B, C>where
B: TensorNetworkOps,
C: CheckpointStrategy,
impl<B, C> TensorNetworkOps for Autodiff<B, C>where
B: TensorNetworkOps,
C: CheckpointStrategy,
Source§fn tn_einsum(
_subscripts: &str,
_inputs: Vec<FloatTensor<Self>>,
) -> FloatTensor<Self>
fn tn_einsum( _subscripts: &str, _inputs: Vec<FloatTensor<Self>>, ) -> FloatTensor<Self>
Perform an einsum contraction, recording the operation on the autodiff tape so that gradients can be computed during the backward pass.
§Future implementation plan
The backward pass will invoke tenferro’s einsum_rrule to obtain the
VJP for each input tensor. The contraction tree used in the forward
pass will be cached (or re-derived) for the backward pass so that
each partial derivative is itself an optimised einsum.