Skip to main content

tenferro/
lib.rs

1#![allow(clippy::multiple_bound_locations)]
2
3//! `tenferro`: traced tensor computation with StableHLO-style IR.
4//!
5//! This crate provides a tracing-based tensor computation framework where
6//! operations are recorded into a StableHLO-compatible intermediate
7//! representation, then compiled and executed on a backend (e.g., CPU).
8//!
9//! # Examples
10//!
11//! ```rust,ignore
12//! use tenferro::{CpuBackend, Engine, TracedTensor};
13//!
14//! let mut engine = Engine::new(CpuBackend::default());
15//! // ... build and execute traced computations
16//! ```
17
18pub use tenferro_tensor::{DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig};
19
20mod checkpoint;
21pub mod compiler;
22mod eager;
23pub mod eager_einsum;
24mod eager_emitter;
25pub mod eager_exec;
26pub(crate) mod eager_ops;
27pub(crate) mod eager_ops_elementwise;
28pub(crate) mod eager_ops_linalg;
29pub mod einsum;
30pub mod engine;
31pub mod error;
32pub mod exec;
33mod linalg_api;
34pub mod segment;
35pub mod shape_infer;
36pub mod sym_dim;
37pub mod traced;
38
39pub use eager::{EagerContext, EagerTensor};
40pub use engine::Engine;
41pub use linalg_api::{
42    cholesky, convert, det, eig, eigh, eigh_with_eps, eigvals, eigvalsh, inv, lu, norm, pinv,
43    pinv_with_rtol, qr, slogdet, solve, svd, svd_with_eps, triangular_solve,
44};
45pub use sym_dim::SymDim;
46pub use tenferro_tensor::cpu::CpuBackend;
47pub use tenferro_tensor::{DType, Tensor, TensorBackend, TensorScalar, TypedTensor};
48pub use traced::TracedTensor;
49
50/// Matrix multiplication helper for rank-2 traced tensors.
51///
52/// This contracts the last dimension of `a` with the first dimension of `b`.
53///
54/// # Examples
55///
56/// ```rust,ignore
57/// let c = tenferro::matmul(&a, &b);
58/// ```
59pub fn matmul(a: &TracedTensor, b: &TracedTensor) -> TracedTensor {
60    let config = DotGeneralConfig {
61        lhs_contracting_dims: vec![a.rank - 1],
62        rhs_contracting_dims: vec![0],
63        lhs_batch_dims: vec![],
64        rhs_batch_dims: vec![],
65        lhs_rank: a.rank,
66        rhs_rank: b.rank,
67    };
68    a.dot_general(b, config)
69}
70
71/// Elementwise power helper with NumPy-style broadcasting.
72///
73/// # Examples
74///
75/// ```rust,ignore
76/// let y = tenferro::pow(&base, &exp);
77/// ```
78pub fn pow(base: &TracedTensor, exp: &TracedTensor) -> TracedTensor {
79    base.pow(exp)
80}