1#![allow(clippy::multiple_bound_locations)]
2
3pub use tenferro_tensor::{DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig};
19
20mod checkpoint;
21pub mod compiler;
22mod eager;
23pub mod eager_einsum;
24mod eager_emitter;
25pub mod eager_exec;
26pub(crate) mod eager_ops;
27pub(crate) mod eager_ops_elementwise;
28pub(crate) mod eager_ops_linalg;
29pub mod einsum;
30pub mod engine;
31pub mod error;
32pub mod exec;
33mod linalg_api;
34pub mod segment;
35pub mod shape_infer;
36pub mod sym_dim;
37pub mod traced;
38
39pub use eager::{EagerContext, EagerTensor};
40pub use engine::Engine;
41pub use linalg_api::{
42 cholesky, convert, det, eig, eigh, eigh_with_eps, eigvals, eigvalsh, inv, lu, norm, pinv,
43 pinv_with_rtol, qr, slogdet, solve, svd, svd_with_eps, triangular_solve,
44};
45pub use sym_dim::SymDim;
46pub use tenferro_tensor::cpu::CpuBackend;
47pub use tenferro_tensor::{DType, Tensor, TensorBackend, TensorScalar, TypedTensor};
48pub use traced::TracedTensor;
49
50pub fn matmul(a: &TracedTensor, b: &TracedTensor) -> TracedTensor {
60 let config = DotGeneralConfig {
61 lhs_contracting_dims: vec![a.rank - 1],
62 rhs_contracting_dims: vec![0],
63 lhs_batch_dims: vec![],
64 rhs_batch_dims: vec![],
65 lhs_rank: a.rank,
66 rhs_rank: b.rank,
67 };
68 a.dot_general(b, config)
69}
70
71pub fn pow(base: &TracedTensor, exp: &TracedTensor) -> TracedTensor {
79 base.pow(exp)
80}