tenferro_einsum/lib.rs
1//! High-level einsum with N-ary contraction tree optimization.
2//!
3//! This crate provides Einstein summation notation for [`Tensor`]
4//! values. It supports:
5//!
6//! - **String notation**: `"ij,jk->ik"` (NumPy/PyTorch compatible)
7//! - **Parenthesized contraction order**: `"ij,(jk,kl)->il"` to control
8//! pairwise contraction sequence in string notation
9//! - **Integer label notation**: omeinsum-rs compatible, using `u32` labels
10//! - **N-ary contraction**: Automatic or manual optimization of pairwise
11//! contraction order via [`ContractionTree`]
12//! - **Accumulating variants**: [`einsum_into`], [`einsum_with_subscripts_into`],
13//! [`einsum_with_plan_into`] write into a pre-allocated output buffer with
14//! BLAS-style `alpha`/`beta` scaling, avoiding allocation in hot loops
15//!
16//! # Backend dispatch
17//!
18//! The backend is selected automatically from the tensor's
19//! [`LogicalMemorySpace`](tenferro_device::LogicalMemorySpace) (PyTorch-style).
20//! There is no backend type parameter in the public API.
21//!
22//! # Examples
23//!
24//! ## Common operations
25//!
26//! ```ignore
27//! use tenferro_einsum::einsum;
28//! use tenferro_tensor::{Tensor, MemoryOrder};
29//! use tenferro_device::LogicalMemorySpace;
30//!
31//! let col = MemoryOrder::ColumnMajor;
32//!
33//! let a = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], col).unwrap();
34//! let b = Tensor::<f64>::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2], col).unwrap();
35//! let v = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0], &[3], col).unwrap();
36//!
37//! // Matrix multiplication: C = A @ B
38//! let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
39//!
40//! // Trace: tr(A)
41//! let tr = einsum("ii->", &[&a]).unwrap();
42//!
43//! // Outer product: v_i * v_j -> M_{ij}
44//! let outer = einsum("i,j->ij", &[&v, &v]).unwrap();
45//!
46//! // Dot product: v . v
47//! let dot = einsum("i,i->", &[&v, &v]).unwrap();
48//!
49//! // Matrix-vector product: A @ v
50//! let mv = einsum("ij,j->i", &[&a, &v]).unwrap();
51//!
52//! // Diagonal embedding: vector -> diagonal matrix
53//! // v = [1, 2, 3] -> [[1,0,0],[0,2,0],[0,0,3]]
54//! let diag = einsum("i->ii", &[&v]).unwrap();
55//! assert_eq!(diag.dims(), &[3, 3]);
56//!
57//! // Diagonal extraction: matrix -> diagonal vector
58//! let d = einsum("ii->i", &[&a]).unwrap();
59//!
60//! // Higher-order diagonal: 3D tensor with repeated index
61//! // Creates T_{iii} from v_i
62//! let t = einsum("i->iii", &[&v]).unwrap();
63//! assert_eq!(t.dims(), &[3, 3, 3]);
64//!
65//! // Consuming variant: operands are moved, buffers may be reused
66//! use tenferro_einsum::einsum_owned;
67//! let x = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], col).unwrap();
68//! let y = Tensor::<f64>::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2], col).unwrap();
69//! let z = einsum_owned("ij,jk->ik", vec![x, y]).unwrap(); // x, y consumed
70//! ```
71//!
72//! ## Batch operations
73//!
74//! ```ignore
75//! // Batched GEMM: 10 independent matrix multiplications in one call
76//! // A: (batch=10, m=3, k=4), B: (batch=10, k=4, n=5) -> C: (batch=10, m=3, n=5)
77//! let a = Tensor::<f64>::zeros(&[10, 3, 4], LogicalMemorySpace::MainMemory, col);
78//! let b = Tensor::<f64>::zeros(&[10, 4, 5], LogicalMemorySpace::MainMemory, col);
79//! let c = einsum("bij,bjk->bik", &[&a, &b]).unwrap();
80//! assert_eq!(c.dims(), &[10, 3, 5]);
81//!
82//! // Multiple batch dimensions: (batch1=2, batch2=3, m, k) x (batch1=2, batch2=3, k, n)
83//! let a = Tensor::<f64>::zeros(&[2, 3, 4, 5], LogicalMemorySpace::MainMemory, col);
84//! let b = Tensor::<f64>::zeros(&[2, 3, 5, 6], LogicalMemorySpace::MainMemory, col);
85//! let c = einsum("abij,abjk->abik", &[&a, &b]).unwrap();
86//! assert_eq!(c.dims(), &[2, 3, 4, 6]);
87//!
88//! // Broadcast batch: A has batch dim, B is shared across batch
89//! // A: (batch=10, m=3, k=4), B: (k=4, n=5) -> C: (batch=10, m=3, n=5)
90//! let a = Tensor::<f64>::zeros(&[10, 3, 4], LogicalMemorySpace::MainMemory, col);
91//! let b = Tensor::<f64>::zeros(&[4, 5], LogicalMemorySpace::MainMemory, col);
92//! let c = einsum("bij,jk->bik", &[&a, &b]).unwrap();
93//! assert_eq!(c.dims(), &[10, 3, 5]);
94//! ```
95//!
96//! ## Integer label notation
97//!
98//! ```ignore
99//! use tenferro_einsum::{einsum_with_subscripts, Subscripts};
100//!
101//! // Same as "ij,jk->ik" but with integer labels
102//! // Useful when indices exceed 52 (a-z, A-Z) or are computed programmatically
103//! let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
104//! let c = einsum_with_subscripts(&subs, &[&a, &b]).unwrap();
105//! ```
106//!
107//! ## Contraction order control
108//!
109//! ```ignore
110//! // Three matrices: D = A @ B @ C
111//! // Parentheses specify: contract B*C first, then A*(BC)
112//! let d = einsum("ij,(jk,kl)->il", &[&a, &b, &c]).unwrap();
113//!
114//! // Or use ContractionTree for programmatic control
115//! use tenferro_einsum::ContractionTree;
116//! let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
117//! let tree = ContractionTree::from_pairs(
118//! &subs,
119//! &[&[3, 4], &[4, 100], &[100, 5]],
120//! &[(1, 2), (0, 3)], // B*C first (avoids large intermediate)
121//! ).unwrap();
122//! let d = einsum_with_plan(&tree, &[&a, &b, &c]).unwrap();
123//! ```
124//!
125//! ## Accumulating into a pre-allocated output
126//!
127//! ```ignore
128//! use tenferro_einsum::{einsum_with_plan_into, ContractionTree, Subscripts};
129//! use tenferro_tensor::{Tensor, MemoryOrder};
130//! use tenferro_device::LogicalMemorySpace;
131//!
132//! let col = MemoryOrder::ColumnMajor;
133//! let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
134//! let tree = ContractionTree::optimize(&subs, &[&[3, 4], &[4, 5]]).unwrap();
135//! let a = Tensor::<f64>::zeros(&[3, 4], LogicalMemorySpace::MainMemory, col);
136//! let b = Tensor::<f64>::zeros(&[4, 5], LogicalMemorySpace::MainMemory, col);
137//! let mut c = Tensor::<f64>::zeros(&[3, 5], LogicalMemorySpace::MainMemory, col);
138//!
139//! // Hot loop: reuse output buffer, zero allocation per iteration
140//! for _ in 0..1000 {
141//! // C = 1.0 * (A @ B) + 0.0 * C (overwrite)
142//! einsum_with_plan_into(&tree, &[&a, &b], 1.0, 0.0, &mut c).unwrap();
143//! }
144//! ```
145//!
146//! ## GPU async chaining (deferred evaluation)
147//!
148//! GPU einsum operations return immediately. The result tensor carries a
149//! [`CompletionEvent`](tenferro_tensor::CompletionEvent) that tracks the
150//! pending accelerator work. Passing this tensor to another einsum chains
151//! via GPU stream dependencies — no CPU synchronization until data is
152//! accessed from the host.
153//!
154//! - `wait()` — explicitly blocks until computation completes
155//! - `view()`, `dims()`, `strides()` — implicitly call `wait()`
156//! - For CPU tensors, `event` is always `None` (zero overhead)
157//!
158//! ```ignore
159//! use tenferro_einsum::einsum;
160//! use tenferro_tensor::{Tensor, MemoryOrder};
161//! use tenferro_device::LogicalMemorySpace;
162//!
163//! // In production, obtain memory spaces via BackendRegistry (future API).
164//! let gpu_mem = LogicalMemorySpace::GpuMemory { device_id: 0 };
165//! let col = MemoryOrder::ColumnMajor;
166//!
167//! let a = Tensor::<f64>::zeros(&[3, 4], gpu_mem, col);
168//! let b = Tensor::<f64>::zeros(&[4, 5], gpu_mem, col);
169//!
170//! // Both einsum calls submit work to the GPU and return immediately.
171//! // The second call detects c's pending event and chains on the stream.
172//! let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
173//! let d = einsum("ij,jk->ik", &[&c, &b]).unwrap();
174//!
175//! // wait() blocks until GPU computation completes
176//! d.wait();
177//! ```
178//!
179//! ## Specifying a compute device
180//!
181//! ```ignore
182//! use tenferro_einsum::einsum;
183//! use tenferro_tensor::{Tensor, MemoryOrder};
184//! use tenferro_device::{LogicalMemorySpace, ComputeDevice};
185//!
186//! let col = MemoryOrder::ColumnMajor;
187//! // In production, obtain memory spaces via BackendRegistry (future API).
188//! let gpu_mem = LogicalMemorySpace::GpuMemory { device_id: 0 };
189//!
190//! let mut a = Tensor::<f64>::zeros(&[3, 4], gpu_mem, col);
191//! let mut b = Tensor::<f64>::zeros(&[4, 5], gpu_mem, col);
192//!
193//! // Pin tensors to CUDA device 1 (overrides automatic device selection).
194//! // This works when CUDA device 1 can access GpuMemory { device_id: 0 }
195//! // (e.g., same physical GPU or NVLink-connected peer).
196//! // If the device cannot access the memory space, einsum returns
197//! // Err(NoCompatibleComputeDevice). In that case, transfer explicitly:
198//! // let a = a.to_memory_space_async(GpuMemory { device_id: 1 }).unwrap();
199//! a.set_preferred_compute_device(Some(ComputeDevice::Cuda { device_id: 1 }));
200//! b.set_preferred_compute_device(Some(ComputeDevice::Cuda { device_id: 1 }));
201//!
202//! // einsum dispatches to the specified CUDA device
203//! let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
204//!
205//! // Clear override — revert to automatic selection
206//! // a.set_preferred_compute_device(None);
207//! ```
208
209use chainrules::{AdResult, Differentiable, DualTensor, TrackedTensor};
210use tenferro_algebra::{HasAlgebra, Scalar};
211use tenferro_device::Result;
212use tenferro_tensor::Tensor;
213
214/// Einsum subscripts using integer labels (omeinsum-rs compatible).
215///
216/// Each dimension is represented by a `u32` label. Labels shared across
217/// multiple input tensors are contracted (summed over). Labels present
218/// only in the output are free indices.
219///
220/// # Examples
221///
222/// ```
223/// use tenferro_einsum::Subscripts;
224///
225/// // Matrix multiplication: C_{ik} = Σ_j A_{ij} * B_{jk}
226/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
227/// assert_eq!(subs.inputs.len(), 2);
228/// assert_eq!(subs.output, vec![0, 2]);
229/// ```
230///
231/// ```ignore
232/// use tenferro_einsum::Subscripts;
233///
234/// // Parse from string notation
235/// let subs = Subscripts::parse("ij,jk->ik").unwrap();
236/// assert_eq!(subs.inputs.len(), 2);
237/// ```
238#[derive(Debug, Clone)]
239pub struct Subscripts {
240 /// Index labels for each input tensor.
241 pub inputs: Vec<Vec<u32>>,
242 /// Index labels for the output tensor.
243 pub output: Vec<u32>,
244}
245
246impl Subscripts {
247 /// Create subscripts from integer label arrays.
248 ///
249 /// # Arguments
250 ///
251 /// * `inputs` — Index labels for each input tensor
252 /// * `output` — Index labels for the output tensor
253 pub fn new(inputs: &[&[u32]], output: &[u32]) -> Self {
254 Self {
255 inputs: inputs.iter().map(|s| s.to_vec()).collect(),
256 output: output.to_vec(),
257 }
258 }
259
260 /// Parse subscripts from NumPy/PyTorch-style string notation.
261 ///
262 /// Each character (`a`–`z`, `A`–`Z`) represents a dimension label.
263 /// Input tensors are separated by commas, and `->` separates inputs
264 /// from the output.
265 ///
266 /// Parentheses can be used to specify contraction order explicitly.
267 /// Grouped operands are contracted first, enabling manual control
268 /// over the pairwise contraction sequence without using
269 /// [`ContractionTree::from_pairs`].
270 ///
271 /// # Examples
272 ///
273 /// - `"ij,jk->ik"` — matrix multiplication
274 /// - `"ii->i"` — diagonal extraction
275 /// - `"ijk->"` — full contraction (scalar result)
276 /// - `"ij,(jk,kl)->il"` — contract B and C first, then with A
277 ///
278 /// # Errors
279 ///
280 /// Returns an error if the notation is malformed.
281 pub fn parse(notation: &str) -> Result<Self> {
282 todo!()
283 }
284}
285
286/// Contraction tree determining pairwise contraction order for N-ary einsum.
287///
288/// When contracting more than two tensors, the order in which pairwise
289/// contractions are performed significantly affects performance.
290/// `ContractionTree` encodes this order as a binary tree.
291///
292/// # Optimization
293///
294/// Use [`ContractionTree::optimize`] for automatic cost-based optimization
295/// (e.g., greedy algorithm based on tensor sizes), or
296/// [`ContractionTree::from_pairs`] for manual specification.
297pub struct ContractionTree {
298 // Internal representation is private.
299 _private: (),
300}
301
302impl ContractionTree {
303 /// Automatically compute an optimized contraction order.
304 ///
305 /// Uses a cost-based heuristic (e.g., greedy algorithm) to determine
306 /// the pairwise contraction sequence that minimizes total operation count.
307 ///
308 /// # Arguments
309 ///
310 /// * `subscripts` — Einsum subscripts for all tensors
311 /// * `shapes` — Shape of each input tensor
312 ///
313 /// # Errors
314 ///
315 /// Returns an error if subscripts and shapes are inconsistent.
316 pub fn optimize(subscripts: &Subscripts, shapes: &[&[usize]]) -> Result<Self> {
317 todo!()
318 }
319
320 /// Manually build a contraction tree from a pairwise contraction sequence.
321 ///
322 /// Each pair `(i, j)` specifies which two tensors (or intermediate results)
323 /// to contract next. Intermediate results are assigned indices starting
324 /// from the number of input tensors.
325 ///
326 /// # Arguments
327 ///
328 /// * `subscripts` — Einsum subscripts for all tensors
329 /// * `shapes` — Shape of each input tensor
330 /// * `pairs` — Ordered list of pairwise contractions
331 ///
332 /// # Examples
333 ///
334 /// ```ignore
335 /// // Three tensors: A[ij] B[jk] C[kl] -> D[il]
336 /// // Contract B and C first, then A with the result:
337 /// let subs = Subscripts::new(&[&[0, 1], &[1, 2], &[2, 3]], &[0, 3]);
338 /// let shapes = [&[3, 4][..], &[4, 5], &[5, 6]];
339 /// let tree = ContractionTree::from_pairs(
340 /// &subs,
341 /// &shapes,
342 /// &[(1, 2), (0, 3)], // B*C -> T(index=3), then A*T -> D
343 /// ).unwrap();
344 /// ```
345 ///
346 /// # Errors
347 ///
348 /// Returns an error if the pairs do not form a valid contraction sequence.
349 pub fn from_pairs(
350 subscripts: &Subscripts,
351 shapes: &[&[usize]],
352 pairs: &[(usize, usize)],
353 ) -> Result<Self> {
354 todo!()
355 }
356}
357
358/// Execute einsum using string notation.
359///
360/// Parses the subscript string, optimizes the contraction order, and
361/// executes the contraction. The backend is selected automatically from
362/// the tensors' memory space and compute device.
363///
364/// Parentheses in the subscript string specify contraction order
365/// explicitly (e.g., `"ij,(jk,kl)->il"` contracts B and C first).
366/// Without parentheses, the contraction order is optimized automatically.
367///
368/// # Arguments
369///
370/// * `subscripts` — Einstein summation notation (e.g., `"ij,jk->ik"`)
371/// * `operands` — Input tensors
372///
373/// # Examples
374///
375/// ```ignore
376/// // Matrix multiplication
377/// let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
378///
379/// // Trace
380/// let tr = einsum("ii->", &[&a]).unwrap();
381///
382/// // Batch matrix multiplication
383/// let c = einsum("bij,bjk->bik", &[&a, &b]).unwrap();
384///
385/// // Explicit contraction order: contract B*C first, then A
386/// let d = einsum("ij,(jk,kl)->il", &[&a, &b, &c]).unwrap();
387/// ```
388///
389/// # Errors
390///
391/// Returns an error if the notation is invalid or tensor shapes are
392/// incompatible with the subscripts.
393pub fn einsum<T: Scalar + HasAlgebra>(
394 subscripts: &str,
395 operands: &[&Tensor<T>],
396) -> Result<Tensor<T>> {
397 todo!()
398}
399
400/// Execute einsum with pre-built [`Subscripts`].
401///
402/// Avoids re-parsing the subscript string on each call. Useful when the
403/// same contraction pattern is applied to tensors of varying shapes.
404///
405/// # Errors
406///
407/// Returns an error if tensor shapes are incompatible with the subscripts.
408pub fn einsum_with_subscripts<T: Scalar + HasAlgebra>(
409 subscripts: &Subscripts,
410 operands: &[&Tensor<T>],
411) -> Result<Tensor<T>> {
412 todo!()
413}
414
415/// Execute einsum with a pre-optimized [`ContractionTree`].
416///
417/// Avoids both subscript parsing and contraction order optimization.
418/// Ideal for hot loops where the same contraction is executed repeatedly
419/// on tensors of the same shape.
420///
421/// # Errors
422///
423/// Returns an error if the operand shapes do not match those used to
424/// build the contraction tree.
425pub fn einsum_with_plan<T: Scalar + HasAlgebra>(
426 tree: &ContractionTree,
427 operands: &[&Tensor<T>],
428) -> Result<Tensor<T>> {
429 todo!()
430}
431
432// ============================================================================
433// Consuming variants (take ownership of input tensors for buffer reuse)
434// ============================================================================
435
436/// Execute einsum using string notation, consuming the input tensors.
437///
438/// Takes ownership of the operands, allowing the implementation to reuse
439/// their buffers for intermediate results or the final output. This avoids
440/// allocation when an operand buffer is already the right shape and layout.
441///
442/// # Examples
443///
444/// ```ignore
445/// use tenferro_einsum::einsum_owned;
446/// use tenferro_tensor::{Tensor, MemoryOrder};
447///
448/// let col = MemoryOrder::ColumnMajor;
449/// let a = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], col).unwrap();
450/// let b = Tensor::<f64>::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2], col).unwrap();
451///
452/// // `a` and `b` are consumed; their buffers may be reused
453/// let c = einsum_owned("ij,jk->ik", vec![a, b]).unwrap();
454/// ```
455///
456/// # Errors
457///
458/// Returns an error if the notation is invalid or tensor shapes are
459/// incompatible with the subscripts.
460pub fn einsum_owned<T: Scalar + HasAlgebra>(
461 _subscripts: &str,
462 _operands: Vec<Tensor<T>>,
463) -> Result<Tensor<T>> {
464 todo!()
465}
466
467/// Execute einsum with pre-built [`Subscripts`], consuming the input tensors.
468///
469/// Combines the benefits of subscript caching ([`einsum_with_subscripts`])
470/// with buffer reuse from owned operands.
471///
472/// # Errors
473///
474/// Returns an error if tensor shapes are incompatible with the subscripts.
475pub fn einsum_with_subscripts_owned<T: Scalar + HasAlgebra>(
476 _subscripts: &Subscripts,
477 _operands: Vec<Tensor<T>>,
478) -> Result<Tensor<T>> {
479 todo!()
480}
481
482/// Execute einsum with a pre-optimized [`ContractionTree`], consuming the
483/// input tensors.
484///
485/// Combines the benefits of plan caching ([`einsum_with_plan`]) with
486/// buffer reuse from owned operands. Ideal for hot loops where the
487/// caller no longer needs the input tensors after contraction.
488///
489/// # Errors
490///
491/// Returns an error if the operand shapes do not match those used to
492/// build the contraction tree.
493pub fn einsum_with_plan_owned<T: Scalar + HasAlgebra>(
494 _tree: &ContractionTree,
495 _operands: Vec<Tensor<T>>,
496) -> Result<Tensor<T>> {
497 todo!()
498}
499
500// ============================================================================
501// Accumulating variants (write into pre-allocated output buffer)
502// ============================================================================
503
504/// Execute einsum using string notation, accumulating into an existing output.
505///
506/// Computes `output = alpha * einsum(operands) + beta * output`, writing
507/// the result into the provided output tensor. This avoids allocating a new
508/// output buffer on each call, which is critical for hot loops.
509///
510/// # Arguments
511///
512/// * `subscripts` — Einstein summation notation (e.g., `"ij,jk->ik"`)
513/// * `operands` — Input tensors
514/// * `alpha` — Scaling factor for the einsum result
515/// * `beta` — Scaling factor for the existing output contents
516/// * `output` — Pre-allocated output tensor (must have correct shape)
517///
518/// # Examples
519///
520/// ```ignore
521/// use tenferro_einsum::einsum_into;
522/// use tenferro_tensor::{Tensor, MemoryOrder};
523/// use tenferro_device::LogicalMemorySpace;
524///
525/// let col = MemoryOrder::ColumnMajor;
526/// let a = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], col).unwrap();
527/// let b = Tensor::<f64>::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2], col).unwrap();
528/// let mut c = Tensor::<f64>::zeros(&[2, 2], LogicalMemorySpace::MainMemory, col);
529///
530/// // Overwrite: C = A @ B
531/// einsum_into("ij,jk->ik", &[&a, &b], 1.0, 0.0, &mut c).unwrap();
532///
533/// // Accumulate: C += A @ B
534/// einsum_into("ij,jk->ik", &[&a, &b], 1.0, 1.0, &mut c).unwrap();
535/// ```
536///
537/// # Errors
538///
539/// Returns an error if the notation is invalid, tensor shapes are
540/// incompatible, or the output shape does not match the expected result.
541pub fn einsum_into<T: Scalar + HasAlgebra>(
542 subscripts: &str,
543 operands: &[&Tensor<T>],
544 alpha: T,
545 beta: T,
546 output: &mut Tensor<T>,
547) -> Result<()> {
548 todo!()
549}
550
551/// Execute einsum with pre-built [`Subscripts`], accumulating into an existing output.
552///
553/// Computes `output = alpha * einsum(operands) + beta * output`.
554/// Avoids re-parsing the subscript string on each call.
555///
556/// # Examples
557///
558/// ```ignore
559/// use tenferro_einsum::{einsum_with_subscripts_into, Subscripts};
560/// use tenferro_tensor::{Tensor, MemoryOrder};
561/// use tenferro_device::LogicalMemorySpace;
562///
563/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
564/// let mut c = Tensor::<f64>::zeros(&[3, 5], LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
565///
566/// // C = 1.0 * (A @ B) + 0.0 * C
567/// einsum_with_subscripts_into(&subs, &[&a, &b], 1.0, 0.0, &mut c).unwrap();
568/// ```
569///
570/// # Errors
571///
572/// Returns an error if tensor shapes are incompatible with the subscripts
573/// or the output shape does not match.
574pub fn einsum_with_subscripts_into<T: Scalar + HasAlgebra>(
575 subscripts: &Subscripts,
576 operands: &[&Tensor<T>],
577 alpha: T,
578 beta: T,
579 output: &mut Tensor<T>,
580) -> Result<()> {
581 todo!()
582}
583
584/// Execute einsum with a pre-optimized [`ContractionTree`], accumulating
585/// into an existing output.
586///
587/// Computes `output = alpha * einsum(operands) + beta * output`.
588/// Avoids both subscript parsing and contraction order optimization.
589/// This is the fastest variant for hot loops with pre-allocated buffers.
590///
591/// # Examples
592///
593/// ```ignore
594/// use tenferro_einsum::{einsum_with_plan_into, ContractionTree, Subscripts};
595/// use tenferro_tensor::{Tensor, MemoryOrder};
596/// use tenferro_device::LogicalMemorySpace;
597///
598/// let col = MemoryOrder::ColumnMajor;
599/// let subs = Subscripts::new(&[&[0, 1], &[1, 2]], &[0, 2]);
600/// let tree = ContractionTree::optimize(&subs, &[&[3, 4], &[4, 5]]).unwrap();
601/// let mut c = Tensor::<f64>::zeros(&[3, 5], LogicalMemorySpace::MainMemory, col);
602///
603/// // Hot loop: reuse output buffer, no allocation per iteration
604/// for _ in 0..1000 {
605/// einsum_with_plan_into(&tree, &[&a, &b], 1.0, 0.0, &mut c).unwrap();
606/// }
607/// ```
608///
609/// # Errors
610///
611/// Returns an error if the operand shapes do not match those used to
612/// build the contraction tree, or the output shape is incorrect.
613pub fn einsum_with_plan_into<T: Scalar + HasAlgebra>(
614 tree: &ContractionTree,
615 operands: &[&Tensor<T>],
616 alpha: T,
617 beta: T,
618 output: &mut Tensor<T>,
619) -> Result<()> {
620 todo!()
621}
622
623// ============================================================================
624// Automatic differentiation support
625// ============================================================================
626
627/// Tracked einsum (reverse-mode AD).
628///
629/// This is the AD-aware counterpart of [`einsum`]. It records the operation
630/// on the reverse-mode tape so that [`chainrules::Tape::pullback`] can
631/// compute gradients through it.
632///
633/// # Examples
634///
635/// ```ignore
636/// use chainrules::Tape;
637/// use tenferro_einsum::tracked_einsum;
638/// use tenferro_tensor::{MemoryOrder, Tensor};
639/// use tenferro_device::LogicalMemorySpace;
640///
641/// let tape = Tape::<Tensor<f64>>::new();
642/// let a = tape.leaf(Tensor::ones(
643/// &[2, 3],
644/// LogicalMemorySpace::MainMemory,
645/// MemoryOrder::ColumnMajor,
646/// ));
647/// let b = tape.leaf(Tensor::ones(
648/// &[3, 4],
649/// LogicalMemorySpace::MainMemory,
650/// MemoryOrder::ColumnMajor,
651/// ));
652/// let c = tracked_einsum("ij,jk->ik", &[&a, &b]).unwrap();
653/// let loss = tracked_einsum("ij,ij->", &[&c, &c]).unwrap();
654/// let grads = tape.pullback(&loss).unwrap();
655/// let _ga = grads.get(a.node_id().unwrap()).unwrap();
656/// ```
657pub fn tracked_einsum<T: Scalar + HasAlgebra>(
658 _subscripts: &str,
659 _operands: &[&TrackedTensor<Tensor<T>>],
660) -> AdResult<TrackedTensor<Tensor<T>>>
661where
662 Tensor<T>: Differentiable,
663{
664 todo!()
665}
666
667/// Dual einsum (forward-mode JVP propagation).
668///
669/// This is the AD-aware counterpart of [`einsum`] for forward-mode.
670/// It propagates tangent vectors through the einsum operation.
671///
672/// # Examples
673///
674/// ```ignore
675/// use chainrules::DualTensor;
676/// use tenferro_einsum::dual_einsum;
677/// use tenferro_tensor::{MemoryOrder, Tensor};
678/// use tenferro_device::LogicalMemorySpace;
679///
680/// let col = MemoryOrder::ColumnMajor;
681/// let mem = LogicalMemorySpace::MainMemory;
682/// let a = Tensor::<f64>::ones(&[2, 3], mem, col);
683/// let da = Tensor::<f64>::ones(&[2, 3], mem, col);
684/// let b = Tensor::<f64>::ones(&[3, 4], mem, col);
685///
686/// let a_dual = DualTensor::with_tangent(a, da).unwrap();
687/// let b_dual = DualTensor::new(b);
688/// let c_dual = dual_einsum("ij,jk->ik", &[&a_dual, &b_dual]).unwrap();
689/// let _tangent = c_dual.tangent();
690/// ```
691pub fn dual_einsum<T: Scalar + HasAlgebra>(
692 _subscripts: &str,
693 _operands: &[&DualTensor<Tensor<T>>],
694) -> AdResult<DualTensor<Tensor<T>>>
695where
696 Tensor<T>: Differentiable,
697{
698 todo!()
699}
700
701/// Reverse-mode rule (rrule) for einsum without building a global tape.
702///
703/// Computes the pullback (vector-Jacobian product) for an einsum operation.
704/// Returns one gradient tensor per input operand.
705///
706/// Named after Julia's ChainRules.jl convention.
707/// This API is intended for language interop and manual AD.
708///
709/// # Examples
710///
711/// ```ignore
712/// use tenferro_einsum::einsum_rrule;
713/// use tenferro_tensor::{MemoryOrder, Tensor};
714/// use tenferro_device::LogicalMemorySpace;
715///
716/// let col = MemoryOrder::ColumnMajor;
717/// let mem = LogicalMemorySpace::MainMemory;
718/// let a = Tensor::<f64>::ones(&[2, 3], mem, col);
719/// let b = Tensor::<f64>::ones(&[3, 4], mem, col);
720/// let grad_c = Tensor::<f64>::ones(&[2, 4], mem, col);
721///
722/// let grads = einsum_rrule("ij,jk->ik", &[&a, &b], &grad_c).unwrap();
723/// assert_eq!(grads.len(), 2);
724/// ```
725pub fn einsum_rrule<T: Scalar + HasAlgebra>(
726 _subscripts: &str,
727 _operands: &[&Tensor<T>],
728 _cotangent: &Tensor<T>,
729) -> Result<Vec<Tensor<T>>> {
730 todo!()
731}
732
733/// Forward-mode rule (frule) for einsum without building a global tape.
734///
735/// Computes the pushforward (Jacobian-vector product) for an einsum operation.
736/// Inputs without tangent should use `None`.
737///
738/// Named after Julia's ChainRules.jl convention.
739/// This API is intended for language interop and manual AD.
740///
741/// # Examples
742///
743/// ```ignore
744/// use tenferro_einsum::einsum_frule;
745/// use tenferro_tensor::{MemoryOrder, Tensor};
746/// use tenferro_device::LogicalMemorySpace;
747///
748/// let col = MemoryOrder::ColumnMajor;
749/// let mem = LogicalMemorySpace::MainMemory;
750/// let a = Tensor::<f64>::ones(&[2, 3], mem, col);
751/// let b = Tensor::<f64>::ones(&[3, 4], mem, col);
752/// let da = Tensor::<f64>::ones(&[2, 3], mem, col);
753///
754/// let dc = einsum_frule("ij,jk->ik", &[&a, &b], &[Some(&da), None]).unwrap();
755/// ```
756pub fn einsum_frule<T: Scalar + HasAlgebra>(
757 _subscripts: &str,
758 _primals: &[&Tensor<T>],
759 _tangents: &[Option<&Tensor<T>>],
760) -> Result<Tensor<T>> {
761 todo!()
762}
763
764/// Local HVP rule for einsum without building a global tape.
765///
766/// Computes the forward-over-reverse Hessian-vector product for an einsum
767/// operation. Given primals, their tangents (direction v), an output
768/// cotangent ḡ, and its tangent dḡ, returns `(gradient, hvp)` pairs
769/// for each input operand.
770///
771/// For C = einsum(subscripts, [A, B]):
772/// - gradient: standard pullback (e.g., ḡ_A = einsum(ḡ_C, B))
773/// - hvp: tangent of pullback (e.g., dḡ_A = einsum(dḡ_C, B) + einsum(ḡ_C, dB))
774///
775/// This API is intended for language interop and manual AD.
776///
777/// # Examples
778///
779/// ```ignore
780/// use tenferro_einsum::einsum_hvp;
781/// use tenferro_tensor::{MemoryOrder, Tensor};
782/// use tenferro_device::LogicalMemorySpace;
783///
784/// let col = MemoryOrder::ColumnMajor;
785/// let mem = LogicalMemorySpace::MainMemory;
786/// let a = Tensor::<f64>::ones(&[2, 3], mem, col);
787/// let b = Tensor::<f64>::ones(&[3, 4], mem, col);
788/// let da = Tensor::<f64>::ones(&[2, 3], mem, col);
789///
790/// let grad_c = Tensor::<f64>::ones(&[2, 4], mem, col);
791/// let dgrad_c = Tensor::<f64>::ones(&[2, 4], mem, col);
792///
793/// let results = einsum_hvp(
794/// "ij,jk->ik",
795/// &[&a, &b],
796/// &[Some(&da), None],
797/// &grad_c,
798/// &dgrad_c,
799/// ).unwrap();
800/// assert_eq!(results.len(), 2);
801/// let (_grad_a, _hvp_a) = &results[0];
802/// let (_grad_b, _hvp_b) = &results[1];
803/// ```
804pub fn einsum_hvp<T: Scalar + HasAlgebra>(
805 _subscripts: &str,
806 _primals: &[&Tensor<T>],
807 _tangents: &[Option<&Tensor<T>>],
808 _cotangent: &Tensor<T>,
809 _cotangent_tangent: &Tensor<T>,
810) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
811 todo!()
812}