summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/random.py2
-rw-r--r--Lib/test/test_random.py54
-rw-r--r--Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst1
3 files changed, 29 insertions, 28 deletions
diff --git a/Lib/random.py b/Lib/random.py
index db0e6c2..30186fc 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -479,7 +479,7 @@ class Random(_random.Random):
raise TypeError('Counts must be integers')
if total <= 0:
raise ValueError('Total of counts must be greater than zero')
- selections = sample(range(total), k=k)
+ selections = self.sample(range(total), k=k)
bisect = _bisect
return [population[bisect(cum_counts, s)] for s in selections]
randbelow = self._randbelow
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index 41a26e3..35ae4e6 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -223,33 +223,6 @@ class TestBasicOps:
with self.assertRaises(ValueError):
sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts
- def test_sample_counts_equivalence(self):
- # Test the documented strong equivalence to a sample with repeated elements.
- # We run this test on random.Random() which makes deterministic selections
- # for a given seed value.
- sample = random.sample
- seed = random.seed
-
- colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
- counts = [500, 200, 20, 10, 5, 1 ]
- k = 700
- seed(8675309)
- s1 = sample(colors, counts=counts, k=k)
- seed(8675309)
- expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
- self.assertEqual(len(expanded), sum(counts))
- s2 = sample(expanded, k=k)
- self.assertEqual(s1, s2)
-
- pop = 'abcdefghi'
- counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
- seed(8675309)
- s1 = ''.join(sample(pop, counts=counts, k=30))
- expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
- seed(8675309)
- s2 = ''.join(sample(expanded, k=30))
- self.assertEqual(s1, s2)
-
def test_choices(self):
choices = self.gen.choices
data = ['red', 'green', 'blue', 'yellow']
@@ -957,6 +930,33 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(self.gen.randbytes(n),
gen2.getrandbits(n * 8).to_bytes(n, 'little'))
+ def test_sample_counts_equivalence(self):
+ # Test the documented strong equivalence to a sample with repeated elements.
+ # We run this test on random.Random() which makes deterministic selections
+ # for a given seed value.
+ sample = self.gen.sample
+ seed = self.gen.seed
+
+ colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
+ counts = [500, 200, 20, 10, 5, 1 ]
+ k = 700
+ seed(8675309)
+ s1 = sample(colors, counts=counts, k=k)
+ seed(8675309)
+ expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
+ self.assertEqual(len(expanded), sum(counts))
+ s2 = sample(expanded, k=k)
+ self.assertEqual(s1, s2)
+
+ pop = 'abcdefghi'
+ counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
+ seed(8675309)
+ s1 = ''.join(sample(pop, counts=counts, k=30))
+ expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
+ seed(8675309)
+ s2 = ''.join(sample(expanded, k=30))
+ self.assertEqual(s1, s2)
+
def gamma(z, sqrt2pi=(2.0*pi)**0.5):
# Reflection to right half of complex plane
diff --git a/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst b/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst
new file mode 100644
index 0000000..b78d10a
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst
@@ -0,0 +1 @@
+Fix ``random.Random.sample`` when ``counts`` argument is not ``None``.