tenferro_ext_burn/lib.rs
1//! Bridge between the [Burn](https://burn.dev) deep learning framework and
2//! tenferro tensor network operations.
3//!
4//! This crate allows Burn tensors to be used with tenferro's einsum and
5//! tensor network contraction routines, enabling seamless integration of
6//! tensor network methods into Burn-based deep learning pipelines.
7//!
8//! Burn tensors are treated as row-major boundary values. The bridge
9//! normalizes them into tenferro's internal column-major canonical layout for
10//! computation, then materializes row-major buffers again when exporting back
11//! to Burn.
12//!
13//! # Examples
14//!
15//! ```ignore
16//! use burn::backend::NdArray;
17//! use burn::tensor::Tensor;
18//! use tenferro_ext_burn::einsum;
19//!
20//! // Matrix multiplication via einsum
21//! let a: Tensor<NdArray<f64>, 2> = Tensor::ones([3, 4], &Default::default());
22//! let b: Tensor<NdArray<f64>, 2> = Tensor::ones([4, 5], &Default::default());
23//! let c: Tensor<NdArray<f64>, 2> = einsum("ij,jk->ik", vec![a, b]);
24//! ```
25
26pub mod backward;
27pub mod convert;
28pub mod forward;
29
30#[cfg(test)]
31mod tests;
32
33pub use convert::{burn_to_tenferro, tenferro_to_burn};
34
35use tenferro_algebra::Standard;
36use tenferro_prims::{CpuBackend, CpuContext};
37use thiserror::Error as ThisError;
38
39use burn::tensor::backend::Backend;
40use burn::tensor::ops::FloatTensor;
41use burn::tensor::{Tensor, TensorMetadata, TensorPrimitive};
42
43/// Error type for Burn/tenferro bridge failures.
44///
45/// # Examples
46///
47/// ```
48/// use tenferro_ext_burn::Error;
49///
50/// let err = Error::InternalInvariant("example");
51/// assert!(err.to_string().contains("example"));
52/// ```
53#[derive(Debug, ThisError)]
54pub enum Error {
55 #[error("invalid tenferro-ext-burn argument: {0}")]
56 InvalidArgument(String),
57 #[error("tenferro-ext-burn internal invariant violated: {0}")]
58 InternalInvariant(&'static str),
59}
60
61/// Result type for Burn/tenferro bridge operations.
62///
63/// # Examples
64///
65/// ```
66/// use tenferro_ext_burn::{Error, Result};
67///
68/// let result: Result<()> = Err(Error::InvalidArgument("bad einsum".into()));
69/// assert!(result.is_err());
70/// ```
71pub type Result<T> = std::result::Result<T, Error>;
72
73pub(crate) fn panic_on_error<T>(result: Result<T>) -> T {
74 result.unwrap_or_else(|err| panic!("{err}"))
75}
76
77/// Trait for backends that support tenferro tensor network operations.
78///
79/// Implement this trait for a Burn backend to enable `einsum` and other
80/// tensor network primitives on that backend's tensors.
81///
82/// # Examples
83///
84/// ```ignore
85/// use burn::backend::NdArray;
86/// use tenferro_ext_burn::TensorNetworkOps;
87///
88/// // NdArray<f64> implements TensorNetworkOps
89/// let result = <NdArray<f64> as TensorNetworkOps>::tn_einsum(
90/// "ij,jk->ik",
91/// vec![a_primitive, b_primitive],
92/// );
93/// ```
94pub trait TensorNetworkOps: Backend<FloatElem = f64> {
95 /// Perform an einsum contraction on raw backend tensor primitives.
96 ///
97 /// This operates at the primitive level. Prefer the high-level [`einsum`]
98 /// function for typical usage.
99 fn tn_einsum(subscripts: &str, inputs: Vec<FloatTensor<Self>>) -> FloatTensor<Self>;
100}
101
102pub(crate) fn try_primitive_einsum<B: Backend<FloatElem = f64>>(
103 subscripts: &str,
104 inputs: Vec<FloatTensor<B>>,
105) -> Result<FloatTensor<B>> {
106 let first = inputs.first().ok_or_else(|| {
107 Error::InvalidArgument(
108 "tenferro-ext-burn::einsum requires at least one input tensor".into(),
109 )
110 })?;
111 let device = B::float_device(first);
112 let tenferro_inputs: Vec<_> = inputs
113 .iter()
114 .cloned()
115 .map(convert::try_burn_to_tenferro::<B>)
116 .collect::<Result<_>>()?;
117 let operand_refs: Vec<_> = tenferro_inputs.iter().collect();
118 let mut ctx = CpuContext::new(1);
119 let output = tenferro_einsum::einsum::<Standard<f64>, CpuBackend>(
120 &mut ctx,
121 subscripts,
122 &operand_refs,
123 None,
124 )
125 .map_err(|err| Error::InvalidArgument(err.to_string()))?;
126
127 convert::try_tenferro_to_burn::<B>(output, &device)
128}
129
130pub(crate) fn primitive_einsum<B: Backend<FloatElem = f64>>(
131 subscripts: &str,
132 inputs: Vec<FloatTensor<B>>,
133) -> FloatTensor<B> {
134 panic_on_error(try_primitive_einsum::<B>(subscripts, inputs))
135}
136
137/// Fallible high-level einsum on Burn tensors, dispatching to the backend's
138/// [`TensorNetworkOps::tn_einsum`] implementation.
139///
140/// The const rank `D` is shared by the input and output Burn tensors, so this
141/// wrapper is only suitable for contractions whose output rank stays equal to
142/// the input rank. Use [`TensorNetworkOps::tn_einsum`] directly for
143/// rank-changing contractions.
144///
145/// # Examples
146///
147/// ```ignore
148/// use burn::backend::NdArray;
149/// use burn::tensor::Tensor;
150/// use tenferro_ext_burn::try_einsum;
151///
152/// let a: Tensor<NdArray<f64>, 2> = Tensor::ones([3, 4], &Default::default());
153/// let b: Tensor<NdArray<f64>, 2> = Tensor::ones([4, 5], &Default::default());
154/// let c: Tensor<NdArray<f64>, 2> = try_einsum("ij,jk->ik", vec![a, b]).unwrap();
155/// ```
156pub fn try_einsum<B: TensorNetworkOps, const D: usize>(
157 subscripts: &str,
158 inputs: Vec<Tensor<B, D>>,
159) -> Result<Tensor<B, D>> {
160 let primitive_inputs: Vec<_> = inputs
161 .into_iter()
162 .map(|tensor| tensor.into_primitive().tensor())
163 .collect();
164 let output = B::tn_einsum(subscripts, primitive_inputs);
165
166 if output.rank() != D {
167 return Err(Error::InvalidArgument(format!(
168 "tenferro-ext-burn::einsum expected output rank {D}, got {}",
169 output.rank()
170 )));
171 }
172
173 Ok(Tensor::from_primitive(TensorPrimitive::Float(output)))
174}
175
176/// High-level infallible einsum convenience wrapper.
177///
178/// # Examples
179///
180/// ```ignore
181/// use burn::backend::NdArray;
182/// use burn::tensor::Tensor;
183/// use tenferro_ext_burn::einsum;
184///
185/// let a: Tensor<NdArray<f64>, 2> = Tensor::ones([2, 2], &Default::default());
186/// let b: Tensor<NdArray<f64>, 2> = Tensor::ones([2, 2], &Default::default());
187/// let c = einsum("ij,jk->ik", vec![a, b]);
188/// assert_eq!(c.dims(), [2, 2]);
189/// ```
190pub fn einsum<B: TensorNetworkOps, const D: usize>(
191 subscripts: &str,
192 inputs: Vec<Tensor<B, D>>,
193) -> Tensor<B, D> {
194 panic_on_error(try_einsum(subscripts, inputs))
195}