tenferro_device/
generator.rs1use std::sync::atomic::{AtomicU64, Ordering};
2use std::time::{SystemTime, UNIX_EPOCH};
3use std::{
4 f64::consts::TAU,
5 sync::{Mutex, OnceLock},
6};
7
8use rand_core::Rng;
9
10use crate::{Error, Result};
11
12#[cfg(feature = "cuda")]
13fn cuda_device_zero_based_is_available(device_id: usize) -> bool {
14 std::panic::catch_unwind(|| {
15 cudarc::runtime::result::device::get_count()
16 .map(|count| device_id < count as usize)
17 .unwrap_or(false)
18 })
19 .unwrap_or(false)
20}
21
22fn default_seed(extra: u64) -> u64 {
23 static COUNTER: AtomicU64 = AtomicU64::new(0);
24 let now = SystemTime::now()
25 .duration_since(UNIX_EPOCH)
26 .map(|duration| duration.as_nanos() as u64)
27 .unwrap_or(0);
28 let pid = u64::from(std::process::id());
29 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
30 now ^ pid.rotate_left(17) ^ counter.rotate_left(33) ^ extra.rotate_left(7)
31}
32
33#[derive(Debug)]
34struct GeneratorState {
35 engine: mt19937::MT19937,
36 cached_normal: Option<f64>,
37}
38
39impl GeneratorState {
40 fn from_seed(seed: u64) -> Self {
41 let low = seed as u32;
42 let high = (seed >> 32) as u32;
43 let seed_words = if high == 0 {
44 vec![low]
45 } else {
46 vec![low, high]
47 };
48 Self {
49 engine: mt19937::MT19937::new_with_slice_seed(&seed_words),
50 cached_normal: None,
51 }
52 }
53}
54
55#[derive(Debug)]
72pub struct Generator {
73 state: GeneratorState,
74 #[cfg(feature = "cuda")]
75 device_id: Option<usize>,
76 #[cfg(feature = "cuda")]
77 seed: u64,
78 #[cfg(feature = "cuda")]
79 offset: u64,
80}
81
82impl Generator {
83 pub fn cpu(seed: u64) -> Self {
93 Self {
94 state: GeneratorState::from_seed(seed),
95 #[cfg(feature = "cuda")]
96 device_id: None,
97 #[cfg(feature = "cuda")]
98 seed,
99 #[cfg(feature = "cuda")]
100 offset: 0,
101 }
102 }
103
104 #[cfg(feature = "cuda")]
117 pub fn cuda(device_id: usize, seed: u64) -> Result<Self> {
118 if !cuda_device_zero_based_is_available(device_id) {
119 return Err(Error::DeviceError(format!(
120 "CUDA generator requires available device {device_id}"
121 )));
122 }
123 Ok(Self {
124 state: GeneratorState::from_seed(seed),
125 device_id: Some(device_id),
126 seed,
127 offset: 0,
128 })
129 }
130
131 #[cfg(feature = "cuda")]
132 pub(crate) fn cuda_seed_and_offset(&self, expected_device_id: usize) -> Result<(u64, u64)> {
133 match self.device_id {
134 Some(device_id) if device_id == expected_device_id => Ok((self.seed, self.offset)),
135 Some(device_id) => Err(Error::DeviceError(format!(
136 "CUDA generator is bound to device {device_id}, expected device {expected_device_id}"
137 ))),
138 None => Err(Error::InvalidArgument(
139 "CPU generator cannot drive CUDA RNG execution".into(),
140 )),
141 }
142 }
143
144 #[cfg(feature = "cuda")]
145 pub(crate) fn advance_cuda_offset(
146 &mut self,
147 expected_device_id: usize,
148 delta: u64,
149 ) -> Result<()> {
150 match self.device_id {
151 Some(device_id) if device_id == expected_device_id => {
152 self.offset = self
153 .offset
154 .checked_add(delta)
155 .ok_or_else(|| Error::DeviceError("CUDA generator offset overflow".into()))?;
156 Ok(())
157 }
158 Some(device_id) => Err(Error::DeviceError(format!(
159 "CUDA generator is bound to device {device_id}, expected device {expected_device_id}"
160 ))),
161 None => Err(Error::InvalidArgument(
162 "CPU generator cannot drive CUDA RNG execution".into(),
163 )),
164 }
165 }
166
167 pub fn sample_uniform_f64(&mut self) -> f64 {
179 mt19937::gen_res53(&mut self.state.engine)
180 }
181
182 pub fn sample_standard_normal_f64(&mut self) -> f64 {
193 if let Some(sample) = self.state.cached_normal.take() {
194 return sample;
195 }
196
197 let u1 = self.sample_uniform_f64().max(f64::MIN_POSITIVE);
198 let u2 = self.sample_uniform_f64();
199 let radius = (-2.0 * u1.ln()).sqrt();
200 let theta = TAU * u2;
201 let z0 = radius * theta.cos();
202 let z1 = radius * theta.sin();
203 self.state.cached_normal = Some(z1);
204 z0
205 }
206
207 pub fn sample_integer_i32(&mut self, low: i32, high: i32) -> Result<i32> {
223 if low >= high {
224 return Err(Error::InvalidArgument(format!(
225 "invalid integer sample range [{low}, {high})"
226 )));
227 }
228
229 let span = i64::from(high) - i64::from(low);
230 let span_u64 = span as u64;
231 let threshold = u64::MAX - (u64::MAX % span_u64);
232
233 loop {
234 let candidate = self.state.engine.next_u64();
235 if candidate < threshold {
236 let value = i64::from(low) + (candidate % span_u64) as i64;
239 return Ok(value as i32);
240 }
241 }
242 }
243}
244
245fn default_cpu_generator() -> &'static Mutex<Generator> {
246 static DEFAULT_CPU_GENERATOR: OnceLock<Mutex<Generator>> = OnceLock::new();
247 DEFAULT_CPU_GENERATOR.get_or_init(|| Mutex::new(Generator::cpu(default_seed(0))))
248}
249
250#[cfg(feature = "cuda")]
251fn default_cuda_generators() -> &'static Mutex<std::collections::HashMap<usize, Generator>> {
252 use std::collections::HashMap;
253
254 static DEFAULT_CUDA_GENERATORS: OnceLock<Mutex<HashMap<usize, Generator>>> = OnceLock::new();
255 DEFAULT_CUDA_GENERATORS.get_or_init(|| Mutex::new(HashMap::new()))
256}
257
258pub fn with_default_generator<R, F>(space: crate::LogicalMemorySpace, f: F) -> Result<R>
275where
276 F: FnOnce(&mut Generator) -> Result<R>,
277{
278 match space {
279 crate::LogicalMemorySpace::MainMemory
280 | crate::LogicalMemorySpace::PinnedMemory
281 | crate::LogicalMemorySpace::ManagedMemory => {
282 let mut guard = default_cpu_generator()
283 .lock()
284 .map_err(|_| Error::DeviceError("default CPU generator mutex poisoned".into()))?;
285 f(&mut guard)
286 }
287 crate::LogicalMemorySpace::GpuMemory { device_id } => {
288 #[cfg(feature = "cuda")]
289 {
290 use std::collections::hash_map::Entry;
291
292 let mut guard = default_cuda_generators().lock().map_err(|_| {
293 Error::DeviceError("default CUDA generator mutex poisoned".into())
294 })?;
295 let generator = match guard.entry(device_id) {
296 Entry::Occupied(entry) => entry.into_mut(),
297 Entry::Vacant(entry) => {
298 let seed = default_seed(device_id as u64);
299 let generator = Generator::cuda(device_id, seed)?;
300 entry.insert(generator)
301 }
302 };
303 f(generator)
304 }
305 #[cfg(not(feature = "cuda"))]
306 {
307 let _ = device_id;
308 Err(Error::DeviceError(
309 "default CUDA generator requires the cuda feature".into(),
310 ))
311 }
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests;