summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/heapq.py156
-rw-r--r--Lib/test/test_heapq.py2
-rw-r--r--Misc/NEWS4
-rw-r--r--Modules/_heapqmodule.c92
4 files changed, 137 insertions, 117 deletions
diff --git a/Lib/heapq.py b/Lib/heapq.py
index fc73df9..14a7a86 100644
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -127,7 +127,7 @@ From all times, sorting has always been a Great Art! :-)
__all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge',
'nlargest', 'nsmallest', 'heappushpop']
-from itertools import islice, count, tee, chain
+from itertools import islice, count
def heappush(heap, item):
"""Push item onto heap, maintaining the heap invariant."""
@@ -179,12 +179,12 @@ def heapify(x):
for i in reversed(range(n//2)):
_siftup(x, i)
-def _heappushpop_max(heap, item):
- """Maxheap version of a heappush followed by a heappop."""
- if heap and item < heap[0]:
- item, heap[0] = heap[0], item
- _siftup_max(heap, 0)
- return item
+def _heapreplace_max(heap, item):
+ """Maxheap version of a heappop followed by a heappush."""
+ returnitem = heap[0] # raises appropriate IndexError if heap is empty
+ heap[0] = item
+ _siftup_max(heap, 0)
+ return returnitem
def _heapify_max(x):
"""Transform list into a maxheap, in-place, in O(len(x)) time."""
@@ -192,24 +192,6 @@ def _heapify_max(x):
for i in reversed(range(n//2)):
_siftup_max(x, i)
-def nsmallest(n, iterable):
- """Find the n smallest elements in a dataset.
-
- Equivalent to: sorted(iterable)[:n]
- """
- if n <= 0:
- return []
- it = iter(iterable)
- result = list(islice(it, n))
- if not result:
- return result
- _heapify_max(result)
- _heappushpop = _heappushpop_max
- for elem in it:
- _heappushpop(result, elem)
- result.sort()
- return result
-
# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos
# is the index of a leaf with a possibly out-of-order value. Restore the
# heap invariant.
@@ -327,6 +309,10 @@ try:
from _heapq import *
except ImportError:
pass
+try:
+ from _heapq import _heapreplace_max
+except ImportError:
+ pass
def merge(*iterables):
'''Merge multiple sorted inputs into a single sorted output.
@@ -367,22 +353,86 @@ def merge(*iterables):
yield v
yield from next.__self__
-# Extend the implementations of nsmallest and nlargest to use a key= argument
-_nsmallest = nsmallest
+
+# Algorithm notes for nlargest() and nsmallest()
+# ==============================================
+#
+# Makes just a single pass over the data while keeping the k most extreme values
+# in a heap. Memory consumption is limited to keeping k values in a list.
+#
+# Measured performance for random inputs:
+#
+# number of comparisons
+# n inputs k-extreme values (average of 5 trials) % more than min()
+# ------------- ---------------- - ------------------- -----------------
+# 1,000 100 3,317 133.2%
+# 10,000 100 14,046 40.5%
+# 100,000 100 105,749 5.7%
+# 1,000,000 100 1,007,751 0.8%
+# 10,000,000 100 10,009,401 0.1%
+#
+# Theoretical number of comparisons for k smallest of n random inputs:
+#
+# Step Comparisons Action
+# ---- -------------------------- ---------------------------
+# 1 1.66 * k heapify the first k-inputs
+# 2 n - k compare remaining elements to top of heap
+# 3 k * (1 + lg2(k)) * ln(n/k) replace the topmost value on the heap
+# 4 k * lg2(k) - (k/2) final sort of the k most extreme values
+# Combining and simplifying for a rough estimate gives:
+# comparisons = n + k * (1 + log(n/k)) * (1 + log(k, 2))
+#
+# Computing the number of comparisons for step 3:
+# -----------------------------------------------
+# * For the i-th new value from the iterable, the probability of being in the
+# k most extreme values is k/i. For example, the probability of the 101st
+# value seen being in the 100 most extreme values is 100/101.
+# * If the value is a new extreme value, the cost of inserting it into the
+# heap is 1 + log(k, 2).
+# * The probabilty times the cost gives:
+# (k/i) * (1 + log(k, 2))
+# * Summing across the remaining n-k elements gives:
+# sum((k/i) * (1 + log(k, 2)) for xrange(k+1, n+1))
+# * This reduces to:
+# (H(n) - H(k)) * k * (1 + log(k, 2))
+# * Where H(n) is the n-th harmonic number estimated by:
+# gamma = 0.5772156649
+# H(n) = log(n, e) + gamma + 1.0 / (2.0 * n)
+# http://en.wikipedia.org/wiki/Harmonic_series_(mathematics)#Rate_of_divergence
+# * Substituting the H(n) formula:
+# comparisons = k * (1 + log(k, 2)) * (log(n/k, e) + (1/n - 1/k) / 2)
+#
+# Worst-case for step 3:
+# ----------------------
+# In the worst case, the input data is reversed sorted so that every new element
+# must be inserted in the heap:
+#
+# comparisons = 1.66 * k + log(k, 2) * (n - k)
+#
+# Alternative Algorithms
+# ----------------------
+# Other algorithms were not used because they:
+# 1) Took much more auxiliary memory,
+# 2) Made multiple passes over the data.
+# 3) Made more comparisons in common cases (small k, large n, semi-random input).
+# See the more detailed comparison of approach at:
+# http://code.activestate.com/recipes/577573-compare-algorithms-for-heapqsmallest
+
def nsmallest(n, iterable, key=None):
"""Find the n smallest elements in a dataset.
Equivalent to: sorted(iterable, key=key)[:n]
"""
+
# Short-cut for n==1 is to use min() when len(iterable)>0
if n == 1:
it = iter(iterable)
- head = list(islice(it, 1))
- if not head:
- return []
+ sentinel = object()
if key is None:
- return [min(chain(head, it))]
- return [min(chain(head, it), key=key)]
+ result = min(it, default=sentinel)
+ else:
+ result = min(it, default=sentinel, key=key)
+ return [] if result is sentinel else [result]
# When n>=size, it's faster to use sorted()
try:
@@ -395,15 +445,39 @@ def nsmallest(n, iterable, key=None):
# When key is none, use simpler decoration
if key is None:
- it = zip(iterable, count()) # decorate
- result = _nsmallest(n, it)
- return [r[0] for r in result] # undecorate
+ it = iter(iterable)
+ result = list(islice(zip(it, count()), n))
+ if not result:
+ return result
+ _heapify_max(result)
+ order = n
+ top = result[0][0]
+ _heapreplace = _heapreplace_max
+ for elem in it:
+ if elem < top:
+ _heapreplace(result, (elem, order))
+ top = result[0][0]
+ order += 1
+ result.sort()
+ return [r[0] for r in result]
# General case, slowest method
- in1, in2 = tee(iterable)
- it = zip(map(key, in1), count(), in2) # decorate
- result = _nsmallest(n, it)
- return [r[2] for r in result] # undecorate
+ it = iter(iterable)
+ result = [(key(elem), i, elem) for i, elem in zip(range(n), it)]
+ if not result:
+ return result
+ _heapify_max(result)
+ order = n
+ top = result[0][0]
+ _heapreplace = _heapreplace_max
+ for elem in it:
+ k = key(elem)
+ if k < top:
+ _heapreplace(result, (k, order, elem))
+ top = result[0][0]
+ order += 1
+ result.sort()
+ return [r[2] for r in result]
def nlargest(n, iterable, key=None):
"""Find the n largest elements in a dataset.
@@ -442,9 +516,9 @@ def nlargest(n, iterable, key=None):
_heapreplace = heapreplace
for elem in it:
if top < elem:
- order -= 1
_heapreplace(result, (elem, order))
top = result[0][0]
+ order -= 1
result.sort(reverse=True)
return [r[0] for r in result]
@@ -460,9 +534,9 @@ def nlargest(n, iterable, key=None):
for elem in it:
k = key(elem)
if top < k:
- order -= 1
_heapreplace(result, (k, order, elem))
top = result[0][0]
+ order -= 1
result.sort(reverse=True)
return [r[2] for r in result]
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index 1735a19..59c7029 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -13,7 +13,7 @@ c_heapq = support.import_fresh_module('heapq', fresh=['_heapq'])
# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
# _heapq is imported, so check them there
func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',
- 'heapreplace', '_nsmallest']
+ 'heapreplace', '_heapreplace_max']
class TestModules(TestCase):
def test_py_functions(self):
diff --git a/Misc/NEWS b/Misc/NEWS
index b80c066..43e5b60 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -84,8 +84,8 @@ Library
- Issue #21156: importlib.abc.InspectLoader.source_to_code() is now a
staticmethod.
-- Issue #21424: Simplified and optimized heaqp.nlargest() to make fewer
- tuple comparisons.
+- Issue #21424: Simplified and optimized heaqp.nlargest() and nmsmallest()
+ to make fewer tuple comparisons.
- Issue #21396: Fix TextIOWrapper(..., write_through=True) to not force a
flush() on the underlying binary stream. Patch by akira.
diff --git a/Modules/_heapqmodule.c b/Modules/_heapqmodule.c
index 964f511..ad190df 100644
--- a/Modules/_heapqmodule.c
+++ b/Modules/_heapqmodule.c
@@ -354,88 +354,34 @@ _siftupmax(PyListObject *heap, Py_ssize_t pos)
}
static PyObject *
-nsmallest(PyObject *self, PyObject *args)
+_heapreplace_max(PyObject *self, PyObject *args)
{
- PyObject *heap=NULL, *elem, *iterable, *los, *it, *oldelem;
- Py_ssize_t i, n;
- int cmp;
+ PyObject *heap, *item, *returnitem;
- if (!PyArg_ParseTuple(args, "nO:nsmallest", &n, &iterable))
+ if (!PyArg_UnpackTuple(args, "_heapreplace_max", 2, 2, &heap, &item))
return NULL;
- it = PyObject_GetIter(iterable);
- if (it == NULL)
+ if (!PyList_Check(heap)) {
+ PyErr_SetString(PyExc_TypeError, "heap argument must be a list");
return NULL;
-
- heap = PyList_New(0);
- if (heap == NULL)
- goto fail;
-
- for (i=0 ; i<n ; i++ ){
- elem = PyIter_Next(it);
- if (elem == NULL) {
- if (PyErr_Occurred())
- goto fail;
- else
- goto sortit;
- }
- if (PyList_Append(heap, elem) == -1) {
- Py_DECREF(elem);
- goto fail;
- }
- Py_DECREF(elem);
}
- n = PyList_GET_SIZE(heap);
- if (n == 0)
- goto sortit;
-
- for (i=n/2-1 ; i>=0 ; i--)
- if(_siftupmax((PyListObject *)heap, i) == -1)
- goto fail;
-
- los = PyList_GET_ITEM(heap, 0);
- while (1) {
- elem = PyIter_Next(it);
- if (elem == NULL) {
- if (PyErr_Occurred())
- goto fail;
- else
- goto sortit;
- }
- cmp = PyObject_RichCompareBool(elem, los, Py_LT);
- if (cmp == -1) {
- Py_DECREF(elem);
- goto fail;
- }
- if (cmp == 0) {
- Py_DECREF(elem);
- continue;
- }
- oldelem = PyList_GET_ITEM(heap, 0);
- PyList_SET_ITEM(heap, 0, elem);
- Py_DECREF(oldelem);
- if (_siftupmax((PyListObject *)heap, 0) == -1)
- goto fail;
- los = PyList_GET_ITEM(heap, 0);
+ if (PyList_GET_SIZE(heap) < 1) {
+ PyErr_SetString(PyExc_IndexError, "index out of range");
+ return NULL;
}
-sortit:
- if (PyList_Sort(heap) == -1)
- goto fail;
- Py_DECREF(it);
- return heap;
-
-fail:
- Py_DECREF(it);
- Py_XDECREF(heap);
- return NULL;
+ returnitem = PyList_GET_ITEM(heap, 0);
+ Py_INCREF(item);
+ PyList_SET_ITEM(heap, 0, item);
+ if (_siftupmax((PyListObject *)heap, 0) == -1) {
+ Py_DECREF(returnitem);
+ return NULL;
+ }
+ return returnitem;
}
-PyDoc_STRVAR(nsmallest_doc,
-"Find the n smallest elements in a dataset.\n\
-\n\
-Equivalent to: sorted(iterable)[:n]\n");
+PyDoc_STRVAR(heapreplace_max_doc, "Maxheap variant of heapreplace");
static PyMethodDef heapq_methods[] = {
{"heappush", (PyCFunction)heappush,
@@ -448,8 +394,8 @@ static PyMethodDef heapq_methods[] = {
METH_VARARGS, heapreplace_doc},
{"heapify", (PyCFunction)heapify,
METH_O, heapify_doc},
- {"nsmallest", (PyCFunction)nsmallest,
- METH_VARARGS, nsmallest_doc},
+ {"_heapreplace_max",(PyCFunction)_heapreplace_max,
+ METH_VARARGS, heapreplace_max_doc},
{NULL, NULL} /* sentinel */
};