summaryrefslogtreecommitdiffstats
path: root/Lib/random.py
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2022-07-13 14:46:04 (GMT)
committerGitHub <noreply@github.com>2022-07-13 14:46:04 (GMT)
commited06ec1ab851544234138952d357facb32eba6c5 (patch)
tree7c1b03af2b30cf6f270bab039b5ced17a530b320 /Lib/random.py
parentf5c02afaff43f4ed7f4ac74d7c90171e56c2b2d7 (diff)
downloadcpython-ed06ec1ab851544234138952d357facb32eba6c5.zip
cpython-ed06ec1ab851544234138952d357facb32eba6c5.tar.gz
cpython-ed06ec1ab851544234138952d357facb32eba6c5.tar.bz2
GH-81620: Add random.binomialvariate() (GH-94719)
Diffstat (limited to 'Lib/random.py')
-rw-r--r--Lib/random.py95
1 files changed, 93 insertions, 2 deletions
diff --git a/Lib/random.py b/Lib/random.py
index 2166474..00849bd 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -24,6 +24,7 @@
negative exponential
gamma
beta
+ binomial
pareto
Weibull
@@ -49,6 +50,7 @@ from warnings import warn as _warn
from math import log as _log, exp as _exp, pi as _pi, e as _e, ceil as _ceil
from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
from math import tau as TWOPI, floor as _floor, isfinite as _isfinite
+from math import lgamma as _lgamma, fabs as _fabs
from os import urandom as _urandom
from _collections_abc import Sequence as _Sequence
from operator import index as _index
@@ -68,6 +70,7 @@ __all__ = [
"Random",
"SystemRandom",
"betavariate",
+ "binomialvariate",
"choice",
"choices",
"expovariate",
@@ -725,6 +728,91 @@ class Random(_random.Random):
return y / (y + self.gammavariate(beta, 1.0))
return 0.0
+
+ def binomialvariate(self, n=1, p=0.5):
+ """Binomial random variable.
+
+ Gives the number of successes for *n* independent trials
+ with the probability of success in each trial being *p*:
+
+ sum(random() < p for i in range(n))
+
+ Returns an integer in the range: 0 <= X <= n
+
+ """
+ # Error check inputs and handle edge cases
+ if n < 0:
+ raise ValueError("n must be non-negative")
+ if p <= 0.0 or p >= 1.0:
+ if p == 0.0:
+ return 0
+ if p == 1.0:
+ return n
+ raise ValueError("p must be in the range 0.0 <= p <= 1.0")
+
+ random = self.random
+
+ # Fast path for a common case
+ if n == 1:
+ return _index(random() < p)
+
+ # Exploit symmetry to establish: p <= 0.5
+ if p > 0.5:
+ return n - self.binomialvariate(n, 1.0 - p)
+
+ if n * p < 10.0:
+ # BG: Geometric method by Devroye with running time of O(np).
+ # https://dl.acm.org/doi/pdf/10.1145/42372.42381
+ x = y = 0
+ c = _log(1.0 - p)
+ if not c:
+ return x
+ while True:
+ y += _floor(_log(random()) / c) + 1
+ if y > n:
+ return x
+ x += 1
+
+ # BTRS: Transformed rejection with squeeze method by Wolfgang Hörmann
+ # https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.47.8407&rep=rep1&type=pdf
+ assert n*p >= 10.0 and p <= 0.5
+ setup_complete = False
+
+ spq = _sqrt(n * p * (1.0 - p)) # Standard deviation of the distribution
+ b = 1.15 + 2.53 * spq
+ a = -0.0873 + 0.0248 * b + 0.01 * p
+ c = n * p + 0.5
+ vr = 0.92 - 4.2 / b
+
+ while True:
+
+ u = random()
+ v = random()
+ u -= 0.5
+ us = 0.5 - _fabs(u)
+ k = _floor((2.0 * a / us + b) * u + c)
+ if k < 0 or k > n:
+ continue
+
+ # The early-out "squeeze" test substantially reduces
+ # the number of acceptance condition evaluations.
+ if us >= 0.07 and v <= vr:
+ return k
+
+ # Acceptance-rejection test.
+ # Note, the original paper errorneously omits the call to log(v)
+ # when comparing to the log of the rescaled binomial distribution.
+ if not setup_complete:
+ alpha = (2.83 + 5.1 / b) * spq
+ lpq = _log(p / (1.0 - p))
+ m = _floor((n + 1) * p) # Mode of the distribution
+ h = _lgamma(m + 1) + _lgamma(n - m + 1)
+ setup_complete = True # Only needs to be done once
+ v *= alpha / (a / (us * us) + b)
+ if _log(v) <= h - _lgamma(k + 1) - _lgamma(n - k + 1) + (k - m) * lpq:
+ return k
+
+
def paretovariate(self, alpha):
"""Pareto distribution. alpha is the shape parameter."""
# Jain, pg. 495
@@ -810,6 +898,7 @@ vonmisesvariate = _inst.vonmisesvariate
gammavariate = _inst.gammavariate
gauss = _inst.gauss
betavariate = _inst.betavariate
+binomialvariate = _inst.binomialvariate
paretovariate = _inst.paretovariate
weibullvariate = _inst.weibullvariate
getstate = _inst.getstate
@@ -834,15 +923,17 @@ def _test_generator(n, func, args):
low = min(data)
high = max(data)
- print(f'{t1 - t0:.3f} sec, {n} times {func.__name__}')
+ print(f'{t1 - t0:.3f} sec, {n} times {func.__name__}{args!r}')
print('avg %g, stddev %g, min %g, max %g\n' % (xbar, sigma, low, high))
-def _test(N=2000):
+def _test(N=10_000):
_test_generator(N, random, ())
_test_generator(N, normalvariate, (0.0, 1.0))
_test_generator(N, lognormvariate, (0.0, 1.0))
_test_generator(N, vonmisesvariate, (0.0, 1.0))
+ _test_generator(N, binomialvariate, (15, 0.60))
+ _test_generator(N, binomialvariate, (100, 0.75))
_test_generator(N, gammavariate, (0.01, 1.0))
_test_generator(N, gammavariate, (0.1, 1.0))
_test_generator(N, gammavariate, (0.1, 2.0))