tenferro_device/
lib.rs

1//! Device abstraction and shared error types for the tenferro workspace.
2//!
3//! This crate provides:
4//! - [`LogicalMemorySpace`] enum representing where tensor data resides
5//! - [`ComputeDevice`] enum representing hardware compute devices
6//! - [`OpKind`] enum classifying tensor operations for device selection
7//! - [`preferred_compute_devices`] for querying compatible devices
8//! - [`Error`] and [`Result`] types used across all tenferro crates
9//!
10//! # Examples
11//!
12//! ```
13//! use tenferro_device::{LogicalMemorySpace, ComputeDevice};
14//!
15//! let space = LogicalMemorySpace::MainMemory;
16//! let dev = ComputeDevice::Cpu { device_id: 0 };
17//! assert_eq!(format!("{dev}"), "cpu:0");
18//! ```
19
20use std::fmt;
21
22/// Logical memory space where tensor data resides.
23///
24/// Separates the concept of "where data lives" from "which hardware
25/// computes on it". A tensor on [`MainMemory`](LogicalMemorySpace::MainMemory)
26/// can be processed by any CPU, while a tensor on
27/// [`GpuMemory`](LogicalMemorySpace::GpuMemory) can be processed by any
28/// compute device with access to that GPU memory space.
29///
30/// The variants align with DLPack `DLDeviceType` constants:
31///
32/// | Variant | DLPack `device_type` |
33/// |---------|---------------------|
34/// | `MainMemory` | `kDLCPU` (1) |
35/// | `PinnedMemory` | `kDLCUDAHost` (3) / `kDLROCMHost` (11) |
36/// | `GpuMemory` | `kDLCUDA` (2) / `kDLROCM` (10) |
37/// | `ManagedMemory` | `kDLCUDAManaged` (13) |
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum LogicalMemorySpace {
40    /// System main memory (CPU-accessible RAM).
41    ///
42    /// Corresponds to DLPack `kDLCPU` (`device_id = 0`).
43    MainMemory,
44    /// Host memory pinned for fast GPU transfer (`cudaMallocHost` / `hipMallocHost`).
45    ///
46    /// Accessible from the CPU but optimized for DMA transfer to/from
47    /// a specific GPU. Corresponds to DLPack `kDLCUDAHost` or `kDLROCMHost`
48    /// (`device_id = 0`).
49    PinnedMemory,
50    /// GPU-resident memory identified by a device ID.
51    ///
52    /// Corresponds to DLPack `kDLCUDA` or `kDLROCM` with the given `device_id`.
53    GpuMemory {
54        /// Zero-based GPU device index (matches DLPack `device_id`).
55        device_id: usize,
56    },
57    /// CUDA Unified/Managed memory (`cudaMallocManaged`).
58    ///
59    /// Accessible from both CPU and all GPUs. The CUDA runtime handles
60    /// page migration automatically. Corresponds to DLPack `kDLCUDAManaged`
61    /// (`device_id = 0`).
62    ManagedMemory,
63}
64
65/// Compute device that can execute tensor operations.
66///
67/// Unlike [`LogicalMemorySpace`], which describes where data resides,
68/// `ComputeDevice` identifies the hardware that performs the computation.
69/// Multiple compute devices may share access to the same memory space.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum ComputeDevice {
72    /// CPU compute device.
73    Cpu {
74        /// Zero-based CPU device index (0 = default global thread pool).
75        device_id: usize,
76    },
77    /// NVIDIA CUDA compute device.
78    Cuda {
79        /// Zero-based CUDA device index.
80        device_id: usize,
81    },
82    /// AMD HIP compute device.
83    Hip {
84        /// Zero-based HIP device index.
85        device_id: usize,
86    },
87}
88
89impl fmt::Display for ComputeDevice {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        match self {
92            ComputeDevice::Cpu { device_id } => write!(f, "cpu:{device_id}"),
93            ComputeDevice::Cuda { device_id } => write!(f, "cuda:{device_id}"),
94            ComputeDevice::Hip { device_id } => write!(f, "hip:{device_id}"),
95        }
96    }
97}
98
99/// Classification of tensor operations, used to query preferred compute
100/// devices for a given operation on a given memory space.
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub enum OpKind {
103    /// General tensor contraction.
104    Contract,
105    /// Batched GEMM (matrix-matrix multiply).
106    BatchedGemm,
107    /// Reduction (sum, max, min) over one or more modes.
108    Reduce,
109    /// Trace (diagonal contraction of paired modes).
110    Trace,
111    /// Mode permutation (transpose).
112    Permute,
113    /// Element-wise multiplication.
114    ElementwiseMul,
115}
116
117/// Return the preferred compute devices for a given operation on a memory space.
118///
119/// The returned list is ordered by preference (most preferred first).
120///
121/// # Errors
122///
123/// Returns [`Error::NoCompatibleComputeDevice`] if no compute device can
124/// execute the given operation on the specified memory space.
125///
126/// # Examples
127///
128/// ```ignore
129/// use tenferro_device::{
130///     preferred_compute_devices, ComputeDevice, LogicalMemorySpace, OpKind,
131/// };
132///
133/// let devices = preferred_compute_devices(
134///     LogicalMemorySpace::MainMemory,
135///     OpKind::BatchedGemm,
136/// ).unwrap();
137///
138/// // Typically includes CPU for main memory workloads.
139/// assert!(devices.contains(&ComputeDevice::Cpu { device_id: 0 }));
140/// ```
141pub fn preferred_compute_devices(
142    _space: LogicalMemorySpace,
143    _op_kind: OpKind,
144) -> Result<Vec<ComputeDevice>> {
145    todo!()
146}
147
148/// Error type used across the tenferro workspace.
149#[derive(Debug, thiserror::Error)]
150pub enum Error {
151    /// Tensor shapes are incompatible for the requested operation.
152    #[error("shape mismatch: expected {expected:?}, got {got:?}")]
153    ShapeMismatch {
154        /// Expected shape.
155        expected: Vec<usize>,
156        /// Actual shape.
157        got: Vec<usize>,
158    },
159
160    /// Tensor ranks (number of dimensions) do not match.
161    #[error("rank mismatch: expected {expected}, got {got}")]
162    RankMismatch {
163        /// Expected rank.
164        expected: usize,
165        /// Actual rank.
166        got: usize,
167    },
168
169    /// An error occurred on the compute device.
170    #[error("device error: {0}")]
171    DeviceError(String),
172
173    /// No compute device is compatible with the requested operation
174    /// on the given memory space.
175    #[error("no compatible compute device for {op:?} on {space:?}")]
176    NoCompatibleComputeDevice {
177        /// The memory space where the tensor data resides.
178        space: LogicalMemorySpace,
179        /// The operation that was requested.
180        op: OpKind,
181    },
182
183    /// Operations on tensors in different memory spaces are not supported
184    /// without explicit transfer.
185    #[error("cross-memory-space operation between {left:?} and {right:?}")]
186    CrossMemorySpaceOperation {
187        /// Memory space of the first operand.
188        left: LogicalMemorySpace,
189        /// Memory space of the second operand.
190        right: LogicalMemorySpace,
191    },
192
193    /// An invalid argument was provided.
194    #[error("invalid argument: {0}")]
195    InvalidArgument(String),
196
197    /// An error related to strided memory layout (invalid strides, out-of-bounds offset, etc.).
198    #[error("stride error: {0}")]
199    StrideError(String),
200}
201
202/// Result type alias using [`Error`].
203pub type Result<T> = std::result::Result<T, Error>;