summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorRaymond Hettinger <python@rcn.com>2007-01-04 17:53:34 (GMT)
committerRaymond Hettinger <python@rcn.com>2007-01-04 17:53:34 (GMT)
commit769a40a1d046baddbac4c01fa8feada2ee9e5207 (patch)
treead7a34d423aeb39afb406b0427996012595369c8 /Lib
parent2dc4db017483a25d9f2438ae4f36d98ec954ed37 (diff)
downloadcpython-769a40a1d046baddbac4c01fa8feada2ee9e5207.zip
cpython-769a40a1d046baddbac4c01fa8feada2ee9e5207.tar.gz
cpython-769a40a1d046baddbac4c01fa8feada2ee9e5207.tar.bz2
Fix stability of heapq's nlargest() and nsmallest().
Diffstat (limited to 'Lib')
-rw-r--r--Lib/heapq.py8
-rw-r--r--Lib/test/test_heapq.py24
2 files changed, 14 insertions, 18 deletions
diff --git a/Lib/heapq.py b/Lib/heapq.py
index 04725cd..753c3b7 100644
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -130,7 +130,7 @@ __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'nlargest',
'nsmallest']
from itertools import islice, repeat, count, imap, izip, tee
-from operator import itemgetter
+from operator import itemgetter, neg
import bisect
def heappush(heap, item):
@@ -315,8 +315,6 @@ def nsmallest(n, iterable, key=None):
Equivalent to: sorted(iterable, key=key)[:n]
"""
- if key is None:
- return _nsmallest(n, iterable)
in1, in2 = tee(iterable)
it = izip(imap(key, in1), count(), in2) # decorate
result = _nsmallest(n, it)
@@ -328,10 +326,8 @@ def nlargest(n, iterable, key=None):
Equivalent to: sorted(iterable, key=key, reverse=True)[:n]
"""
- if key is None:
- return _nlargest(n, iterable)
in1, in2 = tee(iterable)
- it = izip(imap(key, in1), count(), in2) # decorate
+ it = izip(imap(key, in1), imap(neg, count()), in2) # decorate
result = _nlargest(n, it)
return map(itemgetter(2), result) # undecorate
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index 2da4f8c..e9f2798 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -104,20 +104,20 @@ class TestHeap(unittest.TestCase):
self.assertEqual(heap_sorted, sorted(data))
def test_nsmallest(self):
- data = [random.randrange(2000) for i in range(1000)]
- f = lambda x: x * 547 % 2000
- for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
- self.assertEqual(nsmallest(n, data), sorted(data)[:n])
- self.assertEqual(nsmallest(n, data, key=f),
- sorted(data, key=f)[:n])
+ data = [(random.randrange(2000), i) for i in range(1000)]
+ for f in (None, lambda x: x[0] * 547 % 2000):
+ for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
+ self.assertEqual(nsmallest(n, data), sorted(data)[:n])
+ self.assertEqual(nsmallest(n, data, key=f),
+ sorted(data, key=f)[:n])
def test_nlargest(self):
- data = [random.randrange(2000) for i in range(1000)]
- f = lambda x: x * 547 % 2000
- for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
- self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n])
- self.assertEqual(nlargest(n, data, key=f),
- sorted(data, key=f, reverse=True)[:n])
+ data = [(random.randrange(2000), i) for i in range(1000)]
+ for f in (None, lambda x: x[0] * 547 % 2000):
+ for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
+ self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n])
+ self.assertEqual(nlargest(n, data, key=f),
+ sorted(data, key=f, reverse=True)[:n])
#==============================================================================