tenferro_prims/
gpu_stubs.rs

1use std::ffi::c_void;
2use std::marker::PhantomData;
3
4use tenferro_algebra::{Conjugate, Scalar, Standard};
5use tenferro_device::{Error, Result};
6use tenferro_tensor::Tensor;
7
8use crate::cpu::TempPool;
9use crate::{
10    MetadataCastPrimsDescriptor, MetadataPrimsDescriptor, MetadataScalarTensorRef,
11    MetadataTensorMut, MetadataTensorRef, PlanCache, SemiringCoreDescriptor,
12    SemiringFastPathDescriptor, TensorMetadataCastPrims, TensorMetadataPrims, TensorSemiringCore,
13    TensorSemiringFastPath, TensorTempPoolContext,
14};
15
16// ===========================================================================
17// CUDA stub types (used when `cuda` feature is NOT enabled)
18// ===========================================================================
19
20/// CUDA execution context (stub).
21///
22/// **Status: Stub.** This type exists as an API placeholder when the `cuda`
23/// feature is not enabled. All operations on [`CudaBackend`] return errors.
24/// Enable the `cuda` feature for the real implementation.
25///
26/// # Examples
27///
28/// ```ignore
29/// // Aspirational API — not yet functional without `cuda` feature.
30/// use tenferro_prims::CudaContext;
31/// ```
32#[cfg(not(feature = "cuda"))]
33pub struct CudaContext {
34    _stream: *mut c_void,
35    _workspace: Vec<u8>,
36    _plan_cache: PlanCache,
37    temp_pool: TempPool,
38}
39
40#[cfg(not(feature = "cuda"))]
41impl CudaContext {
42    /// Create a stub CUDA context (no-op).
43    pub fn new() -> Self {
44        Self {
45            _stream: std::ptr::null_mut(),
46            _workspace: Vec::new(),
47            _plan_cache: PlanCache::new(),
48            temp_pool: TempPool::default(),
49        }
50    }
51
52    /// Return the stub CUDA device ordinal.
53    ///
54    /// # Examples
55    ///
56    /// ```ignore
57    /// let ctx = tenferro_prims::CudaContext::new();
58    /// assert_eq!(ctx.device_id(), 0);
59    /// ```
60    pub fn device_id(&self) -> usize {
61        0
62    }
63
64    /// Bind the stub CUDA context to the current device.
65    ///
66    /// # Examples
67    ///
68    /// ```ignore
69    /// let ctx = tenferro_prims::CudaContext::new();
70    /// assert!(ctx.bind_to_device().is_err());
71    /// ```
72    pub fn bind_to_device(&self) -> Result<()> {
73        Err(Error::DeviceError(
74            "CUDA feature is not enabled for this build".into(),
75        ))
76    }
77}
78
79#[cfg(not(feature = "cuda"))]
80impl Default for CudaContext {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86#[cfg(not(feature = "cuda"))]
87impl TensorTempPoolContext for CudaContext {
88    fn take_temp_vec<T: Send + 'static>(&mut self, len: usize) -> Vec<T> {
89        self.temp_pool.take_vec::<T>(len)
90    }
91
92    fn put_temp_vec<T: Send + 'static>(&mut self, vec: Vec<T>) {
93        self.temp_pool.put_vec(vec);
94    }
95}
96
97/// CUDA plan (stub) — placeholder when `cuda` feature is not enabled.
98///
99/// **Status: Stub.** Enable the `cuda` feature for the real implementation.
100#[cfg(not(feature = "cuda"))]
101pub struct CudaPlan<T: Scalar> {
102    _handle: *mut c_void,
103    _workspace_size: usize,
104    _marker: PhantomData<T>,
105}
106
107/// CUDA backend (stub) — placeholder when `cuda` feature is not enabled.
108///
109/// **Status: Stub.** All methods return errors. Enable the `cuda` feature
110/// for the real implementation backed by cuTENSOR + cudarc.
111///
112/// # Examples
113///
114/// ```ignore
115/// // Aspirational API — enable `cuda` feature for real backend.
116/// use tenferro_prims::{CudaBackend, BackendRegistry};
117///
118/// let mut registry = BackendRegistry::new();
119/// registry.load_cutensor("/usr/lib/libcutensor.so").unwrap();
120/// ```
121#[cfg(not(feature = "cuda"))]
122pub struct CudaBackend {
123    _handle: *mut c_void,
124    _lib: libloading::Library,
125}
126
127#[cfg(not(feature = "cuda"))]
128impl Drop for CudaBackend {
129    fn drop(&mut self) {
130        self._handle = std::ptr::null_mut();
131    }
132}
133
134/// # Safety
135///
136/// `CudaBackend` can be safely sent across threads because:
137/// - The `_handle` is an opaque pointer to a cuTENSOR handle
138/// - The `_lib` (`libloading::Library`) is thread-safe after loading
139/// - The handle is read-only after construction
140/// - Drop clears the handle before the library is unloaded, preventing use-after-free
141#[cfg(not(feature = "cuda"))]
142unsafe impl Send for CudaBackend {}
143
144/// # Safety
145///
146/// `CudaBackend` can be safely shared across threads because:
147/// - The cuTENSOR handle is designed for concurrent use from multiple threads
148/// - The library handle (`_lib`) is read-only after construction
149/// - Symbol lookup via `dlsym` is thread-safe on POSIX systems
150/// - Drop uses `&mut self`, ensuring exclusive access during cleanup
151#[cfg(not(feature = "cuda"))]
152unsafe impl Sync for CudaBackend {}
153
154#[cfg(not(feature = "cuda"))]
155impl CudaBackend {
156    /// Materialize a lazily-conjugated tensor on GPU.
157    ///
158    /// **Status: Stub fallback.** If data is CPU-accessible, this materializes
159    /// conjugation into a new non-conjugated tensor. Otherwise it returns a
160    /// clone of `src`.
161    pub fn resolve_conj<T: Scalar + Conjugate>(
162        _ctx: &mut CudaContext,
163        src: &tenferro_tensor::Tensor<T>,
164    ) -> tenferro_tensor::Tensor<T> {
165        if !src.is_conjugated() {
166            return src.clone();
167        }
168
169        let contiguous = src.contiguous(tenferro_tensor::MemoryOrder::ColumnMajor);
170        let Some(data) = contiguous.buffer().as_slice() else {
171            return src.clone();
172        };
173        let conjugated_data: Vec<T> = data.iter().map(|&v| v.conj()).collect();
174        tenferro_tensor::Tensor::from_slice(
175            &conjugated_data,
176            src.dims(),
177            tenferro_tensor::MemoryOrder::ColumnMajor,
178        )
179        .unwrap_or_else(|_| src.clone())
180    }
181}
182
183#[cfg(not(feature = "cuda"))]
184impl<S: Scalar> TensorSemiringCore<Standard<S>> for CudaBackend {
185    type Plan = CudaPlan<S>;
186    type Context = CudaContext;
187
188    fn plan(
189        _ctx: &mut CudaContext,
190        _desc: &SemiringCoreDescriptor,
191        _shapes: &[&[usize]],
192    ) -> Result<CudaPlan<S>> {
193        Err(Error::DeviceError(
194            "CUDA backend not available: load cuTENSOR library first".into(),
195        ))
196    }
197
198    fn execute(
199        _ctx: &mut CudaContext,
200        _plan: &CudaPlan<S>,
201        _alpha: S,
202        _inputs: &[&Tensor<S>],
203        _beta: S,
204        _output: &mut Tensor<S>,
205    ) -> Result<()> {
206        Err(Error::DeviceError(
207            "CUDA backend not available: load cuTENSOR library first".into(),
208        ))
209    }
210}
211
212#[cfg(not(feature = "cuda"))]
213impl TensorMetadataPrims for CudaBackend {
214    type Plan = MetadataPrimsDescriptor;
215    type Context = CudaContext;
216
217    fn plan(
218        _ctx: &mut CudaContext,
219        desc: &MetadataPrimsDescriptor,
220        _inputs: &[MetadataTensorRef<'_>],
221        _output: MetadataTensorMut<'_>,
222    ) -> Result<Self::Plan> {
223        Err(Error::DeviceError(format!(
224            "metadata family descriptor {desc:?} is not implemented on stub CudaBackend"
225        )))
226    }
227
228    fn execute(
229        _ctx: &mut CudaContext,
230        _plan: &Self::Plan,
231        _inputs: &[MetadataTensorRef<'_>],
232        _output: MetadataTensorMut<'_>,
233    ) -> Result<()> {
234        Err(Error::DeviceError(
235            "metadata family execution is not implemented on stub CudaBackend".into(),
236        ))
237    }
238
239    fn has_metadata_support(_desc: MetadataPrimsDescriptor) -> bool {
240        false
241    }
242}
243
244#[cfg(not(feature = "cuda"))]
245impl<S: Scalar> TensorSemiringFastPath<Standard<S>> for CudaBackend {
246    type Plan = CudaPlan<S>;
247    type Context = CudaContext;
248
249    fn plan(
250        _ctx: &mut CudaContext,
251        _desc: &SemiringFastPathDescriptor,
252        _shapes: &[&[usize]],
253    ) -> Result<CudaPlan<S>> {
254        Err(Error::DeviceError(
255            "CUDA backend not available: load cuTENSOR library first".into(),
256        ))
257    }
258
259    fn execute(
260        _ctx: &mut CudaContext,
261        _plan: &CudaPlan<S>,
262        _alpha: S,
263        _inputs: &[&Tensor<S>],
264        _beta: S,
265        _output: &mut Tensor<S>,
266    ) -> Result<()> {
267        Err(Error::DeviceError(
268            "CUDA backend not available: load cuTENSOR library first".into(),
269        ))
270    }
271
272    fn has_fast_path(_desc: SemiringFastPathDescriptor) -> bool {
273        false
274    }
275}
276
277#[cfg(not(feature = "cuda"))]
278impl<S: Scalar + num_traits::NumCast> TensorMetadataCastPrims<S> for CudaBackend {
279    type Plan = MetadataCastPrimsDescriptor;
280    type Context = CudaContext;
281
282    fn plan(
283        _ctx: &mut CudaContext,
284        desc: &MetadataCastPrimsDescriptor,
285        _shapes: &[&[usize]],
286    ) -> Result<Self::Plan> {
287        Err(Error::DeviceError(format!(
288            "metadata cast family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
289        )))
290    }
291
292    fn execute(
293        _ctx: &mut CudaContext,
294        _plan: &Self::Plan,
295        _alpha: S,
296        _inputs: &[MetadataScalarTensorRef<'_, S>],
297        _beta: S,
298        _output: &mut Tensor<S>,
299    ) -> Result<()> {
300        Err(Error::DeviceError(
301            "metadata cast family execution is not implemented on CudaBackend in phase 1".into(),
302        ))
303    }
304
305    fn has_metadata_cast_support(_desc: MetadataCastPrimsDescriptor) -> bool {
306        false
307    }
308}
309
310// ===========================================================================
311// ROCm stub types (always present — no real ROCm backend yet)
312// ===========================================================================
313
314/// ROCm execution context.
315///
316/// **Status: Not yet implemented.** This type exists as an API placeholder.
317/// All operations on [`RocmBackend`] currently return errors.
318///
319/// When implemented, will encapsulate ROCm-side execution resources: a HIP
320/// stream, GPU workspace buffer, and plan cache. Analogous to hipTENSOR's
321/// handle.
322///
323/// # Examples
324///
325/// ```ignore
326/// // Aspirational API — not yet functional.
327/// use tenferro_prims::RocmContext;
328///
329/// // Created internally by RocmBackend::load_hiptensor()
330/// ```
331pub struct RocmContext {
332    _stream: *mut c_void,
333    _workspace: Vec<u8>,
334    _plan_cache: PlanCache,
335    temp_pool: TempPool,
336}
337
338impl RocmContext {
339    /// Create a stub ROCm context (no-op).
340    pub fn new() -> Self {
341        Self {
342            _stream: std::ptr::null_mut(),
343            _workspace: Vec::new(),
344            _plan_cache: PlanCache::new(),
345            temp_pool: TempPool::default(),
346        }
347    }
348}
349
350impl Default for RocmContext {
351    fn default() -> Self {
352        Self::new()
353    }
354}
355
356impl TensorTempPoolContext for RocmContext {
357    fn take_temp_vec<T: Send + 'static>(&mut self, len: usize) -> Vec<T> {
358        self.temp_pool.take_vec::<T>(len)
359    }
360
361    fn put_temp_vec<T: Send + 'static>(&mut self, vec: Vec<T>) {
362        self.temp_pool.put_vec(vec);
363    }
364}
365
366/// ROCm plan — wraps a hipTENSOR plan handle.
367///
368/// **Status: Not yet implemented.** This type exists as an API placeholder.
369///
370/// Created by the semiring-family `plan` methods and consumed by `execute`.
371pub struct RocmPlan<T: Scalar> {
372    _handle: *mut c_void,
373    _workspace_size: usize,
374    _marker: PhantomData<T>,
375}
376
377/// ROCm backend using hipTENSOR via runtime dlopen.
378///
379/// **Status: Not yet implemented.** All methods currently return errors.
380/// The type exists to define the intended API surface. `plan()` and
381/// `execute()` return `Err(DeviceError)`. `load_hiptensor()` on
382/// [`crate::BackendRegistry`] also returns an error.
383///
384/// When implemented, will be loaded at runtime from a user-provided `.so`
385/// path with no compile-time ROCm SDK dependency. Will implement
386/// the semiring-family traits for standard arithmetic on AMD GPUs.
387///
388/// hipTENSOR natively supports contraction, reduction, and elementwise
389/// building blocks. Structural `permute` stays a tensor view, and any required
390/// materialization path is modeled through `MakeContiguous`. `AntiTrace` and
391/// `AntiDiag` will be composed via `Contract(eye, dC)`.
392///
393/// # Examples
394///
395/// ```ignore
396/// // Aspirational API — not yet functional.
397/// use tenferro_prims::{RocmBackend, BackendRegistry};
398///
399/// let mut registry = BackendRegistry::new();
400/// registry.load_hiptensor("/usr/lib/libhiptensor.so").unwrap();
401/// ```
402pub struct RocmBackend {
403    _handle: *mut c_void,
404    _lib: libloading::Library,
405}
406
407impl Drop for RocmBackend {
408    fn drop(&mut self) {
409        self._handle = std::ptr::null_mut();
410    }
411}
412
413/// # Safety
414///
415/// `RocmBackend` can be safely sent across threads because:
416/// - The `_handle` is an opaque pointer to a hipTENSOR handle
417/// - The `_lib` (`libloading::Library`) is thread-safe after loading
418/// - The handle is read-only after construction
419/// - Drop clears the handle before the library is unloaded, preventing use-after-free
420unsafe impl Send for RocmBackend {}
421
422/// # Safety
423///
424/// `RocmBackend` can be safely shared across threads because:
425/// - The hipTENSOR handle is designed for concurrent use from multiple threads
426/// - The library handle (`_lib`) is read-only after construction
427/// - Symbol lookup via `dlsym` is thread-safe on POSIX systems
428/// - Drop uses `&mut self`, ensuring exclusive access during cleanup
429unsafe impl Sync for RocmBackend {}
430
431impl RocmBackend {
432    /// Materialize a lazily-conjugated tensor on GPU.
433    ///
434    /// **Status: Stub fallback.** If data is CPU-accessible, this materializes
435    /// conjugation into a new non-conjugated tensor. Otherwise it returns a
436    /// clone of `src`.
437    ///
438    /// When implemented, will use the analytic/scalar family execution traits
439    /// to produce a new tensor with `conjugated = false`.
440    pub fn resolve_conj<T: Scalar + Conjugate>(
441        _ctx: &mut RocmContext,
442        src: &tenferro_tensor::Tensor<T>,
443    ) -> tenferro_tensor::Tensor<T> {
444        if !src.is_conjugated() {
445            return src.clone();
446        }
447
448        let contiguous = src.contiguous(tenferro_tensor::MemoryOrder::ColumnMajor);
449        let Some(data) = contiguous.buffer().as_slice() else {
450            return src.clone();
451        };
452        let conjugated_data: Vec<T> = data.iter().map(|&v| v.conj()).collect();
453        tenferro_tensor::Tensor::from_slice(
454            &conjugated_data,
455            src.dims(),
456            tenferro_tensor::MemoryOrder::ColumnMajor,
457        )
458        .unwrap_or_else(|_| src.clone())
459    }
460}
461
462impl<S: Scalar + num_traits::NumCast> TensorMetadataCastPrims<S> for RocmBackend {
463    type Plan = MetadataCastPrimsDescriptor;
464    type Context = RocmContext;
465
466    fn plan(
467        _ctx: &mut RocmContext,
468        desc: &MetadataCastPrimsDescriptor,
469        _shapes: &[&[usize]],
470    ) -> Result<Self::Plan> {
471        Err(Error::DeviceError(format!(
472            "metadata cast family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
473        )))
474    }
475
476    fn execute(
477        _ctx: &mut RocmContext,
478        _plan: &Self::Plan,
479        _alpha: S,
480        _inputs: &[MetadataScalarTensorRef<'_, S>],
481        _beta: S,
482        _output: &mut Tensor<S>,
483    ) -> Result<()> {
484        Err(Error::DeviceError(
485            "metadata cast family execution is not implemented on RocmBackend in phase 1".into(),
486        ))
487    }
488
489    fn has_metadata_cast_support(_desc: MetadataCastPrimsDescriptor) -> bool {
490        false
491    }
492}
493
494impl TensorMetadataPrims for RocmBackend {
495    type Plan = MetadataPrimsDescriptor;
496    type Context = RocmContext;
497
498    fn plan(
499        _ctx: &mut RocmContext,
500        desc: &MetadataPrimsDescriptor,
501        _inputs: &[MetadataTensorRef<'_>],
502        _output: MetadataTensorMut<'_>,
503    ) -> Result<Self::Plan> {
504        Err(Error::DeviceError(format!(
505            "metadata family descriptor {desc:?} is not implemented on RocmBackend"
506        )))
507    }
508
509    fn execute(
510        _ctx: &mut RocmContext,
511        _plan: &Self::Plan,
512        _inputs: &[MetadataTensorRef<'_>],
513        _output: MetadataTensorMut<'_>,
514    ) -> Result<()> {
515        Err(Error::DeviceError(
516            "metadata family execution is not implemented on RocmBackend".into(),
517        ))
518    }
519
520    fn has_metadata_support(_desc: MetadataPrimsDescriptor) -> bool {
521        false
522    }
523}
524
525impl<S: Scalar> TensorSemiringCore<Standard<S>> for RocmBackend {
526    type Plan = RocmPlan<S>;
527    type Context = RocmContext;
528
529    fn plan(
530        _ctx: &mut RocmContext,
531        _desc: &SemiringCoreDescriptor,
532        _shapes: &[&[usize]],
533    ) -> Result<RocmPlan<S>> {
534        Err(Error::DeviceError(
535            "ROCm backend not available: load hipTENSOR library first".into(),
536        ))
537    }
538
539    fn execute(
540        _ctx: &mut RocmContext,
541        _plan: &RocmPlan<S>,
542        _alpha: S,
543        _inputs: &[&Tensor<S>],
544        _beta: S,
545        _output: &mut Tensor<S>,
546    ) -> Result<()> {
547        Err(Error::DeviceError(
548            "ROCm backend not available: load hipTENSOR library first".into(),
549        ))
550    }
551}
552
553impl<S: Scalar> TensorSemiringFastPath<Standard<S>> for RocmBackend {
554    type Plan = RocmPlan<S>;
555    type Context = RocmContext;
556
557    fn plan(
558        _ctx: &mut RocmContext,
559        _desc: &SemiringFastPathDescriptor,
560        _shapes: &[&[usize]],
561    ) -> Result<RocmPlan<S>> {
562        Err(Error::DeviceError(
563            "ROCm backend not available: load hipTENSOR library first".into(),
564        ))
565    }
566
567    fn execute(
568        _ctx: &mut RocmContext,
569        _plan: &RocmPlan<S>,
570        _alpha: S,
571        _inputs: &[&Tensor<S>],
572        _beta: S,
573        _output: &mut Tensor<S>,
574    ) -> Result<()> {
575        Err(Error::DeviceError(
576            "ROCm backend not available: load hipTENSOR library first".into(),
577        ))
578    }
579
580    fn has_fast_path(_desc: SemiringFastPathDescriptor) -> bool {
581        // Not yet implemented. When available, hipTENSOR will support the
582        // semiring fast-path family for contraction and elementwise binary ops.
583        false
584    }
585}
586
587#[cfg(test)]
588mod tests;