tenferro_burn/lib.rs
1//! Bridge between the [Burn](https://burn.dev) deep learning framework and
2//! tenferro tensor network operations.
3//!
4//! This crate allows Burn tensors to be used with tenferro's einsum and
5//! tensor network contraction routines, enabling seamless integration of
6//! tensor network methods into Burn-based deep learning pipelines.
7//!
8//! # Examples
9//!
10//! ```ignore
11//! use burn::backend::NdArray;
12//! use burn::tensor::Tensor;
13//! use tenferro_burn::einsum;
14//!
15//! // Matrix multiplication via einsum
16//! let a: Tensor<NdArray<f64>, 2> = Tensor::ones([3, 4], &Default::default());
17//! let b: Tensor<NdArray<f64>, 2> = Tensor::ones([4, 5], &Default::default());
18//! let c: Tensor<NdArray<f64>, 2> = einsum("ij,jk->ik", vec![a, b]);
19//! ```
20
21pub mod backward;
22pub mod convert;
23pub mod forward;
24
25use burn::tensor::backend::Backend;
26use burn::tensor::ops::FloatTensor;
27use burn::tensor::Tensor;
28
29/// Trait for backends that support tenferro tensor network operations.
30///
31/// Implement this trait for a Burn backend to enable `einsum` and other
32/// tensor network primitives on that backend's tensors.
33///
34/// # Examples
35///
36/// ```ignore
37/// use burn::backend::NdArray;
38/// use tenferro_burn::TensorNetworkOps;
39///
40/// // NdArray<f64> implements TensorNetworkOps
41/// let result = <NdArray<f64> as TensorNetworkOps>::tn_einsum(
42/// "ij,jk->ik",
43/// vec![a_primitive, b_primitive],
44/// );
45/// ```
46pub trait TensorNetworkOps: Backend {
47 /// Perform an einsum contraction on raw backend tensor primitives.
48 ///
49 /// This operates at the primitive level. Prefer the high-level [`einsum`]
50 /// function for typical usage.
51 fn tn_einsum(subscripts: &str, inputs: Vec<FloatTensor<Self>>) -> FloatTensor<Self>;
52}
53
54/// High-level einsum on Burn tensors, dispatching to the backend's
55/// [`TensorNetworkOps::tn_einsum`] implementation.
56///
57/// # Examples
58///
59/// ```ignore
60/// use burn::backend::NdArray;
61/// use burn::tensor::Tensor;
62/// use tenferro_burn::einsum;
63///
64/// let a: Tensor<NdArray<f64>, 2> = Tensor::ones([3, 4], &Default::default());
65/// let b: Tensor<NdArray<f64>, 2> = Tensor::ones([4, 5], &Default::default());
66/// let c: Tensor<NdArray<f64>, 2> = einsum("ij,jk->ik", vec![a, b]);
67/// ```
68pub fn einsum<B: TensorNetworkOps, const D: usize>(
69 _subscripts: &str,
70 _inputs: Vec<Tensor<B, D>>,
71) -> Tensor<B, D> {
72 todo!()
73}