diff options
author | Raymond Hettinger <python@rcn.com> | 2007-01-04 17:53:34 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2007-01-04 17:53:34 (GMT) |
commit | 769a40a1d046baddbac4c01fa8feada2ee9e5207 (patch) | |
tree | ad7a34d423aeb39afb406b0427996012595369c8 | |
parent | 2dc4db017483a25d9f2438ae4f36d98ec954ed37 (diff) | |
download | cpython-769a40a1d046baddbac4c01fa8feada2ee9e5207.zip cpython-769a40a1d046baddbac4c01fa8feada2ee9e5207.tar.gz cpython-769a40a1d046baddbac4c01fa8feada2ee9e5207.tar.bz2 |
Fix stability of heapq's nlargest() and nsmallest().
-rw-r--r-- | Lib/heapq.py | 8 | ||||
-rw-r--r-- | Lib/test/test_heapq.py | 24 |
2 files changed, 14 insertions, 18 deletions
diff --git a/Lib/heapq.py b/Lib/heapq.py index 04725cd..753c3b7 100644 --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -130,7 +130,7 @@ __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'nlargest', 'nsmallest'] from itertools import islice, repeat, count, imap, izip, tee -from operator import itemgetter +from operator import itemgetter, neg import bisect def heappush(heap, item): @@ -315,8 +315,6 @@ def nsmallest(n, iterable, key=None): Equivalent to: sorted(iterable, key=key)[:n] """ - if key is None: - return _nsmallest(n, iterable) in1, in2 = tee(iterable) it = izip(imap(key, in1), count(), in2) # decorate result = _nsmallest(n, it) @@ -328,10 +326,8 @@ def nlargest(n, iterable, key=None): Equivalent to: sorted(iterable, key=key, reverse=True)[:n] """ - if key is None: - return _nlargest(n, iterable) in1, in2 = tee(iterable) - it = izip(imap(key, in1), count(), in2) # decorate + it = izip(imap(key, in1), imap(neg, count()), in2) # decorate result = _nlargest(n, it) return map(itemgetter(2), result) # undecorate diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index 2da4f8c..e9f2798 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -104,20 +104,20 @@ class TestHeap(unittest.TestCase): self.assertEqual(heap_sorted, sorted(data)) def test_nsmallest(self): - data = [random.randrange(2000) for i in range(1000)] - f = lambda x: x * 547 % 2000 - for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): - self.assertEqual(nsmallest(n, data), sorted(data)[:n]) - self.assertEqual(nsmallest(n, data, key=f), - sorted(data, key=f)[:n]) + data = [(random.randrange(2000), i) for i in range(1000)] + for f in (None, lambda x: x[0] * 547 % 2000): + for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): + self.assertEqual(nsmallest(n, data), sorted(data)[:n]) + self.assertEqual(nsmallest(n, data, key=f), + sorted(data, key=f)[:n]) def test_nlargest(self): - data = [random.randrange(2000) for i in range(1000)] - f = lambda x: x * 547 % 2000 - for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): - self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n]) - self.assertEqual(nlargest(n, data, key=f), - sorted(data, key=f, reverse=True)[:n]) + data = [(random.randrange(2000), i) for i in range(1000)] + for f in (None, lambda x: x[0] * 547 % 2000): + for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): + self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n]) + self.assertEqual(nlargest(n, data, key=f), + sorted(data, key=f, reverse=True)[:n]) #============================================================================== |