diff options
Diffstat (limited to 'Lib/heapq.py')
-rw-r--r-- | Lib/heapq.py | 349 |
1 files changed, 240 insertions, 109 deletions
diff --git a/Lib/heapq.py b/Lib/heapq.py index d615239..07af37e 100644 --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -127,8 +127,6 @@ From all times, sorting has always been a Great Art! :-) __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge', 'nlargest', 'nsmallest', 'heappushpop'] -from itertools import islice, count, tee, chain - def heappush(heap, item): """Push item onto heap, maintaining the heap invariant.""" heap.append(item) @@ -141,9 +139,8 @@ def heappop(heap): returnitem = heap[0] heap[0] = lastelt _siftup(heap, 0) - else: - returnitem = lastelt - return returnitem + return returnitem + return lastelt def heapreplace(heap, item): """Pop and return the current smallest value, and add the new item. @@ -179,12 +176,22 @@ def heapify(x): for i in reversed(range(n//2)): _siftup(x, i) -def _heappushpop_max(heap, item): - """Maxheap version of a heappush followed by a heappop.""" - if heap and item < heap[0]: - item, heap[0] = heap[0], item +def _heappop_max(heap): + """Maxheap version of a heappop.""" + lastelt = heap.pop() # raises appropriate IndexError if heap is empty + if heap: + returnitem = heap[0] + heap[0] = lastelt _siftup_max(heap, 0) - return item + return returnitem + return lastelt + +def _heapreplace_max(heap, item): + """Maxheap version of a heappop followed by a heappush.""" + returnitem = heap[0] # raises appropriate IndexError if heap is empty + heap[0] = item + _siftup_max(heap, 0) + return returnitem def _heapify_max(x): """Transform list into a maxheap, in-place, in O(len(x)) time.""" @@ -192,42 +199,6 @@ def _heapify_max(x): for i in reversed(range(n//2)): _siftup_max(x, i) -def nlargest(n, iterable): - """Find the n largest elements in a dataset. - - Equivalent to: sorted(iterable, reverse=True)[:n] - """ - if n < 0: - return [] - it = iter(iterable) - result = list(islice(it, n)) - if not result: - return result - heapify(result) - _heappushpop = heappushpop - for elem in it: - _heappushpop(result, elem) - result.sort(reverse=True) - return result - -def nsmallest(n, iterable): - """Find the n smallest elements in a dataset. - - Equivalent to: sorted(iterable)[:n] - """ - if n < 0: - return [] - it = iter(iterable) - result = list(islice(it, n)) - if not result: - return result - _heapify_max(result) - _heappushpop = _heappushpop_max - for elem in it: - _heappushpop(result, elem) - result.sort() - return result - # 'heap' is a heap at all indices >= startpos, except possibly for pos. pos # is the index of a leaf with a possibly out-of-order value. Restore the # heap invariant. @@ -340,13 +311,7 @@ def _siftup_max(heap, pos): heap[pos] = newitem _siftdown_max(heap, startpos, pos) -# If available, use C implementation -try: - from _heapq import * -except ImportError: - pass - -def merge(*iterables): +def merge(*iterables, key=None, reverse=False): '''Merge multiple sorted inputs into a single sorted output. Similar to sorted(itertools.chain(*iterables)) but returns a generator, @@ -356,51 +321,158 @@ def merge(*iterables): >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25])) [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25] + If *key* is not None, applies a key function to each element to determine + its sort order. + + >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len)) + ['dog', 'cat', 'fish', 'horse', 'kangaroo'] + ''' - _heappop, _heapreplace, _StopIteration = heappop, heapreplace, StopIteration - _len = len h = [] h_append = h.append - for itnum, it in enumerate(map(iter, iterables)): + + if reverse: + _heapify = _heapify_max + _heappop = _heappop_max + _heapreplace = _heapreplace_max + direction = -1 + else: + _heapify = heapify + _heappop = heappop + _heapreplace = heapreplace + direction = 1 + + if key is None: + for order, it in enumerate(map(iter, iterables)): + try: + next = it.__next__ + h_append([next(), order * direction, next]) + except StopIteration: + pass + _heapify(h) + while len(h) > 1: + try: + while True: + value, order, next = s = h[0] + yield value + s[0] = next() # raises StopIteration when exhausted + _heapreplace(h, s) # restore heap condition + except StopIteration: + _heappop(h) # remove empty iterator + if h: + # fast case when only a single iterator remains + value, order, next = h[0] + yield value + yield from next.__self__ + return + + for order, it in enumerate(map(iter, iterables)): try: next = it.__next__ - h_append([next(), itnum, next]) - except _StopIteration: + value = next() + h_append([key(value), order * direction, value, next]) + except StopIteration: pass - heapify(h) - - while _len(h) > 1: + _heapify(h) + while len(h) > 1: try: while True: - v, itnum, next = s = h[0] - yield v - s[0] = next() # raises StopIteration when exhausted - _heapreplace(h, s) # restore heap condition - except _StopIteration: - _heappop(h) # remove empty iterator + key_value, order, value, next = s = h[0] + yield value + value = next() + s[0] = key(value) + s[2] = value + _heapreplace(h, s) + except StopIteration: + _heappop(h) if h: - # fast case when only a single iterator remains - v, itnum, next = h[0] - yield v + key_value, order, value, next = h[0] + yield value yield from next.__self__ -# Extend the implementations of nsmallest and nlargest to use a key= argument -_nsmallest = nsmallest + +# Algorithm notes for nlargest() and nsmallest() +# ============================================== +# +# Make a single pass over the data while keeping the k most extreme values +# in a heap. Memory consumption is limited to keeping k values in a list. +# +# Measured performance for random inputs: +# +# number of comparisons +# n inputs k-extreme values (average of 5 trials) % more than min() +# ------------- ---------------- --------------------- ----------------- +# 1,000 100 3,317 231.7% +# 10,000 100 14,046 40.5% +# 100,000 100 105,749 5.7% +# 1,000,000 100 1,007,751 0.8% +# 10,000,000 100 10,009,401 0.1% +# +# Theoretical number of comparisons for k smallest of n random inputs: +# +# Step Comparisons Action +# ---- -------------------------- --------------------------- +# 1 1.66 * k heapify the first k-inputs +# 2 n - k compare remaining elements to top of heap +# 3 k * (1 + lg2(k)) * ln(n/k) replace the topmost value on the heap +# 4 k * lg2(k) - (k/2) final sort of the k most extreme values +# +# Combining and simplifying for a rough estimate gives: +# +# comparisons = n + k * (log(k, 2) * log(n/k) + log(k, 2) + log(n/k)) +# +# Computing the number of comparisons for step 3: +# ----------------------------------------------- +# * For the i-th new value from the iterable, the probability of being in the +# k most extreme values is k/i. For example, the probability of the 101st +# value seen being in the 100 most extreme values is 100/101. +# * If the value is a new extreme value, the cost of inserting it into the +# heap is 1 + log(k, 2). +# * The probability times the cost gives: +# (k/i) * (1 + log(k, 2)) +# * Summing across the remaining n-k elements gives: +# sum((k/i) * (1 + log(k, 2)) for i in range(k+1, n+1)) +# * This reduces to: +# (H(n) - H(k)) * k * (1 + log(k, 2)) +# * Where H(n) is the n-th harmonic number estimated by: +# gamma = 0.5772156649 +# H(n) = log(n, e) + gamma + 1 / (2 * n) +# http://en.wikipedia.org/wiki/Harmonic_series_(mathematics)#Rate_of_divergence +# * Substituting the H(n) formula: +# comparisons = k * (1 + log(k, 2)) * (log(n/k, e) + (1/n - 1/k) / 2) +# +# Worst-case for step 3: +# ---------------------- +# In the worst case, the input data is reversed sorted so that every new element +# must be inserted in the heap: +# +# comparisons = 1.66 * k + log(k, 2) * (n - k) +# +# Alternative Algorithms +# ---------------------- +# Other algorithms were not used because they: +# 1) Took much more auxiliary memory, +# 2) Made multiple passes over the data. +# 3) Made more comparisons in common cases (small k, large n, semi-random input). +# See the more detailed comparison of approach at: +# http://code.activestate.com/recipes/577573-compare-algorithms-for-heapqsmallest + def nsmallest(n, iterable, key=None): """Find the n smallest elements in a dataset. Equivalent to: sorted(iterable, key=key)[:n] """ - # Short-cut for n==1 is to use min() when len(iterable)>0 + + # Short-cut for n==1 is to use min() if n == 1: it = iter(iterable) - head = list(islice(it, 1)) - if not head: - return [] + sentinel = object() if key is None: - return [min(chain(head, it))] - return [min(chain(head, it), key=key)] + result = min(it, default=sentinel) + else: + result = min(it, default=sentinel, key=key) + return [] if result is sentinel else [result] # When n>=size, it's faster to use sorted() try: @@ -413,32 +485,57 @@ def nsmallest(n, iterable, key=None): # When key is none, use simpler decoration if key is None: - it = zip(iterable, count()) # decorate - result = _nsmallest(n, it) - return [r[0] for r in result] # undecorate + it = iter(iterable) + # put the range(n) first so that zip() doesn't + # consume one too many elements from the iterator + result = [(elem, i) for i, elem in zip(range(n), it)] + if not result: + return result + _heapify_max(result) + top = result[0][0] + order = n + _heapreplace = _heapreplace_max + for elem in it: + if elem < top: + _heapreplace(result, (elem, order)) + top = result[0][0] + order += 1 + result.sort() + return [r[0] for r in result] # General case, slowest method - in1, in2 = tee(iterable) - it = zip(map(key, in1), count(), in2) # decorate - result = _nsmallest(n, it) - return [r[2] for r in result] # undecorate + it = iter(iterable) + result = [(key(elem), i, elem) for i, elem in zip(range(n), it)] + if not result: + return result + _heapify_max(result) + top = result[0][0] + order = n + _heapreplace = _heapreplace_max + for elem in it: + k = key(elem) + if k < top: + _heapreplace(result, (k, order, elem)) + top = result[0][0] + order += 1 + result.sort() + return [r[2] for r in result] -_nlargest = nlargest def nlargest(n, iterable, key=None): """Find the n largest elements in a dataset. Equivalent to: sorted(iterable, key=key, reverse=True)[:n] """ - # Short-cut for n==1 is to use max() when len(iterable)>0 + # Short-cut for n==1 is to use max() if n == 1: it = iter(iterable) - head = list(islice(it, 1)) - if not head: - return [] + sentinel = object() if key is None: - return [max(chain(head, it))] - return [max(chain(head, it), key=key)] + result = max(it, default=sentinel) + else: + result = max(it, default=sentinel, key=key) + return [] if result is sentinel else [result] # When n>=size, it's faster to use sorted() try: @@ -451,26 +548,60 @@ def nlargest(n, iterable, key=None): # When key is none, use simpler decoration if key is None: - it = zip(iterable, count(0,-1)) # decorate - result = _nlargest(n, it) - return [r[0] for r in result] # undecorate + it = iter(iterable) + result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)] + if not result: + return result + heapify(result) + top = result[0][0] + order = -n + _heapreplace = heapreplace + for elem in it: + if top < elem: + _heapreplace(result, (elem, order)) + top = result[0][0] + order -= 1 + result.sort(reverse=True) + return [r[0] for r in result] # General case, slowest method - in1, in2 = tee(iterable) - it = zip(map(key, in1), count(0,-1), in2) # decorate - result = _nlargest(n, it) - return [r[2] for r in result] # undecorate + it = iter(iterable) + result = [(key(elem), i, elem) for i, elem in zip(range(0, -n, -1), it)] + if not result: + return result + heapify(result) + top = result[0][0] + order = -n + _heapreplace = heapreplace + for elem in it: + k = key(elem) + if top < k: + _heapreplace(result, (k, order, elem)) + top = result[0][0] + order -= 1 + result.sort(reverse=True) + return [r[2] for r in result] + +# If available, use C implementation +try: + from _heapq import * +except ImportError: + pass +try: + from _heapq import _heapreplace_max +except ImportError: + pass +try: + from _heapq import _heapify_max +except ImportError: + pass +try: + from _heapq import _heappop_max +except ImportError: + pass + if __name__ == "__main__": - # Simple sanity test - heap = [] - data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 0] - for item in data: - heappush(heap, item) - sort = [] - while heap: - sort.append(heappop(heap)) - print(sort) import doctest - doctest.testmod() + print(doctest.testmod()) |