Skip to main content

tenferro_ad/
lib.rs

1//! Automatic differentiation APIs for tenferro.
2//!
3//! This crate is the explicit opt-in boundary for traced and eager automatic
4//! differentiation. Primal graph construction and execution live in
5//! `tenferro-runtime`; tensor storage lives in `tenferro-tensor`, and CPU
6//! execution lives in `tenferro-cpu`.
7//!
8//! Use [`EagerRuntime`] and [`EagerTensor`] for PyTorch-style immediate
9//! execution where tracked variables accumulate gradients after `backward()`.
10//! Use [`TracedTensorAdExt`] or [`AdContext`] for JAX-style graph transforms
11//! such as `grad`, `vjp`, and `jvp` on [`tenferro_runtime::TracedTensor`]
12//! values. `AdContext` is the explicit place to add extension AD rule sets for
13//! operation-family crates such as `tenferro-linalg`.
14//!
15//! User-facing guides live at
16//! <https://tensor4all.org/tenferro-rs/guides/autodiff.html> and
17//! <https://tensor4all.org/tenferro-rs/guides/choosing-an-api.html>.
18//!
19//! # Examples
20//!
21//! ```rust
22//! use tenferro_ad::AdContext;
23//! use tenferro_runtime::TracedTensor;
24//!
25//! let ad = AdContext::builder().build().unwrap();
26//! let x = TracedTensor::from_vec_col_major(vec![], vec![3.0_f64]).unwrap();
27//! let loss = (&x * &x).unwrap();
28//! let dx = ad.grad(&loss, &x).unwrap();
29//! assert_eq!(dx.rank, 0);
30//! ```
31
32mod context;
33mod eager;
34mod eager_backend;
35mod eager_builder;
36pub(crate) mod eager_exec;
37pub(crate) mod eager_ops;
38pub(crate) mod eager_ops_elementwise;
39pub mod extension;
40mod shape_packing;
41pub mod traced;
42
43pub use context::{AdContext, AdContextBuilder};
44pub use eager::{EagerRuntime, EagerRuntimeCacheStats, EagerTensor};
45pub use eager_backend::EagerBackend;
46pub(crate) use tenferro_runtime::{extension_cache, extension_runtime, scalar_semantics};
47pub(crate) mod shape_infer {
48    pub use tenferro_runtime::extension::{
49        promote_dtype, promote_dtype_for_binary_op, promote_dtypes,
50    };
51}
52pub use tenferro_runtime::{
53    CompareDir, DType, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
54    Tensor,
55};
56pub use traced::TracedTensorAdExt;
57
58pub use tenferro_runtime::{ContextId, Error, Result};
59
60pub mod error {
61    pub use tenferro_runtime::{ContextId, Error, Result};
62}
63
64pub(crate) mod metadata {
65    pub use tenferro_runtime::ad_support::{
66        metadata_scopes_for_scope, push_metadata_scope, register_scoped_live_graph_metadata,
67        register_scoped_metadata_batch, register_scoped_value_metadata, tensor_meta_from_tensor,
68        GlobalMetadataScope,
69    };
70}