summaryrefslogtreecommitdiffstats
path: root/Lib/random.py
diff options
context:
space:
mode:
authorRaymond Hettinger <python@rcn.com>2002-11-12 17:41:57 (GMT)
committerRaymond Hettinger <python@rcn.com>2002-11-12 17:41:57 (GMT)
commitf24eb35d185c0623315cfbd9977d37c509860dcf (patch)
tree336a055089a3bd51d89bf4fdc987fb90ce743e09 /Lib/random.py
parent3a7ad5c5843467dc67ddf4c6622dd466f0f42c74 (diff)
downloadcpython-f24eb35d185c0623315cfbd9977d37c509860dcf.zip
cpython-f24eb35d185c0623315cfbd9977d37c509860dcf.tar.gz
cpython-f24eb35d185c0623315cfbd9977d37c509860dcf.tar.bz2
SF patch 629637: Add sample(population, k) method to the random module.
Used for random sampling without replacement.
Diffstat (limited to 'Lib/random.py')
-rw-r--r--Lib/random.py58
1 files changed, 56 insertions, 2 deletions
diff --git a/Lib/random.py b/Lib/random.py
index 4d29080..e2c675c 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -7,6 +7,7 @@
sequences
---------
pick random element
+ pick random sample
generate random permutation
distributions on the real line:
@@ -77,7 +78,7 @@ from math import log as _log, exp as _exp, pi as _pi, e as _e
from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
from math import floor as _floor
-__all__ = ["Random","seed","random","uniform","randint","choice",
+__all__ = ["Random","seed","random","uniform","randint","choice","sample",
"randrange","shuffle","normalvariate","lognormvariate",
"cunifvariate","expovariate","vonmisesvariate","gammavariate",
"stdgamma","gauss","betavariate","paretovariate","weibullvariate",
@@ -373,6 +374,43 @@ class Random:
j = int(random() * (i+1))
x[i], x[j] = x[j], x[i]
+ def sample(self, population, k, random=None, int=int):
+ """Chooses k unique random elements from a population sequence.
+
+ Returns a new list containing elements from the population. The
+ list itself is in random order so that all sub-slices are also
+ random samples. The original sequence is left undisturbed.
+
+ If the population has repeated elements, then each occurence is
+ a possible selection in the sample.
+
+ If indices are needed for a large population, use xrange as an
+ argument: sample(xrange(10000000), 60)
+
+ Optional arg random is a 0-argument function returning a random
+ float in [0.0, 1.0); by default, the standard random.random.
+ """
+
+ n = len(population)
+ if not 0 <= k <= n:
+ raise ValueError, "sample larger than population"
+ if random is None:
+ random = self.random
+ if n < 6 * k: # if n len list takes less space than a k len dict
+ pool = list(population)
+ for i in xrange(n-1, n-k-1, -1):
+ j = int(random() * (i+1))
+ pool[i], pool[j] = pool[j], pool[i]
+ return pool[-k:]
+ inorder = [None] * k
+ selections = {}
+ for i in xrange(k):
+ j = int(random() * n)
+ while j in selections:
+ j = int(random() * n)
+ selections[j] = inorder[i] = population[j]
+ return inorder # return selections in the order they were picked
+
## -------------------- real-valued distributions -------------------
## -------------------- uniform distribution -------------------
@@ -711,7 +749,19 @@ def _test_generator(n, funccall):
print 'avg %g, stddev %g, min %g, max %g' % \
(avg, stddev, smallest, largest)
-def _test(N=20000):
+def _test_sample(n):
+ # For the entire allowable range of 0 <= k <= n, validate that
+ # the sample is of the correct length and contains only unique items
+ population = xrange(n)
+ for k in xrange(n+1):
+ s = sample(population, k)
+ assert len(dict([(elem,True) for elem in s])) == len(s) == k
+
+def _sample_generator(n, k):
+ # Return a fixed element from the sample. Validates random ordering.
+ return sample(xrange(n), k)[k//2]
+
+def _test(N=2000):
print 'TWOPI =', TWOPI
print 'LOG4 =', LOG4
print 'NV_MAGICCONST =', NV_MAGICCONST
@@ -735,6 +785,9 @@ def _test(N=20000):
_test_generator(N, 'betavariate(3.0, 3.0)')
_test_generator(N, 'paretovariate(1.0)')
_test_generator(N, 'weibullvariate(1.0, 1.0)')
+ _test_generator(N, '_sample_generator(50, 5)') # expected s.d.: 14.4
+ _test_generator(N, '_sample_generator(50, 45)') # expected s.d.: 14.4
+ _test_sample(1000)
# Test jumpahead.
s = getstate()
@@ -760,6 +813,7 @@ uniform = _inst.uniform
randint = _inst.randint
choice = _inst.choice
randrange = _inst.randrange
+sample = _inst.sample
shuffle = _inst.shuffle
normalvariate = _inst.normalvariate
lognormvariate = _inst.lognormvariate