tenferro_prims/cpu/
rng.rs1use tenferro_algebra::Standard;
2use tenferro_device::{Error, Generator, Result};
3use tenferro_tensor::Tensor;
4
5use crate::cpu::tensor_to_view_mut;
6use crate::{
7 validate_shape_count, validate_shape_eq, CpuBackend, CpuContext, RngPrimsDescriptor,
8 TensorRngPrims,
9};
10
11pub type CpuRngPlan = (RngPrimsDescriptor, Vec<usize>);
23
24fn fill_tensor_with_f64_samples<F>(output: &mut Tensor<f64>, mut sample: F) -> Result<()>
25where
26 F: FnMut() -> Result<f64>,
27{
28 let dims = output.dims().to_vec();
29 let mut view = tensor_to_view_mut(output)?;
30 let mut error = None;
31 crate::for_each_index(&dims, |idx| {
32 if error.is_some() {
33 return;
34 }
35 match sample() {
36 Ok(value) => view.set(idx, value),
37 Err(err) => error = Some(err),
38 }
39 });
40
41 match error {
42 Some(err) => Err(err),
43 None => Ok(()),
44 }
45}
46
47fn fill_tensor_with_i32_samples<F>(output: &mut Tensor<i32>, mut sample: F) -> Result<()>
48where
49 F: FnMut() -> Result<i32>,
50{
51 let dims = output.dims().to_vec();
52 let mut view = tensor_to_view_mut(output)?;
53 let mut error = None;
54 crate::for_each_index(&dims, |idx| {
55 if error.is_some() {
56 return;
57 }
58 match sample() {
59 Ok(value) => view.set(idx, value),
60 Err(err) => error = Some(err),
61 }
62 });
63
64 match error {
65 Some(err) => Err(err),
66 None => Ok(()),
67 }
68}
69
70fn validate_rng_plan(
71 desc: &RngPrimsDescriptor,
72 shapes: &[&[usize]],
73 op_name: &str,
74) -> Result<Vec<usize>> {
75 validate_shape_count(shapes, 1, op_name)?;
76 let output_shape = shapes[0].to_vec();
77 match desc {
78 RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal => Ok(output_shape),
79 RngPrimsDescriptor::Integer { low, high } => {
80 if low >= high {
81 return Err(Error::InvalidArgument(format!(
82 "{op_name} requires low < high (got low={low}, high={high})"
83 )));
84 }
85 Ok(output_shape)
86 }
87 }
88}
89
90impl TensorRngPrims<Standard<f64>> for CpuBackend {
91 type Plan = CpuRngPlan;
92 type Context = CpuContext;
93
94 fn plan(
95 _ctx: &mut Self::Context,
96 desc: &RngPrimsDescriptor,
97 shapes: &[&[usize]],
98 ) -> Result<Self::Plan> {
99 let output_shape = validate_rng_plan(desc, shapes, "CpuRng")?;
100 match desc {
101 RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal => {
102 Ok((desc.clone(), output_shape))
103 }
104 RngPrimsDescriptor::Integer { .. } => Err(Error::InvalidArgument(
105 "integer RNG planning is only supported for Tensor<i32>".into(),
106 )),
107 }
108 }
109
110 fn execute(
111 _ctx: &mut Self::Context,
112 plan: &Self::Plan,
113 generator: &mut Generator,
114 output: &mut Tensor<f64>,
115 ) -> Result<()> {
116 validate_shape_eq(output.dims(), &plan.1, "CpuRng output")?;
117 match &plan.0 {
118 RngPrimsDescriptor::Uniform => {
119 fill_tensor_with_f64_samples(output, || Ok(generator.sample_uniform_f64()))
120 }
121 RngPrimsDescriptor::Normal => {
122 fill_tensor_with_f64_samples(output, || Ok(generator.sample_standard_normal_f64()))
123 }
124 RngPrimsDescriptor::Integer { .. } => Err(Error::InvalidArgument(
125 "integer RNG execution is only supported for Tensor<i32>".into(),
126 )),
127 }
128 }
129
130 fn has_rng_support(desc: RngPrimsDescriptor) -> bool {
131 matches!(
132 desc,
133 RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal
134 )
135 }
136}
137
138impl TensorRngPrims<Standard<i32>> for CpuBackend {
139 type Plan = CpuRngPlan;
140 type Context = CpuContext;
141
142 fn plan(
143 _ctx: &mut Self::Context,
144 desc: &RngPrimsDescriptor,
145 shapes: &[&[usize]],
146 ) -> Result<Self::Plan> {
147 let output_shape = validate_rng_plan(desc, shapes, "CpuRng")?;
148 match desc {
149 RngPrimsDescriptor::Integer { low, high } => {
150 if low >= high {
151 return Err(Error::InvalidArgument(format!(
152 "CpuRng requires low < high (got low={low}, high={high})"
153 )));
154 }
155 Ok((desc.clone(), output_shape))
156 }
157 RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal => {
158 Err(Error::InvalidArgument(
159 "floating-point RNG planning is only supported for Tensor<f64>".into(),
160 ))
161 }
162 }
163 }
164
165 fn execute(
166 _ctx: &mut Self::Context,
167 plan: &Self::Plan,
168 generator: &mut Generator,
169 output: &mut Tensor<i32>,
170 ) -> Result<()> {
171 validate_shape_eq(output.dims(), &plan.1, "CpuRng output")?;
172 match &plan.0 {
173 RngPrimsDescriptor::Integer { low, high } => {
174 fill_tensor_with_i32_samples(output, || generator.sample_integer_i32(*low, *high))
175 }
176 RngPrimsDescriptor::Uniform | RngPrimsDescriptor::Normal => {
177 Err(Error::InvalidArgument(
178 "floating-point RNG execution is only supported for Tensor<f64>".into(),
179 ))
180 }
181 }
182 }
183
184 fn has_rng_support(desc: RngPrimsDescriptor) -> bool {
185 matches!(desc, RngPrimsDescriptor::Integer { .. })
186 }
187}