tenferro_device/
generator.rs1use std::sync::atomic::{AtomicU64, Ordering};
2use std::time::{SystemTime, UNIX_EPOCH};
3use std::{
4 convert::TryFrom,
5 f64::consts::TAU,
6 sync::{Mutex, OnceLock},
7};
8
9use rand_core::Rng;
10
11use crate::{Error, Result};
12
13#[cfg(feature = "cuda")]
14fn cuda_device_zero_based_is_available(device_id: usize) -> bool {
15 std::panic::catch_unwind(|| {
16 cudarc::runtime::result::device::get_count()
17 .map(|count| device_id < count as usize)
18 .unwrap_or(false)
19 })
20 .unwrap_or(false)
21}
22
23fn default_seed(extra: u64) -> u64 {
24 static COUNTER: AtomicU64 = AtomicU64::new(0);
25 let now = SystemTime::now()
26 .duration_since(UNIX_EPOCH)
27 .map(|duration| duration.as_nanos() as u64)
28 .unwrap_or(0);
29 let pid = u64::from(std::process::id());
30 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
31 now ^ pid.rotate_left(17) ^ counter.rotate_left(33) ^ extra.rotate_left(7)
32}
33
34#[derive(Debug)]
35struct GeneratorState {
36 engine: mt19937::MT19937,
37 cached_normal: Option<f64>,
38}
39
40impl GeneratorState {
41 fn from_seed(seed: u64) -> Self {
42 let low = seed as u32;
43 let high = (seed >> 32) as u32;
44 let seed_words = if high == 0 {
45 vec![low]
46 } else {
47 vec![low, high]
48 };
49 Self {
50 engine: mt19937::MT19937::new_with_slice_seed(&seed_words),
51 cached_normal: None,
52 }
53 }
54}
55
56#[derive(Debug)]
73pub struct Generator {
74 state: GeneratorState,
75 #[cfg(feature = "cuda")]
76 device_id: Option<usize>,
77 #[cfg(feature = "cuda")]
78 seed: u64,
79 #[cfg(feature = "cuda")]
80 offset: u64,
81}
82
83impl Generator {
84 pub fn cpu(seed: u64) -> Self {
94 Self {
95 state: GeneratorState::from_seed(seed),
96 #[cfg(feature = "cuda")]
97 device_id: None,
98 #[cfg(feature = "cuda")]
99 seed,
100 #[cfg(feature = "cuda")]
101 offset: 0,
102 }
103 }
104
105 #[cfg(feature = "cuda")]
118 pub fn cuda(device_id: usize, seed: u64) -> Result<Self> {
119 if !cuda_device_zero_based_is_available(device_id) {
120 return Err(Error::DeviceError(format!(
121 "CUDA generator requires available device {device_id}"
122 )));
123 }
124 Ok(Self {
125 state: GeneratorState::from_seed(seed),
126 device_id: Some(device_id),
127 seed,
128 offset: 0,
129 })
130 }
131
132 #[cfg(feature = "cuda")]
133 pub(crate) fn cuda_seed_and_offset(&self, expected_device_id: usize) -> Result<(u64, u64)> {
134 match self.device_id {
135 Some(device_id) if device_id == expected_device_id => Ok((self.seed, self.offset)),
136 Some(device_id) => Err(Error::DeviceError(format!(
137 "CUDA generator is bound to device {device_id}, expected device {expected_device_id}"
138 ))),
139 None => Err(Error::InvalidArgument(
140 "CPU generator cannot drive CUDA RNG execution".into(),
141 )),
142 }
143 }
144
145 #[cfg(feature = "cuda")]
146 pub(crate) fn advance_cuda_offset(
147 &mut self,
148 expected_device_id: usize,
149 delta: u64,
150 ) -> Result<()> {
151 match self.device_id {
152 Some(device_id) if device_id == expected_device_id => {
153 self.offset = self
154 .offset
155 .checked_add(delta)
156 .ok_or_else(|| Error::DeviceError("CUDA generator offset overflow".into()))?;
157 Ok(())
158 }
159 Some(device_id) => Err(Error::DeviceError(format!(
160 "CUDA generator is bound to device {device_id}, expected device {expected_device_id}"
161 ))),
162 None => Err(Error::InvalidArgument(
163 "CPU generator cannot drive CUDA RNG execution".into(),
164 )),
165 }
166 }
167
168 pub fn sample_uniform_f64(&mut self) -> f64 {
180 mt19937::gen_res53(&mut self.state.engine)
181 }
182
183 pub fn sample_standard_normal_f64(&mut self) -> f64 {
194 if let Some(sample) = self.state.cached_normal.take() {
195 return sample;
196 }
197
198 let u1 = self.sample_uniform_f64().max(f64::MIN_POSITIVE);
199 let u2 = self.sample_uniform_f64();
200 let radius = (-2.0 * u1.ln()).sqrt();
201 let theta = TAU * u2;
202 let z0 = radius * theta.cos();
203 let z1 = radius * theta.sin();
204 self.state.cached_normal = Some(z1);
205 z0
206 }
207
208 pub fn sample_integer_i32(&mut self, low: i32, high: i32) -> Result<i32> {
224 if low >= high {
225 return Err(Error::InvalidArgument(format!(
226 "invalid integer sample range [{low}, {high})"
227 )));
228 }
229
230 let span = i64::from(high) - i64::from(low);
231 let span_u64 = u64::try_from(span).map_err(|_| {
232 Error::InvalidArgument(format!("integer sample span {span} does not fit into u64"))
233 })?;
234 let threshold = u64::MAX - (u64::MAX % span_u64);
235
236 loop {
237 let candidate = self.state.engine.next_u64();
238 if candidate < threshold {
239 let value = i64::from(low)
240 + i64::try_from(candidate % span_u64).map_err(|_| {
241 Error::InvalidArgument("integer sample conversion overflow".into())
242 })?;
243 return i32::try_from(value).map_err(|_| {
244 Error::InvalidArgument(format!("integer sample {value} does not fit into i32"))
245 });
246 }
247 }
248 }
249}
250
251fn default_cpu_generator() -> &'static Mutex<Generator> {
252 static DEFAULT_CPU_GENERATOR: OnceLock<Mutex<Generator>> = OnceLock::new();
253 DEFAULT_CPU_GENERATOR.get_or_init(|| Mutex::new(Generator::cpu(default_seed(0))))
254}
255
256#[cfg(feature = "cuda")]
257fn default_cuda_generators() -> &'static Mutex<std::collections::HashMap<usize, Generator>> {
258 use std::collections::HashMap;
259
260 static DEFAULT_CUDA_GENERATORS: OnceLock<Mutex<HashMap<usize, Generator>>> = OnceLock::new();
261 DEFAULT_CUDA_GENERATORS.get_or_init(|| Mutex::new(HashMap::new()))
262}
263
264pub fn with_default_generator<R, F>(space: crate::LogicalMemorySpace, f: F) -> Result<R>
281where
282 F: FnOnce(&mut Generator) -> Result<R>,
283{
284 match space {
285 crate::LogicalMemorySpace::MainMemory
286 | crate::LogicalMemorySpace::PinnedMemory
287 | crate::LogicalMemorySpace::ManagedMemory => {
288 let mut guard = default_cpu_generator()
289 .lock()
290 .map_err(|_| Error::DeviceError("default CPU generator mutex poisoned".into()))?;
291 f(&mut guard)
292 }
293 crate::LogicalMemorySpace::GpuMemory { device_id } => {
294 #[cfg(feature = "cuda")]
295 {
296 use std::collections::hash_map::Entry;
297
298 let mut guard = default_cuda_generators().lock().map_err(|_| {
299 Error::DeviceError("default CUDA generator mutex poisoned".into())
300 })?;
301 let generator = match guard.entry(device_id) {
302 Entry::Occupied(entry) => entry.into_mut(),
303 Entry::Vacant(entry) => {
304 let seed = default_seed(device_id as u64);
305 let generator = Generator::cuda(device_id, seed)?;
306 entry.insert(generator)
307 }
308 };
309 f(generator)
310 }
311 #[cfg(not(feature = "cuda"))]
312 {
313 let _ = device_id;
314 Err(Error::DeviceError(
315 "default CUDA generator requires the cuda feature".into(),
316 ))
317 }
318 }
319 }
320}
321
322#[cfg(test)]
323mod tests;