tenferro_device/
lib.rs

1//! Device abstraction and shared error types for the tenferro workspace.
2//!
3//! This crate provides:
4//! - [`LogicalMemorySpace`] enum representing where tensor data resides
5//! - [`ComputeDevice`] enum representing hardware compute devices
6//! - [`OpKind`] enum classifying tensor operations for device selection
7//! - [`preferred_compute_devices`] for querying compatible devices
8//! - [`Error`] and [`Result`] types used across all tenferro crates
9//!
10//! # Examples
11//!
12//! ```
13//! use tenferro_device::{LogicalMemorySpace, ComputeDevice};
14//!
15//! let space = LogicalMemorySpace::MainMemory;
16//! let dev = ComputeDevice::Cpu { device_id: 0 };
17//! assert_eq!(format!("{dev}"), "cpu:0");
18//! ```
19
20use std::fmt;
21
22mod batch_index;
23#[cfg(feature = "cuda")]
24pub mod cuda;
25mod generator;
26
27/// Logical memory space where tensor data resides.
28///
29/// Separates the concept of "where data lives" from "which hardware
30/// computes on it". A tensor on [`MainMemory`](LogicalMemorySpace::MainMemory)
31/// can be processed by any CPU, while a tensor on
32/// [`GpuMemory`](LogicalMemorySpace::GpuMemory) can be processed by any
33/// compute device with access to that GPU memory space.
34///
35/// The variants align with DLPack `DLDeviceType` constants:
36///
37/// | Variant | DLPack `device_type` |
38/// |---------|---------------------|
39/// | `MainMemory` | `kDLCPU` (1) |
40/// | `PinnedMemory` | `kDLCUDAHost` (3) / `kDLROCMHost` (11) |
41/// | `GpuMemory` | `kDLCUDA` (2) / `kDLROCM` (10) — vendor determined by context |
42/// | `ManagedMemory` | `kDLCUDAManaged` (13) |
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum LogicalMemorySpace {
45    /// System main memory (CPU-accessible RAM).
46    ///
47    /// Corresponds to DLPack `kDLCPU` (`device_id = 0`).
48    MainMemory,
49    /// Host memory pinned for fast GPU transfer (`cudaMallocHost` / `hipMallocHost`).
50    ///
51    /// Accessible from the CPU but optimized for DMA transfer to/from
52    /// a specific GPU. Corresponds to DLPack `kDLCUDAHost` or `kDLROCMHost`
53    /// (`device_id = 0`).
54    PinnedMemory,
55    /// GPU-resident memory identified by a device ID.
56    ///
57    /// Corresponds to DLPack `kDLCUDA` or `kDLROCM` with the given `device_id`.
58    GpuMemory {
59        /// Zero-based GPU device index (matches DLPack `device_id`).
60        device_id: usize,
61    },
62    /// CUDA Unified/Managed memory (`cudaMallocManaged`).
63    ///
64    /// Accessible from both CPU and all GPUs. The CUDA runtime handles
65    /// page migration automatically. Corresponds to DLPack `kDLCUDAManaged`
66    /// (`device_id = 0`).
67    ManagedMemory,
68}
69
70/// Compute device that can execute tensor operations.
71///
72/// Unlike [`LogicalMemorySpace`], which describes where data resides,
73/// `ComputeDevice` identifies the hardware that performs the computation.
74/// Multiple compute devices may share access to the same memory space.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum ComputeDevice {
77    /// CPU compute device.
78    Cpu {
79        /// Zero-based CPU device index (0 = default global thread pool).
80        device_id: usize,
81    },
82    /// NVIDIA CUDA compute device.
83    Cuda {
84        /// Zero-based CUDA device index.
85        device_id: usize,
86    },
87    /// AMD ROCm compute device.
88    Rocm {
89        /// Zero-based ROCm device index.
90        device_id: usize,
91    },
92}
93
94impl fmt::Display for ComputeDevice {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        match self {
97            ComputeDevice::Cpu { device_id } => write!(f, "cpu:{device_id}"),
98            ComputeDevice::Cuda { device_id } => write!(f, "cuda:{device_id}"),
99            ComputeDevice::Rocm { device_id } => write!(f, "rocm:{device_id}"),
100        }
101    }
102}
103
104/// Classification of tensor operations, used to query preferred compute
105/// devices for a given operation on a given memory space.
106///
107/// # Examples
108///
109/// ```
110/// use tenferro_device::OpKind;
111///
112/// let op = OpKind::BatchedGemm;
113/// assert_eq!(format!("{op:?}"), "BatchedGemm");
114/// ```
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
116pub enum OpKind {
117    /// General tensor contraction.
118    Contract,
119    /// Batched GEMM (matrix-matrix multiply).
120    BatchedGemm,
121    /// Reduction (sum, max, min) over one or more modes.
122    Reduce,
123    /// Trace (diagonal contraction of paired modes).
124    Trace,
125    /// Mode permutation (transpose).
126    Permute,
127    /// Element-wise multiplication.
128    ElementwiseMul,
129}
130
131/// Return the preferred compute devices for a given memory space.
132///
133/// Returns compute devices that can execute operations on tensors in the
134/// specified memory space. The selection follows these rules:
135///
136/// - [`MainMemory`](LogicalMemorySpace::MainMemory): Returns CPU devices
137/// - [`GpuMemory`](LogicalMemorySpace::GpuMemory): Returns the corresponding
138///   CUDA device (when compiled with `cuda` feature and device is available)
139/// - [`PinnedMemory`](LogicalMemorySpace::PinnedMemory): Prefers GPU compute
140///   if available, falls back to CPU
141/// - [`ManagedMemory`](LogicalMemorySpace::ManagedMemory): Prefers GPU compute
142///   if available, falls back to CPU
143///
144/// The returned list is ordered by preference (most preferred first).
145///
146/// # Errors
147///
148/// Returns [`Error::NoCompatibleComputeDevice`] if no compute device can
149/// execute the given operation on the specified memory space.
150///
151/// # Examples
152///
153/// ```
154/// use tenferro_device::{
155///     preferred_compute_devices, ComputeDevice, LogicalMemorySpace, OpKind,
156/// };
157///
158/// let devices = preferred_compute_devices(
159///     LogicalMemorySpace::MainMemory,
160///     OpKind::BatchedGemm,
161/// ).unwrap();
162///
163/// // Typically includes CPU for main memory workloads.
164/// assert!(devices.contains(&ComputeDevice::Cpu { device_id: 0 }));
165/// ```
166pub fn preferred_compute_devices(
167    space: LogicalMemorySpace,
168    _op_kind: OpKind,
169) -> Result<Vec<ComputeDevice>> {
170    match space {
171        LogicalMemorySpace::MainMemory => Ok(vec![ComputeDevice::Cpu { device_id: 0 }]),
172        LogicalMemorySpace::GpuMemory { device_id } => {
173            #[cfg(feature = "cuda")]
174            {
175                if is_cuda_device_available(device_id) {
176                    Ok(vec![ComputeDevice::Cuda { device_id }])
177                } else {
178                    Err(Error::NoCompatibleComputeDevice {
179                        space,
180                        op: _op_kind,
181                    })
182                }
183            }
184            #[cfg(not(feature = "cuda"))]
185            {
186                let _ = device_id;
187                Err(Error::NoCompatibleComputeDevice {
188                    space,
189                    op: _op_kind,
190                })
191            }
192        }
193        LogicalMemorySpace::PinnedMemory | LogicalMemorySpace::ManagedMemory => {
194            #[cfg(feature = "cuda")]
195            {
196                let mut devices = Vec::new();
197                if let Some(cuda_device) = first_available_cuda_device() {
198                    devices.push(cuda_device);
199                }
200                devices.push(ComputeDevice::Cpu { device_id: 0 });
201                Ok(devices)
202            }
203            #[cfg(not(feature = "cuda"))]
204            {
205                Ok(vec![ComputeDevice::Cpu { device_id: 0 }])
206            }
207        }
208    }
209}
210
211#[cfg(feature = "cuda")]
212fn cuda_device_count() -> usize {
213    std::panic::catch_unwind(|| {
214        cudarc::runtime::result::device::get_count()
215            .map(|count| count as usize)
216            .unwrap_or(0)
217    })
218    .unwrap_or(0)
219}
220
221#[cfg(feature = "cuda")]
222fn is_cuda_device_available(device_id: usize) -> bool {
223    device_id < cuda_device_count()
224}
225
226#[cfg(feature = "cuda")]
227fn first_available_cuda_device() -> Option<ComputeDevice> {
228    let count = cuda_device_count();
229    if count > 0 {
230        Some(ComputeDevice::Cuda { device_id: 0 })
231    } else {
232        None
233    }
234}
235
236/// Error type used across the tenferro workspace.
237///
238/// # Examples
239///
240/// ```
241/// use tenferro_device::Error;
242///
243/// let err = Error::InvalidArgument("bad index".into());
244/// assert!(err.to_string().contains("bad index"));
245/// ```
246#[derive(Debug, thiserror::Error)]
247pub enum Error {
248    /// Tensor shapes are incompatible for the requested operation.
249    #[error("shape mismatch: expected {expected:?}, got {got:?}")]
250    ShapeMismatch {
251        /// Expected shape.
252        expected: Vec<usize>,
253        /// Actual shape.
254        got: Vec<usize>,
255    },
256
257    /// Tensor ranks (number of dimensions) do not match.
258    #[error("rank mismatch: expected {expected}, got {got}")]
259    RankMismatch {
260        /// Expected rank.
261        expected: usize,
262        /// Actual rank.
263        got: usize,
264    },
265
266    /// An error occurred on the compute device.
267    #[error("device error: {0}")]
268    DeviceError(String),
269
270    /// No compute device is compatible with the requested operation
271    /// on the given memory space.
272    #[error("no compatible compute device for {op:?} on {space:?}")]
273    NoCompatibleComputeDevice {
274        /// The memory space where the tensor data resides.
275        space: LogicalMemorySpace,
276        /// The operation that was requested.
277        op: OpKind,
278    },
279
280    /// Operations on tensors in different memory spaces are not supported
281    /// without explicit transfer.
282    ///
283    /// Reserved for future cross-device operations (e.g., GPU-CPU transfers).
284    #[allow(dead_code)]
285    #[error("cross-memory-space operation between {left:?} and {right:?}")]
286    CrossMemorySpaceOperation {
287        /// Memory space of the first operand.
288        left: LogicalMemorySpace,
289        /// Memory space of the second operand.
290        right: LogicalMemorySpace,
291    },
292
293    /// An invalid argument was provided.
294    #[error("invalid argument: {0}")]
295    InvalidArgument(String),
296
297    /// An error related to strided memory layout (invalid strides, out-of-bounds offset, etc.).
298    #[error("stride error: {0}")]
299    StrideError(String),
300}
301
302/// Result type alias using [`Error`].
303///
304/// # Examples
305///
306/// ```
307/// use tenferro_device::Result;
308///
309/// fn compute() -> Result<usize> {
310///     Ok(42)
311/// }
312/// assert_eq!(compute().unwrap(), 42);
313/// ```
314pub type Result<T> = std::result::Result<T, Error>;
315
316#[doc(hidden)]
317pub use batch_index::{
318    broadcast_batch_dims, checked_batch_count, flatten_col_major_index,
319    unflatten_col_major_index_into, BroadcastBatchIndexer,
320};
321pub use generator::{with_default_generator, Generator};
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_preferred_compute_devices_main_memory() {
329        let devices =
330            preferred_compute_devices(LogicalMemorySpace::MainMemory, OpKind::BatchedGemm).unwrap();
331        assert_eq!(devices.len(), 1);
332        assert_eq!(devices[0], ComputeDevice::Cpu { device_id: 0 });
333    }
334
335    #[test]
336    fn test_preferred_compute_devices_main_memory_all_ops() {
337        for op in [
338            OpKind::Contract,
339            OpKind::BatchedGemm,
340            OpKind::Reduce,
341            OpKind::Trace,
342            OpKind::Permute,
343            OpKind::ElementwiseMul,
344        ] {
345            let devices = preferred_compute_devices(LogicalMemorySpace::MainMemory, op).unwrap();
346            assert!(
347                devices.contains(&ComputeDevice::Cpu { device_id: 0 }),
348                "CPU should be available for {:?}",
349                op
350            );
351        }
352    }
353
354    #[cfg(not(feature = "cuda"))]
355    #[test]
356    fn test_preferred_compute_devices_gpu_memory_without_cuda_feature() {
357        let result = preferred_compute_devices(
358            LogicalMemorySpace::GpuMemory { device_id: 0 },
359            OpKind::BatchedGemm,
360        );
361        assert!(result.is_err());
362        match result.unwrap_err() {
363            Error::NoCompatibleComputeDevice { space, op } => {
364                assert_eq!(space, LogicalMemorySpace::GpuMemory { device_id: 0 });
365                assert_eq!(op, OpKind::BatchedGemm);
366            }
367            _ => panic!("Expected NoCompatibleComputeDevice error"),
368        }
369    }
370
371    #[cfg(not(feature = "cuda"))]
372    #[test]
373    fn test_preferred_compute_devices_pinned_memory_without_cuda_feature() {
374        let devices =
375            preferred_compute_devices(LogicalMemorySpace::PinnedMemory, OpKind::BatchedGemm)
376                .unwrap();
377        assert_eq!(devices.len(), 1);
378        assert_eq!(devices[0], ComputeDevice::Cpu { device_id: 0 });
379    }
380
381    #[cfg(not(feature = "cuda"))]
382    #[test]
383    fn test_preferred_compute_devices_managed_memory_without_cuda_feature() {
384        let devices =
385            preferred_compute_devices(LogicalMemorySpace::ManagedMemory, OpKind::BatchedGemm)
386                .unwrap();
387        assert_eq!(devices.len(), 1);
388        assert_eq!(devices[0], ComputeDevice::Cpu { device_id: 0 });
389    }
390
391    #[cfg(feature = "cuda")]
392    #[test]
393    fn test_preferred_compute_devices_gpu_memory_with_cuda_available() {
394        if cuda_device_count() > 0 {
395            let devices = preferred_compute_devices(
396                LogicalMemorySpace::GpuMemory { device_id: 0 },
397                OpKind::BatchedGemm,
398            )
399            .unwrap();
400            assert_eq!(devices.len(), 1);
401            assert_eq!(devices[0], ComputeDevice::Cuda { device_id: 0 });
402        }
403    }
404
405    #[cfg(feature = "cuda")]
406    #[test]
407    fn test_preferred_compute_devices_gpu_memory_invalid_device() {
408        let invalid_device_id = cuda_device_count() + 100;
409        let result = preferred_compute_devices(
410            LogicalMemorySpace::GpuMemory {
411                device_id: invalid_device_id,
412            },
413            OpKind::BatchedGemm,
414        );
415        assert!(result.is_err());
416    }
417
418    #[cfg(feature = "cuda")]
419    #[test]
420    fn test_preferred_compute_devices_pinned_memory_with_cuda() {
421        let devices =
422            preferred_compute_devices(LogicalMemorySpace::PinnedMemory, OpKind::BatchedGemm)
423                .unwrap();
424        if cuda_device_count() > 0 {
425            assert!(
426                devices.contains(&ComputeDevice::Cuda { device_id: 0 }),
427                "CUDA device should be preferred for pinned memory"
428            );
429        }
430        assert!(
431            devices.contains(&ComputeDevice::Cpu { device_id: 0 }),
432            "CPU should be fallback for pinned memory"
433        );
434    }
435
436    #[cfg(feature = "cuda")]
437    #[test]
438    fn test_preferred_compute_devices_managed_memory_with_cuda() {
439        let devices =
440            preferred_compute_devices(LogicalMemorySpace::ManagedMemory, OpKind::BatchedGemm)
441                .unwrap();
442        if cuda_device_count() > 0 {
443            assert!(
444                devices.contains(&ComputeDevice::Cuda { device_id: 0 }),
445                "CUDA device should be preferred for managed memory"
446            );
447        }
448        assert!(
449            devices.contains(&ComputeDevice::Cpu { device_id: 0 }),
450            "CPU should be fallback for managed memory"
451        );
452    }
453
454    #[cfg(feature = "cuda")]
455    #[test]
456    fn test_cuda_device_count_safe() {
457        let count = cuda_device_count();
458        assert!(count < 1000, "device count should be reasonable");
459    }
460
461    #[test]
462    fn test_compute_device_display() {
463        assert_eq!(format!("{}", ComputeDevice::Cpu { device_id: 0 }), "cpu:0");
464        assert_eq!(
465            format!("{}", ComputeDevice::Cuda { device_id: 2 }),
466            "cuda:2"
467        );
468        assert_eq!(
469            format!("{}", ComputeDevice::Rocm { device_id: 1 }),
470            "rocm:1"
471        );
472    }
473
474    #[test]
475    fn test_logical_memory_space_equality() {
476        assert_eq!(
477            LogicalMemorySpace::GpuMemory { device_id: 0 },
478            LogicalMemorySpace::GpuMemory { device_id: 0 }
479        );
480        assert_ne!(
481            LogicalMemorySpace::GpuMemory { device_id: 0 },
482            LogicalMemorySpace::GpuMemory { device_id: 1 }
483        );
484        assert_ne!(
485            LogicalMemorySpace::GpuMemory { device_id: 0 },
486            LogicalMemorySpace::MainMemory
487        );
488    }
489
490    #[test]
491    fn test_error_display() {
492        let err = Error::NoCompatibleComputeDevice {
493            space: LogicalMemorySpace::GpuMemory { device_id: 0 },
494            op: OpKind::BatchedGemm,
495        };
496        let msg = err.to_string();
497        assert!(msg.contains("no compatible compute device"));
498        assert!(msg.contains("BatchedGemm"));
499    }
500}