tenferro_linalg/lib.rs
1//! Batched matrix linear algebra decompositions with AD rules.
2//!
3//! This crate provides SVD, QR, LU, and eigendecomposition for tensors
4//! with shape `(m, n, *)`, adapted from PyTorch's `torch.linalg` for
5//! column-major layout:
6//!
7//! - **First 2 dimensions** are the matrix (`m × n`).
8//! - **All following dimensions** (`*`) are independent batch dimensions.
9//! - Input must be **column-major contiguous** (LAPACK/cuSOLVER native).
10//!
11//! This convention mirrors PyTorch's `(*, m, n)` but is flipped for
12//! col-major: in col-major the first dimensions are contiguous, so
13//! placing the matrix there ensures LAPACK can operate directly without
14//! transposition.
15//!
16//! This module is **context-agnostic**: it does not know about tensor
17//! networks, MPS, or any specific application. If you need to decompose
18//! a tensor along arbitrary legs, `permute` + `reshape` +
19//! `contiguous(ColumnMajor)` before calling these functions.
20//!
21//! # AD rules
22//!
23//! Each decomposition has stateless `_rrule` (reverse-mode / VJP) and
24//! `_frule` (forward-mode / JVP) functions. These implement matrix-level
25//! AD formulas (Mathieu 2019 et al.) using batched operations that
26//! naturally broadcast over batch dimensions `*`.
27//!
28//! There are no `tracked_*` / `dual_*` functions — the chainrules tape
29//! engine composes `permute_backward` + `reshape_backward` + `svd_rrule`
30//! via the standard chain rule automatically.
31//!
32//! # Examples
33//!
34//! ## SVD of a matrix
35//!
36//! ```ignore
37//! use tenferro_linalg::{svd, SvdOptions};
38//! use tenferro_tensor::{Tensor, MemoryOrder};
39//! use tenferro_device::LogicalMemorySpace;
40//!
41//! let col = MemoryOrder::ColumnMajor;
42//! let mem = LogicalMemorySpace::MainMemory;
43//!
44//! // 2D matrix: shape [3, 4]
45//! let a = Tensor::<f64>::zeros(&[3, 4], mem, col);
46//! let result = svd(&a, None).unwrap();
47//! // result.u: shape [3, 3] (m × k, k = min(m,n) = 3)
48//! // result.s: shape [3] (singular values)
49//! // result.vt: shape [3, 4] (k × n)
50//! ```
51//!
52//! ## Batched SVD
53//!
54//! ```ignore
55//! use tenferro_linalg::svd;
56//! use tenferro_tensor::{Tensor, MemoryOrder};
57//! use tenferro_device::LogicalMemorySpace;
58//!
59//! let col = MemoryOrder::ColumnMajor;
60//! let mem = LogicalMemorySpace::MainMemory;
61//!
62//! // Batched: shape [m, n, batch] = [3, 4, 10]
63//! let a = Tensor::<f64>::zeros(&[3, 4, 10], mem, col);
64//! let result = svd(&a, None).unwrap();
65//! // result.u: shape [3, 3, 10]
66//! // result.s: shape [3, 10]
67//! // result.vt: shape [3, 4, 10]
68//! ```
69//!
70//! ## Decomposing a 4D tensor along specific legs
71//!
72//! ```ignore
73//! use tenferro_linalg::svd;
74//! use tenferro_tensor::{Tensor, MemoryOrder};
75//! use tenferro_device::LogicalMemorySpace;
76//!
77//! let col = MemoryOrder::ColumnMajor;
78//! let mem = LogicalMemorySpace::MainMemory;
79//!
80//! // 4D tensor [2, 3, 4, 5] — want SVD with left=[0,1], right=[2,3]
81//! let t = Tensor::<f64>::zeros(&[2, 3, 4, 5], mem, col);
82//!
83//! // User's responsibility: permute + reshape + contiguous
84//! let mat = t.permute(&[0, 1, 2, 3]) // already in order
85//! .reshape(&[6, 20]).unwrap() // m = 2*3 = 6, n = 4*5 = 20
86//! .contiguous(col);
87//! let result = svd(&mat, None).unwrap();
88//! // Then reshape result.u, result.vt back to desired tensor shape
89//! ```
90//!
91//! ## Reverse-mode AD (stateless rrule)
92//!
93//! ```ignore
94//! use tenferro_linalg::{svd, svd_rrule, SvdCotangent};
95//! use tenferro_tensor::{Tensor, MemoryOrder};
96//! use tenferro_device::LogicalMemorySpace;
97//!
98//! let col = MemoryOrder::ColumnMajor;
99//! let mem = LogicalMemorySpace::MainMemory;
100//!
101//! let a = Tensor::<f64>::zeros(&[3, 4], mem, col);
102//! let result = svd(&a, None).unwrap();
103//!
104//! // Full cotangent: gradient through U, S, and Vt
105//! let cotangent = SvdCotangent {
106//! u: Some(Tensor::ones(&[3, 3], mem, col)),
107//! s: Some(Tensor::ones(&[3], mem, col)),
108//! vt: Some(Tensor::ones(&[3, 4], mem, col)),
109//! };
110//! let grad_a = svd_rrule(&a, &cotangent, None).unwrap();
111//! // grad_a has same shape as a: [3, 4]
112//!
113//! // Partial cotangent: gradient only through singular values (always stable)
114//! let cotangent_s_only = SvdCotangent {
115//! u: None,
116//! s: Some(Tensor::ones(&[3], mem, col)),
117//! vt: None,
118//! };
119//! let grad_a2 = svd_rrule(&a, &cotangent_s_only, None).unwrap();
120//! ```
121
122use chainrules_core::AdResult;
123use tenferro_algebra::Scalar;
124use tenferro_device::Result;
125use tenferro_prims::UnaryOp;
126use tenferro_tensor::Tensor;
127
128// ============================================================================
129// Result types
130// ============================================================================
131
132/// SVD result: `A = U * diag(S) * Vt`.
133///
134/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
135///
136/// - `u`: shape `(m, k, *)`
137/// - `s`: shape `(k, *)` (singular values, descending order)
138/// - `vt`: shape `(k, n, *)`
139///
140/// # Examples
141///
142/// ```ignore
143/// use tenferro_linalg::svd;
144/// use tenferro_tensor::{Tensor, MemoryOrder};
145/// use tenferro_device::LogicalMemorySpace;
146///
147/// let a = Tensor::<f64>::zeros(&[3, 4],
148/// LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
149/// let result = svd(&a, None).unwrap();
150/// assert_eq!(result.s.ndim(), 1);
151/// ```
152pub struct SvdResult<T: Scalar> {
153 /// Left singular vectors. Shape: `(m, k, *)`.
154 pub u: Tensor<T>,
155 /// Singular values (descending order). Shape: `(k, *)`.
156 pub s: Tensor<T>,
157 /// Right singular vectors (conjugate-transposed). Shape: `(k, n, *)`.
158 pub vt: Tensor<T>,
159}
160
161/// Options for truncated SVD.
162///
163/// When both `max_rank` and `cutoff` are specified, the more restrictive
164/// constraint applies.
165///
166/// # Examples
167///
168/// ```
169/// use tenferro_linalg::SvdOptions;
170///
171/// // Keep at most 10 singular values above 1e-12
172/// let opts = SvdOptions {
173/// max_rank: Some(10),
174/// cutoff: Some(1e-12),
175/// };
176/// ```
177#[derive(Debug, Clone)]
178pub struct SvdOptions {
179 /// Maximum number of singular values to keep. `None` means no limit.
180 pub max_rank: Option<usize>,
181 /// Discard singular values below this threshold. `None` means no cutoff.
182 pub cutoff: Option<f64>,
183}
184
185impl Default for SvdOptions {
186 fn default() -> Self {
187 Self {
188 max_rank: None,
189 cutoff: None,
190 }
191 }
192}
193
194/// QR decomposition result: `A = Q * R`.
195///
196/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
197///
198/// - `q`: shape `(m, k, *)` (orthonormal columns)
199/// - `r`: shape `(k, n, *)` (upper triangular)
200///
201/// # Examples
202///
203/// ```ignore
204/// use tenferro_linalg::qr;
205/// use tenferro_tensor::{Tensor, MemoryOrder};
206/// use tenferro_device::LogicalMemorySpace;
207///
208/// let a = Tensor::<f64>::zeros(&[4, 3],
209/// LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
210/// let result = qr(&a).unwrap();
211/// assert_eq!(result.q.dims(), &[4, 3]);
212/// assert_eq!(result.r.dims(), &[3, 3]);
213/// ```
214pub struct QrResult<T: Scalar> {
215 /// Orthonormal factor. Shape: `(m, k, *)`.
216 pub q: Tensor<T>,
217 /// Upper triangular factor. Shape: `(k, n, *)`.
218 pub r: Tensor<T>,
219}
220
221/// LU decomposition result: `A = P * L * U` (partial pivoting).
222///
223/// For an input of shape `(m, n, *)` with `k = min(m, n)`:
224///
225/// - `p`: permutation indices, shape `(m, *)`
226/// - `l`: shape `(m, k, *)` (unit lower triangular)
227/// - `u`: shape `(k, n, *)` (upper triangular)
228///
229/// # Examples
230///
231/// ```ignore
232/// use tenferro_linalg::lu;
233/// use tenferro_tensor::{Tensor, MemoryOrder};
234/// use tenferro_device::LogicalMemorySpace;
235///
236/// let a = Tensor::<f64>::zeros(&[3, 3],
237/// LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
238/// let result = lu(&a).unwrap();
239/// ```
240pub struct LuResult<T: Scalar> {
241 /// Row permutation indices (partial pivoting). Shape: `(m, *)`.
242 pub p: Vec<usize>,
243 /// Unit lower triangular factor. Shape: `(m, k, *)`.
244 pub l: Tensor<T>,
245 /// Upper triangular factor. Shape: `(k, n, *)`.
246 pub u: Tensor<T>,
247}
248
249/// Eigendecomposition result: `A * V = V * diag(values)`.
250///
251/// Only valid for square matrices (`m == n`).
252///
253/// - `values`: shape `(n, *)` (eigenvalues)
254/// - `vectors`: shape `(n, n, *)` (right eigenvectors as columns)
255///
256/// # Examples
257///
258/// ```ignore
259/// use tenferro_linalg::eigen;
260/// use tenferro_tensor::{Tensor, MemoryOrder};
261/// use tenferro_device::LogicalMemorySpace;
262///
263/// let a = Tensor::<f64>::zeros(&[3, 3],
264/// LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
265/// let result = eigen(&a).unwrap();
266/// assert_eq!(result.values.dims(), &[3]);
267/// assert_eq!(result.vectors.dims(), &[3, 3]);
268/// ```
269pub struct EigenResult<T: Scalar> {
270 /// Eigenvalues. Shape: `(n, *)`.
271 pub values: Tensor<T>,
272 /// Right eigenvectors (columns). Shape: `(n, n, *)`.
273 pub vectors: Tensor<T>,
274}
275
276// ============================================================================
277// Primary decomposition functions
278// ============================================================================
279
280/// Compute the SVD of a batched matrix.
281///
282/// Input shape: `(m, n, *)`. Must be column-major contiguous.
283///
284/// # Arguments
285///
286/// * `tensor` — Input tensor of shape `(m, n, *)`
287/// * `options` — Optional truncation parameters
288///
289/// # Examples
290///
291/// ```ignore
292/// use tenferro_linalg::{svd, SvdOptions};
293/// use tenferro_tensor::{Tensor, MemoryOrder};
294/// use tenferro_device::LogicalMemorySpace;
295///
296/// let col = MemoryOrder::ColumnMajor;
297/// let a = Tensor::<f64>::zeros(&[3, 4],
298/// LogicalMemorySpace::MainMemory, col);
299///
300/// // Full SVD
301/// let result = svd(&a, None).unwrap();
302///
303/// // Truncated SVD
304/// let opts = SvdOptions { max_rank: Some(2), cutoff: None };
305/// let result = svd(&a, Some(&opts)).unwrap();
306/// ```
307///
308/// # Errors
309///
310/// Returns an error if the input has fewer than 2 dimensions.
311pub fn svd<T: Scalar>(_tensor: &Tensor<T>, _options: Option<&SvdOptions>) -> Result<SvdResult<T>> {
312 todo!()
313}
314
315/// Compute the QR decomposition of a batched matrix.
316///
317/// Input shape: `(m, n, *)`. Must be column-major contiguous.
318///
319/// # Examples
320///
321/// ```ignore
322/// use tenferro_linalg::qr;
323/// use tenferro_tensor::{Tensor, MemoryOrder};
324/// use tenferro_device::LogicalMemorySpace;
325///
326/// let a = Tensor::<f64>::zeros(&[4, 3],
327/// LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
328/// let result = qr(&a).unwrap();
329/// ```
330///
331/// # Errors
332///
333/// Returns an error if the input has fewer than 2 dimensions.
334pub fn qr<T: Scalar>(_tensor: &Tensor<T>) -> Result<QrResult<T>> {
335 todo!()
336}
337
338/// Compute the LU decomposition of a batched matrix (partial pivoting).
339///
340/// Input shape: `(m, n, *)`. Must be column-major contiguous.
341///
342/// # Examples
343///
344/// ```ignore
345/// use tenferro_linalg::lu;
346/// use tenferro_tensor::{Tensor, MemoryOrder};
347/// use tenferro_device::LogicalMemorySpace;
348///
349/// let a = Tensor::<f64>::zeros(&[3, 3],
350/// LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
351/// let result = lu(&a).unwrap();
352/// ```
353///
354/// # Errors
355///
356/// Returns an error if the input has fewer than 2 dimensions.
357pub fn lu<T: Scalar>(_tensor: &Tensor<T>) -> Result<LuResult<T>> {
358 todo!()
359}
360
361/// Compute the eigendecomposition of a batched square matrix.
362///
363/// Input shape: `(n, n, *)`. Must be column-major contiguous.
364///
365/// # Examples
366///
367/// ```ignore
368/// use tenferro_linalg::eigen;
369/// use tenferro_tensor::{Tensor, MemoryOrder};
370/// use tenferro_device::LogicalMemorySpace;
371///
372/// let a = Tensor::<f64>::zeros(&[3, 3],
373/// LogicalMemorySpace::MainMemory, MemoryOrder::ColumnMajor);
374/// let result = eigen(&a).unwrap();
375/// ```
376///
377/// # Errors
378///
379/// Returns an error if the input has fewer than 2 dimensions or
380/// the first two dimensions are not equal.
381pub fn eigen<T: Scalar>(_tensor: &Tensor<T>) -> Result<EigenResult<T>> {
382 todo!()
383}
384
385// ============================================================================
386// AD cotangent types
387// ============================================================================
388
389/// Cotangent (adjoint) for SVD outputs.
390///
391/// Each field is `Option` because the user may not need gradients for
392/// all outputs (e.g., only `s` for singular value optimization).
393///
394/// # Examples
395///
396/// ```ignore
397/// use tenferro_linalg::SvdCotangent;
398/// use tenferro_tensor::{Tensor, MemoryOrder};
399/// use tenferro_device::LogicalMemorySpace;
400///
401/// let col = MemoryOrder::ColumnMajor;
402/// let mem = LogicalMemorySpace::MainMemory;
403///
404/// // Only need gradient through singular values
405/// let cotangent = SvdCotangent {
406/// u: None,
407/// s: Some(Tensor::<f64>::ones(&[3], mem, col)),
408/// vt: None,
409/// };
410/// ```
411pub struct SvdCotangent<T: Scalar> {
412 /// Cotangent for U. Shape must match `SvdResult::u`.
413 pub u: Option<Tensor<T>>,
414 /// Cotangent for S. Shape must match `SvdResult::s`.
415 pub s: Option<Tensor<T>>,
416 /// Cotangent for Vt. Shape must match `SvdResult::vt`.
417 pub vt: Option<Tensor<T>>,
418}
419
420/// Cotangent (adjoint) for QR outputs.
421///
422/// # Examples
423///
424/// ```ignore
425/// use tenferro_linalg::QrCotangent;
426///
427/// let cotangent = QrCotangent::<f64> { q: None, r: None };
428/// ```
429pub struct QrCotangent<T: Scalar> {
430 /// Cotangent for Q. Shape must match `QrResult::q`.
431 pub q: Option<Tensor<T>>,
432 /// Cotangent for R. Shape must match `QrResult::r`.
433 pub r: Option<Tensor<T>>,
434}
435
436/// Cotangent (adjoint) for LU outputs.
437///
438/// Note: the permutation `p` is discrete and has no gradient.
439///
440/// # Examples
441///
442/// ```ignore
443/// use tenferro_linalg::LuCotangent;
444///
445/// let cotangent = LuCotangent::<f64> { l: None, u: None };
446/// ```
447pub struct LuCotangent<T: Scalar> {
448 /// Cotangent for L. Shape must match `LuResult::l`.
449 pub l: Option<Tensor<T>>,
450 /// Cotangent for U. Shape must match `LuResult::u`.
451 pub u: Option<Tensor<T>>,
452}
453
454/// Cotangent (adjoint) for eigendecomposition outputs.
455///
456/// # Examples
457///
458/// ```ignore
459/// use tenferro_linalg::EigenCotangent;
460///
461/// let cotangent = EigenCotangent::<f64> { values: None, vectors: None };
462/// ```
463pub struct EigenCotangent<T: Scalar> {
464 /// Cotangent for eigenvalues. Shape must match `EigenResult::values`.
465 pub values: Option<Tensor<T>>,
466 /// Cotangent for eigenvectors. Shape must match `EigenResult::vectors`.
467 pub vectors: Option<Tensor<T>>,
468}
469
470// ============================================================================
471// AD functions: rrule (reverse-mode, stateless)
472// ============================================================================
473
474/// Reverse-mode AD rule for SVD (VJP / pullback).
475///
476/// Computes the gradient of the input given cotangents for the SVD outputs.
477/// Uses batched matrix operations (Mathieu 2019) that broadcast over `*`.
478///
479/// # Examples
480///
481/// ```ignore
482/// use tenferro_linalg::{svd, svd_rrule, SvdCotangent};
483/// use tenferro_tensor::{Tensor, MemoryOrder};
484/// use tenferro_device::LogicalMemorySpace;
485///
486/// let col = MemoryOrder::ColumnMajor;
487/// let mem = LogicalMemorySpace::MainMemory;
488/// let a = Tensor::<f64>::zeros(&[3, 4], mem, col);
489///
490/// let cotangent = SvdCotangent {
491/// u: None,
492/// s: Some(Tensor::ones(&[3], mem, col)),
493/// vt: None,
494/// };
495/// let grad_a = svd_rrule(&a, &cotangent, None).unwrap();
496/// ```
497pub fn svd_rrule<T: Scalar>(
498 _tensor: &Tensor<T>,
499 _cotangent: &SvdCotangent<T>,
500 _options: Option<&SvdOptions>,
501) -> AdResult<Tensor<T>> {
502 // SVD reverse-mode AD (Mathieu 2019).
503 //
504 // Given: A = U · diag(S) · Vt, with cotangents dU, dS, dVt.
505 //
506 // Algorithm (all operations batched over `*` dims):
507 //
508 // 1. Forward pass: (U, S, Vt) = svd(A)
509 // → Already computed by the caller; recompute or cache as needed.
510 //
511 // 2. Build F-matrix: F_ij = 1/(σ_j² - σ_i²) for i≠j, 0 for i=j.
512 // Ops: ElementwiseMul(S, S) → S², broadcast, subtract, Reciprocal,
513 // zero diagonal.
514 // Prims used: ElementwiseMul, ElementwiseUnary(Reciprocal).
515 //
516 // 3. Compute Ut·dU (k×k batched):
517 // Ops: BatchedGemm(Ut, dU)
518 // Prims used: BatchedGemm.
519 //
520 // 4. Symmetrize: M = Ut·dU - (Ut·dU)^T via permute.
521 // Ops: permute (zero-copy), alpha/beta subtraction.
522 // Prims used: Permute (metadata only), BatchedGemm with beta=-1.
523 //
524 // 5. Hadamard product: F ⊙ M
525 // Prims used: ElementwiseMul.
526 //
527 // 6. Add diagonal dS: F⊙M + diag(dS)
528 // Prims used: AntiTrace (embed 1D → diagonal of 2D).
529 //
530 // 7. Assemble: dA = U · (F⊙M + diag(dS)) · Vt
531 // Prims used: BatchedGemm (two multiplications).
532 //
533 // 8. (Full-rank case, m > n) Add projector term:
534 // dA += (I - U·Ut) · dU · diag(1/S) · Vt
535 // Prims used: eye, BatchedGemm, ElementwiseUnary(Reciprocal).
536
537 // Suppress unused import warning in skeleton.
538 let _ = UnaryOp::Reciprocal;
539
540 todo!("SVD rrule: implement steps 1-8 using tenferro-prims operations")
541}
542
543/// Reverse-mode AD rule for QR (VJP / pullback).
544///
545/// # Examples
546///
547/// ```ignore
548/// use tenferro_linalg::{qr_rrule, QrCotangent};
549/// use tenferro_tensor::{Tensor, MemoryOrder};
550/// use tenferro_device::LogicalMemorySpace;
551///
552/// let col = MemoryOrder::ColumnMajor;
553/// let mem = LogicalMemorySpace::MainMemory;
554/// let a = Tensor::<f64>::zeros(&[4, 3], mem, col);
555/// let cotangent = QrCotangent {
556/// q: Some(Tensor::ones(&[4, 3], mem, col)),
557/// r: None,
558/// };
559/// let grad_a = qr_rrule(&a, &cotangent).unwrap();
560/// ```
561pub fn qr_rrule<T: Scalar>(
562 _tensor: &Tensor<T>,
563 _cotangent: &QrCotangent<T>,
564) -> AdResult<Tensor<T>> {
565 todo!()
566}
567
568/// Reverse-mode AD rule for LU (VJP / pullback).
569///
570/// # Examples
571///
572/// ```ignore
573/// use tenferro_linalg::{lu_rrule, LuCotangent};
574/// use tenferro_tensor::{Tensor, MemoryOrder};
575/// use tenferro_device::LogicalMemorySpace;
576///
577/// let col = MemoryOrder::ColumnMajor;
578/// let mem = LogicalMemorySpace::MainMemory;
579/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col);
580/// let cotangent = LuCotangent {
581/// l: Some(Tensor::ones(&[3, 3], mem, col)),
582/// u: None,
583/// };
584/// let grad_a = lu_rrule(&a, &cotangent).unwrap();
585/// ```
586pub fn lu_rrule<T: Scalar>(
587 _tensor: &Tensor<T>,
588 _cotangent: &LuCotangent<T>,
589) -> AdResult<Tensor<T>> {
590 todo!()
591}
592
593/// Reverse-mode AD rule for eigendecomposition (VJP / pullback).
594///
595/// # Examples
596///
597/// ```ignore
598/// use tenferro_linalg::{eigen_rrule, EigenCotangent};
599/// use tenferro_tensor::{Tensor, MemoryOrder};
600/// use tenferro_device::LogicalMemorySpace;
601///
602/// let col = MemoryOrder::ColumnMajor;
603/// let mem = LogicalMemorySpace::MainMemory;
604/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col);
605/// let cotangent = EigenCotangent {
606/// values: Some(Tensor::ones(&[3], mem, col)),
607/// vectors: None,
608/// };
609/// let grad_a = eigen_rrule(&a, &cotangent).unwrap();
610/// ```
611pub fn eigen_rrule<T: Scalar>(
612 _tensor: &Tensor<T>,
613 _cotangent: &EigenCotangent<T>,
614) -> AdResult<Tensor<T>> {
615 todo!()
616}
617
618// ============================================================================
619// AD functions: frule (forward-mode, stateless)
620// ============================================================================
621
622/// Forward-mode AD rule for SVD (JVP / pushforward).
623///
624/// Computes the JVP of all SVD outputs given a tangent for the input.
625/// Uses batched matrix operations that broadcast over `*`.
626///
627/// # Examples
628///
629/// ```ignore
630/// use tenferro_linalg::svd_frule;
631/// use tenferro_tensor::{Tensor, MemoryOrder};
632/// use tenferro_device::LogicalMemorySpace;
633///
634/// let col = MemoryOrder::ColumnMajor;
635/// let mem = LogicalMemorySpace::MainMemory;
636/// let a = Tensor::<f64>::zeros(&[3, 4], mem, col);
637/// let da = Tensor::<f64>::ones(&[3, 4], mem, col);
638/// let (result, dresult) = svd_frule(&a, &da, None).unwrap();
639/// ```
640pub fn svd_frule<T: Scalar>(
641 _tensor: &Tensor<T>,
642 _tangent: &Tensor<T>,
643 _options: Option<&SvdOptions>,
644) -> AdResult<(SvdResult<T>, SvdResult<T>)> {
645 todo!()
646}
647
648/// Forward-mode AD rule for QR (JVP / pushforward).
649///
650/// # Examples
651///
652/// ```ignore
653/// use tenferro_linalg::qr_frule;
654/// use tenferro_tensor::{Tensor, MemoryOrder};
655/// use tenferro_device::LogicalMemorySpace;
656///
657/// let col = MemoryOrder::ColumnMajor;
658/// let mem = LogicalMemorySpace::MainMemory;
659/// let a = Tensor::<f64>::zeros(&[4, 3], mem, col);
660/// let da = Tensor::<f64>::ones(&[4, 3], mem, col);
661/// let (result, dresult) = qr_frule(&a, &da).unwrap();
662/// ```
663pub fn qr_frule<T: Scalar>(
664 _tensor: &Tensor<T>,
665 _tangent: &Tensor<T>,
666) -> AdResult<(QrResult<T>, QrResult<T>)> {
667 todo!()
668}
669
670/// Forward-mode AD rule for LU (JVP / pushforward).
671///
672/// # Examples
673///
674/// ```ignore
675/// use tenferro_linalg::lu_frule;
676/// use tenferro_tensor::{Tensor, MemoryOrder};
677/// use tenferro_device::LogicalMemorySpace;
678///
679/// let col = MemoryOrder::ColumnMajor;
680/// let mem = LogicalMemorySpace::MainMemory;
681/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col);
682/// let da = Tensor::<f64>::ones(&[3, 3], mem, col);
683/// let (result, dresult) = lu_frule(&a, &da).unwrap();
684/// ```
685pub fn lu_frule<T: Scalar>(
686 _tensor: &Tensor<T>,
687 _tangent: &Tensor<T>,
688) -> AdResult<(LuResult<T>, LuResult<T>)> {
689 todo!()
690}
691
692/// Forward-mode AD rule for eigendecomposition (JVP / pushforward).
693///
694/// # Examples
695///
696/// ```ignore
697/// use tenferro_linalg::eigen_frule;
698/// use tenferro_tensor::{Tensor, MemoryOrder};
699/// use tenferro_device::LogicalMemorySpace;
700///
701/// let col = MemoryOrder::ColumnMajor;
702/// let mem = LogicalMemorySpace::MainMemory;
703/// let a = Tensor::<f64>::zeros(&[3, 3], mem, col);
704/// let da = Tensor::<f64>::ones(&[3, 3], mem, col);
705/// let (result, dresult) = eigen_frule(&a, &da).unwrap();
706/// ```
707pub fn eigen_frule<T: Scalar>(
708 _tensor: &Tensor<T>,
709 _tangent: &Tensor<T>,
710) -> AdResult<(EigenResult<T>, EigenResult<T>)> {
711 todo!()
712}