tenferro_prims/families/
rng.rs

1use tenferro_algebra::{Algebra, Standard};
2use tenferro_device::{Error, Generator, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext, RocmBackend, RocmContext};
7#[cfg(feature = "cuda")]
8use crate::{RocmBackend, RocmContext};
9
10/// Random-number generation descriptors for dense eager tensor construction.
11///
12/// # Examples
13///
14/// ```rust
15/// use tenferro_prims::RngPrimsDescriptor;
16///
17/// let desc = RngPrimsDescriptor::Uniform;
18/// assert!(matches!(desc, RngPrimsDescriptor::Uniform));
19/// ```
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub enum RngPrimsDescriptor {
22    /// Fill the output tensor with samples from the half-open interval `[0, 1)`.
23    Uniform,
24    /// Fill the output tensor with standard normal samples.
25    Normal,
26    /// Fill the output tensor with integer samples from the half-open interval `[low, high)`.
27    Integer {
28        /// Inclusive lower bound.
29        low: i32,
30        /// Exclusive upper bound.
31        high: i32,
32    },
33}
34
35/// Tensor RNG execution family.
36///
37/// The family is plan/execute based like the rest of `tenferro-prims`, but the
38/// plan is intentionally small: it only records the descriptor and the output
39/// shape so the execution side can validate the destination tensor before
40/// writing into it.
41///
42/// # Examples
43///
44/// ```ignore
45/// use tenferro_algebra::Standard;
46/// use tenferro_device::Generator;
47/// use tenferro_prims::{CpuBackend, CpuContext, RngPrimsDescriptor, TensorRngPrims};
48/// use tenferro_tensor::{MemoryOrder, Tensor};
49///
50/// let mut ctx = CpuContext::new(1);
51/// let mut generator = Generator::cpu(1234);
52/// let desc = RngPrimsDescriptor::Uniform;
53/// let mut output = Tensor::<f64>::zeros(
54///     &[8],
55///     tenferro_device::LogicalMemorySpace::MainMemory,
56///     MemoryOrder::ColumnMajor,
57/// ).unwrap();
58/// let plan = <CpuBackend as TensorRngPrims<Standard<f64>>>::plan(
59///     &mut ctx,
60///     &desc,
61///     &[output.dims()],
62/// ).unwrap();
63/// <CpuBackend as TensorRngPrims<Standard<f64>>>::execute(
64///     &mut ctx,
65///     &plan,
66///     &mut generator,
67///     &mut output,
68/// ).unwrap();
69/// ```
70pub trait TensorRngPrims<Alg: Algebra> {
71    /// Backend-specific execution plan.
72    type Plan;
73    /// Backend execution context.
74    type Context;
75
76    /// Plan a tensor RNG operation for the given output shape.
77    fn plan(
78        ctx: &mut Self::Context,
79        desc: &RngPrimsDescriptor,
80        shapes: &[&[usize]],
81    ) -> Result<Self::Plan>;
82
83    /// Execute a previously planned RNG operation.
84    fn execute(
85        ctx: &mut Self::Context,
86        plan: &Self::Plan,
87        generator: &mut Generator,
88        output: &mut Tensor<Alg::Scalar>,
89    ) -> Result<()>;
90
91    /// Report whether the backend advertises support for the given descriptor.
92    fn has_rng_support(desc: RngPrimsDescriptor) -> bool;
93}
94
95#[cfg(not(feature = "cuda"))]
96impl TensorRngPrims<Standard<f64>> for CudaBackend {
97    type Plan = (RngPrimsDescriptor, Vec<usize>);
98    type Context = CudaContext;
99
100    fn plan(
101        _ctx: &mut Self::Context,
102        desc: &RngPrimsDescriptor,
103        _shapes: &[&[usize]],
104    ) -> Result<Self::Plan> {
105        Err(Error::InvalidArgument(format!(
106            "RNG descriptor {desc:?} is not implemented on CudaBackend in phase 1"
107        )))
108    }
109
110    fn execute(
111        _ctx: &mut Self::Context,
112        _plan: &Self::Plan,
113        _generator: &mut Generator,
114        _output: &mut Tensor<f64>,
115    ) -> Result<()> {
116        Err(Error::InvalidArgument(
117            "RNG execution is not implemented on CudaBackend in phase 1".into(),
118        ))
119    }
120
121    fn has_rng_support(_desc: RngPrimsDescriptor) -> bool {
122        false
123    }
124}
125
126#[cfg(not(feature = "cuda"))]
127impl TensorRngPrims<Standard<i32>> for CudaBackend {
128    type Plan = (RngPrimsDescriptor, Vec<usize>);
129    type Context = CudaContext;
130
131    fn plan(
132        _ctx: &mut Self::Context,
133        desc: &RngPrimsDescriptor,
134        _shapes: &[&[usize]],
135    ) -> Result<Self::Plan> {
136        Err(Error::InvalidArgument(format!(
137            "RNG descriptor {desc:?} is not implemented on CudaBackend in phase 1"
138        )))
139    }
140
141    fn execute(
142        _ctx: &mut Self::Context,
143        _plan: &Self::Plan,
144        _generator: &mut Generator,
145        _output: &mut Tensor<i32>,
146    ) -> Result<()> {
147        Err(Error::InvalidArgument(
148            "RNG execution is not implemented on CudaBackend in phase 1".into(),
149        ))
150    }
151
152    fn has_rng_support(_desc: RngPrimsDescriptor) -> bool {
153        false
154    }
155}
156
157impl TensorRngPrims<Standard<f64>> for RocmBackend {
158    type Plan = (RngPrimsDescriptor, Vec<usize>);
159    type Context = RocmContext;
160
161    fn plan(
162        _ctx: &mut Self::Context,
163        desc: &RngPrimsDescriptor,
164        _shapes: &[&[usize]],
165    ) -> Result<Self::Plan> {
166        Err(Error::InvalidArgument(format!(
167            "RNG descriptor {desc:?} is not implemented on RocmBackend in phase 1"
168        )))
169    }
170
171    fn execute(
172        _ctx: &mut Self::Context,
173        _plan: &Self::Plan,
174        _generator: &mut Generator,
175        _output: &mut Tensor<f64>,
176    ) -> Result<()> {
177        Err(Error::InvalidArgument(
178            "RNG execution is not implemented on RocmBackend in phase 1".into(),
179        ))
180    }
181
182    fn has_rng_support(_desc: RngPrimsDescriptor) -> bool {
183        false
184    }
185}
186
187impl TensorRngPrims<Standard<i32>> for RocmBackend {
188    type Plan = (RngPrimsDescriptor, Vec<usize>);
189    type Context = RocmContext;
190
191    fn plan(
192        _ctx: &mut Self::Context,
193        desc: &RngPrimsDescriptor,
194        _shapes: &[&[usize]],
195    ) -> Result<Self::Plan> {
196        Err(Error::InvalidArgument(format!(
197            "RNG descriptor {desc:?} is not implemented on RocmBackend in phase 1"
198        )))
199    }
200
201    fn execute(
202        _ctx: &mut Self::Context,
203        _plan: &Self::Plan,
204        _generator: &mut Generator,
205        _output: &mut Tensor<i32>,
206    ) -> Result<()> {
207        Err(Error::InvalidArgument(
208            "RNG execution is not implemented on RocmBackend in phase 1".into(),
209        ))
210    }
211
212    fn has_rng_support(_desc: RngPrimsDescriptor) -> bool {
213        false
214    }
215}