1use num_complex::ComplexFloat;
2use tenferro_algebra::{Algebra, Conjugate, Scalar, Semiring};
3use tenferro_tensor::Tensor;
4
5use crate::{
6 CpuBackend, CpuContext, CudaBackend, CudaContext, RocmBackend, RocmContext,
7 TensorComplexRealPrims, TensorComplexScalePrims, TensorIndexingPrims, TensorMetadataContextFor,
8 TensorMetadataPrims, TensorScalarPrims, TensorSemiringCore, TensorSortPrims,
9};
10
11pub trait TensorSemiringContextFor<Alg: Semiring> {
32 type SemiringBackend: TensorSemiringCore<Alg, Context = Self>;
34}
35
36pub trait TensorScalarContextFor<Alg: Algebra> {
58 type ScalarBackend: TensorScalarPrims<Alg, Context = Self>;
60}
61
62pub trait TensorComplexRealContextFor<Input: ComplexFloat + Scalar> {
84 type ComplexRealBackend: TensorComplexRealPrims<Input, Context = Self, Real = Input::Real>;
86}
87
88pub trait TensorComplexScaleContextFor<Input: ComplexFloat + Scalar>
106where
107 Input::Real: Scalar + Send + Sync,
108{
109 type ComplexScaleBackend: TensorComplexScalePrims<Input, Context = Self>;
111}
112
113pub trait TensorResolveConjContextFor<T: Scalar + Conjugate> {
141 fn resolve_conj(ctx: &mut Self, src: &Tensor<T>) -> Tensor<T>;
143}
144
145pub trait TensorIndexingContextFor<Alg: Algebra> {
167 type IndexingBackend: TensorIndexingPrims<Alg, Context = Self>;
169}
170
171pub trait TensorSortContextFor<Alg: Algebra>
193where
194 Alg::Scalar: PartialOrd,
195{
196 type SortBackend: TensorSortPrims<Alg, Context = Self>;
198}
199
200impl<Alg> TensorSemiringContextFor<Alg> for CpuContext
201where
202 Alg: Semiring,
203 CpuBackend: TensorSemiringCore<Alg, Context = CpuContext>,
204{
205 type SemiringBackend = CpuBackend;
206}
207
208impl<Alg> TensorScalarContextFor<Alg> for CpuContext
209where
210 Alg: Algebra,
211 CpuBackend: TensorScalarPrims<Alg, Context = CpuContext>,
212{
213 type ScalarBackend = CpuBackend;
214}
215
216impl<Input> TensorComplexRealContextFor<Input> for CpuContext
217where
218 Input: ComplexFloat + Scalar,
219 Input::Real: Scalar,
220 CpuBackend: TensorComplexRealPrims<Input, Context = CpuContext, Real = Input::Real>,
221{
222 type ComplexRealBackend = CpuBackend;
223}
224
225impl<T> TensorResolveConjContextFor<T> for CpuContext
226where
227 T: Scalar + Conjugate,
228{
229 fn resolve_conj(ctx: &mut Self, src: &Tensor<T>) -> Tensor<T> {
230 CpuBackend::resolve_conj(ctx, src)
231 }
232}
233
234impl TensorMetadataContextFor for CpuContext
235where
236 CpuBackend: TensorMetadataPrims<Context = CpuContext>,
237{
238 type MetadataBackend = CpuBackend;
239}
240
241impl<Input> TensorComplexScaleContextFor<Input> for CpuContext
242where
243 Input: ComplexFloat + Scalar,
244 Input::Real: Scalar + Send + Sync,
245 CpuBackend: TensorComplexScalePrims<Input, Context = CpuContext>,
246{
247 type ComplexScaleBackend = CpuBackend;
248}
249
250impl<Alg> TensorIndexingContextFor<Alg> for CpuContext
251where
252 Alg: Algebra,
253 CpuBackend: TensorIndexingPrims<Alg, Context = CpuContext>,
254{
255 type IndexingBackend = CpuBackend;
256}
257
258impl<Alg> TensorSortContextFor<Alg> for CpuContext
259where
260 Alg: Algebra,
261 Alg::Scalar: PartialOrd,
262 CpuBackend: TensorSortPrims<Alg, Context = CpuContext>,
263{
264 type SortBackend = CpuBackend;
265}
266
267impl<Alg> TensorSemiringContextFor<Alg> for CudaContext
268where
269 Alg: Semiring,
270 CudaBackend: TensorSemiringCore<Alg, Context = CudaContext>,
271{
272 type SemiringBackend = CudaBackend;
273}
274
275impl<Alg> TensorScalarContextFor<Alg> for CudaContext
276where
277 Alg: Algebra,
278 CudaBackend: TensorScalarPrims<Alg, Context = CudaContext>,
279{
280 type ScalarBackend = CudaBackend;
281}
282
283impl<Input> TensorComplexRealContextFor<Input> for CudaContext
284where
285 Input: ComplexFloat + Scalar,
286 Input::Real: Scalar,
287 CudaBackend: TensorComplexRealPrims<Input, Context = CudaContext, Real = Input::Real>,
288{
289 type ComplexRealBackend = CudaBackend;
290}
291
292impl<T> TensorResolveConjContextFor<T> for CudaContext
293where
294 T: Scalar + Conjugate + 'static,
295{
296 fn resolve_conj(ctx: &mut Self, src: &Tensor<T>) -> Tensor<T> {
297 CudaBackend::resolve_conj(ctx, src)
298 }
299}
300
301impl TensorMetadataContextFor for CudaContext
302where
303 CudaBackend: TensorMetadataPrims<Context = CudaContext>,
304{
305 type MetadataBackend = CudaBackend;
306}
307
308impl<Input> TensorComplexScaleContextFor<Input> for CudaContext
309where
310 Input: ComplexFloat + Scalar,
311 Input::Real: Scalar + Send + Sync,
312 CudaBackend: TensorComplexScalePrims<Input, Context = CudaContext>,
313{
314 type ComplexScaleBackend = CudaBackend;
315}
316
317impl<Alg> TensorIndexingContextFor<Alg> for CudaContext
318where
319 Alg: Algebra,
320 CudaBackend: TensorIndexingPrims<Alg, Context = CudaContext>,
321{
322 type IndexingBackend = CudaBackend;
323}
324
325impl<Alg> TensorSortContextFor<Alg> for CudaContext
326where
327 Alg: Algebra,
328 Alg::Scalar: PartialOrd,
329 CudaBackend: TensorSortPrims<Alg, Context = CudaContext>,
330{
331 type SortBackend = CudaBackend;
332}
333
334impl<Alg> TensorSemiringContextFor<Alg> for RocmContext
335where
336 Alg: Semiring,
337 RocmBackend: TensorSemiringCore<Alg, Context = RocmContext>,
338{
339 type SemiringBackend = RocmBackend;
340}
341
342impl<Alg> TensorScalarContextFor<Alg> for RocmContext
343where
344 Alg: Algebra,
345 RocmBackend: TensorScalarPrims<Alg, Context = RocmContext>,
346{
347 type ScalarBackend = RocmBackend;
348}
349
350impl TensorMetadataContextFor for RocmContext
351where
352 RocmBackend: TensorMetadataPrims<Context = RocmContext>,
353{
354 type MetadataBackend = RocmBackend;
355}
356
357impl<Input> TensorComplexRealContextFor<Input> for RocmContext
358where
359 Input: ComplexFloat + Scalar,
360 Input::Real: Scalar,
361 RocmBackend: TensorComplexRealPrims<Input, Context = RocmContext, Real = Input::Real>,
362{
363 type ComplexRealBackend = RocmBackend;
364}
365
366impl<T> TensorResolveConjContextFor<T> for RocmContext
367where
368 T: Scalar + Conjugate,
369{
370 fn resolve_conj(ctx: &mut Self, src: &Tensor<T>) -> Tensor<T> {
371 RocmBackend::resolve_conj(ctx, src)
372 }
373}
374
375impl<Input> TensorComplexScaleContextFor<Input> for RocmContext
376where
377 Input: ComplexFloat + Scalar,
378 Input::Real: Scalar + Send + Sync,
379 RocmBackend: TensorComplexScalePrims<Input, Context = RocmContext>,
380{
381 type ComplexScaleBackend = RocmBackend;
382}
383
384impl<Alg> TensorIndexingContextFor<Alg> for RocmContext
385where
386 Alg: Algebra,
387 RocmBackend: TensorIndexingPrims<Alg, Context = RocmContext>,
388{
389 type IndexingBackend = RocmBackend;
390}
391
392impl<Alg> TensorSortContextFor<Alg> for RocmContext
393where
394 Alg: Algebra,
395 Alg::Scalar: PartialOrd,
396 RocmBackend: TensorSortPrims<Alg, Context = RocmContext>,
397{
398 type SortBackend = RocmBackend;
399}