MPFRNormalR.hpp
8.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
/**
* \file MPFRNormalR.hpp
* \brief Header for MPFRNormalR
*
* Sampling exactly from the normal distribution for MPFR using the ratio
* method.
*
* Copyright (c) Charles Karney (2012) <charles@karney.com> and licensed under
* the MIT/X11 License. For more information, see
* http://randomlib.sourceforge.net/
**********************************************************************/
#if !defined(RANDOMLIB_MPFRNORMALR_HPP)
#define RANDOMLIB_MPFRNORMALR_HPP 1
#include <algorithm> // for max/min
#include <cmath> // for pow
#include <mpfr.h>
#define HAVE_MPFR (MPFR_VERSION_MAJOR >= 3)
#if HAVE_MPFR || defined(DOXYGEN)
namespace RandomLib {
/**
* \brief The normal distribution for MPFR (ratio method).
*
* This class is <b>DEPRECATED</b>. It is included for illustrative purposes
* only. The MPFRNormal class provides a much more efficient method for
* sampling from the normal distribution.
*
* This is an adaption of NormalDistribution to MPFR. The changes are
* - Use MPFR's random number generator
* - Use sufficient precision internally to ensure that a correctly rounded
* result is returned.
*
* This class uses a mutable private object. So a single MPFRNormalR
* object cannot safely be used by multiple threads. In a multi-processing
* environment, each thread should use a thread-specific MPFRNormalR
* object.
**********************************************************************/
class MPFRNormalR {
private:
// The number of bits of randomness to add at a time. Require that Leva's
// bounds "work" at a precision of 2^-chunk and that an unsigned long can
// hold this many bits.
static const long chunk_ = 32;
static const unsigned long m = 3684067834; // ceil(2^chunk*sqrt(2/e))
public:
/**
* Initialize the MPFRNormalR object.
**********************************************************************/
MPFRNormalR() {
mpz_init(_ui);
mpz_init(_vi);
mpfr_init2(_eps, chunk_);
mpfr_init2(_u, chunk_);
mpfr_init2(_v, chunk_);
mpfr_init2(_up, chunk_);
mpfr_init2(_vp, chunk_);
mpfr_init2(_vx, chunk_);
mpfr_init2(_x1, chunk_);
mpfr_init2(_x2, chunk_);
}
/**
* Destroy the MPFRNormalR object.
**********************************************************************/
~MPFRNormalR() {
mpfr_clear(_x2);
mpfr_clear(_x1);
mpfr_clear(_vx);
mpfr_clear(_vp);
mpfr_clear(_up);
mpfr_clear(_v);
mpfr_clear(_u);
mpfr_clear(_eps);
mpz_clear(_vi);
mpz_clear(_ui);
}
/**
* Sample from the normal distribution with mean 0 and variance 1.
*
* @param[out] val the sample from the normal distribution
* @param[in,out] r a GMP random generator.
* @param[in] round the rounding direction.
* @return the MPFR ternary result (±1 if val is larger/smaller than
* the exact sample).
**********************************************************************/
int operator()(mpfr_t val, gmp_randstate_t r, mpfr_rnd_t round) const {
const double
s = 0.449871, // Constants from Leva
t = -0.386595,
a = 0.19600 ,
b = 0.25472 ,
r1 = 0.27597 ,
r2 = 0.27846 ,
u1 = 0.606530, // sqrt(1/e) rounded down and up
u2 = 0.606531,
scale = std::pow(2.0, -chunk_); // for turning randoms into doubles
while (true) {
mpz_urandomb(_vi, r, chunk_);
if (mpz_cmp_ui(_vi, m) >= 0) continue; // Very early reject
double vf = (mpz_get_ui(_vi) + 0.5) * scale;
mpz_urandomb(_ui, r, chunk_);
double uf = (mpz_get_ui(_ui) + 0.5) * scale;
double
x = uf - s,
y = vf - t,
Q = x*x + y * (a*y - b*x);
if (Q >= r2) continue; // Early reject
mpfr_set_ui_2exp(_eps, 1u, -chunk_, MPFR_RNDN);
mpfr_prec_t prec = chunk_;
mpfr_set_prec(_u, prec);
mpfr_set_prec(_v, prec);
// (u,v) = sw corner of range
mpfr_set_z_2exp(_u, _ui, -prec, MPFR_RNDN);
mpfr_set_z_2exp(_v, _vi, -prec, MPFR_RNDN);
mpfr_set_prec(_up, prec);
mpfr_set_prec(_vp, prec);
// (up,vp) = ne corner of range
mpfr_add(_up, _u, _eps, MPFR_RNDN);
mpfr_add(_vp, _v, _eps, MPFR_RNDN);
// Estimate how many extra bits will be needed to achieve the desired
// precision.
mpfr_prec_t prec_guard = 3 + chunk_ -
(std::max)(mpz_sizeinbase(_ui, 2), mpz_sizeinbase(_vi, 2));
if (Q > r1) {
int reject;
while (true) {
// Rejection curve v^2 + 4 * u^2 * log(u) < 0 has a peak at u =
// exp(-1/2) = 0.60653066. So treat uf in (0.606530, 0.606531) =
// (u1, u2) specially
// Try for rejection first
if (uf <= u1)
reject = Reject(_u, _vp, prec, MPFR_RNDU);
else if (uf >= u2)
reject = Reject(_up, _vp, prec, MPFR_RNDU);
else { // u in (u1, u2)
mpfr_set_prec(_vx, prec);
mpfr_add(_vx, _vp, _eps, MPFR_RNDN);
reject = Reject(_u, _vx, prec, MPFR_RNDU); // Could use _up too
}
if (reject < 0) break; // tried to reject but failed, so accept
// Try for acceptance
if (uf <= u1)
reject = Reject(_up, _v, prec, MPFR_RNDD);
else if (uf >= u2)
reject = Reject(_u, _v, prec, MPFR_RNDD);
else { // u in (u2, u2)
mpfr_sub(_vx, _v, _eps, MPFR_RNDN);
reject = Reject(_u, _vx, prec, MPFR_RNDD); // Could use _up too
}
if (reject > 0) break; // tried to accept but failed, so reject
prec = Refine(r, prec); // still can't decide, to refine
}
if (reject > 0) continue; // reject, back to outer loop
}
// Now evaluate v/u to the necessary precision
mpfr_prec_t prec0 = mpfr_get_prec (val);
// while (prec < prec0 + prec_guard) prec = Refine(r, prec);
if (prec < prec0 + prec_guard)
prec = Refine(r, prec,
(prec0 + prec_guard - prec + chunk_ - 1) / chunk_);
mpfr_set_prec(_x1, prec0);
mpfr_set_prec(_x2, prec0);
int flag;
while (true) {
int
f1 = mpfr_div(_x1, _v, _up, round), // min slope
f2 = mpfr_div(_x2, _vp, _u, round); // max slope
if (f1 == f2 && mpfr_equal_p(_x1, _x2)) {
flag = f1;
break;
}
prec = Refine(r, prec);
}
mpz_urandomb(_ui, r, 1);
if (mpz_tstbit(_ui, 0)) {
flag = -flag;
mpfr_neg(val, _x1, MPFR_RNDN);
} else
mpfr_set(val, _x1, MPFR_RNDN);
// std::cerr << uf << " " << vf << " " << Q << "\n";
return flag;
}
}
private:
// disable copy constructor and assignment operator
MPFRNormalR(const MPFRNormalR&);
MPFRNormalR& operator=(const MPFRNormalR&);
// Refine the random square
mpfr_prec_t Refine(gmp_randstate_t r, mpfr_prec_t prec, long num = 1)
const {
if (num <= 0) return prec;
// Use _vx as scratch
prec += num * chunk_;
mpfr_div_2ui(_eps, _eps, num * chunk_, MPFR_RNDN);
mpz_urandomb(_ui, r, num * chunk_);
mpfr_set_prec(_up, prec);
mpfr_set_z_2exp(_up, _ui, -prec, MPFR_RNDN);
mpfr_set_prec(_vx, prec);
mpfr_add(_vx, _u, _up, MPFR_RNDN);
mpfr_swap(_u, _vx); // u = vx
mpfr_add(_up, _u, _eps, MPFR_RNDN);
mpz_urandomb(_vi, r, num * chunk_);
mpfr_set_prec(_vp, prec);
mpfr_set_z_2exp(_vp, _vi, -prec, MPFR_RNDN);
mpfr_set_prec(_vx, prec);
mpfr_add(_vx, _v, _vp, MPFR_RNDN);
mpfr_swap(_v, _vx); // v = vx
mpfr_add(_vp, _v, _eps, MPFR_RNDN);
return prec;
}
// Evaluate the sign of the rejection condition v^2 + 4*u^2*log(u)
int Reject(mpfr_t u, mpfr_t v, mpfr_prec_t prec, mpfr_rnd_t round) const {
// Use x1, x2 as scratch
mpfr_set_prec(_x1, prec);
mpfr_log(_x1, u, round);
mpfr_mul(_x1, _x1, u, round); // Important to do the multiplications in
mpfr_mul(_x1, _x1, u, round); // this order so that rounding works right.
mpfr_mul_2ui(_x1, _x1, 2u, round); // 4*u^2*log(u)
mpfr_set_prec(_x2, prec);
mpfr_mul(_x2, v, v, round); // v^2
mpfr_add(_x1, _x1, _x2, round); // v^2 + 4*u^2*log(u)
return mpfr_sgn(_x1);
}
mutable mpz_t _ui;
mutable mpz_t _vi;
mutable mpfr_t _eps;
mutable mpfr_t _u;
mutable mpfr_t _v;
mutable mpfr_t _up;
mutable mpfr_t _vp;
mutable mpfr_t _vx;
mutable mpfr_t _x1;
mutable mpfr_t _x2;
};
} // namespace RandomLib
#endif // HAVE_MPFR
#endif // RANDOMLIB_MPFRNORMALR_HPP