tenferro_ext_burn/
lib.rs

1//! Bridge between the [Burn](https://burn.dev) deep learning framework and
2//! tenferro tensor network operations.
3//!
4//! This crate allows Burn tensors to be used with tenferro's einsum and
5//! tensor network contraction routines, enabling seamless integration of
6//! tensor network methods into Burn-based deep learning pipelines.
7//!
8//! Burn tensors are treated as row-major boundary values. The bridge
9//! normalizes them into tenferro's internal column-major canonical layout for
10//! computation, then materializes row-major buffers again when exporting back
11//! to Burn.
12//!
13//! # Examples
14//!
15//! ```ignore
16//! use burn::backend::NdArray;
17//! use burn::tensor::Tensor;
18//! use tenferro_ext_burn::einsum;
19//!
20//! // Matrix multiplication via einsum
21//! let a: Tensor<NdArray<f64>, 2> = Tensor::ones([3, 4], &Default::default());
22//! let b: Tensor<NdArray<f64>, 2> = Tensor::ones([4, 5], &Default::default());
23//! let c: Tensor<NdArray<f64>, 2> = einsum("ij,jk->ik", vec![a, b]);
24//! ```
25
26pub mod backward;
27pub mod convert;
28pub mod forward;
29
30#[cfg(test)]
31mod tests;
32
33pub use convert::{burn_to_tenferro, tenferro_to_burn};
34
35use tenferro_algebra::Standard;
36use tenferro_prims::{CpuBackend, CpuContext};
37use thiserror::Error as ThisError;
38
39use burn::tensor::backend::Backend;
40use burn::tensor::ops::FloatTensor;
41use burn::tensor::{Tensor, TensorMetadata, TensorPrimitive};
42
43/// Error type for Burn/tenferro bridge failures.
44///
45/// # Examples
46///
47/// ```
48/// use tenferro_ext_burn::Error;
49///
50/// let err = Error::InternalInvariant("example");
51/// assert!(err.to_string().contains("example"));
52/// ```
53#[derive(Debug, ThisError)]
54pub enum Error {
55    #[error("invalid tenferro-ext-burn argument: {0}")]
56    InvalidArgument(String),
57    #[error("tenferro-ext-burn internal invariant violated: {0}")]
58    InternalInvariant(&'static str),
59}
60
61/// Result type for Burn/tenferro bridge operations.
62///
63/// # Examples
64///
65/// ```
66/// use tenferro_ext_burn::{Error, Result};
67///
68/// let result: Result<()> = Err(Error::InvalidArgument("bad einsum".into()));
69/// assert!(result.is_err());
70/// ```
71pub type Result<T> = std::result::Result<T, Error>;
72
73pub(crate) fn panic_on_error<T>(result: Result<T>) -> T {
74    result.unwrap_or_else(|err| panic!("{err}"))
75}
76
77/// Trait for backends that support tenferro tensor network operations.
78///
79/// Implement this trait for a Burn backend to enable `einsum` and other
80/// tensor network primitives on that backend's tensors.
81///
82/// # Examples
83///
84/// ```ignore
85/// use burn::backend::NdArray;
86/// use tenferro_ext_burn::TensorNetworkOps;
87///
88/// // NdArray<f64> implements TensorNetworkOps
89/// let result = <NdArray<f64> as TensorNetworkOps>::tn_einsum(
90///     "ij,jk->ik",
91///     vec![a_primitive, b_primitive],
92/// );
93/// ```
94pub trait TensorNetworkOps: Backend<FloatElem = f64> {
95    /// Perform an einsum contraction on raw backend tensor primitives.
96    ///
97    /// This operates at the primitive level. Prefer the high-level [`einsum`]
98    /// function for typical usage.
99    fn tn_einsum(subscripts: &str, inputs: Vec<FloatTensor<Self>>) -> FloatTensor<Self>;
100}
101
102pub(crate) fn try_primitive_einsum<B: Backend<FloatElem = f64>>(
103    subscripts: &str,
104    inputs: Vec<FloatTensor<B>>,
105) -> Result<FloatTensor<B>> {
106    let first = inputs.first().ok_or_else(|| {
107        Error::InvalidArgument(
108            "tenferro-ext-burn::einsum requires at least one input tensor".into(),
109        )
110    })?;
111    let device = B::float_device(first);
112    let tenferro_inputs: Vec<_> = inputs
113        .iter()
114        .cloned()
115        .map(convert::try_burn_to_tenferro::<B>)
116        .collect::<Result<_>>()?;
117    let operand_refs: Vec<_> = tenferro_inputs.iter().collect();
118    let mut ctx = CpuContext::new(1);
119    let output = tenferro_einsum::einsum::<Standard<f64>, CpuBackend>(
120        &mut ctx,
121        subscripts,
122        &operand_refs,
123        None,
124    )
125    .map_err(|err| Error::InvalidArgument(err.to_string()))?;
126
127    convert::try_tenferro_to_burn::<B>(output, &device)
128}
129
130pub(crate) fn primitive_einsum<B: Backend<FloatElem = f64>>(
131    subscripts: &str,
132    inputs: Vec<FloatTensor<B>>,
133) -> FloatTensor<B> {
134    panic_on_error(try_primitive_einsum::<B>(subscripts, inputs))
135}
136
137/// Fallible high-level einsum on Burn tensors, dispatching to the backend's
138/// [`TensorNetworkOps::tn_einsum`] implementation.
139///
140/// The const rank `D` is shared by the input and output Burn tensors, so this
141/// wrapper is only suitable for contractions whose output rank stays equal to
142/// the input rank. Use [`TensorNetworkOps::tn_einsum`] directly for
143/// rank-changing contractions.
144///
145/// # Examples
146///
147/// ```ignore
148/// use burn::backend::NdArray;
149/// use burn::tensor::Tensor;
150/// use tenferro_ext_burn::try_einsum;
151///
152/// let a: Tensor<NdArray<f64>, 2> = Tensor::ones([3, 4], &Default::default());
153/// let b: Tensor<NdArray<f64>, 2> = Tensor::ones([4, 5], &Default::default());
154/// let c: Tensor<NdArray<f64>, 2> = try_einsum("ij,jk->ik", vec![a, b]).unwrap();
155/// ```
156pub fn try_einsum<B: TensorNetworkOps, const D: usize>(
157    subscripts: &str,
158    inputs: Vec<Tensor<B, D>>,
159) -> Result<Tensor<B, D>> {
160    let primitive_inputs: Vec<_> = inputs
161        .into_iter()
162        .map(|tensor| tensor.into_primitive().tensor())
163        .collect();
164    let output = B::tn_einsum(subscripts, primitive_inputs);
165
166    if output.rank() != D {
167        return Err(Error::InvalidArgument(format!(
168            "tenferro-ext-burn::einsum expected output rank {D}, got {}",
169            output.rank()
170        )));
171    }
172
173    Ok(Tensor::from_primitive(TensorPrimitive::Float(output)))
174}
175
176/// High-level infallible einsum convenience wrapper.
177///
178/// # Examples
179///
180/// ```ignore
181/// use burn::backend::NdArray;
182/// use burn::tensor::Tensor;
183/// use tenferro_ext_burn::einsum;
184///
185/// let a: Tensor<NdArray<f64>, 2> = Tensor::ones([2, 2], &Default::default());
186/// let b: Tensor<NdArray<f64>, 2> = Tensor::ones([2, 2], &Default::default());
187/// let c = einsum("ij,jk->ik", vec![a, b]);
188/// assert_eq!(c.dims(), [2, 2]);
189/// ```
190pub fn einsum<B: TensorNetworkOps, const D: usize>(
191    subscripts: &str,
192    inputs: Vec<Tensor<B, D>>,
193) -> Tensor<B, D> {
194    panic_on_error(try_einsum(subscripts, inputs))
195}