TensorNetworkOps

Trait TensorNetworkOps 

Source
pub trait TensorNetworkOps: Backend {
    // Required method
    fn tn_einsum(
        subscripts: &str,
        inputs: Vec<FloatTensor<Self>>,
    ) -> FloatTensor<Self>;
}
Expand description

Trait for backends that support tenferro tensor network operations.

Implement this trait for a Burn backend to enable einsum and other tensor network primitives on that backend’s tensors.

§Examples

use burn::backend::NdArray;
use tenferro_burn::TensorNetworkOps;

// NdArray<f64> implements TensorNetworkOps
let result = <NdArray<f64> as TensorNetworkOps>::tn_einsum(
    "ij,jk->ik",
    vec![a_primitive, b_primitive],
);

Required Methods§

Source

fn tn_einsum( subscripts: &str, inputs: Vec<FloatTensor<Self>>, ) -> FloatTensor<Self>

Perform an einsum contraction on raw backend tensor primitives.

This operates at the primitive level. Prefer the high-level einsum function for typical usage.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl TensorNetworkOps for NdArray<f64>

Source§

fn tn_einsum( _subscripts: &str, _inputs: Vec<FloatTensor<Self>>, ) -> FloatTensor<Self>

Source§

impl<B, C> TensorNetworkOps for Autodiff<B, C>
where B: TensorNetworkOps, C: CheckpointStrategy,

Source§

fn tn_einsum( _subscripts: &str, _inputs: Vec<FloatTensor<Self>>, ) -> FloatTensor<Self>

Perform an einsum contraction, recording the operation on the autodiff tape so that gradients can be computed during the backward pass.

§Future implementation plan

The backward pass will invoke tenferro’s einsum_rrule to obtain the VJP for each input tensor. The contraction tree used in the forward pass will be cached (or re-derived) for the backward pass so that each partial derivative is itself an optimised einsum.

Implementors§