tenferro_prims/cpu/
rng.rs

1use tenferro_algebra::Standard;
2use tenferro_device::{Error, Generator, Result};
3use tenferro_tensor::Tensor;
4
5use crate::cpu::tensor_to_view_mut;
6use crate::{
7    validate_shape_count, validate_shape_eq, CpuBackend, CpuContext, RngPrimsDescriptor,
8    TensorRngPrims,
9};
10
11/// CPU execution plan for the RNG family.
12///
13/// The plan stores the descriptor together with the output shape so execution
14/// can revalidate the destination tensor before mutating it.
15///
16/// # Examples
17///
18/// ```ignore
19/// use tenferro_prims::CpuRngPlan;
20/// let _plan: CpuRngPlan = (tenferro_prims::RngPrimsDescriptor::Uniform, vec![2, 2]);
21/// ```
22pub type CpuRngPlan = (RngPrimsDescriptor, Vec<usize>);
23
24fn fill_tensor_with_f64_samples<F>(output: &mut Tensor<f64>, mut sample: F) -> Result<()>
25where
26    F: FnMut() -> Result<f64>,
27{
28    let dims = output.dims().to_vec();
29    let mut view = tensor_to_view_mut(output)?;
30    let mut error = None;
31    crate::for_each_index(&dims, |idx| {
32        if error.is_some() {
33            return;
34        }
35        match sample() {
36            Ok(value) => view.set(idx, value),
37            Err(err) => error = Some(err),
38        }
39    });
40
41    match error {
42        Some(err) => Err(err),
43        None => Ok(()),
44    }
45}
46
47fn fill_tensor_with_i32_samples<F>(output: &mut Tensor<i32>, mut sample: F) -> Result<()>
48where
49    F: FnMut() -> Result<i32>,
50{
51    let dims = output.dims().to_vec();
52    let mut view = tensor_to_view_mut(output)?;
53    let mut error = None;
54    crate::for_each_index(&dims, |idx| {
55        if error.is_some() {
56            return;
57        }
58        match sample() {
59            Ok(value) => view.set(idx, value),
60            Err(err) => error = Some(err),
61        }
62    });
63
64    match error {
65        Some(err) => Err(err),
66        None => Ok(()),
67    }
68}
69
70fn validate_rng_plan(
71    desc: &RngPrimsDescriptor,
72    shapes: &[&[usize]],
73    op_name: &str,
74) -> Result<Vec<usize>> {
75    validate_shape_count(shapes, 1, op_name)?;
76    let output_shape = shapes[0].to_vec();
77    match desc {
78        RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal => Ok(output_shape),
79        RngPrimsDescriptor::Integer { low, high } => {
80            if low >= high {
81                return Err(Error::InvalidArgument(format!(
82                    "{op_name} requires low < high (got low={low}, high={high})"
83                )));
84            }
85            Ok(output_shape)
86        }
87    }
88}
89
90impl TensorRngPrims<Standard<f64>> for CpuBackend {
91    type Plan = CpuRngPlan;
92    type Context = CpuContext;
93
94    fn plan(
95        _ctx: &mut Self::Context,
96        desc: &RngPrimsDescriptor,
97        shapes: &[&[usize]],
98    ) -> Result<Self::Plan> {
99        let output_shape = validate_rng_plan(desc, shapes, "CpuRng")?;
100        match desc {
101            RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal => {
102                Ok((desc.clone(), output_shape))
103            }
104            RngPrimsDescriptor::Integer { .. } => Err(Error::InvalidArgument(
105                "integer RNG planning is only supported for Tensor<i32>".into(),
106            )),
107        }
108    }
109
110    fn execute(
111        _ctx: &mut Self::Context,
112        plan: &Self::Plan,
113        generator: &mut Generator,
114        output: &mut Tensor<f64>,
115    ) -> Result<()> {
116        validate_shape_eq(output.dims(), &plan.1, "CpuRng output")?;
117        match &plan.0 {
118            RngPrimsDescriptor::Uniform => {
119                fill_tensor_with_f64_samples(output, || Ok(generator.sample_uniform_f64()))
120            }
121            RngPrimsDescriptor::Normal => {
122                fill_tensor_with_f64_samples(output, || Ok(generator.sample_standard_normal_f64()))
123            }
124            RngPrimsDescriptor::Integer { .. } => Err(Error::InvalidArgument(
125                "integer RNG execution is only supported for Tensor<i32>".into(),
126            )),
127        }
128    }
129
130    fn has_rng_support(desc: RngPrimsDescriptor) -> bool {
131        matches!(
132            desc,
133            RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal
134        )
135    }
136}
137
138impl TensorRngPrims<Standard<i32>> for CpuBackend {
139    type Plan = CpuRngPlan;
140    type Context = CpuContext;
141
142    fn plan(
143        _ctx: &mut Self::Context,
144        desc: &RngPrimsDescriptor,
145        shapes: &[&[usize]],
146    ) -> Result<Self::Plan> {
147        let output_shape = validate_rng_plan(desc, shapes, "CpuRng")?;
148        match desc {
149            RngPrimsDescriptor::Integer { low, high } => {
150                if low >= high {
151                    return Err(Error::InvalidArgument(format!(
152                        "CpuRng requires low < high (got low={low}, high={high})"
153                    )));
154                }
155                Ok((desc.clone(), output_shape))
156            }
157            RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal => {
158                Err(Error::InvalidArgument(
159                    "floating-point RNG planning is only supported for Tensor<f64>".into(),
160                ))
161            }
162        }
163    }
164
165    fn execute(
166        _ctx: &mut Self::Context,
167        plan: &Self::Plan,
168        generator: &mut Generator,
169        output: &mut Tensor<i32>,
170    ) -> Result<()> {
171        validate_shape_eq(output.dims(), &plan.1, "CpuRng output")?;
172        match &plan.0 {
173            RngPrimsDescriptor::Integer { low, high } => {
174                fill_tensor_with_i32_samples(output, || generator.sample_integer_i32(*low, *high))
175            }
176            RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal => {
177                Err(Error::InvalidArgument(
178                    "floating-point RNG execution is only supported for Tensor<f64>".into(),
179                ))
180            }
181        }
182    }
183
184    fn has_rng_support(desc: RngPrimsDescriptor) -> bool {
185        matches!(desc, RngPrimsDescriptor::Integer { .. })
186    }
187}