summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorRaymond Hettinger <python@rcn.com>2014-05-30 09:28:36 (GMT)
committerRaymond Hettinger <python@rcn.com>2014-05-30 09:28:36 (GMT)
commit35db43955cf231a4040d32c77b1ff5a4b639039f (patch)
tree670f590844d11613b16ab61b018caaeb93e9bab1 /Lib
parente7bfe13635e4201660c9d016b62de10c2f7c9de3 (diff)
downloadcpython-35db43955cf231a4040d32c77b1ff5a4b639039f.zip
cpython-35db43955cf231a4040d32c77b1ff5a4b639039f.tar.gz
cpython-35db43955cf231a4040d32c77b1ff5a4b639039f.tar.bz2
Issue #13742: Add key and reverse parameters to heapq.merge()
Diffstat (limited to 'Lib')
-rw-r--r--Lib/heapq.py74
-rw-r--r--Lib/test/test_heapq.py19
2 files changed, 78 insertions, 15 deletions
diff --git a/Lib/heapq.py b/Lib/heapq.py
index ae7ac96..79b46fe 100644
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -176,6 +176,16 @@ def heapify(x):
for i in reversed(range(n//2)):
_siftup(x, i)
+def _heappop_max(heap):
+ """Maxheap version of a heappop."""
+ lastelt = heap.pop() # raises appropriate IndexError if heap is empty
+ if heap:
+ returnitem = heap[0]
+ heap[0] = lastelt
+ _siftup_max(heap, 0)
+ return returnitem
+ return lastelt
+
def _heapreplace_max(heap, item):
"""Maxheap version of a heappop followed by a heappush."""
returnitem = heap[0] # raises appropriate IndexError if heap is empty
@@ -311,7 +321,7 @@ try:
except ImportError:
pass
-def merge(*iterables):
+def merge(*iterables, key=None, reverse=False):
'''Merge multiple sorted inputs into a single sorted output.
Similar to sorted(itertools.chain(*iterables)) but returns a generator,
@@ -321,31 +331,73 @@ def merge(*iterables):
>>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25]))
[0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25]
+ If *key* is not None, applies a key function to each element to determine
+ its sort order.
+
+ >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len))
+ ['dog', 'cat', 'fish', 'horse', 'kangaroo']
+
'''
h = []
h_append = h.append
+
+ if reverse:
+ _heapify = _heapify_max
+ _heappop = _heappop_max
+ _heapreplace = _heapreplace_max
+ direction = -1
+ else:
+ _heapify = heapify
+ _heappop = heappop
+ _heapreplace = heapreplace
+ direction = 1
+
+ if key is None:
+ for order, it in enumerate(map(iter, iterables)):
+ try:
+ next = it.__next__
+ h_append([next(), order * direction, next])
+ except StopIteration:
+ pass
+ _heapify(h)
+ while len(h) > 1:
+ try:
+ while True:
+ value, order, next = s = h[0]
+ yield value
+ s[0] = next() # raises StopIteration when exhausted
+ _heapreplace(h, s) # restore heap condition
+ except StopIteration:
+ _heappop(h) # remove empty iterator
+ if h:
+ # fast case when only a single iterator remains
+ value, order, next = h[0]
+ yield value
+ yield from next.__self__
+ return
+
for order, it in enumerate(map(iter, iterables)):
try:
next = it.__next__
- h_append([next(), order, next])
+ value = next()
+ h_append([key(value), order * direction, value, next])
except StopIteration:
pass
- heapify(h)
-
- _heapreplace = heapreplace
+ _heapify(h)
while len(h) > 1:
try:
while True:
- value, order, next = s = h[0]
+ key_value, order, value, next = s = h[0]
yield value
- s[0] = next() # raises StopIteration when exhausted
- _heapreplace(h, s) # restore heap condition
+ value = next()
+ s[0] = key(value)
+ s[2] = value
+ _heapreplace(h, s)
except StopIteration:
- heappop(h) # remove empty iterator
+ _heappop(h)
if h:
- # fast case when only a single iterator remains
- value, order, next = h[0]
+ key_value, order, value, next = h[0]
yield value
yield from next.__self__
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index 59c7029..685797a 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -6,6 +6,7 @@ import unittest
from test import support
from unittest import TestCase, skipUnless
+from operator import itemgetter
py_heapq = support.import_fresh_module('heapq', blocked=['_heapq'])
c_heapq = support.import_fresh_module('heapq', fresh=['_heapq'])
@@ -152,11 +153,21 @@ class TestHeap:
def test_merge(self):
inputs = []
- for i in range(random.randrange(5)):
- row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
+ for i in range(random.randrange(25)):
+ row = []
+ for j in range(random.randrange(100)):
+ tup = random.choice('ABC'), random.randrange(-500, 500)
+ row.append(tup)
inputs.append(row)
- self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
- self.assertEqual(list(self.module.merge()), [])
+
+ for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
+ for reverse in [False, True]:
+ seqs = []
+ for seq in inputs:
+ seqs.append(sorted(seq, key=key, reverse=reverse))
+ self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse),
+ list(self.module.merge(*seqs, key=key, reverse=reverse)))
+ self.assertEqual(list(self.module.merge()), [])
def test_merge_does_not_suppress_index_error(self):
# Issue 19018: Heapq.merge suppresses IndexError from user generator