tenferro_einsum/
lib.rs

1#![allow(
2    clippy::multiple_bound_locations,
3    clippy::too_many_arguments,
4    clippy::type_complexity
5)]
6
7//! High-level einsum with N-ary contraction tree optimization.
8//!
9//! This crate provides Einstein summation notation for [`tenferro_tensor::Tensor`]
10//! values. It supports:
11//!
12//! - **String notation**: `"ij,jk->ik"` (NumPy/PyTorch compatible)
13//! - **Parenthesized notation**: `"ij,(jk,kl)->il"` respects user-specified
14//!   contraction order via [`NestedEinsum`] (OMEinsum.jl-compatible)
15//! - **Integer label notation**: omeinsum-rs compatible, using `u32` labels
16//! - **N-ary contraction**: Automatic or manual optimization of pairwise
17//!   contraction order via [`ContractionTree`]
18//! - **Binary primitive**: Public two-input einsum APIs (`einsum_binary*`) for
19//!   composing explicit contraction paths in higher layers
20//! - **Accumulating variants**: [`einsum_into`], [`einsum_with_subscripts_into`],
21//!   [`einsum_with_plan_into`] write into a pre-allocated output buffer with
22//!   BLAS-style `alpha`/`beta` scaling, avoiding allocation in hot loops
23//!
24//! # Backend dispatch
25//!
26//! The backend is passed explicitly as a type parameter
27//! `Backend: EinsumBackend<Alg>` with a mutable context
28//! [`BackendContext<Alg, Backend>`](crate::BackendContext). This follows Rust
29//! idiom of explicit ownership and mutability (no global/thread-local state).
30//! The backend contract is the semiring core plus optional semiring fast paths.
31//!
32//! # Examples
33//!
34//! ## Common operations
35//!
36//! ```ignore
37//! use tenferro_algebra::Standard;
38//! use tenferro_einsum::einsum;
39//! use tenferro_tensor::{Tensor, MemoryOrder};
40//! use tenferro_device::LogicalMemorySpace;
41//! use tenferro_prims::{CpuBackend, CpuContext};
42//!
43//! let col = MemoryOrder::ColumnMajor;
44//! let mut ctx = CpuContext::new(4);
45//!
46//! let a = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], col).unwrap();
47//! let b = Tensor::<f64>::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2], col).unwrap();
48//! let v = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0], &[3], col).unwrap();
49//!
50//! // Matrix multiplication: C = A @ B
51//! let c = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "ij,jk->ik", &[&a, &b], None).unwrap();
52//!
53//! // Trace: tr(A)
54//! let tr = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "ii->", &[&a], None).unwrap();
55//!
56//! // Outer product: v_i * v_j -> M_{ij}
57//! let outer =
58//!     einsum::<Standard<f64>, CpuBackend>(&mut ctx, "i,j->ij", &[&v, &v], None).unwrap();
59//!
60//! // Dot product: v . v
61//! let dot = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "i,i->", &[&v, &v], None).unwrap();
62//!
63//! // Matrix-vector product: A @ v
64//! let mv = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "ij,j->i", &[&a, &v], None).unwrap();
65//!
66//! // Diagonal embedding: vector -> diagonal matrix
67//! // v = [1, 2, 3] -> [[1,0,0],[0,2,0],[0,0,3]]
68//! let diag = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "i->ii", &[&v], None).unwrap();
69//! assert_eq!(diag.dims(), &[3, 3]);
70//!
71//! // Diagonal extraction: matrix -> diagonal vector
72//! let d = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "ii->i", &[&a], None).unwrap();
73//!
74//! // Higher-order diagonal: 3D tensor with repeated index
75//! // Creates T_{iii} from v_i
76//! let t = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "i->iii", &[&v], None).unwrap();
77//! assert_eq!(t.dims(), &[3, 3, 3]);
78//!
79//! // Consuming variant: operands are moved (buffer reuse not yet implemented)
80//! use tenferro_einsum::einsum_owned;
81//! let x = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], col).unwrap();
82//! let y = Tensor::<f64>::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2], col).unwrap();
83//! let z =
84//!     einsum_owned::<Standard<f64>, CpuBackend>(&mut ctx, "ij,jk->ik", vec![x, y], None)
85//!         .unwrap();
86//! ```
87//!
88//! ## Batch operations
89//!
90//! ```ignore
91//! use tenferro_algebra::Standard;
92//! // Batched GEMM: 10 independent matrix multiplications in one call
93//! // A: (batch=10, m=3, k=4), B: (batch=10, k=4, n=5) -> C: (batch=10, m=3, n=5)
94//! let a = Tensor::<f64>::zeros(&[10, 3, 4], LogicalMemorySpace::MainMemory, col).unwrap();
95//! let b = Tensor::<f64>::zeros(&[10, 4, 5], LogicalMemorySpace::MainMemory, col).unwrap();
96//! let c =
97//!     einsum::<Standard<f64>, CpuBackend>(&mut ctx, "bij,bjk->bik", &[&a, &b], None).unwrap();
98//! assert_eq!(c.dims(), &[10, 3, 5]);
99//!
100//! // Multiple batch dimensions: (batch1=2, batch2=3, m, k) x (batch1=2, batch2=3, k, n)
101//! let a = Tensor::<f64>::zeros(&[2, 3, 4, 5], LogicalMemorySpace::MainMemory, col).unwrap();
102//! let b = Tensor::<f64>::zeros(&[2, 3, 5, 6], LogicalMemorySpace::MainMemory, col).unwrap();
103//! let c = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "abij,abjk->abik", &[&a, &b], None)
104//!     .unwrap();
105//! assert_eq!(c.dims(), &[2, 3, 4, 6]);
106//!
107//! // Broadcast batch: A has batch dim, B is shared across batch
108//! // A: (batch=10, m=3, k=4), B: (k=4, n=5) -> C: (batch=10, m=3, n=5)
109//! let a = Tensor::<f64>::zeros(&[10, 3, 4], LogicalMemorySpace::MainMemory, col).unwrap();
110//! let b = Tensor::<f64>::zeros(&[4, 5], LogicalMemorySpace::MainMemory, col).unwrap();
111//! let c = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "bij,jk->bik", &[&a, &b], None)
112//!     .unwrap();
113//! assert_eq!(c.dims(), &[10, 3, 5]);
114//! ```
115//!
116//! ## Ellipsis notation for batch dimensions
117//!
118//! NumPy/PyTorch/JAX-style ellipsis notation (`...`) is fully supported for
119//! batch dimensions, allowing generic code that works with any number of batch
120//! dimensions (resolves issue #529).
121//!
122//! ```ignore
123//! use tenferro_algebra::Standard;
124//!
125//! // Batched matrix multiply with ellipsis: works with any number of batch dims
126//! // 1 batch dim: A[2,3,4] @ B[2,4,5] -> C[2,3,5]
127//! let a = Tensor::<f64>::zeros(&[2, 3, 4], LogicalMemorySpace::MainMemory, col).unwrap();
128//! let b = Tensor::<f64>::zeros(&[2, 4, 5], LogicalMemorySpace::MainMemory, col).unwrap();
129//! let c = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "...ij,...jk->...ik", &[&a, &b], None)
130//!     .unwrap();
131//! assert_eq!(c.dims(), &[2, 3, 5]);
132//!
133//! // 2 batch dims: A[2,3,4,5] @ B[2,3,5,6] -> C[2,3,4,6]
134//! let a = Tensor::<f64>::zeros(&[2, 3, 4, 5], LogicalMemorySpace::MainMemory, col).unwrap();
135//! let b = Tensor::<f64>::zeros(&[2, 3, 5, 6], LogicalMemorySpace::MainMemory, col).unwrap();
136//! let c = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "...ij,...jk->...ik", &[&a, &b], None)
137//!     .unwrap();
138//! assert_eq!(c.dims(), &[2, 3, 4, 6]);
139//!
140//! // No batch dims: A[3,4] @ B[4,5] -> C[3,5]
141//! let a = Tensor::<f64>::zeros(&[3, 4], LogicalMemorySpace::MainMemory, col).unwrap();
142//! let b = Tensor::<f64>::zeros(&[4, 5], LogicalMemorySpace::MainMemory, col).unwrap();
143//! let c = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "...ij,...jk->...ik", &[&a, &b], None)
144//!     .unwrap();
145//! assert_eq!(c.dims(), &[3, 5]);
146//! ```
147//!
148//! ## Integer label notation
149//!
150//! ```ignore
151//! use tenferro_algebra::Standard;
152//! use tenferro_einsum::{einsum_with_subscripts, Subscripts};
153//!
154//! // Same as "ij,jk->ik" but with integer labels
155//! // Useful when indices exceed 52 (a-z, A-Z) or are computed programmatically
156//! let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
157//! let c = einsum_with_subscripts::<Standard<f64>, CpuBackend>(&mut ctx, &subs, &[&a, &b], None)
158//!     .unwrap();
159//! ```
160//!
161//! ## Contraction order control
162//!
163//! ```ignore
164//! use tenferro_algebra::Standard;
165//! // Three matrices: D = A @ B @ C
166//! // Parentheses specify: contract B*C first, then A*(BC)
167//! let d = einsum::<Standard<f64>, CpuBackend>(&mut ctx, "ij,(jk,kl)->il", &[&a, &b, &c], None)
168//!     .unwrap();
169//!
170//! // Or use ContractionTree for programmatic control
171//! use tenferro_einsum::ContractionTree;
172//! let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
173//! let tree = ContractionTree::from_pairs(
174//!     &subs,
175//!     &[&[3, 4], &[4, 5], &[5, 6]],
176//!     &[(1, 2), (0, 3)],  // B*C first (avoids large intermediate)
177//! ).unwrap();
178//! let d = einsum_with_plan::<Standard<f64>, CpuBackend>(&mut ctx, &tree, &[&a, &b, &c], None)
179//!     .unwrap();
180//! ```
181//!
182//! ## Accumulating into a pre-allocated output
183//!
184//! ```ignore
185//! use tenferro_algebra::Standard;
186//! use tenferro_einsum::{einsum_with_plan_into, ContractionTree, Subscripts};
187//! use tenferro_tensor::{Tensor, MemoryOrder};
188//! use tenferro_device::LogicalMemorySpace;
189//! use tenferro_prims::{CpuBackend, CpuContext};
190//!
191//! let col = MemoryOrder::ColumnMajor;
192//! let mut ctx = CpuContext::new(4);
193//! let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
194//! let tree = ContractionTree::optimize(&subs, &[&[3, 4], &[4, 5]]).unwrap();
195//! let a = Tensor::<f64>::zeros(&[3, 4], LogicalMemorySpace::MainMemory, col).unwrap();
196//! let b = Tensor::<f64>::zeros(&[4, 5], LogicalMemorySpace::MainMemory, col).unwrap();
197//! let mut c = Tensor::<f64>::zeros(&[3, 5], LogicalMemorySpace::MainMemory, col).unwrap();
198//!
199//! // Hot loop: reuse output buffer, zero allocation per iteration
200//! for _ in 0..1000 {
201//!     // C = 1.0 * (A @ B) + 0.0 * C  (overwrite)
202//!     einsum_with_plan_into::<Standard<f64>, CpuBackend>(
203//!         &mut ctx, &tree, &[&a, &b], 1.0, 0.0, &mut c, None,
204//!     ).unwrap();
205//! }
206//! ```
207//!
208//! ## GPU async chaining (deferred evaluation)
209//!
210//! > **Status: Not yet implemented.** GPU backends do not exist yet.
211//! > The examples below are aspirational design targets, not working code.
212//!
213//! GPU einsum operations return immediately. The result tensor carries a
214//! [`CompletionEvent`](tenferro_tensor::CompletionEvent) that tracks the
215//! pending accelerator work. Passing this tensor to another einsum chains
216//! via GPU stream dependencies — no CPU synchronization until data is
217//! accessed from the host.
218//!
219//! - `wait()` — explicitly blocks until computation completes
220//! - `dims()`, `strides()` — implicitly call `wait()`
221//! - For CPU tensors, `event` is always `None` (zero overhead)
222//!
223//! ```ignore
224//! use tenferro_algebra::Standard;
225//! use tenferro_einsum::einsum;
226//! use tenferro_tensor::{Tensor, MemoryOrder};
227//! use tenferro_device::LogicalMemorySpace;
228//! use tenferro_prims::CudaBackend; // future
229//!
230//! // In production, obtain memory spaces via BackendRegistry (future API).
231//! let gpu_mem = LogicalMemorySpace::GpuMemory { device_id: 0 };
232//! let col = MemoryOrder::ColumnMajor;
233//! let mut gpu_ctx = /* CudaContext from BackendRegistry */;
234//!
235//! let a = Tensor::<f64>::zeros(&[3, 4], gpu_mem, col).unwrap();
236//! let b = Tensor::<f64>::zeros(&[4, 5], gpu_mem, col).unwrap();
237//!
238//! // Both einsum calls submit work to the GPU and return immediately.
239//! // The second call detects c's pending event and chains on the stream.
240//! let c = einsum::<Standard<f64>, CudaBackend>(&mut gpu_ctx, "ij,jk->ik", &[&a, &b], None)
241//!     .unwrap();
242//! let d = einsum::<Standard<f64>, CudaBackend>(&mut gpu_ctx, "ij,jk->ik", &[&c, &b], None)
243//!     .unwrap();
244//!
245//! // wait() blocks until GPU computation completes
246//! d.wait();
247//! ```
248//!
249//! ## Specifying a compute device
250//!
251//! > **Status: Not yet implemented.** See GPU note above.
252//!
253//! ```ignore
254//! use tenferro_einsum::einsum;
255//! use tenferro_tensor::{Tensor, MemoryOrder};
256//! use tenferro_device::{LogicalMemorySpace, ComputeDevice};
257//!
258//! let col = MemoryOrder::ColumnMajor;
259//! // In production, obtain memory spaces via BackendRegistry (future API).
260//! let gpu_mem = LogicalMemorySpace::GpuMemory { device_id: 0 };
261//!
262//! let mut a = Tensor::<f64>::zeros(&[3, 4], gpu_mem, col).unwrap();
263//! let mut b = Tensor::<f64>::zeros(&[4, 5], gpu_mem, col).unwrap();
264//!
265//! // Pin tensors to CUDA device 1 (overrides automatic device selection).
266//! // This works when CUDA device 1 can access GpuMemory { device_id: 0 }
267//! // (e.g., same physical GPU or NVLink-connected peer).
268//! // If the device cannot access the memory space, einsum returns
269//! // Err(NoCompatibleComputeDevice). In that case, transfer explicitly:
270//! //   let a = a.to_memory_space_async(GpuMemory { device_id: 1 }).unwrap();
271//! a.set_preferred_compute_device(Some(ComputeDevice::Cuda { device_id: 1 }));
272//! b.set_preferred_compute_device(Some(ComputeDevice::Cuda { device_id: 1 }));
273//!
274//! // einsum dispatches to the specified CUDA device
275//! let c = einsum::<Standard<f64>, CudaBackend>(&mut gpu_ctx, "ij,jk->ik", &[&a, &b], None)
276//!     .unwrap();
277//!
278//! // Clear override — revert to automatic device selection
279//! // a.set_preferred_compute_device(None);
280//! ```
281
282// Internal modules
283pub(crate) mod ad;
284pub(crate) mod api;
285mod execution;
286mod layout;
287mod planning;
288mod syntax;
289
290// Public re-exports: types
291pub use execution::{BackendContext, EinsumBackend};
292pub use planning::{ContractionOptimizerOptions, ContractionTree};
293pub use syntax::{NestedEinsum, Subscripts};
294
295// Public re-exports: functions
296pub use api::{
297    einsum, einsum_binary, einsum_binary_into, einsum_binary_with_subscripts,
298    einsum_binary_with_subscripts_into, einsum_into, einsum_owned, einsum_with_path,
299    einsum_with_path_into, einsum_with_plan, einsum_with_plan_into, einsum_with_plan_owned,
300    einsum_with_subscripts, einsum_with_subscripts_into, einsum_with_subscripts_owned,
301};
302
303pub use ad::{einsum_frule, einsum_rrule};
304
305#[cfg(feature = "profile-dispatch")]
306pub use execution::print_and_reset_profile;
307
308// ============================================================================
309// Tests
310// ============================================================================
311
312#[cfg(test)]
313mod tests;