tenferro_prims/families/
rng.rs1use tenferro_algebra::{Algebra, Standard};
2use tenferro_device::{Error, Generator, Result};
3use tenferro_tensor::Tensor;
4
5#[cfg(not(feature = "cuda"))]
6use crate::{CudaBackend, CudaContext, RocmBackend, RocmContext};
7#[cfg(feature = "cuda")]
8use crate::{RocmBackend, RocmContext};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub enum RngPrimsDescriptor {
22 Uniform,
24 Normal,
26 Integer {
28 low: i32,
30 high: i32,
32 },
33}
34
35pub trait TensorRngPrims<Alg: Algebra> {
71 type Plan;
73 type Context;
75
76 fn plan(
78 ctx: &mut Self::Context,
79 desc: &RngPrimsDescriptor,
80 shapes: &[&[usize]],
81 ) -> Result<Self::Plan>;
82
83 fn execute(
85 ctx: &mut Self::Context,
86 plan: &Self::Plan,
87 generator: &mut Generator,
88 output: &mut Tensor<Alg::Scalar>,
89 ) -> Result<()>;
90
91 fn has_rng_support(desc: RngPrimsDescriptor) -> bool;
93}
94
95#[cfg(not(feature = "cuda"))]
96impl TensorRngPrims<Standard<f64>> for CudaBackend {
97 type Plan = (RngPrimsDescriptor, Vec<usize>);
98 type Context = CudaContext;
99
100 fn plan(
101 _ctx: &mut Self::Context,
102 desc: &RngPrimsDescriptor,
103 _shapes: &[&[usize]],
104 ) -> Result<Self::Plan> {
105 Err(Error::InvalidArgument(format!(
106 "RNG descriptor {desc:?} is not implemented on CudaBackend in phase 1"
107 )))
108 }
109
110 fn execute(
111 _ctx: &mut Self::Context,
112 _plan: &Self::Plan,
113 _generator: &mut Generator,
114 _output: &mut Tensor<f64>,
115 ) -> Result<()> {
116 Err(Error::InvalidArgument(
117 "RNG execution is not implemented on CudaBackend in phase 1".into(),
118 ))
119 }
120
121 fn has_rng_support(_desc: RngPrimsDescriptor) -> bool {
122 false
123 }
124}
125
126#[cfg(not(feature = "cuda"))]
127impl TensorRngPrims<Standard<i32>> for CudaBackend {
128 type Plan = (RngPrimsDescriptor, Vec<usize>);
129 type Context = CudaContext;
130
131 fn plan(
132 _ctx: &mut Self::Context,
133 desc: &RngPrimsDescriptor,
134 _shapes: &[&[usize]],
135 ) -> Result<Self::Plan> {
136 Err(Error::InvalidArgument(format!(
137 "RNG descriptor {desc:?} is not implemented on CudaBackend in phase 1"
138 )))
139 }
140
141 fn execute(
142 _ctx: &mut Self::Context,
143 _plan: &Self::Plan,
144 _generator: &mut Generator,
145 _output: &mut Tensor<i32>,
146 ) -> Result<()> {
147 Err(Error::InvalidArgument(
148 "RNG execution is not implemented on CudaBackend in phase 1".into(),
149 ))
150 }
151
152 fn has_rng_support(_desc: RngPrimsDescriptor) -> bool {
153 false
154 }
155}
156
157impl TensorRngPrims<Standard<f64>> for RocmBackend {
158 type Plan = (RngPrimsDescriptor, Vec<usize>);
159 type Context = RocmContext;
160
161 fn plan(
162 _ctx: &mut Self::Context,
163 desc: &RngPrimsDescriptor,
164 _shapes: &[&[usize]],
165 ) -> Result<Self::Plan> {
166 Err(Error::InvalidArgument(format!(
167 "RNG descriptor {desc:?} is not implemented on RocmBackend in phase 1"
168 )))
169 }
170
171 fn execute(
172 _ctx: &mut Self::Context,
173 _plan: &Self::Plan,
174 _generator: &mut Generator,
175 _output: &mut Tensor<f64>,
176 ) -> Result<()> {
177 Err(Error::InvalidArgument(
178 "RNG execution is not implemented on RocmBackend in phase 1".into(),
179 ))
180 }
181
182 fn has_rng_support(_desc: RngPrimsDescriptor) -> bool {
183 false
184 }
185}
186
187impl TensorRngPrims<Standard<i32>> for RocmBackend {
188 type Plan = (RngPrimsDescriptor, Vec<usize>);
189 type Context = RocmContext;
190
191 fn plan(
192 _ctx: &mut Self::Context,
193 desc: &RngPrimsDescriptor,
194 _shapes: &[&[usize]],
195 ) -> Result<Self::Plan> {
196 Err(Error::InvalidArgument(format!(
197 "RNG descriptor {desc:?} is not implemented on RocmBackend in phase 1"
198 )))
199 }
200
201 fn execute(
202 _ctx: &mut Self::Context,
203 _plan: &Self::Plan,
204 _generator: &mut Generator,
205 _output: &mut Tensor<i32>,
206 ) -> Result<()> {
207 Err(Error::InvalidArgument(
208 "RNG execution is not implemented on RocmBackend in phase 1".into(),
209 ))
210 }
211
212 fn has_rng_support(_desc: RngPrimsDescriptor) -> bool {
213 false
214 }
215}