1use std::fmt;
21
22mod batch_index;
23#[cfg(feature = "cuda")]
24pub mod cuda;
25mod generator;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum LogicalMemorySpace {
45 MainMemory,
49 PinnedMemory,
55 GpuMemory {
59 device_id: usize,
61 },
62 ManagedMemory,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum ComputeDevice {
77 Cpu {
79 device_id: usize,
81 },
82 Cuda {
84 device_id: usize,
86 },
87 Rocm {
89 device_id: usize,
91 },
92}
93
94impl fmt::Display for ComputeDevice {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 match self {
97 ComputeDevice::Cpu { device_id } => write!(f, "cpu:{device_id}"),
98 ComputeDevice::Cuda { device_id } => write!(f, "cuda:{device_id}"),
99 ComputeDevice::Rocm { device_id } => write!(f, "rocm:{device_id}"),
100 }
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
116pub enum OpKind {
117 Contract,
119 BatchedGemm,
121 Reduce,
123 Trace,
125 Permute,
127 ElementwiseMul,
129}
130
131pub fn preferred_compute_devices(
167 space: LogicalMemorySpace,
168 _op_kind: OpKind,
169) -> Result<Vec<ComputeDevice>> {
170 match space {
171 LogicalMemorySpace::MainMemory => Ok(vec![ComputeDevice::Cpu { device_id: 0 }]),
172 LogicalMemorySpace::GpuMemory { device_id } => {
173 #[cfg(feature = "cuda")]
174 {
175 if is_cuda_device_available(device_id) {
176 Ok(vec![ComputeDevice::Cuda { device_id }])
177 } else {
178 Err(Error::NoCompatibleComputeDevice {
179 space,
180 op: _op_kind,
181 })
182 }
183 }
184 #[cfg(not(feature = "cuda"))]
185 {
186 let _ = device_id;
187 Err(Error::NoCompatibleComputeDevice {
188 space,
189 op: _op_kind,
190 })
191 }
192 }
193 LogicalMemorySpace::PinnedMemory | LogicalMemorySpace::ManagedMemory => {
194 #[cfg(feature = "cuda")]
195 {
196 let mut devices = Vec::new();
197 if let Some(cuda_device) = first_available_cuda_device() {
198 devices.push(cuda_device);
199 }
200 devices.push(ComputeDevice::Cpu { device_id: 0 });
201 Ok(devices)
202 }
203 #[cfg(not(feature = "cuda"))]
204 {
205 Ok(vec![ComputeDevice::Cpu { device_id: 0 }])
206 }
207 }
208 }
209}
210
211#[cfg(feature = "cuda")]
212fn cuda_device_count() -> usize {
213 std::panic::catch_unwind(|| {
214 cudarc::runtime::result::device::get_count()
215 .map(|count| count as usize)
216 .unwrap_or(0)
217 })
218 .unwrap_or(0)
219}
220
221#[cfg(feature = "cuda")]
222fn is_cuda_device_available(device_id: usize) -> bool {
223 device_id < cuda_device_count()
224}
225
226#[cfg(feature = "cuda")]
227fn first_available_cuda_device() -> Option<ComputeDevice> {
228 let count = cuda_device_count();
229 if count > 0 {
230 Some(ComputeDevice::Cuda { device_id: 0 })
231 } else {
232 None
233 }
234}
235
236#[derive(Debug, thiserror::Error)]
247pub enum Error {
248 #[error("shape mismatch: expected {expected:?}, got {got:?}")]
250 ShapeMismatch {
251 expected: Vec<usize>,
253 got: Vec<usize>,
255 },
256
257 #[error("rank mismatch: expected {expected}, got {got}")]
259 RankMismatch {
260 expected: usize,
262 got: usize,
264 },
265
266 #[error("device error: {0}")]
268 DeviceError(String),
269
270 #[error("no compatible compute device for {op:?} on {space:?}")]
273 NoCompatibleComputeDevice {
274 space: LogicalMemorySpace,
276 op: OpKind,
278 },
279
280 #[allow(dead_code)]
285 #[error("cross-memory-space operation between {left:?} and {right:?}")]
286 CrossMemorySpaceOperation {
287 left: LogicalMemorySpace,
289 right: LogicalMemorySpace,
291 },
292
293 #[error("invalid argument: {0}")]
295 InvalidArgument(String),
296
297 #[error("stride error: {0}")]
299 StrideError(String),
300}
301
302pub type Result<T> = std::result::Result<T, Error>;
315
316#[doc(hidden)]
317pub use batch_index::{
318 broadcast_batch_dims, checked_batch_count, flatten_col_major_index,
319 unflatten_col_major_index_into, BroadcastBatchIndexer,
320};
321pub use generator::{with_default_generator, Generator};
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_preferred_compute_devices_main_memory() {
329 let devices =
330 preferred_compute_devices(LogicalMemorySpace::MainMemory, OpKind::BatchedGemm).unwrap();
331 assert_eq!(devices.len(), 1);
332 assert_eq!(devices[0], ComputeDevice::Cpu { device_id: 0 });
333 }
334
335 #[test]
336 fn test_preferred_compute_devices_main_memory_all_ops() {
337 for op in [
338 OpKind::Contract,
339 OpKind::BatchedGemm,
340 OpKind::Reduce,
341 OpKind::Trace,
342 OpKind::Permute,
343 OpKind::ElementwiseMul,
344 ] {
345 let devices = preferred_compute_devices(LogicalMemorySpace::MainMemory, op).unwrap();
346 assert!(
347 devices.contains(&ComputeDevice::Cpu { device_id: 0 }),
348 "CPU should be available for {:?}",
349 op
350 );
351 }
352 }
353
354 #[cfg(not(feature = "cuda"))]
355 #[test]
356 fn test_preferred_compute_devices_gpu_memory_without_cuda_feature() {
357 let result = preferred_compute_devices(
358 LogicalMemorySpace::GpuMemory { device_id: 0 },
359 OpKind::BatchedGemm,
360 );
361 assert!(result.is_err());
362 match result.unwrap_err() {
363 Error::NoCompatibleComputeDevice { space, op } => {
364 assert_eq!(space, LogicalMemorySpace::GpuMemory { device_id: 0 });
365 assert_eq!(op, OpKind::BatchedGemm);
366 }
367 _ => panic!("Expected NoCompatibleComputeDevice error"),
368 }
369 }
370
371 #[cfg(not(feature = "cuda"))]
372 #[test]
373 fn test_preferred_compute_devices_pinned_memory_without_cuda_feature() {
374 let devices =
375 preferred_compute_devices(LogicalMemorySpace::PinnedMemory, OpKind::BatchedGemm)
376 .unwrap();
377 assert_eq!(devices.len(), 1);
378 assert_eq!(devices[0], ComputeDevice::Cpu { device_id: 0 });
379 }
380
381 #[cfg(not(feature = "cuda"))]
382 #[test]
383 fn test_preferred_compute_devices_managed_memory_without_cuda_feature() {
384 let devices =
385 preferred_compute_devices(LogicalMemorySpace::ManagedMemory, OpKind::BatchedGemm)
386 .unwrap();
387 assert_eq!(devices.len(), 1);
388 assert_eq!(devices[0], ComputeDevice::Cpu { device_id: 0 });
389 }
390
391 #[cfg(feature = "cuda")]
392 #[test]
393 fn test_preferred_compute_devices_gpu_memory_with_cuda_available() {
394 if cuda_device_count() > 0 {
395 let devices = preferred_compute_devices(
396 LogicalMemorySpace::GpuMemory { device_id: 0 },
397 OpKind::BatchedGemm,
398 )
399 .unwrap();
400 assert_eq!(devices.len(), 1);
401 assert_eq!(devices[0], ComputeDevice::Cuda { device_id: 0 });
402 }
403 }
404
405 #[cfg(feature = "cuda")]
406 #[test]
407 fn test_preferred_compute_devices_gpu_memory_invalid_device() {
408 let invalid_device_id = cuda_device_count() + 100;
409 let result = preferred_compute_devices(
410 LogicalMemorySpace::GpuMemory {
411 device_id: invalid_device_id,
412 },
413 OpKind::BatchedGemm,
414 );
415 assert!(result.is_err());
416 }
417
418 #[cfg(feature = "cuda")]
419 #[test]
420 fn test_preferred_compute_devices_pinned_memory_with_cuda() {
421 let devices =
422 preferred_compute_devices(LogicalMemorySpace::PinnedMemory, OpKind::BatchedGemm)
423 .unwrap();
424 if cuda_device_count() > 0 {
425 assert!(
426 devices.contains(&ComputeDevice::Cuda { device_id: 0 }),
427 "CUDA device should be preferred for pinned memory"
428 );
429 }
430 assert!(
431 devices.contains(&ComputeDevice::Cpu { device_id: 0 }),
432 "CPU should be fallback for pinned memory"
433 );
434 }
435
436 #[cfg(feature = "cuda")]
437 #[test]
438 fn test_preferred_compute_devices_managed_memory_with_cuda() {
439 let devices =
440 preferred_compute_devices(LogicalMemorySpace::ManagedMemory, OpKind::BatchedGemm)
441 .unwrap();
442 if cuda_device_count() > 0 {
443 assert!(
444 devices.contains(&ComputeDevice::Cuda { device_id: 0 }),
445 "CUDA device should be preferred for managed memory"
446 );
447 }
448 assert!(
449 devices.contains(&ComputeDevice::Cpu { device_id: 0 }),
450 "CPU should be fallback for managed memory"
451 );
452 }
453
454 #[cfg(feature = "cuda")]
455 #[test]
456 fn test_cuda_device_count_safe() {
457 let count = cuda_device_count();
458 assert!(count < 1000, "device count should be reasonable");
459 }
460
461 #[test]
462 fn test_compute_device_display() {
463 assert_eq!(format!("{}", ComputeDevice::Cpu { device_id: 0 }), "cpu:0");
464 assert_eq!(
465 format!("{}", ComputeDevice::Cuda { device_id: 2 }),
466 "cuda:2"
467 );
468 assert_eq!(
469 format!("{}", ComputeDevice::Rocm { device_id: 1 }),
470 "rocm:1"
471 );
472 }
473
474 #[test]
475 fn test_logical_memory_space_equality() {
476 assert_eq!(
477 LogicalMemorySpace::GpuMemory { device_id: 0 },
478 LogicalMemorySpace::GpuMemory { device_id: 0 }
479 );
480 assert_ne!(
481 LogicalMemorySpace::GpuMemory { device_id: 0 },
482 LogicalMemorySpace::GpuMemory { device_id: 1 }
483 );
484 assert_ne!(
485 LogicalMemorySpace::GpuMemory { device_id: 0 },
486 LogicalMemorySpace::MainMemory
487 );
488 }
489
490 #[test]
491 fn test_error_display() {
492 let err = Error::NoCompatibleComputeDevice {
493 space: LogicalMemorySpace::GpuMemory { device_id: 0 },
494 op: OpKind::BatchedGemm,
495 };
496 let msg = err.to_string();
497 assert!(msg.contains("no compatible compute device"));
498 assert!(msg.contains("BatchedGemm"));
499 }
500}