1use std::ffi::c_void;
2use std::marker::PhantomData;
3
4use tenferro_algebra::{Conjugate, Scalar, Standard};
5use tenferro_device::{Error, Result};
6use tenferro_tensor::Tensor;
7
8use crate::cpu::TempPool;
9use crate::{
10 MetadataCastPrimsDescriptor, MetadataPrimsDescriptor, MetadataScalarTensorRef,
11 MetadataTensorMut, MetadataTensorRef, PlanCache, SemiringCoreDescriptor,
12 SemiringFastPathDescriptor, TensorMetadataCastPrims, TensorMetadataPrims, TensorSemiringCore,
13 TensorSemiringFastPath, TensorTempPoolContext,
14};
15
16#[cfg(not(feature = "cuda"))]
33pub struct CudaContext {
34 _stream: *mut c_void,
35 _workspace: Vec<u8>,
36 _plan_cache: PlanCache,
37 temp_pool: TempPool,
38}
39
40#[cfg(not(feature = "cuda"))]
41impl CudaContext {
42 pub fn new() -> Self {
44 Self {
45 _stream: std::ptr::null_mut(),
46 _workspace: Vec::new(),
47 _plan_cache: PlanCache::new(),
48 temp_pool: TempPool::default(),
49 }
50 }
51
52 pub fn device_id(&self) -> usize {
61 0
62 }
63
64 pub fn bind_to_device(&self) -> Result<()> {
73 Err(Error::DeviceError(
74 "CUDA feature is not enabled for this build".into(),
75 ))
76 }
77}
78
79#[cfg(not(feature = "cuda"))]
80impl Default for CudaContext {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86#[cfg(not(feature = "cuda"))]
87impl TensorTempPoolContext for CudaContext {
88 fn take_temp_vec<T: Send + 'static>(&mut self, len: usize) -> Vec<T> {
89 self.temp_pool.take_vec::<T>(len)
90 }
91
92 fn put_temp_vec<T: Send + 'static>(&mut self, vec: Vec<T>) {
93 self.temp_pool.put_vec(vec);
94 }
95}
96
97#[cfg(not(feature = "cuda"))]
101pub struct CudaPlan<T: Scalar> {
102 _handle: *mut c_void,
103 _workspace_size: usize,
104 _marker: PhantomData<T>,
105}
106
107#[cfg(not(feature = "cuda"))]
122pub struct CudaBackend {
123 _handle: *mut c_void,
124 _lib: libloading::Library,
125}
126
127#[cfg(not(feature = "cuda"))]
128impl Drop for CudaBackend {
129 fn drop(&mut self) {
130 self._handle = std::ptr::null_mut();
131 }
132}
133
134#[cfg(not(feature = "cuda"))]
142unsafe impl Send for CudaBackend {}
143
144#[cfg(not(feature = "cuda"))]
152unsafe impl Sync for CudaBackend {}
153
154#[cfg(not(feature = "cuda"))]
155impl CudaBackend {
156 pub fn resolve_conj<T: Scalar + Conjugate>(
162 _ctx: &mut CudaContext,
163 src: &tenferro_tensor::Tensor<T>,
164 ) -> tenferro_tensor::Tensor<T> {
165 if !src.is_conjugated() {
166 return src.clone();
167 }
168
169 let contiguous = src.contiguous(tenferro_tensor::MemoryOrder::ColumnMajor);
170 let Some(data) = contiguous.buffer().as_slice() else {
171 return src.clone();
172 };
173 let conjugated_data: Vec<T> = data.iter().map(|&v| v.conj()).collect();
174 tenferro_tensor::Tensor::from_slice(
175 &conjugated_data,
176 src.dims(),
177 tenferro_tensor::MemoryOrder::ColumnMajor,
178 )
179 .unwrap_or_else(|_| src.clone())
180 }
181}
182
183#[cfg(not(feature = "cuda"))]
184impl<S: Scalar> TensorSemiringCore<Standard<S>> for CudaBackend {
185 type Plan = CudaPlan<S>;
186 type Context = CudaContext;
187
188 fn plan(
189 _ctx: &mut CudaContext,
190 _desc: &SemiringCoreDescriptor,
191 _shapes: &[&[usize]],
192 ) -> Result<CudaPlan<S>> {
193 Err(Error::DeviceError(
194 "CUDA backend not available: load cuTENSOR library first".into(),
195 ))
196 }
197
198 fn execute(
199 _ctx: &mut CudaContext,
200 _plan: &CudaPlan<S>,
201 _alpha: S,
202 _inputs: &[&Tensor<S>],
203 _beta: S,
204 _output: &mut Tensor<S>,
205 ) -> Result<()> {
206 Err(Error::DeviceError(
207 "CUDA backend not available: load cuTENSOR library first".into(),
208 ))
209 }
210}
211
212#[cfg(not(feature = "cuda"))]
213impl TensorMetadataPrims for CudaBackend {
214 type Plan = MetadataPrimsDescriptor;
215 type Context = CudaContext;
216
217 fn plan(
218 _ctx: &mut CudaContext,
219 desc: &MetadataPrimsDescriptor,
220 _inputs: &[MetadataTensorRef<'_>],
221 _output: MetadataTensorMut<'_>,
222 ) -> Result<Self::Plan> {
223 Err(Error::DeviceError(format!(
224 "metadata family descriptor {desc:?} is not implemented on stub CudaBackend"
225 )))
226 }
227
228 fn execute(
229 _ctx: &mut CudaContext,
230 _plan: &Self::Plan,
231 _inputs: &[MetadataTensorRef<'_>],
232 _output: MetadataTensorMut<'_>,
233 ) -> Result<()> {
234 Err(Error::DeviceError(
235 "metadata family execution is not implemented on stub CudaBackend".into(),
236 ))
237 }
238
239 fn has_metadata_support(_desc: MetadataPrimsDescriptor) -> bool {
240 false
241 }
242}
243
244#[cfg(not(feature = "cuda"))]
245impl<S: Scalar> TensorSemiringFastPath<Standard<S>> for CudaBackend {
246 type Plan = CudaPlan<S>;
247 type Context = CudaContext;
248
249 fn plan(
250 _ctx: &mut CudaContext,
251 _desc: &SemiringFastPathDescriptor,
252 _shapes: &[&[usize]],
253 ) -> Result<CudaPlan<S>> {
254 Err(Error::DeviceError(
255 "CUDA backend not available: load cuTENSOR library first".into(),
256 ))
257 }
258
259 fn execute(
260 _ctx: &mut CudaContext,
261 _plan: &CudaPlan<S>,
262 _alpha: S,
263 _inputs: &[&Tensor<S>],
264 _beta: S,
265 _output: &mut Tensor<S>,
266 ) -> Result<()> {
267 Err(Error::DeviceError(
268 "CUDA backend not available: load cuTENSOR library first".into(),
269 ))
270 }
271
272 fn has_fast_path(_desc: SemiringFastPathDescriptor) -> bool {
273 false
274 }
275}
276
277#[cfg(not(feature = "cuda"))]
278impl<S: Scalar + num_traits::NumCast> TensorMetadataCastPrims<S> for CudaBackend {
279 type Plan = MetadataCastPrimsDescriptor;
280 type Context = CudaContext;
281
282 fn plan(
283 _ctx: &mut CudaContext,
284 desc: &MetadataCastPrimsDescriptor,
285 _shapes: &[&[usize]],
286 ) -> Result<Self::Plan> {
287 Err(Error::DeviceError(format!(
288 "metadata cast family descriptor {desc:?} is not implemented on CudaBackend in phase 1"
289 )))
290 }
291
292 fn execute(
293 _ctx: &mut CudaContext,
294 _plan: &Self::Plan,
295 _alpha: S,
296 _inputs: &[MetadataScalarTensorRef<'_, S>],
297 _beta: S,
298 _output: &mut Tensor<S>,
299 ) -> Result<()> {
300 Err(Error::DeviceError(
301 "metadata cast family execution is not implemented on CudaBackend in phase 1".into(),
302 ))
303 }
304
305 fn has_metadata_cast_support(_desc: MetadataCastPrimsDescriptor) -> bool {
306 false
307 }
308}
309
310pub struct RocmContext {
332 _stream: *mut c_void,
333 _workspace: Vec<u8>,
334 _plan_cache: PlanCache,
335 temp_pool: TempPool,
336}
337
338impl RocmContext {
339 pub fn new() -> Self {
341 Self {
342 _stream: std::ptr::null_mut(),
343 _workspace: Vec::new(),
344 _plan_cache: PlanCache::new(),
345 temp_pool: TempPool::default(),
346 }
347 }
348}
349
350impl Default for RocmContext {
351 fn default() -> Self {
352 Self::new()
353 }
354}
355
356impl TensorTempPoolContext for RocmContext {
357 fn take_temp_vec<T: Send + 'static>(&mut self, len: usize) -> Vec<T> {
358 self.temp_pool.take_vec::<T>(len)
359 }
360
361 fn put_temp_vec<T: Send + 'static>(&mut self, vec: Vec<T>) {
362 self.temp_pool.put_vec(vec);
363 }
364}
365
366pub struct RocmPlan<T: Scalar> {
372 _handle: *mut c_void,
373 _workspace_size: usize,
374 _marker: PhantomData<T>,
375}
376
377pub struct RocmBackend {
403 _handle: *mut c_void,
404 _lib: libloading::Library,
405}
406
407impl Drop for RocmBackend {
408 fn drop(&mut self) {
409 self._handle = std::ptr::null_mut();
410 }
411}
412
413unsafe impl Send for RocmBackend {}
421
422unsafe impl Sync for RocmBackend {}
430
431impl RocmBackend {
432 pub fn resolve_conj<T: Scalar + Conjugate>(
441 _ctx: &mut RocmContext,
442 src: &tenferro_tensor::Tensor<T>,
443 ) -> tenferro_tensor::Tensor<T> {
444 if !src.is_conjugated() {
445 return src.clone();
446 }
447
448 let contiguous = src.contiguous(tenferro_tensor::MemoryOrder::ColumnMajor);
449 let Some(data) = contiguous.buffer().as_slice() else {
450 return src.clone();
451 };
452 let conjugated_data: Vec<T> = data.iter().map(|&v| v.conj()).collect();
453 tenferro_tensor::Tensor::from_slice(
454 &conjugated_data,
455 src.dims(),
456 tenferro_tensor::MemoryOrder::ColumnMajor,
457 )
458 .unwrap_or_else(|_| src.clone())
459 }
460}
461
462impl<S: Scalar + num_traits::NumCast> TensorMetadataCastPrims<S> for RocmBackend {
463 type Plan = MetadataCastPrimsDescriptor;
464 type Context = RocmContext;
465
466 fn plan(
467 _ctx: &mut RocmContext,
468 desc: &MetadataCastPrimsDescriptor,
469 _shapes: &[&[usize]],
470 ) -> Result<Self::Plan> {
471 Err(Error::DeviceError(format!(
472 "metadata cast family descriptor {desc:?} is not implemented on RocmBackend in phase 1"
473 )))
474 }
475
476 fn execute(
477 _ctx: &mut RocmContext,
478 _plan: &Self::Plan,
479 _alpha: S,
480 _inputs: &[MetadataScalarTensorRef<'_, S>],
481 _beta: S,
482 _output: &mut Tensor<S>,
483 ) -> Result<()> {
484 Err(Error::DeviceError(
485 "metadata cast family execution is not implemented on RocmBackend in phase 1".into(),
486 ))
487 }
488
489 fn has_metadata_cast_support(_desc: MetadataCastPrimsDescriptor) -> bool {
490 false
491 }
492}
493
494impl TensorMetadataPrims for RocmBackend {
495 type Plan = MetadataPrimsDescriptor;
496 type Context = RocmContext;
497
498 fn plan(
499 _ctx: &mut RocmContext,
500 desc: &MetadataPrimsDescriptor,
501 _inputs: &[MetadataTensorRef<'_>],
502 _output: MetadataTensorMut<'_>,
503 ) -> Result<Self::Plan> {
504 Err(Error::DeviceError(format!(
505 "metadata family descriptor {desc:?} is not implemented on RocmBackend"
506 )))
507 }
508
509 fn execute(
510 _ctx: &mut RocmContext,
511 _plan: &Self::Plan,
512 _inputs: &[MetadataTensorRef<'_>],
513 _output: MetadataTensorMut<'_>,
514 ) -> Result<()> {
515 Err(Error::DeviceError(
516 "metadata family execution is not implemented on RocmBackend".into(),
517 ))
518 }
519
520 fn has_metadata_support(_desc: MetadataPrimsDescriptor) -> bool {
521 false
522 }
523}
524
525impl<S: Scalar> TensorSemiringCore<Standard<S>> for RocmBackend {
526 type Plan = RocmPlan<S>;
527 type Context = RocmContext;
528
529 fn plan(
530 _ctx: &mut RocmContext,
531 _desc: &SemiringCoreDescriptor,
532 _shapes: &[&[usize]],
533 ) -> Result<RocmPlan<S>> {
534 Err(Error::DeviceError(
535 "ROCm backend not available: load hipTENSOR library first".into(),
536 ))
537 }
538
539 fn execute(
540 _ctx: &mut RocmContext,
541 _plan: &RocmPlan<S>,
542 _alpha: S,
543 _inputs: &[&Tensor<S>],
544 _beta: S,
545 _output: &mut Tensor<S>,
546 ) -> Result<()> {
547 Err(Error::DeviceError(
548 "ROCm backend not available: load hipTENSOR library first".into(),
549 ))
550 }
551}
552
553impl<S: Scalar> TensorSemiringFastPath<Standard<S>> for RocmBackend {
554 type Plan = RocmPlan<S>;
555 type Context = RocmContext;
556
557 fn plan(
558 _ctx: &mut RocmContext,
559 _desc: &SemiringFastPathDescriptor,
560 _shapes: &[&[usize]],
561 ) -> Result<RocmPlan<S>> {
562 Err(Error::DeviceError(
563 "ROCm backend not available: load hipTENSOR library first".into(),
564 ))
565 }
566
567 fn execute(
568 _ctx: &mut RocmContext,
569 _plan: &RocmPlan<S>,
570 _alpha: S,
571 _inputs: &[&Tensor<S>],
572 _beta: S,
573 _output: &mut Tensor<S>,
574 ) -> Result<()> {
575 Err(Error::DeviceError(
576 "ROCm backend not available: load hipTENSOR library first".into(),
577 ))
578 }
579
580 fn has_fast_path(_desc: SemiringFastPathDescriptor) -> bool {
581 false
584 }
585}
586
587#[cfg(test)]
588mod tests;