tenferro_device/lib.rs
1//! Device abstraction and shared error types for the tenferro workspace.
2//!
3//! This crate provides:
4//! - [`LogicalMemorySpace`] enum representing where tensor data resides
5//! - [`ComputeDevice`] enum representing hardware compute devices
6//! - [`OpKind`] enum classifying tensor operations for device selection
7//! - [`preferred_compute_devices`] for querying compatible devices
8//! - [`Error`] and [`Result`] types used across all tenferro crates
9//!
10//! # Examples
11//!
12//! ```
13//! use tenferro_device::{LogicalMemorySpace, ComputeDevice};
14//!
15//! let space = LogicalMemorySpace::MainMemory;
16//! let dev = ComputeDevice::Cpu { device_id: 0 };
17//! assert_eq!(format!("{dev}"), "cpu:0");
18//! ```
19
20use std::fmt;
21
22/// Logical memory space where tensor data resides.
23///
24/// Separates the concept of "where data lives" from "which hardware
25/// computes on it". A tensor on [`MainMemory`](LogicalMemorySpace::MainMemory)
26/// can be processed by any CPU, while a tensor on
27/// [`GpuMemory`](LogicalMemorySpace::GpuMemory) can be processed by any
28/// compute device with access to that GPU memory space.
29///
30/// The variants align with DLPack `DLDeviceType` constants:
31///
32/// | Variant | DLPack `device_type` |
33/// |---------|---------------------|
34/// | `MainMemory` | `kDLCPU` (1) |
35/// | `PinnedMemory` | `kDLCUDAHost` (3) / `kDLROCMHost` (11) |
36/// | `GpuMemory` | `kDLCUDA` (2) / `kDLROCM` (10) |
37/// | `ManagedMemory` | `kDLCUDAManaged` (13) |
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum LogicalMemorySpace {
40 /// System main memory (CPU-accessible RAM).
41 ///
42 /// Corresponds to DLPack `kDLCPU` (`device_id = 0`).
43 MainMemory,
44 /// Host memory pinned for fast GPU transfer (`cudaMallocHost` / `hipMallocHost`).
45 ///
46 /// Accessible from the CPU but optimized for DMA transfer to/from
47 /// a specific GPU. Corresponds to DLPack `kDLCUDAHost` or `kDLROCMHost`
48 /// (`device_id = 0`).
49 PinnedMemory,
50 /// GPU-resident memory identified by a device ID.
51 ///
52 /// Corresponds to DLPack `kDLCUDA` or `kDLROCM` with the given `device_id`.
53 GpuMemory {
54 /// Zero-based GPU device index (matches DLPack `device_id`).
55 device_id: usize,
56 },
57 /// CUDA Unified/Managed memory (`cudaMallocManaged`).
58 ///
59 /// Accessible from both CPU and all GPUs. The CUDA runtime handles
60 /// page migration automatically. Corresponds to DLPack `kDLCUDAManaged`
61 /// (`device_id = 0`).
62 ManagedMemory,
63}
64
65/// Compute device that can execute tensor operations.
66///
67/// Unlike [`LogicalMemorySpace`], which describes where data resides,
68/// `ComputeDevice` identifies the hardware that performs the computation.
69/// Multiple compute devices may share access to the same memory space.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum ComputeDevice {
72 /// CPU compute device.
73 Cpu {
74 /// Zero-based CPU device index (0 = default global thread pool).
75 device_id: usize,
76 },
77 /// NVIDIA CUDA compute device.
78 Cuda {
79 /// Zero-based CUDA device index.
80 device_id: usize,
81 },
82 /// AMD HIP compute device.
83 Hip {
84 /// Zero-based HIP device index.
85 device_id: usize,
86 },
87}
88
89impl fmt::Display for ComputeDevice {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 match self {
92 ComputeDevice::Cpu { device_id } => write!(f, "cpu:{device_id}"),
93 ComputeDevice::Cuda { device_id } => write!(f, "cuda:{device_id}"),
94 ComputeDevice::Hip { device_id } => write!(f, "hip:{device_id}"),
95 }
96 }
97}
98
99/// Classification of tensor operations, used to query preferred compute
100/// devices for a given operation on a given memory space.
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub enum OpKind {
103 /// General tensor contraction.
104 Contract,
105 /// Batched GEMM (matrix-matrix multiply).
106 BatchedGemm,
107 /// Reduction (sum, max, min) over one or more modes.
108 Reduce,
109 /// Trace (diagonal contraction of paired modes).
110 Trace,
111 /// Mode permutation (transpose).
112 Permute,
113 /// Element-wise multiplication.
114 ElementwiseMul,
115}
116
117/// Return the preferred compute devices for a given operation on a memory space.
118///
119/// The returned list is ordered by preference (most preferred first).
120///
121/// # Errors
122///
123/// Returns [`Error::NoCompatibleComputeDevice`] if no compute device can
124/// execute the given operation on the specified memory space.
125///
126/// # Examples
127///
128/// ```ignore
129/// use tenferro_device::{
130/// preferred_compute_devices, ComputeDevice, LogicalMemorySpace, OpKind,
131/// };
132///
133/// let devices = preferred_compute_devices(
134/// LogicalMemorySpace::MainMemory,
135/// OpKind::BatchedGemm,
136/// ).unwrap();
137///
138/// // Typically includes CPU for main memory workloads.
139/// assert!(devices.contains(&ComputeDevice::Cpu { device_id: 0 }));
140/// ```
141pub fn preferred_compute_devices(
142 _space: LogicalMemorySpace,
143 _op_kind: OpKind,
144) -> Result<Vec<ComputeDevice>> {
145 todo!()
146}
147
148/// Error type used across the tenferro workspace.
149#[derive(Debug, thiserror::Error)]
150pub enum Error {
151 /// Tensor shapes are incompatible for the requested operation.
152 #[error("shape mismatch: expected {expected:?}, got {got:?}")]
153 ShapeMismatch {
154 /// Expected shape.
155 expected: Vec<usize>,
156 /// Actual shape.
157 got: Vec<usize>,
158 },
159
160 /// Tensor ranks (number of dimensions) do not match.
161 #[error("rank mismatch: expected {expected}, got {got}")]
162 RankMismatch {
163 /// Expected rank.
164 expected: usize,
165 /// Actual rank.
166 got: usize,
167 },
168
169 /// An error occurred on the compute device.
170 #[error("device error: {0}")]
171 DeviceError(String),
172
173 /// No compute device is compatible with the requested operation
174 /// on the given memory space.
175 #[error("no compatible compute device for {op:?} on {space:?}")]
176 NoCompatibleComputeDevice {
177 /// The memory space where the tensor data resides.
178 space: LogicalMemorySpace,
179 /// The operation that was requested.
180 op: OpKind,
181 },
182
183 /// Operations on tensors in different memory spaces are not supported
184 /// without explicit transfer.
185 #[error("cross-memory-space operation between {left:?} and {right:?}")]
186 CrossMemorySpaceOperation {
187 /// Memory space of the first operand.
188 left: LogicalMemorySpace,
189 /// Memory space of the second operand.
190 right: LogicalMemorySpace,
191 },
192
193 /// An invalid argument was provided.
194 #[error("invalid argument: {0}")]
195 InvalidArgument(String),
196
197 /// An error related to strided memory layout (invalid strides, out-of-bounds offset, etc.).
198 #[error("stride error: {0}")]
199 StrideError(String),
200}
201
202/// Result type alias using [`Error`].
203pub type Result<T> = std::result::Result<T, Error>;