tenferro_device/
generator.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::time::{SystemTime, UNIX_EPOCH};
3use std::{
4    convert::TryFrom,
5    f64::consts::TAU,
6    sync::{Mutex, OnceLock},
7};
8
9use rand_core::Rng;
10
11use crate::{Error, Result};
12
13#[cfg(feature = "cuda")]
14fn cuda_device_zero_based_is_available(device_id: usize) -> bool {
15    std::panic::catch_unwind(|| {
16        cudarc::runtime::result::device::get_count()
17            .map(|count| device_id < count as usize)
18            .unwrap_or(false)
19    })
20    .unwrap_or(false)
21}
22
23fn default_seed(extra: u64) -> u64 {
24    static COUNTER: AtomicU64 = AtomicU64::new(0);
25    let now = SystemTime::now()
26        .duration_since(UNIX_EPOCH)
27        .map(|duration| duration.as_nanos() as u64)
28        .unwrap_or(0);
29    let pid = u64::from(std::process::id());
30    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
31    now ^ pid.rotate_left(17) ^ counter.rotate_left(33) ^ extra.rotate_left(7)
32}
33
34#[derive(Debug)]
35struct GeneratorState {
36    engine: mt19937::MT19937,
37    cached_normal: Option<f64>,
38}
39
40impl GeneratorState {
41    fn from_seed(seed: u64) -> Self {
42        let low = seed as u32;
43        let high = (seed >> 32) as u32;
44        let seed_words = if high == 0 {
45            vec![low]
46        } else {
47            vec![low, high]
48        };
49        Self {
50            engine: mt19937::MT19937::new_with_slice_seed(&seed_words),
51            cached_normal: None,
52        }
53    }
54}
55
56/// Pseudo-random number generator used across the tenferro workspace.
57///
58/// The CPU half uses an MT19937 engine seeded from a `u64`. CUDA execution
59/// uses the same public `Generator` surface but advances an internal
60/// seed/offset pair that device kernels consume through a Philox-style
61/// counter-based scheme.
62///
63/// # Examples
64///
65/// ```
66/// use tenferro_device::Generator;
67///
68/// let mut generator = Generator::cpu(1234);
69/// let sample = generator.sample_uniform_f64();
70/// assert!(sample >= 0.0 && sample < 1.0);
71/// ```
72#[derive(Debug)]
73pub struct Generator {
74    state: GeneratorState,
75    #[cfg(feature = "cuda")]
76    device_id: Option<usize>,
77    #[cfg(feature = "cuda")]
78    seed: u64,
79    #[cfg(feature = "cuda")]
80    offset: u64,
81}
82
83impl Generator {
84    /// Create a CPU generator from a seed.
85    ///
86    /// # Examples
87    ///
88    /// ```
89    /// use tenferro_device::Generator;
90    ///
91    /// let _generator = Generator::cpu(42);
92    /// ```
93    pub fn cpu(seed: u64) -> Self {
94        Self {
95            state: GeneratorState::from_seed(seed),
96            #[cfg(feature = "cuda")]
97            device_id: None,
98            #[cfg(feature = "cuda")]
99            seed,
100            #[cfg(feature = "cuda")]
101            offset: 0,
102        }
103    }
104
105    /// Create a CUDA generator from a seed and device ordinal.
106    ///
107    /// The CPU half of the RNG phase only records the metadata so the public
108    /// API is in place for the later CUDA implementation.
109    ///
110    /// # Examples
111    ///
112    /// ```ignore
113    /// use tenferro_device::Generator;
114    ///
115    /// let _generator = Generator::cuda(0, 1234).unwrap();
116    /// ```
117    #[cfg(feature = "cuda")]
118    pub fn cuda(device_id: usize, seed: u64) -> Result<Self> {
119        if !cuda_device_zero_based_is_available(device_id) {
120            return Err(Error::DeviceError(format!(
121                "CUDA generator requires available device {device_id}"
122            )));
123        }
124        Ok(Self {
125            state: GeneratorState::from_seed(seed),
126            device_id: Some(device_id),
127            seed,
128            offset: 0,
129        })
130    }
131
132    #[cfg(feature = "cuda")]
133    pub(crate) fn cuda_seed_and_offset(&self, expected_device_id: usize) -> Result<(u64, u64)> {
134        match self.device_id {
135            Some(device_id) if device_id == expected_device_id => Ok((self.seed, self.offset)),
136            Some(device_id) => Err(Error::DeviceError(format!(
137                "CUDA generator is bound to device {device_id}, expected device {expected_device_id}"
138            ))),
139            None => Err(Error::InvalidArgument(
140                "CPU generator cannot drive CUDA RNG execution".into(),
141            )),
142        }
143    }
144
145    #[cfg(feature = "cuda")]
146    pub(crate) fn advance_cuda_offset(
147        &mut self,
148        expected_device_id: usize,
149        delta: u64,
150    ) -> Result<()> {
151        match self.device_id {
152            Some(device_id) if device_id == expected_device_id => {
153                self.offset = self
154                    .offset
155                    .checked_add(delta)
156                    .ok_or_else(|| Error::DeviceError("CUDA generator offset overflow".into()))?;
157                Ok(())
158            }
159            Some(device_id) => Err(Error::DeviceError(format!(
160                "CUDA generator is bound to device {device_id}, expected device {expected_device_id}"
161            ))),
162            None => Err(Error::InvalidArgument(
163                "CPU generator cannot drive CUDA RNG execution".into(),
164            )),
165        }
166    }
167
168    /// Draw a floating-point sample from the half-open interval `[0, 1)`.
169    ///
170    /// # Examples
171    ///
172    /// ```
173    /// use tenferro_device::Generator;
174    ///
175    /// let mut generator = Generator::cpu(7);
176    /// let x = generator.sample_uniform_f64();
177    /// assert!(x >= 0.0 && x < 1.0);
178    /// ```
179    pub fn sample_uniform_f64(&mut self) -> f64 {
180        mt19937::gen_res53(&mut self.state.engine)
181    }
182
183    /// Draw a standard normal sample using Box-Muller sampling.
184    ///
185    /// # Examples
186    ///
187    /// ```
188    /// use tenferro_device::Generator;
189    ///
190    /// let mut generator = Generator::cpu(7);
191    /// let _z = generator.sample_standard_normal_f64();
192    /// ```
193    pub fn sample_standard_normal_f64(&mut self) -> f64 {
194        if let Some(sample) = self.state.cached_normal.take() {
195            return sample;
196        }
197
198        let u1 = self.sample_uniform_f64().max(f64::MIN_POSITIVE);
199        let u2 = self.sample_uniform_f64();
200        let radius = (-2.0 * u1.ln()).sqrt();
201        let theta = TAU * u2;
202        let z0 = radius * theta.cos();
203        let z1 = radius * theta.sin();
204        self.state.cached_normal = Some(z1);
205        z0
206    }
207
208    /// Draw an integer sample from the half-open interval `[low, high)`.
209    ///
210    /// # Errors
211    ///
212    /// Returns [`Error::InvalidArgument`] if `low >= high`.
213    ///
214    /// # Examples
215    ///
216    /// ```
217    /// use tenferro_device::Generator;
218    ///
219    /// let mut generator = Generator::cpu(7);
220    /// let x = generator.sample_integer_i32(-3, 7).unwrap();
221    /// assert!((-3..7).contains(&x));
222    /// ```
223    pub fn sample_integer_i32(&mut self, low: i32, high: i32) -> Result<i32> {
224        if low >= high {
225            return Err(Error::InvalidArgument(format!(
226                "invalid integer sample range [{low}, {high})"
227            )));
228        }
229
230        let span = i64::from(high) - i64::from(low);
231        let span_u64 = u64::try_from(span).map_err(|_| {
232            Error::InvalidArgument(format!("integer sample span {span} does not fit into u64"))
233        })?;
234        let threshold = u64::MAX - (u64::MAX % span_u64);
235
236        loop {
237            let candidate = self.state.engine.next_u64();
238            if candidate < threshold {
239                let value = i64::from(low)
240                    + i64::try_from(candidate % span_u64).map_err(|_| {
241                        Error::InvalidArgument("integer sample conversion overflow".into())
242                    })?;
243                return i32::try_from(value).map_err(|_| {
244                    Error::InvalidArgument(format!("integer sample {value} does not fit into i32"))
245                });
246            }
247        }
248    }
249}
250
251fn default_cpu_generator() -> &'static Mutex<Generator> {
252    static DEFAULT_CPU_GENERATOR: OnceLock<Mutex<Generator>> = OnceLock::new();
253    DEFAULT_CPU_GENERATOR.get_or_init(|| Mutex::new(Generator::cpu(default_seed(0))))
254}
255
256#[cfg(feature = "cuda")]
257fn default_cuda_generators() -> &'static Mutex<std::collections::HashMap<usize, Generator>> {
258    use std::collections::HashMap;
259
260    static DEFAULT_CUDA_GENERATORS: OnceLock<Mutex<HashMap<usize, Generator>>> = OnceLock::new();
261    DEFAULT_CUDA_GENERATORS.get_or_init(|| Mutex::new(HashMap::new()))
262}
263
264/// Run a closure with a shared default generator for the requested memory space.
265///
266/// The default generator is process-global and advances across calls. CPU
267/// memory spaces share a CPU default generator, while each CUDA device gets its
268/// own device-bound default generator.
269///
270/// # Examples
271///
272/// ```
273/// use tenferro_device::{with_default_generator, LogicalMemorySpace};
274///
275/// let value = with_default_generator(LogicalMemorySpace::MainMemory, |generator| {
276///     Ok(generator.sample_uniform_f64())
277/// }).unwrap();
278/// assert!(value >= 0.0 && value < 1.0);
279/// ```
280pub fn with_default_generator<R, F>(space: crate::LogicalMemorySpace, f: F) -> Result<R>
281where
282    F: FnOnce(&mut Generator) -> Result<R>,
283{
284    match space {
285        crate::LogicalMemorySpace::MainMemory
286        | crate::LogicalMemorySpace::PinnedMemory
287        | crate::LogicalMemorySpace::ManagedMemory => {
288            let mut guard = default_cpu_generator()
289                .lock()
290                .map_err(|_| Error::DeviceError("default CPU generator mutex poisoned".into()))?;
291            f(&mut guard)
292        }
293        crate::LogicalMemorySpace::GpuMemory { device_id } => {
294            #[cfg(feature = "cuda")]
295            {
296                use std::collections::hash_map::Entry;
297
298                let mut guard = default_cuda_generators().lock().map_err(|_| {
299                    Error::DeviceError("default CUDA generator mutex poisoned".into())
300                })?;
301                let generator = match guard.entry(device_id) {
302                    Entry::Occupied(entry) => entry.into_mut(),
303                    Entry::Vacant(entry) => {
304                        let seed = default_seed(device_id as u64);
305                        let generator = Generator::cuda(device_id, seed)?;
306                        entry.insert(generator)
307                    }
308                };
309                f(generator)
310            }
311            #[cfg(not(feature = "cuda"))]
312            {
313                let _ = device_id;
314                Err(Error::DeviceError(
315                    "default CUDA generator requires the cuda feature".into(),
316                ))
317            }
318        }
319    }
320}
321
322#[cfg(test)]
323mod tests;