MPFRNormal.hpp
4.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
/**
* \file MPFRNormal.hpp
* \brief Header for MPFRNormal
*
* Sampling exactly from the normal distribution for MPFR.
*
* Copyright (c) Charles Karney (2012) <charles@karney.com> and licensed under
* the MIT/X11 License. For more information, see
* http://randomlib.sourceforge.net/
**********************************************************************/
#if !defined(RANDOMLIB_MPFRNORMAL_HPP)
#define RANDOMLIB_MPFRNORMAL_HPP 1
#include <algorithm> // for max/min
#include <RandomLib/MPFRRandom.hpp>
#if HAVE_MPFR || defined(DOXYGEN)
namespace RandomLib {
/**
* \brief The normal distribution for MPFR.
*
* This is a transcription of ExactNormal (version 1.3) for use with MPFR.
*
* This class uses mutable private objects. So a single MPFRNormal object
* cannot safely be used by multiple threads. In a multi-processing
* environment, each thread should use a thread-specific MPFRNormal object.
*
* @tparam bits the number of bits in each digit.
**********************************************************************/
template<int bits = 32> class MPFRNormal {
public:
/**
* Initialize the MPFRNormal object.
**********************************************************************/
MPFRNormal() { mpz_init(_tt); }
/**
* Destroy the MPFRNormal object.
**********************************************************************/
~MPFRNormal() { mpz_clear(_tt); }
/**
* Sample from the normal distribution with mean 0 and variance 1 returning
* a MPFRRandom.
*
* @param[out] t the MPFRRandom result.
* @param[in,out] r a GMP random generator.
**********************************************************************/
void operator()(MPFRRandom<bits>& t,gmp_randstate_t r) const
{ Compute(r); return _x.swap(t); }
/**
* Sample from the normal distribution with mean 0 and variance 1.
*
* @param[out] val the sample from the normal distribution
* @param[in,out] r a GMP random generator.
* @param[in] round the rounding direction.
* @return the MPFR ternary result (±1 if val is larger/smaller than
* the exact sample).
**********************************************************************/
int operator()(mpfr_t val, gmp_randstate_t r, mpfr_rnd_t round) const
{ Compute(r); return _x(val, r, round); }
private:
// Disable copy constructor and assignment operator
MPFRNormal(const MPFRNormal&);
MPFRNormal& operator=(const MPFRNormal&);
// True with prob exp(-1/2)
int ExpProbH(gmp_randstate_t r) const {
_p.Init(); if (_p.TestHighBit(r)) return 1;
// von Neumann rejection
while (true) {
_q.Init(); if (!_q.LessThan(r, _p)) return 0;
_p.Init(); if (!_p.LessThan(r, _q)) return 1;
}
}
// True with prob exp(-n/2)
int ExpProb(gmp_randstate_t r, unsigned n) const {
while (n--) { if (!ExpProbH(r)) return 0; }
return 1;
}
// n with prob (1-exp(-1/2)) * exp(-n/2)
unsigned ExpProbN(gmp_randstate_t r) const {
unsigned n = 0;
while (ExpProbH(r)) ++n;
return n;
}
// Return:
// 1 with prob 2k/(2k + 2)
// 0 with prob 1/(2k + 2)
// -1 with prob 1/(2k + 2)
int Choose(gmp_randstate_t r, int k) const {
const int b = 15; // To avoid integer overflow on multiplication
const int m = 2 * k + 2;
int n1 = m - 2, n2 = m - 1;
while (true) {
mpz_urandomb(_tt, r, b);
int d = int( mpz_get_ui(_tt) ) * m;
n1 = (std::max)((n1 << b) - d, 0);
if (n1 >= m) return 1;
n2 = (std::min)((n2 << b) - d, m);
if (n2 <= 0) return -1;
if (n1 == 0 && n2 == m) return 0;
}
}
void Compute(gmp_randstate_t r) const {
while (true) {
unsigned k = ExpProbN(r); // the integer part of the result.
if (ExpProb(r, (k - 1) * k)) {
_x.Init();
unsigned s = 1;
for (unsigned j = 0; j <= k; ++j) { // execute k + 1 times
bool first;
for (s = 1, first = true; ; s ^= 1, first = false) {
if (k == 0 && _x.Boolean(r)) break;
_q.Init(); if (!_q.LessThan(r, first ? _x : _p)) break;
int y = k == 0 ? 0 : Choose(r, k);
if (y < 0)
break;
else if (y == 0) {
_p.Init(); if (!_p.LessThan(r, _x)) break;
}
_p.swap(_q); // a fast way of doing p = q
}
if (s == 0) break;
}
if (s != 0) {
_x.AddInteger(k);
if (_x.Boolean(r)) _x.Negate();
return;
}
}
}
}
mutable mpz_t _tt; // A temporary
mutable MPFRRandom<bits> _x;
mutable MPFRRandom<bits> _p;
mutable MPFRRandom<bits> _q;
};
} // namespace RandomLib
#endif // HAVE_MPFR
#endif // RANDOMLIB_MPFRNORMAL_HPP