summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRaymond Hettinger <python@rcn.com>2016-09-07 00:15:29 (GMT)
committerRaymond Hettinger <python@rcn.com>2016-09-07 00:15:29 (GMT)
commite8f1e002c642e30b820181cd87ae9d187d709f59 (patch)
treef5cb3c6514eec58c8bf2071dbaa7b71473f5d2d4
parent63d98bcd4c88eea1c4b50dae95da662284813114 (diff)
downloadcpython-e8f1e002c642e30b820181cd87ae9d187d709f59.zip
cpython-e8f1e002c642e30b820181cd87ae9d187d709f59.tar.gz
cpython-e8f1e002c642e30b820181cd87ae9d187d709f59.tar.bz2
Issue #18844: Add random.weighted_choices()
-rw-r--r--Doc/library/random.rst21
-rw-r--r--Lib/random.py28
-rw-r--r--Lib/test/test_random.py68
-rw-r--r--Misc/NEWS2
4 files changed, 118 insertions, 1 deletions
diff --git a/Doc/library/random.rst b/Doc/library/random.rst
index 6dc54d2..330cce1 100644
--- a/Doc/library/random.rst
+++ b/Doc/library/random.rst
@@ -124,6 +124,27 @@ Functions for sequences:
Return a random element from the non-empty sequence *seq*. If *seq* is empty,
raises :exc:`IndexError`.
+.. function:: weighted_choices(k, population, weights=None, *, cum_weights=None)
+
+ Return a *k* sized list of elements chosen from the *population* with replacement.
+ If the *population* is empty, raises :exc:`IndexError`.
+
+ If a *weights* sequence is specified, selections are made according to the
+ relative weights. Alternatively, if a *cum_weights* sequence is given, the
+ selections are made according to the cumulative weights. For example, the
+ relative weights ``[10, 5, 30, 5]`` are equivalent to the cumulative
+ weights ``[10, 15, 45, 50]``. Internally, the relative weights are
+ converted to cumulative weights before making selections, so supplying the
+ cumulative weights saves work.
+
+ If neither *weights* nor *cum_weights* are specified, selections are made
+ with equal probability. If a weights sequence is supplied, it must be
+ the same length as the *population* sequence. It is a :exc:`TypeError`
+ to specify both *weights* and *cum_weights*.
+
+ The *weights* or *cum_weights* can use any numeric type that interoperates
+ with the :class:`float` values returned by :func:`random` (that includes
+ integers, floats, and fractions but excludes decimals).
.. function:: shuffle(x[, random])
diff --git a/Lib/random.py b/Lib/random.py
index 82f6013..136395e 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -8,6 +8,7 @@
---------
pick random element
pick random sample
+ pick weighted random sample
generate random permutation
distributions on the real line:
@@ -43,12 +44,14 @@ from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
from os import urandom as _urandom
from _collections_abc import Set as _Set, Sequence as _Sequence
from hashlib import sha512 as _sha512
+import itertools as _itertools
+import bisect as _bisect
__all__ = ["Random","seed","random","uniform","randint","choice","sample",
"randrange","shuffle","normalvariate","lognormvariate",
"expovariate","vonmisesvariate","gammavariate","triangular",
"gauss","betavariate","paretovariate","weibullvariate",
- "getstate","setstate", "getrandbits",
+ "getstate","setstate", "getrandbits", "weighted_choices",
"SystemRandom"]
NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0)
@@ -334,6 +337,28 @@ class Random(_random.Random):
result[i] = population[j]
return result
+ def weighted_choices(self, k, population, weights=None, *, cum_weights=None):
+ """Return a k sized list of population elements chosen with replacement.
+
+ If the relative weights or cumulative weights are not specified,
+ the selections are made with equal probability.
+
+ """
+ if cum_weights is None:
+ if weights is None:
+ choice = self.choice
+ return [choice(population) for i in range(k)]
+ else:
+ cum_weights = list(_itertools.accumulate(weights))
+ elif weights is not None:
+ raise TypeError('Cannot specify both weights and cumulative_weights')
+ if len(cum_weights) != len(population):
+ raise ValueError('The number of weights does not match the population')
+ bisect = _bisect.bisect
+ random = self.random
+ total = cum_weights[-1]
+ return [population[bisect(cum_weights, random() * total)] for i in range(k)]
+
## -------------------- real-valued distributions -------------------
## -------------------- uniform distribution -------------------
@@ -724,6 +749,7 @@ choice = _inst.choice
randrange = _inst.randrange
sample = _inst.sample
shuffle = _inst.shuffle
+weighted_choices = _inst.weighted_choices
normalvariate = _inst.normalvariate
lognormvariate = _inst.lognormvariate
expovariate = _inst.expovariate
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index e80ed17..b3741a8 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -7,6 +7,7 @@ import warnings
from functools import partial
from math import log, exp, pi, fsum, sin
from test import support
+from fractions import Fraction
class TestBasicOps:
# Superclass with tests common to all generators.
@@ -141,6 +142,73 @@ class TestBasicOps:
def test_sample_on_dicts(self):
self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2)
+ def test_weighted_choices(self):
+ weighted_choices = self.gen.weighted_choices
+ data = ['red', 'green', 'blue', 'yellow']
+ str_data = 'abcd'
+ range_data = range(4)
+ set_data = set(range(4))
+
+ # basic functionality
+ for sample in [
+ weighted_choices(5, data),
+ weighted_choices(5, data, range(4)),
+ weighted_choices(k=5, population=data, weights=range(4)),
+ weighted_choices(k=5, population=data, cum_weights=range(4)),
+ ]:
+ self.assertEqual(len(sample), 5)
+ self.assertEqual(type(sample), list)
+ self.assertTrue(set(sample) <= set(data))
+
+ # test argument handling
+ with self.assertRaises(TypeError): # missing arguments
+ weighted_choices(2)
+
+ self.assertEqual(weighted_choices(0, data), []) # k == 0
+ self.assertEqual(weighted_choices(-1, data), []) # negative k behaves like ``[0] * -1``
+ with self.assertRaises(TypeError):
+ weighted_choices(2.5, data) # k is a float
+
+ self.assertTrue(set(weighted_choices(5, str_data)) <= set(str_data)) # population is a string sequence
+ self.assertTrue(set(weighted_choices(5, range_data)) <= set(range_data)) # population is a range
+ with self.assertRaises(TypeError):
+ weighted_choices(2.5, set_data) # population is not a sequence
+
+ self.assertTrue(set(weighted_choices(5, data, None)) <= set(data)) # weights is None
+ self.assertTrue(set(weighted_choices(5, data, weights=None)) <= set(data))
+ with self.assertRaises(ValueError):
+ weighted_choices(5, data, [1,2]) # len(weights) != len(population)
+ with self.assertRaises(IndexError):
+ weighted_choices(5, data, [0]*4) # weights sum to zero
+ with self.assertRaises(TypeError):
+ weighted_choices(5, data, 10) # non-iterable weights
+ with self.assertRaises(TypeError):
+ weighted_choices(5, data, [None]*4) # non-numeric weights
+ for weights in [
+ [15, 10, 25, 30], # integer weights
+ [15.1, 10.2, 25.2, 30.3], # float weights
+ [Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional weights
+ [True, False, True, False] # booleans (include / exclude)
+ ]:
+ self.assertTrue(set(weighted_choices(5, data, weights)) <= set(data))
+
+ with self.assertRaises(ValueError):
+ weighted_choices(5, data, cum_weights=[1,2]) # len(weights) != len(population)
+ with self.assertRaises(IndexError):
+ weighted_choices(5, data, cum_weights=[0]*4) # cum_weights sum to zero
+ with self.assertRaises(TypeError):
+ weighted_choices(5, data, cum_weights=10) # non-iterable cum_weights
+ with self.assertRaises(TypeError):
+ weighted_choices(5, data, cum_weights=[None]*4) # non-numeric cum_weights
+ with self.assertRaises(TypeError):
+ weighted_choices(5, data, range(4), cum_weights=range(4)) # both weights and cum_weights
+ for weights in [
+ [15, 10, 25, 30], # integer cum_weights
+ [15.1, 10.2, 25.2, 30.3], # float cum_weights
+ [Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional cum_weights
+ ]:
+ self.assertTrue(set(weighted_choices(5, data, cum_weights=weights)) <= set(data))
+
def test_gauss(self):
# Ensure that the seed() method initializes all the hidden state. In
# particular, through 2.2.1 it failed to reset a piece of state used
diff --git a/Misc/NEWS b/Misc/NEWS
index e913ef8..fbf7b2b 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -101,6 +101,8 @@ Library
- Issue #27691: Fix ssl module's parsing of GEN_RID subject alternative name
fields in X.509 certs.
+- Issue #18844: Add random.weighted_choices().
+
- Issue #25761: Improved error reporting about truncated pickle data in
C implementation of unpickler. UnpicklingError is now raised instead of
AttributeError and ValueError in some cases.