diff options
author | Raymond Hettinger <python@rcn.com> | 2014-05-30 09:28:36 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2014-05-30 09:28:36 (GMT) |
commit | 35db43955cf231a4040d32c77b1ff5a4b639039f (patch) | |
tree | 670f590844d11613b16ab61b018caaeb93e9bab1 /Lib | |
parent | e7bfe13635e4201660c9d016b62de10c2f7c9de3 (diff) | |
download | cpython-35db43955cf231a4040d32c77b1ff5a4b639039f.zip cpython-35db43955cf231a4040d32c77b1ff5a4b639039f.tar.gz cpython-35db43955cf231a4040d32c77b1ff5a4b639039f.tar.bz2 |
Issue #13742: Add key and reverse parameters to heapq.merge()
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/heapq.py | 74 | ||||
-rw-r--r-- | Lib/test/test_heapq.py | 19 |
2 files changed, 78 insertions, 15 deletions
diff --git a/Lib/heapq.py b/Lib/heapq.py index ae7ac96..79b46fe 100644 --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -176,6 +176,16 @@ def heapify(x): for i in reversed(range(n//2)): _siftup(x, i) +def _heappop_max(heap): + """Maxheap version of a heappop.""" + lastelt = heap.pop() # raises appropriate IndexError if heap is empty + if heap: + returnitem = heap[0] + heap[0] = lastelt + _siftup_max(heap, 0) + return returnitem + return lastelt + def _heapreplace_max(heap, item): """Maxheap version of a heappop followed by a heappush.""" returnitem = heap[0] # raises appropriate IndexError if heap is empty @@ -311,7 +321,7 @@ try: except ImportError: pass -def merge(*iterables): +def merge(*iterables, key=None, reverse=False): '''Merge multiple sorted inputs into a single sorted output. Similar to sorted(itertools.chain(*iterables)) but returns a generator, @@ -321,31 +331,73 @@ def merge(*iterables): >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25])) [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25] + If *key* is not None, applies a key function to each element to determine + its sort order. + + >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len)) + ['dog', 'cat', 'fish', 'horse', 'kangaroo'] + ''' h = [] h_append = h.append + + if reverse: + _heapify = _heapify_max + _heappop = _heappop_max + _heapreplace = _heapreplace_max + direction = -1 + else: + _heapify = heapify + _heappop = heappop + _heapreplace = heapreplace + direction = 1 + + if key is None: + for order, it in enumerate(map(iter, iterables)): + try: + next = it.__next__ + h_append([next(), order * direction, next]) + except StopIteration: + pass + _heapify(h) + while len(h) > 1: + try: + while True: + value, order, next = s = h[0] + yield value + s[0] = next() # raises StopIteration when exhausted + _heapreplace(h, s) # restore heap condition + except StopIteration: + _heappop(h) # remove empty iterator + if h: + # fast case when only a single iterator remains + value, order, next = h[0] + yield value + yield from next.__self__ + return + for order, it in enumerate(map(iter, iterables)): try: next = it.__next__ - h_append([next(), order, next]) + value = next() + h_append([key(value), order * direction, value, next]) except StopIteration: pass - heapify(h) - - _heapreplace = heapreplace + _heapify(h) while len(h) > 1: try: while True: - value, order, next = s = h[0] + key_value, order, value, next = s = h[0] yield value - s[0] = next() # raises StopIteration when exhausted - _heapreplace(h, s) # restore heap condition + value = next() + s[0] = key(value) + s[2] = value + _heapreplace(h, s) except StopIteration: - heappop(h) # remove empty iterator + _heappop(h) if h: - # fast case when only a single iterator remains - value, order, next = h[0] + key_value, order, value, next = h[0] yield value yield from next.__self__ diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index 59c7029..685797a 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -6,6 +6,7 @@ import unittest from test import support from unittest import TestCase, skipUnless +from operator import itemgetter py_heapq = support.import_fresh_module('heapq', blocked=['_heapq']) c_heapq = support.import_fresh_module('heapq', fresh=['_heapq']) @@ -152,11 +153,21 @@ class TestHeap: def test_merge(self): inputs = [] - for i in range(random.randrange(5)): - row = sorted(random.randrange(1000) for j in range(random.randrange(10))) + for i in range(random.randrange(25)): + row = [] + for j in range(random.randrange(100)): + tup = random.choice('ABC'), random.randrange(-500, 500) + row.append(tup) inputs.append(row) - self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs))) - self.assertEqual(list(self.module.merge()), []) + + for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]: + for reverse in [False, True]: + seqs = [] + for seq in inputs: + seqs.append(sorted(seq, key=key, reverse=reverse)) + self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse), + list(self.module.merge(*seqs, key=key, reverse=reverse))) + self.assertEqual(list(self.module.merge()), []) def test_merge_does_not_suppress_index_error(self): # Issue 19018: Heapq.merge suppresses IndexError from user generator |