summaryrefslogtreecommitdiffstats
path: root/Lib/random.py
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2020-05-08 14:53:15 (GMT)
committerGitHub <noreply@github.com>2020-05-08 14:53:15 (GMT)
commit81a5fc38e81b424869f4710f48e9371dfa2d3b77 (patch)
tree1691df16b3e62be26065f034810c95ad0deb9768 /Lib/random.py
parent2effef7453986bf43a6d921cd471a8bc0722c36a (diff)
downloadcpython-81a5fc38e81b424869f4710f48e9371dfa2d3b77.zip
cpython-81a5fc38e81b424869f4710f48e9371dfa2d3b77.tar.gz
cpython-81a5fc38e81b424869f4710f48e9371dfa2d3b77.tar.bz2
bpo-40541: Add optional *counts* parameter to random.sample() (GH-19970)
Diffstat (limited to 'Lib/random.py')
-rw-r--r--Lib/random.py34
1 files changed, 29 insertions, 5 deletions
diff --git a/Lib/random.py b/Lib/random.py
index f2c4f39..75f70d5 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -331,7 +331,7 @@ class Random(_random.Random):
j = _int(random() * (i+1))
x[i], x[j] = x[j], x[i]
- def sample(self, population, k):
+ def sample(self, population, k, *, counts=None):
"""Chooses k unique random elements from a population sequence or set.
Returns a new list containing elements from the population while
@@ -344,9 +344,21 @@ class Random(_random.Random):
population contains repeats, then each occurrence is a possible
selection in the sample.
- To choose a sample in a range of integers, use range as an argument.
- This is especially fast and space efficient for sampling from a
- large population: sample(range(10000000), 60)
+ Repeated elements can be specified one at a time or with the optional
+ counts parameter. For example:
+
+ sample(['red', 'blue'], counts=[4, 2], k=5)
+
+ is equivalent to:
+
+ sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5)
+
+ To choose a sample from a range of integers, use range() for the
+ population argument. This is especially fast and space efficient
+ for sampling from a large population:
+
+ sample(range(10000000), 60)
+
"""
# Sampling without replacement entails tracking either potential
@@ -379,8 +391,20 @@ class Random(_random.Random):
population = tuple(population)
if not isinstance(population, _Sequence):
raise TypeError("Population must be a sequence. For dicts or sets, use sorted(d).")
- randbelow = self._randbelow
n = len(population)
+ if counts is not None:
+ cum_counts = list(_accumulate(counts))
+ if len(cum_counts) != n:
+ raise ValueError('The number of counts does not match the population')
+ total = cum_counts.pop()
+ if not isinstance(total, int):
+ raise TypeError('Counts must be integers')
+ if total <= 0:
+ raise ValueError('Total of counts must be greater than zero')
+ selections = sample(range(total), k=k)
+ bisect = _bisect
+ return [population[bisect(cum_counts, s)] for s in selections]
+ randbelow = self._randbelow
if not 0 <= k <= n:
raise ValueError("Sample larger than population or is negative")
result = [None] * k