summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorAntoine Pitrou <antoine@python.org>2020-04-17 17:32:14 (GMT)
committerGitHub <noreply@github.com>2020-04-17 17:32:14 (GMT)
commit75a3378810bab03949ad9f653f78d933bdf3879c (patch)
tree96ec61f0d287ca5632da7ee091079817d4197733 /Lib
parentd7c657d4b121164caa439253da5266b2e29a1bed (diff)
downloadcpython-75a3378810bab03949ad9f653f78d933bdf3879c.zip
cpython-75a3378810bab03949ad9f653f78d933bdf3879c.tar.gz
cpython-75a3378810bab03949ad9f653f78d933bdf3879c.tar.bz2
bpo-40282: Allow random.getrandbits(0) (GH-19539)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/random.py6
-rw-r--r--Lib/test/test_random.py69
2 files changed, 33 insertions, 42 deletions
diff --git a/Lib/random.py b/Lib/random.py
index 82345fa..3243938 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -261,6 +261,8 @@ class Random(_random.Random):
def _randbelow_with_getrandbits(self, n):
"Return a random int in the range [0,n). Raises ValueError if n==0."
+ if not n:
+ raise ValueError("Boundary cannot be zero")
getrandbits = self.getrandbits
k = n.bit_length() # don't use (n-1) here because n can be 1
r = getrandbits(k) # 0 <= r < 2**k
@@ -733,8 +735,8 @@ class SystemRandom(Random):
def getrandbits(self, k):
"""getrandbits(k) -> x. Generates an int with k random bits."""
- if k <= 0:
- raise ValueError('number of bits must be greater than zero')
+ if k < 0:
+ raise ValueError('number of bits must be non-negative')
numbytes = (k + 7) // 8 # bits / 8 and rounded up
x = int.from_bytes(_urandom(numbytes), 'big')
return x >> (numbytes * 8 - k) # trim excess bits
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index f709e52..efac36a 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -263,6 +263,31 @@ class TestBasicOps:
self.assertEqual(x1, x2)
self.assertEqual(y1, y2)
+ def test_getrandbits(self):
+ # Verify ranges
+ for k in range(1, 1000):
+ self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
+ self.assertEqual(self.gen.getrandbits(0), 0)
+
+ # Verify all bits active
+ getbits = self.gen.getrandbits
+ for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
+ all_bits = 2**span-1
+ cum = 0
+ cpl_cum = 0
+ for i in range(100):
+ v = getbits(span)
+ cum |= v
+ cpl_cum |= all_bits ^ v
+ self.assertEqual(cum, all_bits)
+ self.assertEqual(cpl_cum, all_bits)
+
+ # Verify argument checking
+ self.assertRaises(TypeError, self.gen.getrandbits)
+ self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
+ self.assertRaises(ValueError, self.gen.getrandbits, -1)
+ self.assertRaises(TypeError, self.gen.getrandbits, 10.1)
+
def test_pickling(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
state = pickle.dumps(self.gen, proto)
@@ -390,26 +415,6 @@ class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
raises(0, 42, 0)
raises(0, 42, 3.14159)
- def test_genrandbits(self):
- # Verify ranges
- for k in range(1, 1000):
- self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
-
- # Verify all bits active
- getbits = self.gen.getrandbits
- for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
- cum = 0
- for i in range(100):
- cum |= getbits(span)
- self.assertEqual(cum, 2**span-1)
-
- # Verify argument checking
- self.assertRaises(TypeError, self.gen.getrandbits)
- self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
- self.assertRaises(ValueError, self.gen.getrandbits, 0)
- self.assertRaises(ValueError, self.gen.getrandbits, -1)
- self.assertRaises(TypeError, self.gen.getrandbits, 10.1)
-
def test_randbelow_logic(self, _log=log, int=int):
# check bitcount transition points: 2**i and 2**(i+1)-1
# show that: k = int(1.001 + _log(n, 2))
@@ -629,34 +634,18 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(set(range(start,stop)),
set([self.gen.randrange(start,stop) for i in range(100)]))
- def test_genrandbits(self):
+ def test_getrandbits(self):
+ super().test_getrandbits()
+
# Verify cross-platform repeatability
self.gen.seed(1234567)
self.assertEqual(self.gen.getrandbits(100),
97904845777343510404718956115)
- # Verify ranges
- for k in range(1, 1000):
- self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
-
- # Verify all bits active
- getbits = self.gen.getrandbits
- for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
- cum = 0
- for i in range(100):
- cum |= getbits(span)
- self.assertEqual(cum, 2**span-1)
-
- # Verify argument checking
- self.assertRaises(TypeError, self.gen.getrandbits)
- self.assertRaises(TypeError, self.gen.getrandbits, 'a')
- self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
- self.assertRaises(ValueError, self.gen.getrandbits, 0)
- self.assertRaises(ValueError, self.gen.getrandbits, -1)
def test_randrange_uses_getrandbits(self):
# Verify use of getrandbits by randrange
# Use same seed as in the cross-platform repeatability test
- # in test_genrandbits above.
+ # in test_getrandbits above.
self.gen.seed(1234567)
# If randrange uses getrandbits, it should pick getrandbits(100)
# when called with a 100-bits stop argument.