Skip to main content

tenferro_device/
generator.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::time::{SystemTime, UNIX_EPOCH};
3use std::{
4    f64::consts::TAU,
5    sync::{Mutex, OnceLock},
6};
7
8use rand_core::Rng;
9
10use crate::{Error, Result};
11
12#[cfg(feature = "cuda")]
13fn cuda_device_zero_based_is_available(device_id: usize) -> bool {
14    std::panic::catch_unwind(|| {
15        cudarc::runtime::result::device::get_count()
16            .map(|count| device_id < count as usize)
17            .unwrap_or(false)
18    })
19    .unwrap_or(false)
20}
21
22fn default_seed(extra: u64) -> u64 {
23    static COUNTER: AtomicU64 = AtomicU64::new(0);
24    let now = SystemTime::now()
25        .duration_since(UNIX_EPOCH)
26        .map(|duration| duration.as_nanos() as u64)
27        .unwrap_or(0);
28    let pid = u64::from(std::process::id());
29    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
30    now ^ pid.rotate_left(17) ^ counter.rotate_left(33) ^ extra.rotate_left(7)
31}
32
33#[derive(Debug)]
34struct GeneratorState {
35    engine: mt19937::MT19937,
36    cached_normal: Option<f64>,
37}
38
39impl GeneratorState {
40    fn from_seed(seed: u64) -> Self {
41        let low = seed as u32;
42        let high = (seed >> 32) as u32;
43        let seed_words = if high == 0 {
44            vec![low]
45        } else {
46            vec![low, high]
47        };
48        Self {
49            engine: mt19937::MT19937::new_with_slice_seed(&seed_words),
50            cached_normal: None,
51        }
52    }
53}
54
55/// Pseudo-random number generator used across the tenferro workspace.
56///
57/// The CPU half uses an MT19937 engine seeded from a `u64`. CUDA execution
58/// uses the same public `Generator` surface but advances an internal
59/// seed/offset pair that device kernels consume through a Philox-style
60/// counter-based scheme.
61///
62/// # Examples
63///
64/// ```
65/// use tenferro_device::Generator;
66///
67/// let mut generator = Generator::cpu(1234);
68/// let sample = generator.sample_uniform_f64();
69/// assert!(sample >= 0.0 && sample < 1.0);
70/// ```
71#[derive(Debug)]
72pub struct Generator {
73    state: GeneratorState,
74    #[cfg(feature = "cuda")]
75    device_id: Option<usize>,
76    #[cfg(feature = "cuda")]
77    seed: u64,
78    #[cfg(feature = "cuda")]
79    offset: u64,
80}
81
82impl Generator {
83    /// Create a CPU generator from a seed.
84    ///
85    /// # Examples
86    ///
87    /// ```
88    /// use tenferro_device::Generator;
89    ///
90    /// let _generator = Generator::cpu(42);
91    /// ```
92    pub fn cpu(seed: u64) -> Self {
93        Self {
94            state: GeneratorState::from_seed(seed),
95            #[cfg(feature = "cuda")]
96            device_id: None,
97            #[cfg(feature = "cuda")]
98            seed,
99            #[cfg(feature = "cuda")]
100            offset: 0,
101        }
102    }
103
104    /// Create a CUDA generator from a seed and device ordinal.
105    ///
106    /// The CPU half of the RNG phase only records the metadata so the public
107    /// API is in place for the later CUDA implementation.
108    ///
109    /// # Examples
110    ///
111    /// ```ignore
112    /// use tenferro_device::Generator;
113    ///
114    /// let _generator = Generator::cuda(0, 1234).unwrap();
115    /// ```
116    #[cfg(feature = "cuda")]
117    pub fn cuda(device_id: usize, seed: u64) -> Result<Self> {
118        if !cuda_device_zero_based_is_available(device_id) {
119            return Err(Error::DeviceError(format!(
120                "CUDA generator requires available device {device_id}"
121            )));
122        }
123        Ok(Self {
124            state: GeneratorState::from_seed(seed),
125            device_id: Some(device_id),
126            seed,
127            offset: 0,
128        })
129    }
130
131    #[cfg(feature = "cuda")]
132    pub(crate) fn cuda_seed_and_offset(&self, expected_device_id: usize) -> Result<(u64, u64)> {
133        match self.device_id {
134            Some(device_id) if device_id == expected_device_id => Ok((self.seed, self.offset)),
135            Some(device_id) => Err(Error::DeviceError(format!(
136                "CUDA generator is bound to device {device_id}, expected device {expected_device_id}"
137            ))),
138            None => Err(Error::InvalidArgument(
139                "CPU generator cannot drive CUDA RNG execution".into(),
140            )),
141        }
142    }
143
144    #[cfg(feature = "cuda")]
145    pub(crate) fn advance_cuda_offset(
146        &mut self,
147        expected_device_id: usize,
148        delta: u64,
149    ) -> Result<()> {
150        match self.device_id {
151            Some(device_id) if device_id == expected_device_id => {
152                self.offset = self
153                    .offset
154                    .checked_add(delta)
155                    .ok_or_else(|| Error::DeviceError("CUDA generator offset overflow".into()))?;
156                Ok(())
157            }
158            Some(device_id) => Err(Error::DeviceError(format!(
159                "CUDA generator is bound to device {device_id}, expected device {expected_device_id}"
160            ))),
161            None => Err(Error::InvalidArgument(
162                "CPU generator cannot drive CUDA RNG execution".into(),
163            )),
164        }
165    }
166
167    /// Draw a floating-point sample from the half-open interval `[0, 1)`.
168    ///
169    /// # Examples
170    ///
171    /// ```
172    /// use tenferro_device::Generator;
173    ///
174    /// let mut generator = Generator::cpu(7);
175    /// let x = generator.sample_uniform_f64();
176    /// assert!(x >= 0.0 && x < 1.0);
177    /// ```
178    pub fn sample_uniform_f64(&mut self) -> f64 {
179        mt19937::gen_res53(&mut self.state.engine)
180    }
181
182    /// Draw a standard normal sample using Box-Muller sampling.
183    ///
184    /// # Examples
185    ///
186    /// ```
187    /// use tenferro_device::Generator;
188    ///
189    /// let mut generator = Generator::cpu(7);
190    /// let _z = generator.sample_standard_normal_f64();
191    /// ```
192    pub fn sample_standard_normal_f64(&mut self) -> f64 {
193        if let Some(sample) = self.state.cached_normal.take() {
194            return sample;
195        }
196
197        let u1 = self.sample_uniform_f64().max(f64::MIN_POSITIVE);
198        let u2 = self.sample_uniform_f64();
199        let radius = (-2.0 * u1.ln()).sqrt();
200        let theta = TAU * u2;
201        let z0 = radius * theta.cos();
202        let z1 = radius * theta.sin();
203        self.state.cached_normal = Some(z1);
204        z0
205    }
206
207    /// Draw an integer sample from the half-open interval `[low, high)`.
208    ///
209    /// # Errors
210    ///
211    /// Returns [`Error::InvalidArgument`] if `low >= high`.
212    ///
213    /// # Examples
214    ///
215    /// ```
216    /// use tenferro_device::Generator;
217    ///
218    /// let mut generator = Generator::cpu(7);
219    /// let x = generator.sample_integer_i32(-3, 7).unwrap();
220    /// assert!((-3..7).contains(&x));
221    /// ```
222    pub fn sample_integer_i32(&mut self, low: i32, high: i32) -> Result<i32> {
223        if low >= high {
224            return Err(Error::InvalidArgument(format!(
225                "invalid integer sample range [{low}, {high})"
226            )));
227        }
228
229        let span = i64::from(high) - i64::from(low);
230        let span_u64 = span as u64;
231        let threshold = u64::MAX - (u64::MAX % span_u64);
232
233        loop {
234            let candidate = self.state.engine.next_u64();
235            if candidate < threshold {
236                // `candidate % span_u64` is always in `[0, span)`, so adding it to `low`
237                // stays inside the validated half-open interval `[low, high)`.
238                let value = i64::from(low) + (candidate % span_u64) as i64;
239                return Ok(value as i32);
240            }
241        }
242    }
243}
244
245fn default_cpu_generator() -> &'static Mutex<Generator> {
246    static DEFAULT_CPU_GENERATOR: OnceLock<Mutex<Generator>> = OnceLock::new();
247    DEFAULT_CPU_GENERATOR.get_or_init(|| Mutex::new(Generator::cpu(default_seed(0))))
248}
249
250#[cfg(feature = "cuda")]
251fn default_cuda_generators() -> &'static Mutex<std::collections::HashMap<usize, Generator>> {
252    use std::collections::HashMap;
253
254    static DEFAULT_CUDA_GENERATORS: OnceLock<Mutex<HashMap<usize, Generator>>> = OnceLock::new();
255    DEFAULT_CUDA_GENERATORS.get_or_init(|| Mutex::new(HashMap::new()))
256}
257
258/// Run a closure with a shared default generator for the requested memory space.
259///
260/// The default generator is process-global and advances across calls. CPU
261/// memory spaces share a CPU default generator, while each CUDA device gets its
262/// own device-bound default generator.
263///
264/// # Examples
265///
266/// ```
267/// use tenferro_device::{with_default_generator, LogicalMemorySpace};
268///
269/// let value = with_default_generator(LogicalMemorySpace::MainMemory, |generator| {
270///     Ok(generator.sample_uniform_f64())
271/// }).unwrap();
272/// assert!(value >= 0.0 && value < 1.0);
273/// ```
274pub fn with_default_generator<R, F>(space: crate::LogicalMemorySpace, f: F) -> Result<R>
275where
276    F: FnOnce(&mut Generator) -> Result<R>,
277{
278    match space {
279        crate::LogicalMemorySpace::MainMemory
280        | crate::LogicalMemorySpace::PinnedMemory
281        | crate::LogicalMemorySpace::ManagedMemory => {
282            let mut guard = default_cpu_generator()
283                .lock()
284                .map_err(|_| Error::DeviceError("default CPU generator mutex poisoned".into()))?;
285            f(&mut guard)
286        }
287        crate::LogicalMemorySpace::GpuMemory { device_id } => {
288            #[cfg(feature = "cuda")]
289            {
290                use std::collections::hash_map::Entry;
291
292                let mut guard = default_cuda_generators().lock().map_err(|_| {
293                    Error::DeviceError("default CUDA generator mutex poisoned".into())
294                })?;
295                let generator = match guard.entry(device_id) {
296                    Entry::Occupied(entry) => entry.into_mut(),
297                    Entry::Vacant(entry) => {
298                        let seed = default_seed(device_id as u64);
299                        let generator = Generator::cuda(device_id, seed)?;
300                        entry.insert(generator)
301                    }
302                };
303                f(generator)
304            }
305            #[cfg(not(feature = "cuda"))]
306            {
307                let _ = device_id;
308                Err(Error::DeviceError(
309                    "default CUDA generator requires the cuda feature".into(),
310                ))
311            }
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests;