tenferro_prims/families/
context.rs

1use num_complex::ComplexFloat;
2use tenferro_algebra::{Algebra, Conjugate, Scalar, Semiring};
3use tenferro_tensor::Tensor;
4
5use crate::{
6    CpuBackend, CpuContext, CudaBackend, CudaContext, RocmBackend, RocmContext,
7    TensorComplexRealPrims, TensorComplexScalePrims, TensorIndexingPrims, TensorMetadataContextFor,
8    TensorMetadataPrims, TensorScalarPrims, TensorSemiringCore, TensorSortPrims,
9};
10
11/// Bridge trait that binds a semiring execution context to its backend.
12///
13/// High-level crates use this trait to stay generic over runtime context types
14/// while still dispatching semiring execution through the correct backend
15/// marker type.
16///
17/// # Examples
18///
19/// ```ignore
20/// use tenferro_prims::{CpuContext, TensorSemiringContextFor};
21///
22/// fn accepts_context<C>(_: &mut C)
23/// where
24///     C: TensorSemiringContextFor<tenferro_algebra::Standard<f64>>,
25/// {
26/// }
27///
28/// let mut ctx = CpuContext::new(1);
29/// accepts_context(&mut ctx);
30/// ```
31pub trait TensorSemiringContextFor<Alg: Semiring> {
32    /// Backend associated with this context for the given algebra family.
33    type SemiringBackend: TensorSemiringCore<Alg, Context = Self>;
34}
35
36/// Bridge trait that binds a scalar-family execution context to its backend.
37///
38/// High-level crates use this trait to stay generic over runtime context types
39/// while dispatching pointwise and reduction scalar families through the
40/// correct backend marker type.
41///
42/// # Examples
43///
44/// ```ignore
45/// use tenferro_algebra::Standard;
46/// use tenferro_prims::{CpuContext, TensorScalarContextFor};
47///
48/// fn accepts_context<C>(_: &mut C)
49/// where
50///     C: TensorScalarContextFor<Standard<f64>>,
51/// {
52/// }
53///
54/// let mut ctx = CpuContext::new(1);
55/// accepts_context(&mut ctx);
56/// ```
57pub trait TensorScalarContextFor<Alg: Algebra> {
58    /// Backend associated with this context for the scalar family.
59    type ScalarBackend: TensorScalarPrims<Alg, Context = Self>;
60}
61
62/// Bridge trait that binds a complex-to-real execution context to its backend.
63///
64/// High-level crates use this trait to stay generic over runtime context types
65/// while dispatching cross-dtype complex-to-real families through the correct
66/// backend marker type.
67///
68/// # Examples
69///
70/// ```ignore
71/// use num_complex::Complex64;
72/// use tenferro_prims::{CpuContext, TensorComplexRealContextFor};
73///
74/// fn accepts_context<C>(_: &mut C)
75/// where
76///     C: TensorComplexRealContextFor<Complex64>,
77/// {
78/// }
79///
80/// let mut ctx = CpuContext::new(1);
81/// accepts_context(&mut ctx);
82/// ```
83pub trait TensorComplexRealContextFor<Input: ComplexFloat + Scalar> {
84    /// Backend associated with this context for the complex-to-real family.
85    type ComplexRealBackend: TensorComplexRealPrims<Input, Context = Self, Real = Input::Real>;
86}
87
88/// Bridge trait that binds a complex-by-real execution context to its backend.
89///
90/// # Examples
91///
92/// ```ignore
93/// use num_complex::Complex64;
94/// use tenferro_prims::{CpuContext, TensorComplexScaleContextFor};
95///
96/// fn accepts_context<C>(_: &mut C)
97/// where
98///     C: TensorComplexScaleContextFor<Complex64>,
99/// {
100/// }
101///
102/// let mut ctx = CpuContext::new(1);
103/// accepts_context(&mut ctx);
104/// ```
105pub trait TensorComplexScaleContextFor<Input: ComplexFloat + Scalar>
106where
107    Input::Real: Scalar + Send + Sync,
108{
109    /// Backend associated with the complex-by-real family.
110    type ComplexScaleBackend: TensorComplexScalePrims<Input, Context = Self>;
111}
112
113/// Bridge trait for backend-specific lazy-conjugation resolution.
114///
115/// High-level crates use this trait to stay generic over runtime context types
116/// while still materializing unresolved conjugation through the correct backend
117/// utility.
118///
119/// # Examples
120///
121/// ```ignore
122/// use num_complex::Complex64;
123/// use tenferro_prims::{CpuContext, TensorResolveConjContextFor};
124/// use tenferro_tensor::{MemoryOrder, Tensor};
125///
126/// let mut ctx = CpuContext::new(1);
127/// let base = Tensor::from_slice(
128///     &[Complex64::new(1.0, 2.0), Complex64::new(3.0, -4.0)],
129///     &[2],
130///     MemoryOrder::ColumnMajor,
131/// )
132/// .unwrap();
133/// let lazy = base.conj();
134/// let resolved = <CpuContext as TensorResolveConjContextFor<Complex64>>::resolve_conj(
135///     &mut ctx,
136///     &lazy,
137/// );
138/// assert!(!resolved.is_conjugated());
139/// ```
140pub trait TensorResolveConjContextFor<T: Scalar + Conjugate> {
141    /// Materialize a lazily-conjugated tensor using the backend tied to `Self`.
142    fn resolve_conj(ctx: &mut Self, src: &Tensor<T>) -> Tensor<T>;
143}
144
145/// Bridge trait that binds an indexing-family execution context to its backend.
146///
147/// High-level crates use this trait to stay generic over runtime context types
148/// while dispatching index-based selection, gathering, and scattering through
149/// the correct backend marker type.
150///
151/// # Examples
152///
153/// ```ignore
154/// use tenferro_algebra::Standard;
155/// use tenferro_prims::{CpuContext, TensorIndexingContextFor};
156///
157/// fn accepts_context<C>(_: &mut C)
158/// where
159///     C: TensorIndexingContextFor<Standard<f64>>,
160/// {
161/// }
162///
163/// let mut ctx = CpuContext::new(1);
164/// accepts_context(&mut ctx);
165/// ```
166pub trait TensorIndexingContextFor<Alg: Algebra> {
167    /// Backend associated with this context for the indexing family.
168    type IndexingBackend: TensorIndexingPrims<Alg, Context = Self>;
169}
170
171/// Bridge trait that binds a sort-family execution context to its backend.
172///
173/// High-level crates use this trait to stay generic over runtime context types
174/// while dispatching sort, argsort, and top-k operations through the correct
175/// backend marker type.
176///
177/// # Examples
178///
179/// ```ignore
180/// use tenferro_algebra::Standard;
181/// use tenferro_prims::{CpuContext, TensorSortContextFor};
182///
183/// fn accepts_context<C>(_: &mut C)
184/// where
185///     C: TensorSortContextFor<Standard<f64>>,
186/// {
187/// }
188///
189/// let mut ctx = CpuContext::new(1);
190/// accepts_context(&mut ctx);
191/// ```
192pub trait TensorSortContextFor<Alg: Algebra>
193where
194    Alg::Scalar: PartialOrd,
195{
196    /// Backend associated with this context for the sort family.
197    type SortBackend: TensorSortPrims<Alg, Context = Self>;
198}
199
200impl<Alg> TensorSemiringContextFor<Alg> for CpuContext
201where
202    Alg: Semiring,
203    CpuBackend: TensorSemiringCore<Alg, Context = CpuContext>,
204{
205    type SemiringBackend = CpuBackend;
206}
207
208impl<Alg> TensorScalarContextFor<Alg> for CpuContext
209where
210    Alg: Algebra,
211    CpuBackend: TensorScalarPrims<Alg, Context = CpuContext>,
212{
213    type ScalarBackend = CpuBackend;
214}
215
216impl<Input> TensorComplexRealContextFor<Input> for CpuContext
217where
218    Input: ComplexFloat + Scalar,
219    Input::Real: Scalar,
220    CpuBackend: TensorComplexRealPrims<Input, Context = CpuContext, Real = Input::Real>,
221{
222    type ComplexRealBackend = CpuBackend;
223}
224
225impl<T> TensorResolveConjContextFor<T> for CpuContext
226where
227    T: Scalar + Conjugate,
228{
229    fn resolve_conj(ctx: &mut Self, src: &Tensor<T>) -> Tensor<T> {
230        CpuBackend::resolve_conj(ctx, src)
231    }
232}
233
234impl TensorMetadataContextFor for CpuContext
235where
236    CpuBackend: TensorMetadataPrims<Context = CpuContext>,
237{
238    type MetadataBackend = CpuBackend;
239}
240
241impl<Input> TensorComplexScaleContextFor<Input> for CpuContext
242where
243    Input: ComplexFloat + Scalar,
244    Input::Real: Scalar + Send + Sync,
245    CpuBackend: TensorComplexScalePrims<Input, Context = CpuContext>,
246{
247    type ComplexScaleBackend = CpuBackend;
248}
249
250impl<Alg> TensorIndexingContextFor<Alg> for CpuContext
251where
252    Alg: Algebra,
253    CpuBackend: TensorIndexingPrims<Alg, Context = CpuContext>,
254{
255    type IndexingBackend = CpuBackend;
256}
257
258impl<Alg> TensorSortContextFor<Alg> for CpuContext
259where
260    Alg: Algebra,
261    Alg::Scalar: PartialOrd,
262    CpuBackend: TensorSortPrims<Alg, Context = CpuContext>,
263{
264    type SortBackend = CpuBackend;
265}
266
267impl<Alg> TensorSemiringContextFor<Alg> for CudaContext
268where
269    Alg: Semiring,
270    CudaBackend: TensorSemiringCore<Alg, Context = CudaContext>,
271{
272    type SemiringBackend = CudaBackend;
273}
274
275impl<Alg> TensorScalarContextFor<Alg> for CudaContext
276where
277    Alg: Algebra,
278    CudaBackend: TensorScalarPrims<Alg, Context = CudaContext>,
279{
280    type ScalarBackend = CudaBackend;
281}
282
283impl<Input> TensorComplexRealContextFor<Input> for CudaContext
284where
285    Input: ComplexFloat + Scalar,
286    Input::Real: Scalar,
287    CudaBackend: TensorComplexRealPrims<Input, Context = CudaContext, Real = Input::Real>,
288{
289    type ComplexRealBackend = CudaBackend;
290}
291
292impl<T> TensorResolveConjContextFor<T> for CudaContext
293where
294    T: Scalar + Conjugate + 'static,
295{
296    fn resolve_conj(ctx: &mut Self, src: &Tensor<T>) -> Tensor<T> {
297        CudaBackend::resolve_conj(ctx, src)
298    }
299}
300
301impl TensorMetadataContextFor for CudaContext
302where
303    CudaBackend: TensorMetadataPrims<Context = CudaContext>,
304{
305    type MetadataBackend = CudaBackend;
306}
307
308impl<Input> TensorComplexScaleContextFor<Input> for CudaContext
309where
310    Input: ComplexFloat + Scalar,
311    Input::Real: Scalar + Send + Sync,
312    CudaBackend: TensorComplexScalePrims<Input, Context = CudaContext>,
313{
314    type ComplexScaleBackend = CudaBackend;
315}
316
317impl<Alg> TensorIndexingContextFor<Alg> for CudaContext
318where
319    Alg: Algebra,
320    CudaBackend: TensorIndexingPrims<Alg, Context = CudaContext>,
321{
322    type IndexingBackend = CudaBackend;
323}
324
325impl<Alg> TensorSortContextFor<Alg> for CudaContext
326where
327    Alg: Algebra,
328    Alg::Scalar: PartialOrd,
329    CudaBackend: TensorSortPrims<Alg, Context = CudaContext>,
330{
331    type SortBackend = CudaBackend;
332}
333
334impl<Alg> TensorSemiringContextFor<Alg> for RocmContext
335where
336    Alg: Semiring,
337    RocmBackend: TensorSemiringCore<Alg, Context = RocmContext>,
338{
339    type SemiringBackend = RocmBackend;
340}
341
342impl<Alg> TensorScalarContextFor<Alg> for RocmContext
343where
344    Alg: Algebra,
345    RocmBackend: TensorScalarPrims<Alg, Context = RocmContext>,
346{
347    type ScalarBackend = RocmBackend;
348}
349
350impl TensorMetadataContextFor for RocmContext
351where
352    RocmBackend: TensorMetadataPrims<Context = RocmContext>,
353{
354    type MetadataBackend = RocmBackend;
355}
356
357impl<Input> TensorComplexRealContextFor<Input> for RocmContext
358where
359    Input: ComplexFloat + Scalar,
360    Input::Real: Scalar,
361    RocmBackend: TensorComplexRealPrims<Input, Context = RocmContext, Real = Input::Real>,
362{
363    type ComplexRealBackend = RocmBackend;
364}
365
366impl<T> TensorResolveConjContextFor<T> for RocmContext
367where
368    T: Scalar + Conjugate,
369{
370    fn resolve_conj(ctx: &mut Self, src: &Tensor<T>) -> Tensor<T> {
371        RocmBackend::resolve_conj(ctx, src)
372    }
373}
374
375impl<Input> TensorComplexScaleContextFor<Input> for RocmContext
376where
377    Input: ComplexFloat + Scalar,
378    Input::Real: Scalar + Send + Sync,
379    RocmBackend: TensorComplexScalePrims<Input, Context = RocmContext>,
380{
381    type ComplexScaleBackend = RocmBackend;
382}
383
384impl<Alg> TensorIndexingContextFor<Alg> for RocmContext
385where
386    Alg: Algebra,
387    RocmBackend: TensorIndexingPrims<Alg, Context = RocmContext>,
388{
389    type IndexingBackend = RocmBackend;
390}
391
392impl<Alg> TensorSortContextFor<Alg> for RocmContext
393where
394    Alg: Algebra,
395    Alg::Scalar: PartialOrd,
396    RocmBackend: TensorSortPrims<Alg, Context = RocmContext>,
397{
398    type SortBackend = RocmBackend;
399}