Skip to main content

tenferro_linalg/
lib.rs

1//! Linear algebra extension operations for tenferro.
2//!
3//! This crate owns the graph-facing linalg op payloads and runtime
4//! registration. Tensor-facing operations are exposed through extension traits.
5//! CPU backend kernels live in this crate behind the linalg backend trait.
6//!
7//! # Examples
8//!
9//! ```
10//! use tenferro_linalg::TracedTensorLinalgExt;
11//! use tenferro_cpu::CpuBackend;
12//! use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
13//!
14//! let a = TracedTensor::from_vec_col_major(
15//!     vec![2, 2],
16//!     vec![4.0_f64, 2.0, 2.0, 3.0],
17//! )
18//! .unwrap();
19//! let l = a.cholesky().unwrap();
20//!
21//! let mut compiler = GraphCompiler::new();
22//! let program = compiler.compile(&l).unwrap();
23//! let mut executor = GraphExecutor::new(CpuBackend::new());
24//! executor.register_extension(tenferro_linalg::register_runtime).unwrap();
25//! let out = executor.run(&program).unwrap();
26//! assert_eq!(out.shape(), &[2, 2]);
27//! ```
28
29#[cfg(feature = "autodiff")]
30mod ad;
31pub mod backend;
32mod cpu;
33#[cfg(feature = "autodiff")]
34mod eager_backend;
35#[cfg(feature = "autodiff")]
36mod eager_ext;
37mod extension;
38#[cfg(feature = "cuda")]
39mod gpu;
40mod traced;
41
42#[cfg(feature = "autodiff")]
43pub use ad::ad_rules;
44#[cfg(feature = "autodiff")]
45pub use ad::support::{
46    all_linalg_ad_support, linalg_ad_support, LinalgAdOpKind, LinalgAdOutputSupport,
47    LinalgAdRuleSupport, LinalgAdSupport,
48};
49pub use backend::LinalgBackend;
50#[cfg(feature = "autodiff")]
51pub use eager_ext::EagerTensorLinalgExt;
52pub use extension::{register_runtime, LINALG_EXTENSION_FAMILY_ID};
53pub use traced::TracedTensorLinalgExt;