diff options
-rw-r--r-- | Lib/random.py | 2 | ||||
-rw-r--r-- | Lib/test/test_random.py | 54 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst | 1 |
3 files changed, 29 insertions, 28 deletions
diff --git a/Lib/random.py b/Lib/random.py index db0e6c2..30186fc 100644 --- a/Lib/random.py +++ b/Lib/random.py @@ -479,7 +479,7 @@ class Random(_random.Random): raise TypeError('Counts must be integers') if total <= 0: raise ValueError('Total of counts must be greater than zero') - selections = sample(range(total), k=k) + selections = self.sample(range(total), k=k) bisect = _bisect return [population[bisect(cum_counts, s)] for s in selections] randbelow = self._randbelow diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py index 41a26e3..35ae4e6 100644 --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -223,33 +223,6 @@ class TestBasicOps: with self.assertRaises(ValueError): sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts - def test_sample_counts_equivalence(self): - # Test the documented strong equivalence to a sample with repeated elements. - # We run this test on random.Random() which makes deterministic selections - # for a given seed value. - sample = random.sample - seed = random.seed - - colors = ['red', 'green', 'blue', 'orange', 'black', 'amber'] - counts = [500, 200, 20, 10, 5, 1 ] - k = 700 - seed(8675309) - s1 = sample(colors, counts=counts, k=k) - seed(8675309) - expanded = [color for (color, count) in zip(colors, counts) for i in range(count)] - self.assertEqual(len(expanded), sum(counts)) - s2 = sample(expanded, k=k) - self.assertEqual(s1, s2) - - pop = 'abcdefghi' - counts = [10, 9, 8, 7, 6, 5, 4, 3, 2] - seed(8675309) - s1 = ''.join(sample(pop, counts=counts, k=30)) - expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)]) - seed(8675309) - s2 = ''.join(sample(expanded, k=30)) - self.assertEqual(s1, s2) - def test_choices(self): choices = self.gen.choices data = ['red', 'green', 'blue', 'yellow'] @@ -957,6 +930,33 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase): self.assertEqual(self.gen.randbytes(n), gen2.getrandbits(n * 8).to_bytes(n, 'little')) + def test_sample_counts_equivalence(self): + # Test the documented strong equivalence to a sample with repeated elements. + # We run this test on random.Random() which makes deterministic selections + # for a given seed value. + sample = self.gen.sample + seed = self.gen.seed + + colors = ['red', 'green', 'blue', 'orange', 'black', 'amber'] + counts = [500, 200, 20, 10, 5, 1 ] + k = 700 + seed(8675309) + s1 = sample(colors, counts=counts, k=k) + seed(8675309) + expanded = [color for (color, count) in zip(colors, counts) for i in range(count)] + self.assertEqual(len(expanded), sum(counts)) + s2 = sample(expanded, k=k) + self.assertEqual(s1, s2) + + pop = 'abcdefghi' + counts = [10, 9, 8, 7, 6, 5, 4, 3, 2] + seed(8675309) + s1 = ''.join(sample(pop, counts=counts, k=30)) + expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)]) + seed(8675309) + s2 = ''.join(sample(expanded, k=30)) + self.assertEqual(s1, s2) + def gamma(z, sqrt2pi=(2.0*pi)**0.5): # Reflection to right half of complex plane diff --git a/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst b/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst new file mode 100644 index 0000000..b78d10a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst @@ -0,0 +1 @@ +Fix ``random.Random.sample`` when ``counts`` argument is not ``None``. |