tenferro/
lib.rs

1#![allow(clippy::multiple_bound_locations)]
2
3//! `tenferro`: AD-aware tensor interface layer on top of `tenferro-rs`.
4//!
5//! The public surface includes reverse-mode helpers plus a narrow JVP
6//! transform, [`jvp`], which returns [`JvpResult`] with primal outputs and
7//! optional output tangents for the currently wired operations.
8//!
9//! # Examples
10//!
11//! ```rust
12//! use tenferro::{jvp, Tensor};
13//!
14//! let x = Tensor::from_slice(&[1.0_f64, 2.0], &[2]).unwrap();
15//! let primals = [x];
16//! let tangents = [Some(Tensor::from_slice(&[1.0_f64, 0.0], &[2]).unwrap())];
17//!
18//! let result = jvp(
19//!     |inputs| Ok(vec![inputs[0].add(&inputs[0]).unwrap().exp().unwrap().sum().unwrap()]),
20//!     &primals,
21//!     &tangents,
22//! )
23//! .unwrap();
24//!
25//! assert_eq!(result.outputs.len(), 1);
26//! assert_eq!(result.output_tangents.len(), 1);
27//! ```
28//!
29//! Builder `.run()` execution is configured through [`set_default_runtime`].
30//! The primary public frontend is [`Tensor`], backed by `tidu`'s
31//! `Value<DynTensor>` carrier. Custom downstream operations should use
32//! [`LinearizableOp`] and [`LinearizedOp`] directly; `jvp` is not a public
33//! dual-builder API and does not imply higher-order forward-mode or HVP
34//! support.
35//!
36//! Runtime-dispatched tensor operations such as [`Tensor::einsum`],
37//! [`Tensor::solve`], [`Tensor::solve_triangular`], [`Tensor::det`],
38//! [`Tensor::inv`], [`Tensor::slogdet`], [`Tensor::cholesky`],
39//! [`Tensor::lstsq`], [`Tensor::lu`], [`Tensor::norm`],
40//! [`Tensor::vector_norm`], [`Tensor::matrix_norm`], [`Tensor::qr`],
41//! [`Tensor::svd`], [`Tensor::eig`], [`Tensor::eigh`], [`Tensor::pinv`],
42//! and [`Tensor::matrix_exp`] require an installed runtime via
43//! [`set_default_runtime`] or [`runtime::with_runtime`].
44
45mod core;
46pub mod error;
47#[path = "jvp.rs"]
48mod jvp_api;
49pub mod runtime;
50mod scalar_value;
51pub mod snapshot;
52
53pub use core::{
54    CholeskyExResult, EigResult, EighResult, InvExResult, LstsqResult, LuFactorExResult,
55    LuFactorResult, LuPivot, LuResult, QrResult, ScalarType, SlogdetResult, SolveExResult,
56    SvdResult, Tensor,
57};
58pub use error::{Error, Result};
59pub use jvp_api::{jvp, JvpResult};
60pub use runtime::{set_default_runtime, with_default_runtime, DefaultRuntimeGuard, RuntimeContext};
61pub use scalar_value::ScalarValue;
62pub use tenferro_device::{ComputeDevice, LogicalMemorySpace};
63pub use tenferro_internal_ad_surface::{
64    backward, grad, with_ad_policy, AdExecutionPolicy, BackwardOptions, CheckpointHint,
65    CheckpointMode, GradOptions, LinearizableOp, LinearizedOp, MatrixNormOrd, NormKind, Schema,
66    SlotSchema, SvdOptions, Value, VectorNormOrd,
67};
68pub use tenferro_tensor::MemoryOrder;